fix: keep mcp dynamic refresh tasks tracked
This commit is contained in:
parent
02ae152222
commit
1350d12b0b
@ -706,6 +706,106 @@ class TestMCPServerTask:
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_refresh_tools_deregisters_removed_tools(self):
|
||||
"""Dynamic refresh removes stale registry entries for deleted tools."""
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
server = MCPServerTask("srv")
|
||||
server._config = {"command": "test"}
|
||||
server._tools = [_make_mcp_tool("old"), _make_mcp_tool("keep")]
|
||||
server._registered_tool_names = ["mcp_srv_old", "mcp_srv_keep"]
|
||||
server.session = MagicMock()
|
||||
server.session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=[_make_mcp_tool("keep"), _make_mcp_tool("new")])
|
||||
)
|
||||
|
||||
with patch("tools.registry.registry", mock_registry):
|
||||
mock_registry.register(
|
||||
name="mcp_srv_old",
|
||||
toolset="mcp-srv",
|
||||
schema={"name": "mcp_srv_old", "description": "Old"},
|
||||
handler=lambda *_args, **_kwargs: "{}",
|
||||
)
|
||||
mock_registry.register(
|
||||
name="mcp_srv_keep",
|
||||
toolset="mcp-srv",
|
||||
schema={"name": "mcp_srv_keep", "description": "Keep"},
|
||||
handler=lambda *_args, **_kwargs: "{}",
|
||||
)
|
||||
|
||||
asyncio.run(server._refresh_tools())
|
||||
|
||||
names = mock_registry.get_all_tool_names()
|
||||
assert "mcp_srv_old" not in names
|
||||
assert "mcp_srv_keep" in names
|
||||
assert "mcp_srv_new" in names
|
||||
assert set(server._registered_tool_names) == {
|
||||
"mcp_srv_keep",
|
||||
"mcp_srv_new",
|
||||
"mcp_srv_list_resources",
|
||||
"mcp_srv_read_resource",
|
||||
"mcp_srv_list_prompts",
|
||||
"mcp_srv_get_prompt",
|
||||
}
|
||||
|
||||
def test_schedule_tools_refresh_keeps_task_until_done(self):
|
||||
"""Background refresh tasks are strongly referenced and then discarded."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
async def _test():
|
||||
started = asyncio.Event()
|
||||
finish = asyncio.Event()
|
||||
server = MCPServerTask("srv")
|
||||
|
||||
async def fake_refresh(_server):
|
||||
started.set()
|
||||
await finish.wait()
|
||||
|
||||
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
|
||||
server._schedule_tools_refresh()
|
||||
|
||||
await started.wait()
|
||||
assert len(server._pending_refresh_tasks) == 1
|
||||
task = next(iter(server._pending_refresh_tasks))
|
||||
assert not task.done()
|
||||
|
||||
finish.set()
|
||||
await task
|
||||
await asyncio.sleep(0)
|
||||
assert server._pending_refresh_tasks == set()
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_shutdown_cancels_pending_refresh_tasks(self):
|
||||
"""shutdown() cancels in-flight background refresh tasks."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
async def _test():
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
server = MCPServerTask("srv")
|
||||
|
||||
async def fake_refresh(_server):
|
||||
started.set()
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
|
||||
server._schedule_tools_refresh()
|
||||
await started.wait()
|
||||
|
||||
await server.shutdown()
|
||||
|
||||
assert cancelled.is_set()
|
||||
assert server._pending_refresh_tasks == set()
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_empty_env_gets_safe_defaults(self):
|
||||
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
@ -1993,7 +2093,13 @@ try:
|
||||
except ImportError:
|
||||
ToolUseContent = _CompatType
|
||||
|
||||
from tools.mcp_tool import SamplingHandler, _safe_numeric
|
||||
from tools.mcp_tool import (
|
||||
CreateMessageResultWithTools,
|
||||
SamplingHandler,
|
||||
SamplingToolsCapability,
|
||||
ToolUseContent,
|
||||
_safe_numeric,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -167,10 +167,22 @@ _MCP_HTTP_AVAILABLE = False
|
||||
_MCP_SAMPLING_TYPES = False
|
||||
_MCP_NOTIFICATION_TYPES = False
|
||||
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
||||
_MCP_NEW_HTTP = False
|
||||
streamablehttp_client = None
|
||||
streamable_http_client = None
|
||||
# Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION.
|
||||
# Streamable HTTP was introduced by 2025-03-26, so this remains valid for the
|
||||
# HTTP transport path even on older-but-supported SDK versions.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
|
||||
|
||||
class _CompatType:
|
||||
"""Minimal attribute bag for MCP SDK types missing in older/newer builds."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
@ -191,20 +203,28 @@ try:
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
except ImportError:
|
||||
logger.debug("mcp.types.LATEST_PROTOCOL_VERSION not available -- using fallback protocol version")
|
||||
# Sampling types -- separated so older SDK versions don't break MCP support
|
||||
# Sampling types -- import individually because SDK names changed across releases.
|
||||
try:
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
from mcp.types import CreateMessageResult, ErrorData, SamplingCapability, TextContent
|
||||
|
||||
try:
|
||||
from mcp.types import CreateMessageResultWithTools
|
||||
except ImportError:
|
||||
CreateMessageResultWithTools = _CompatType
|
||||
|
||||
try:
|
||||
from mcp.types import SamplingToolsCapability
|
||||
except ImportError:
|
||||
SamplingToolsCapability = _CompatType
|
||||
|
||||
try:
|
||||
from mcp.types import ToolUseContent
|
||||
except ImportError:
|
||||
ToolUseContent = _CompatType
|
||||
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
logger.debug("MCP sampling base types not available -- sampling disabled")
|
||||
# Notification types for dynamic tool discovery (tools/list_changed)
|
||||
try:
|
||||
from mcp.types import (
|
||||
@ -868,7 +888,7 @@ class MCPServerTask:
|
||||
"_task", "_ready", "_shutdown_event", "_reconnect_event",
|
||||
"_tools", "_error", "_config",
|
||||
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
||||
"_rpc_lock",
|
||||
"_rpc_lock", "_pending_refresh_tasks",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@ -895,8 +915,10 @@ class MCPServerTask:
|
||||
# list_changed notifications during startup; if the notification
|
||||
# handler calls list_tools while a normal tool call is in flight, the
|
||||
# stream can wedge and the user-visible tool call times out. Serialize
|
||||
# client-initiated RPCs per server.
|
||||
# client-initiated RPCs per server. The lock is also applied to HTTP
|
||||
# transports for conservative per-server ordering.
|
||||
self._rpc_lock = asyncio.Lock()
|
||||
self._pending_refresh_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@ -904,6 +926,21 @@ class MCPServerTask:
|
||||
|
||||
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
|
||||
|
||||
async def _refresh_tools_task(self):
|
||||
"""Run a dynamic tool refresh and log failures from background tasks."""
|
||||
try:
|
||||
await self._refresh_tools()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("MCP server '%s': dynamic tool refresh failed", self.name)
|
||||
|
||||
def _schedule_tools_refresh(self) -> None:
|
||||
"""Schedule a background tool refresh and keep it strongly referenced."""
|
||||
task = asyncio.create_task(self._refresh_tools_task())
|
||||
self._pending_refresh_tasks.add(task)
|
||||
task.add_done_callback(self._pending_refresh_tasks.discard)
|
||||
|
||||
def _make_message_handler(self):
|
||||
"""Build a ``message_handler`` callback for ``ClientSession``.
|
||||
|
||||
@ -932,7 +969,7 @@ class MCPServerTask:
|
||||
# subsequent tool calls time out. Do the refresh in
|
||||
# a separate task and let the handler return
|
||||
# promptly.
|
||||
asyncio.create_task(self._refresh_tools())
|
||||
self._schedule_tools_refresh()
|
||||
case PromptListChangedNotification():
|
||||
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
|
||||
case ResourceListChangedNotification():
|
||||
@ -962,11 +999,20 @@ class MCPServerTask:
|
||||
tools_result = await self.session.list_tools()
|
||||
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
|
||||
# 2. Re-register with fresh tool list. Avoid deregistering first:
|
||||
# live agent turns already have tool-call IDs pointing at the
|
||||
# existing handler functions. Replacing entries in-place is enough
|
||||
# for unchanged names and avoids transient "tool not connected" /
|
||||
# stale-handler races during startup notifications.
|
||||
# 2. Re-register with fresh tool list. Avoid nuke-and-repave for
|
||||
# all names: live agent turns may already have tool-call IDs
|
||||
# pointing at existing handler functions. Replacing entries
|
||||
# in-place is enough for unchanged names and avoids transient
|
||||
# "tool not connected" / stale-handler races during startup
|
||||
# notifications. Tools absent from the fresh list are no longer
|
||||
# callable, so remove only those stale registry entries first.
|
||||
stale_tool_names = old_tool_names - {
|
||||
f"mcp_{sanitize_mcp_name_component(self.name)}_"
|
||||
f"{sanitize_mcp_name_component(tool.name)}"
|
||||
for tool in new_mcp_tools
|
||||
}
|
||||
for tool_name in stale_tool_names:
|
||||
registry.deregister(tool_name)
|
||||
|
||||
# 3. Re-register with fresh tool list
|
||||
self._tools = new_mcp_tools
|
||||
@ -1383,6 +1429,11 @@ class MCPServerTask:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._pending_refresh_tasks:
|
||||
for task in list(self._pending_refresh_tasks):
|
||||
task.cancel()
|
||||
await asyncio.gather(*self._pending_refresh_tasks, return_exceptions=True)
|
||||
self._pending_refresh_tasks.clear()
|
||||
for tool_name in list(getattr(self, "_registered_tool_names", [])):
|
||||
registry.deregister(tool_name)
|
||||
self._registered_tool_names = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user