From bdfa62c83409bc75cd20dd6474e063db3af00cb3 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 30 Sep 2025 17:22:46 -0700 Subject: [PATCH] Python: Fix: MCP tool calls flatten nested JSON arguments (handle $ref schemas) (#990) * Initial plan * Fix nested JSON argument flattening in MCP tools - handle $ref schemas Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com> * Apply code formatting fixes from ruff linter Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com> * Fixed formatting * Refactor JSON type mapping to use match-case statement Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com> Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> --- python/packages/core/agent_framework/_mcp.py | 47 ++++++-- python/packages/core/tests/core/test_mcp.py | 118 +++++++++++++++++++ 2 files changed, 152 insertions(+), 13 deletions(-) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index af48f137d3..6de34cbe0f 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -181,27 +181,48 @@ def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]: """Creates a Pydantic model from a tools parameters.""" properties = tool.inputSchema.get("properties", None) required = tool.inputSchema.get("required", []) + definitions = tool.inputSchema.get("$defs", {}) + # Check if 'properties' is missing or not a dictionary if not properties: return create_model(f"{tool.name}_input") + def resolve_type(prop_details: dict[str, Any]) -> type: + """Resolve JSON Schema type to Python type, handling $ref.""" + # Handle $ref by resolving the reference + if "$ref" in prop_details: + ref = prop_details["$ref"] + # Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam") + if ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + if def_name in definitions: + # Resolve the reference and use its type + resolved = definitions[def_name] + return resolve_type(resolved) + # If we can't resolve the ref, default to dict for safety + return dict + + # Map JSON Schema types to Python types + json_type = prop_details.get("type", "string") + match json_type: + case "integer": + return int + case "number": + return float + case "boolean": + return bool + case "array": + return list + case "object": + return dict + case _: + return str # default + field_definitions: dict[str, Any] = {} for prop_name, prop_details in properties.items(): prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details - # Map JSON Schema types to Python types - json_type = prop_details.get("type", "string") - python_type: type = str # default - if json_type == "integer": - python_type = int - elif json_type == "number": - python_type = float - elif json_type == "boolean": - python_type = bool - elif json_type == "array": - python_type = list - elif json_type == "object": - python_type = dict + python_type = resolve_type(prop_details) # Create field definition for create_model if prop_name in required: diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index df768fee8d..ae230c7239 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -252,6 +252,71 @@ def test_get_input_model_from_mcp_tool(): model(param2=42) +def test_get_input_model_from_mcp_tool_with_nested_object(): + """Test creation of input model from MCP tool with nested object property.""" + tool = types.Tool( + name="get_customer_detail", + description="Get customer details", + inputSchema={ + "type": "object", + "properties": { + "params": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, + "required": ["params"], + }, + ) + model = _get_input_model_from_mcp_tool(tool) + + # Create an instance to verify the model works with nested objects + instance = model(params={"customer_id": 251}) + assert instance.params == {"customer_id": 251} + assert isinstance(instance.params, dict) + + # Verify model_dump produces the correct nested structure + dumped = instance.model_dump() + assert dumped == {"params": {"customer_id": 251}} + + +def test_get_input_model_from_mcp_tool_with_ref_schema(): + """Test creation of input model from MCP tool with $ref schema. + + This simulates a FastMCP tool that uses Pydantic models with $ref in the schema. + The schema should be resolved and nested objects should be preserved. + """ + # This is similar to what FastMCP generates when you have: + # async def get_customer_detail(params: CustomerIdParam) -> CustomerDetail + tool = types.Tool( + name="get_customer_detail", + description="Get customer details", + inputSchema={ + "type": "object", + "properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}}, + "required": ["params"], + "$defs": { + "CustomerIdParam": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, + }, + ) + model = _get_input_model_from_mcp_tool(tool) + + # Create an instance to verify the model works with $ref schemas + instance = model(params={"customer_id": 251}) + assert instance.params == {"customer_id": 251} + assert isinstance(instance.params, dict) + + # Verify model_dump produces the correct nested structure + dumped = instance.model_dump() + assert dumped == {"params": {"customer_id": 251}} + + def test_get_input_model_from_mcp_prompt(): """Test creation of input model from MCP prompt.""" prompt = types.Prompt( @@ -406,6 +471,59 @@ async def test_local_mcp_server_function_execution(): assert result[0].text == "Tool executed successfully" +async def test_local_mcp_server_function_execution_with_nested_object(): + """Test function execution through MCP server with nested object arguments.""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="get_customer_detail", + description="Get customer details", + inputSchema={ + "type": "object", + "properties": { + "params": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, + "required": ["params"], + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult( + content=[types.TextContent(type="text", text='{"name": "John Doe", "id": 251}')] + ) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + func = server.functions[0] + + # Call with nested object + result = await func.invoke(params={"customer_id": 251}) + + assert len(result) == 1 + assert isinstance(result[0], TextContent) + + # Verify the session.call_tool was called with the correct nested structure + server.session.call_tool.assert_called_once() + call_args = server.session.call_tool.call_args + assert call_args.kwargs["arguments"] == {"params": {"customer_id": 251}} + + async def test_local_mcp_server_function_execution_error(): """Test function execution error handling."""