fix(api-server): collapse tool start/lifecycle into a single SSE event
Address Copilot review on PR #16666: 1. **Duplicate event on every tool start** — both ``tool_progress_callback`` and ``tool_start_callback`` fire side-by-side in ``run_agent.py``, so wiring both into chat completions emitted *two* ``hermes.tool.progress`` events per real tool call. Drop the legacy ``_on_tool_progress`` emit entirely; ``_on_tool_start`` now produces a single unified event that carries the legacy ``tool``/``emoji``/``label`` fields plus the new ``toolCallId``/``status`` correlation fields. Label is computed inline via ``build_tool_preview`` so callers do not need to pre-format it. 2. **Weak per-event correlation in the regression test** — the previous assertion checked that a ``toolCallId`` appeared *somewhere* in the aggregate, which would have passed even if ``running`` lacked the id. Collect ``(status, toolCallId)`` per event and assert each event carries the correct pair, plus exactly two events on the wire (no silent duplication regression). The two existing chat-completions tool-progress tests are updated to fire ``tool_start_callback`` instead of ``tool_progress_callback``, matching production reality where ``run_agent`` always pairs them. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
13c238327e
commit
e0a03f3f40
@ -981,39 +981,62 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if delta is not None:
|
||||
_stream_q.put(delta)
|
||||
|
||||
def _on_tool_progress(event_type, name, preview, args, **kwargs):
|
||||
"""Send tool progress as a separate SSE event.
|
||||
# Track which tool_call_ids we've emitted a "running" lifecycle
|
||||
# event for, so a "completed" event without a matching "running"
|
||||
# (e.g. internal/filtered tools) is silently dropped instead of
|
||||
# producing an orphaned event clients can't correlate.
|
||||
_started_tool_call_ids: set[str] = set()
|
||||
|
||||
Previously, progress markers like ``⏰ list`` were injected
|
||||
directly into ``delta.content``. OpenAI-compatible frontends
|
||||
(Open WebUI, LobeChat, …) store ``delta.content`` verbatim as
|
||||
the assistant message and send it back on subsequent requests.
|
||||
After enough turns the model learns to *emit* the markers as
|
||||
plain text instead of issuing real tool calls — silently
|
||||
hallucinating tool results. See #6972.
|
||||
def _on_tool_start(tool_call_id, function_name, function_args):
|
||||
"""Emit ``hermes.tool.progress`` with ``status: running``.
|
||||
|
||||
The fix: push a tagged tuple ``("__tool_progress__", payload)``
|
||||
onto the stream queue. The SSE writer emits it as a custom
|
||||
``event: hermes.tool.progress`` line that compliant frontends
|
||||
can render for UX but will *not* persist into conversation
|
||||
history. Clients that don't understand the custom event type
|
||||
silently ignore it per the SSE specification.
|
||||
Replaces the old ``tool_progress_callback("tool.started",
|
||||
...)`` emit so SSE consumers receive a single event per
|
||||
tool start, carrying both the legacy ``tool``/``emoji``/
|
||||
``label`` payload (for #6972 frontends) and the new
|
||||
``toolCallId``/``status`` correlation fields (#16588).
|
||||
|
||||
Skips tools whose names start with ``_`` so internal
|
||||
events (``_thinking``, …) stay off the wire — matching
|
||||
the prior ``_on_tool_progress`` filter exactly.
|
||||
"""
|
||||
if event_type != "tool.started":
|
||||
if not tool_call_id or function_name.startswith("_"):
|
||||
return
|
||||
if name.startswith("_"):
|
||||
return
|
||||
from agent.display import get_tool_emoji
|
||||
emoji = get_tool_emoji(name)
|
||||
label = preview or name
|
||||
_started_tool_call_ids.add(tool_call_id)
|
||||
from agent.display import build_tool_preview, get_tool_emoji
|
||||
label = build_tool_preview(function_name, function_args) or function_name
|
||||
_stream_q.put(("__tool_progress__", {
|
||||
"tool": name,
|
||||
"emoji": emoji,
|
||||
"tool": function_name,
|
||||
"emoji": get_tool_emoji(function_name),
|
||||
"label": label,
|
||||
"toolCallId": tool_call_id,
|
||||
"status": "running",
|
||||
}))
|
||||
|
||||
def _on_tool_complete(tool_call_id, function_name, function_args, function_result):
|
||||
"""Emit the matching ``status: completed`` event.
|
||||
|
||||
Dropped if the start was filtered (internal tool, missing
|
||||
id, or never seen) so clients never get an orphaned
|
||||
``completed`` they can't correlate to a prior ``running``.
|
||||
"""
|
||||
if not tool_call_id or tool_call_id not in _started_tool_call_ids:
|
||||
return
|
||||
_started_tool_call_ids.discard(tool_call_id)
|
||||
_stream_q.put(("__tool_progress__", {
|
||||
"tool": function_name,
|
||||
"toolCallId": tool_call_id,
|
||||
"status": "completed",
|
||||
}))
|
||||
|
||||
# Start agent in background. agent_ref is a mutable container
|
||||
# so the SSE writer can interrupt the agent on client disconnect.
|
||||
#
|
||||
# ``tool_progress_callback`` is intentionally not wired here:
|
||||
# it would duplicate every emit because ``run_agent`` fires it
|
||||
# side-by-side with ``tool_start_callback``/``tool_complete_callback``.
|
||||
# The structured callbacks are strictly richer (they carry the
|
||||
# tool_call id), so they own the chat-completions SSE channel.
|
||||
agent_ref = [None]
|
||||
agent_task = asyncio.ensure_future(self._run_agent(
|
||||
user_message=user_message,
|
||||
@ -1021,7 +1044,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
ephemeral_system_prompt=system_prompt,
|
||||
session_id=session_id,
|
||||
stream_delta_callback=_on_delta,
|
||||
tool_progress_callback=_on_tool_progress,
|
||||
tool_start_callback=_on_tool_start,
|
||||
tool_complete_callback=_on_tool_complete,
|
||||
agent_ref=agent_ref,
|
||||
))
|
||||
|
||||
@ -1136,7 +1160,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Tagged tuples ``("__tool_progress__", payload)`` are sent
|
||||
as a custom ``event: hermes.tool.progress`` SSE event so
|
||||
frontends can display them without storing the markers in
|
||||
conversation history. See #6972.
|
||||
conversation history. See #6972 for the original event,
|
||||
#16588 for the ``toolCallId``/``status`` lifecycle fields.
|
||||
"""
|
||||
if isinstance(item, tuple) and len(item) == 2 and item[0] == "__tool_progress__":
|
||||
event_data = json.dumps(item[1])
|
||||
|
||||
@ -688,17 +688,17 @@ class TestChatCompletionsEndpoint:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_includes_tool_progress(self, adapter):
|
||||
"""tool_progress_callback fires → progress appears as custom SSE event, not in delta.content."""
|
||||
"""tool_start_callback fires → progress appears as custom SSE event, not in delta.content."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
# Simulate tool progress before streaming content
|
||||
if tp_cb:
|
||||
tp_cb("tool.started", "terminal", "ls -la", {"command": "ls -la"})
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
# Simulate the structured tool start the gateway now consumes.
|
||||
if ts_cb:
|
||||
ts_cb("call_terminal_1", "terminal", {"command": "ls -la"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Here are the files.")
|
||||
@ -724,7 +724,10 @@ class TestChatCompletionsEndpoint:
|
||||
# markers instead of calling tools (#6972).
|
||||
assert "event: hermes.tool.progress" in body
|
||||
assert '"tool": "terminal"' in body
|
||||
assert '"label": "ls -la"' in body
|
||||
# ``label`` is now derived by ``build_tool_preview`` from the
|
||||
# tool args rather than passed by the caller, so we assert
|
||||
# only that *some* label exists rather than a literal value.
|
||||
assert '"label":' in body
|
||||
# The progress marker must NOT appear inside any
|
||||
# chat.completion.chunk delta.content field.
|
||||
import json as _json
|
||||
@ -744,17 +747,17 @@ class TestChatCompletionsEndpoint:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_tool_progress_skips_internal_events(self, adapter):
|
||||
"""Internal events (name starting with _) are not streamed."""
|
||||
"""Internal tool calls (name starting with ``_``) are not streamed."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
if tp_cb:
|
||||
tp_cb("tool.started", "_thinking", "some internal state", {})
|
||||
tp_cb("tool.started", "web_search", "Python docs", {"query": "Python docs"})
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
if ts_cb:
|
||||
ts_cb("call_internal_1", "_thinking", {"text": "some internal state"})
|
||||
ts_cb("call_search_1", "web_search", {"query": "Python docs"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Found it.")
|
||||
@ -776,10 +779,142 @@ class TestChatCompletionsEndpoint:
|
||||
body = await resp.text()
|
||||
# Internal _thinking event should NOT appear anywhere
|
||||
assert "some internal state" not in body
|
||||
assert "call_internal_1" not in body
|
||||
# Real tool progress should appear as custom SSE event
|
||||
assert "event: hermes.tool.progress" in body
|
||||
assert '"tool": "web_search"' in body
|
||||
assert '"label": "Python docs"' in body
|
||||
# Label is derived from the args dict by build_tool_preview;
|
||||
# asserting on the structural fact (label exists, call id
|
||||
# is correlated) rather than a literal preview string keeps
|
||||
# the test robust against preview-formatter tweaks.
|
||||
assert '"label":' in body
|
||||
assert '"toolCallId": "call_search_1"' in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_emits_tool_lifecycle_with_call_id(self, adapter):
|
||||
"""Regression for #16588.
|
||||
|
||||
``/v1/chat/completions`` streaming previously emitted only a
|
||||
``tool.started``-style ``hermes.tool.progress`` event; clients
|
||||
rendering tool lifecycle UI had no way to mark a tool as finished
|
||||
because no matching ``status: completed`` event was emitted, and
|
||||
no ``toolCallId`` was carried for correlation.
|
||||
|
||||
The fix adds ``tool_start_callback`` / ``tool_complete_callback``
|
||||
to the chat completions agent invocation and writes both halves
|
||||
of the lifecycle pair on the same ``event: hermes.tool.progress``
|
||||
SSE line, with stable ``toolCallId`` and ``status``.
|
||||
"""
|
||||
import asyncio
|
||||
import json as _json
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
tc_cb = kwargs.get("tool_complete_callback")
|
||||
# The structured callbacks own the chat-completions SSE
|
||||
# channel now; ``tool_progress_callback`` is intentionally
|
||||
# not wired so each tool start emits exactly one event.
|
||||
if ts_cb:
|
||||
ts_cb("call_terminal_1", "terminal", {"command": "ls -la"})
|
||||
if tc_cb:
|
||||
tc_cb("call_terminal_1", "terminal", {"command": "ls -la"}, "ok")
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("done.")
|
||||
return (
|
||||
{"final_response": "done.", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "list"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
|
||||
# Walk the SSE body and collect *(status, toolCallId)* pairs
|
||||
# per event so the assertions verify per-event correlation —
|
||||
# an event missing ``toolCallId`` would not pass even if a
|
||||
# different event happens to carry the right id.
|
||||
pairs: list[tuple[str | None, str | None]] = []
|
||||
lines = body.splitlines()
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip() != "event: hermes.tool.progress":
|
||||
continue
|
||||
for follow in lines[i + 1: i + 4]:
|
||||
if follow.startswith("data: "):
|
||||
try:
|
||||
payload = _json.loads(follow[len("data: "):])
|
||||
except _json.JSONDecodeError:
|
||||
break
|
||||
pairs.append((payload.get("status"), payload.get("toolCallId")))
|
||||
break
|
||||
|
||||
# Each tool start must emit exactly one event (no duplicate
|
||||
# legacy + new emit), and each lifecycle pair must carry the
|
||||
# same toolCallId on every event — not just somewhere in the
|
||||
# aggregate.
|
||||
assert len(pairs) == 2, f"expected 2 events (running+completed), got {pairs}"
|
||||
assert pairs[0] == ("running", "call_terminal_1"), pairs
|
||||
assert pairs[1] == ("completed", "call_terminal_1"), pairs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_tool_lifecycle_skips_internal_and_orphan_completes(self, adapter):
|
||||
"""Internal tools (``_thinking``-style) and ``completed`` events
|
||||
without a prior matching ``running`` must produce no lifecycle
|
||||
events on the wire — otherwise clients would see orphaned
|
||||
``status: completed`` updates they cannot correlate."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
ts_cb = kwargs.get("tool_start_callback")
|
||||
tc_cb = kwargs.get("tool_complete_callback")
|
||||
# Internal tool — must be filtered.
|
||||
if ts_cb:
|
||||
ts_cb("call_internal_1", "_thinking", {})
|
||||
if tc_cb:
|
||||
tc_cb("call_internal_1", "_thinking", {}, "")
|
||||
# Completion without start — orphan, must be dropped.
|
||||
if tc_cb:
|
||||
tc_cb("call_orphan_1", "web_search", {}, "ok")
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("ok.")
|
||||
return (
|
||||
{"final_response": "ok.", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "ok"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
|
||||
# Neither the internal call_id nor the orphan call_id should
|
||||
# surface as a lifecycle payload on the wire.
|
||||
assert "call_internal_1" not in body
|
||||
assert "call_orphan_1" not in body
|
||||
assert '"status": "running"' not in body
|
||||
assert '"status": "completed"' not in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_message_returns_400(self, adapter):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user