fix: keep mcp dynamic refresh tasks tracked

This commit is contained in:
Pony.Ma 2026-04-28 10:41:28 +08:00 committed by Teknium
parent 02ae152222
commit 1350d12b0b
2 changed files with 177 additions and 20 deletions

View File

@ -706,6 +706,106 @@ class TestMCPServerTask:
asyncio.run(_test())
def test_refresh_tools_deregisters_removed_tools(self):
"""Dynamic refresh removes stale registry entries for deleted tools."""
from tools.registry import ToolRegistry
from tools.mcp_tool import MCPServerTask
mock_registry = ToolRegistry()
server = MCPServerTask("srv")
server._config = {"command": "test"}
server._tools = [_make_mcp_tool("old"), _make_mcp_tool("keep")]
server._registered_tool_names = ["mcp_srv_old", "mcp_srv_keep"]
server.session = MagicMock()
server.session.list_tools = AsyncMock(
return_value=SimpleNamespace(tools=[_make_mcp_tool("keep"), _make_mcp_tool("new")])
)
with patch("tools.registry.registry", mock_registry):
mock_registry.register(
name="mcp_srv_old",
toolset="mcp-srv",
schema={"name": "mcp_srv_old", "description": "Old"},
handler=lambda *_args, **_kwargs: "{}",
)
mock_registry.register(
name="mcp_srv_keep",
toolset="mcp-srv",
schema={"name": "mcp_srv_keep", "description": "Keep"},
handler=lambda *_args, **_kwargs: "{}",
)
asyncio.run(server._refresh_tools())
names = mock_registry.get_all_tool_names()
assert "mcp_srv_old" not in names
assert "mcp_srv_keep" in names
assert "mcp_srv_new" in names
assert set(server._registered_tool_names) == {
"mcp_srv_keep",
"mcp_srv_new",
"mcp_srv_list_resources",
"mcp_srv_read_resource",
"mcp_srv_list_prompts",
"mcp_srv_get_prompt",
}
def test_schedule_tools_refresh_keeps_task_until_done(self):
"""Background refresh tasks are strongly referenced and then discarded."""
from tools.mcp_tool import MCPServerTask
async def _test():
started = asyncio.Event()
finish = asyncio.Event()
server = MCPServerTask("srv")
async def fake_refresh(_server):
started.set()
await finish.wait()
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
server._schedule_tools_refresh()
await started.wait()
assert len(server._pending_refresh_tasks) == 1
task = next(iter(server._pending_refresh_tasks))
assert not task.done()
finish.set()
await task
await asyncio.sleep(0)
assert server._pending_refresh_tasks == set()
asyncio.run(_test())
def test_shutdown_cancels_pending_refresh_tasks(self):
"""shutdown() cancels in-flight background refresh tasks."""
from tools.mcp_tool import MCPServerTask
async def _test():
started = asyncio.Event()
cancelled = asyncio.Event()
server = MCPServerTask("srv")
async def fake_refresh(_server):
started.set()
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
cancelled.set()
raise
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
server._schedule_tools_refresh()
await started.wait()
await server.shutdown()
assert cancelled.is_set()
assert server._pending_refresh_tasks == set()
asyncio.run(_test())
def test_empty_env_gets_safe_defaults(self):
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
from tools.mcp_tool import MCPServerTask
@ -1993,7 +2093,13 @@ try:
except ImportError:
ToolUseContent = _CompatType
from tools.mcp_tool import SamplingHandler, _safe_numeric
from tools.mcp_tool import (
CreateMessageResultWithTools,
SamplingHandler,
SamplingToolsCapability,
ToolUseContent,
_safe_numeric,
)
# ---------------------------------------------------------------------------

View File

