mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix: MCP tool calls flatten nested JSON arguments (handle $ref schemas) (#990)
* Initial plan * Fix nested JSON argument flattening in MCP tools - handle $ref schemas Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com> * Apply code formatting fixes from ruff linter Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com> * Fixed formatting * Refactor JSON type mapping to use match-case statement Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: dmytrostruk <13853051+dmytrostruk@users.noreply.github.com> Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
fa9f5c1aed
commit
bdfa62c834
@@ -181,27 +181,48 @@ def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
|
||||
"""Creates a Pydantic model from a tools parameters."""
|
||||
properties = tool.inputSchema.get("properties", None)
|
||||
required = tool.inputSchema.get("required", [])
|
||||
definitions = tool.inputSchema.get("$defs", {})
|
||||
|
||||
# Check if 'properties' is missing or not a dictionary
|
||||
if not properties:
|
||||
return create_model(f"{tool.name}_input")
|
||||
|
||||
def resolve_type(prop_details: dict[str, Any]) -> type:
|
||||
"""Resolve JSON Schema type to Python type, handling $ref."""
|
||||
# Handle $ref by resolving the reference
|
||||
if "$ref" in prop_details:
|
||||
ref = prop_details["$ref"]
|
||||
# Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam")
|
||||
if ref.startswith("#/$defs/"):
|
||||
def_name = ref.split("/")[-1]
|
||||
if def_name in definitions:
|
||||
# Resolve the reference and use its type
|
||||
resolved = definitions[def_name]
|
||||
return resolve_type(resolved)
|
||||
# If we can't resolve the ref, default to dict for safety
|
||||
return dict
|
||||
|
||||
# Map JSON Schema types to Python types
|
||||
json_type = prop_details.get("type", "string")
|
||||
match json_type:
|
||||
case "integer":
|
||||
return int
|
||||
case "number":
|
||||
return float
|
||||
case "boolean":
|
||||
return bool
|
||||
case "array":
|
||||
return list
|
||||
case "object":
|
||||
return dict
|
||||
case _:
|
||||
return str # default
|
||||
|
||||
field_definitions: dict[str, Any] = {}
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
|
||||
|
||||
# Map JSON Schema types to Python types
|
||||
json_type = prop_details.get("type", "string")
|
||||
python_type: type = str # default
|
||||
if json_type == "integer":
|
||||
python_type = int
|
||||
elif json_type == "number":
|
||||
python_type = float
|
||||
elif json_type == "boolean":
|
||||
python_type = bool
|
||||
elif json_type == "array":
|
||||
python_type = list
|
||||
elif json_type == "object":
|
||||
python_type = dict
|
||||
python_type = resolve_type(prop_details)
|
||||
|
||||
# Create field definition for create_model
|
||||
if prop_name in required:
|
||||
|
||||
@@ -252,6 +252,71 @@ def test_get_input_model_from_mcp_tool():
|
||||
model(param2=42)
|
||||
|
||||
|
||||
def test_get_input_model_from_mcp_tool_with_nested_object():
|
||||
"""Test creation of input model from MCP tool with nested object property."""
|
||||
tool = types.Tool(
|
||||
name="get_customer_detail",
|
||||
description="Get customer details",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"params": {
|
||||
"type": "object",
|
||||
"properties": {"customer_id": {"type": "integer"}},
|
||||
"required": ["customer_id"],
|
||||
}
|
||||
},
|
||||
"required": ["params"],
|
||||
},
|
||||
)
|
||||
model = _get_input_model_from_mcp_tool(tool)
|
||||
|
||||
# Create an instance to verify the model works with nested objects
|
||||
instance = model(params={"customer_id": 251})
|
||||
assert instance.params == {"customer_id": 251}
|
||||
assert isinstance(instance.params, dict)
|
||||
|
||||
# Verify model_dump produces the correct nested structure
|
||||
dumped = instance.model_dump()
|
||||
assert dumped == {"params": {"customer_id": 251}}
|
||||
|
||||
|
||||
def test_get_input_model_from_mcp_tool_with_ref_schema():
|
||||
"""Test creation of input model from MCP tool with $ref schema.
|
||||
|
||||
This simulates a FastMCP tool that uses Pydantic models with $ref in the schema.
|
||||
The schema should be resolved and nested objects should be preserved.
|
||||
"""
|
||||
# This is similar to what FastMCP generates when you have:
|
||||
# async def get_customer_detail(params: CustomerIdParam) -> CustomerDetail
|
||||
tool = types.Tool(
|
||||
name="get_customer_detail",
|
||||
description="Get customer details",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}},
|
||||
"required": ["params"],
|
||||
"$defs": {
|
||||
"CustomerIdParam": {
|
||||
"type": "object",
|
||||
"properties": {"customer_id": {"type": "integer"}},
|
||||
"required": ["customer_id"],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
model = _get_input_model_from_mcp_tool(tool)
|
||||
|
||||
# Create an instance to verify the model works with $ref schemas
|
||||
instance = model(params={"customer_id": 251})
|
||||
assert instance.params == {"customer_id": 251}
|
||||
assert isinstance(instance.params, dict)
|
||||
|
||||
# Verify model_dump produces the correct nested structure
|
||||
dumped = instance.model_dump()
|
||||
assert dumped == {"params": {"customer_id": 251}}
|
||||
|
||||
|
||||
def test_get_input_model_from_mcp_prompt():
|
||||
"""Test creation of input model from MCP prompt."""
|
||||
prompt = types.Prompt(
|
||||
@@ -406,6 +471,59 @@ async def test_local_mcp_server_function_execution():
|
||||
assert result[0].text == "Tool executed successfully"
|
||||
|
||||
|
||||
async def test_local_mcp_server_function_execution_with_nested_object():
|
||||
"""Test function execution through MCP server with nested object arguments."""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="get_customer_detail",
|
||||
description="Get customer details",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"params": {
|
||||
"type": "object",
|
||||
"properties": {"customer_id": {"type": "integer"}},
|
||||
"required": ["customer_id"],
|
||||
}
|
||||
},
|
||||
"required": ["params"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(
|
||||
content=[types.TextContent(type="text", text='{"name": "John Doe", "id": 251}')]
|
||||
)
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
func = server.functions[0]
|
||||
|
||||
# Call with nested object
|
||||
result = await func.invoke(params={"customer_id": 251})
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], TextContent)
|
||||
|
||||
# Verify the session.call_tool was called with the correct nested structure
|
||||
server.session.call_tool.assert_called_once()
|
||||
call_args = server.session.call_tool.call_args
|
||||
assert call_args.kwargs["arguments"] == {"params": {"customer_id": 251}}
|
||||
|
||||
|
||||
async def test_local_mcp_server_function_execution_error():
|
||||
"""Test function execution error handling."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user