diff --git a/workspace/adapter_base.py b/workspace/adapter_base.py index 3ef48984..0de914c4 100644 --- a/workspace/adapter_base.py +++ b/workspace/adapter_base.py @@ -132,6 +132,77 @@ class BaseAdapter(ABC): "source": None, } + def pre_stop_state(self) -> dict: + """Capture in-memory state for pause/resume serialization. + + Called by main.py's shutdown handler just before the container exits. + Returns a dict that will be scrubbed (via lib.snapshot_scrub) and + written to /configs/.agent_snapshot.json. + + Default implementation: + 1. Attempts to read ``self._executor._session_id`` (set by + create_executor) and includes it as ``session_id``. + 2. Includes up to 200 recent transcript lines via transcript_lines(). + + Override in adapters that hold additional in-memory state that + should survive a container stop. + + Returns: + A JSON-serializable dict. All string values are scrubbed before + persisting, so it is safe to include raw content from the + agent's context. + """ + from lib.pre_stop import MAX_TRANSCRIPT_LINES + + state: dict = {} + + # Session handle — critical for resuming the Claude Code session. + executor = getattr(self, "_executor", None) + if executor is not None: + session_id = getattr(executor, "_session_id", None) + if session_id: + state["session_id"] = session_id + + # Recent conversation log — captures where the agent left off. + # transcript_lines() may be async; call it synchronously if possible, + # otherwise let async adapters override pre_stop_state entirely. + try: + import inspect as _inspect + transcript_fn = self.transcript_lines + if _inspect.iscoroutinefunction(transcript_fn): + # Async adapter — override pre_stop_state() for transcript access. + # The base impl still captures session_id above. + pass + else: + transcript = transcript_fn(since=0, limit=MAX_TRANSCRIPT_LINES) + if transcript.get("supported"): + state["transcript_lines"] = transcript.get("lines", []) + except Exception: + # Best-effort: never let transcript capture failure block serialization. + pass + + return state + + def restore_state(self, snapshot: dict) -> None: + """Restore in-memory state from a pause/resume snapshot. + + Called by main.py on first boot when /configs/.agent_snapshot.json + exists. Gives the adapter a chance to restore session handles, + conversation context, or any other in-memory state before the A2A + server starts accepting requests. + + Default implementation stores ``snapshot["session_id"]`` and + ``snapshot["transcript_lines"]`` as ``self._snapshot_session_id`` + and ``self._snapshot_transcript`` so that ``create_executor()`` or + the executor itself can pick them up. + + Args: + snapshot: The scrubbed snapshot dict previously written by + pre_stop_state(). All secrets have already been redacted. + """ + self._snapshot_session_id: str | None = snapshot.get("session_id") + self._snapshot_transcript: list | None = snapshot.get("transcript_lines") + def register_subagent_hook(self, name: str, spec: dict) -> None: """Default no-op. DeepAgents overrides to register a sub-agent.""" return None @@ -305,5 +376,9 @@ class BaseAdapter(ABC): async def create_executor(self, config: AdapterConfig) -> AgentExecutor: """Create and return an AgentExecutor ready for A2A integration. The returned executor's execute() method will be called by the - A2A server's DefaultRequestHandler.""" + A2A server's DefaultRequestHandler. + + Subclasses should also store the returned executor as ``self._executor`` + so ``pre_stop_state()`` can access it for serialization. + """ ... # pragma: no cover diff --git a/workspace/lib/pre_stop.py b/workspace/lib/pre_stop.py new file mode 100644 index 00000000..da919d39 --- /dev/null +++ b/workspace/lib/pre_stop.py @@ -0,0 +1,192 @@ +"""Pre-stop serialization for pause/resume — GH#1391. + +Captures the agent's in-memory state just before the container exits so it +survives intentional pause and unplanned restart. All content is scrubbed +with lib.snapshot_scrub before being written to disk so that a snapshot blob +obtained by an attacker cannot recover API keys, tokens, or arbitrary sandbox +output (GH#823). + +State captured +-------------- +- ``workspace_id`` — identity for cross-container restore +- ``current_task`` — active task label from heartbeat (what the canvas sees) +- ``active_tasks`` — task count +- ``session_id`` — SDK session handle (Claude Code); key for full session +- ``transcript_lines`` — recent session log lines from the adapter +- ``uptime_seconds`` — how long this container has been running +- ``timestamp`` — when the snapshot was taken (ISO-8601) + +Scrubbing +--------- +Every text field passes through scrub_snapshot before being written. +Sandbox-sourced content (tool=run_code, source=sandbox, [sandbox_output]) is +dropped wholesale. Secrets matching the pattern library are replaced with +[REDACTED:TYPE] markers. + +Storage +------- +Snapshots are written to /configs/.agent_snapshot.json by default. The +config volume survives container restarts so the file is durable. The path +is also overridable via ``AGENT_SNAPSHOT_PATH`` for testing or custom layouts. +""" + +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from .snapshot_scrub import scrub_snapshot + +if TYPE_CHECKING: + from heartbeat import HeartbeatLoop + +logger = logging.getLogger(__name__) + +# Default snapshot path — on the config volume, survives container restarts. +DEFAULT_SNAPSHOT_PATH = os.environ.get( + "AGENT_SNAPSHOT_PATH", + "/configs/.agent_snapshot.json", +) + +# How many transcript lines to capture in the snapshot (recent window). +MAX_TRANSCRIPT_LINES = 200 + + +def build_snapshot( + heartbeat: "HeartbeatLoop | None", + adapter_state: dict[str, Any], +) -> dict[str, Any]: + """Build a raw snapshot dict from live workspace state. + + Args: + heartbeat: HeartbeatLoop instance; provides current_task, session_id, etc. + adapter_state: Arbitrary state dict from the adapter's pre_stop_state() hook. + Keys are free-form; all string values in nested dicts/lists are + scrubbed before writing. + + Returns a raw (not yet scrubbed) snapshot dict. + """ + import time + + raw: dict[str, Any] = { + "workspace_id": os.environ.get("WORKSPACE_ID", "unknown"), + "timestamp": datetime.now(timezone.utc).isoformat(), + # Defaults — heartbeat block below overwrites these when available: + "current_task": "", + "active_tasks": 0, + } + + if heartbeat is not None: + raw["current_task"] = heartbeat.current_task or "" + raw["active_tasks"] = heartbeat.active_tasks + if hasattr(heartbeat, "start_time"): + raw["uptime_seconds"] = int(time.time() - heartbeat.start_time) + # session_id lives in the adapter but we also accept it via heartbeat + # for convenience (avoids requiring every adapter to pass it separately). + if not adapter_state.get("session_id"): + raw["session_id"] = getattr(heartbeat, "_session_id", None) or "" + + # Adapter-supplied state (conversation history, reasoning traces, etc.) + raw["adapter"] = adapter_state + + return raw + + +def _scrub_value(value: Any) -> Any: + """Recursively scrub all secret patterns from a value. + + - Strings: scrub_content() replaces patterns with [REDACTED:TYPE]. + - Dicts: return a new dict with all values scrubbed recursively. + - Lists: drop entries that are sandbox content; scrub remaining items. + - Other: pass through unchanged. + """ + from .snapshot_scrub import is_sandbox_content, scrub_content + + if isinstance(value, str): + return scrub_content(value) + if isinstance(value, dict): + return {k: _scrub_value(v) for k, v in value.items()} + if isinstance(value, list): + result = [] + for item in value: + if isinstance(item, str) and is_sandbox_content(item): + continue # Drop sandbox entries wholesale + result.append(_scrub_value(item)) + return result + return value + + +def write_snapshot( + snapshot: dict[str, Any], + path: str | None = None, +) -> bool: + """Scrub and write a snapshot to disk. + + Args: + snapshot: Raw snapshot dict from build_snapshot(). + path: Target file path (default: DEFAULT_SNAPSHOT_PATH). + + Returns: + True if the snapshot was written successfully; False on any error. + Errors are logged but never raise — pre-stop serialization must be + best-effort to avoid blocking shutdown. + """ + target = path or DEFAULT_SNAPSHOT_PATH + + try: + # Deep-scrub every string value in the snapshot to remove API keys, + # tokens, and arbitrary sandbox output before writing to disk. + scrubbed = _scrub_value(snapshot) + + # Ensure parent directory exists. + parent = os.path.dirname(target) + if parent: + os.makedirs(parent, exist_ok=True) + + with open(target, "w") as f: + json.dump(scrubbed, f, indent=2, default=str) + + logger.info( + "Pre-stop snapshot written: %s (workspace=%s, task=%r, lines=%d)", + target, + scrubbed.get("workspace_id", "?"), + scrubbed.get("current_task", ""), + len(scrubbed.get("adapter", {}).get("transcript_lines", [])), + ) + return True + + except Exception as exc: + logger.warning("Pre-stop snapshot write failed (%s): %s", target, exc) + return False + + +def read_snapshot( + path: str | None = None, +) -> dict[str, Any] | None: + """Read and return a previously-written snapshot, or None if absent/invalid.""" + target = path or DEFAULT_SNAPSHOT_PATH + + if not os.path.exists(target): + return None + + try: + with open(target) as f: + return json.load(f) + except Exception as exc: + logger.debug("Snapshot read failed (%s): %s", target, exc) + return None + + +def delete_snapshot(path: str | None = None) -> None: + """Remove a snapshot file. Idempotent — no error if absent.""" + target = path or DEFAULT_SNAPSHOT_PATH + try: + os.remove(target) + logger.debug("Snapshot deleted: %s", target) + except FileNotFoundError: + pass + except Exception as exc: + logger.warning("Snapshot delete failed (%s): %s", target, exc) diff --git a/workspace/main.py b/workspace/main.py index 59baee76..c95feba6 100644 --- a/workspace/main.py +++ b/workspace/main.py @@ -124,6 +124,21 @@ async def main(): # pragma: no cover try: await adapter.setup(adapter_config) executor = await adapter.create_executor(adapter_config) + + # 5b. Restore from pre-stop snapshot if one exists (GH#1391). + # The snapshot is scrubbed before being written, so secrets are + # already redacted — restore_state must not re-expose them. + from lib.pre_stop import read_snapshot + snapshot = read_snapshot() + if snapshot: + try: + adapter.restore_state(snapshot) + print( + f"Pre-stop snapshot restored: task={snapshot.get('current_task', '')!r}, " + f"uptime={snapshot.get('uptime_seconds', 0)}s" + ) + except Exception as restore_err: + print(f"Warning: snapshot restore failed (continuing): {restore_err}") except Exception: # heartbeat hasn't started yet but may have async tasks pending if hasattr(heartbeat, "stop"): @@ -543,6 +558,18 @@ async def main(): # pragma: no cover try: await server.serve() finally: + # 10d. Pre-stop serialization — GH#1391. + # Capture in-memory state before the container exits so it survives + # intentional pause and unplanned restart. All content is scrubbed + # via lib.snapshot_scrub before being written to the config volume. + try: + from lib.pre_stop import build_snapshot, write_snapshot + adapter_state = adapter.pre_stop_state() if adapter else {} + snapshot = build_snapshot(heartbeat, adapter_state) + write_snapshot(snapshot) + except Exception as pre_stop_err: + print(f"Warning: pre-stop serialization failed (continuing): {pre_stop_err}") + # Cancel initial prompt if still running if initial_prompt_task and not initial_prompt_task.done(): initial_prompt_task.cancel() diff --git a/workspace/tests/test_pre_stop.py b/workspace/tests/test_pre_stop.py new file mode 100644 index 00000000..13bf1f52 --- /dev/null +++ b/workspace/tests/test_pre_stop.py @@ -0,0 +1,270 @@ +"""Tests for lib.pre_stop — GH#1391 pre-stop serialization.""" + +import json +import os +import tempfile + +import pytest + + +class _MockHeartbeat: + """Minimal heartbeat for testing — matches heartbeat.HeartbeatLoop shape.""" + + def __init__(self): + self.current_task = "Implementing feature X" + self.active_tasks = 1 + self.start_time = 1000.0 + self._session_id = None + + +class _MockAdapter: + """Minimal adapter that returns known pre_stop_state for testing.""" + + def pre_stop_state(self): + return { + "session_id": "sess_abc123xyz", + "transcript_lines": [ + "User: hello", + "Agent: Hi! How can I help?", + ], + } + + +def test_build_snapshot_basic(): + """build_snapshot returns workspace_id, timestamp, and heartbeat fields.""" + from lib.pre_stop import build_snapshot + + hb = _MockHeartbeat() + adapter_state = {"session_id": "sess_abc", "transcript_lines": ["line1"]} + snap = build_snapshot(hb, adapter_state) + + assert snap["workspace_id"] == os.environ.get("WORKSPACE_ID", "unknown") + assert "timestamp" in snap + assert snap["current_task"] == "Implementing feature X" + assert snap["active_tasks"] == 1 + assert snap["adapter"] == adapter_state + + +def test_build_snapshot_none_heartbeat(): + """build_snapshot handles None heartbeat gracefully.""" + from lib.pre_stop import build_snapshot + + snap = build_snapshot(None, {"session_id": "sess_xyz"}) + assert snap["current_task"] == "" + assert snap["active_tasks"] == 0 + # session_id is NOT promoted to top-level when heartbeat is absent; + # it stays nested inside adapter. + assert "session_id" not in snap + assert snap["adapter"]["session_id"] == "sess_xyz" + + +def test_build_snapshot_scrubbed_secrets(): + """Snapshot content with API keys is scrubbed by write_snapshot.""" + from lib.pre_stop import build_snapshot, write_snapshot + + hb = _MockHeartbeat() + adapter_state = { + "session_id": "sess_secret", + "transcript_lines": [ + "Authorization: Bearer abc123.def456.ghi789", + "token_used: Bearer xyz.token.placeholder", + ], + } + snap = build_snapshot(hb, adapter_state) + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = f.name + + try: + ok = write_snapshot(snap, path=path) + assert ok, "write_snapshot should return True on success" + + with open(path) as f: + loaded = json.load(f) + + lines = loaded["adapter"]["transcript_lines"] + assert not any("Bearer abc" in l for l in lines), "Bearer token should be scrubbed" + assert any("REDACTED" in l for l in lines), "Scrub markers should be present" + finally: + os.unlink(path) + + +def test_build_snapshot_scrub_drops_sandbox_content(): + """Sandbox-sourced transcript lines are dropped entirely.""" + from lib.pre_stop import build_snapshot, write_snapshot + + hb = _MockHeartbeat() + adapter_state = { + "session_lines": [ + "source=sandbox echo hello", + "Normal message", + ], + } + snap = build_snapshot(hb, adapter_state) + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = f.name + + try: + write_snapshot(snap, path=path) + with open(path) as f: + loaded = json.load(f) + # scrub_snapshot drops sandbox entries from lists + lines = loaded["adapter"].get("session_lines", []) + assert not any("sandbox" in l for l in lines), "Sandbox lines should be dropped" + finally: + os.unlink(path) + + +def test_read_snapshot_missing_returns_none(): + """read_snapshot returns None when the file doesn't exist.""" + from lib.pre_stop import read_snapshot + + result = read_snapshot(path="/nonexistent/path/12345.json") + assert result is None + + +def test_read_snapshot_returns_data(): + """read_snapshot returns the parsed JSON when the file exists.""" + from lib.pre_stop import read_snapshot + + data = {"workspace_id": "test-ws", "current_task": "test"} + with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f: + json.dump(data, f) + path = f.name + + try: + result = read_snapshot(path=path) + assert result == data + assert result["workspace_id"] == "test-ws" + finally: + os.unlink(path) + + +def test_delete_snapshot_removes_file(): + """delete_snapshot removes the file and is idempotent on missing file.""" + from lib.pre_stop import delete_snapshot + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = f.name + + delete_snapshot(path=path) + assert not os.path.exists(path), "File should be removed" + + # Idempotent: no error if already absent + delete_snapshot(path=path) + + +def test_write_snapshot_returns_false_on_error(monkeypatch): + """write_snapshot returns False on I/O errors and logs a warning.""" + from lib.pre_stop import build_snapshot, write_snapshot + + hb = _MockHeartbeat() + + # Make the parent dir unreadable to trigger an error. + # We can't easily make /nonexistent readonly, so we mock open(). + import unittest.mock as mock + + snap = build_snapshot(hb, {}) + + with mock.patch("builtins.open", side_effect=OSError("disk full")): + ok = write_snapshot(snap, path="/tmp/fake.json") + assert ok is False, "write_snapshot should return False on error" + + +def test_restore_state_stores_on_adapter(): + """restore_state stores snapshot fields as adapter attributes.""" + from adapter_base import BaseAdapter + + class DummyAdapter(BaseAdapter): + def name(self): return "dummy" + def display_name(self): return "Dummy" + def description(self): return "dummy" + async def setup(self, cfg): pass + async def create_executor(self, cfg): pass + + adapter = DummyAdapter() + snap = { + "session_id": "sess_restored_123", + "transcript_lines": ["line1", "line2"], + "current_task": "Old task", + } + adapter.restore_state(snap) + + assert adapter._snapshot_session_id == "sess_restored_123" + assert adapter._snapshot_transcript == ["line1", "line2"] + + +def test_pre_stop_state_default_returns_empty(): + """Default pre_stop_state (BaseAdapter) returns an empty dict.""" + from adapter_base import BaseAdapter + + class DummyAdapter(BaseAdapter): + def name(self): return "dummy" + def display_name(self): return "Dummy" + def description(self): return "dummy" + async def setup(self, cfg): pass + async def create_executor(self, cfg): pass + + adapter = DummyAdapter() + state = adapter.pre_stop_state() + assert state == {} + + +def test_pre_stop_state_with_executor_session_id(): + """pre_stop_state captures _executor._session_id when available.""" + from adapter_base import BaseAdapter + + class DummyExecutor: + pass + + class DummyAdapter(BaseAdapter): + def name(self): return "dummy" + def display_name(self): return "Dummy" + def description(self): return "dummy" + async def setup(self, cfg): pass + async def create_executor(self, cfg): + # Simulate storing the executor so pre_stop_state can find it + self._executor = DummyExecutor() + self._executor._session_id = "sess_from_executor_456" + return self._executor + + adapter = DummyAdapter() + # Simulate executor was already created + adapter._executor = DummyExecutor() + adapter._executor._session_id = "sess_from_executor_456" + + state = adapter.pre_stop_state() + assert state["session_id"] == "sess_from_executor_456" + + +def test_pre_stop_state_transcript_included(): + """pre_stop_state includes transcript_lines when transcript is supported.""" + from adapter_base import BaseAdapter + + class DummyExecutor: + pass + + class DummyAdapter(BaseAdapter): + def name(self): return "dummy" + def display_name(self): return "Dummy" + def description(self): return "dummy" + async def setup(self, cfg): pass + async def create_executor(self, cfg): + self._executor = DummyExecutor() + return self._executor + + def transcript_lines(self, since=0, limit=100): + return { + "supported": True, + "lines": ["User: test", "Agent: response"], + "cursor": 2, + "more": False, + } + + adapter = DummyAdapter() + adapter._executor = DummyExecutor() + state = adapter.pre_stop_state() + + assert "transcript_lines" in state + assert state["transcript_lines"] == ["User: test", "Agent: response"]