fix(mcp): normalize nullable tool schemas
This commit is contained in:
parent
9cd02b1698
commit
02ae152222
@ -1117,6 +1117,49 @@ def _sanitize_tool_id(tool_id: str) -> str:
|
||||
return sanitized or "tool_0"
|
||||
|
||||
|
||||
def _normalize_tool_input_schema(schema: Any) -> Dict[str, Any]:
|
||||
"""Normalize tool schemas before sending them to Anthropic.
|
||||
|
||||
Anthropic's tool schema validator rejects nullable unions such as
|
||||
``anyOf: [{"type": "string"}, {"type": "null"}]`` that Pydantic/MCP
|
||||
commonly emits for optional fields. Tool optionality is represented by
|
||||
the parent ``required`` array, so collapse nullable unions to the non-null
|
||||
branch while preserving metadata like description/default.
|
||||
"""
|
||||
if not schema:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
def _strip_nullable_union(node: Any) -> Any:
|
||||
if isinstance(node, list):
|
||||
return [_strip_nullable_union(item) for item in node]
|
||||
if not isinstance(node, dict):
|
||||
return node
|
||||
|
||||
stripped = {k: _strip_nullable_union(v) for k, v in node.items()}
|
||||
for key in ("anyOf", "oneOf"):
|
||||
variants = stripped.get(key)
|
||||
if not isinstance(variants, list):
|
||||
continue
|
||||
non_null = [
|
||||
item for item in variants
|
||||
if not (isinstance(item, dict) and item.get("type") == "null")
|
||||
]
|
||||
if len(non_null) == 1 and len(non_null) != len(variants):
|
||||
replacement = dict(non_null[0]) if isinstance(non_null[0], dict) else {}
|
||||
for meta_key in ("title", "description", "default", "examples"):
|
||||
if meta_key in stripped and meta_key not in replacement:
|
||||
replacement[meta_key] = stripped[meta_key]
|
||||
return _strip_nullable_union(replacement)
|
||||
return stripped
|
||||
|
||||
normalized = _strip_nullable_union(schema)
|
||||
if not isinstance(normalized, dict):
|
||||
return {"type": "object", "properties": {}}
|
||||
if normalized.get("type") == "object" and not isinstance(normalized.get("properties"), dict):
|
||||
normalized = {**normalized, "properties": {}}
|
||||
return normalized
|
||||
|
||||
|
||||
def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
"""Convert OpenAI tool definitions to Anthropic format."""
|
||||
if not tools:
|
||||
@ -1127,7 +1170,9 @@ def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]:
|
||||
result.append({
|
||||
"name": fn.get("name", ""),
|
||||
"description": fn.get("description", ""),
|
||||
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
||||
"input_schema": _normalize_tool_input_schema(
|
||||
fn.get("parameters", {"type": "object", "properties": {}})
|
||||
),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
@ -544,6 +544,36 @@ class TestConvertTools:
|
||||
assert convert_tools_to_anthropic([]) == []
|
||||
assert convert_tools_to_anthropic(None) == []
|
||||
|
||||
def test_strips_nullable_union_from_input_schema(self):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run",
|
||||
"description": "Run command",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string"},
|
||||
"timeout": {
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": None,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = convert_tools_to_anthropic(tools)
|
||||
|
||||
assert result[0]["input_schema"]["properties"]["timeout"] == {
|
||||
"type": "integer",
|
||||
"default": None,
|
||||
}
|
||||
assert result[0]["input_schema"]["required"] == ["command"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message conversion
|
||||
|
||||
@ -266,6 +266,56 @@ class TestSchemaConversion:
|
||||
|
||||
assert schema["properties"]["items"]["items"]["properties"] == {}
|
||||
|
||||
def test_optional_nullable_field_is_collapsed_to_non_null_schema(self):
|
||||
"""Anthropic rejects MCP/Pydantic anyOf-null optional parameter schemas."""
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string"},
|
||||
"workdir": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"description": "Optional working directory",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
})
|
||||
|
||||
assert schema["properties"]["workdir"] == {
|
||||
"type": "string",
|
||||
"default": None,
|
||||
"description": "Optional working directory",
|
||||
}
|
||||
assert schema["required"] == ["command"]
|
||||
|
||||
def test_nested_nullable_array_items_are_collapsed(self):
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"field": {"type": "string"}},
|
||||
},
|
||||
{"type": "null"},
|
||||
]
|
||||
},
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
assert schema["properties"]["filters"]["items"] == {
|
||||
"type": "object",
|
||||
"properties": {"field": {"type": "string"}},
|
||||
}
|
||||
|
||||
def test_convert_mcp_schema_survives_missing_inputschema_attribute(self):
|
||||
"""A Tool object without .inputSchema must not crash registration."""
|
||||
import types
|
||||
@ -1910,15 +1960,38 @@ class TestUtilityToolRegistration:
|
||||
import math
|
||||
import time
|
||||
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
class _CompatType:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
try:
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
TextContent,
|
||||
)
|
||||
except ImportError:
|
||||
CreateMessageResult = _CompatType
|
||||
ErrorData = _CompatType
|
||||
SamplingCapability = _CompatType
|
||||
TextContent = _CompatType
|
||||
|
||||
try:
|
||||
from mcp.types import CreateMessageResultWithTools
|
||||
except ImportError:
|
||||
CreateMessageResultWithTools = _CompatType
|
||||
|
||||
try:
|
||||
from mcp.types import SamplingToolsCapability
|
||||
except ImportError:
|
||||
SamplingToolsCapability = _CompatType
|
||||
|
||||
try:
|
||||
from mcp.types import ToolUseContent
|
||||
except ImportError:
|
||||
ToolUseContent = _CompatType
|
||||
|
||||
from tools.mcp_tool import SamplingHandler, _safe_numeric
|
||||
|
||||
|
||||
@ -868,6 +868,7 @@ class MCPServerTask:
|
||||
"_task", "_ready", "_shutdown_event", "_reconnect_event",
|
||||
"_tools", "_error", "_config",
|
||||
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
||||
"_rpc_lock",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@ -890,6 +891,12 @@ class MCPServerTask:
|
||||
self._registered_tool_names: list[str] = []
|
||||
self._auth_type: str = ""
|
||||
self._refresh_lock = asyncio.Lock()
|
||||
# MCP stdio sessions are a single JSON-RPC stream. Some servers emit
|
||||
# list_changed notifications during startup; if the notification
|
||||
# handler calls list_tools while a normal tool call is in flight, the
|
||||
# stream can wedge and the user-visible tool call times out. Serialize
|
||||
# client-initiated RPCs per server.
|
||||
self._rpc_lock = asyncio.Lock()
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@ -916,7 +923,16 @@ class MCPServerTask:
|
||||
"MCP server '%s': received tools/list_changed notification",
|
||||
self.name,
|
||||
)
|
||||
await self._refresh_tools()
|
||||
# Some servers (notably mongodb-mcp-server) emit
|
||||
# tools/list_changed immediately after initialize,
|
||||
# while the client may already be executing another
|
||||
# request. Refreshing synchronously inside the SDK
|
||||
# notification handler can race with that request
|
||||
# and wedge the stdio JSON-RPC stream, making all
|
||||
# subsequent tool calls time out. Do the refresh in
|
||||
# a separate task and let the handler return
|
||||
# promptly.
|
||||
asyncio.create_task(self._refresh_tools())
|
||||
case PromptListChangedNotification():
|
||||
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
|
||||
case ResourceListChangedNotification():
|
||||
@ -942,12 +958,15 @@ class MCPServerTask:
|
||||
old_tool_names = set(self._registered_tool_names)
|
||||
|
||||
# 1. Fetch current tool list from server
|
||||
tools_result = await self.session.list_tools()
|
||||
async with self._rpc_lock:
|
||||
tools_result = await self.session.list_tools()
|
||||
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
|
||||
# 2. Deregister old tools from the central registry
|
||||
for prefixed_name in self._registered_tool_names:
|
||||
registry.deregister(prefixed_name)
|
||||
# 2. Re-register with fresh tool list. Avoid deregistering first:
|
||||
# live agent turns already have tool-call IDs pointing at the
|
||||
# existing handler functions. Replacing entries in-place is enough
|
||||
# for unchanged names and avoids transient "tool not connected" /
|
||||
# stale-handler races during startup notifications.
|
||||
|
||||
# 3. Re-register with fresh tool list
|
||||
self._tools = new_mcp_tools
|
||||
@ -1204,7 +1223,8 @@ class MCPServerTask:
|
||||
"""Discover tools from the connected session."""
|
||||
if self.session is None:
|
||||
return
|
||||
tools_result = await self.session.list_tools()
|
||||
async with self._rpc_lock:
|
||||
tools_result = await self.session.list_tools()
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
@ -1954,7 +1974,8 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
}, ensure_ascii=False)
|
||||
|
||||
async def _call():
|
||||
result = await server.session.call_tool(tool_name, arguments=args)
|
||||
async with server._rpc_lock:
|
||||
result = await server.session.call_tool(tool_name, arguments=args)
|
||||
# MCP CallToolResult has .content (list of content blocks) and .isError
|
||||
if result.isError:
|
||||
error_text = ""
|
||||
@ -2052,7 +2073,8 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||
}, ensure_ascii=False)
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_resources()
|
||||
async with server._rpc_lock:
|
||||
result = await server.session.list_resources()
|
||||
resources = []
|
||||
for r in (result.resources if hasattr(result, "resources") else []):
|
||||
entry = {}
|
||||
@ -2115,7 +2137,8 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
return tool_error("Missing required parameter 'uri'")
|
||||
|
||||
async def _call():
|
||||
result = await server.session.read_resource(uri)
|
||||
async with server._rpc_lock:
|
||||
result = await server.session.read_resource(uri)
|
||||
# read_resource returns ReadResourceResult with .contents list
|
||||
parts: List[str] = []
|
||||
contents = result.contents if hasattr(result, "contents") else []
|
||||
@ -2168,7 +2191,8 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||
}, ensure_ascii=False)
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_prompts()
|
||||
async with server._rpc_lock:
|
||||
result = await server.session.list_prompts()
|
||||
prompts = []
|
||||
for p in (result.prompts if hasattr(result, "prompts") else []):
|
||||
entry = {}
|
||||
@ -2237,7 +2261,8 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
arguments = args.get("arguments", {})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.get_prompt(name, arguments=arguments)
|
||||
async with server._rpc_lock:
|
||||
result = await server.session.get_prompt(name, arguments=arguments)
|
||||
# GetPromptResult has .messages list
|
||||
messages = []
|
||||
for msg in (result.messages if hasattr(result, "messages") else []):
|
||||
@ -2321,6 +2346,11 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
||||
* ``required`` arrays are pruned to only names that exist in
|
||||
``properties``; otherwise Google AI Studio / Gemini 400s with
|
||||
``property is not defined``. See PR #4651.
|
||||
* MCP/Pydantic optional fields commonly arrive as
|
||||
``anyOf: [{...}, {"type": "null"}], default: null``. Anthropic rejects
|
||||
nullable branches in tool input schemas, so nullable unions are collapsed
|
||||
to the non-null branch and optionality remains represented solely by the
|
||||
parent object's ``required`` list.
|
||||
|
||||
All repairs are provider-agnostic and ideally produce a schema valid on
|
||||
OpenAI, Anthropic, Gemini, and Moonshot in one pass.
|
||||
@ -2342,6 +2372,30 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
||||
return [_rewrite_local_refs(item) for item in node]
|
||||
return node
|
||||
|
||||
def _strip_nullable_union(node):
|
||||
"""Collapse JSON Schema nullable unions to provider-safe non-null schemas."""
|
||||
if isinstance(node, list):
|
||||
return [_strip_nullable_union(item) for item in node]
|
||||
if not isinstance(node, dict):
|
||||
return node
|
||||
|
||||
stripped = {k: _strip_nullable_union(v) for k, v in node.items()}
|
||||
for key in ("anyOf", "oneOf"):
|
||||
variants = stripped.get(key)
|
||||
if not isinstance(variants, list):
|
||||
continue
|
||||
non_null = [
|
||||
item for item in variants
|
||||
if not (isinstance(item, dict) and item.get("type") == "null")
|
||||
]
|
||||
if len(non_null) == 1 and len(non_null) != len(variants):
|
||||
replacement = dict(non_null[0]) if isinstance(non_null[0], dict) else {}
|
||||
for meta_key in ("title", "description", "default", "examples"):
|
||||
if meta_key in stripped and meta_key not in replacement:
|
||||
replacement[meta_key] = stripped[meta_key]
|
||||
return _strip_nullable_union(replacement)
|
||||
return stripped
|
||||
|
||||
def _repair_object_shape(node):
|
||||
"""Recursively repair object-shaped nodes: fill type, prune required."""
|
||||
if isinstance(node, list):
|
||||
@ -2381,6 +2435,7 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
||||
return repaired
|
||||
|
||||
normalized = _rewrite_local_refs(schema)
|
||||
normalized = _strip_nullable_union(normalized)
|
||||
normalized = _repair_object_shape(normalized)
|
||||
|
||||
# Ensure top-level is a well-formed object schema
|
||||
|
||||
Loading…
Reference in New Issue
Block a user