@ -167,10 +167,22 @@ _MCP_HTTP_AVAILABLE = False
_MCP_SAMPLING_TYPES = False
_MCP_NOTIFICATION_TYPES = False
_MCP_MESSAGE_HANDLER_SUPPORTED = False
_MCP_NEW_HTTP = False
streamablehttp_client = None
streamable_http_client = None
# Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION.
# Streamable HTTP was introduced by 2025-03-26, so this remains valid for the
# HTTP transport path even on older-but-supported SDK versions.
LATEST_PROTOCOL_VERSION = "2025-03-26"
class _CompatType:
"""Minimal attribute bag for MCP SDK types missing in older/newer builds."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@ -191,20 +203,28 @@ try:
from mcp.types import LATEST_PROTOCOL_VERSION
except ImportError:
logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version")
# Sampling types -- separated so older SDK versions don't break MCP support
# Sampling types -- import individually because SDK names changed across releases.
try:
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
from mcp.types import CreateMessageResult, ErrorData, SamplingCapability, TextContent
try:
from mcp.types import CreateMessageResultWithTools
except ImportError:
CreateMessageResultWithTools = _CompatType
try:
from mcp.types import SamplingToolsCapability
except ImportError:
SamplingToolsCapability = _CompatType
try:
from mcp.types import ToolUseContent
except ImportError:
ToolUseContent = _CompatType
_MCP_SAMPLING_TYPES = True
except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled")
logger.debug("MCP sampling base types not available -- sampling disabled")
# Notification types for dynamic tool discovery (tools/list_changed)
try:
from mcp.types import (
@ -868,7 +888,7 @@ class MCPServerTask:
"_task", "_ready", "_shutdown_event", "_reconnect_event",
"_tools", "_error", "_config",
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
"_rpc_lock",
"_rpc_lock", "_pending_refresh_tasks",
)
def __init__(self, name: str):
@ -895,8 +915,10 @@ class MCPServerTask:
# list_changed notifications during startup; if the notification
# handler calls list_tools while a normal tool call is in flight, the
# stream can wedge and the user-visible tool call times out. Serialize
# client-initiated RPCs per server.
# client-initiated RPCs per server. The lock is also applied to HTTP
# transports for conservative per-server ordering.
self._rpc_lock = asyncio.Lock()
self._pending_refresh_tasks: set[asyncio.Task] = set()
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@ -904,6 +926,21 @@ class MCPServerTask:
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
async def _refresh_tools_task(self):
"""Run a dynamic tool refresh and log failures from background tasks."""
try:
await self._refresh_tools()
except asyncio.CancelledError:
raise
except Exception:
logger.exception("MCP server '%s': dynamic tool refresh failed", self.name)
def _schedule_tools_refresh(self) -> None:
"""Schedule a background tool refresh and keep it strongly referenced."""
task = asyncio.create_task(self._refresh_tools_task())
self._pending_refresh_tasks.add(task)
task.add_done_callback(self._pending_refresh_tasks.discard)
def _make_message_handler(self):
"""Build a ``message_handler`` callback for ``ClientSession``.
@ -932,7 +969,7 @@ class MCPServerTask:
# subsequent tool calls time out. Do the refresh in
# a separate task and let the handler return
# promptly.
asyncio.create_task(self._refresh_tools())
self._schedule_tools_refresh()
case PromptListChangedNotification():
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
case ResourceListChangedNotification():
@ -962,11 +999,20 @@ class MCPServerTask:
tools_result = await self.session.list_tools()
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
# 2. Re-register with fresh tool list. Avoid deregistering first:
# live agent turns already have tool-call IDs pointing at the
# existing handler functions. Replacing entries in-place is enough
# for unchanged names and avoids transient "tool not connected" /
# stale-handler races during startup notifications.
# 2. Re-register with fresh tool list. Avoid nuke-and-repave for
# all names: live agent turns may already have tool-call IDs
# pointing at existing handler functions. Replacing entries
# in-place is enough for unchanged names and avoids transient
# "tool not connected" / stale-handler races during startup
# notifications. Tools absent from the fresh list are no longer
# callable, so remove only those stale registry entries first.
stale_tool_names = old_tool_names - {
f"mcp_{sanitize_mcp_name_component(self.name)}_"
f"{sanitize_mcp_name_component(tool.name)}"
for tool in new_mcp_tools
}
for tool_name in stale_tool_names:
registry.deregister(tool_name)
# 3. Re-register with fresh tool list
self._tools = new_mcp_tools
@ -1383,6 +1429,11 @@ class MCPServerTask:
await self._task
except asyncio.CancelledError:
pass
if self._pending_refresh_tasks:
for task in list(self._pending_refresh_tasks):
task.cancel()
await asyncio.gather(*self._pending_refresh_tasks, return_exceptions=True)
self._pending_refresh_tasks.clear()
for tool_name in list(getattr(self, "_registered_tool_names", [])):
registry.deregister(tool_name)
self._registered_tool_names = []