molecule-ai-workspace-runtime/molecule_runtime/heartbeat.py
rabbitblood 050c2412b3 fix(heartbeat): refresh on-disk auth token on 401 + retry once (#1877)
## Problem

Auto-restart rotates the workspace's auth token in two non-atomic steps:
  1. Platform issues new token via wsauth.IssueToken
  2. Provisioner writes the new token to /configs/.auth_token AFTER
     ContainerStart returns

Between steps 1 and 2, the new container has booted and the runtime has
already loaded the OLD cached value of .auth_token (or no value if the
file was empty during boot). The runtime's first /registry/heartbeat
call sends the stale token, gets 401, but the loop never re-reads the
on-disk token — so subsequent heartbeats also send the stale value.

Each 401 means the platform never sees the workspace as alive →
status stays 'provisioning' → scheduler won't dispatch → workspace
looks dead from every angle even though the container is actually
running.

The existing code comment in workspace_provision.go acknowledges this:
"the workspace will get 401 on its first heartbeat and can recover on
the next restart." That recovery only worked because workspaces used
to crash for unrelated reasons and get restarted. After PR #1861
(provisioner empty-volume auto-recover) removed those crashes,
workspaces get stuck in the 401 loop with no exit.

## Fix

Two-part runtime-side fix in molecule-ai-workspace-runtime:

1. **platform_auth.refresh_from_disk()** — new helper that clears the
   in-memory cache and re-reads /configs/.auth_token. Returns the
   fresh value (or None if missing). Updates the cache as a side effect.

2. **HeartbeatLoop._loop()** — on 401 from /registry/heartbeat, calls
   refresh_from_disk() and retries the request ONCE with the new token.
   Same pattern in _check_delegations(). Bounded retry budget — if the
   on-disk token is also stale (bug elsewhere), no infinite loop.

## Tests

6/6 new tests in tests/test_token_refresh_1877.py:

  - refresh_picks_up_rotated_token              — happy path
  - refresh_returns_none_when_file_missing      — defensive
  - refresh_clears_stale_cache_when_file_disappears
  - refresh_is_idempotent
  - 401_retry_pattern_uses_refreshed_token      — the production fix path
  - 401_retry_no_loop_when_disk_token_also_stale — bounded retry budget

All pass locally on Python 3.13 + pytest 9.

## Why this fix and not the alternatives

- **Alternative B (platform writes token before ContainerStart):**
  Right architecturally but invasive — needs provisioner refactor to
  prep volumes before docker run.
- **Alternative C (skip rotation on auto-restart):** Breaks the
  multi-instance-safety invariant the existing code calls out
  (revoke prevents stale tokens from sister deployments).
- **This fix (A):** 3-line core change + helper. Self-healing for any
  timing edge case, not just the post-restart one. Costs nothing in
  the happy path (only triggers on 401).

## Version

Bumped to 0.1.9. Once published to PyPI + workspace template image
rebuilt, deployed workspaces auto-recover from token-rotation races
without operator intervention.

Closes #1877.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-23 13:26:36 -07:00

310 lines
14 KiB
Python

"""Heartbeat loop — alive signal + delegation status checker.
Every 30 seconds:
1. Send heartbeat to platform (alive signal with current_task, error_rate)
2. Check pending delegations — any results back?
3. Store completed delegation results for the agent to pick up
Resilient: recreates HTTP client on failure, auto-restarts on crash.
"""
import asyncio
import json
import logging
import os
import time
from pathlib import Path
import httpx
from molecule_runtime.platform_auth import auth_headers, refresh_from_disk
logger = logging.getLogger(__name__)
HEARTBEAT_INTERVAL = 30 # seconds
MAX_CONSECUTIVE_FAILURES = 10
MAX_SEEN_DELEGATION_IDS = 200
SELF_MESSAGE_COOLDOWN = 60 # seconds — minimum between self-messages to prevent loops
# Shared path — also used by cli_executor._read_delegation_results()
DELEGATION_RESULTS_FILE = os.environ.get("DELEGATION_RESULTS_FILE", "/tmp/delegation_results.jsonl")
class HeartbeatLoop:
def __init__(self, platform_url: str, workspace_id: str):
self.platform_url = platform_url
self.workspace_id = workspace_id
self.start_time = time.time()
self.error_count = 0
self.request_count = 0
self.active_tasks = 0
self.current_task = ""
self.sample_error = ""
self._task = None
self._consecutive_failures = 0
self._seen_delegation_ids: set[str] = set()
self._last_self_message_time = 0.0
self._parent_name: str | None = None # Cached after first lookup
@property
def error_rate(self) -> float:
if self.request_count == 0:
return 0.0
return self.error_count / self.request_count
def record_error(self, error: str):
self.error_count += 1
self.request_count += 1
self.sample_error = error
def record_success(self):
self.request_count += 1
def start(self):
self._task = asyncio.create_task(self._loop())
self._task.add_done_callback(self._on_done)
def _on_done(self, task):
if not task.cancelled() and task.exception():
logger.error("Heartbeat loop died: %s — restarting", task.exception())
self._task = asyncio.create_task(self._loop())
self._task.add_done_callback(self._on_done)
async def stop(self):
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
async def _loop(self):
while True:
client = None
try:
client = httpx.AsyncClient(timeout=10.0)
while True:
# 1. Send heartbeat (Phase 30.1: include auth header if token known)
try:
hb_payload = {
"workspace_id": self.workspace_id,
"error_rate": self.error_rate,
"sample_error": self.sample_error,
"active_tasks": self.active_tasks,
"current_task": self.current_task,
"uptime_seconds": int(time.time() - self.start_time),
}
resp = await client.post(
f"{self.platform_url}/registry/heartbeat",
json=hb_payload,
headers=auth_headers(),
)
# #1877: auto-restart rotates the workspace token AFTER
# container start, so the first heartbeat after a restart
# can race the token write and send the stale cached
# value → 401. Re-read /configs/.auth_token and retry ONCE
# to break the 401 loop without needing another restart.
if resp.status_code == 401:
if refresh_from_disk() is not None:
logger.info(
"Heartbeat: got 401, refreshed token from disk, retrying"
)
resp = await client.post(
f"{self.platform_url}/registry/heartbeat",
json=hb_payload,
headers=auth_headers(),
)
self.error_count = 0
self.request_count = 0
self._consecutive_failures = 0
except Exception as e:
self._consecutive_failures += 1
if self._consecutive_failures <= 3 or self._consecutive_failures % MAX_CONSECUTIVE_FAILURES == 0:
logger.warning("Heartbeat failed (%d consecutive): %s", self._consecutive_failures, e)
if self._consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
logger.info("Heartbeat: recreating HTTP client after %d failures", self._consecutive_failures)
try:
await client.aclose()
except Exception:
pass
break
# 2. Check delegation status
try:
await self._check_delegations(client)
except Exception as e:
logger.debug("Delegation check failed: %s", e)
await asyncio.sleep(HEARTBEAT_INTERVAL)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("Heartbeat loop error: %s — retrying in 30s", e)
await asyncio.sleep(HEARTBEAT_INTERVAL)
finally:
if client:
try:
await client.aclose()
except Exception:
pass
async def _check_delegations(self, client: httpx.AsyncClient):
"""Check for completed delegations and store results for the agent."""
try:
url = f"{self.platform_url}/workspaces/{self.workspace_id}/delegations"
resp = await client.get(url, headers=auth_headers())
# #1877: refresh token on 401 and retry ONCE — same post-restart
# token-rotation race as the heartbeat path above.
if resp.status_code == 401 and refresh_from_disk() is not None:
resp = await client.get(url, headers=auth_headers())
if resp.status_code != 200:
return
delegations = resp.json()
if not isinstance(delegations, list):
return
new_results = []
for d in delegations:
did = d.get("delegation_id", "")
status = d.get("status", "")
if not did or did in self._seen_delegation_ids:
continue
if status in ("completed", "failed"):
# Fix B (Cycle 5): validate source_id before accepting delegation
# results. Only process delegations that THIS workspace created
# (source_id == self.workspace_id). Attacker-crafted delegation
# records with a foreign source_id cannot inject instructions.
source_id = d.get("source_id", "")
if source_id != self.workspace_id:
logger.warning(
"Heartbeat: skipping delegation %s — source_id %r does not "
"match this workspace %r; possible injection attempt",
did, source_id, self.workspace_id,
)
self._seen_delegation_ids.add(did) # mark seen so we don't warn again
continue
self._seen_delegation_ids.add(did)
new_results.append({
"delegation_id": did,
"target_id": d.get("target_id", ""),
"source_id": source_id,
"status": status,
"summary": d.get("summary", ""),
"response_preview": d.get("response_preview", ""),
"error": d.get("error", ""),
"timestamp": time.time(),
})
# Evict old seen IDs if over limit
if len(self._seen_delegation_ids) > MAX_SEEN_DELEGATION_IDS:
# Keep most recent half
self._seen_delegation_ids = set(list(self._seen_delegation_ids)[MAX_SEEN_DELEGATION_IDS // 2:])
if new_results:
# Append to results file for context injection on next message
with open(DELEGATION_RESULTS_FILE, "a") as f:
for r in new_results:
f.write(json.dumps(r) + "\n")
logger.info("Heartbeat: %d new delegation results — triggering self-message", len(new_results))
# Build a summary message for the agent.
# Fix B (Cycle 5): do NOT embed raw response_preview text in
# user-role A2A messages — that is the prompt-injection vector.
# Instead reference only the delegation ID and status; the agent
# reads full content from DELEGATION_RESULTS_FILE which was
# written above from trusted platform data.
summary_lines = []
for r in new_results:
line = f"- [{r['status']}] Delegation {r['delegation_id'][:8]}: {r['summary'][:80]}"
if r.get("error"):
line += f"\n Error: {r['error'][:100]}"
summary_lines.append(line)
# Look up parent workspace (cached after first call)
if self._parent_name is None:
try:
parent_resp = await client.get(
f"{self.platform_url}/workspaces/{self.workspace_id}",
headers=auth_headers(),
)
if parent_resp.status_code == 200:
parent_id = parent_resp.json().get("parent_id", "")
if parent_id:
parent_info = await client.get(
f"{self.platform_url}/workspaces/{parent_id}",
headers=auth_headers(),
)
if parent_info.status_code == 200:
self._parent_name = parent_info.json().get("name", "")
if self._parent_name is None:
self._parent_name = "" # No parent — cache empty
except Exception:
pass # Will retry next cycle
parent_name = self._parent_name or ""
report_instruction = ""
if parent_name:
report_instruction = (
f"\n\nIMPORTANT: Report these results back to your parent '{parent_name}' "
f"by delegating a summary to them. Use delegate_task or delegate_task_async "
f"with a concise status report. Also use send_message_to_user to notify the user."
)
else:
report_instruction = (
"\n\nReport results using send_message_to_user to notify the user."
)
trigger_msg = (
"Delegation results are ready. Review them and take appropriate action:\n"
+ "\n".join(summary_lines)
+ report_instruction
)
# Send A2A self-message to wake the agent.
# Minimum 60s between self-messages to avoid spam, but always send
# when there are genuinely NEW results to process.
now = time.time()
if now - self._last_self_message_time < SELF_MESSAGE_COOLDOWN:
logger.debug("Heartbeat: self-message cooldown (60s), will retry next cycle")
else:
self._last_self_message_time = now
try:
await client.post(
f"{self.platform_url}/workspaces/{self.workspace_id}/a2a",
json={
"method": "message/send",
"params": {
"message": {
"role": "user",
"parts": [{"type": "text", "text": trigger_msg}],
},
},
},
headers=auth_headers(),
timeout=120.0,
)
logger.info("Heartbeat: self-message sent to process delegation results")
except Exception as e:
logger.warning("Heartbeat: failed to send self-message: %s", e)
# Also push notification to user via canvas
for r in new_results:
try:
msg = f"Delegation {r['status']}: {r['summary'][:100]}"
if r.get("response_preview"):
msg += f"\nResult: {r['response_preview'][:200]}"
await client.post(
f"{self.platform_url}/workspaces/{self.workspace_id}/notify",
json={"message": msg, "type": "delegation_result"},
headers=auth_headers(),
)
except Exception:
pass
except Exception as e:
logger.debug("Delegation check error: %s", e)