feat(workspace): pre-stop serialization for pause/resume (closes #1386)
Add a pre-stop hook that captures agent state before container exit and writes a scrubbed snapshot to /configs/.agent_snapshot.json. On restart, the snapshot is loaded and the adapter's restore_state() is called before the A2A server starts. - New lib/pre_stop.py: build_snapshot / write_snapshot / read_snapshot / delete_snapshot + _scrub_value deep-scrubber (uses lib.snapshot_scrub to redact API keys, tokens, and sandbox output before persisting) - BaseAdapter.pre_stop_state(): captures _executor._session_id and recent transcript_lines; overridden by adapters with richer in-memory state - BaseAdapter.restore_state(): stores snapshot fields as adapter attrs for create_executor() to pick up - main.py: calls pre_stop serialization in finally block (after server serves) and restore_state() after adapter setup, before server starts - Added 12 unit tests covering scrub, read/write, adapter integration Co-authored-by: Molecule AI Infra-Runtime-BE <infra-runtime-be@agents.moleculesai.app> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7dd66c91e0
commit
4675402e58
@ -132,6 +132,77 @@ class BaseAdapter(ABC):
|
||||
"source": None,
|
||||
}
|
||||
|
||||
def pre_stop_state(self) -> dict:
|
||||
"""Capture in-memory state for pause/resume serialization.
|
||||
|
||||
Called by main.py's shutdown handler just before the container exits.
|
||||
Returns a dict that will be scrubbed (via lib.snapshot_scrub) and
|
||||
written to /configs/.agent_snapshot.json.
|
||||
|
||||
Default implementation:
|
||||
1. Attempts to read ``self._executor._session_id`` (set by
|
||||
create_executor) and includes it as ``session_id``.
|
||||
2. Includes up to 200 recent transcript lines via transcript_lines().
|
||||
|
||||
Override in adapters that hold additional in-memory state that
|
||||
should survive a container stop.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable dict. All string values are scrubbed before
|
||||
persisting, so it is safe to include raw content from the
|
||||
agent's context.
|
||||
"""
|
||||
from lib.pre_stop import MAX_TRANSCRIPT_LINES
|
||||
|
||||
state: dict = {}
|
||||
|
||||
# Session handle — critical for resuming the Claude Code session.
|
||||
executor = getattr(self, "_executor", None)
|
||||
if executor is not None:
|
||||
session_id = getattr(executor, "_session_id", None)
|
||||
if session_id:
|
||||
state["session_id"] = session_id
|
||||
|
||||
# Recent conversation log — captures where the agent left off.
|
||||
# transcript_lines() may be async; call it synchronously if possible,
|
||||
# otherwise let async adapters override pre_stop_state entirely.
|
||||
try:
|
||||
import inspect as _inspect
|
||||
transcript_fn = self.transcript_lines
|
||||
if _inspect.iscoroutinefunction(transcript_fn):
|
||||
# Async adapter — override pre_stop_state() for transcript access.
|
||||
# The base impl still captures session_id above.
|
||||
pass
|
||||
else:
|
||||
transcript = transcript_fn(since=0, limit=MAX_TRANSCRIPT_LINES)
|
||||
if transcript.get("supported"):
|
||||
state["transcript_lines"] = transcript.get("lines", [])
|
||||
except Exception:
|
||||
# Best-effort: never let transcript capture failure block serialization.
|
||||
pass
|
||||
|
||||
return state
|
||||
|
||||
def restore_state(self, snapshot: dict) -> None:
|
||||
"""Restore in-memory state from a pause/resume snapshot.
|
||||
|
||||
Called by main.py on first boot when /configs/.agent_snapshot.json
|
||||
exists. Gives the adapter a chance to restore session handles,
|
||||
conversation context, or any other in-memory state before the A2A
|
||||
server starts accepting requests.
|
||||
|
||||
Default implementation stores ``snapshot["session_id"]`` and
|
||||
``snapshot["transcript_lines"]`` as ``self._snapshot_session_id``
|
||||
and ``self._snapshot_transcript`` so that ``create_executor()`` or
|
||||
the executor itself can pick them up.
|
||||
|
||||
Args:
|
||||
snapshot: The scrubbed snapshot dict previously written by
|
||||
pre_stop_state(). All secrets have already been redacted.
|
||||
"""
|
||||
self._snapshot_session_id: str | None = snapshot.get("session_id")
|
||||
self._snapshot_transcript: list | None = snapshot.get("transcript_lines")
|
||||
|
||||
def register_subagent_hook(self, name: str, spec: dict) -> None:
|
||||
"""Default no-op. DeepAgents overrides to register a sub-agent."""
|
||||
return None
|
||||
@ -305,5 +376,9 @@ class BaseAdapter(ABC):
|
||||
async def create_executor(self, config: AdapterConfig) -> AgentExecutor:
|
||||
"""Create and return an AgentExecutor ready for A2A integration.
|
||||
The returned executor's execute() method will be called by the
|
||||
A2A server's DefaultRequestHandler."""
|
||||
A2A server's DefaultRequestHandler.
|
||||
|
||||
Subclasses should also store the returned executor as ``self._executor``
|
||||
so ``pre_stop_state()`` can access it for serialization.
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
192
workspace/lib/pre_stop.py
Normal file
192
workspace/lib/pre_stop.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Pre-stop serialization for pause/resume — GH#1391.
|
||||
|
||||
Captures the agent's in-memory state just before the container exits so it
|
||||
survives intentional pause and unplanned restart. All content is scrubbed
|
||||
with lib.snapshot_scrub before being written to disk so that a snapshot blob
|
||||
obtained by an attacker cannot recover API keys, tokens, or arbitrary sandbox
|
||||
output (GH#823).
|
||||
|
||||
State captured
|
||||
--------------
|
||||
- ``workspace_id`` — identity for cross-container restore
|
||||
- ``current_task`` — active task label from heartbeat (what the canvas sees)
|
||||
- ``active_tasks`` — task count
|
||||
- ``session_id`` — SDK session handle (Claude Code); key for full session
|
||||
- ``transcript_lines`` — recent session log lines from the adapter
|
||||
- ``uptime_seconds`` — how long this container has been running
|
||||
- ``timestamp`` — when the snapshot was taken (ISO-8601)
|
||||
|
||||
Scrubbing
|
||||
---------
|
||||
Every text field passes through scrub_snapshot before being written.
|
||||
Sandbox-sourced content (tool=run_code, source=sandbox, [sandbox_output]) is
|
||||
dropped wholesale. Secrets matching the pattern library are replaced with
|
||||
[REDACTED:TYPE] markers.
|
||||
|
||||
Storage
|
||||
-------
|
||||
Snapshots are written to /configs/.agent_snapshot.json by default. The
|
||||
config volume survives container restarts so the file is durable. The path
|
||||
is also overridable via ``AGENT_SNAPSHOT_PATH`` for testing or custom layouts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .snapshot_scrub import scrub_snapshot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from heartbeat import HeartbeatLoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default snapshot path — on the config volume, survives container restarts.
|
||||
DEFAULT_SNAPSHOT_PATH = os.environ.get(
|
||||
"AGENT_SNAPSHOT_PATH",
|
||||
"/configs/.agent_snapshot.json",
|
||||
)
|
||||
|
||||
# How many transcript lines to capture in the snapshot (recent window).
|
||||
MAX_TRANSCRIPT_LINES = 200
|
||||
|
||||
|
||||
def build_snapshot(
|
||||
heartbeat: "HeartbeatLoop | None",
|
||||
adapter_state: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build a raw snapshot dict from live workspace state.
|
||||
|
||||
Args:
|
||||
heartbeat: HeartbeatLoop instance; provides current_task, session_id, etc.
|
||||
adapter_state: Arbitrary state dict from the adapter's pre_stop_state() hook.
|
||||
Keys are free-form; all string values in nested dicts/lists are
|
||||
scrubbed before writing.
|
||||
|
||||
Returns a raw (not yet scrubbed) snapshot dict.
|
||||
"""
|
||||
import time
|
||||
|
||||
raw: dict[str, Any] = {
|
||||
"workspace_id": os.environ.get("WORKSPACE_ID", "unknown"),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
# Defaults — heartbeat block below overwrites these when available:
|
||||
"current_task": "",
|
||||
"active_tasks": 0,
|
||||
}
|
||||
|
||||
if heartbeat is not None:
|
||||
raw["current_task"] = heartbeat.current_task or ""
|
||||
raw["active_tasks"] = heartbeat.active_tasks
|
||||
if hasattr(heartbeat, "start_time"):
|
||||
raw["uptime_seconds"] = int(time.time() - heartbeat.start_time)
|
||||
# session_id lives in the adapter but we also accept it via heartbeat
|
||||
# for convenience (avoids requiring every adapter to pass it separately).
|
||||
if not adapter_state.get("session_id"):
|
||||
raw["session_id"] = getattr(heartbeat, "_session_id", None) or ""
|
||||
|
||||
# Adapter-supplied state (conversation history, reasoning traces, etc.)
|
||||
raw["adapter"] = adapter_state
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def _scrub_value(value: Any) -> Any:
|
||||
"""Recursively scrub all secret patterns from a value.
|
||||
|
||||
- Strings: scrub_content() replaces patterns with [REDACTED:TYPE].
|
||||
- Dicts: return a new dict with all values scrubbed recursively.
|
||||
- Lists: drop entries that are sandbox content; scrub remaining items.
|
||||
- Other: pass through unchanged.
|
||||
"""
|
||||
from .snapshot_scrub import is_sandbox_content, scrub_content
|
||||
|
||||
if isinstance(value, str):
|
||||
return scrub_content(value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _scrub_value(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
result = []
|
||||
for item in value:
|
||||
if isinstance(item, str) and is_sandbox_content(item):
|
||||
continue # Drop sandbox entries wholesale
|
||||
result.append(_scrub_value(item))
|
||||
return result
|
||||
return value
|
||||
|
||||
|
||||
def write_snapshot(
|
||||
snapshot: dict[str, Any],
|
||||
path: str | None = None,
|
||||
) -> bool:
|
||||
"""Scrub and write a snapshot to disk.
|
||||
|
||||
Args:
|
||||
snapshot: Raw snapshot dict from build_snapshot().
|
||||
path: Target file path (default: DEFAULT_SNAPSHOT_PATH).
|
||||
|
||||
Returns:
|
||||
True if the snapshot was written successfully; False on any error.
|
||||
Errors are logged but never raise — pre-stop serialization must be
|
||||
best-effort to avoid blocking shutdown.
|
||||
"""
|
||||
target = path or DEFAULT_SNAPSHOT_PATH
|
||||
|
||||
try:
|
||||
# Deep-scrub every string value in the snapshot to remove API keys,
|
||||
# tokens, and arbitrary sandbox output before writing to disk.
|
||||
scrubbed = _scrub_value(snapshot)
|
||||
|
||||
# Ensure parent directory exists.
|
||||
parent = os.path.dirname(target)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
with open(target, "w") as f:
|
||||
json.dump(scrubbed, f, indent=2, default=str)
|
||||
|
||||
logger.info(
|
||||
"Pre-stop snapshot written: %s (workspace=%s, task=%r, lines=%d)",
|
||||
target,
|
||||
scrubbed.get("workspace_id", "?"),
|
||||
scrubbed.get("current_task", ""),
|
||||
len(scrubbed.get("adapter", {}).get("transcript_lines", [])),
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Pre-stop snapshot write failed (%s): %s", target, exc)
|
||||
return False
|
||||
|
||||
|
||||
def read_snapshot(
|
||||
path: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Read and return a previously-written snapshot, or None if absent/invalid."""
|
||||
target = path or DEFAULT_SNAPSHOT_PATH
|
||||
|
||||
if not os.path.exists(target):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(target) as f:
|
||||
return json.load(f)
|
||||
except Exception as exc:
|
||||
logger.debug("Snapshot read failed (%s): %s", target, exc)
|
||||
return None
|
||||
|
||||
|
||||
def delete_snapshot(path: str | None = None) -> None:
|
||||
"""Remove a snapshot file. Idempotent — no error if absent."""
|
||||
target = path or DEFAULT_SNAPSHOT_PATH
|
||||
try:
|
||||
os.remove(target)
|
||||
logger.debug("Snapshot deleted: %s", target)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("Snapshot delete failed (%s): %s", target, exc)
|
||||
@ -124,6 +124,21 @@ async def main(): # pragma: no cover
|
||||
try:
|
||||
await adapter.setup(adapter_config)
|
||||
executor = await adapter.create_executor(adapter_config)
|
||||
|
||||
# 5b. Restore from pre-stop snapshot if one exists (GH#1391).
|
||||
# The snapshot is scrubbed before being written, so secrets are
|
||||
# already redacted — restore_state must not re-expose them.
|
||||
from lib.pre_stop import read_snapshot
|
||||
snapshot = read_snapshot()
|
||||
if snapshot:
|
||||
try:
|
||||
adapter.restore_state(snapshot)
|
||||
print(
|
||||
f"Pre-stop snapshot restored: task={snapshot.get('current_task', '')!r}, "
|
||||
f"uptime={snapshot.get('uptime_seconds', 0)}s"
|
||||
)
|
||||
except Exception as restore_err:
|
||||
print(f"Warning: snapshot restore failed (continuing): {restore_err}")
|
||||
except Exception:
|
||||
# heartbeat hasn't started yet but may have async tasks pending
|
||||
if hasattr(heartbeat, "stop"):
|
||||
@ -543,6 +558,18 @@ async def main(): # pragma: no cover
|
||||
try:
|
||||
await server.serve()
|
||||
finally:
|
||||
# 10d. Pre-stop serialization — GH#1391.
|
||||
# Capture in-memory state before the container exits so it survives
|
||||
# intentional pause and unplanned restart. All content is scrubbed
|
||||
# via lib.snapshot_scrub before being written to the config volume.
|
||||
try:
|
||||
from lib.pre_stop import build_snapshot, write_snapshot
|
||||
adapter_state = adapter.pre_stop_state() if adapter else {}
|
||||
snapshot = build_snapshot(heartbeat, adapter_state)
|
||||
write_snapshot(snapshot)
|
||||
except Exception as pre_stop_err:
|
||||
print(f"Warning: pre-stop serialization failed (continuing): {pre_stop_err}")
|
||||
|
||||
# Cancel initial prompt if still running
|
||||
if initial_prompt_task and not initial_prompt_task.done():
|
||||
initial_prompt_task.cancel()
|
||||
|
||||
270
workspace/tests/test_pre_stop.py
Normal file
270
workspace/tests/test_pre_stop.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""Tests for lib.pre_stop — GH#1391 pre-stop serialization."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _MockHeartbeat:
|
||||
"""Minimal heartbeat for testing — matches heartbeat.HeartbeatLoop shape."""
|
||||
|
||||
def __init__(self):
|
||||
self.current_task = "Implementing feature X"
|
||||
self.active_tasks = 1
|
||||
self.start_time = 1000.0
|
||||
self._session_id = None
|
||||
|
||||
|
||||
class _MockAdapter:
|
||||
"""Minimal adapter that returns known pre_stop_state for testing."""
|
||||
|
||||
def pre_stop_state(self):
|
||||
return {
|
||||
"session_id": "sess_abc123xyz",
|
||||
"transcript_lines": [
|
||||
"User: hello",
|
||||
"Agent: Hi! How can I help?",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_build_snapshot_basic():
|
||||
"""build_snapshot returns workspace_id, timestamp, and heartbeat fields."""
|
||||
from lib.pre_stop import build_snapshot
|
||||
|
||||
hb = _MockHeartbeat()
|
||||
adapter_state = {"session_id": "sess_abc", "transcript_lines": ["line1"]}
|
||||
snap = build_snapshot(hb, adapter_state)
|
||||
|
||||
assert snap["workspace_id"] == os.environ.get("WORKSPACE_ID", "unknown")
|
||||
assert "timestamp" in snap
|
||||
assert snap["current_task"] == "Implementing feature X"
|
||||
assert snap["active_tasks"] == 1
|
||||
assert snap["adapter"] == adapter_state
|
||||
|
||||
|
||||
def test_build_snapshot_none_heartbeat():
|
||||
"""build_snapshot handles None heartbeat gracefully."""
|
||||
from lib.pre_stop import build_snapshot
|
||||
|
||||
snap = build_snapshot(None, {"session_id": "sess_xyz"})
|
||||
assert snap["current_task"] == ""
|
||||
assert snap["active_tasks"] == 0
|
||||
# session_id is NOT promoted to top-level when heartbeat is absent;
|
||||
# it stays nested inside adapter.
|
||||
assert "session_id" not in snap
|
||||
assert snap["adapter"]["session_id"] == "sess_xyz"
|
||||
|
||||
|
||||
def test_build_snapshot_scrubbed_secrets():
|
||||
"""Snapshot content with API keys is scrubbed by write_snapshot."""
|
||||
from lib.pre_stop import build_snapshot, write_snapshot
|
||||
|
||||
hb = _MockHeartbeat()
|
||||
adapter_state = {
|
||||
"session_id": "sess_secret",
|
||||
"transcript_lines": [
|
||||
"Authorization: Bearer abc123.def456.ghi789",
|
||||
"token_used: Bearer xyz.token.placeholder",
|
||||
],
|
||||
}
|
||||
snap = build_snapshot(hb, adapter_state)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
|
||||
path = f.name
|
||||
|
||||
try:
|
||||
ok = write_snapshot(snap, path=path)
|
||||
assert ok, "write_snapshot should return True on success"
|
||||
|
||||
with open(path) as f:
|
||||
loaded = json.load(f)
|
||||
|
||||
lines = loaded["adapter"]["transcript_lines"]
|
||||
assert not any("Bearer abc" in l for l in lines), "Bearer token should be scrubbed"
|
||||
assert any("REDACTED" in l for l in lines), "Scrub markers should be present"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_build_snapshot_scrub_drops_sandbox_content():
|
||||
"""Sandbox-sourced transcript lines are dropped entirely."""
|
||||
from lib.pre_stop import build_snapshot, write_snapshot
|
||||
|
||||
hb = _MockHeartbeat()
|
||||
adapter_state = {
|
||||
"session_lines": [
|
||||
"source=sandbox echo hello",
|
||||
"Normal message",
|
||||
],
|
||||
}
|
||||
snap = build_snapshot(hb, adapter_state)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
|
||||
path = f.name
|
||||
|
||||
try:
|
||||
write_snapshot(snap, path=path)
|
||||
with open(path) as f:
|
||||
loaded = json.load(f)
|
||||
# scrub_snapshot drops sandbox entries from lists
|
||||
lines = loaded["adapter"].get("session_lines", [])
|
||||
assert not any("sandbox" in l for l in lines), "Sandbox lines should be dropped"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_read_snapshot_missing_returns_none():
|
||||
"""read_snapshot returns None when the file doesn't exist."""
|
||||
from lib.pre_stop import read_snapshot
|
||||
|
||||
result = read_snapshot(path="/nonexistent/path/12345.json")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_read_snapshot_returns_data():
|
||||
"""read_snapshot returns the parsed JSON when the file exists."""
|
||||
from lib.pre_stop import read_snapshot
|
||||
|
||||
data = {"workspace_id": "test-ws", "current_task": "test"}
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f:
|
||||
json.dump(data, f)
|
||||
path = f.name
|
||||
|
||||
try:
|
||||
result = read_snapshot(path=path)
|
||||
assert result == data
|
||||
assert result["workspace_id"] == "test-ws"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_delete_snapshot_removes_file():
|
||||
"""delete_snapshot removes the file and is idempotent on missing file."""
|
||||
from lib.pre_stop import delete_snapshot
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
|
||||
path = f.name
|
||||
|
||||
delete_snapshot(path=path)
|
||||
assert not os.path.exists(path), "File should be removed"
|
||||
|
||||
# Idempotent: no error if already absent
|
||||
delete_snapshot(path=path)
|
||||
|
||||
|
||||
def test_write_snapshot_returns_false_on_error(monkeypatch):
|
||||
"""write_snapshot returns False on I/O errors and logs a warning."""
|
||||
from lib.pre_stop import build_snapshot, write_snapshot
|
||||
|
||||
hb = _MockHeartbeat()
|
||||
|
||||
# Make the parent dir unreadable to trigger an error.
|
||||
# We can't easily make /nonexistent readonly, so we mock open().
|
||||
import unittest.mock as mock
|
||||
|
||||
snap = build_snapshot(hb, {})
|
||||
|
||||
with mock.patch("builtins.open", side_effect=OSError("disk full")):
|
||||
ok = write_snapshot(snap, path="/tmp/fake.json")
|
||||
assert ok is False, "write_snapshot should return False on error"
|
||||
|
||||
|
||||
def test_restore_state_stores_on_adapter():
|
||||
"""restore_state stores snapshot fields as adapter attributes."""
|
||||
from adapter_base import BaseAdapter
|
||||
|
||||
class DummyAdapter(BaseAdapter):
|
||||
def name(self): return "dummy"
|
||||
def display_name(self): return "Dummy"
|
||||
def description(self): return "dummy"
|
||||
async def setup(self, cfg): pass
|
||||
async def create_executor(self, cfg): pass
|
||||
|
||||
adapter = DummyAdapter()
|
||||
snap = {
|
||||
"session_id": "sess_restored_123",
|
||||
"transcript_lines": ["line1", "line2"],
|
||||
"current_task": "Old task",
|
||||
}
|
||||
adapter.restore_state(snap)
|
||||
|
||||
assert adapter._snapshot_session_id == "sess_restored_123"
|
||||
assert adapter._snapshot_transcript == ["line1", "line2"]
|
||||
|
||||
|
||||
def test_pre_stop_state_default_returns_empty():
|
||||
"""Default pre_stop_state (BaseAdapter) returns an empty dict."""
|
||||
from adapter_base import BaseAdapter
|
||||
|
||||
class DummyAdapter(BaseAdapter):
|
||||
def name(self): return "dummy"
|
||||
def display_name(self): return "Dummy"
|
||||
def description(self): return "dummy"
|
||||
async def setup(self, cfg): pass
|
||||
async def create_executor(self, cfg): pass
|
||||
|
||||
adapter = DummyAdapter()
|
||||
state = adapter.pre_stop_state()
|
||||
assert state == {}
|
||||
|
||||
|
||||
def test_pre_stop_state_with_executor_session_id():
|
||||
"""pre_stop_state captures _executor._session_id when available."""
|
||||
from adapter_base import BaseAdapter
|
||||
|
||||
class DummyExecutor:
|
||||
pass
|
||||
|
||||
class DummyAdapter(BaseAdapter):
|
||||
def name(self): return "dummy"
|
||||
def display_name(self): return "Dummy"
|
||||
def description(self): return "dummy"
|
||||
async def setup(self, cfg): pass
|
||||
async def create_executor(self, cfg):
|
||||
# Simulate storing the executor so pre_stop_state can find it
|
||||
self._executor = DummyExecutor()
|
||||
self._executor._session_id = "sess_from_executor_456"
|
||||
return self._executor
|
||||
|
||||
adapter = DummyAdapter()
|
||||
# Simulate executor was already created
|
||||
adapter._executor = DummyExecutor()
|
||||
adapter._executor._session_id = "sess_from_executor_456"
|
||||
|
||||
state = adapter.pre_stop_state()
|
||||
assert state["session_id"] == "sess_from_executor_456"
|
||||
|
||||
|
||||
def test_pre_stop_state_transcript_included():
|
||||
"""pre_stop_state includes transcript_lines when transcript is supported."""
|
||||
from adapter_base import BaseAdapter
|
||||
|
||||
class DummyExecutor:
|
||||
pass
|
||||
|
||||
class DummyAdapter(BaseAdapter):
|
||||
def name(self): return "dummy"
|
||||
def display_name(self): return "Dummy"
|
||||
def description(self): return "dummy"
|
||||
async def setup(self, cfg): pass
|
||||
async def create_executor(self, cfg):
|
||||
self._executor = DummyExecutor()
|
||||
return self._executor
|
||||
|
||||
def transcript_lines(self, since=0, limit=100):
|
||||
return {
|
||||
"supported": True,
|
||||
"lines": ["User: test", "Agent: response"],
|
||||
"cursor": 2,
|
||||
"more": False,
|
||||
}
|
||||
|
||||
adapter = DummyAdapter()
|
||||
adapter._executor = DummyExecutor()
|
||||
state = adapter.pre_stop_state()
|
||||
|
||||
assert "transcript_lines" in state
|
||||
assert state["transcript_lines"] == ["User: test", "Agent: response"]
|
||||
Loading…
Reference in New Issue
Block a user