diff --git a/gateway/run.py b/gateway/run.py index 1bef295c..cfb4af82 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -573,6 +573,7 @@ class GatewayRunner: self._running_agents: Dict[str, Any] = {} self._running_agents_ts: Dict[str, float] = {} # start timestamp per session self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt + self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce) # Cache AIAgent instances per session to preserve prompt caching. # Without this, a new AIAgent is created per message, rebuilding the @@ -1329,26 +1330,100 @@ class GatewayRunner: merge_pending_message_event(adapter._pending_messages, session_key, event) async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool: - if not self._draining: - return False + # --- Draining case (gateway restarting/stopping) --- + if self._draining: + adapter = self.adapters.get(event.source.platform) + if not adapter: + return True + + thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None + if self._queue_during_drain_enabled(): + self._queue_or_replace_pending_event(session_key, event) + message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back." + else: + message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now." + + await adapter._send_with_retry( + chat_id=event.source.chat_id, + content=message, + reply_to=event.message_id, + metadata=thread_meta, + ) + return True + + # --- Normal busy case (agent actively running a task) --- + # The user sent a message while the agent is working. Interrupt the + # agent immediately so it stops the current tool-calling loop and + # processes the new message. The pending message is stored in the + # adapter so the base adapter picks it up once the interrupted run + # returns. A brief ack tells the user what's happening (debounced + # to avoid spam when they fire multiple messages quickly). adapter = self.adapters.get(event.source.platform) if not adapter: - return True + return False # let default path handle it + + # Store the message so it's processed as the next turn after the + # interrupt causes the current run to exit. + from gateway.platforms.base import merge_pending_message_event + merge_pending_message_event(adapter._pending_messages, session_key, event) + + # Interrupt the running agent — this aborts in-flight tool calls and + # causes the agent loop to exit at the next check point. + running_agent = self._running_agents.get(session_key) + if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: + try: + running_agent.interrupt(event.text) + except Exception: + pass # don't let interrupt failure block the ack + + # Debounce: only send an acknowledgment once every 30 seconds per session + # to avoid spamming the user when they send multiple messages quickly + _BUSY_ACK_COOLDOWN = 30 + now = time.time() + last_ack = self._busy_ack_ts.get(session_key, 0) + if now - last_ack < _BUSY_ACK_COOLDOWN: + return True # interrupt sent, ack already delivered recently + + self._busy_ack_ts[session_key] = now + + # Build a status-rich acknowledgment + status_parts = [] + if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: + try: + summary = running_agent.get_activity_summary() + iteration = summary.get("api_call_count", 0) + max_iter = summary.get("max_iterations", 0) + current_tool = summary.get("current_tool") + start_ts = self._running_agents_ts.get(session_key, 0) + if start_ts: + elapsed_min = int((now - start_ts) / 60) + if elapsed_min > 0: + status_parts.append(f"{elapsed_min} min elapsed") + if max_iter: + status_parts.append(f"iteration {iteration}/{max_iter}") + if current_tool: + status_parts.append(f"running: {current_tool}") + except Exception: + pass + + status_detail = f" ({', '.join(status_parts)})" if status_parts else "" + message = ( + f"⚡ Interrupting current task{status_detail}. " + f"I'll respond to your message shortly." + ) thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None - if self._queue_during_drain_enabled(): - self._queue_or_replace_pending_event(session_key, event) - message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back." - else: - message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now." + try: + await adapter._send_with_retry( + chat_id=event.source.chat_id, + content=message, + reply_to=event.message_id, + metadata=thread_meta, + ) + except Exception as e: + logger.debug("Failed to send busy-ack: %s", e) - await adapter._send_with_retry( - chat_id=event.source.chat_id, - content=message, - reply_to=event.message_id, - metadata=thread_meta, - ) return True async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]: @@ -2237,6 +2312,8 @@ class GatewayRunner: self._running_agents.clear() self._pending_messages.clear() self._pending_approvals.clear() + if hasattr(self, '_busy_ack_ts'): + self._busy_ack_ts.clear() self._shutdown_event.set() # Global cleanup: kill any remaining tool subprocesses not tied @@ -2721,6 +2798,7 @@ class GatewayRunner: ) del self._running_agents[_quick_key] self._running_agents_ts.pop(_quick_key, None) + self._busy_ack_ts.pop(_quick_key, None) if _quick_key in self._running_agents: if event.get_command() == "status": diff --git a/tests/gateway/test_busy_session_ack.py b/tests/gateway/test_busy_session_ack.py new file mode 100644 index 00000000..07fe5fa2 --- /dev/null +++ b/tests/gateway/test_busy_session_ack.py @@ -0,0 +1,293 @@ +"""Tests for busy-session acknowledgment when user sends messages during active agent runs. + +Verifies that users get an immediate status response instead of total silence +when the agent is working on a task. See PR fix for the @Lonely__MH report. +""" +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Minimal stubs so we can import gateway code without heavy deps +# --------------------------------------------------------------------------- +import sys, types + +_tg = types.ModuleType("telegram") +_tg.constants = types.ModuleType("telegram.constants") +_ct = MagicMock() +_ct.SUPERGROUP = "supergroup" +_ct.GROUP = "group" +_ct.PRIVATE = "private" +_tg.constants.ChatType = _ct +sys.modules.setdefault("telegram", _tg) +sys.modules.setdefault("telegram.constants", _tg.constants) +sys.modules.setdefault("telegram.ext", types.ModuleType("telegram.ext")) + +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SessionSource, + build_session_key, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_event(text="hello", chat_id="123", platform_val="telegram"): + """Build a minimal MessageEvent.""" + source = SessionSource( + platform=MagicMock(value=platform_val), + chat_id=chat_id, + chat_type="private", + user_id="user1", + ) + evt = MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=source, + message_id="msg1", + ) + return evt + + +def _make_runner(): + """Build a minimal GatewayRunner-like object for testing.""" + from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL + + runner = object.__new__(GatewayRunner) + runner._running_agents = {} + runner._running_agents_ts = {} + runner._pending_messages = {} + runner._busy_ack_ts = {} + runner._draining = False + runner.adapters = {} + runner.config = MagicMock() + runner.session_store = None + runner.hooks = MagicMock() + runner.hooks.emit = AsyncMock() + return runner, _AGENT_PENDING_SENTINEL + + +def _make_adapter(platform_val="telegram"): + """Build a minimal adapter mock.""" + adapter = MagicMock() + adapter._pending_messages = {} + adapter._send_with_retry = AsyncMock() + adapter.config = MagicMock() + adapter.config.extra = {} + adapter.platform = MagicMock(value=platform_val) + return adapter + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestBusySessionAck: + """User sends a message while agent is running — should get acknowledgment.""" + + @pytest.mark.asyncio + async def test_sends_ack_when_agent_running(self): + """First message during busy session should get a status ack.""" + runner, sentinel = _make_runner() + adapter = _make_adapter() + + event = _make_event(text="Are you working?") + sk = build_session_key(event.source) + + # Simulate running agent + agent = MagicMock() + agent.get_activity_summary.return_value = { + "api_call_count": 21, + "max_iterations": 60, + "current_tool": "terminal", + "last_activity_ts": time.time(), + "last_activity_desc": "terminal", + "seconds_since_activity": 1.0, + } + runner._running_agents[sk] = agent + runner._running_agents_ts[sk] = time.time() - 600 # 10 min ago + runner.adapters[event.source.platform] = adapter + + result = await runner._handle_active_session_busy_message(event, sk) + + assert result is True # handled + # Verify ack was sent + adapter._send_with_retry.assert_called_once() + call_kwargs = adapter._send_with_retry.call_args + content = call_kwargs.kwargs.get("content") or call_kwargs[1].get("content", "") + if not content and call_kwargs.args: + # positional args + content = str(call_kwargs) + assert "Interrupting" in content or "respond" in content + assert "/stop" not in content # no need — we ARE interrupting + + # Verify message was queued in adapter pending + assert sk in adapter._pending_messages + + # Verify agent interrupt was called + agent.interrupt.assert_called_once_with("Are you working?") + + @pytest.mark.asyncio + async def test_debounce_suppresses_rapid_acks(self): + """Second message within 30s should NOT send another ack.""" + runner, sentinel = _make_runner() + adapter = _make_adapter() + + event1 = _make_event(text="hello?") + # Reuse the same source so platform mock matches + event2 = MessageEvent( + text="still there?", + message_type=MessageType.TEXT, + source=event1.source, + message_id="msg2", + ) + sk = build_session_key(event1.source) + + agent = MagicMock() + agent.get_activity_summary.return_value = { + "api_call_count": 5, + "max_iterations": 60, + "current_tool": None, + "last_activity_ts": time.time(), + "last_activity_desc": "api_call", + "seconds_since_activity": 0.5, + } + runner._running_agents[sk] = agent + runner._running_agents_ts[sk] = time.time() - 60 + runner.adapters[event1.source.platform] = adapter + + # First message — should get ack + result1 = await runner._handle_active_session_busy_message(event1, sk) + assert result1 is True + assert adapter._send_with_retry.call_count == 1 + + # Second message within cooldown — should be queued but no ack + result2 = await runner._handle_active_session_busy_message(event2, sk) + assert result2 is True + assert adapter._send_with_retry.call_count == 1 # still 1, no new ack + + # But interrupt should still be called for both + assert agent.interrupt.call_count == 2 + + @pytest.mark.asyncio + async def test_ack_after_cooldown_expires(self): + """After 30s cooldown, a new message should send a fresh ack.""" + runner, sentinel = _make_runner() + adapter = _make_adapter() + + event = _make_event(text="hello?") + sk = build_session_key(event.source) + + agent = MagicMock() + agent.get_activity_summary.return_value = { + "api_call_count": 10, + "max_iterations": 60, + "current_tool": "web_search", + "last_activity_ts": time.time(), + "last_activity_desc": "tool", + "seconds_since_activity": 0.5, + } + runner._running_agents[sk] = agent + runner._running_agents_ts[sk] = time.time() - 120 + runner.adapters[event.source.platform] = adapter + + # First ack + await runner._handle_active_session_busy_message(event, sk) + assert adapter._send_with_retry.call_count == 1 + + # Fake that cooldown expired + runner._busy_ack_ts[sk] = time.time() - 31 + + # Second ack should go through + await runner._handle_active_session_busy_message(event, sk) + assert adapter._send_with_retry.call_count == 2 + + @pytest.mark.asyncio + async def test_includes_status_detail(self): + """Ack message should include iteration and tool info when available.""" + runner, sentinel = _make_runner() + adapter = _make_adapter() + + event = _make_event(text="yo") + sk = build_session_key(event.source) + + agent = MagicMock() + agent.get_activity_summary.return_value = { + "api_call_count": 21, + "max_iterations": 60, + "current_tool": "terminal", + "last_activity_ts": time.time(), + "last_activity_desc": "terminal", + "seconds_since_activity": 0.5, + } + runner._running_agents[sk] = agent + runner._running_agents_ts[sk] = time.time() - 600 # 10 min + runner.adapters[event.source.platform] = adapter + + await runner._handle_active_session_busy_message(event, sk) + + call_kwargs = adapter._send_with_retry.call_args + content = call_kwargs.kwargs.get("content", "") + assert "21/60" in content # iteration + assert "terminal" in content # current tool + assert "10 min" in content # elapsed + + @pytest.mark.asyncio + async def test_draining_still_works(self): + """Draining case should still produce the drain-specific message.""" + runner, sentinel = _make_runner() + runner._draining = True + adapter = _make_adapter() + + event = _make_event(text="hello") + sk = build_session_key(event.source) + runner.adapters[event.source.platform] = adapter + + # Mock the drain-specific methods + runner._queue_during_drain_enabled = lambda: False + runner._status_action_gerund = lambda: "restarting" + + result = await runner._handle_active_session_busy_message(event, sk) + assert result is True + + call_kwargs = adapter._send_with_retry.call_args + content = call_kwargs.kwargs.get("content", "") + assert "restarting" in content + + @pytest.mark.asyncio + async def test_pending_sentinel_no_interrupt(self): + """When agent is PENDING_SENTINEL, don't call interrupt (it has no method).""" + runner, sentinel = _make_runner() + adapter = _make_adapter() + + event = _make_event(text="hey") + sk = build_session_key(event.source) + + runner._running_agents[sk] = sentinel + runner._running_agents_ts[sk] = time.time() + runner.adapters[event.source.platform] = adapter + + result = await runner._handle_active_session_busy_message(event, sk) + assert result is True + # Should still send ack + adapter._send_with_retry.assert_called_once() + + @pytest.mark.asyncio + async def test_no_adapter_falls_through(self): + """If adapter is missing, return False so default path handles it.""" + runner, sentinel = _make_runner() + + event = _make_event(text="hello") + sk = build_session_key(event.source) + + # No adapter registered + runner._running_agents[sk] = MagicMock() + + result = await runner._handle_active_session_busy_message(event, sk) + assert result is False # not handled, let default path try