fix(mcp): normalize nullable tool schemas

This commit is contained in:
Pony.Ma 2026-04-28 09:57:58 +08:00 committed by Teknium
parent 9cd02b1698
commit 02ae152222
4 changed files with 224 additions and 21 deletions

View File

@ -1117,6 +1117,49 @@ def _sanitize_tool_id(tool_id: str) -> str:
return sanitized or "tool_0"
def _normalize_tool_input_schema(schema: Any) -> Dict[str, Any]:
"""Normalize tool schemas before sending them to Anthropic.
Anthropic's tool schema validator rejects nullable unions such as
``anyOf: [{"type": "string"}, {"type": "null"}]`` that Pydantic/MCP
commonly emits for optional fields. Tool optionality is represented by
the parent ``required`` array, so collapse nullable unions to the non-null
branch while preserving metadata like description/default.
"""
if not schema:
return {"type": "object", "properties": {}}
def _strip_nullable_union(node: Any) -> Any:
if isinstance(node, list):
return [_strip_nullable_union(item) for item in node]
if not isinstance(node, dict):
return node
stripped = {k: _strip_nullable_union(v) for k, v in node.items()}
for key in ("anyOf", "oneOf"):
variants = stripped.get(key)
if not isinstance(variants, list):
continue
non_null = [
item for item in variants
if not (isinstance(item, dict) and item.get("type") == "null")
]
if len(non_null) == 1 and len(non_null) != len(variants):
replacement = dict(non_null[0]) if isinstance(non_null[0], dict) else {}
for meta_key in ("title", "description", "default", "examples"):
if meta_key in stripped and meta_key not in replacement:
replacement[meta_key] = stripped[meta_key]
return _strip_nullable_union(replacement)
return stripped
normalized = _strip_nullable_union(schema)
if not isinstance(normalized, dict):
return {"type": "object", "properties": {}}
if normalized.get("type") == "object" and not isinstance(normalized.get("properties"), dict):
normalized = {**normalized, "properties": {}}
return normalized
def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
"""Convert OpenAI tool definitions to Anthropic format."""
if not tools:
@ -1127,7 +1170,9 @@ def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
result.append({
"name": fn.get("name", ""),
"description": fn.get("description", ""),
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
"input_schema": _normalize_tool_input_schema(
fn.get("parameters", {"type": "object", "properties": {}})
),
})
return result

View File

@ -544,6 +544,36 @@ class TestConvertTools:
assert convert_tools_to_anthropic([]) == []
assert convert_tools_to_anthropic(None) == []
def test_strips_nullable_union_from_input_schema(self):
tools = [
{
"type": "function",
"function": {
"name": "run",
"description": "Run command",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string"},
"timeout": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"default": None,
},
},
"required": ["command"],
},
},
}
]
result = convert_tools_to_anthropic(tools)
assert result[0]["input_schema"]["properties"]["timeout"] == {
"type": "integer",
"default": None,
}
assert result[0]["input_schema"]["required"] == ["command"]
# ---------------------------------------------------------------------------
# Message conversion

View File

@ -266,6 +266,56 @@ class TestSchemaConversion:
assert schema["properties"]["items"]["items"]["properties"] == {}
def test_optional_nullable_field_is_collapsed_to_non_null_schema(self):
"""Anthropic rejects MCP/Pydantic anyOf-null optional parameter schemas."""
from tools.mcp_tool import _normalize_mcp_input_schema
schema = _normalize_mcp_input_schema({
"type": "object",
"properties": {
"command": {"type": "string"},
"workdir": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"description": "Optional working directory",
},
},
"required": ["command"],
})
assert schema["properties"]["workdir"] == {
"type": "string",
"default": None,
"description": "Optional working directory",
}
assert schema["required"] == ["command"]
def test_nested_nullable_array_items_are_collapsed(self):
from tools.mcp_tool import _normalize_mcp_input_schema
schema = _normalize_mcp_input_schema({
"type": "object",
"properties": {
"filters": {
"type": "array",
"items": {
"oneOf": [
{
"type": "object",
"properties": {"field": {"type": "string"}},
},
{"type": "null"},
]
},
}
},
})
assert schema["properties"]["filters"]["items"] == {
"type": "object",
"properties": {"field": {"type": "string"}},
}
def test_convert_mcp_schema_survives_missing_inputschema_attribute(self):
"""A Tool object without .inputSchema must not crash registration."""
import types
@ -1910,15 +1960,38 @@ class TestUtilityToolRegistration:
import math
import time
from mcp.types import (
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
SamplingCapability,
SamplingToolsCapability,
TextContent,
ToolUseContent,
)
class _CompatType:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
try:
from mcp.types import (
CreateMessageResult,
ErrorData,
SamplingCapability,
TextContent,
)
except ImportError:
CreateMessageResult = _CompatType
ErrorData = _CompatType
SamplingCapability = _CompatType
TextContent = _CompatType
try:
from mcp.types import CreateMessageResultWithTools
except ImportError:
CreateMessageResultWithTools = _CompatType
try:
from mcp.types import SamplingToolsCapability
except ImportError:
SamplingToolsCapability = _CompatType
try:
from mcp.types import ToolUseContent
except ImportError:
ToolUseContent = _CompatType
from tools.mcp_tool import SamplingHandler, _safe_numeric

View File

