diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 645a642b..1f26ed85 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -1401,7 +1401,13 @@ class BasePlatformAdapter(ABC): return paths, cleaned - async def _keep_typing(self, chat_id: str, interval: float = 2.0, metadata=None) -> None: + async def _keep_typing( + self, + chat_id: str, + interval: float = 2.0, + metadata=None, + stop_event: asyncio.Event | None = None, + ) -> None: """ Continuously send typing indicator until cancelled. @@ -1415,9 +1421,18 @@ class BasePlatformAdapter(ABC): """ try: while True: + if stop_event is not None and stop_event.is_set(): + return if chat_id not in self._typing_paused: await self.send_typing(chat_id, metadata=metadata) - await asyncio.sleep(interval) + if stop_event is None: + await asyncio.sleep(interval) + continue + try: + await asyncio.wait_for(stop_event.wait(), timeout=interval) + except asyncio.TimeoutError: + continue + return except asyncio.CancelledError: pass # Normal cancellation when handler completes finally: @@ -1444,6 +1459,17 @@ class BasePlatformAdapter(ABC): """Resume typing indicator for a chat after approval resolves.""" self._typing_paused.discard(chat_id) + async def interrupt_session_activity(self, session_key: str, chat_id: str) -> None: + """Signal the active session loop to stop and clear typing immediately.""" + if session_key: + interrupt_event = self._active_sessions.get(session_key) + if interrupt_event is not None: + interrupt_event.set() + try: + await self.stop_typing(chat_id) + except Exception: + pass + # ── Processing lifecycle hooks ────────────────────────────────────────── # Subclasses override these to react to message processing events # (e.g. Discord adds 👀/✅/❌ reactions). @@ -1717,7 +1743,13 @@ class BasePlatformAdapter(ABC): # Start continuous typing indicator (refreshes every 2 seconds) _thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None - typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id, metadata=_thread_metadata)) + typing_task = asyncio.create_task( + self._keep_typing( + event.source.chat_id, + metadata=_thread_metadata, + stop_event=interrupt_event, + ) + ) try: await self._run_processing_hook("on_processing_start", event) diff --git a/gateway/run.py b/gateway/run.py index 37b27232..ed3b6b5e 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -402,6 +402,26 @@ def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None: return adapter.get_pending_message(session_key) +_CONTROL_INTERRUPT_MESSAGES = frozenset( + { + "stop requested", + "session reset requested", + "execution timed out (inactivity)", + "sse client disconnected", + "gateway shutting down", + "gateway restarting", + } +) + + +def _is_control_interrupt_message(message: Optional[str]) -> bool: + """Return True when an interrupt message is internal control flow.""" + if not message: + return False + normalized = " ".join(str(message).strip().split()).lower() + return normalized in _CONTROL_INTERRUPT_MESSAGES + + def _check_unavailable_skill(command_name: str) -> str | None: """Check if a command matches a known-but-inactive skill. @@ -630,6 +650,7 @@ class GatewayRunner: self._running_agents_ts: Dict[str, float] = {} # start timestamp per session self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce) + self._session_run_generation: Dict[str, int] = {} # Cache AIAgent instances per session to preserve prompt caching. # Without this, a new AIAgent is created per message, rebuilding the @@ -3064,6 +3085,10 @@ class GatewayRunner: _quick_key[:30], _stale_age, _stale_idle, _raw_stale_timeout, _stale_detail, ) + self._invalidate_session_run_generation( + _quick_key, + reason="stale_running_agent_eviction", + ) self._release_running_agent_state(_quick_key) if _quick_key in self._running_agents: @@ -3091,7 +3116,13 @@ class GatewayRunner: if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: running_agent.interrupt("Stop requested") # Force-clean: remove the session lock regardless of agent state + self._invalidate_session_run_generation( + _quick_key, + reason="stop_command", + ) adapter = self.adapters.get(source.platform) + if adapter and hasattr(adapter, "interrupt_session_activity"): + await adapter.interrupt_session_activity(_quick_key, source.chat_id) if adapter and hasattr(adapter, 'get_pending_message'): adapter.get_pending_message(_quick_key) # consume and discard self._pending_messages.pop(_quick_key, None) @@ -3111,7 +3142,13 @@ class GatewayRunner: if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: running_agent.interrupt("Session reset requested") # Clear any pending messages so the old text doesn't replay + self._invalidate_session_run_generation( + _quick_key, + reason="new_command", + ) adapter = self.adapters.get(source.platform) + if adapter and hasattr(adapter, "interrupt_session_activity"): + await adapter.interrupt_session_activity(_quick_key, source.chat_id) if adapter and hasattr(adapter, 'get_pending_message'): adapter.get_pending_message(_quick_key) # consume and discard self._pending_messages.pop(_quick_key, None) @@ -3598,9 +3635,10 @@ class GatewayRunner: # same session — corrupting the transcript. self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL self._running_agents_ts[_quick_key] = time.time() + _run_generation = self._begin_session_run_generation(_quick_key) try: - return await self._handle_message_with_agent(event, source, _quick_key) + return await self._handle_message_with_agent(event, source, _quick_key, _run_generation) finally: # If _run_agent replaced the sentinel with a real agent and # then cleaned it up, this is a no-op. If we exited early @@ -3771,7 +3809,7 @@ class GatewayRunner: return message_text - async def _handle_message_with_agent(self, event, source, _quick_key: str): + async def _handle_message_with_agent(self, event, source, _quick_key: str, run_generation: int): """Inner handler that runs under the _running_agents sentinel guard.""" _msg_start_time = time.time() _platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform) @@ -4246,6 +4284,7 @@ class GatewayRunner: source=source, session_id=session_entry.session_id, session_key=session_key, + run_generation=run_generation, event_message_id=event.message_id, channel_prompt=event.channel_prompt, ) @@ -4258,6 +4297,17 @@ class GatewayRunner: except Exception: pass + if not self._is_session_run_current(_quick_key, run_generation): + logger.info( + "Discarding stale agent result for %s — generation %d is no longer current", + _quick_key[:20] if _quick_key else "?", + run_generation, + ) + _stale_adapter = self.adapters.get(source.platform) + if _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"): + _stale_adapter._post_delivery_callbacks.pop(_quick_key, None) + return None + response = agent_result.get("final_response") or "" # Convert the agent's internal "(empty)" sentinel into a @@ -4672,6 +4722,7 @@ class GatewayRunner: # Get existing session key session_key = self._session_key_for_source(source) + self._invalidate_session_run_generation(session_key, reason="session_reset") # Flush memories in the background (fire-and-forget) so the user # gets the "Session reset!" response immediately. @@ -4931,6 +4982,10 @@ class GatewayRunner: agent = self._running_agents.get(session_key) if agent is _AGENT_PENDING_SENTINEL: # Force-clean the sentinel so the session is unlocked. + self._invalidate_session_run_generation(session_key, reason="stop_command_pending") + adapter = self.adapters.get(source.platform) + if adapter and hasattr(adapter, "interrupt_session_activity"): + await adapter.interrupt_session_activity(session_key, source.chat_id) self._release_running_agent_state(session_key) logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20]) return "⚡ Stopped. The agent hadn't started yet — you can continue this session." @@ -4938,6 +4993,10 @@ class GatewayRunner: agent.interrupt("Stop requested") # Force-clean the session lock so a truly hung agent doesn't # keep it locked forever. + self._invalidate_session_run_generation(session_key, reason="stop_command_handler") + adapter = self.adapters.get(source.platform) + if adapter and hasattr(adapter, "interrupt_session_activity"): + await adapter.interrupt_session_activity(session_key, source.chat_id) self._release_running_agent_state(session_key) return "⚡ Stopped. You can continue this session." else: @@ -8385,6 +8444,43 @@ class GatewayRunner: if hasattr(self, "_busy_ack_ts"): self._busy_ack_ts.pop(session_key, None) + def _begin_session_run_generation(self, session_key: str) -> int: + """Claim a fresh run generation token for ``session_key``. + + Every top-level gateway turn gets a monotonically increasing token. + If a later command like /stop or /new invalidates that token while the + old worker is still unwinding, the late result can be recognized and + dropped instead of bleeding into the fresh session. + """ + if not session_key: + return 0 + generations = self.__dict__.get("_session_run_generation") + if generations is None: + generations = {} + self._session_run_generation = generations + next_generation = int(generations.get(session_key, 0)) + 1 + generations[session_key] = next_generation + return next_generation + + def _invalidate_session_run_generation(self, session_key: str, *, reason: str = "") -> int: + """Invalidate any in-flight run token for ``session_key``.""" + generation = self._begin_session_run_generation(session_key) + if reason: + logger.info( + "Invalidated run generation for %s → %d (%s)", + session_key[:20], + generation, + reason, + ) + return generation + + def _is_session_run_current(self, session_key: str, generation: int) -> bool: + """Return True when ``generation`` is still current for ``session_key``.""" + if not session_key: + return True + generations = self.__dict__.get("_session_run_generation") or {} + return int(generations.get(session_key, 0)) == int(generation) + def _evict_cached_agent(self, session_key: str) -> None: """Remove a cached agent for a session (called on /new, /model, etc).""" _lock = getattr(self, "_agent_cache_lock", None) @@ -8807,6 +8903,7 @@ class GatewayRunner: source: SessionSource, session_id: str, session_key: str = None, + run_generation: Optional[int] = None, _interrupt_depth: int = 0, event_message_id: Optional[str] = None, channel_prompt: Optional[str] = None, @@ -8837,6 +8934,11 @@ class GatewayRunner: from run_agent import AIAgent import queue + + def _run_still_current() -> bool: + if run_generation is None or not session_key: + return True + return self._is_session_run_current(session_key, run_generation) user_config = _load_gateway_config() platform_key = _platform_config_key(source.platform) @@ -8891,7 +8993,7 @@ class GatewayRunner: def progress_callback(event_type: str, tool_name: str = None, preview: str = None, args: dict = None, **kwargs): """Callback invoked by agent on tool lifecycle events.""" - if not progress_queue: + if not progress_queue or not _run_still_current(): return # Only act on tool.started events (ignore tool.completed, reasoning.available, etc.) @@ -8996,6 +9098,14 @@ class GatewayRunner: while True: try: + if not _run_still_current(): + while not progress_queue.empty(): + try: + progress_queue.get_nowait() + except Exception: + break + return + raw = progress_queue.get_nowait() # Handle dedup messages: update last line with repeat counter @@ -9021,6 +9131,9 @@ class GatewayRunner: await asyncio.sleep(_remaining) continue + if not _run_still_current(): + return + if can_edit and progress_msg_id is not None: # Try to edit the existing progress message full_text = "\n".join(progress_lines) @@ -9056,7 +9169,8 @@ class GatewayRunner: # Restore typing indicator await asyncio.sleep(0.3) - await adapter.send_typing(source.chat_id, metadata=_progress_metadata) + if _run_still_current(): + await adapter.send_typing(source.chat_id, metadata=_progress_metadata) except queue.Empty: await asyncio.sleep(0.3) @@ -9100,6 +9214,8 @@ class GatewayRunner: _hooks_ref = self.hooks def _step_callback_sync(iteration: int, prev_tools: list) -> None: + if not _run_still_current(): + return try: # prev_tools may be list[str] or list[dict] with "name"/"result" # keys. Normalise to keep "tool_names" backward-compatible for @@ -9130,7 +9246,7 @@ class GatewayRunner: _status_thread_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None def _status_callback_sync(event_type: str, message: str) -> None: - if not _status_adapter: + if not _status_adapter or not _run_still_current(): return try: asyncio.run_coroutine_threadsafe( @@ -9261,12 +9377,16 @@ class GatewayRunner: metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None, ) if _want_stream_deltas: - _stream_delta_cb = _stream_consumer.on_delta + def _stream_delta_cb(text: str) -> None: + if _run_still_current(): + _stream_consumer.on_delta(text) stream_consumer_holder[0] = _stream_consumer except Exception as _sc_err: logger.debug("Could not set up stream consumer: %s", _sc_err) def _interim_assistant_cb(text: str, *, already_streamed: bool = False) -> None: + if not _run_still_current(): + return if _stream_consumer is not None: if already_streamed: _stream_consumer.on_segment_break() @@ -9370,7 +9490,7 @@ class GatewayRunner: _bg_review_pending_lock = threading.Lock() def _deliver_bg_review_message(message: str) -> None: - if not _status_adapter: + if not _status_adapter or not _run_still_current(): return try: asyncio.run_coroutine_threadsafe( @@ -9394,7 +9514,7 @@ class GatewayRunner: # Background review delivery — send "💾 Memory updated" etc. to user def _bg_review_send(message: str) -> None: - if not _status_adapter: + if not _status_adapter or not _run_still_current(): return if not _bg_review_release.is_set(): with _bg_review_pending_lock: @@ -10076,7 +10196,15 @@ class GatewayRunner: if result and adapter and session_key: pending_event = _dequeue_pending_event(adapter, session_key) if result.get("interrupted") and not pending_event and result.get("interrupt_message"): - pending = result.get("interrupt_message") + interrupt_message = result.get("interrupt_message") + if _is_control_interrupt_message(interrupt_message): + logger.info( + "Ignoring control interrupt message for session %s: %s", + session_key[:20] if session_key else "?", + interrupt_message, + ) + else: + pending = interrupt_message elif pending_event: pending = pending_event.text or _build_media_placeholder(pending_event) logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) @@ -10229,6 +10357,7 @@ class GatewayRunner: source=next_source, session_id=session_id, session_key=session_key, + run_generation=run_generation, _interrupt_depth=_interrupt_depth + 1, event_message_id=next_message_id, channel_prompt=next_channel_prompt, diff --git a/tests/gateway/test_pending_event_none.py b/tests/gateway/test_pending_event_none.py index b2e1356f..e717c882 100644 --- a/tests/gateway/test_pending_event_none.py +++ b/tests/gateway/test_pending_event_none.py @@ -1,13 +1,18 @@ -"""Tests for the pending_event None guard in recursive _run_agent calls. +"""Tests for pending follow-up extraction in recursive _run_agent calls. When pending_event is None (Path B: pending comes from interrupt_message), accessing pending_event.channel_prompt previously raised AttributeError. This verifies the fix: channel_prompt is captured inside the `if pending_event is not None:` block and falls back to None otherwise. + +Also verifies that internal control interrupt reasons like "Stop requested" +do not get recycled into the pending-user-message follow-up path. """ from types import SimpleNamespace +from gateway.run import _is_control_interrupt_message + def _extract_channel_prompt(pending_event): """Reproduce the fixed logic from gateway/run.py. @@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event): return next_channel_prompt +def _extract_pending_text(interrupted, pending_event, interrupt_message): + """Reproduce the fixed pending-text selection from gateway/run.py.""" + if interrupted and pending_event is None and interrupt_message: + if _is_control_interrupt_message(interrupt_message): + return None + return interrupt_message + return None + + class TestPendingEventNoneChannelPrompt: """Guard against AttributeError when pending_event is None.""" @@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt: event = SimpleNamespace() result = _extract_channel_prompt(event) assert result is None + + +class TestControlInterruptMessages: + """Control interrupt reasons must not become follow-up user input.""" + + def test_stop_requested_is_not_treated_as_pending_user_message(self): + result = _extract_pending_text(True, None, "Stop requested") + assert result is None + + def test_session_reset_requested_is_not_treated_as_pending_user_message(self): + result = _extract_pending_text(True, None, "Session reset requested") + assert result is None + + def test_real_user_interrupt_message_still_requeues(self): + result = _extract_pending_text(True, None, "actually use postgres instead") + assert result == "actually use postgres instead" diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index 4878f2fa..59e9fa04 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter): async def send_typing(self, chat_id, metadata=None) -> None: self.typing.append({"chat_id": chat_id, "metadata": metadata}) + async def stop_typing(self, chat_id) -> None: + self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}}) + async def get_chat_info(self, chat_id: str): return {"id": chat_id} @@ -90,6 +93,40 @@ class LongPreviewAgent: } +class DelayedProgressAgent: + def __init__(self, **kwargs): + self.tool_progress_callback = kwargs.get("tool_progress_callback") + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + self.tool_progress_callback("tool.started", "terminal", "first command", {}) + time.sleep(0.45) + self.tool_progress_callback("tool.started", "terminal", "second command", {}) + time.sleep(0.1) + return { + "final_response": "done", + "messages": [], + "api_calls": 1, + } + + +class DelayedInterimAgent: + def __init__(self, **kwargs): + self.interim_assistant_callback = kwargs.get("interim_assistant_callback") + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + self.interim_assistant_callback("first interim") + time.sleep(0.45) + self.interim_assistant_callback("second interim") + time.sleep(0.1) + return { + "final_response": "done", + "messages": [], + "api_calls": 1, + } + + def _make_runner(adapter): gateway_run = importlib.import_module("gateway.run") GatewayRunner = gateway_run.GatewayRunner @@ -104,6 +141,7 @@ def _make_runner(adapter): runner._fallback_model = None runner._session_db = None runner._running_agents = {} + runner._session_run_generation = {} runner.hooks = SimpleNamespace(loaded_hooks=False) runner.config = SimpleNamespace( thread_sessions_per_user=False, @@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send() assert released == [True] +@pytest.mark.asyncio +async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path): + import yaml + + (tmp_path / "config.yaml").write_text( + yaml.dump({"display": {"tool_progress": "all"}}), + encoding="utf-8", + ) + + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = DelayedProgressAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + import tools.terminal_tool # noqa: F401 - register terminal tool metadata + + adapter = ProgressCaptureAdapter(platform=Platform.DISCORD) + runner = _make_runner(adapter) + gateway_run = importlib.import_module("gateway.run") + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="dm-1", + chat_type="dm", + thread_id=None, + ) + session_key = "agent:main:discord:dm:dm-1" + runner._session_run_generation[session_key] = 1 + + original_send = adapter.send + invalidated = {"done": False} + + async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None): + result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata) + if "first command" in content and not invalidated["done"]: + invalidated["done"] = True + runner._invalidate_session_run_generation(session_key, reason="test_stop") + return result + + adapter.send = send_and_invalidate + + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id="sess-progress-stop", + session_key=session_key, + run_generation=1, + ) + + all_progress_text = " ".join(call["content"] for call in adapter.sent) + all_progress_text += " ".join(call["content"] for call in adapter.edits) + assert result["final_response"] == "done" + assert 'first command' in all_progress_text + assert 'second command' not in all_progress_text + + +@pytest.mark.asyncio +async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path): + import yaml + + (tmp_path / "config.yaml").write_text( + yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}), + encoding="utf-8", + ) + + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = DelayedInterimAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + adapter = ProgressCaptureAdapter(platform=Platform.DISCORD) + runner = _make_runner(adapter) + gateway_run = importlib.import_module("gateway.run") + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="dm-2", + chat_type="dm", + thread_id=None, + ) + session_key = "agent:main:discord:dm:dm-2" + runner._session_run_generation[session_key] = 1 + + original_send = adapter.send + invalidated = {"done": False} + + async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None): + result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata) + if content == "first interim" and not invalidated["done"]: + invalidated["done"] = True + runner._invalidate_session_run_generation(session_key, reason="test_stop") + return result + + adapter.send = send_and_invalidate + + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id="sess-commentary-stop", + session_key=session_key, + run_generation=1, + ) + + sent_texts = [call["content"] for call in adapter.sent] + assert result["final_response"] == "done" + assert "first interim" in sent_texts + assert "second interim" not in sent_texts + + +@pytest.mark.asyncio +async def test_keep_typing_stops_immediately_when_interrupt_event_is_set(): + adapter = ProgressCaptureAdapter(platform=Platform.DISCORD) + stop_event = asyncio.Event() + + task = asyncio.create_task( + adapter._keep_typing( + "dm-typing-stop", + interval=30.0, + stop_event=stop_event, + ) + ) + await asyncio.sleep(0.05) + stop_event.set() + await asyncio.wait_for(task, timeout=0.5) + + normal_typing_calls = [ + call for call in adapter.typing if call.get("metadata") != {"stopped": True} + ] + stopped_calls = [ + call for call in adapter.typing if call.get("metadata") == {"stopped": True} + ] + assert len(normal_typing_calls) == 1 + assert len(stopped_calls) == 1 + + @pytest.mark.asyncio async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path): """Verbose mode with default tool_preview_length (0) should NOT truncate args. diff --git a/tests/gateway/test_session_race_guard.py b/tests/gateway/test_session_race_guard.py index 8c26abec..fe1ef011 100644 --- a/tests/gateway/test_session_race_guard.py +++ b/tests/gateway/test_session_race_guard.py @@ -24,10 +24,18 @@ class _FakeAdapter: def __init__(self): self._pending_messages = {} + self._active_sessions = {} + self.interrupted_sessions = [] async def send(self, chat_id, text, **kwargs): pass + async def interrupt_session_activity(self, session_key, chat_id): + self.interrupted_sessions.append((session_key, chat_id)) + event = self._active_sessions.get(session_key) + if event is not None: + event.set() + def _make_runner(): runner = object.__new__(GatewayRunner) @@ -37,6 +45,7 @@ def _make_runner(): runner.adapters = {Platform.TELEGRAM: _FakeAdapter()} runner._running_agents = {} runner._running_agents_ts = {} + runner._session_run_generation = {} runner._pending_messages = {} runner._pending_approvals = {} runner._voice_mode = {} @@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup(): # Patch _handle_message_with_agent to capture state at entry sentinel_was_set = False - async def mock_inner(self_inner, ev, src, qk): + async def mock_inner(self_inner, ev, src, qk, generation): nonlocal sentinel_was_set sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL return "ok" @@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns(): event = _make_event() session_key = build_session_key(event.source) - async def mock_inner(self_inner, ev, src, qk): + async def mock_inner(self_inner, ev, src, qk, generation): return "ok" with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner): @@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception(): event = _make_event() session_key = build_session_key(event.source) - async def mock_inner(self_inner, ev, src, qk): + async def mock_inner(self_inner, ev, src, qk, generation): raise RuntimeError("boom") with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner): @@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate(): barrier = asyncio.Event() - async def slow_inner(self_inner, ev, src, qk): + async def slow_inner(self_inner, ev, src, qk, generation): # Simulate slow setup — wait until test tells us to proceed await barrier.wait() return "ok" @@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session(): barrier = asyncio.Event() - async def slow_inner(self_inner, ev, src, qk): + async def slow_inner(self_inner, ev, src, qk, generation): await barrier.wait() return "ok" @@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent(): fake_agent = MagicMock() fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0} runner._running_agents[session_key] = fake_agent + runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event() # Send /stop stop_event = _make_event(text="/stop") @@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent(): assert session_key not in runner._running_agents, ( "/stop must remove the agent from _running_agents so the session is unlocked" ) + assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [ + (session_key, "12345") + ] + assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set() # Must return a confirmation assert result is not None diff --git a/tests/gateway/test_status_command.py b/tests/gateway/test_status_command.py index c4a64f30..3cdf637d 100644 --- a/tests/gateway/test_status_command.py +++ b/tests/gateway/test_status_command.py @@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry): runner.session_store.rewrite_transcript = MagicMock() runner.session_store.update_session = MagicMock() runner._running_agents = {} + runner._session_run_generation = {} runner._pending_messages = {} runner._pending_approvals = {} runner._session_db = MagicMock() @@ -223,6 +224,52 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch): ) +@pytest.mark.asyncio +async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch): + import gateway.run as gateway_run + + session_entry = SessionEntry( + session_key=build_session_key(_make_source()), + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + runner = _make_runner(session_entry) + runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}] + session_key = session_entry.session_key + runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()} + + async def _stale_result(**kwargs): + runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result") + return { + "final_response": "late reply", + "messages": [], + "tools": [], + "history_offset": 0, + "last_prompt_tokens": 80, + "input_tokens": 120, + "output_tokens": 45, + "model": "openai/test-model", + } + + runner._run_agent = AsyncMock(side_effect=_stale_result) + + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + monkeypatch.setattr( + "agent.model_metadata.get_model_context_length", + lambda *_args, **_kwargs: 100000, + ) + + result = await runner._handle_message(_make_event("hello")) + + assert result is None + runner.session_store.append_to_transcript.assert_not_called() + runner.session_store.update_session.assert_not_called() + assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks + + @pytest.mark.asyncio async def test_status_command_bypasses_active_session_guard():