forked from molecule-ai/molecule-core
feat(workspace): pre-stop serialization for pause/resume (closes #1386)
Add a pre-stop hook that captures agent state before container exit and writes a scrubbed snapshot to /configs/.agent_snapshot.json. On restart, the snapshot is loaded and the adapter's restore_state() is called before the A2A server starts. - New lib/pre_stop.py: build_snapshot / write_snapshot / read_snapshot / delete_snapshot + _scrub_value deep-scrubber (uses lib.snapshot_scrub to redact API keys, tokens, and sandbox output before persisting) - BaseAdapter.pre_stop_state(): captures _executor._session_id and recent transcript_lines; overridden by adapters with richer in-memory state - BaseAdapter.restore_state(): stores snapshot fields as adapter attrs for create_executor() to pick up - main.py: calls pre_stop serialization in finally block (after server serves) and restore_state() after adapter setup, before server starts - Added 12 unit tests covering scrub, read/write, adapter integration Co-authored-by: Molecule AI Infra-Runtime-BE <infra-runtime-be@agents.moleculesai.app> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7dd66c91e0
commit
4675402e58
@ -132,6 +132,77 @@ class BaseAdapter(ABC):
|
|||||||
"source": None,
|
"source": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def pre_stop_state(self) -> dict:
|
||||||
|
"""Capture in-memory state for pause/resume serialization.
|
||||||
|
|
||||||
|
Called by main.py's shutdown handler just before the container exits.
|
||||||
|
Returns a dict that will be scrubbed (via lib.snapshot_scrub) and
|
||||||
|
written to /configs/.agent_snapshot.json.
|
||||||
|
|
||||||
|
Default implementation:
|
||||||
|
1. Attempts to read ``self._executor._session_id`` (set by
|
||||||
|
create_executor) and includes it as ``session_id``.
|
||||||
|
2. Includes up to 200 recent transcript lines via transcript_lines().
|
||||||
|
|
||||||
|
Override in adapters that hold additional in-memory state that
|
||||||
|
should survive a container stop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON-serializable dict. All string values are scrubbed before
|
||||||
|
persisting, so it is safe to include raw content from the
|
||||||
|
agent's context.
|
||||||
|
"""
|
||||||
|
from lib.pre_stop import MAX_TRANSCRIPT_LINES
|
||||||
|
|
||||||
|
state: dict = {}
|
||||||
|
|
||||||
|
# Session handle — critical for resuming the Claude Code session.
|
||||||
|
executor = getattr(self, "_executor", None)
|
||||||
|
if executor is not None:
|
||||||
|
session_id = getattr(executor, "_session_id", None)
|
||||||
|
if session_id:
|
||||||
|
state["session_id"] = session_id
|
||||||
|
|
||||||
|
# Recent conversation log — captures where the agent left off.
|
||||||
|
# transcript_lines() may be async; call it synchronously if possible,
|
||||||
|
# otherwise let async adapters override pre_stop_state entirely.
|
||||||
|
try:
|
||||||
|
import inspect as _inspect
|
||||||
|
transcript_fn = self.transcript_lines
|
||||||
|
if _inspect.iscoroutinefunction(transcript_fn):
|
||||||
|
# Async adapter — override pre_stop_state() for transcript access.
|
||||||
|
# The base impl still captures session_id above.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
transcript = transcript_fn(since=0, limit=MAX_TRANSCRIPT_LINES)
|
||||||
|
if transcript.get("supported"):
|
||||||
|
state["transcript_lines"] = transcript.get("lines", [])
|
||||||
|
except Exception:
|
||||||
|
# Best-effort: never let transcript capture failure block serialization.
|
||||||
|
pass
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def restore_state(self, snapshot: dict) -> None:
|
||||||
|
"""Restore in-memory state from a pause/resume snapshot.
|
||||||
|
|
||||||
|
Called by main.py on first boot when /configs/.agent_snapshot.json
|
||||||
|
exists. Gives the adapter a chance to restore session handles,
|
||||||
|
conversation context, or any other in-memory state before the A2A
|
||||||
|
server starts accepting requests.
|
||||||
|
|
||||||
|
Default implementation stores ``snapshot["session_id"]`` and
|
||||||
|
``snapshot["transcript_lines"]`` as ``self._snapshot_session_id``
|
||||||
|
and ``self._snapshot_transcript`` so that ``create_executor()`` or
|
||||||
|
the executor itself can pick them up.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: The scrubbed snapshot dict previously written by
|
||||||
|
pre_stop_state(). All secrets have already been redacted.
|
||||||
|
"""
|
||||||
|
self._snapshot_session_id: str | None = snapshot.get("session_id")
|
||||||
|
self._snapshot_transcript: list | None = snapshot.get("transcript_lines")
|
||||||
|
|
||||||
def register_subagent_hook(self, name: str, spec: dict) -> None:
|
def register_subagent_hook(self, name: str, spec: dict) -> None:
|
||||||
"""Default no-op. DeepAgents overrides to register a sub-agent."""
|
"""Default no-op. DeepAgents overrides to register a sub-agent."""
|
||||||
return None
|
return None
|
||||||
@ -305,5 +376,9 @@ class BaseAdapter(ABC):
|
|||||||
async def create_executor(self, config: AdapterConfig) -> AgentExecutor:
|
async def create_executor(self, config: AdapterConfig) -> AgentExecutor:
|
||||||
"""Create and return an AgentExecutor ready for A2A integration.
|
"""Create and return an AgentExecutor ready for A2A integration.
|
||||||
The returned executor's execute() method will be called by the
|
The returned executor's execute() method will be called by the
|
||||||
A2A server's DefaultRequestHandler."""
|
A2A server's DefaultRequestHandler.
|
||||||
|
|
||||||
|
Subclasses should also store the returned executor as ``self._executor``
|
||||||
|
so ``pre_stop_state()`` can access it for serialization.
|
||||||
|
"""
|
||||||
... # pragma: no cover
|
... # pragma: no cover
|
||||||
|
|||||||
192
workspace/lib/pre_stop.py
Normal file
192
workspace/lib/pre_stop.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""Pre-stop serialization for pause/resume — GH#1391.
|
||||||
|
|
||||||
|
Captures the agent's in-memory state just before the container exits so it
|
||||||
|
survives intentional pause and unplanned restart. All content is scrubbed
|
||||||
|
with lib.snapshot_scrub before being written to disk so that a snapshot blob
|
||||||
|
obtained by an attacker cannot recover API keys, tokens, or arbitrary sandbox
|
||||||
|
output (GH#823).
|
||||||
|
|
||||||
|
State captured
|
||||||
|
--------------
|
||||||
|
- ``workspace_id`` — identity for cross-container restore
|
||||||
|
- ``current_task`` — active task label from heartbeat (what the canvas sees)
|
||||||
|
- ``active_tasks`` — task count
|
||||||
|
- ``session_id`` — SDK session handle (Claude Code); key for full session
|
||||||
|
- ``transcript_lines`` — recent session log lines from the adapter
|
||||||
|
- ``uptime_seconds`` — how long this container has been running
|
||||||
|
- ``timestamp`` — when the snapshot was taken (ISO-8601)
|
||||||
|
|
||||||
|
Scrubbing
|
||||||
|
---------
|
||||||
|
Every text field passes through scrub_snapshot before being written.
|
||||||
|
Sandbox-sourced content (tool=run_code, source=sandbox, [sandbox_output]) is
|
||||||
|
dropped wholesale. Secrets matching the pattern library are replaced with
|
||||||
|
[REDACTED:TYPE] markers.
|
||||||
|
|
||||||
|
Storage
|
||||||
|
-------
|
||||||
|
Snapshots are written to /configs/.agent_snapshot.json by default. The
|
||||||
|
config volume survives container restarts so the file is durable. The path
|
||||||
|
is also overridable via ``AGENT_SNAPSHOT_PATH`` for testing or custom layouts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from .snapshot_scrub import scrub_snapshot
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from heartbeat import HeartbeatLoop
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default snapshot path — on the config volume, survives container restarts.
|
||||||
|
DEFAULT_SNAPSHOT_PATH = os.environ.get(
|
||||||
|
"AGENT_SNAPSHOT_PATH",
|
||||||
|
"/configs/.agent_snapshot.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# How many transcript lines to capture in the snapshot (recent window).
|
||||||
|
MAX_TRANSCRIPT_LINES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def build_snapshot(
|
||||||
|
heartbeat: "HeartbeatLoop | None",
|
||||||
|
adapter_state: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build a raw snapshot dict from live workspace state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
heartbeat: HeartbeatLoop instance; provides current_task, session_id, etc.
|
||||||
|
adapter_state: Arbitrary state dict from the adapter's pre_stop_state() hook.
|
||||||
|
Keys are free-form; all string values in nested dicts/lists are
|
||||||
|
scrubbed before writing.
|
||||||
|
|
||||||
|
Returns a raw (not yet scrubbed) snapshot dict.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
raw: dict[str, Any] = {
|
||||||
|
"workspace_id": os.environ.get("WORKSPACE_ID", "unknown"),
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
# Defaults — heartbeat block below overwrites these when available:
|
||||||
|
"current_task": "",
|
||||||
|
"active_tasks": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if heartbeat is not None:
|
||||||
|
raw["current_task"] = heartbeat.current_task or ""
|
||||||
|
raw["active_tasks"] = heartbeat.active_tasks
|
||||||
|
if hasattr(heartbeat, "start_time"):
|
||||||
|
raw["uptime_seconds"] = int(time.time() - heartbeat.start_time)
|
||||||
|
# session_id lives in the adapter but we also accept it via heartbeat
|
||||||
|
# for convenience (avoids requiring every adapter to pass it separately).
|
||||||
|
if not adapter_state.get("session_id"):
|
||||||
|
raw["session_id"] = getattr(heartbeat, "_session_id", None) or ""
|
||||||
|
|
||||||
|
# Adapter-supplied state (conversation history, reasoning traces, etc.)
|
||||||
|
raw["adapter"] = adapter_state
|
||||||
|
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def _scrub_value(value: Any) -> Any:
|
||||||
|
"""Recursively scrub all secret patterns from a value.
|
||||||
|
|
||||||
|
- Strings: scrub_content() replaces patterns with [REDACTED:TYPE].
|
||||||
|
- Dicts: return a new dict with all values scrubbed recursively.
|
||||||
|
- Lists: drop entries that are sandbox content; scrub remaining items.
|
||||||
|
- Other: pass through unchanged.
|
||||||
|
"""
|
||||||
|
from .snapshot_scrub import is_sandbox_content, scrub_content
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
return scrub_content(value)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: _scrub_value(v) for k, v in value.items()}
|
||||||
|
if isinstance(value, list):
|
||||||
|
result = []
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str) and is_sandbox_content(item):
|
||||||
|
continue # Drop sandbox entries wholesale
|
||||||
|
result.append(_scrub_value(item))
|
||||||
|
return result
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def write_snapshot(
|
||||||
|
snapshot: dict[str, Any],
|
||||||
|
path: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Scrub and write a snapshot to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: Raw snapshot dict from build_snapshot().
|
||||||
|
path: Target file path (default: DEFAULT_SNAPSHOT_PATH).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the snapshot was written successfully; False on any error.
|
||||||
|
Errors are logged but never raise — pre-stop serialization must be
|
||||||
|
best-effort to avoid blocking shutdown.
|
||||||
|
"""
|
||||||
|
target = path or DEFAULT_SNAPSHOT_PATH
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Deep-scrub every string value in the snapshot to remove API keys,
|
||||||
|
# tokens, and arbitrary sandbox output before writing to disk.
|
||||||
|
scrubbed = _scrub_value(snapshot)
|
||||||
|
|
||||||
|
# Ensure parent directory exists.
|
||||||
|
parent = os.path.dirname(target)
|
||||||
|
if parent:
|
||||||
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
|
||||||
|
with open(target, "w") as f:
|
||||||
|
json.dump(scrubbed, f, indent=2, default=str)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Pre-stop snapshot written: %s (workspace=%s, task=%r, lines=%d)",
|
||||||
|
target,
|
||||||
|
scrubbed.get("workspace_id", "?"),
|
||||||
|
scrubbed.get("current_task", ""),
|
||||||
|
len(scrubbed.get("adapter", {}).get("transcript_lines", [])),
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Pre-stop snapshot write failed (%s): %s", target, exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def read_snapshot(
|
||||||
|
path: str | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Read and return a previously-written snapshot, or None if absent/invalid."""
|
||||||
|
target = path or DEFAULT_SNAPSHOT_PATH
|
||||||
|
|
||||||
|
if not os.path.exists(target):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(target) as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Snapshot read failed (%s): %s", target, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_snapshot(path: str | None = None) -> None:
|
||||||
|
"""Remove a snapshot file. Idempotent — no error if absent."""
|
||||||
|
target = path or DEFAULT_SNAPSHOT_PATH
|
||||||
|
try:
|
||||||
|
os.remove(target)
|
||||||
|
logger.debug("Snapshot deleted: %s", target)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Snapshot delete failed (%s): %s", target, exc)
|
||||||
@ -124,6 +124,21 @@ async def main(): # pragma: no cover
|
|||||||
try:
|
try:
|
||||||
await adapter.setup(adapter_config)
|
await adapter.setup(adapter_config)
|
||||||
executor = await adapter.create_executor(adapter_config)
|
executor = await adapter.create_executor(adapter_config)
|
||||||
|
|
||||||
|
# 5b. Restore from pre-stop snapshot if one exists (GH#1391).
|
||||||
|
# The snapshot is scrubbed before being written, so secrets are
|
||||||
|
# already redacted — restore_state must not re-expose them.
|
||||||
|
from lib.pre_stop import read_snapshot
|
||||||
|
snapshot = read_snapshot()
|
||||||
|
if snapshot:
|
||||||
|
try:
|
||||||
|
adapter.restore_state(snapshot)
|
||||||
|
print(
|
||||||
|
f"Pre-stop snapshot restored: task={snapshot.get('current_task', '')!r}, "
|
||||||
|
f"uptime={snapshot.get('uptime_seconds', 0)}s"
|
||||||
|
)
|
||||||
|
except Exception as restore_err:
|
||||||
|
print(f"Warning: snapshot restore failed (continuing): {restore_err}")
|
||||||
except Exception:
|
except Exception:
|
||||||
# heartbeat hasn't started yet but may have async tasks pending
|
# heartbeat hasn't started yet but may have async tasks pending
|
||||||
if hasattr(heartbeat, "stop"):
|
if hasattr(heartbeat, "stop"):
|
||||||
@ -543,6 +558,18 @@ async def main(): # pragma: no cover
|
|||||||
try:
|
try:
|
||||||
await server.serve()
|
await server.serve()
|
||||||
finally:
|
finally:
|
||||||
|
# 10d. Pre-stop serialization — GH#1391.
|
||||||
|
# Capture in-memory state before the container exits so it survives
|
||||||
|
# intentional pause and unplanned restart. All content is scrubbed
|
||||||
|
# via lib.snapshot_scrub before being written to the config volume.
|
||||||
|
try:
|
||||||
|
from lib.pre_stop import build_snapshot, write_snapshot
|
||||||
|
adapter_state = adapter.pre_stop_state() if adapter else {}
|
||||||
|
snapshot = build_snapshot(heartbeat, adapter_state)
|
||||||
|
write_snapshot(snapshot)
|
||||||
|
except Exception as pre_stop_err:
|
||||||
|
print(f"Warning: pre-stop serialization failed (continuing): {pre_stop_err}")
|
||||||
|
|
||||||
# Cancel initial prompt if still running
|
# Cancel initial prompt if still running
|
||||||
if initial_prompt_task and not initial_prompt_task.done():
|
if initial_prompt_task and not initial_prompt_task.done():
|
||||||
initial_prompt_task.cancel()
|
initial_prompt_task.cancel()
|
||||||
|
|||||||
270
workspace/tests/test_pre_stop.py
Normal file
270
workspace/tests/test_pre_stop.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
"""Tests for lib.pre_stop — GH#1391 pre-stop serialization."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class _MockHeartbeat:
|
||||||
|
"""Minimal heartbeat for testing — matches heartbeat.HeartbeatLoop shape."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.current_task = "Implementing feature X"
|
||||||
|
self.active_tasks = 1
|
||||||
|
self.start_time = 1000.0
|
||||||
|
self._session_id = None
|
||||||
|
|
||||||
|
|
||||||
|
class _MockAdapter:
|
||||||
|
"""Minimal adapter that returns known pre_stop_state for testing."""
|
||||||
|
|
||||||
|
def pre_stop_state(self):
|
||||||
|
return {
|
||||||
|
"session_id": "sess_abc123xyz",
|
||||||
|
"transcript_lines": [
|
||||||
|
"User: hello",
|
||||||
|
"Agent: Hi! How can I help?",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_snapshot_basic():
|
||||||
|
"""build_snapshot returns workspace_id, timestamp, and heartbeat fields."""
|
||||||
|
from lib.pre_stop import build_snapshot
|
||||||
|
|
||||||
|
hb = _MockHeartbeat()
|
||||||
|
adapter_state = {"session_id": "sess_abc", "transcript_lines": ["line1"]}
|
||||||
|
snap = build_snapshot(hb, adapter_state)
|
||||||
|
|
||||||
|
assert snap["workspace_id"] == os.environ.get("WORKSPACE_ID", "unknown")
|
||||||
|
assert "timestamp" in snap
|
||||||
|
assert snap["current_task"] == "Implementing feature X"
|
||||||
|
assert snap["active_tasks"] == 1
|
||||||
|
assert snap["adapter"] == adapter_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_snapshot_none_heartbeat():
|
||||||
|
"""build_snapshot handles None heartbeat gracefully."""
|
||||||
|
from lib.pre_stop import build_snapshot
|
||||||
|
|
||||||
|
snap = build_snapshot(None, {"session_id": "sess_xyz"})
|
||||||
|
assert snap["current_task"] == ""
|
||||||
|
assert snap["active_tasks"] == 0
|
||||||
|
# session_id is NOT promoted to top-level when heartbeat is absent;
|
||||||
|
# it stays nested inside adapter.
|
||||||
|
assert "session_id" not in snap
|
||||||
|
assert snap["adapter"]["session_id"] == "sess_xyz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_snapshot_scrubbed_secrets():
|
||||||
|
"""Snapshot content with API keys is scrubbed by write_snapshot."""
|
||||||
|
from lib.pre_stop import build_snapshot, write_snapshot
|
||||||
|
|
||||||
|
hb = _MockHeartbeat()
|
||||||
|
adapter_state = {
|
||||||
|
"session_id": "sess_secret",
|
||||||
|
"transcript_lines": [
|
||||||
|
"Authorization: Bearer abc123.def456.ghi789",
|
||||||
|
"token_used: Bearer xyz.token.placeholder",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
snap = build_snapshot(hb, adapter_state)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
ok = write_snapshot(snap, path=path)
|
||||||
|
assert ok, "write_snapshot should return True on success"
|
||||||
|
|
||||||
|
with open(path) as f:
|
||||||
|
loaded = json.load(f)
|
||||||
|
|
||||||
|
lines = loaded["adapter"]["transcript_lines"]
|
||||||
|
assert not any("Bearer abc" in l for l in lines), "Bearer token should be scrubbed"
|
||||||
|
assert any("REDACTED" in l for l in lines), "Scrub markers should be present"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_snapshot_scrub_drops_sandbox_content():
|
||||||
|
"""Sandbox-sourced transcript lines are dropped entirely."""
|
||||||
|
from lib.pre_stop import build_snapshot, write_snapshot
|
||||||
|
|
||||||
|
hb = _MockHeartbeat()
|
||||||
|
adapter_state = {
|
||||||
|
"session_lines": [
|
||||||
|
"source=sandbox echo hello",
|
||||||
|
"Normal message",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
snap = build_snapshot(hb, adapter_state)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
write_snapshot(snap, path=path)
|
||||||
|
with open(path) as f:
|
||||||
|
loaded = json.load(f)
|
||||||
|
# scrub_snapshot drops sandbox entries from lists
|
||||||
|
lines = loaded["adapter"].get("session_lines", [])
|
||||||
|
assert not any("sandbox" in l for l in lines), "Sandbox lines should be dropped"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_snapshot_missing_returns_none():
|
||||||
|
"""read_snapshot returns None when the file doesn't exist."""
|
||||||
|
from lib.pre_stop import read_snapshot
|
||||||
|
|
||||||
|
result = read_snapshot(path="/nonexistent/path/12345.json")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_snapshot_returns_data():
|
||||||
|
"""read_snapshot returns the parsed JSON when the file exists."""
|
||||||
|
from lib.pre_stop import read_snapshot
|
||||||
|
|
||||||
|
data = {"workspace_id": "test-ws", "current_task": "test"}
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = read_snapshot(path=path)
|
||||||
|
assert result == data
|
||||||
|
assert result["workspace_id"] == "test-ws"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_snapshot_removes_file():
|
||||||
|
"""delete_snapshot removes the file and is idempotent on missing file."""
|
||||||
|
from lib.pre_stop import delete_snapshot
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
|
||||||
|
delete_snapshot(path=path)
|
||||||
|
assert not os.path.exists(path), "File should be removed"
|
||||||
|
|
||||||
|
# Idempotent: no error if already absent
|
||||||
|
delete_snapshot(path=path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_snapshot_returns_false_on_error(monkeypatch):
|
||||||
|
"""write_snapshot returns False on I/O errors and logs a warning."""
|
||||||
|
from lib.pre_stop import build_snapshot, write_snapshot
|
||||||
|
|
||||||
|
hb = _MockHeartbeat()
|
||||||
|
|
||||||
|
# Make the parent dir unreadable to trigger an error.
|
||||||
|
# We can't easily make /nonexistent readonly, so we mock open().
|
||||||
|
import unittest.mock as mock
|
||||||
|
|
||||||
|
snap = build_snapshot(hb, {})
|
||||||
|
|
||||||
|
with mock.patch("builtins.open", side_effect=OSError("disk full")):
|
||||||
|
ok = write_snapshot(snap, path="/tmp/fake.json")
|
||||||
|
assert ok is False, "write_snapshot should return False on error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_restore_state_stores_on_adapter():
|
||||||
|
"""restore_state stores snapshot fields as adapter attributes."""
|
||||||
|
from adapter_base import BaseAdapter
|
||||||
|
|
||||||
|
class DummyAdapter(BaseAdapter):
|
||||||
|
def name(self): return "dummy"
|
||||||
|
def display_name(self): return "Dummy"
|
||||||
|
def description(self): return "dummy"
|
||||||
|
async def setup(self, cfg): pass
|
||||||
|
async def create_executor(self, cfg): pass
|
||||||
|
|
||||||
|
adapter = DummyAdapter()
|
||||||
|
snap = {
|
||||||
|
"session_id": "sess_restored_123",
|
||||||
|
"transcript_lines": ["line1", "line2"],
|
||||||
|
"current_task": "Old task",
|
||||||
|
}
|
||||||
|
adapter.restore_state(snap)
|
||||||
|
|
||||||
|
assert adapter._snapshot_session_id == "sess_restored_123"
|
||||||
|
assert adapter._snapshot_transcript == ["line1", "line2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_stop_state_default_returns_empty():
|
||||||
|
"""Default pre_stop_state (BaseAdapter) returns an empty dict."""
|
||||||
|
from adapter_base import BaseAdapter
|
||||||
|
|
||||||
|
class DummyAdapter(BaseAdapter):
|
||||||
|
def name(self): return "dummy"
|
||||||
|
def display_name(self): return "Dummy"
|
||||||
|
def description(self): return "dummy"
|
||||||
|
async def setup(self, cfg): pass
|
||||||
|
async def create_executor(self, cfg): pass
|
||||||
|
|
||||||
|
adapter = DummyAdapter()
|
||||||
|
state = adapter.pre_stop_state()
|
||||||
|
assert state == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_stop_state_with_executor_session_id():
|
||||||
|
"""pre_stop_state captures _executor._session_id when available."""
|
||||||
|
from adapter_base import BaseAdapter
|
||||||
|
|
||||||
|
class DummyExecutor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DummyAdapter(BaseAdapter):
|
||||||
|
def name(self): return "dummy"
|
||||||
|
def display_name(self): return "Dummy"
|
||||||
|
def description(self): return "dummy"
|
||||||
|
async def setup(self, cfg): pass
|
||||||
|
async def create_executor(self, cfg):
|
||||||
|
# Simulate storing the executor so pre_stop_state can find it
|
||||||
|
self._executor = DummyExecutor()
|
||||||
|
self._executor._session_id = "sess_from_executor_456"
|
||||||
|
return self._executor
|
||||||
|
|
||||||
|
adapter = DummyAdapter()
|
||||||
|
# Simulate executor was already created
|
||||||
|
adapter._executor = DummyExecutor()
|
||||||
|
adapter._executor._session_id = "sess_from_executor_456"
|
||||||
|
|
||||||
|
state = adapter.pre_stop_state()
|
||||||
|
assert state["session_id"] == "sess_from_executor_456"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_stop_state_transcript_included():
|
||||||
|
"""pre_stop_state includes transcript_lines when transcript is supported."""
|
||||||
|
from adapter_base import BaseAdapter
|
||||||
|
|
||||||
|
class DummyExecutor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DummyAdapter(BaseAdapter):
|
||||||
|
def name(self): return "dummy"
|
||||||
|
def display_name(self): return "Dummy"
|
||||||
|
def description(self): return "dummy"
|
||||||
|
async def setup(self, cfg): pass
|
||||||
|
async def create_executor(self, cfg):
|
||||||
|
self._executor = DummyExecutor()
|
||||||
|
return self._executor
|
||||||
|
|
||||||
|
def transcript_lines(self, since=0, limit=100):
|
||||||
|
return {
|
||||||
|
"supported": True,
|
||||||
|
"lines": ["User: test", "Agent: response"],
|
||||||
|
"cursor": 2,
|
||||||
|
"more": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter = DummyAdapter()
|
||||||
|
adapter._executor = DummyExecutor()
|
||||||
|
state = adapter.pre_stop_state()
|
||||||
|
|
||||||
|
assert "transcript_lines" in state
|
||||||
|
assert state["transcript_lines"] == ["User: test", "Agent: response"]
|
||||||
Loading…
Reference in New Issue
Block a user