mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Remove FunctionTool[Any] compatibility shim for schema passthrough (#3600) (#3907)
* Fix #3600: Pass JSON schemas through without Pydantic conversion This change optimizes FunctionTool and MCP flows by passing JSON schemas directly to providers without converting them to Pydantic models first. Key changes: - Store JSON schema as-is when supplied to FunctionTool - Skip Pydantic model_validate for schema-supplied tools in invoke() - Return MCP tool schemas directly without conversion - Add comprehensive tests for schema passthrough behavior Performance benefits: - Eliminates expensive Pydantic model creation for supplied schemas - Preserves exact schema structure (additionalProperties, custom fields, etc.) - Reduces memory overhead and initialization time Maintains backward compatibility: - Function signature inference still uses Pydantic models - Explicit Pydantic models passed as input_model work as before - All existing tests pass * Fix schema passthrough validation and remove helper * Simplify FunctionTool without generic model dependency * Fix FunctionTool typing fallout in 3600 * Remove FunctionTool[Any] compatibility shim * Use serializable kwargs in OTEL tool args
This commit is contained in:
committed by
GitHub
Unverified
parent
cd1e3110aa
commit
fc9c81b0b1
@@ -267,7 +267,7 @@ class AGUIChatClient(
|
||||
if any(getattr(tool, "name", None) == tool_name for tool in additional_tools):
|
||||
return
|
||||
|
||||
placeholder: FunctionTool[Any] = FunctionTool(
|
||||
placeholder: FunctionTool = FunctionTool(
|
||||
name=tool_name,
|
||||
description="Server-managed tool placeholder (AG-UI)",
|
||||
func=None,
|
||||
|
||||
@@ -162,7 +162,7 @@ def make_json_safe(obj: Any) -> Any: # noqa: ANN401
|
||||
|
||||
def convert_agui_tools_to_agent_framework(
|
||||
agui_tools: list[dict[str, Any]] | None,
|
||||
) -> list[FunctionTool[Any]] | None:
|
||||
) -> list[FunctionTool] | None:
|
||||
"""Convert AG-UI tool definitions to Agent Framework FunctionTool declarations.
|
||||
|
||||
Creates declaration-only FunctionTool instances (no executable implementation).
|
||||
@@ -181,13 +181,13 @@ def convert_agui_tools_to_agent_framework(
|
||||
if not agui_tools:
|
||||
return None
|
||||
|
||||
result: list[FunctionTool[Any]] = []
|
||||
result: list[FunctionTool] = []
|
||||
for tool_def in agui_tools:
|
||||
# Create declaration-only FunctionTool (func=None means no implementation)
|
||||
# When func=None, the declaration_only property returns True,
|
||||
# which tells the function invocation mixin to return the function call
|
||||
# without executing it (so it can be sent back to the client)
|
||||
func: FunctionTool[Any] = FunctionTool(
|
||||
func: FunctionTool = FunctionTool(
|
||||
name=tool_def.get("name", ""),
|
||||
description=tool_def.get("description", ""),
|
||||
func=None, # CRITICAL: Makes declaration_only=True
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
from agent_framework import Agent, FunctionTool, SupportsChatGetResponse
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from agent_framework import ChatOptions
|
||||
|
||||
# Declaration-only tools (func=None) - actual rendering happens on the client side
|
||||
generate_haiku = FunctionTool[Any](
|
||||
generate_haiku = FunctionTool(
|
||||
name="generate_haiku",
|
||||
description="""Generate a haiku with image and gradient background (FRONTEND_RENDER).
|
||||
|
||||
@@ -71,7 +71,7 @@ generate_haiku = FunctionTool[Any](
|
||||
},
|
||||
)
|
||||
|
||||
create_chart = FunctionTool[Any](
|
||||
create_chart = FunctionTool(
|
||||
name="create_chart",
|
||||
description="""Create an interactive chart (FRONTEND_RENDER).
|
||||
|
||||
@@ -99,7 +99,7 @@ create_chart = FunctionTool[Any](
|
||||
},
|
||||
)
|
||||
|
||||
display_timeline = FunctionTool[Any](
|
||||
display_timeline = FunctionTool(
|
||||
name="display_timeline",
|
||||
description="""Display an interactive timeline (FRONTEND_RENDER).
|
||||
|
||||
@@ -127,7 +127,7 @@ display_timeline = FunctionTool[Any](
|
||||
},
|
||||
)
|
||||
|
||||
show_comparison_table = FunctionTool[Any](
|
||||
show_comparison_table = FunctionTool(
|
||||
name="show_comparison_table",
|
||||
description="""Show a comparison table (FRONTEND_RENDER).
|
||||
|
||||
|
||||
@@ -484,7 +484,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
return create_sdk_mcp_server(name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools), tool_names
|
||||
|
||||
def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool[Any]) -> SdkMcpTool[Any]:
|
||||
def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool) -> SdkMcpTool[Any]:
|
||||
"""Convert a FunctionTool to an SDK MCP tool.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -439,7 +439,7 @@ class BaseAgent(SerializationMixin):
|
||||
stream_callback: Callable[[AgentResponseUpdate], None]
|
||||
| Callable[[AgentResponseUpdate], Awaitable[None]]
|
||||
| None = None,
|
||||
) -> FunctionTool[BaseModel]:
|
||||
) -> FunctionTool:
|
||||
"""Create a FunctionTool that wraps this agent.
|
||||
|
||||
Keyword Args:
|
||||
@@ -513,7 +513,7 @@ class BaseAgent(SerializationMixin):
|
||||
# Create final text from accumulated updates
|
||||
return AgentResponse.from_updates(response_updates).text
|
||||
|
||||
agent_tool: FunctionTool[BaseModel] = FunctionTool(
|
||||
agent_tool: FunctionTool = FunctionTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
func=agent_wrapper,
|
||||
@@ -1258,17 +1258,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
@server.list_tools() # type: ignore
|
||||
async def _list_tools() -> list[types.Tool]: # type: ignore
|
||||
"""List all tools in the agent."""
|
||||
# Get the JSON schema from the Pydantic model
|
||||
schema = agent_tool.input_model.model_json_schema()
|
||||
schema = agent_tool.parameters()
|
||||
|
||||
tool = types.Tool(
|
||||
name=agent_tool.name,
|
||||
description=agent_tool.description,
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": schema.get("properties", {}),
|
||||
"required": schema.get("required", []),
|
||||
},
|
||||
inputSchema=schema,
|
||||
)
|
||||
|
||||
await _log(level="debug", data=f"Agent tool: {agent_tool}")
|
||||
@@ -1291,7 +1286,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
# Create an instance of the input model with the arguments
|
||||
try:
|
||||
args_instance = agent_tool.input_model(**arguments)
|
||||
args_instance: BaseModel | dict[str, Any] = (
|
||||
agent_tool.input_model(**arguments) if agent_tool.input_model is not None else arguments
|
||||
)
|
||||
result = await agent_tool.invoke(arguments=args_instance)
|
||||
except Exception as e:
|
||||
raise McpError(
|
||||
|
||||
@@ -24,11 +24,9 @@ from mcp.client.websocket import websocket_client
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from ._tools import (
|
||||
FunctionTool,
|
||||
_build_pydantic_model_from_json_schema,
|
||||
)
|
||||
from ._types import (
|
||||
Content,
|
||||
@@ -355,11 +353,14 @@ def _prepare_message_for_mcp(
|
||||
return messages
|
||||
|
||||
|
||||
def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]:
|
||||
"""Creates a Pydantic model from a prompt's parameters."""
|
||||
def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any]:
|
||||
"""Get the input model from an MCP prompt.
|
||||
|
||||
Returns a JSON schema dictionary for prompt arguments.
|
||||
"""
|
||||
# Check if 'arguments' is missing or empty
|
||||
if not prompt.arguments:
|
||||
return create_model(f"{prompt.name}_input")
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
# Convert prompt arguments to JSON schema format
|
||||
properties: dict[str, Any] = {}
|
||||
@@ -374,13 +375,10 @@ def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]:
|
||||
if prompt_argument.required:
|
||||
required.append(prompt_argument.name)
|
||||
|
||||
schema = {"properties": properties, "required": required}
|
||||
return _build_pydantic_model_from_json_schema(prompt.name, schema)
|
||||
|
||||
|
||||
def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
|
||||
"""Creates a Pydantic model from a tools parameters."""
|
||||
return _build_pydantic_model_from_json_schema(tool.name, tool.inputSchema)
|
||||
schema: dict[str, Any] = {"type": "object", "properties": properties}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
return schema
|
||||
|
||||
|
||||
def _normalize_mcp_name(name: str) -> str:
|
||||
@@ -467,7 +465,7 @@ class MCPTool:
|
||||
self.session = session
|
||||
self.request_timeout = request_timeout
|
||||
self.client = client
|
||||
self._functions: list[FunctionTool[Any]] = []
|
||||
self._functions: list[FunctionTool] = []
|
||||
self.is_connected: bool = False
|
||||
self._tools_loaded: bool = False
|
||||
self._prompts_loaded: bool = False
|
||||
@@ -476,7 +474,7 @@ class MCPTool:
|
||||
return f"MCPTool(name={self.name}, description={self.description})"
|
||||
|
||||
@property
|
||||
def functions(self) -> list[FunctionTool[Any]]:
|
||||
def functions(self) -> list[FunctionTool]:
|
||||
"""Get the list of functions that are allowed."""
|
||||
if not self.allowed_tools:
|
||||
return self._functions
|
||||
@@ -744,7 +742,7 @@ class MCPTool:
|
||||
|
||||
input_model = _get_input_model_from_mcp_prompt(prompt)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
func: FunctionTool[BaseModel] = FunctionTool(
|
||||
func: FunctionTool = FunctionTool(
|
||||
func=partial(self.get_prompt, prompt.name),
|
||||
name=local_name,
|
||||
description=prompt.description or "",
|
||||
@@ -785,15 +783,14 @@ class MCPTool:
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
input_model = _get_input_model_from_mcp_tool(tool)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
# Create FunctionTools out of each tool
|
||||
func: FunctionTool[BaseModel] = FunctionTool(
|
||||
func: FunctionTool = FunctionTool(
|
||||
func=partial(self.call_tool, tool.name),
|
||||
name=local_name,
|
||||
description=tool.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
input_model=tool.inputSchema,
|
||||
)
|
||||
self._functions.append(func)
|
||||
existing_names.add(local_name)
|
||||
|
||||
@@ -234,8 +234,8 @@ class FunctionInvocationContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: FunctionTool[Any],
|
||||
arguments: BaseModel,
|
||||
function: FunctionTool,
|
||||
arguments: BaseModel | Mapping[str, Any],
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
result: Any = None,
|
||||
kwargs: Mapping[str, Any] | None = None,
|
||||
|
||||
@@ -26,7 +26,6 @@ from typing import (
|
||||
Literal,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
@@ -89,8 +88,6 @@ DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
# region Helpers
|
||||
|
||||
ArgsT = TypeVar("ArgsT", bound=BaseModel, default=BaseModel)
|
||||
|
||||
|
||||
def _parse_inputs(
|
||||
inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None,
|
||||
@@ -183,11 +180,7 @@ def _default_histogram() -> Histogram:
|
||||
ClassT = TypeVar("ClassT", bound="SerializationMixin")
|
||||
|
||||
|
||||
class EmptyInputModel(BaseModel):
|
||||
"""An empty input model for functions with no parameters."""
|
||||
|
||||
|
||||
class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
class FunctionTool(SerializationMixin):
|
||||
"""A tool that wraps a Python function to make it callable by AI models.
|
||||
|
||||
This class wraps a Python function to make it callable by AI models with automatic
|
||||
@@ -240,6 +233,8 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
"input_model",
|
||||
"_invocation_duration_histogram",
|
||||
"_cached_parameters",
|
||||
"_input_schema",
|
||||
"_schema_supplied",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -252,7 +247,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
max_invocation_exceptions: int | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
func: Callable[..., Any] | None = None,
|
||||
input_model: type[ArgsT] | Mapping[str, Any] | None = None,
|
||||
input_model: type[BaseModel] | Mapping[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -299,7 +294,16 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
# FunctionTool-specific attributes
|
||||
self.func = func
|
||||
self._instance = None # Store the instance for bound methods
|
||||
self.input_model = self._resolve_input_model(input_model)
|
||||
|
||||
# Track if schema was supplied as JSON dict (for optimization)
|
||||
if isinstance(input_model, Mapping):
|
||||
self._schema_supplied = True
|
||||
self._input_schema: dict[str, Any] = dict(input_model)
|
||||
self.input_model: type[BaseModel] | None = None
|
||||
else:
|
||||
self._schema_supplied = False
|
||||
self.input_model = self._resolve_input_model(input_model)
|
||||
self._input_schema = self.input_model.model_json_schema()
|
||||
self._cached_parameters: dict[str, Any] | None = None
|
||||
self.approval_mode = approval_mode or "never_require"
|
||||
if max_invocations is not None and max_invocations < 1:
|
||||
@@ -335,7 +339,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
return True
|
||||
return self.func is None
|
||||
|
||||
def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]:
|
||||
def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool:
|
||||
"""Implement the descriptor protocol to support bound methods.
|
||||
|
||||
When a FunctionTool is accessed as an attribute of a class instance,
|
||||
@@ -366,17 +370,30 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
|
||||
return self
|
||||
|
||||
def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]:
|
||||
def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[BaseModel]:
|
||||
"""Resolve the input model for the function."""
|
||||
if input_model is None:
|
||||
if self.func is None:
|
||||
return cast(type[ArgsT], EmptyInputModel)
|
||||
return cast(type[ArgsT], _create_input_model_from_func(func=self.func, name=self.name))
|
||||
if inspect.isclass(input_model) and issubclass(input_model, BaseModel):
|
||||
return input_model
|
||||
if isinstance(input_model, Mapping):
|
||||
return cast(type[ArgsT], _create_model_from_json_schema(self.name, input_model))
|
||||
raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.")
|
||||
if input_model is not None:
|
||||
if inspect.isclass(input_model) and issubclass(input_model, BaseModel):
|
||||
return input_model
|
||||
raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.")
|
||||
|
||||
if self.func is None:
|
||||
return create_model(f"{self.name}_input")
|
||||
|
||||
func = self.func.func if isinstance(self.func, FunctionTool) else self.func
|
||||
if func is None:
|
||||
return create_model(f"{self.name}_input")
|
||||
sig = inspect.signature(func)
|
||||
fields: dict[str, Any] = {
|
||||
pname: (
|
||||
_parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str,
|
||||
param.default if param.default is not inspect.Parameter.empty else ...,
|
||||
)
|
||||
for pname, param in sig.parameters.items()
|
||||
if pname not in {"self", "cls"}
|
||||
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
||||
}
|
||||
return create_model(f"{self.name}_input", **fields)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call the wrapped function with the provided arguments."""
|
||||
@@ -407,7 +424,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
async def invoke(
|
||||
self,
|
||||
*,
|
||||
arguments: ArgsT | None = None,
|
||||
arguments: BaseModel | Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the AI function with the provided arguments as a Pydantic model.
|
||||
@@ -417,14 +434,14 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
``result_parser`` if one was provided.
|
||||
|
||||
Keyword Args:
|
||||
arguments: A Pydantic model instance containing the arguments for the function.
|
||||
arguments: A mapping or model instance containing the arguments for the function.
|
||||
kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided.
|
||||
|
||||
Returns:
|
||||
The parsed result as a string — either plain text or serialized JSON.
|
||||
|
||||
Raises:
|
||||
TypeError: If arguments is not an instance of the expected input model.
|
||||
TypeError: If arguments is not mapping-like or fails schema checks.
|
||||
"""
|
||||
if self.declaration_only:
|
||||
raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.")
|
||||
@@ -436,9 +453,32 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
original_kwargs = dict(kwargs)
|
||||
tool_call_id = original_kwargs.pop("tool_call_id", None)
|
||||
if arguments is not None:
|
||||
if not isinstance(arguments, self.input_model):
|
||||
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
|
||||
kwargs = arguments.model_dump(exclude_none=True)
|
||||
try:
|
||||
if isinstance(arguments, Mapping):
|
||||
parsed_arguments = dict(arguments)
|
||||
if self.input_model is not None and not self._schema_supplied:
|
||||
parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
elif isinstance(arguments, BaseModel):
|
||||
if (
|
||||
self.input_model is not None
|
||||
and not self._schema_supplied
|
||||
and not isinstance(arguments, self.input_model)
|
||||
):
|
||||
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
|
||||
parsed_arguments = arguments.model_dump(exclude_none=True)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}"
|
||||
)
|
||||
except ValidationError as exc:
|
||||
raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc
|
||||
kwargs = _validate_arguments_against_schema(
|
||||
arguments=parsed_arguments,
|
||||
schema=self.parameters(),
|
||||
tool_name=self.name,
|
||||
)
|
||||
if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs:
|
||||
kwargs.update(original_kwargs)
|
||||
else:
|
||||
@@ -458,34 +498,34 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
return parsed
|
||||
|
||||
attributes = get_function_span_attributes(self, tool_call_id=tool_call_id)
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
|
||||
# Filter out framework kwargs that are not JSON serializable
|
||||
serializable_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in {
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
}
|
||||
# Filter out framework kwargs that are not JSON serializable.
|
||||
serializable_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in {
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
}
|
||||
}
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
|
||||
attributes.update({
|
||||
OtelAttr.TOOL_ARGUMENTS: arguments.model_dump_json(ensure_ascii=False)
|
||||
if arguments
|
||||
else json.dumps(serializable_kwargs, default=str, ensure_ascii=False)
|
||||
if serializable_kwargs
|
||||
else "None"
|
||||
OtelAttr.TOOL_ARGUMENTS: (
|
||||
json.dumps(serializable_kwargs, default=str, ensure_ascii=False)
|
||||
if serializable_kwargs
|
||||
else "None"
|
||||
)
|
||||
})
|
||||
with get_function_span(attributes=attributes) as span:
|
||||
attributes[OtelAttr.MEASUREMENT_FUNCTION_TAG_NAME] = self.name
|
||||
logger.info(f"Function name: {self.name}")
|
||||
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined]
|
||||
logger.debug(f"Function arguments: {kwargs}")
|
||||
logger.debug(f"Function arguments: {serializable_kwargs}")
|
||||
start_time_stamp = perf_counter()
|
||||
end_time_stamp: float | None = None
|
||||
try:
|
||||
@@ -523,7 +563,7 @@ class FunctionTool(SerializationMixin, Generic[ArgsT]):
|
||||
The result is cached after the first call for performance.
|
||||
"""
|
||||
if self._cached_parameters is None:
|
||||
self._cached_parameters = self.input_model.model_json_schema()
|
||||
self._cached_parameters = self._input_schema
|
||||
return self._cached_parameters
|
||||
|
||||
@staticmethod
|
||||
@@ -677,23 +717,79 @@ def _parse_annotation(annotation: Any) -> Any:
|
||||
return annotation
|
||||
|
||||
|
||||
def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from a function's signature."""
|
||||
# Unwrap FunctionTool objects to get the underlying function
|
||||
if isinstance(func, FunctionTool):
|
||||
func = func.func # type: ignore[assignment]
|
||||
def _matches_json_schema_type(value: Any, schema_type: str) -> bool:
|
||||
"""Check a value against a simple JSON schema primitive type."""
|
||||
match schema_type:
|
||||
case "string":
|
||||
return isinstance(value, str)
|
||||
case "integer":
|
||||
return isinstance(value, int) and not isinstance(value, bool)
|
||||
case "number":
|
||||
return (isinstance(value, int | float)) and not isinstance(value, bool)
|
||||
case "boolean":
|
||||
return isinstance(value, bool)
|
||||
case "array":
|
||||
return isinstance(value, list)
|
||||
case "object":
|
||||
return isinstance(value, dict)
|
||||
case "null":
|
||||
return value is None
|
||||
case _:
|
||||
return True
|
||||
|
||||
sig = inspect.signature(func)
|
||||
fields = {
|
||||
pname: (
|
||||
_parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str,
|
||||
param.default if param.default is not inspect.Parameter.empty else ...,
|
||||
)
|
||||
for pname, param in sig.parameters.items()
|
||||
if pname not in {"self", "cls"}
|
||||
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
||||
}
|
||||
return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return]
|
||||
|
||||
def _validate_arguments_against_schema(
|
||||
*,
|
||||
arguments: Mapping[str, Any],
|
||||
schema: Mapping[str, Any],
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run lightweight argument checks for schema-supplied tools."""
|
||||
parsed_arguments = dict(arguments)
|
||||
|
||||
required_raw = schema.get("required", [])
|
||||
required_fields = [field for field in required_raw if isinstance(field, str)]
|
||||
missing_fields = [field for field in required_fields if field not in parsed_arguments]
|
||||
if missing_fields:
|
||||
raise TypeError(f"Missing required argument(s) for '{tool_name}': {', '.join(sorted(missing_fields))}")
|
||||
|
||||
properties_raw = schema.get("properties")
|
||||
properties = properties_raw if isinstance(properties_raw, Mapping) else {}
|
||||
|
||||
if schema.get("additionalProperties") is False:
|
||||
unexpected_fields = sorted(field for field in parsed_arguments if field not in properties)
|
||||
if unexpected_fields:
|
||||
raise TypeError(f"Unexpected argument(s) for '{tool_name}': {', '.join(unexpected_fields)}")
|
||||
|
||||
for field_name, field_value in parsed_arguments.items():
|
||||
field_schema = properties.get(field_name)
|
||||
if not isinstance(field_schema, Mapping):
|
||||
continue
|
||||
|
||||
enum_values = field_schema.get("enum")
|
||||
if isinstance(enum_values, list) and enum_values and field_value not in enum_values:
|
||||
raise TypeError(
|
||||
f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} is not in {enum_values!r}"
|
||||
)
|
||||
|
||||
schema_type = field_schema.get("type")
|
||||
if isinstance(schema_type, str):
|
||||
if not _matches_json_schema_type(field_value, schema_type):
|
||||
raise TypeError(
|
||||
f"Invalid type for '{field_name}' in '{tool_name}': "
|
||||
f"expected {schema_type}, got {type(field_value).__name__}"
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(schema_type, list):
|
||||
allowed_types = [item for item in schema_type if isinstance(item, str)]
|
||||
if allowed_types and not any(_matches_json_schema_type(field_value, item) for item in allowed_types):
|
||||
raise TypeError(
|
||||
f"Invalid type for '{field_name}' in '{tool_name}': expected one of "
|
||||
f"{allowed_types}, got {type(field_value).__name__}"
|
||||
)
|
||||
|
||||
return parsed_arguments
|
||||
|
||||
|
||||
# Map JSON Schema types to Pydantic types
|
||||
@@ -942,7 +1038,7 @@ def tool(
|
||||
max_invocation_exceptions: int | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str] | None = None,
|
||||
) -> FunctionTool[Any]: ...
|
||||
) -> FunctionTool: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -957,7 +1053,7 @@ def tool(
|
||||
max_invocation_exceptions: int | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str] | None = None,
|
||||
) -> Callable[[Callable[..., Any]], FunctionTool[Any]]: ...
|
||||
) -> Callable[[Callable[..., Any]], FunctionTool]: ...
|
||||
|
||||
|
||||
def tool(
|
||||
@@ -971,7 +1067,7 @@ def tool(
|
||||
max_invocation_exceptions: int | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
result_parser: Callable[[Any], str] | None = None,
|
||||
) -> FunctionTool[Any] | Callable[[Callable[..., Any]], FunctionTool[Any]]:
|
||||
) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]:
|
||||
"""Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically.
|
||||
|
||||
This decorator creates a Pydantic model from the function's signature,
|
||||
@@ -1095,12 +1191,12 @@ def tool(
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> FunctionTool[Any]:
|
||||
def decorator(func: Callable[..., Any]) -> FunctionTool:
|
||||
@wraps(func)
|
||||
def wrapper(f: Callable[..., Any]) -> FunctionTool[Any]:
|
||||
def wrapper(f: Callable[..., Any]) -> FunctionTool:
|
||||
tool_name: str = name or getattr(f, "__name__", "unknown_function") # type: ignore[assignment]
|
||||
tool_desc: str = description or (f.__doc__ or "")
|
||||
return FunctionTool[Any](
|
||||
return FunctionTool(
|
||||
name=tool_name,
|
||||
description=tool_desc,
|
||||
approval_mode=approval_mode,
|
||||
@@ -1193,7 +1289,7 @@ async def _auto_invoke_function(
|
||||
custom_args: dict[str, Any] | None = None,
|
||||
*,
|
||||
config: FunctionInvocationConfiguration,
|
||||
tool_map: dict[str, FunctionTool[BaseModel]],
|
||||
tool_map: dict[str, FunctionTool],
|
||||
sequence_index: int | None = None,
|
||||
request_index: int | None = None,
|
||||
middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline
|
||||
@@ -1225,7 +1321,7 @@ async def _auto_invoke_function(
|
||||
# this function is called. This function only handles the actual execution of approved,
|
||||
# non-declaration-only functions.
|
||||
|
||||
tool: FunctionTool[BaseModel] | None = None
|
||||
tool: FunctionTool | None = None
|
||||
if function_call_content.type == "function_call":
|
||||
tool = tool_map.get(function_call_content.name) # type: ignore[arg-type]
|
||||
# Tool should exist because _try_execute_function_calls validates this
|
||||
@@ -1258,8 +1354,16 @@ async def _auto_invoke_function(
|
||||
if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"}
|
||||
}
|
||||
try:
|
||||
args = tool.input_model.model_validate(parsed_args)
|
||||
except ValidationError as exc:
|
||||
if not tool._schema_supplied and tool.input_model is not None:
|
||||
args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True)
|
||||
else:
|
||||
args = dict(parsed_args)
|
||||
args = _validate_arguments_against_schema(
|
||||
arguments=args,
|
||||
schema=tool.parameters(),
|
||||
tool_name=tool.name,
|
||||
)
|
||||
except (TypeError, ValidationError) as exc:
|
||||
message = "Error: Argument parsing failed."
|
||||
if config["include_detailed_errors"]:
|
||||
message = f"{message} Exception: {exc}"
|
||||
@@ -1340,8 +1444,8 @@ def _get_tool_map(
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]],
|
||||
) -> dict[str, FunctionTool[Any]]:
|
||||
tool_list: dict[str, FunctionTool[Any]] = {}
|
||||
) -> dict[str, FunctionTool]:
|
||||
tool_list: dict[str, FunctionTool] = {}
|
||||
for tool_item in tools if isinstance(tools, list) else [tools]:
|
||||
if isinstance(tool_item, FunctionTool):
|
||||
tool_list[tool_item.name] = tool_item
|
||||
|
||||
@@ -1448,7 +1448,7 @@ class AgentTelemetryLayer:
|
||||
# region Otel Helpers
|
||||
|
||||
|
||||
def get_function_span_attributes(function: FunctionTool[Any], tool_call_id: str | None = None) -> dict[str, str]:
|
||||
def get_function_span_attributes(function: FunctionTool, tool_call_id: str | None = None) -> dict[str, str]:
|
||||
"""Get the span attributes for the given function.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from mcp import types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.shared.exceptions import McpError
|
||||
from pydantic import AnyUrl, BaseModel, ValidationError
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from agent_framework import (
|
||||
Content,
|
||||
@@ -22,7 +22,6 @@ from agent_framework import (
|
||||
from agent_framework._mcp import (
|
||||
MCPTool,
|
||||
_get_input_model_from_mcp_prompt,
|
||||
_get_input_model_from_mcp_tool,
|
||||
_normalize_mcp_name,
|
||||
_parse_content_from_mcp,
|
||||
_parse_message_from_mcp,
|
||||
@@ -276,363 +275,338 @@ def test_prepare_message_for_mcp():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_id,input_schema,valid_data,expected_values,invalid_data,validation_check",
|
||||
"test_id,input_schema",
|
||||
[
|
||||
# Basic types with required/optional fields
|
||||
(
|
||||
"basic_types",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"param1": {"type": "string"}, "param2": {"type": "number"}},
|
||||
"required": ["param1"],
|
||||
},
|
||||
{"param1": "test", "param2": 42},
|
||||
{"param1": "test", "param2": 42},
|
||||
{"param2": 42}, # Missing required param1
|
||||
None,
|
||||
),
|
||||
# Nested object
|
||||
(
|
||||
"nested_object",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"params": {
|
||||
"type": "object",
|
||||
"properties": {"customer_id": {"type": "integer"}},
|
||||
"required": ["customer_id"],
|
||||
}
|
||||
(test_id, input_schema)
|
||||
for test_id, input_schema, _, _, _, _ in [
|
||||
# Basic types with required/optional fields
|
||||
(
|
||||
"basic_types",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"param1": {"type": "string"}, "param2": {"type": "number"}},
|
||||
"required": ["param1"],
|
||||
},
|
||||
"required": ["params"],
|
||||
},
|
||||
{"params": {"customer_id": 251}},
|
||||
{"params.customer_id": 251},
|
||||
{"params": {}}, # Missing required customer_id
|
||||
lambda instance: isinstance(instance.params, BaseModel),
|
||||
),
|
||||
# $ref resolution
|
||||
(
|
||||
"ref_schema",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}},
|
||||
"required": ["params"],
|
||||
"$defs": {
|
||||
"CustomerIdParam": {
|
||||
"type": "object",
|
||||
"properties": {"customer_id": {"type": "integer"}},
|
||||
"required": ["customer_id"],
|
||||
}
|
||||
},
|
||||
},
|
||||
{"params": {"customer_id": 251}},
|
||||
{"params.customer_id": 251},
|
||||
{"params": {}}, # Missing required customer_id
|
||||
lambda instance: isinstance(instance.params, BaseModel),
|
||||
),
|
||||
# Array of strings (typed)
|
||||
(
|
||||
"array_of_strings",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"description": "List of tags",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
},
|
||||
"required": ["tags"],
|
||||
},
|
||||
{"tags": ["tag1", "tag2", "tag3"]},
|
||||
{"tags": ["tag1", "tag2", "tag3"]},
|
||||
None, # No validation error test for this case
|
||||
None,
|
||||
),
|
||||
# Array of integers (typed)
|
||||
(
|
||||
"array_of_integers",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"numbers": {
|
||||
"type": "array",
|
||||
"description": "List of integers",
|
||||
"items": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
"required": ["numbers"],
|
||||
},
|
||||
{"numbers": [1, 2, 3]},
|
||||
{"numbers": [1, 2, 3]},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Array of objects (complex nested)
|
||||
(
|
||||
"array_of_objects",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"users": {
|
||||
"type": "array",
|
||||
"description": "List of users",
|
||||
"items": {
|
||||
{"param1": "test", "param2": 42},
|
||||
{"param1": "test", "param2": 42},
|
||||
{"param2": 42}, # Missing required param1
|
||||
None,
|
||||
),
|
||||
# Nested object
|
||||
(
|
||||
"nested_object",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"params": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer", "description": "User ID"},
|
||||
"name": {"type": "string", "description": "User name"},
|
||||
},
|
||||
"required": ["id", "name"],
|
||||
},
|
||||
}
|
||||
"properties": {"customer_id": {"type": "integer"}},
|
||||
"required": ["customer_id"],
|
||||
}
|
||||
},
|
||||
"required": ["params"],
|
||||
},
|
||||
"required": ["users"],
|
||||
},
|
||||
{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]},
|
||||
{"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"},
|
||||
{"users": [{"id": 1}]}, # Missing required 'name'
|
||||
lambda instance: all(isinstance(user, BaseModel) for user in instance.users),
|
||||
),
|
||||
# Deeply nested objects (3+ levels)
|
||||
(
|
||||
"deeply_nested",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filters": {
|
||||
{"params": {"customer_id": 251}},
|
||||
{"params.customer_id": 251},
|
||||
{"params": {}}, # Missing required customer_id
|
||||
lambda instance: isinstance(instance.params, BaseModel),
|
||||
),
|
||||
# $ref resolution
|
||||
(
|
||||
"ref_schema",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}},
|
||||
"required": ["params"],
|
||||
"$defs": {
|
||||
"CustomerIdParam": {
|
||||
"type": "object",
|
||||
"properties": {"customer_id": {"type": "integer"}},
|
||||
"required": ["customer_id"],
|
||||
}
|
||||
},
|
||||
},
|
||||
{"params": {"customer_id": 251}},
|
||||
{"params.customer_id": 251},
|
||||
{"params": {}}, # Missing required customer_id
|
||||
lambda instance: isinstance(instance.params, BaseModel),
|
||||
),
|
||||
# Array of strings (typed)
|
||||
(
|
||||
"array_of_strings",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"description": "List of tags",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
},
|
||||
"required": ["tags"],
|
||||
},
|
||||
{"tags": ["tag1", "tag2", "tag3"]},
|
||||
{"tags": ["tag1", "tag2", "tag3"]},
|
||||
None, # No validation error test for this case
|
||||
None,
|
||||
),
|
||||
# Array of integers (typed)
|
||||
(
|
||||
"array_of_integers",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"numbers": {
|
||||
"type": "array",
|
||||
"description": "List of integers",
|
||||
"items": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
"required": ["numbers"],
|
||||
},
|
||||
{"numbers": [1, 2, 3]},
|
||||
{"numbers": [1, 2, 3]},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Array of objects (complex nested)
|
||||
(
|
||||
"array_of_objects",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"users": {
|
||||
"type": "array",
|
||||
"description": "List of users",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date_range": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start": {"type": "string"},
|
||||
"end": {"type": "string"},
|
||||
},
|
||||
"required": ["start", "end"],
|
||||
},
|
||||
"categories": {"type": "array", "items": {"type": "string"}},
|
||||
"id": {"type": "integer", "description": "User ID"},
|
||||
"name": {"type": "string", "description": "User name"},
|
||||
},
|
||||
"required": ["date_range"],
|
||||
}
|
||||
},
|
||||
"required": ["filters"],
|
||||
"required": ["id", "name"],
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["users"],
|
||||
},
|
||||
{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]},
|
||||
{"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"},
|
||||
{"users": [{"id": 1}]}, # Missing required 'name'
|
||||
lambda instance: all(isinstance(user, BaseModel) for user in instance.users),
|
||||
),
|
||||
# Deeply nested objects (3+ levels)
|
||||
(
|
||||
"deeply_nested",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date_range": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start": {"type": "string"},
|
||||
"end": {"type": "string"},
|
||||
},
|
||||
"required": ["start", "end"],
|
||||
},
|
||||
"categories": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["date_range"],
|
||||
}
|
||||
},
|
||||
"required": ["filters"],
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
{
|
||||
"query": {
|
||||
"filters": {
|
||||
"date_range": {"start": "2024-01-01", "end": "2024-12-31"},
|
||||
"categories": ["tech", "science"],
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
{
|
||||
"query": {
|
||||
"filters": {
|
||||
"date_range": {"start": "2024-01-01", "end": "2024-12-31"},
|
||||
"categories": ["tech", "science"],
|
||||
{
|
||||
"query.filters.date_range.start": "2024-01-01",
|
||||
"query.filters.date_range.end": "2024-12-31",
|
||||
"query.filters.categories": ["tech", "science"],
|
||||
},
|
||||
{"query": {"filters": {"date_range": {}}}}, # Missing required start and end
|
||||
None,
|
||||
),
|
||||
# Complex $ref with nested structure
|
||||
(
|
||||
"ref_nested_structure",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"order": {"$ref": "#/$defs/OrderParams"}},
|
||||
"required": ["order"],
|
||||
"$defs": {
|
||||
"OrderParams": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {"$ref": "#/$defs/Customer"},
|
||||
"items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}},
|
||||
},
|
||||
"required": ["customer", "items"],
|
||||
},
|
||||
"Customer": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}, "email": {"type": "string"}},
|
||||
"required": ["id", "email"],
|
||||
},
|
||||
"OrderItem": {
|
||||
"type": "object",
|
||||
"properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}},
|
||||
"required": ["product_id", "quantity"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"order": {
|
||||
"customer": {"id": 123, "email": "test@example.com"},
|
||||
"items": [{"product_id": "prod1", "quantity": 2}],
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"query.filters.date_range.start": "2024-01-01",
|
||||
"query.filters.date_range.end": "2024-12-31",
|
||||
"query.filters.categories": ["tech", "science"],
|
||||
},
|
||||
{"query": {"filters": {"date_range": {}}}}, # Missing required start and end
|
||||
None,
|
||||
),
|
||||
# Complex $ref with nested structure
|
||||
(
|
||||
"ref_nested_structure",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"order": {"$ref": "#/$defs/OrderParams"}},
|
||||
"required": ["order"],
|
||||
"$defs": {
|
||||
"OrderParams": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {"$ref": "#/$defs/Customer"},
|
||||
"items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}},
|
||||
},
|
||||
{
|
||||
"order.customer.id": 123,
|
||||
"order.customer.email": "test@example.com",
|
||||
"order.items[0].product_id": "prod1",
|
||||
"order.items[0].quantity": 2,
|
||||
},
|
||||
{"order": {"customer": {"id": 123}, "items": []}}, # Missing email
|
||||
lambda instance: isinstance(instance.order.customer, BaseModel),
|
||||
),
|
||||
# Mixed types (primitives, arrays, nested objects)
|
||||
(
|
||||
"mixed_types",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"simple_string": {"type": "string"},
|
||||
"simple_number": {"type": "integer"},
|
||||
"string_array": {"type": "array", "items": {"type": "string"}},
|
||||
"nested_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"options": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["enabled"],
|
||||
},
|
||||
"required": ["customer", "items"],
|
||||
},
|
||||
"Customer": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "integer"}, "email": {"type": "string"}},
|
||||
"required": ["id", "email"],
|
||||
},
|
||||
"OrderItem": {
|
||||
"type": "object",
|
||||
"properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}},
|
||||
"required": ["product_id", "quantity"],
|
||||
"required": ["simple_string", "nested_config"],
|
||||
},
|
||||
{
|
||||
"simple_string": "test",
|
||||
"simple_number": 42,
|
||||
"string_array": ["a", "b"],
|
||||
"nested_config": {"enabled": True, "options": ["opt1", "opt2"]},
|
||||
},
|
||||
{
|
||||
"simple_string": "test",
|
||||
"simple_number": 42,
|
||||
"string_array": ["a", "b"],
|
||||
"nested_config.enabled": True,
|
||||
"nested_config.options": ["opt1", "opt2"],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Empty schema (no properties)
|
||||
(
|
||||
"empty_schema",
|
||||
{"type": "object", "properties": {}},
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# All primitive types
|
||||
(
|
||||
"all_primitives",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"string_field": {"type": "string"},
|
||||
"integer_field": {"type": "integer"},
|
||||
"number_field": {"type": "number"},
|
||||
"boolean_field": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"order": {
|
||||
"customer": {"id": 123, "email": "test@example.com"},
|
||||
"items": [{"product_id": "prod1", "quantity": 2}],
|
||||
}
|
||||
},
|
||||
{
|
||||
"order.customer.id": 123,
|
||||
"order.customer.email": "test@example.com",
|
||||
"order.items[0].product_id": "prod1",
|
||||
"order.items[0].quantity": 2,
|
||||
},
|
||||
{"order": {"customer": {"id": 123}, "items": []}}, # Missing email
|
||||
lambda instance: isinstance(instance.order.customer, BaseModel),
|
||||
),
|
||||
# Mixed types (primitives, arrays, nested objects)
|
||||
(
|
||||
"mixed_types",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"simple_string": {"type": "string"},
|
||||
"simple_number": {"type": "integer"},
|
||||
"string_array": {"type": "array", "items": {"type": "string"}},
|
||||
"nested_config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
"options": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["enabled"],
|
||||
},
|
||||
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
|
||||
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Edge case: unresolvable $ref (fallback to dict)
|
||||
(
|
||||
"unresolvable_ref",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"data": {"$ref": "#/$defs/NonExistent"}},
|
||||
"$defs": {},
|
||||
},
|
||||
"required": ["simple_string", "nested_config"],
|
||||
},
|
||||
{
|
||||
"simple_string": "test",
|
||||
"simple_number": 42,
|
||||
"string_array": ["a", "b"],
|
||||
"nested_config": {"enabled": True, "options": ["opt1", "opt2"]},
|
||||
},
|
||||
{
|
||||
"simple_string": "test",
|
||||
"simple_number": 42,
|
||||
"string_array": ["a", "b"],
|
||||
"nested_config.enabled": True,
|
||||
"nested_config.options": ["opt1", "opt2"],
|
||||
},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Empty schema (no properties)
|
||||
(
|
||||
"empty_schema",
|
||||
{"type": "object", "properties": {}},
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# All primitive types
|
||||
(
|
||||
"all_primitives",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"string_field": {"type": "string"},
|
||||
"integer_field": {"type": "integer"},
|
||||
"number_field": {"type": "number"},
|
||||
"boolean_field": {"type": "boolean"},
|
||||
{"data": {"key": "value"}},
|
||||
{"data": {"key": "value"}},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Edge case: array without items schema (fallback to bare list)
|
||||
(
|
||||
"array_no_items",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
},
|
||||
},
|
||||
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
|
||||
{"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Edge case: unresolvable $ref (fallback to dict)
|
||||
(
|
||||
"unresolvable_ref",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"data": {"$ref": "#/$defs/NonExistent"}},
|
||||
"$defs": {},
|
||||
},
|
||||
{"data": {"key": "value"}},
|
||||
{"data": {"key": "value"}},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Edge case: array without items schema (fallback to bare list)
|
||||
(
|
||||
"array_no_items",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
},
|
||||
{"items": [1, "two", 3.0]},
|
||||
{"items": [1, "two", 3.0]},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Edge case: object without properties (fallback to dict)
|
||||
(
|
||||
"object_no_properties",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"config": {"type": "object"}},
|
||||
},
|
||||
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
|
||||
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
{"items": [1, "two", 3.0]},
|
||||
{"items": [1, "two", 3.0]},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Edge case: object without properties (fallback to dict)
|
||||
(
|
||||
"object_no_properties",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"config": {"type": "object"}},
|
||||
},
|
||||
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
|
||||
{"config": {"arbitrary": "data", "nested": {"key": "value"}}},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_get_input_model_from_mcp_tool_parametrized(
|
||||
test_id, input_schema, valid_data, expected_values, invalid_data, validation_check
|
||||
):
|
||||
"""Parametrized test for JSON schema to Pydantic model conversion.
|
||||
def test_get_input_model_from_mcp_tool_parametrized(test_id: str, input_schema: dict[str, Any]) -> None:
|
||||
"""Parametrized test for MCP tool input schema passthrough.
|
||||
|
||||
This test covers various edge cases including:
|
||||
- Basic types with required/optional fields
|
||||
- Nested objects
|
||||
- $ref resolution
|
||||
- Typed arrays (strings, integers, objects)
|
||||
- Deeply nested structures
|
||||
- Complex $ref with nested structures
|
||||
- Mixed types
|
||||
This test verifies that MCP tool schemas are passed through as-is
|
||||
without Pydantic conversion, which improves performance and preserves
|
||||
the original schema structure.
|
||||
|
||||
To add a new test case, add a tuple to the parametrize decorator with:
|
||||
- test_id: A descriptive name for the test case
|
||||
- input_schema: The JSON schema (inputSchema dict)
|
||||
- valid_data: Valid data to instantiate the model
|
||||
- expected_values: Dict of expected values (supports dot notation for nested access)
|
||||
- invalid_data: Invalid data to test validation errors (None to skip)
|
||||
- validation_check: Optional callable to perform additional validation checks
|
||||
"""
|
||||
tool = types.Tool(name="test_tool", description="A test tool", inputSchema=input_schema)
|
||||
model = _get_input_model_from_mcp_tool(tool)
|
||||
schema = tool.inputSchema
|
||||
|
||||
# Test valid data
|
||||
instance = model(**valid_data)
|
||||
|
||||
# Check expected values
|
||||
for field_path, expected_value in expected_values.items():
|
||||
# Support dot notation and array indexing for nested access
|
||||
current = instance
|
||||
parts = field_path.replace("]", "").replace("[", ".").split(".")
|
||||
for part in parts:
|
||||
current = current[int(part)] if part.isdigit() else getattr(current, part)
|
||||
assert current == expected_value, f"Field {field_path} = {current}, expected {expected_value}"
|
||||
|
||||
# Run additional validation checks if provided
|
||||
if validation_check:
|
||||
assert validation_check(instance), f"Validation check failed for {test_id}"
|
||||
|
||||
# Test invalid data if provided
|
||||
if invalid_data is not None:
|
||||
with pytest.raises(ValidationError):
|
||||
model(**invalid_data)
|
||||
# Verify schema is returned as-is (dict)
|
||||
assert isinstance(schema, dict), f"Expected dict, got {type(schema)}"
|
||||
assert schema == input_schema, "Schema should be passed through unchanged"
|
||||
|
||||
|
||||
def test_get_input_model_from_mcp_prompt():
|
||||
"""Test creation of input model from MCP prompt."""
|
||||
"""Test creation of input schema from MCP prompt."""
|
||||
prompt = types.Prompt(
|
||||
name="test_prompt",
|
||||
description="A test prompt",
|
||||
@@ -641,16 +615,24 @@ def test_get_input_model_from_mcp_prompt():
|
||||
types.PromptArgument(name="arg2", description="Second argument", required=False),
|
||||
],
|
||||
)
|
||||
model = _get_input_model_from_mcp_prompt(prompt)
|
||||
result = _get_input_model_from_mcp_prompt(prompt)
|
||||
|
||||
# Create an instance to verify the model works
|
||||
instance = model(arg1="test", arg2="optional")
|
||||
assert instance.arg1 == "test"
|
||||
assert instance.arg2 == "optional"
|
||||
# Should return a dict (schema)
|
||||
assert isinstance(result, dict), f"Expected dict, got {type(result)}"
|
||||
assert result["type"] == "object"
|
||||
assert "arg1" in result["properties"]
|
||||
assert "arg2" in result["properties"]
|
||||
assert "arg1" in result["required"]
|
||||
assert "arg2" not in result["required"]
|
||||
|
||||
# Test validation
|
||||
with pytest.raises(ValidationError): # Missing required arg1
|
||||
model(arg2="optional")
|
||||
|
||||
def test_get_input_model_from_mcp_prompt_without_arguments():
|
||||
"""Test prompt schema generation when no prompt arguments are defined."""
|
||||
prompt = types.Prompt(name="empty_prompt", description="No args prompt", arguments=[])
|
||||
result = _get_input_model_from_mcp_prompt(prompt)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
# MCPTool tests
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestAgentContext:
|
||||
class TestFunctionInvocationContext:
|
||||
"""Test cases for FunctionInvocationContext."""
|
||||
|
||||
def test_init_with_defaults(self, mock_function: FunctionTool[Any]) -> None:
|
||||
def test_init_with_defaults(self, mock_function: FunctionTool) -> None:
|
||||
"""Test FunctionInvocationContext initialization with default values."""
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
@@ -83,7 +83,7 @@ class TestFunctionInvocationContext:
|
||||
assert context.arguments == arguments
|
||||
assert context.metadata == {}
|
||||
|
||||
def test_init_with_custom_metadata(self, mock_function: FunctionTool[Any]) -> None:
|
||||
def test_init_with_custom_metadata(self, mock_function: FunctionTool) -> None:
|
||||
"""Test FunctionInvocationContext initialization with custom metadata."""
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
metadata = {"key": "value"}
|
||||
@@ -420,7 +420,7 @@ class TestFunctionMiddlewarePipeline:
|
||||
await call_next()
|
||||
raise MiddlewareTermination
|
||||
|
||||
async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool) -> None:
|
||||
"""Test pipeline execution with termination before next() raises MiddlewareTermination."""
|
||||
middleware = self.PreNextTerminateFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline(middleware)
|
||||
@@ -439,7 +439,7 @@ class TestFunctionMiddlewarePipeline:
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
|
||||
async def test_execute_with_post_next_termination(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_execute_with_post_next_termination(self, mock_function: FunctionTool) -> None:
|
||||
"""Test pipeline execution with termination after next() raises MiddlewareTermination."""
|
||||
middleware = self.PostNextTerminateFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline(middleware)
|
||||
@@ -480,7 +480,7 @@ class TestFunctionMiddlewarePipeline:
|
||||
pipeline = FunctionMiddlewarePipeline(test_middleware)
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
async def test_execute_no_middleware(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_execute_no_middleware(self, mock_function: FunctionTool) -> None:
|
||||
"""Test pipeline execution with no middleware."""
|
||||
pipeline = FunctionMiddlewarePipeline()
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
@@ -494,7 +494,7 @@ class TestFunctionMiddlewarePipeline:
|
||||
result = await pipeline.execute(context, final_handler)
|
||||
assert result == expected_result
|
||||
|
||||
async def test_execute_with_middleware(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_execute_with_middleware(self, mock_function: FunctionTool) -> None:
|
||||
"""Test pipeline execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -787,7 +787,7 @@ class TestClassBasedMiddleware:
|
||||
assert context.metadata["after"] is True
|
||||
assert metadata_updates == ["before", "handler", "after"]
|
||||
|
||||
async def test_function_middleware_execution(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_middleware_execution(self, mock_function: FunctionTool) -> None:
|
||||
"""Test class-based function middleware execution."""
|
||||
metadata_updates: list[str] = []
|
||||
|
||||
@@ -847,7 +847,7 @@ class TestFunctionBasedMiddleware:
|
||||
assert context.metadata["function_middleware"] is True
|
||||
assert execution_order == ["function_before", "handler", "function_after"]
|
||||
|
||||
async def test_function_function_middleware(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_function_middleware(self, mock_function: FunctionTool) -> None:
|
||||
"""Test function-based function middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -905,7 +905,7 @@ class TestMixedMiddleware:
|
||||
assert result is not None
|
||||
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
|
||||
|
||||
async def test_mixed_function_middleware(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_mixed_function_middleware(self, mock_function: FunctionTool) -> None:
|
||||
"""Test mixed class and function-based function middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -1017,7 +1017,7 @@ class TestMultipleMiddlewareOrdering:
|
||||
]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_function_middleware_execution_order(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_middleware_execution_order(self, mock_function: FunctionTool) -> None:
|
||||
"""Test that multiple function middleware execute in registration order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@@ -1143,7 +1143,7 @@ class TestContextContentValidation:
|
||||
result = await pipeline.execute(context, final_handler)
|
||||
assert result is not None
|
||||
|
||||
async def test_function_context_validation(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_context_validation(self, mock_function: FunctionTool) -> None:
|
||||
"""Test that function context contains expected data."""
|
||||
|
||||
class ContextValidationMiddleware(FunctionMiddleware):
|
||||
@@ -1489,7 +1489,7 @@ class TestMiddlewareExecutionControl:
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool) -> None:
|
||||
"""Test that when function middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class FunctionTestArgs(BaseModel):
|
||||
@@ -1666,9 +1666,9 @@ def mock_agent() -> SupportsAgentRun:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_function() -> FunctionTool[Any]:
|
||||
def mock_function() -> FunctionTool:
|
||||
"""Mock function for testing."""
|
||||
function = MagicMock(spec=FunctionTool[Any])
|
||||
function = MagicMock(spec=FunctionTool)
|
||||
function.name = "test_function"
|
||||
return function
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -103,7 +102,7 @@ class TestResultOverrideMiddleware:
|
||||
assert updates[0].text == "overridden"
|
||||
assert updates[1].text == " stream"
|
||||
|
||||
async def test_function_middleware_result_override(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_middleware_result_override(self, mock_function: FunctionTool) -> None:
|
||||
"""Test that function middleware can override result."""
|
||||
override_result = "overridden function result"
|
||||
|
||||
@@ -252,7 +251,7 @@ class TestResultOverrideMiddleware:
|
||||
assert execute_result.messages[0].text == "executed response"
|
||||
assert handler_called
|
||||
|
||||
async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool) -> None:
|
||||
"""Test that when function middleware conditionally doesn't call next(), no execution happens."""
|
||||
|
||||
class ConditionalNoNextFunctionMiddleware(FunctionMiddleware):
|
||||
@@ -335,7 +334,7 @@ class TestResultObservability:
|
||||
assert observed_responses[0].messages[0].text == "executed response"
|
||||
assert result == observed_responses[0]
|
||||
|
||||
async def test_function_middleware_result_observability(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_middleware_result_observability(self, mock_function: FunctionTool) -> None:
|
||||
"""Test that middleware can observe function result after execution."""
|
||||
observed_results: list[str] = []
|
||||
|
||||
@@ -402,7 +401,7 @@ class TestResultObservability:
|
||||
assert result is not None
|
||||
assert result.messages[0].text == "modified after execution"
|
||||
|
||||
async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool[Any]) -> None:
|
||||
async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool) -> None:
|
||||
"""Test that middleware can override function result after observing execution."""
|
||||
|
||||
class PostExecutionOverrideMiddleware(FunctionMiddleware):
|
||||
@@ -444,8 +443,8 @@ def mock_agent() -> SupportsAgentRun:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_function() -> FunctionTool[Any]:
|
||||
def mock_function() -> FunctionTool:
|
||||
"""Mock function for testing."""
|
||||
function = MagicMock(spec=FunctionTool[Any])
|
||||
function = MagicMock(spec=FunctionTool)
|
||||
function.name = "test_function"
|
||||
return function
|
||||
|
||||
@@ -108,6 +108,90 @@ def test_tool_decorator_with_json_schema_dict():
|
||||
assert search("hello") == "Searching for: hello (max 10)"
|
||||
|
||||
|
||||
async def test_tool_decorator_with_json_schema_invoke_uses_mapping():
|
||||
"""Test that schema-based tools can be invoked directly with mapping arguments."""
|
||||
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"max_results": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@tool(name="search", description="Search tool", schema=json_schema)
|
||||
def search(query: str, max_results: int = 10) -> str:
|
||||
return f"{query}:{max_results}"
|
||||
|
||||
result = await search.invoke(arguments={"query": "hello", "max_results": 3})
|
||||
assert result == "hello:3"
|
||||
|
||||
|
||||
async def test_tool_decorator_with_json_schema_invoke_missing_required():
|
||||
"""Test schema-required fields are checked for mapping arguments."""
|
||||
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@tool(name="search", description="Search tool", schema=json_schema)
|
||||
def search(query: str) -> str:
|
||||
return query
|
||||
|
||||
with pytest.raises(TypeError, match="Missing required argument"):
|
||||
await search.invoke(arguments={})
|
||||
|
||||
|
||||
async def test_tool_decorator_with_json_schema_invoke_invalid_type():
|
||||
"""Test schema type checks run for mapping arguments."""
|
||||
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"max_results": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@tool(name="search", description="Search tool", schema=json_schema)
|
||||
def search(query: str, max_results: int = 10) -> str:
|
||||
return f"{query}:{max_results}"
|
||||
|
||||
with pytest.raises(TypeError, match="Invalid type for 'max_results'"):
|
||||
await search.invoke(arguments={"query": "hello", "max_results": "three"})
|
||||
|
||||
|
||||
def test_tool_decorator_with_json_schema_preserves_custom_properties():
|
||||
"""Test schema passthrough keeps custom JSON schema properties."""
|
||||
|
||||
json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": {
|
||||
"type": "string",
|
||||
"enum": ["low", "medium", "high"],
|
||||
"x-custom-field": "custom-value",
|
||||
},
|
||||
},
|
||||
"required": ["priority"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@tool(name="process", description="Process tool", schema=json_schema)
|
||||
def process(priority: str) -> str:
|
||||
return priority
|
||||
|
||||
params = process.parameters()
|
||||
assert not params.get("additionalProperties")
|
||||
assert params["properties"]["priority"]["x-custom-field"] == "custom-value"
|
||||
|
||||
|
||||
def test_tool_decorator_schema_none_default():
|
||||
"""Test that schema=None (default) still infers from function signature."""
|
||||
|
||||
@@ -555,7 +639,7 @@ async def test_tool_invoke_telemetry_with_pydantic_args(span_exporter: InMemoryS
|
||||
assert span.attributes[OtelAttr.TOOL_CALL_ID] == "pydantic_call"
|
||||
assert span.attributes[OtelAttr.TOOL_TYPE] == "function"
|
||||
assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool with Pydantic args"
|
||||
assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x":5,"y":10}'
|
||||
assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 5, "y": 10}'
|
||||
|
||||
|
||||
async def test_tool_invoke_telemetry_with_exception(span_exporter: InMemorySpanExporter):
|
||||
|
||||
@@ -499,7 +499,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
return copilot_tools
|
||||
|
||||
def _tool_to_copilot_tool(self, ai_func: FunctionTool[Any]) -> CopilotTool:
|
||||
def _tool_to_copilot_tool(self, ai_func: FunctionTool) -> CopilotTool:
|
||||
"""Convert an FunctionTool to a Copilot SDK tool."""
|
||||
|
||||
async def handler(invocation: ToolInvocation) -> ToolResult:
|
||||
|
||||
@@ -27,7 +27,7 @@ from tau2.environment.tool import Tool # type: ignore[import-untyped]
|
||||
_original_set_state = Environment.set_state
|
||||
|
||||
|
||||
def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool[Any]:
|
||||
def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool:
|
||||
"""Convert a tau2 Tool to a FunctionTool for agent framework compatibility.
|
||||
|
||||
Creates a wrapper that preserves the tool's interface while ensuring
|
||||
|
||||
@@ -324,7 +324,7 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
existing_tools = list(default_options.get("tools") or [])
|
||||
existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")}
|
||||
|
||||
new_tools: list[FunctionTool[Any]] = []
|
||||
new_tools: list[FunctionTool] = []
|
||||
for target in targets:
|
||||
handoff_tool = self._create_handoff_tool(target.target_id, target.description)
|
||||
if handoff_tool.name in existing_names:
|
||||
@@ -340,7 +340,7 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
else:
|
||||
default_options["tools"] = existing_tools
|
||||
|
||||
def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool[Any]:
|
||||
def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool:
|
||||
"""Construct the synthetic handoff tool that signals routing to `target_id`."""
|
||||
tool_name = get_handoff_tool_name(target_id)
|
||||
doc = description or f"Handoff to the {target_id} agent."
|
||||
|
||||
Reference in New Issue
Block a user