mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix PydanticSchemaGenerationError when using from __future__ import annotations with @tool (#4822)
* Fix PydanticSchemaGenerationError with PEP 563 annotations in @tool _resolve_input_model used raw param.annotation from inspect.signature(), which returns string annotations when 'from __future__ import annotations' is active (PEP 563). This caused Pydantic's create_model to fail for complex types like Optional[int] or FunctionInvocationContext. Use typing.get_type_hints() to resolve annotations to actual types before passing them to create_model, matching the approach already used by _discover_injected_parameters. Fixes #4809 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Remove reproduction report and unused test imports Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(tests): strengthen PEP 563 regression tests per review feedback (#4809) - Verify type correctness in schema assertions (not just key presence) - Fix ctx annotation to FunctionInvocationContext | None for type consistency - Add test for Optional[CustomType] pattern (original bug trigger) - Add test for get_type_hints() fallback with unresolvable forward refs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Address review feedback for #4809: Python: [Bug]: PydanticSchemaGenerationError in FunctionInvocationContext --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
5070c67d0e
commit
cb96347c95
@@ -466,9 +466,15 @@ class FunctionTool(SerializationMixin):
|
||||
if func is None:
|
||||
return create_model(f"{self.name}_input")
|
||||
sig = inspect.signature(func)
|
||||
try:
|
||||
type_hints = typing.get_type_hints(func, include_extras=True)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
fields: dict[str, Any] = {
|
||||
pname: (
|
||||
_parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str,
|
||||
_parse_annotation(type_hints.get(pname, param.annotation))
|
||||
if type_hints.get(pname, param.annotation) is not inspect.Parameter.empty
|
||||
else str,
|
||||
param.default if param.default is not inspect.Parameter.empty else ...,
|
||||
)
|
||||
for pname, param in sig.parameters.items()
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for @tool with PEP 563 (from __future__ import annotations).
|
||||
|
||||
When ``from __future__ import annotations`` is active, all annotations
|
||||
become strings. _resolve_input_model must resolve them via
|
||||
typing.get_type_hints() before passing them to Pydantic's create_model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework import tool
|
||||
from agent_framework._middleware import FunctionInvocationContext
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
max_results: int = 10
|
||||
|
||||
|
||||
def test_tool_with_context_parameter():
|
||||
"""FunctionInvocationContext parameter is excluded from schema under PEP 563."""
|
||||
|
||||
@tool
|
||||
def get_weather(location: str, ctx: FunctionInvocationContext) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
return f"Weather in {location}"
|
||||
|
||||
params = get_weather.parameters()
|
||||
assert "ctx" not in params.get("properties", {})
|
||||
assert "location" in params["properties"]
|
||||
|
||||
|
||||
def test_tool_with_context_parameter_first():
|
||||
"""FunctionInvocationContext as the first parameter is excluded under PEP 563."""
|
||||
|
||||
@tool
|
||||
def get_weather(ctx: FunctionInvocationContext, location: str) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
return f"Weather in {location}"
|
||||
|
||||
params = get_weather.parameters()
|
||||
assert "ctx" not in params.get("properties", {})
|
||||
assert "location" in params["properties"]
|
||||
|
||||
|
||||
def test_tool_with_optional_param():
|
||||
"""Optional[int] is resolved to the actual type, not left as a string."""
|
||||
|
||||
@tool
|
||||
def search(query: str, limit: int | None = None) -> str:
|
||||
"""Search for something."""
|
||||
return query
|
||||
|
||||
params = search.parameters()
|
||||
assert params["properties"]["query"]["type"] == "string"
|
||||
limit_schema = params["properties"]["limit"]
|
||||
limit_types = {t["type"] for t in limit_schema["anyOf"]}
|
||||
assert limit_types == {"integer", "null"}
|
||||
|
||||
|
||||
def test_tool_with_optional_param_and_context():
|
||||
"""Optional param + FunctionInvocationContext both work under PEP 563."""
|
||||
|
||||
@tool
|
||||
def search(query: str, limit: int | None = None, ctx: FunctionInvocationContext | None = None) -> str:
|
||||
"""Search for something."""
|
||||
return query
|
||||
|
||||
params = search.parameters()
|
||||
assert params["properties"]["query"]["type"] == "string"
|
||||
limit_schema = params["properties"]["limit"]
|
||||
limit_types = {t["type"] for t in limit_schema["anyOf"]}
|
||||
assert limit_types == {"integer", "null"}
|
||||
assert "ctx" not in params.get("properties", {})
|
||||
|
||||
|
||||
def test_tool_with_optional_custom_type():
|
||||
"""Optional[CustomType] is resolved under PEP 563 (original bug pattern)."""
|
||||
|
||||
@tool
|
||||
def search(query: str, config: SearchConfig | None = None) -> str:
|
||||
"""Search for something."""
|
||||
return query
|
||||
|
||||
params = search.parameters()
|
||||
assert params["properties"]["query"]["type"] == "string"
|
||||
config_schema = params["properties"]["config"]
|
||||
config_types = [t.get("type") for t in config_schema["anyOf"]]
|
||||
assert "null" in config_types
|
||||
|
||||
|
||||
def test_tool_with_unresolvable_forward_ref():
|
||||
"""Fallback to raw annotations when get_type_hints() fails."""
|
||||
import types
|
||||
|
||||
# Build a function in an isolated namespace so get_type_hints() cannot resolve
|
||||
# the forward reference, exercising the except-branch fallback.
|
||||
ns: dict = {}
|
||||
exec(
|
||||
"def greet(name: str = 'world') -> str:\n '''Greet someone.'''\n return f'Hello {name}'\n",
|
||||
ns,
|
||||
)
|
||||
func = ns["greet"]
|
||||
# Place the function in a throwaway module so get_type_hints() will fail on
|
||||
# any non-builtin forward ref while still having a valid __module__.
|
||||
mod = types.ModuleType("_phantom")
|
||||
func.__module__ = mod.__name__
|
||||
|
||||
t = tool(func)
|
||||
params = t.parameters()
|
||||
assert params["properties"]["name"]["type"] == "string"
|
||||
|
||||
|
||||
async def test_tool_invoke_with_context():
|
||||
"""Full invocation with FunctionInvocationContext under PEP 563."""
|
||||
|
||||
@tool
|
||||
def get_weather(location: str, ctx: FunctionInvocationContext) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
user = ctx.kwargs.get("user", "anon")
|
||||
return f"Weather in {location} for {user}"
|
||||
|
||||
params = get_weather.parameters()
|
||||
assert "ctx" not in params.get("properties", {})
|
||||
|
||||
context = FunctionInvocationContext(
|
||||
function=get_weather,
|
||||
arguments=get_weather.input_model(location="Seattle"),
|
||||
kwargs={"user": "test_user"},
|
||||
)
|
||||
result = await get_weather.invoke(context=context)
|
||||
assert result[0].text == "Weather in Seattle for test_user"
|
||||
Reference in New Issue
Block a user