Python: [BREAKING] Remove FunctionTool[Any] compatibility shim for schema passthrough (#3600) (#3907)

* Fix #3600: Pass JSON schemas through without Pydantic conversion

This change optimizes FunctionTool and MCP flows by passing JSON schemas
directly to providers without converting them to Pydantic models first.

Key changes:
- Store JSON schema as-is when supplied to FunctionTool
- Skip Pydantic model_validate for schema-supplied tools in invoke()
- Return MCP tool schemas directly without conversion
- Add comprehensive tests for schema passthrough behavior

Performance benefits:
- Eliminates expensive Pydantic model creation for supplied schemas
- Preserves exact schema structure (additionalProperties, custom fields, etc.)
- Reduces memory overhead and initialization time

Maintains backward compatibility:
- Function signature inference still uses Pydantic models
- Explicit Pydantic models passed as input_model work as before
- All existing tests pass

* Fix schema passthrough validation and remove helper

* Simplify FunctionTool without generic model dependency

* Fix FunctionTool typing fallout in 3600

* Remove FunctionTool[Any] compatibility shim

* Use serializable kwargs in OTEL tool args
This commit is contained in:
Eduard van Valkenburg
2026-02-14 11:12:21 +01:00
committed by GitHub
Unverified
parent cd1e3110aa
commit fc9c81b0b1
16 changed files with 645 additions and 482 deletions
@@ -267,7 +267,7 @@ class AGUIChatClient(
if any(getattr(tool, "name", None) == tool_name for tool in additional_tools):
return
placeholder: FunctionTool[Any] = FunctionTool(
placeholder: FunctionTool = FunctionTool(
name=tool_name,
description="Server-managed tool placeholder (AG-UI)",
func=None,
@@ -162,7 +162,7 @@ def make_json_safe(obj: Any) -> Any: # noqa: ANN401
def convert_agui_tools_to_agent_framework(
agui_tools: list[dict[str, Any]] | None,
) -> list[FunctionTool[Any]] | None:
) -> list[FunctionTool] | None:
"""Convert AG-UI tool definitions to Agent Framework FunctionTool declarations.
Creates declaration-only FunctionTool instances (no executable implementation).
@@ -181,13 +181,13 @@ def convert_agui_tools_to_agent_framework(
if not agui_tools:
return None
result: list[FunctionTool[Any]] = []
result: list[FunctionTool] = []
for tool_def in agui_tools:
# Create declaration-only FunctionTool (func=None means no implementation)
# When func=None, the declaration_only property returns True,
# which tells the function invocation mixin to return the function call
# without executing it (so it can be sent back to the client)
func: FunctionTool[Any] = FunctionTool(
func: FunctionTool = FunctionTool(
name=tool_def.get("name", ""),
description=tool_def.get("description", ""),
func=None, # CRITICAL: Makes declaration_only=True
@@ -5,7 +5,7 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any, TypedDict
from typing import TYPE_CHECKING, TypedDict
from agent_framework import Agent, FunctionTool, SupportsChatGetResponse
from agent_framework.ag_ui import AgentFrameworkAgent
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from agent_framework import ChatOptions
# Declaration-only tools (func=None) - actual rendering happens on the client side
generate_haiku = FunctionTool[Any](
generate_haiku = FunctionTool(
name="generate_haiku",
description="""Generate a haiku with image and gradient background (FRONTEND_RENDER).
@@ -71,7 +71,7 @@ generate_haiku = FunctionTool[Any](
},
)
create_chart = FunctionTool[Any](
create_chart = FunctionTool(
name="create_chart",
description="""Create an interactive chart (FRONTEND_RENDER).
@@ -99,7 +99,7 @@ create_chart = FunctionTool[Any](
},
)
display_timeline = FunctionTool[Any](
display_timeline = FunctionTool(
name="display_timeline",
description="""Display an interactive timeline (FRONTEND_RENDER).
@@ -127,7 +127,7 @@ display_timeline = FunctionTool[Any](
},
)
show_comparison_table = FunctionTool[Any](
show_comparison_table = FunctionTool(
name="show_comparison_table",
description="""Show a comparison table (FRONTEND_RENDER).
@@ -484,7 +484,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
return create_sdk_mcp_server(name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools), tool_names
def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool[Any]) -> SdkMcpTool[Any]:
def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool) -> SdkMcpTool[Any]:
"""Convert a FunctionTool to an SDK MCP tool.
Args:
@@ -439,7 +439,7 @@ class BaseAgent(SerializationMixin):
stream_callback: Callable[[AgentResponseUpdate], None]
| Callable[[AgentResponseUpdate], Awaitable[None]]
| None = None,
) -> FunctionTool[BaseModel]:
) -> FunctionTool:
"""Create a FunctionTool that wraps this agent.
Keyword Args:
@@ -513,7 +513,7 @@ class BaseAgent(SerializationMixin):
# Create final text from accumulated updates
return AgentResponse.from_updates(response_updates).text
agent_tool: FunctionTool[BaseModel] = FunctionTool(
agent_tool: FunctionTool = FunctionTool(
name=tool_name,
description=tool_description,
func=agent_wrapper,
@@ -1258,17 +1258,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
@server.list_tools() # type: ignore
async def _list_tools() -> list[types.Tool]: # type: ignore
"""List all tools in the agent."""
# Get the JSON schema from the Pydantic model
schema = agent_tool.input_model.model_json_schema()
schema = agent_tool.parameters()
tool = types.Tool(
name=agent_tool.name,
description=agent_tool.description,
inputSchema={
"type": "object",
"properties": schema.get("properties", {}),
"required": schema.get("required", []),
},
inputSchema=schema,
)
await _log(level="debug", data=f"Agent tool: {agent_tool}")
@@ -1291,7 +1286,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
# Create an instance of the input model with the arguments
try:
args_instance = agent_tool.input_model(**arguments)
args_instance: BaseModel | dict[str, Any] = (
agent_tool.input_model(**arguments) if agent_tool.input_model is not None else arguments
)
result = await agent_tool.invoke(arguments=args_instance)
except Exception as e:
raise McpError(
+15 -18
View File
@@ -24,11 +24,9 @@ from mcp.client.websocket import websocket_client
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from pydantic import BaseModel, create_model
from ._tools import (
FunctionTool,
_build_pydantic_model_from_json_schema,
)
from ._types import (
Content,
@@ -355,11 +353,14 @@ def _prepare_message_for_mcp(
return messages
def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]:
"""Creates a Pydantic model from a prompt's parameters."""
def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any]:
"""Get the input model from an MCP prompt.
Returns a JSON schema dictionary for prompt arguments.
"""
# Check if 'arguments' is missing or empty
if not prompt.arguments:
return create_model(f"{prompt.name}_input")
return {"type": "object", "properties": {}}
# Convert prompt arguments to JSON schema format
properties: dict[str, Any] = {}
@@ -374,13 +375,10 @@ def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]:
if prompt_argument.required:
required.append(prompt_argument.name)
schema = {"properties": properties, "required": required}
return _build_pydantic_model_from_json_schema(prompt.name, schema)
def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
"""Creates a Pydantic model from a tools parameters."""
return _build_pydantic_model_from_json_schema(tool.name, tool.inputSchema)
schema: dict[str, Any] = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema
def _normalize_mcp_name(name: str) -> str:
@@ -467,7 +465,7 @@ class MCPTool:
self.session = session
self.request_timeout = request_timeout
self.client = client
self._functions: list[FunctionTool[Any]] = []
self._functions: list[FunctionTool] = []
self.is_connected: bool = False
self._tools_loaded: bool = False
self._prompts_loaded: bool = False
@@ -476,7 +474,7 @@ class MCPTool:
return f"MCPTool(name={self.name}, description={self.description})"
@property
def functions(self) -> list[FunctionTool[Any]]:
def functions(self) -> list[FunctionTool]:
"""Get the list of functions that are allowed."""
if not self.allowed_tools:
return self._functions
@@ -744,7 +742,7 @@ class MCPTool:
input_model = _get_input_model_from_mcp_prompt(prompt)
approval_mode = self._determine_approval_mode(local_name)
func: FunctionTool[BaseModel] = FunctionTool(
func: FunctionTool = FunctionTool(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
@@ -785,15 +783,14 @@ class MCPTool:
if local_name in existing_names:
continue
input_model = _get_input_model_from_mcp_tool(tool)
approval_mode = self._determine_approval_mode(local_name)
# Create FunctionTools out of each tool
func: FunctionTool[BaseModel] = FunctionTool(
func: FunctionTool = FunctionTool(
func=partial(self.call_tool, tool.name),
name=local_name,
description=tool.description or "",
approval_mode=approval_mode,
input_model=input_model,
input_model=tool.inputSchema,
)
self._functions.append(func)
existing_names.add(local_name)
@@ -234,8 +234,8 @@ class FunctionInvocationContext:
def __init__(
self,
function: FunctionTool[Any],
arguments: BaseModel,
function: FunctionTool,
arguments: BaseModel | Mapping[str, Any],
metadata: Mapping[str, Any] | None = None,
result: Any = None,
kwargs: Mapping[str, Any] | None = None,
+181 -77
View File
@@ -26,7 +26,6 @@ from typing import (
Literal,
TypedDict,
Union,
cast,
get_args,
get_origin,
overload,
@@ -89,8 +88,6 @@ DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
# region Helpers
ArgsT = TypeVar("ArgsT", bound=BaseModel, default=BaseModel)
def _parse_inputs(
inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None,
@@ -183,11 +180,7 @@ def _default_histogram() -> Histogram:
ClassT = TypeVar("ClassT", bound="SerializationMixin")
class EmptyInputModel(BaseModel):
"""An empty input model for functions with no parameters."""
class FunctionTool(SerializationMixin, Generic[ArgsT]):
class FunctionTool(SerializationMixin):
"""A tool that wraps a Python function to make it callable by AI models.
This class wraps a Python function to make it callable by AI models with automatic
@@ -240,6 +233,8 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
"input_model",
"_invocation_duration_histogram",
"_cached_parameters",
"_input_schema",
"_schema_supplied",
}
def __init__(
@@ -252,7 +247,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
max_invocation_exceptions: int | None = None,
additional_properties: dict[str, Any] | None = None,
func: Callable[..., Any] | None = None,
input_model: type[ArgsT] | Mapping[str, Any] | None = None,
input_model: type[BaseModel] | Mapping[str, Any] | None = None,
result_parser: Callable[[Any], str] | None = None,
**kwargs: Any,
) -> None:
@@ -299,7 +294,16 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
# FunctionTool-specific attributes
self.func = func
self._instance = None # Store the instance for bound methods
self.input_model = self._resolve_input_model(input_model)
# Track if schema was supplied as JSON dict (for optimization)
if isinstance(input_model, Mapping):
self._schema_supplied = True
self._input_schema: dict[str, Any] = dict(input_model)
self.input_model: type[BaseModel] | None = None
else:
self._schema_supplied = False
self.input_model = self._resolve_input_model(input_model)
self._input_schema = self.input_model.model_json_schema()
self._cached_parameters: dict[str, Any] | None = None
self.approval_mode = approval_mode or "never_require"
if max_invocations is not None and max_invocations < 1:
@@ -335,7 +339,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
return True
return self.func is None
def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]:
def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool:
"""Implement the descriptor protocol to support bound methods.
When a FunctionTool is accessed as an attribute of a class instance,
@@ -366,17 +370,30 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
return self
def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]:
def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[BaseModel]:
"""Resolve the input model for the function."""
if input_model is None:
if self.func is None:
return cast(type[ArgsT], EmptyInputModel)
return cast(type[ArgsT], _create_input_model_from_func(func=self.func, name=self.name))
if inspect.isclass(input_model) and issubclass(input_model, BaseModel):
return input_model
if isinstance(input_model, Mapping):
return cast(type[ArgsT], _create_model_from_json_schema(self.name, input_model))
raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.")
if input_model is not None:
if inspect.isclass(input_model) and issubclass(input_model, BaseModel):
return input_model
raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.")
if self.func is None:
return create_model(f"{self.name}_input")
func = self.func.func if isinstance(self.func, FunctionTool) else self.func
if func is None:
return create_model(f"{self.name}_input")
sig = inspect.signature(func)
fields: dict[str, Any] = {
pname: (
_parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str,
param.default if param.default is not inspect.Parameter.empty else ...,
)
for pname, param in sig.parameters.items()
if pname not in {"self", "cls"}
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
return create_model(f"{self.name}_input", **fields)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Call the wrapped function with the provided arguments."""
@@ -407,7 +424,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
async def invoke(
self,
*,
arguments: ArgsT | None = None,
arguments: BaseModel | Mapping[str, Any] | None = None,
**kwargs: Any,
) -> str:
"""Run the AI function with the provided arguments as a Pydantic model.
@@ -417,14 +434,14 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
``result_parser`` if one was provided.
Keyword Args:
arguments: A Pydantic model instance containing the arguments for the function.
arguments: A mapping or model instance containing the arguments for the function.
kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided.
Returns:
The parsed result as a string — either plain text or serialized JSON.
Raises:
TypeError: If arguments is not an instance of the expected input model.
TypeError: If arguments is not mapping-like or fails schema checks.
"""
if self.declaration_only:
raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.")
@@ -436,9 +453,32 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
original_kwargs = dict(kwargs)
tool_call_id = original_kwargs.pop("tool_call_id", None)
if arguments is not None:
if not isinstance(arguments, self.input_model):
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
kwargs = arguments.model_dump(exclude_none=True)
try:
if isinstance(arguments, Mapping):
parsed_arguments = dict(arguments)
if self.input_model is not None and not self._schema_supplied:
parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump(
exclude_none=True
)
elif isinstance(arguments, BaseModel):
if (
self.input_model is not None
and not self._schema_supplied
and not isinstance(arguments, self.input_model)
):
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
parsed_arguments = arguments.model_dump(exclude_none=True)
else:
raise TypeError(
f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}"
)
except ValidationError as exc:
raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc
kwargs = _validate_arguments_against_schema(
arguments=parsed_arguments,
schema=self.parameters(),
tool_name=self.name,
)
if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs:
kwargs.update(original_kwargs)
else:
@@ -458,34 +498,34 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
return parsed
attributes = get_function_span_attributes(self, tool_call_id=tool_call_id)
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
# Filter out framework kwargs that are not JSON serializable
serializable_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in {
"chat_options",
"tools",
"tool_choice",
"session",
"conversation_id",
"options",
"response_format",
}
# Filter out framework kwargs that are not JSON serializable.
serializable_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in {
"chat_options",
"tools",
"tool_choice",
"session",
"conversation_id",
"options",
"response_format",
}
}
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
attributes.update({
OtelAttr.TOOL_ARGUMENTS: arguments.model_dump_json(ensure_ascii=False)
if arguments
else json.dumps(serializable_kwargs, default=str, ensure_ascii=False)
if serializable_kwargs
else "None"
OtelAttr.TOOL_ARGUMENTS: (
json.dumps(serializable_kwargs, default=str, ensure_ascii=False)
if serializable_kwargs
else "None"
)
})
with get_function_span(attributes=attributes) as span:
attributes[OtelAttr.MEASUREMENT_FUNCTION_TAG_NAME] = self.name
logger.info(f"Function name: {self.name}")
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
logger.debug(f"Function arguments: {kwargs}")
logger.debug(f"Function arguments: {serializable_kwargs}")
start_time_stamp = perf_counter()
end_time_stamp: float | None = None
try:
@@ -523,7 +563,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
The result is cached after the first call for performance.
"""
if self._cached_parameters is None:
self._cached_parameters = self.input_model.model_json_schema()
self._cached_parameters = self._input_schema
return self._cached_parameters
@staticmethod
@@ -677,23 +717,79 @@ def _parse_annotation(annotation: Any) -> Any:
return annotation
def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[BaseModel]:
"""Create a Pydantic model from a function's signature."""
# Unwrap FunctionTool objects to get the underlying function
if isinstance(func, FunctionTool):
func = func.func # type: ignore[assignment]
def _matches_json_schema_type(value: Any, schema_type: str) -> bool:
"""Check a value against a simple JSON schema primitive type."""
match schema_type:
case "string":
return isinstance(value, str)
case "integer":
return isinstance(value, int) and not isinstance(value, bool)
case "number":
return (isinstance(value, int | float)) and not isinstance(value, bool)
case "boolean":
return isinstance(value, bool)
case "array":
return isinstance(value, list)
case "object":
return isinstance(value, dict)
case "null":
return value is None
case _:
return True
sig = inspect.signature(func)
fields = {
pname: (
_parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str,
param.default if param.default is not inspect.Parameter.empty else ...,
)
for pname, param in sig.parameters.items()
if pname not in {"self", "cls"}
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return]
def _validate_arguments_against_schema(
*,
arguments: Mapping[str, Any],
schema: Mapping[str, Any],
tool_name: str,
) -> dict[str, Any]:
"""Run lightweight argument checks for schema-supplied tools."""
parsed_arguments = dict(arguments)
required_raw = schema.get("required", [])
required_fields = [field for field in required_raw if isinstance(field, str)]
missing_fields = [field for field in required_fields if field not in parsed_arguments]
if missing_fields:
raise TypeError(f"Missing required argument(s) for '{tool_name}': {', '.join(sorted(missing_fields))}")
properties_raw = schema.get("properties")
properties = properties_raw if isinstance(properties_raw, Mapping) else {}
if schema.get("additionalProperties") is False:
unexpected_fields = sorted(field for field in parsed_arguments if field not in properties)
if unexpected_fields:
raise TypeError(f"Unexpected argument(s) for '{tool_name}': {', '.join(unexpected_fields)}")
for field_name, field_value in parsed_arguments.items():
field_schema = properties.get(field_name)
if not isinstance(field_schema, Mapping):
continue
enum_values = field_schema.get("enum")
if isinstance(enum_values, list) and enum_values and field_value not in enum_values:
raise TypeError(
f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} is not in {enum_values!r}"
)
schema_type = field_schema.get("type")
if isinstance(schema_type, str):
if not _matches_json_schema_type(field_value, schema_type):
raise TypeError(
f"Invalid type for '{field_name}' in '{tool_name}': "
f"expected {schema_type}, got {type(field_value).__name__}"
)
continue
if isinstance(schema_type, list):
allowed_types = [item for item in schema_type if isinstance(item, str)]
if allowed_types and not any(_matches_json_schema_type(field_value, item) for item in allowed_types):
raise TypeError(
f"Invalid type for '{field_name}' in '{tool_name}': expected one of "
f"{allowed_types}, got {type(field_value).__name__}"
)
return parsed_arguments
# Map JSON Schema types to Pydantic types
@@ -942,7 +1038,7 @@ def tool(
max_invocation_exceptions: int | None = None,
additional_properties: dict[str, Any] | None = None,
result_parser: Callable[[Any], str] | None = None,
) -> FunctionTool[Any]: ...
) -> FunctionTool: ...
@overload
@@ -957,7 +1053,7 @@ def tool(
max_invocation_exceptions: int | None = None,
additional_properties: dict[str, Any] | None = None,
result_parser: Callable[[Any], str] | None = None,
) -> Callable[[Callable[..., Any]], FunctionTool[Any]]: ...
) -> Callable[[Callable[..., Any]], FunctionTool]: ...
def tool(
@@ -971,7 +1067,7 @@ def tool(
max_invocation_exceptions: int | None = None,
additional_properties: dict[str, Any] | None = None,
result_parser: Callable[[Any], str] | None = None,
) -> FunctionTool[Any] | Callable[[Callable[..., Any]], FunctionTool[Any]]:
) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]:
"""Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically.
This decorator creates a Pydantic model from the function's signature,
@@ -1095,12 +1191,12 @@ def tool(
"""
def decorator(func: Callable[..., Any]) -> FunctionTool[Any]:
def decorator(func: Callable[..., Any]) -> FunctionTool:
@wraps(func)
def wrapper(f: Callable[..., Any]) -> FunctionTool[Any]:
def wrapper(f: Callable[..., Any]) -> FunctionTool:
tool_name: str = name or getattr(f, "__name__", "unknown_function") # type: ignore[assignment]
tool_desc: str = description or (f.__doc__ or "")
return FunctionTool[Any](
return FunctionTool(
name=tool_name,
description=tool_desc,
approval_mode=approval_mode,
@@ -1193,7 +1289,7 @@ async def _auto_invoke_function(
custom_args: dict[str, Any] | None = None,
*,
config: FunctionInvocationConfiguration,
tool_map: dict[str, FunctionTool[BaseModel]],
tool_map: dict[str, FunctionTool],
sequence_index: int | None = None,
request_index: int | None = None,
middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline
@@ -1225,7 +1321,7 @@ async def _auto_invoke_function(
# this function is called. This function only handles the actual execution of approved,
# non-declaration-only functions.
tool: FunctionTool[BaseModel] | None = None
tool: FunctionTool | None = None
if function_call_content.type == "function_call":
tool = tool_map.get(function_call_content.name) # type: ignore[arg-type]
# Tool should exist because _try_execute_function_calls validates this
@@ -1258,8 +1354,16 @@ async def _auto_invoke_function(
if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"}
}
try:
args = tool.input_model.model_validate(parsed_args)
except ValidationError as exc:
if not tool._schema_supplied and tool.input_model is not None:
args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True)
else:
args = dict(parsed_args)
args = _validate_arguments_against_schema(
arguments=args,
schema=tool.parameters(),
tool_name=tool.name,
)
except (TypeError, ValidationError) as exc:
message = "Error: Argument parsing failed."
if config["include_detailed_errors"]:
message = f"{message} Exception: {exc}"
@@ -1340,8 +1444,8 @@ def _get_tool_map(
| Callable[..., Any]
| MutableMapping[str, Any]
| Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]],
) -> dict[str, FunctionTool[Any]]:
tool_list: dict[str, FunctionTool[Any]] = {}
) -> dict[str, FunctionTool]:
tool_list: dict[str, FunctionTool] = {}
for tool_item in tools if isinstance(tools, list) else [tools]:
if isinstance(tool_item, FunctionTool):
tool_list[tool_item.name] = tool_item
@@ -1448,7 +1448,7 @@ class AgentTelemetryLayer:
# region Otel Helpers
def get_function_span_attributes(function: FunctionTool[Any], tool_call_id: str | None = None) -> dict[str, str]:
def get_function_span_attributes(function: FunctionTool, tool_call_id: str | None = None) -> dict[str, str]:
"""Get the span attributes for the given function.
Args:
+320 -338
View File
@@ -10,7 +10,7 @@ import pytest
from mcp import types
from mcp.client.session import ClientSession
from mcp.shared.exceptions import McpError
from pydantic import AnyUrl, BaseModel, ValidationError
from pydantic import AnyUrl, BaseModel
from agent_framework import (
Content,
@@ -22,7 +22,6 @@ from agent_framework import (
from agent_framework._mcp import (
MCPTool,
_get_input_model_from_mcp_prompt,
_get_input_model_from_mcp_tool,
_normalize_mcp_name,
_parse_content_from_mcp,
_parse_message_from_mcp,
@@ -276,363 +275,338 @@ def test_prepare_message_for_mcp():
@pytest.mark.parametrize(
"test_id,input_schema,valid_data,expected_values,invalid_data,validation_check",
"test_id,input_schema",
[
# Basic types with required/optional fields
(
"basic_types",
{
"type": "object",
"properties": {"param1": {"type": "string"}, "param2": {"type": "number"}},
"required": ["param1"],
},
{"param1": "test", "param2": 42},
{"param1": "test", "param2": 42},
{"param2": 42}, # Missing required param1
None,
),
# Nested object
(
"nested_object",
{
"type": "object",
"properties": {
"params": {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
(test_id, input_schema)
for test_id, input_schema, _, _, _, _ in [
# Basic types with required/optional fields
(
"basic_types",
{
"type": "object",
"properties": {"param1": {"type": "string"}, "param2": {"type": "number"}},
"required": ["param1"],
},
"required": ["params"],
},
{"params": {"customer_id": 251}},
{"params.customer_id": 251},
{"params": {}}, # Missing required customer_id
lambda instance: isinstance(instance.params, BaseModel),
),
# $ref resolution
(
"ref_schema",
{
"type": "object",
"properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}},
"required": ["params"],
"$defs": {
"CustomerIdParam": {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
},
},
{"params": {"customer_id": 251}},
{"params.customer_id": 251},
{"params": {}}, # Missing required customer_id
lambda instance: isinstance(instance.params, BaseModel),
),
# Array of strings (typed)
(
"array_of_strings",
{
"type": "object",
"properties": {
"tags": {
"type": "array",
"description": "List of tags",
"items": {"type": "string"},
}
},
"required": ["tags"],
},
{"tags": ["tag1", "tag2", "tag3"]},
{"tags": ["tag1", "tag2", "tag3"]},
None, # No validation error test for this case
None,
),
# Array of integers (typed)
(
"array_of_integers",
{
"type": "object",
"properties": {
"numbers": {
"type": "array",
"description": "List of integers",
"items": {"type": "integer"},
}
},
"required": ["numbers"],
},
{"numbers": [1, 2, 3]},
{"numbers": [1, 2, 3]},
None,
None,
),
# Array of objects (complex nested)
(
"array_of_objects",
{
"type": "object",
"properties": {
"users": {
"type": "array",
"description": "List of users",
"items": {
{"param1": "test", "param2": 42},
{"param1": "test", "param2": 42},
{"param2": 42}, # Missing required param1
None,
),
# Nested object
(
"nested_object",
{
"type": "object",
"properties": {
"params": {
"type": "object",
"properties": {
"id": {"type": "integer", "description": "User ID"},
"name": {"type": "string", "description": "User name"},
},
"required": ["id", "name"],
},
}
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
},
"required": ["params"],
},
"required": ["users"],
},
{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]},
{"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"},
{"users": [{"id": 1}]}, # Missing required 'name'
lambda instance: all(isinstance(user, BaseModel) for user in instance.users),
),
# Deeply nested objects (3+ levels)
(
"deeply_nested",
{
"type": "object",
"properties": {
"query": {
"type": "object",
"properties": {
"filters": {
{"params": {"customer_id": 251}},
{"params.customer_id": 251},
{"params": {}}, # Missing required customer_id
lambda instance: isinstance(instance.params, BaseModel),
),
# $ref resolution
(
"ref_schema",
{
"type": "object",
"properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}},
"required": ["params"],
"$defs": {
"CustomerIdParam": {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
},
},
{"params": {"customer_id": 251}},
{"params.customer_id": 251},
{"params": {}}, # Missing required customer_id
lambda instance: isinstance(instance.params, BaseModel),
),
# Array of strings (typed)
(
"array_of_strings",
{
"type": "object",
"properties": {
"tags": {
"type": "array",
"description": "List of tags",
"items": {"type": "string"},
}
},
"required": ["tags"],
},
{"tags": ["tag1", "tag2", "tag3"]},
{"tags": ["tag1", "tag2", "tag3"]},
None, # No validation error test for this case
None,
),
# Array of integers (typed)
(
"array_of_integers",
{
"type": "object",
"properties": {
"numbers": {
"type": "array",
"description": "List of integers",
"items": {"type": "integer"},
}
},
"required": ["numbers"],
},
{"numbers": [1, 2, 3]},
{"numbers": [1, 2, 3]},
None,
None,
),
# Array of objects (complex nested)
(
"array_of_objects",
{
"type": "object",
"properties": {
"users": {
"type": "array",
"description": "List of users",
"items": {
"type": "object",
"properties": {
"date_range": {
"type": "object",
"properties": {
"start": {"type": "string"},
"end": {"type": "string"},
},
"required": ["start", "end"],
},
"categories": {"type": "array", "items": {"type": "string"}},
"id": {"type": "integer", "description": "User ID"},
"name": {"type": "string", "description": "User name"},
},
"required": ["date_range"],
}
},
"required": ["filters"],
"required": ["id", "name"],
},
}
},
"required": ["users"],
},
{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]},
{"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"},
{"users": [{"id": 1}]}, # Missing required 'name'
lambda instance: all(isinstance(user, BaseModel) for user in instance.users),
),
# Deeply nested objects (3+ levels)
(
"deeply_nested",
{
"type": "object",
"properties": {
"query": {
"type": "object",
"properties": {
"filters": {
"type": "object",
"properties": {
"date_range": {
"type": "object",
"properties": {
"start": {"type": "string"},
"end": {"type": "string"},
},
"required": ["start", "end"],
},
"categories": {"type": "array", "items": {"type": "string"}},
},
"required": ["date_range"],
}
},
"required": ["filters"],
}
},
"required": ["query"],
},
{
"query": {
"filters": {
"date_range": {"start": "2024-01-01", "end": "2024-12-31"},
"categories": ["tech", "science"],
}
}
},
"required": ["query"],
},
{
"query": {
"filters": {
"date_range": {"start": "2024-01-01", "end": "2024-12-31"},
"categories": ["tech", "science"],
{
"query.filters.date_range.start": "2024-01-01",
"query.filters.date_range.end": "2024-12-31",
"query.filters.categories": ["tech", "science"],
},
{"query": {"filters": {"date_range": {}}}}, # Missing required start and end
None,
),
# Complex $ref with nested structure
(
"ref_nested_structure",
{
"type": "object",
"properties": {"order": {"$ref": "#/$defs/OrderParams"}},
"required": ["order"],
"$defs": {
"OrderParams": {
"type": "object",
"properties": {
"customer": {"$ref": "#/$defs/Customer"},
"items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}},
},
"required": ["customer", "items"],
},
"Customer": {
"type": "object",
"properties": {"id": {"type": "integer"}, "email": {"type": "string"}},
"required": ["id", "email"],
},
"OrderItem": {
"type": "object",
"properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}},
"required": ["product_id", "quantity"],
},
},
},
{
"order": {
"customer": {"id": 123, "email": "test@example.com"},
"items": [{"product_id": "prod1", "quantity": 2}],
}
}
},
{
"query.filters.date_range.start": "2024-01-01",
"query.filters.date_range.end": "2024-12-31",
"query.filters.categories": ["tech", "science"],
},
{"query": {"filters": {"date_range": {}}}}, # Missing required start and end
None,
),
# Complex $ref with nested structure
(
"ref_nested_structure",
{
"type": "object",
"properties": {"order": {"$ref": "#/$defs/OrderParams"}},
"required": ["order"],
"$defs": {
"OrderParams": {
"type": "object",
"properties": {
"customer": {"$ref": "#/$defs/Customer"},
"items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}},
},
{
"order.customer.id": 123,
"order.customer.email": "test@example.com",
"order.items[0].product_id": "prod1",
"order.items[0].quantity": 2,
},
{"order": {"customer": {"id": 123}, "items": []}}, # Missing email
lambda instance: isinstance(instance.order.customer, BaseModel),
),
# Mixed types (primitives, arrays, nested objects)
(
"mixed_types",
{
"type": "object",
"properties": {
"simple_string": {"type": "string"},
"simple_number": {"type": "integer"},
"string_array": {"type": "array", "items": {"type": "string"}},
"nested_config": {
"type": "object",
"properties": {
"enabled": {"type": "boolean"},
"options": {"type": "array", "items": {"type": "string"}},
},
"required": ["enabled"],
},
"required": ["customer", "items"],
},
"Customer": {
"type": "object",
"properties": {"id": {"type": "integer"}, "email": {"type": "string"}},
"required": ["id", "email"],
},
"OrderItem": {
"type": "object",
"properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}},
"required": ["product_id", "quantity"],
"required": ["simple_string", "nested_config"],
},
{
"simple_string": "test",
"simple_number": 42,
"string_array": ["a", "b"],
"nested_config": {"enabled": True, "options": ["opt1", "opt2"]},
},
{
"simple_string": "test",
"simple_number": 42,
"string_array": ["a", "b"],
"nested_config.enabled": True,
"nested_config.options": ["opt1", "opt2"],
},
None,
None,
),
# Empty schema (no properties)
(
"empty_schema",
{"type": "object", "properties": {}},
{},
{},
None,
None,
),
# All primitive types
(
"all_primitives",
{
"type": "object",
"properties": {
"string_field": {"type": "string"},
"integer_field": {"type": "integer"},
"number_field": {"type": "number"},
"boolean_field": {"type": "boolean"},
},
},
},
{
"order": {
"customer": {"id": 123, "email": "test@example.com"},
"items": [{"product_id": "prod1", "quantity": 2}],
}
},
{
"order.customer.id": 123,
"order.customer.email": "test@example.com",
"order.items[0].product_id": "prod1",
"order.items[0].quantity": 2,
},
{"order": {"customer": {"id": 123}, "items": []}}, # Missing email
lambda instance: isinstance(instance.order.customer, BaseModel),
),
# Mixed types (primitives, arrays, nested objects)
(
"mixed_types",
{
"type": "object",
"properties": {
"simple_string": {"type": "string"},
"simple_number": {"type": "integer"},
"string_array": {"type": "array", "items": {"type": "string"}},
"nested_config": {
"type": "object",
"properties": {
"enabled": {"type": "boolean"},
"options": {"type": "array", "items": {"type": "string"}},
},
"required": ["enabled"],
},
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
None,
None,
),
# Edge case: unresolvable $ref (fallback to dict)
(
"unresolvable_ref",
{
"type": "object",
"properties": {"data": {"$ref": "#/$defs/NonExistent"}},
"$defs": {},
},
"required": ["simple_string", "nested_config"],
},
{
"simple_string": "test",
"simple_number": 42,
"string_array": ["a", "b"],
"nested_config": {"enabled": True, "options": ["opt1", "opt2"]},
},
{
"simple_string": "test",
"simple_number": 42,
"string_array": ["a", "b"],
"nested_config.enabled": True,
"nested_config.options": ["opt1", "opt2"],
},
None,
None,
),
# Empty schema (no properties)
(
"empty_schema",
{"type": "object", "properties": {}},
{},
{},
None,
None,
),
# All primitive types
(
"all_primitives",
{
"type": "object",
"properties": {
"string_field": {"type": "string"},
"integer_field": {"type": "integer"},
"number_field": {"type": "number"},
"boolean_field": {"type": "boolean"},
{"data": {"key": "value"}},
{"data": {"key": "value"}},
None,
None,
),
# Edge case: array without items schema (fallback to bare list)
(
"array_no_items",
{
"type": "object",
"properties": {"items": {"type": "array"}},
},
},
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
None,
None,
),
# Edge case: unresolvable $ref (fallback to dict)
(
"unresolvable_ref",
{
"type": "object",
"properties": {"data": {"$ref": "#/$defs/NonExistent"}},
"$defs": {},
},
{"data": {"key": "value"}},
{"data": {"key": "value"}},
None,
None,
),
# Edge case: array without items schema (fallback to bare list)
(
"array_no_items",
{
"type": "object",
"properties": {"items": {"type": "array"}},
},
{"items": [1, "two", 3.0]},
{"items": [1, "two", 3.0]},
None,
None,
),
# Edge case: object without properties (fallback to dict)
(
"object_no_properties",
{
"type": "object",
"properties": {"config": {"type": "object"}},
},
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
None,
None,
),
{"items": [1, "two", 3.0]},
{"items": [1, "two", 3.0]},
None,
None,
),
# Edge case: object without properties (fallback to dict)
(
"object_no_properties",
{
"type": "object",
"properties": {"config": {"type": "object"}},
},
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
None,
None,
),
]
],
)
def test_get_input_model_from_mcp_tool_parametrized(
test_id, input_schema, valid_data, expected_values, invalid_data, validation_check
):
"""Parametrized test for JSON schema to Pydantic model conversion.
def test_get_input_model_from_mcp_tool_parametrized(test_id: str, input_schema: dict[str, Any]) -> None:
"""Parametrized test for MCP tool input schema passthrough.
This test covers various edge cases including:
- Basic types with required/optional fields
- Nested objects
- $ref resolution
- Typed arrays (strings, integers, objects)
- Deeply nested structures
- Complex $ref with nested structures
- Mixed types
This test verifies that MCP tool schemas are passed through as-is
without Pydantic conversion, which improves performance and preserves
the original schema structure.
To add a new test case, add a tuple to the parametrize decorator with:
- test_id: A descriptive name for the test case
- input_schema: The JSON schema (inputSchema dict)
- valid_data: Valid data to instantiate the model
- expected_values: Dict of expected values (supports dot notation for nested access)
- invalid_data: Invalid data to test validation errors (None to skip)
- validation_check: Optional callable to perform additional validation checks
"""
tool = types.Tool(name="test_tool", description="A test tool", inputSchema=input_schema)
model = _get_input_model_from_mcp_tool(tool)
schema = tool.inputSchema
# Test valid data
instance = model(**valid_data)
# Check expected values
for field_path, expected_value in expected_values.items():
# Support dot notation and array indexing for nested access
current = instance
parts = field_path.replace("]", "").replace("[", ".").split(".")
for part in parts:
current = current[int(part)] if part.isdigit() else getattr(current, part)
assert current == expected_value, f"Field {field_path} = {current}, expected {expected_value}"
# Run additional validation checks if provided
if validation_check:
assert validation_check(instance), f"Validation check failed for {test_id}"
# Test invalid data if provided
if invalid_data is not None:
with pytest.raises(ValidationError):
model(**invalid_data)
# Verify schema is returned as-is (dict)
assert isinstance(schema, dict), f"Expected dict, got {type(schema)}"
assert schema == input_schema, "Schema should be passed through unchanged"
def test_get_input_model_from_mcp_prompt():
"""Test creation of input model from MCP prompt."""
"""Test creation of input schema from MCP prompt."""
prompt = types.Prompt(
name="test_prompt",
description="A test prompt",
@@ -641,16 +615,24 @@ def test_get_input_model_from_mcp_prompt():
types.PromptArgument(name="arg2", description="Second argument", required=False),
],
)
model = _get_input_model_from_mcp_prompt(prompt)
result = _get_input_model_from_mcp_prompt(prompt)
# Create an instance to verify the model works
instance = model(arg1="test", arg2="optional")
assert instance.arg1 == "test"
assert instance.arg2 == "optional"
# Should return a dict (schema)
assert isinstance(result, dict), f"Expected dict, got {type(result)}"
assert result["type"] == "object"
assert "arg1" in result["properties"]
assert "arg2" in result["properties"]
assert "arg1" in result["required"]
assert "arg2" not in result["required"]
# Test validation
with pytest.raises(ValidationError): # Missing required arg1
model(arg2="optional")
def test_get_input_model_from_mcp_prompt_without_arguments():
"""Test prompt schema generation when no prompt arguments are defined."""
prompt = types.Prompt(name="empty_prompt", description="No args prompt", arguments=[])
result = _get_input_model_from_mcp_prompt(prompt)
assert isinstance(result, dict)
assert result == {"type": "object", "properties": {}}
# MCPTool tests
@@ -74,7 +74,7 @@ class TestAgentContext:
class TestFunctionInvocationContext:
"""Test cases for FunctionInvocationContext."""
def test_init_with_defaults(self, mock_function: FunctionTool[Any]) -> None:
def test_init_with_defaults(self, mock_function: FunctionTool) -> None:
"""Test FunctionInvocationContext initialization with default values."""
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
@@ -83,7 +83,7 @@ class TestFunctionInvocationContext:
assert context.arguments == arguments
assert context.metadata == {}
def test_init_with_custom_metadata(self, mock_function: FunctionTool[Any]) -> None:
def test_init_with_custom_metadata(self, mock_function: FunctionTool) -> None:
"""Test FunctionInvocationContext initialization with custom metadata."""
arguments = FunctionTestArgs(name="test")
metadata = {"key": "value"}
@@ -420,7 +420,7 @@ class TestFunctionMiddlewarePipeline:
await call_next()
raise MiddlewareTermination
async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any]) -> None:
async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool) -> None:
"""Test pipeline execution with termination before next() raises MiddlewareTermination."""
middleware = self.PreNextTerminateFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline(middleware)
@@ -439,7 +439,7 @@ class TestFunctionMiddlewarePipeline:
# Handler should not be called when terminated before next()
assert execution_order == []
async def test_execute_with_post_next_termination(self, mock_function: FunctionTool[Any]) -> None:
async def test_execute_with_post_next_termination(self, mock_function: FunctionTool) -> None:
"""Test pipeline execution with termination after next() raises MiddlewareTermination."""
middleware = self.PostNextTerminateFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline(middleware)
@@ -480,7 +480,7 @@ class TestFunctionMiddlewarePipeline:
pipeline = FunctionMiddlewarePipeline(test_middleware)
assert pipeline.has_middlewares
async def test_execute_no_middleware(self, mock_function: FunctionTool[Any]) -> None:
async def test_execute_no_middleware(self, mock_function: FunctionTool) -> None:
"""Test pipeline execution with no middleware."""
pipeline = FunctionMiddlewarePipeline()
arguments = FunctionTestArgs(name="test")
@@ -494,7 +494,7 @@ class TestFunctionMiddlewarePipeline:
result = await pipeline.execute(context, final_handler)
assert result == expected_result
async def test_execute_with_middleware(self, mock_function: FunctionTool[Any]) -> None:
async def test_execute_with_middleware(self, mock_function: FunctionTool) -> None:
"""Test pipeline execution with middleware."""
execution_order: list[str] = []
@@ -787,7 +787,7 @@ class TestClassBasedMiddleware:
assert context.metadata["after"] is True
assert metadata_updates == ["before", "handler", "after"]
async def test_function_middleware_execution(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_middleware_execution(self, mock_function: FunctionTool) -> None:
"""Test class-based function middleware execution."""
metadata_updates: list[str] = []
@@ -847,7 +847,7 @@ class TestFunctionBasedMiddleware:
assert context.metadata["function_middleware"] is True
assert execution_order == ["function_before", "handler", "function_after"]
async def test_function_function_middleware(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_function_middleware(self, mock_function: FunctionTool) -> None:
"""Test function-based function middleware."""
execution_order: list[str] = []
@@ -905,7 +905,7 @@ class TestMixedMiddleware:
assert result is not None
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
async def test_mixed_function_middleware(self, mock_function: FunctionTool[Any]) -> None:
async def test_mixed_function_middleware(self, mock_function: FunctionTool) -> None:
"""Test mixed class and function-based function middleware."""
execution_order: list[str] = []
@@ -1017,7 +1017,7 @@ class TestMultipleMiddlewareOrdering:
]
assert execution_order == expected_order
async def test_function_middleware_execution_order(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_middleware_execution_order(self, mock_function: FunctionTool) -> None:
"""Test that multiple function middleware execute in registration order."""
execution_order: list[str] = []
@@ -1143,7 +1143,7 @@ class TestContextContentValidation:
result = await pipeline.execute(context, final_handler)
assert result is not None
async def test_function_context_validation(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_context_validation(self, mock_function: FunctionTool) -> None:
"""Test that function context contains expected data."""
class ContextValidationMiddleware(FunctionMiddleware):
@@ -1489,7 +1489,7 @@ class TestMiddlewareExecutionControl:
assert not handler_called
assert context.result is None
async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool) -> None:
"""Test that when function middleware doesn't call next(), no execution happens."""
class FunctionTestArgs(BaseModel):
@@ -1666,9 +1666,9 @@ def mock_agent() -> SupportsAgentRun:
@pytest.fixture
def mock_function() -> FunctionTool[Any]:
def mock_function() -> FunctionTool:
"""Mock function for testing."""
function = MagicMock(spec=FunctionTool[Any])
function = MagicMock(spec=FunctionTool)
function.name = "test_function"
return function
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any
from unittest.mock import MagicMock
import pytest
@@ -103,7 +102,7 @@ class TestResultOverrideMiddleware:
assert updates[0].text == "overridden"
assert updates[1].text == " stream"
async def test_function_middleware_result_override(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_middleware_result_override(self, mock_function: FunctionTool) -> None:
"""Test that function middleware can override result."""
override_result = "overridden function result"
@@ -252,7 +251,7 @@ class TestResultOverrideMiddleware:
assert execute_result.messages[0].text == "executed response"
assert handler_called
async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool) -> None:
"""Test that when function middleware conditionally doesn't call next(), no execution happens."""
class ConditionalNoNextFunctionMiddleware(FunctionMiddleware):
@@ -335,7 +334,7 @@ class TestResultObservability:
assert observed_responses[0].messages[0].text == "executed response"
assert result == observed_responses[0]
async def test_function_middleware_result_observability(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_middleware_result_observability(self, mock_function: FunctionTool) -> None:
"""Test that middleware can observe function result after execution."""
observed_results: list[str] = []
@@ -402,7 +401,7 @@ class TestResultObservability:
assert result is not None
assert result.messages[0].text == "modified after execution"
async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool[Any]) -> None:
async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool) -> None:
"""Test that middleware can override function result after observing execution."""
class PostExecutionOverrideMiddleware(FunctionMiddleware):
@@ -444,8 +443,8 @@ def mock_agent() -> SupportsAgentRun:
@pytest.fixture
def mock_function() -> FunctionTool[Any]:
def mock_function() -> FunctionTool:
"""Mock function for testing."""
function = MagicMock(spec=FunctionTool[Any])
function = MagicMock(spec=FunctionTool)
function.name = "test_function"
return function
+85 -1
View File
@@ -108,6 +108,90 @@ def test_tool_decorator_with_json_schema_dict():
assert search("hello") == "Searching for: hello (max 10)"
async def test_tool_decorator_with_json_schema_invoke_uses_mapping():
"""Test that schema-based tools can be invoked directly with mapping arguments."""
json_schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"max_results": {"type": "integer"},
},
"required": ["query"],
}
@tool(name="search", description="Search tool", schema=json_schema)
def search(query: str, max_results: int = 10) -> str:
return f"{query}:{max_results}"
result = await search.invoke(arguments={"query": "hello", "max_results": 3})
assert result == "hello:3"
async def test_tool_decorator_with_json_schema_invoke_missing_required():
"""Test schema-required fields are checked for mapping arguments."""
json_schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
},
"required": ["query"],
}
@tool(name="search", description="Search tool", schema=json_schema)
def search(query: str) -> str:
return query
with pytest.raises(TypeError, match="Missing required argument"):
await search.invoke(arguments={})
async def test_tool_decorator_with_json_schema_invoke_invalid_type():
"""Test schema type checks run for mapping arguments."""
json_schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"max_results": {"type": "integer"},
},
"required": ["query"],
}
@tool(name="search", description="Search tool", schema=json_schema)
def search(query: str, max_results: int = 10) -> str:
return f"{query}:{max_results}"
with pytest.raises(TypeError, match="Invalid type for 'max_results'"):
await search.invoke(arguments={"query": "hello", "max_results": "three"})
def test_tool_decorator_with_json_schema_preserves_custom_properties():
"""Test schema passthrough keeps custom JSON schema properties."""
json_schema = {
"type": "object",
"properties": {
"priority": {
"type": "string",
"enum": ["low", "medium", "high"],
"x-custom-field": "custom-value",
},
},
"required": ["priority"],
"additionalProperties": False,
}
@tool(name="process", description="Process tool", schema=json_schema)
def process(priority: str) -> str:
return priority
params = process.parameters()
assert not params.get("additionalProperties")
assert params["properties"]["priority"]["x-custom-field"] == "custom-value"
def test_tool_decorator_schema_none_default():
"""Test that schema=None (default) still infers from function signature."""
@@ -555,7 +639,7 @@ async def test_tool_invoke_telemetry_with_pydantic_args(span_exporter: InMemoryS
assert span.attributes[OtelAttr.TOOL_CALL_ID] == "pydantic_call"
assert span.attributes[OtelAttr.TOOL_TYPE] == "function"
assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool with Pydantic args"
assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x":5,"y":10}'
assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 5, "y": 10}'
async def test_tool_invoke_telemetry_with_exception(span_exporter: InMemorySpanExporter):
@@ -499,7 +499,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
return copilot_tools
def _tool_to_copilot_tool(self, ai_func: FunctionTool[Any]) -> CopilotTool:
def _tool_to_copilot_tool(self, ai_func: FunctionTool) -> CopilotTool:
"""Convert an FunctionTool to a Copilot SDK tool."""
async def handler(invocation: ToolInvocation) -> ToolResult:
@@ -27,7 +27,7 @@ from tau2.environment.tool import Tool # type: ignore[import-untyped]
_original_set_state = Environment.set_state
def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool[Any]:
def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool:
"""Convert a tau2 Tool to a FunctionTool for agent framework compatibility.
Creates a wrapper that preserves the tool's interface while ensuring
@@ -324,7 +324,7 @@ class HandoffAgentExecutor(AgentExecutor):
existing_tools = list(default_options.get("tools") or [])
existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")}
new_tools: list[FunctionTool[Any]] = []
new_tools: list[FunctionTool] = []
for target in targets:
handoff_tool = self._create_handoff_tool(target.target_id, target.description)
if handoff_tool.name in existing_names:
@@ -340,7 +340,7 @@ class HandoffAgentExecutor(AgentExecutor):
else:
default_options["tools"] = existing_tools
def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool[Any]:
def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool:
"""Construct the synthetic handoff tool that signals routing to `target_id`."""
tool_name = get_handoff_tool_name(target_id)
doc = description or f"Handoff to the {target_id} agent."