@ -868,6 +868,7 @@ class MCPServerTask:
"_task", "_ready", "_shutdown_event", "_reconnect_event",
"_tools", "_error", "_config",
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
"_rpc_lock",
)
def __init__(self, name: str):
@ -890,6 +891,12 @@ class MCPServerTask:
self._registered_tool_names: list[str] = []
self._auth_type: str = ""
self._refresh_lock = asyncio.Lock()
# MCP stdio sessions are a single JSON-RPC stream. Some servers emit
# list_changed notifications during startup; if the notification
# handler calls list_tools while a normal tool call is in flight, the
# stream can wedge and the user-visible tool call times out. Serialize
# client-initiated RPCs per server.
self._rpc_lock = asyncio.Lock()
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@ -916,7 +923,16 @@ class MCPServerTask:
"MCP server '%s': received tools/list_changed notification",
self.name,
)
await self._refresh_tools()
# Some servers (notably mongodb-mcp-server) emit
# tools/list_changed immediately after initialize,
# while the client may already be executing another
# request. Refreshing synchronously inside the SDK
# notification handler can race with that request
# and wedge the stdio JSON-RPC stream, making all
# subsequent tool calls time out. Do the refresh in
# a separate task and let the handler return
# promptly.
asyncio.create_task(self._refresh_tools())
case PromptListChangedNotification():
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
case ResourceListChangedNotification():
@ -942,12 +958,15 @@ class MCPServerTask:
old_tool_names = set(self._registered_tool_names)
# 1. Fetch current tool list from server
tools_result = await self.session.list_tools()
async with self._rpc_lock:
tools_result = await self.session.list_tools()
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
# 2. Deregister old tools from the central registry
for prefixed_name in self._registered_tool_names:
registry.deregister(prefixed_name)
# 2. Re-register with fresh tool list. Avoid deregistering first:
# live agent turns already have tool-call IDs pointing at the
# existing handler functions. Replacing entries in-place is enough
# for unchanged names and avoids transient "tool not connected" /
# stale-handler races during startup notifications.
# 3. Re-register with fresh tool list
self._tools = new_mcp_tools
@ -1204,7 +1223,8 @@ class MCPServerTask:
"""Discover tools from the connected session."""
if self.session is None:
return
tools_result = await self.session.list_tools()
async with self._rpc_lock:
tools_result = await self.session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
@ -1954,7 +1974,8 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
}, ensure_ascii=False)
async def _call():
result = await server.session.call_tool(tool_name, arguments=args)
async with server._rpc_lock:
result = await server.session.call_tool(tool_name, arguments=args)
# MCP CallToolResult has .content (list of content blocks) and .isError
if result.isError:
error_text = ""
@ -2052,7 +2073,8 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
}, ensure_ascii=False)
async def _call():
result = await server.session.list_resources()
async with server._rpc_lock:
result = await server.session.list_resources()
resources = []
for r in (result.resources if hasattr(result, "resources") else []):
entry = {}
@ -2115,7 +2137,8 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
return tool_error("Missing required parameter 'uri'")
async def _call():
result = await server.session.read_resource(uri)
async with server._rpc_lock:
result = await server.session.read_resource(uri)
# read_resource returns ReadResourceResult with .contents list
parts: List[str] = []
contents = result.contents if hasattr(result, "contents") else []
@ -2168,7 +2191,8 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
}, ensure_ascii=False)
async def _call():
result = await server.session.list_prompts()
async with server._rpc_lock:
result = await server.session.list_prompts()
prompts = []
for p in (result.prompts if hasattr(result, "prompts") else []):
entry = {}
@ -2237,7 +2261,8 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
arguments = args.get("arguments", {})
async def _call():
result = await server.session.get_prompt(name, arguments=arguments)
async with server._rpc_lock:
result = await server.session.get_prompt(name, arguments=arguments)
# GetPromptResult has .messages list
messages = []
for msg in (result.messages if hasattr(result, "messages") else []):
@ -2321,6 +2346,11 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
* ``required`` arrays are pruned to only names that exist in
``properties``; otherwise Google AI Studio / Gemini 400s with
``property is not defined``. See PR #4651.
* MCP/Pydantic optional fields commonly arrive as
``anyOf: [{...}, {"type": "null"}], default: null``. Anthropic rejects
nullable branches in tool input schemas, so nullable unions are collapsed
to the non-null branch and optionality remains represented solely by the
parent object's ``required`` list.
All repairs are provider-agnostic and ideally produce a schema valid on
OpenAI, Anthropic, Gemini, and Moonshot in one pass.
@ -2342,6 +2372,30 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
return [_rewrite_local_refs(item) for item in node]
return node
def _strip_nullable_union(node):
"""Collapse JSON Schema nullable unions to provider-safe non-null schemas."""
if isinstance(node, list):
return [_strip_nullable_union(item) for item in node]
if not isinstance(node, dict):
return node
stripped = {k: _strip_nullable_union(v) for k, v in node.items()}
for key in ("anyOf", "oneOf"):
variants = stripped.get(key)
if not isinstance(variants, list):
continue
non_null = [
item for item in variants
if not (isinstance(item, dict) and item.get("type") == "null")
]
if len(non_null) == 1 and len(non_null) != len(variants):
replacement = dict(non_null[0]) if isinstance(non_null[0], dict) else {}
for meta_key in ("title", "description", "default", "examples"):
if meta_key in stripped and meta_key not in replacement:
replacement[meta_key] = stripped[meta_key]
return _strip_nullable_union(replacement)
return stripped
def _repair_object_shape(node):
"""Recursively repair object-shaped nodes: fill type, prune required."""
if isinstance(node, list):
@ -2381,6 +2435,7 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
return repaired
normalized = _rewrite_local_refs(schema)
normalized = _strip_nullable_union(normalized)
normalized = _repair_object_shape(normalized)
# Ensure top-level is a well-formed object schema