feat(workspace): pre-stop serialization for pause/resume (closes #1386)

Add a pre-stop hook that captures agent state before container exit and
writes a scrubbed snapshot to /configs/.agent_snapshot.json. On restart,
the snapshot is loaded and the adapter's restore_state() is called before
the A2A server starts.

- New lib/pre_stop.py: build_snapshot / write_snapshot / read_snapshot /
  delete_snapshot + _scrub_value deep-scrubber (uses lib.snapshot_scrub
  to redact API keys, tokens, and sandbox output before persisting)
- BaseAdapter.pre_stop_state(): captures _executor._session_id and recent
  transcript_lines; overridden by adapters with richer in-memory state
- BaseAdapter.restore_state(): stores snapshot fields as adapter attrs
  for create_executor() to pick up
- main.py: calls pre_stop serialization in finally block (after server
  serves) and restore_state() after adapter setup, before server starts
- Added 12 unit tests covering scrub, read/write, adapter integration

Co-authored-by: Molecule AI Infra-Runtime-BE <infra-runtime-be@agents.moleculesai.app>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
molecule-ai[bot] 2026-04-21 12:40:44 +00:00 committed by GitHub
parent 7dd66c91e0
commit 4675402e58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 565 additions and 1 deletions

View File

@ -132,6 +132,77 @@ class BaseAdapter(ABC):
"source": None,
}
def pre_stop_state(self) -> dict:
"""Capture in-memory state for pause/resume serialization.
Called by main.py's shutdown handler just before the container exits.
Returns a dict that will be scrubbed (via lib.snapshot_scrub) and
written to /configs/.agent_snapshot.json.
Default implementation:
1. Attempts to read ``self._executor._session_id`` (set by
create_executor) and includes it as ``session_id``.
2. Includes up to 200 recent transcript lines via transcript_lines().
Override in adapters that hold additional in-memory state that
should survive a container stop.
Returns:
A JSON-serializable dict. All string values are scrubbed before
persisting, so it is safe to include raw content from the
agent's context.
"""
from lib.pre_stop import MAX_TRANSCRIPT_LINES
state: dict = {}
# Session handle — critical for resuming the Claude Code session.
executor = getattr(self, "_executor", None)
if executor is not None:
session_id = getattr(executor, "_session_id", None)
if session_id:
state["session_id"] = session_id
# Recent conversation log — captures where the agent left off.
# transcript_lines() may be async; call it synchronously if possible,
# otherwise let async adapters override pre_stop_state entirely.
try:
import inspect as _inspect
transcript_fn = self.transcript_lines
if _inspect.iscoroutinefunction(transcript_fn):
# Async adapter — override pre_stop_state() for transcript access.
# The base impl still captures session_id above.
pass
else:
transcript = transcript_fn(since=0, limit=MAX_TRANSCRIPT_LINES)
if transcript.get("supported"):
state["transcript_lines"] = transcript.get("lines", [])
except Exception:
# Best-effort: never let transcript capture failure block serialization.
pass
return state
def restore_state(self, snapshot: dict) -> None:
"""Restore in-memory state from a pause/resume snapshot.
Called by main.py on first boot when /configs/.agent_snapshot.json
exists. Gives the adapter a chance to restore session handles,
conversation context, or any other in-memory state before the A2A
server starts accepting requests.
Default implementation stores ``snapshot["session_id"]`` and
``snapshot["transcript_lines"]`` as ``self._snapshot_session_id``
and ``self._snapshot_transcript`` so that ``create_executor()`` or
the executor itself can pick them up.
Args:
snapshot: The scrubbed snapshot dict previously written by
pre_stop_state(). All secrets have already been redacted.
"""
self._snapshot_session_id: str | None = snapshot.get("session_id")
self._snapshot_transcript: list | None = snapshot.get("transcript_lines")
def register_subagent_hook(self, name: str, spec: dict) -> None:
"""Default no-op. DeepAgents overrides to register a sub-agent."""
return None
@ -305,5 +376,9 @@ class BaseAdapter(ABC):
async def create_executor(self, config: AdapterConfig) -> AgentExecutor:
"""Create and return an AgentExecutor ready for A2A integration.
The returned executor's execute() method will be called by the
A2A server's DefaultRequestHandler."""
A2A server's DefaultRequestHandler.
Subclasses should also store the returned executor as ``self._executor``
so ``pre_stop_state()`` can access it for serialization.
"""
... # pragma: no cover

192
workspace/lib/pre_stop.py Normal file
View File

@ -0,0 +1,192 @@
"""Pre-stop serialization for pause/resume — GH#1391.
Captures the agent's in-memory state just before the container exits so it
survives intentional pause and unplanned restart. All content is scrubbed
with lib.snapshot_scrub before being written to disk so that a snapshot blob
obtained by an attacker cannot recover API keys, tokens, or arbitrary sandbox
output (GH#823).
State captured
--------------
- ``workspace_id`` identity for cross-container restore
- ``current_task`` active task label from heartbeat (what the canvas sees)
- ``active_tasks`` task count
- ``session_id`` SDK session handle (Claude Code); key for full session
- ``transcript_lines`` recent session log lines from the adapter
- ``uptime_seconds`` how long this container has been running
- ``timestamp`` when the snapshot was taken (ISO-8601)
Scrubbing
---------
Every text field passes through scrub_snapshot before being written.
Sandbox-sourced content (tool=run_code, source=sandbox, [sandbox_output]) is
dropped wholesale. Secrets matching the pattern library are replaced with
[REDACTED:TYPE] markers.
Storage
-------
Snapshots are written to /configs/.agent_snapshot.json by default. The
config volume survives container restarts so the file is durable. The path
is also overridable via ``AGENT_SNAPSHOT_PATH`` for testing or custom layouts.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from .snapshot_scrub import scrub_snapshot
if TYPE_CHECKING:
from heartbeat import HeartbeatLoop
logger = logging.getLogger(__name__)
# Default snapshot path — on the config volume, survives container restarts.
DEFAULT_SNAPSHOT_PATH = os.environ.get(
"AGENT_SNAPSHOT_PATH",
"/configs/.agent_snapshot.json",
)
# How many transcript lines to capture in the snapshot (recent window).
MAX_TRANSCRIPT_LINES = 200
def build_snapshot(
heartbeat: "HeartbeatLoop | None",
adapter_state: dict[str, Any],
) -> dict[str, Any]:
"""Build a raw snapshot dict from live workspace state.
Args:
heartbeat: HeartbeatLoop instance; provides current_task, session_id, etc.
adapter_state: Arbitrary state dict from the adapter's pre_stop_state() hook.
Keys are free-form; all string values in nested dicts/lists are
scrubbed before writing.
Returns a raw (not yet scrubbed) snapshot dict.
"""
import time
raw: dict[str, Any] = {
"workspace_id": os.environ.get("WORKSPACE_ID", "unknown"),
"timestamp": datetime.now(timezone.utc).isoformat(),
# Defaults — heartbeat block below overwrites these when available:
"current_task": "",
"active_tasks": 0,
}
if heartbeat is not None:
raw["current_task"] = heartbeat.current_task or ""
raw["active_tasks"] = heartbeat.active_tasks
if hasattr(heartbeat, "start_time"):
raw["uptime_seconds"] = int(time.time() - heartbeat.start_time)
# session_id lives in the adapter but we also accept it via heartbeat
# for convenience (avoids requiring every adapter to pass it separately).
if not adapter_state.get("session_id"):
raw["session_id"] = getattr(heartbeat, "_session_id", None) or ""
# Adapter-supplied state (conversation history, reasoning traces, etc.)
raw["adapter"] = adapter_state
return raw
def _scrub_value(value: Any) -> Any:
"""Recursively scrub all secret patterns from a value.
- Strings: scrub_content() replaces patterns with [REDACTED:TYPE].
- Dicts: return a new dict with all values scrubbed recursively.
- Lists: drop entries that are sandbox content; scrub remaining items.
- Other: pass through unchanged.
"""
from .snapshot_scrub import is_sandbox_content, scrub_content
if isinstance(value, str):
return scrub_content(value)
if isinstance(value, dict):
return {k: _scrub_value(v) for k, v in value.items()}
if isinstance(value, list):
result = []
for item in value:
if isinstance(item, str) and is_sandbox_content(item):
continue # Drop sandbox entries wholesale
result.append(_scrub_value(item))
return result
return value
def write_snapshot(
snapshot: dict[str, Any],
path: str | None = None,
) -> bool:
"""Scrub and write a snapshot to disk.
Args:
snapshot: Raw snapshot dict from build_snapshot().
path: Target file path (default: DEFAULT_SNAPSHOT_PATH).
Returns:
True if the snapshot was written successfully; False on any error.
Errors are logged but never raise pre-stop serialization must be
best-effort to avoid blocking shutdown.
"""
target = path or DEFAULT_SNAPSHOT_PATH
try:
# Deep-scrub every string value in the snapshot to remove API keys,
# tokens, and arbitrary sandbox output before writing to disk.
scrubbed = _scrub_value(snapshot)
# Ensure parent directory exists.
parent = os.path.dirname(target)
if parent:
os.makedirs(parent, exist_ok=True)
with open(target, "w") as f:
json.dump(scrubbed, f, indent=2, default=str)
logger.info(
"Pre-stop snapshot written: %s (workspace=%s, task=%r, lines=%d)",
target,
scrubbed.get("workspace_id", "?"),
scrubbed.get("current_task", ""),
len(scrubbed.get("adapter", {}).get("transcript_lines", [])),
)
return True
except Exception as exc:
logger.warning("Pre-stop snapshot write failed (%s): %s", target, exc)
return False
def read_snapshot(
path: str | None = None,
) -> dict[str, Any] | None:
"""Read and return a previously-written snapshot, or None if absent/invalid."""
target = path or DEFAULT_SNAPSHOT_PATH
if not os.path.exists(target):
return None
try:
with open(target) as f:
return json.load(f)
except Exception as exc:
logger.debug("Snapshot read failed (%s): %s", target, exc)
return None
def delete_snapshot(path: str | None = None) -> None:
"""Remove a snapshot file. Idempotent — no error if absent."""
target = path or DEFAULT_SNAPSHOT_PATH
try:
os.remove(target)
logger.debug("Snapshot deleted: %s", target)
except FileNotFoundError:
pass
except Exception as exc:
logger.warning("Snapshot delete failed (%s): %s", target, exc)

View File

@ -124,6 +124,21 @@ async def main(): # pragma: no cover
try:
await adapter.setup(adapter_config)
executor = await adapter.create_executor(adapter_config)
# 5b. Restore from pre-stop snapshot if one exists (GH#1391).
# The snapshot is scrubbed before being written, so secrets are
# already redacted — restore_state must not re-expose them.
from lib.pre_stop import read_snapshot
snapshot = read_snapshot()
if snapshot:
try:
adapter.restore_state(snapshot)
print(
f"Pre-stop snapshot restored: task={snapshot.get('current_task', '')!r}, "
f"uptime={snapshot.get('uptime_seconds', 0)}s"
)
except Exception as restore_err:
print(f"Warning: snapshot restore failed (continuing): {restore_err}")
except Exception:
# heartbeat hasn't started yet but may have async tasks pending
if hasattr(heartbeat, "stop"):
@ -543,6 +558,18 @@ async def main(): # pragma: no cover
try:
await server.serve()
finally:
# 10d. Pre-stop serialization — GH#1391.
# Capture in-memory state before the container exits so it survives
# intentional pause and unplanned restart. All content is scrubbed
# via lib.snapshot_scrub before being written to the config volume.
try:
from lib.pre_stop import build_snapshot, write_snapshot
adapter_state = adapter.pre_stop_state() if adapter else {}
snapshot = build_snapshot(heartbeat, adapter_state)
write_snapshot(snapshot)
except Exception as pre_stop_err:
print(f"Warning: pre-stop serialization failed (continuing): {pre_stop_err}")
# Cancel initial prompt if still running
if initial_prompt_task and not initial_prompt_task.done():
initial_prompt_task.cancel()

View File

@ -0,0 +1,270 @@
"""Tests for lib.pre_stop — GH#1391 pre-stop serialization."""
import json
import os
import tempfile
import pytest
class _MockHeartbeat:
"""Minimal heartbeat for testing — matches heartbeat.HeartbeatLoop shape."""
def __init__(self):
self.current_task = "Implementing feature X"
self.active_tasks = 1
self.start_time = 1000.0
self._session_id = None
class _MockAdapter:
"""Minimal adapter that returns known pre_stop_state for testing."""
def pre_stop_state(self):
return {
"session_id": "sess_abc123xyz",
"transcript_lines": [
"User: hello",
"Agent: Hi! How can I help?",
],
}
def test_build_snapshot_basic():
"""build_snapshot returns workspace_id, timestamp, and heartbeat fields."""
from lib.pre_stop import build_snapshot
hb = _MockHeartbeat()
adapter_state = {"session_id": "sess_abc", "transcript_lines": ["line1"]}
snap = build_snapshot(hb, adapter_state)
assert snap["workspace_id"] == os.environ.get("WORKSPACE_ID", "unknown")
assert "timestamp" in snap
assert snap["current_task"] == "Implementing feature X"
assert snap["active_tasks"] == 1
assert snap["adapter"] == adapter_state
def test_build_snapshot_none_heartbeat():
"""build_snapshot handles None heartbeat gracefully."""
from lib.pre_stop import build_snapshot
snap = build_snapshot(None, {"session_id": "sess_xyz"})
assert snap["current_task"] == ""
assert snap["active_tasks"] == 0
# session_id is NOT promoted to top-level when heartbeat is absent;
# it stays nested inside adapter.
assert "session_id" not in snap
assert snap["adapter"]["session_id"] == "sess_xyz"
def test_build_snapshot_scrubbed_secrets():
"""Snapshot content with API keys is scrubbed by write_snapshot."""
from lib.pre_stop import build_snapshot, write_snapshot
hb = _MockHeartbeat()
adapter_state = {
"session_id": "sess_secret",
"transcript_lines": [
"Authorization: Bearer abc123.def456.ghi789",
"token_used: Bearer xyz.token.placeholder",
],
}
snap = build_snapshot(hb, adapter_state)
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
path = f.name
try:
ok = write_snapshot(snap, path=path)
assert ok, "write_snapshot should return True on success"
with open(path) as f:
loaded = json.load(f)
lines = loaded["adapter"]["transcript_lines"]
assert not any("Bearer abc" in l for l in lines), "Bearer token should be scrubbed"
assert any("REDACTED" in l for l in lines), "Scrub markers should be present"
finally:
os.unlink(path)
def test_build_snapshot_scrub_drops_sandbox_content():
"""Sandbox-sourced transcript lines are dropped entirely."""
from lib.pre_stop import build_snapshot, write_snapshot
hb = _MockHeartbeat()
adapter_state = {
"session_lines": [
"source=sandbox echo hello",
"Normal message",
],
}
snap = build_snapshot(hb, adapter_state)
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
path = f.name
try:
write_snapshot(snap, path=path)
with open(path) as f:
loaded = json.load(f)
# scrub_snapshot drops sandbox entries from lists
lines = loaded["adapter"].get("session_lines", [])
assert not any("sandbox" in l for l in lines), "Sandbox lines should be dropped"
finally:
os.unlink(path)
def test_read_snapshot_missing_returns_none():
"""read_snapshot returns None when the file doesn't exist."""
from lib.pre_stop import read_snapshot
result = read_snapshot(path="/nonexistent/path/12345.json")
assert result is None
def test_read_snapshot_returns_data():
"""read_snapshot returns the parsed JSON when the file exists."""
from lib.pre_stop import read_snapshot
data = {"workspace_id": "test-ws", "current_task": "test"}
with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f:
json.dump(data, f)
path = f.name
try:
result = read_snapshot(path=path)
assert result == data
assert result["workspace_id"] == "test-ws"
finally:
os.unlink(path)
def test_delete_snapshot_removes_file():
"""delete_snapshot removes the file and is idempotent on missing file."""
from lib.pre_stop import delete_snapshot
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
path = f.name
delete_snapshot(path=path)
assert not os.path.exists(path), "File should be removed"
# Idempotent: no error if already absent
delete_snapshot(path=path)
def test_write_snapshot_returns_false_on_error(monkeypatch):
"""write_snapshot returns False on I/O errors and logs a warning."""
from lib.pre_stop import build_snapshot, write_snapshot
hb = _MockHeartbeat()
# Make the parent dir unreadable to trigger an error.
# We can't easily make /nonexistent readonly, so we mock open().
import unittest.mock as mock
snap = build_snapshot(hb, {})
with mock.patch("builtins.open", side_effect=OSError("disk full")):
ok = write_snapshot(snap, path="/tmp/fake.json")
assert ok is False, "write_snapshot should return False on error"
def test_restore_state_stores_on_adapter():
"""restore_state stores snapshot fields as adapter attributes."""
from adapter_base import BaseAdapter
class DummyAdapter(BaseAdapter):
def name(self): return "dummy"
def display_name(self): return "Dummy"
def description(self): return "dummy"
async def setup(self, cfg): pass
async def create_executor(self, cfg): pass
adapter = DummyAdapter()
snap = {
"session_id": "sess_restored_123",
"transcript_lines": ["line1", "line2"],
"current_task": "Old task",
}
adapter.restore_state(snap)
assert adapter._snapshot_session_id == "sess_restored_123"
assert adapter._snapshot_transcript == ["line1", "line2"]
def test_pre_stop_state_default_returns_empty():
"""Default pre_stop_state (BaseAdapter) returns an empty dict."""
from adapter_base import BaseAdapter
class DummyAdapter(BaseAdapter):
def name(self): return "dummy"
def display_name(self): return "Dummy"
def description(self): return "dummy"
async def setup(self, cfg): pass
async def create_executor(self, cfg): pass
adapter = DummyAdapter()
state = adapter.pre_stop_state()
assert state == {}
def test_pre_stop_state_with_executor_session_id():
"""pre_stop_state captures _executor._session_id when available."""
from adapter_base import BaseAdapter
class DummyExecutor:
pass
class DummyAdapter(BaseAdapter):
def name(self): return "dummy"
def display_name(self): return "Dummy"
def description(self): return "dummy"
async def setup(self, cfg): pass
async def create_executor(self, cfg):
# Simulate storing the executor so pre_stop_state can find it
self._executor = DummyExecutor()
self._executor._session_id = "sess_from_executor_456"
return self._executor
adapter = DummyAdapter()
# Simulate executor was already created
adapter._executor = DummyExecutor()
adapter._executor._session_id = "sess_from_executor_456"
state = adapter.pre_stop_state()
assert state["session_id"] == "sess_from_executor_456"
def test_pre_stop_state_transcript_included():
"""pre_stop_state includes transcript_lines when transcript is supported."""
from adapter_base import BaseAdapter
class DummyExecutor:
pass
class DummyAdapter(BaseAdapter):
def name(self): return "dummy"
def display_name(self): return "Dummy"
def description(self): return "dummy"
async def setup(self, cfg): pass
async def create_executor(self, cfg):
self._executor = DummyExecutor()
return self._executor
def transcript_lines(self, since=0, limit=100):
return {
"supported": True,
"lines": ["User: test", "Agent: response"],
"cursor": 2,
"more": False,
}
adapter = DummyAdapter()
adapter._executor = DummyExecutor()
state = adapter.pre_stop_state()
assert "transcript_lines" in state
assert state["transcript_lines"] == ["User: test", "Agent: response"]