Python: Fix: MCP tool calls flatten nested JSON arguments (handle $ref schemas) (#990)

* Initial plan

* Fix nested JSON argument flattening in MCP tools - handle $ref schemas

Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com>

* Apply code formatting fixes from ruff linter

Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com>

* Fixed formatting

* Refactor JSON type mapping to use match-case statement

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com>
Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>
This commit is contained in:
Copilot
2025-09-30 17:22:46 -07:00
committed by GitHub
Unverified
parent fa9f5c1aed
commit bdfa62c834
2 changed files with 152 additions and 13 deletions
+34 -13
View File
@@ -181,27 +181,48 @@ def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
"""Creates a Pydantic model from a tools parameters."""
properties = tool.inputSchema.get("properties", None)
required = tool.inputSchema.get("required", [])
definitions = tool.inputSchema.get("$defs", {})
# Check if 'properties' is missing or not a dictionary
if not properties:
return create_model(f"{tool.name}_input")
def resolve_type(prop_details: dict[str, Any]) -> type:
"""Resolve JSON Schema type to Python type, handling $ref."""
# Handle $ref by resolving the reference
if "$ref" in prop_details:
ref = prop_details["$ref"]
# Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam")
if ref.startswith("#/$defs/"):
def_name = ref.split("/")[-1]
if def_name in definitions:
# Resolve the reference and use its type
resolved = definitions[def_name]
return resolve_type(resolved)
# If we can't resolve the ref, default to dict for safety
return dict
# Map JSON Schema types to Python types
json_type = prop_details.get("type", "string")
match json_type:
case "integer":
return int
case "number":
return float
case "boolean":
return bool
case "array":
return list
case "object":
return dict
case _:
return str # default
field_definitions: dict[str, Any] = {}
for prop_name, prop_details in properties.items():
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
# Map JSON Schema types to Python types
json_type = prop_details.get("type", "string")
python_type: type = str # default
if json_type == "integer":
python_type = int
elif json_type == "number":
python_type = float
elif json_type == "boolean":
python_type = bool
elif json_type == "array":
python_type = list
elif json_type == "object":
python_type = dict
python_type = resolve_type(prop_details)
# Create field definition for create_model
if prop_name in required:
+118
View File
@@ -252,6 +252,71 @@ def test_get_input_model_from_mcp_tool():
model(param2=42)
def test_get_input_model_from_mcp_tool_with_nested_object():
"""Test creation of input model from MCP tool with nested object property."""
tool = types.Tool(
name="get_customer_detail",
description="Get customer details",
inputSchema={
"type": "object",
"properties": {
"params": {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
},
"required": ["params"],
},
)
model = _get_input_model_from_mcp_tool(tool)
# Create an instance to verify the model works with nested objects
instance = model(params={"customer_id": 251})
assert instance.params == {"customer_id": 251}
assert isinstance(instance.params, dict)
# Verify model_dump produces the correct nested structure
dumped = instance.model_dump()
assert dumped == {"params": {"customer_id": 251}}
def test_get_input_model_from_mcp_tool_with_ref_schema():
"""Test creation of input model from MCP tool with $ref schema.
This simulates a FastMCP tool that uses Pydantic models with $ref in the schema.
The schema should be resolved and nested objects should be preserved.
"""
# This is similar to what FastMCP generates when you have:
# async def get_customer_detail(params: CustomerIdParam) -> CustomerDetail
tool = types.Tool(
name="get_customer_detail",
description="Get customer details",
inputSchema={
"type": "object",
"properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}},
"required": ["params"],
"$defs": {
"CustomerIdParam": {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
},
},
)
model = _get_input_model_from_mcp_tool(tool)
# Create an instance to verify the model works with $ref schemas
instance = model(params={"customer_id": 251})
assert instance.params == {"customer_id": 251}
assert isinstance(instance.params, dict)
# Verify model_dump produces the correct nested structure
dumped = instance.model_dump()
assert dumped == {"params": {"customer_id": 251}}
def test_get_input_model_from_mcp_prompt():
"""Test creation of input model from MCP prompt."""
prompt = types.Prompt(
@@ -406,6 +471,59 @@ async def test_local_mcp_server_function_execution():
assert result[0].text == "Tool executed successfully"
async def test_local_mcp_server_function_execution_with_nested_object():
"""Test function execution through MCP server with nested object arguments."""
class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="get_customer_detail",
description="Get customer details",
inputSchema={
"type": "object",
"properties": {
"params": {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
},
"required": ["params"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(
content=[types.TextContent(type="text", text='{"name": "John Doe", "id": 251}')]
)
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server")
async with server:
await server.load_tools()
func = server.functions[0]
# Call with nested object
result = await func.invoke(params={"customer_id": 251})
assert len(result) == 1
assert isinstance(result[0], TextContent)
# Verify the session.call_tool was called with the correct nested structure
server.session.call_tool.assert_called_once()
call_args = server.session.call_tool.call_args
assert call_args.kwargs["arguments"] == {"params": {"customer_id": 251}}
async def test_local_mcp_server_function_execution_error():
"""Test function execution error handling."""