feat(sdk): add stop_event parameter to run_heartbeat_loop and run_agent_loop
Resolves KI-009. Both loops now accept a threading.Event that, when set, causes immediate clean exit with return value "stopped". The check is ordered before max_iterations so a signal always wins. New tests: - test_run_loop_exits_on_stop_event: event set before loop — 0 iterations - test_run_loop_respects_stop_event_between_iterations: event set mid-run - test_run_agent_loop_exits_on_stop_event: same for run_agent_loop Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
6c94ceaeee
commit
6a306f310d
@ -258,42 +258,36 @@ def _is_hex(value: str) -> bool:
|
||||
|
||||
## KI-009 — `run_heartbeat_loop()` does not honour external stop signals
|
||||
|
||||
**File:** `molecule_agent/client.py` (`RemoteAgentClient.run_heartbeat_loop`)
|
||||
**Status:** Identified
|
||||
**File:** `molecule_agent/client.py` (`RemoteAgentClient.run_heartbeat_loop`,
|
||||
`RemoteAgentClient.run_agent_loop`)
|
||||
**Status:** ✅ Resolved (PR: `feat/ki-009-stop-event`)
|
||||
**Severity:** Low
|
||||
|
||||
### Symptom
|
||||
`run_heartbeat_loop()` runs an unbounded `while True` loop with `sleep(heartbeat_interval)`
|
||||
between iterations. There is no mechanism for an external caller to signal the loop
|
||||
to exit cleanly. If the MCP client that launched the remote agent disconnects (e.g. via
|
||||
SSE stream close), the heartbeat loop continues indefinitely until `max_iterations` is
|
||||
reached or the process is killed externally.
|
||||
### Resolution
|
||||
Added `stop_event: threading.Event | None = None` parameter to both
|
||||
`run_heartbeat_loop()` and `run_agent_loop()`. When set, the event is checked
|
||||
at the start of each loop iteration (before `max_iterations`). When the event
|
||||
is set, the loop exits immediately with return value `"stopped"`. The check
|
||||
is ordered before `max_iterations` so a signal always wins.
|
||||
|
||||
### Impact
|
||||
Orphaned heartbeat processes continue consuming platform API quota after the controlling
|
||||
MCP client has disconnected. Each iteration sends a `POST /registry/heartbeat` and a
|
||||
`GET /workspaces/:id/state` call. Over time this accumulates unnecessary API calls.
|
||||
|
||||
### Suggested fix
|
||||
Add a `stop_event` parameter to `run_heartbeat_loop()` — a `threading.Event` or
|
||||
`asyncio.Event` that, when set, causes the loop to exit cleanly with a `stopped`
|
||||
return value:
|
||||
Callers achieve graceful shutdown by setting the event from a SIGTERM handler:
|
||||
|
||||
```python
|
||||
def run_heartbeat_loop(
|
||||
self,
|
||||
max_iterations: int | None = None,
|
||||
task_supplier: "callable | None" = None,
|
||||
stop_event: threading.Event | None = None,
|
||||
) -> str:
|
||||
i = 0
|
||||
while True:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
return "stopped"
|
||||
if max_iterations is not None and i >= max_iterations:
|
||||
return "max_iterations"
|
||||
# ... rest of loop
|
||||
import signal, threading
|
||||
from molecule_agent import RemoteAgentClient
|
||||
|
||||
stop = threading.Event()
|
||||
client = RemoteAgentClient(...)
|
||||
|
||||
def sigterm_handler(signum, frame):
|
||||
stop.set()
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
terminal = client.run_heartbeat_loop(stop_event=stop)
|
||||
# terminal == "stopped" when killed cleanly
|
||||
```
|
||||
|
||||
Callers (MCP client wrappers, shell scripts) can then call `stop_event.set()` on
|
||||
SIGTERM/SIGINT to achieve clean shutdown.
|
||||
Tests added: `test_run_loop_exits_on_stop_event`,
|
||||
`test_run_loop_respects_stop_event_between_iterations` in
|
||||
`tests/test_remote_agent.py`; `test_run_agent_loop_exits_on_stop_event`
|
||||
in `tests/test_inbound.py`.
|
||||
|
||||
@ -27,6 +27,7 @@ import os
|
||||
import random
|
||||
import subprocess
|
||||
import tarfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@ -921,6 +922,7 @@ class RemoteAgentClient:
|
||||
delivery: "InboundDelivery | None" = None,
|
||||
max_iterations: int | None = None,
|
||||
task_supplier: "callable | None" = None,
|
||||
stop_event: threading.Event | None = None,
|
||||
) -> str:
|
||||
"""Combined heartbeat + state-poll + inbound-delivery loop.
|
||||
|
||||
@ -946,10 +948,14 @@ class RemoteAgentClient:
|
||||
task_supplier: Optional zero-arg callable returning a dict
|
||||
``{"current_task": str, "active_tasks": int}`` reported on
|
||||
each heartbeat (same contract as :py:meth:`run_heartbeat_loop`).
|
||||
stop_event: Optional :py:class:`threading.Event` that, when set,
|
||||
causes the loop to exit cleanly with return value ``"stopped"``.
|
||||
Call ``stop_event.set()`` from a SIGTERM handler to achieve
|
||||
graceful shutdown. Ignored when ``None``.
|
||||
|
||||
Returns:
|
||||
The terminal status: ``"paused"``, ``"removed"``, or
|
||||
``"max_iterations"``.
|
||||
The terminal status: ``"paused"``, ``"removed"``,
|
||||
``"max_iterations"``, or ``"stopped"``.
|
||||
|
||||
Errors from the activity poll, heartbeat, or state poll are
|
||||
logged and the loop continues — a transient platform hiccup
|
||||
@ -964,6 +970,8 @@ class RemoteAgentClient:
|
||||
i = 0
|
||||
try:
|
||||
while True:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
return "stopped"
|
||||
if max_iterations is not None and i >= max_iterations:
|
||||
return "max_iterations"
|
||||
i += 1
|
||||
@ -1224,10 +1232,11 @@ class RemoteAgentClient:
|
||||
self,
|
||||
max_iterations: int | None = None,
|
||||
task_supplier: "callable | None" = None,
|
||||
stop_event: threading.Event | None = None,
|
||||
) -> str:
|
||||
"""Drive heartbeat + state-poll on a timer. Returns the terminal
|
||||
status when the loop exits (``"paused"``, ``"removed"``, or
|
||||
``"max_iterations"``).
|
||||
status when the loop exits (``"paused"``, ``"removed"``,
|
||||
``"max_iterations"``, or ``"stopped"``).
|
||||
|
||||
Args:
|
||||
max_iterations: Stop after N loop iterations. None = run until
|
||||
@ -1236,6 +1245,10 @@ class RemoteAgentClient:
|
||||
task_supplier: Optional zero-arg callable returning a dict
|
||||
``{"current_task": str, "active_tasks": int}`` fetched
|
||||
each iteration. Lets the agent report what it's doing.
|
||||
stop_event: Optional :py:class:`threading.Event` that, when set,
|
||||
causes the loop to exit cleanly with return value ``"stopped"``.
|
||||
Call ``stop_event.set()`` from a SIGTERM handler to achieve
|
||||
graceful shutdown. Ignored when ``None``.
|
||||
|
||||
The loop sends one heartbeat + one state poll per iteration; the
|
||||
next iteration sleeps for ``heartbeat_interval`` seconds. Errors
|
||||
@ -1245,6 +1258,8 @@ class RemoteAgentClient:
|
||||
"""
|
||||
i = 0
|
||||
while True:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
return "stopped"
|
||||
if max_iterations is not None and i >= max_iterations:
|
||||
return "max_iterations"
|
||||
i += 1
|
||||
|
||||
@ -770,3 +770,29 @@ def test_run_agent_loop_swallows_task_supplier_exception(
|
||||
hb_kwargs = client.heartbeat.call_args.kwargs
|
||||
assert hb_kwargs["current_task"] == ""
|
||||
assert hb_kwargs["active_tasks"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_agent_loop — stop_event (KI-009)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_agent_loop_exits_on_stop_event(client: RemoteAgentClient, monkeypatch):
|
||||
"""stop_event.set() before calling the loop causes immediate 'stopped' exit."""
|
||||
import threading
|
||||
import molecule_agent.client as mod
|
||||
monkeypatch.setattr(mod.time, "sleep", lambda s: None)
|
||||
|
||||
client.save_token("t")
|
||||
client.heartbeat = MagicMock() # avoid actual HTTP calls
|
||||
client.poll_state = MagicMock(return_value=None)
|
||||
|
||||
stop = threading.Event()
|
||||
stop.set() # signal stop BEFORE entering the loop
|
||||
terminal = client.run_agent_loop(
|
||||
lambda *_: None, max_iterations=999, stop_event=stop
|
||||
)
|
||||
|
||||
assert terminal == "stopped"
|
||||
# No heartbeat attempted — stop_event fired before the first iteration
|
||||
assert client.heartbeat.call_count == 0
|
||||
|
||||
@ -334,6 +334,66 @@ def test_run_loop_task_supplier_reported(client: RemoteAgentClient, monkeypatch)
|
||||
assert body["active_tasks"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_heartbeat_loop — stop_event (KI-009)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_loop_exits_on_stop_event(client: RemoteAgentClient, monkeypatch):
|
||||
"""stop_event.set() before calling the loop causes immediate 'stopped' exit,
|
||||
before the first heartbeat is sent."""
|
||||
import threading
|
||||
import molecule_agent.client as mod
|
||||
monkeypatch.setattr(mod.time, "sleep", lambda s: None)
|
||||
|
||||
client.save_token("t")
|
||||
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
|
||||
client._session.get.return_value = FakeResponse(
|
||||
200, {"status": "online", "paused": False, "deleted": False}
|
||||
)
|
||||
|
||||
stop = threading.Event()
|
||||
stop.set() # signal stop BEFORE entering the loop
|
||||
terminal = client.run_heartbeat_loop(max_iterations=999, stop_event=stop)
|
||||
|
||||
assert terminal == "stopped"
|
||||
# Zero heartbeats sent — stop_event fired before the first iteration body
|
||||
assert client._session.post.call_count == 0
|
||||
|
||||
|
||||
def test_run_loop_respects_stop_event_between_iterations(
|
||||
client: RemoteAgentClient, monkeypatch
|
||||
):
|
||||
"""stop_event.set() mid-run causes exit after the current iteration finishes."""
|
||||
import threading
|
||||
import molecule_agent.client as mod
|
||||
|
||||
# Don't stub sleep — we need the event to fire *between* iterations
|
||||
call_count = [0]
|
||||
|
||||
def fake_sleep(s):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 2:
|
||||
stop.set() # signal stop after the second iteration
|
||||
# otherwise no-op so the test doesn't wait
|
||||
|
||||
monkeypatch.setattr(mod.time, "sleep", fake_sleep)
|
||||
|
||||
client.save_token("t")
|
||||
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
|
||||
client._session.get.return_value = FakeResponse(
|
||||
200, {"status": "online", "paused": False, "deleted": False}
|
||||
)
|
||||
|
||||
stop = threading.Event()
|
||||
terminal = client.run_heartbeat_loop(max_iterations=999, stop_event=stop)
|
||||
|
||||
assert terminal == "stopped"
|
||||
# Two full iterations completed before stop was honoured
|
||||
assert client._session.post.call_count == 2
|
||||
assert client._session.get.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkspaceState dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user