mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: MCP Improvements: improved connection loss behavior, pagination for loading and a param to control representation (#3154)
* pagination support (#2848) added a parse_tool_result param and connection loss (#2884) * fix #3153 * improved connection handling * improved logic
This commit is contained in:
committed by
GitHub
Unverified
parent
203fb7b1c4
commit
b2893fbc00
@@ -4,13 +4,14 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Collection, Sequence
|
||||
from collections.abc import Callable, Collection, Sequence
|
||||
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
@@ -21,7 +22,11 @@ from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from ._tools import AIFunction, HostedMCPSpecificApproval, _build_pydantic_model_from_json_schema
|
||||
from ._tools import (
|
||||
AIFunction,
|
||||
HostedMCPSpecificApproval,
|
||||
_build_pydantic_model_from_json_schema,
|
||||
)
|
||||
from ._types import (
|
||||
ChatMessage,
|
||||
Contents,
|
||||
@@ -329,7 +334,9 @@ class MCPTool:
|
||||
approval_mode: (Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None) = None,
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
|
||||
load_prompts: bool = True,
|
||||
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
|
||||
session: ClientSession | None = None,
|
||||
request_timeout: int | None = None,
|
||||
chat_client: "ChatClientProtocol | None" = None,
|
||||
@@ -347,7 +354,9 @@ class MCPTool:
|
||||
self.allowed_tools = allowed_tools
|
||||
self.additional_properties = additional_properties
|
||||
self.load_tools_flag = load_tools
|
||||
self.parse_tool_results = parse_tool_results
|
||||
self.load_prompts_flag = load_prompts
|
||||
self.parse_prompt_results = parse_prompt_results
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self.session = session
|
||||
self.request_timeout = request_timeout
|
||||
@@ -367,15 +376,23 @@ class MCPTool:
|
||||
return self._functions
|
||||
return [func for func in self._functions if func.name in self.allowed_tools]
|
||||
|
||||
async def connect(self) -> None:
|
||||
async def connect(self, *, reset: bool = False) -> None:
|
||||
"""Connect to the MCP server.
|
||||
|
||||
Establishes a connection to the MCP server, initializes the session,
|
||||
and loads tools and prompts if configured to do so.
|
||||
|
||||
Keyword Args:
|
||||
reset: If True, forces a reconnection even if already connected.
|
||||
|
||||
Raises:
|
||||
ToolException: If connection or session initialization fails.
|
||||
"""
|
||||
if reset:
|
||||
await self._exit_stack.aclose()
|
||||
self.session = None
|
||||
self.is_connected = False
|
||||
self._exit_stack = AsyncExitStack()
|
||||
if not self.session:
|
||||
try:
|
||||
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
|
||||
@@ -565,86 +582,88 @@ class MCPTool:
|
||||
"""Load prompts from the MCP server.
|
||||
|
||||
Retrieves available prompts from the connected MCP server and converts
|
||||
them into AIFunction instances.
|
||||
them into AIFunction instances. Handles pagination automatically.
|
||||
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
if not self.session:
|
||||
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
||||
try:
|
||||
prompt_list = await self.session.list_prompts()
|
||||
except Exception as exc:
|
||||
logger.info(
|
||||
"Prompt could not be loaded, you can exclude trying to load, by setting: load_prompts=False",
|
||||
exc_info=exc,
|
||||
)
|
||||
prompt_list = None
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
|
||||
for prompt in prompt_list.prompts if prompt_list else []:
|
||||
local_name = _normalize_mcp_name(prompt.name)
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
# Ensure connection is still valid before each page request
|
||||
await self._ensure_connected()
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
|
||||
|
||||
input_model = _get_input_model_from_mcp_prompt(prompt)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction(
|
||||
func=partial(self.get_prompt, prompt.name),
|
||||
name=local_name,
|
||||
description=prompt.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
)
|
||||
self._functions.append(func)
|
||||
existing_names.add(local_name)
|
||||
for prompt in prompt_list.prompts:
|
||||
local_name = _normalize_mcp_name(prompt.name)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
input_model = _get_input_model_from_mcp_prompt(prompt)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
func: AIFunction[BaseModel, list[ChatMessage] | Any | types.GetPromptResult] = AIFunction(
|
||||
func=partial(self.get_prompt, prompt.name),
|
||||
name=local_name,
|
||||
description=prompt.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
)
|
||||
self._functions.append(func)
|
||||
existing_names.add(local_name)
|
||||
|
||||
# Check if there are more pages
|
||||
if not prompt_list or not prompt_list.nextCursor:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor)
|
||||
|
||||
async def load_tools(self) -> None:
|
||||
"""Load tools from the MCP server.
|
||||
|
||||
Retrieves available tools from the connected MCP server and converts
|
||||
them into AIFunction instances.
|
||||
them into AIFunction instances. Handles pagination automatically.
|
||||
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
if not self.session:
|
||||
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
||||
try:
|
||||
tool_list = await self.session.list_tools()
|
||||
except Exception as exc:
|
||||
logger.info(
|
||||
"Tools could not be loaded, you can exclude trying to load, by setting: load_tools=False",
|
||||
exc_info=exc,
|
||||
)
|
||||
tool_list = None
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
|
||||
for tool in tool_list.tools if tool_list else []:
|
||||
local_name = _normalize_mcp_name(tool.name)
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
# Ensure connection is still valid before each page request
|
||||
await self._ensure_connected()
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
|
||||
|
||||
input_model = _get_input_model_from_mcp_tool(tool)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
# Create AIFunctions out of each tool
|
||||
func: AIFunction[BaseModel, list[Contents]] = AIFunction(
|
||||
func=partial(self.call_tool, tool.name),
|
||||
name=local_name,
|
||||
description=tool.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
)
|
||||
self._functions.append(func)
|
||||
existing_names.add(local_name)
|
||||
for tool in tool_list.tools:
|
||||
local_name = _normalize_mcp_name(tool.name)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
input_model = _get_input_model_from_mcp_tool(tool)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
# Create AIFunctions out of each tool
|
||||
func: AIFunction[BaseModel, list[Contents] | Any | types.CallToolResult] = AIFunction(
|
||||
func=partial(self.call_tool, tool.name),
|
||||
name=local_name,
|
||||
description=tool.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
)
|
||||
self._functions.append(func)
|
||||
existing_names.add(local_name)
|
||||
|
||||
# Check if there are more pages
|
||||
if not tool_list or not tool_list.nextCursor:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Disconnect from the MCP server.
|
||||
@@ -664,7 +683,28 @@ class MCPTool:
|
||||
"""
|
||||
pass
|
||||
|
||||
async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]:
|
||||
async def _ensure_connected(self) -> None:
|
||||
"""Ensure the connection is valid, reconnecting if necessary.
|
||||
|
||||
This method proactively checks if the connection is valid and
|
||||
reconnects if it's not, avoiding the need to catch ClosedResourceError.
|
||||
|
||||
Raises:
|
||||
ToolExecutionException: If reconnection fails.
|
||||
"""
|
||||
try:
|
||||
await self.session.send_ping() # type: ignore[union-attr]
|
||||
except Exception:
|
||||
logger.info("MCP connection invalid or closed. Reconnecting...")
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to establish MCP connection.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
|
||||
async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents] | Any | types.CallToolResult:
|
||||
"""Call a tool with the given arguments.
|
||||
|
||||
Args:
|
||||
@@ -680,8 +720,6 @@ class MCPTool:
|
||||
ToolExecutionException: If the MCP server is not connected, tools are not loaded,
|
||||
or the tool call fails.
|
||||
"""
|
||||
if not self.session:
|
||||
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
||||
if not self.load_tools_flag:
|
||||
raise ToolExecutionException(
|
||||
"Tools are not loaded for this server, please set load_tools=True in the constructor."
|
||||
@@ -692,16 +730,44 @@ class MCPTool:
|
||||
filtered_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k not in {"chat_options", "tools", "tool_choice", "thread"}
|
||||
}
|
||||
try:
|
||||
return _parse_contents_from_mcp_tool_result(
|
||||
await self.session.call_tool(tool_name, arguments=filtered_kwargs)
|
||||
)
|
||||
except McpError as mcp_exc:
|
||||
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
|
||||
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage]:
|
||||
# Try the operation, reconnecting once if the connection is closed
|
||||
for attempt in range(2):
|
||||
try:
|
||||
result = await self.session.call_tool(tool_name, arguments=filtered_kwargs) # type: ignore
|
||||
if self.parse_tool_results is None:
|
||||
return result
|
||||
if self.parse_tool_results is True:
|
||||
return _parse_contents_from_mcp_tool_result(result)
|
||||
if callable(self.parse_tool_results):
|
||||
return self.parse_tool_results(result)
|
||||
return result
|
||||
except ClosedResourceError as cl_ex:
|
||||
if attempt == 0:
|
||||
# First attempt failed, try reconnecting
|
||||
logger.info("MCP connection closed unexpectedly. Reconnecting...")
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
continue # Retry the operation
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
else:
|
||||
# Second attempt also failed, give up
|
||||
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost.",
|
||||
inner_exception=cl_ex,
|
||||
) from cl_ex
|
||||
except McpError as mcp_exc:
|
||||
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
|
||||
|
||||
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] | Any | types.GetPromptResult:
|
||||
"""Call a prompt with the given arguments.
|
||||
|
||||
Args:
|
||||
@@ -717,19 +783,46 @@ class MCPTool:
|
||||
ToolExecutionException: If the MCP server is not connected, prompts are not loaded,
|
||||
or the prompt call fails.
|
||||
"""
|
||||
if not self.session:
|
||||
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
||||
if not self.load_prompts_flag:
|
||||
raise ToolExecutionException(
|
||||
"Prompts are not loaded for this server, please set load_prompts=True in the constructor."
|
||||
)
|
||||
try:
|
||||
prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs)
|
||||
return [_parse_message_from_mcp(message) for message in prompt_result.messages]
|
||||
except McpError as mcp_exc:
|
||||
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
|
||||
|
||||
# Try the operation, reconnecting once if the connection is closed
|
||||
for attempt in range(2):
|
||||
try:
|
||||
prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) # type: ignore
|
||||
if self.parse_prompt_results is None:
|
||||
return prompt_result
|
||||
if self.parse_prompt_results is True:
|
||||
return [_parse_message_from_mcp(message) for message in prompt_result.messages]
|
||||
if callable(self.parse_prompt_results):
|
||||
return self.parse_prompt_results(prompt_result)
|
||||
return prompt_result
|
||||
except ClosedResourceError as cl_ex:
|
||||
if attempt == 0:
|
||||
# First attempt failed, try reconnecting
|
||||
logger.info("MCP connection closed unexpectedly. Reconnecting...")
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
continue # Retry the operation
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
else:
|
||||
# Second attempt also failed, give up
|
||||
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call prompt '{prompt_name}' - connection lost.",
|
||||
inner_exception=cl_ex,
|
||||
) from cl_ex
|
||||
except McpError as mcp_exc:
|
||||
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to get prompt '{prompt_name}' after retries.")
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter the async context manager.
|
||||
@@ -804,7 +897,9 @@ class MCPStdioTool(MCPTool):
|
||||
command: str,
|
||||
*,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
|
||||
load_prompts: bool = True,
|
||||
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
|
||||
request_timeout: int | None = None,
|
||||
session: ClientSession | None = None,
|
||||
description: str | None = None,
|
||||
@@ -830,7 +925,15 @@ class MCPStdioTool(MCPTool):
|
||||
|
||||
Keyword Args:
|
||||
load_tools: Whether to load tools from the MCP server.
|
||||
parse_tool_results: How to parse tool results from the MCP server.
|
||||
Set to True, to use the default parser that converts to Agent Framework types.
|
||||
Set to a callable to use a custom parser function.
|
||||
Set to None to return the raw MCP tool result.
|
||||
load_prompts: Whether to load prompts from the MCP server.
|
||||
parse_prompt_results: How to parse prompt results from the MCP server.
|
||||
Set to True, to use the default parser that converts to Agent Framework types.
|
||||
Set to a callable to use a custom parser function.
|
||||
Set to None to return the raw MCP prompt result.
|
||||
request_timeout: The default timeout in seconds for all requests.
|
||||
session: The session to use for the MCP connection.
|
||||
description: The description of the tool.
|
||||
@@ -857,7 +960,9 @@ class MCPStdioTool(MCPTool):
|
||||
session=session,
|
||||
chat_client=chat_client,
|
||||
load_tools=load_tools,
|
||||
parse_tool_results=parse_tool_results,
|
||||
load_prompts=load_prompts,
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
)
|
||||
self.command = command
|
||||
@@ -913,7 +1018,9 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
url: str,
|
||||
*,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
|
||||
load_prompts: bool = True,
|
||||
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
|
||||
request_timeout: int | None = None,
|
||||
session: ClientSession | None = None,
|
||||
description: str | None = None,
|
||||
@@ -939,7 +1046,15 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
|
||||
Keyword Args:
|
||||
load_tools: Whether to load tools from the MCP server.
|
||||
parse_tool_results: How to parse tool results from the MCP server.
|
||||
Set to True, to use the default parser that converts to Agent Framework types.
|
||||
Set to a callable to use a custom parser function.
|
||||
Set to None to return the raw MCP tool result.
|
||||
load_prompts: Whether to load prompts from the MCP server.
|
||||
parse_prompt_results: How to parse prompt results from the MCP server.
|
||||
Set to True, to use the default parser that converts to Agent Framework types.
|
||||
Set to a callable to use a custom parser function.
|
||||
Set to None to return the raw MCP prompt result.
|
||||
request_timeout: The default timeout in seconds for all requests.
|
||||
session: The session to use for the MCP connection.
|
||||
description: The description of the tool.
|
||||
@@ -968,7 +1083,9 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
session=session,
|
||||
chat_client=chat_client,
|
||||
load_tools=load_tools,
|
||||
parse_tool_results=parse_tool_results,
|
||||
load_prompts=load_prompts,
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
)
|
||||
self.url = url
|
||||
@@ -1016,7 +1133,9 @@ class MCPWebsocketTool(MCPTool):
|
||||
url: str,
|
||||
*,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
|
||||
load_prompts: bool = True,
|
||||
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
|
||||
request_timeout: int | None = None,
|
||||
session: ClientSession | None = None,
|
||||
description: str | None = None,
|
||||
@@ -1040,7 +1159,15 @@ class MCPWebsocketTool(MCPTool):
|
||||
|
||||
Keyword Args:
|
||||
load_tools: Whether to load tools from the MCP server.
|
||||
parse_tool_results: How to parse tool results from the MCP server.
|
||||
Set to True, to use the default parser that converts to Agent Framework types.
|
||||
Set to a callable to use a custom parser function.
|
||||
Set to None to return the raw MCP tool result.
|
||||
load_prompts: Whether to load prompts from the MCP server.
|
||||
parse_prompt_results: How to parse prompt results from the MCP server.
|
||||
Set to True, to use the default parser that converts to Agent Framework types.
|
||||
Set to a callable to use a custom parser function.
|
||||
Set to None to return the raw MCP prompt result.
|
||||
request_timeout: The default timeout in seconds for all requests.
|
||||
session: The session to use for the MCP connection.
|
||||
description: The description of the tool.
|
||||
@@ -1064,7 +1191,9 @@ class MCPWebsocketTool(MCPTool):
|
||||
session=session,
|
||||
chat_client=chat_client,
|
||||
load_tools=load_tools,
|
||||
parse_tool_results=parse_tool_results,
|
||||
load_prompts=load_prompts,
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
)
|
||||
self.url = url
|
||||
|
||||
@@ -4,7 +4,15 @@ import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Collection, Mapping, MutableMapping, Sequence
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Sequence,
|
||||
)
|
||||
from functools import wraps
|
||||
from time import perf_counter, time_ns
|
||||
from typing import (
|
||||
@@ -18,6 +26,7 @@ from typing import (
|
||||
Protocol,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
@@ -121,7 +130,13 @@ def _parse_inputs(
|
||||
if inputs is None:
|
||||
return []
|
||||
|
||||
from ._types import BaseContent, DataContent, HostedFileContent, HostedVectorStoreContent, UriContent
|
||||
from ._types import (
|
||||
BaseContent,
|
||||
DataContent,
|
||||
HostedFileContent,
|
||||
HostedVectorStoreContent,
|
||||
UriContent,
|
||||
)
|
||||
|
||||
parsed_inputs: list["Contents"] = []
|
||||
if not isinstance(inputs, list):
|
||||
@@ -1010,6 +1025,27 @@ def _build_pydantic_model_from_json_schema(
|
||||
if not properties:
|
||||
return create_model(f"{model_name}_input")
|
||||
|
||||
def _resolve_literal_type(prop_details: dict[str, Any]) -> type | None:
|
||||
"""Check if property should be a Literal type (const or enum).
|
||||
|
||||
Args:
|
||||
prop_details: The JSON Schema property details
|
||||
|
||||
Returns:
|
||||
Literal type if const or enum is present, None otherwise
|
||||
"""
|
||||
# const → Literal["value"]
|
||||
if "const" in prop_details:
|
||||
return Literal[prop_details["const"]] # type: ignore
|
||||
|
||||
# enum → Literal["a", "b", ...]
|
||||
if "enum" in prop_details and isinstance(prop_details["enum"], list):
|
||||
enum_values = prop_details["enum"]
|
||||
if enum_values:
|
||||
return Literal[tuple(enum_values)] # type: ignore
|
||||
|
||||
return None
|
||||
|
||||
def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
|
||||
"""Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays.
|
||||
|
||||
@@ -1020,6 +1056,31 @@ def _build_pydantic_model_from_json_schema(
|
||||
Returns:
|
||||
Python type annotation (could be int, str, list[str], or a nested Pydantic model)
|
||||
"""
|
||||
# Handle oneOf + discriminator (polymorphic objects)
|
||||
if "oneOf" in prop_details and "discriminator" in prop_details:
|
||||
discriminator = prop_details["discriminator"]
|
||||
disc_field = discriminator.get("propertyName")
|
||||
|
||||
variants = []
|
||||
for variant in prop_details["oneOf"]:
|
||||
if "$ref" in variant:
|
||||
ref = variant["$ref"]
|
||||
if ref.startswith("#/$defs/"):
|
||||
def_name = ref.split("/")[-1]
|
||||
resolved = definitions.get(def_name)
|
||||
if resolved:
|
||||
variant_model = _resolve_type(
|
||||
resolved,
|
||||
parent_name=f"{parent_name}_{def_name}",
|
||||
)
|
||||
variants.append(variant_model)
|
||||
|
||||
if variants and disc_field:
|
||||
return Annotated[
|
||||
Union[tuple(variants)], # type: ignore
|
||||
Field(discriminator=disc_field),
|
||||
]
|
||||
|
||||
# Handle $ref by resolving the reference
|
||||
if "$ref" in prop_details:
|
||||
ref = prop_details["$ref"]
|
||||
@@ -1070,9 +1131,15 @@ def _build_pydantic_model_from_json_schema(
|
||||
else nested_prop_details
|
||||
)
|
||||
|
||||
nested_python_type = _resolve_type(
|
||||
nested_prop_details, f"{nested_model_name}_{nested_prop_name}"
|
||||
)
|
||||
# Check for Literal types first (const/enum)
|
||||
literal_type = _resolve_literal_type(nested_prop_details)
|
||||
if literal_type is not None:
|
||||
nested_python_type = literal_type
|
||||
else:
|
||||
nested_python_type = _resolve_type(
|
||||
nested_prop_details,
|
||||
f"{nested_model_name}_{nested_prop_name}",
|
||||
)
|
||||
nested_description = nested_prop_details.get("description", "")
|
||||
|
||||
# Build field kwargs for nested property
|
||||
@@ -1109,7 +1176,12 @@ def _build_pydantic_model_from_json_schema(
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
|
||||
|
||||
python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}")
|
||||
# Check for Literal types first (const/enum)
|
||||
literal_type = _resolve_literal_type(prop_details)
|
||||
if literal_type is not None:
|
||||
python_type = literal_type
|
||||
else:
|
||||
python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}")
|
||||
description = prop_details.get("description", "")
|
||||
|
||||
# Build field kwargs (description, etc.)
|
||||
|
||||
@@ -1248,6 +1248,75 @@ async def test_streamable_http_integration():
|
||||
assert result[0].text is not None
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@skip_if_mcp_integration_tests_disabled
|
||||
async def test_mcp_connection_reset_integration():
|
||||
"""Test that connection reset works correctly with a real MCP server.
|
||||
|
||||
This integration test verifies:
|
||||
1. Initial connection and tool execution works
|
||||
2. Simulating connection failure triggers automatic reconnection
|
||||
3. Tool execution works after reconnection
|
||||
4. Exit stack cleanup happens properly during reconnection
|
||||
"""
|
||||
url = os.environ.get("LOCAL_MCP_URL")
|
||||
|
||||
tool = MCPStreamableHTTPTool(name="integration_test", url=url)
|
||||
|
||||
async with tool:
|
||||
# Verify initial connection
|
||||
assert tool.session is not None
|
||||
assert tool.is_connected is True
|
||||
assert len(tool.functions) > 0, "The MCP server should have at least one function."
|
||||
|
||||
# Get the first function and invoke it
|
||||
func = tool.functions[0]
|
||||
first_result = await func.invoke(query="What is Agent Framework?")
|
||||
assert first_result is not None
|
||||
assert len(first_result) > 0
|
||||
|
||||
# Store the original session and exit stack for comparison
|
||||
original_session = tool.session
|
||||
original_exit_stack = tool._exit_stack
|
||||
original_call_tool = tool.session.call_tool
|
||||
|
||||
# Simulate connection failure by making call_tool raise ClosedResourceError once
|
||||
call_count = 0
|
||||
|
||||
async def call_tool_with_error(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First call fails with connection error
|
||||
from anyio.streams.memory import ClosedResourceError
|
||||
|
||||
raise ClosedResourceError
|
||||
# After reconnection, delegate to the original method
|
||||
return await original_call_tool(*args, **kwargs)
|
||||
|
||||
tool.session.call_tool = call_tool_with_error
|
||||
|
||||
# Invoke the function again - this should trigger automatic reconnection on ClosedResourceError
|
||||
second_result = await func.invoke(query="What is Agent Framework?")
|
||||
assert second_result is not None
|
||||
assert len(second_result) > 0
|
||||
|
||||
# Verify we have a new session and exit stack after reconnection
|
||||
assert tool.session is not None
|
||||
assert tool.session is not original_session, "Session should be replaced after reconnection"
|
||||
assert tool._exit_stack is not original_exit_stack, "Exit stack should be replaced after reconnection"
|
||||
assert tool.is_connected is True
|
||||
|
||||
# Verify tools are still available after reconnection
|
||||
assert len(tool.functions) > 0
|
||||
|
||||
# Both results should be valid (we don't compare content as it may vary)
|
||||
if hasattr(first_result[0], "text"):
|
||||
assert first_result[0].text is not None
|
||||
if hasattr(second_result[0], "text"):
|
||||
assert second_result[0].text is not None
|
||||
|
||||
|
||||
async def test_mcp_tool_message_handler_notification():
|
||||
"""Test that message_handler correctly processes tools/list_changed and prompts/list_changed
|
||||
notifications."""
|
||||
@@ -1549,7 +1618,6 @@ def test_mcp_websocket_tool_get_mcp_client_with_kwargs():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_deduplication():
|
||||
"""Test that MCP tools are not duplicated in MCPTool"""
|
||||
from agent_framework._mcp import MCPTool
|
||||
@@ -1611,7 +1679,6 @@ async def test_mcp_tool_deduplication():
|
||||
assert added_count == 1 # Only 1 new function added
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_tools_prevents_multiple_calls():
|
||||
"""Test that connect() prevents calling load_tools() multiple times"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
@@ -1627,6 +1694,7 @@ async def test_load_tools_prevents_multiple_calls():
|
||||
mock_session = AsyncMock()
|
||||
mock_tool_list = MagicMock()
|
||||
mock_tool_list.tools = []
|
||||
mock_tool_list.nextCursor = None # No pagination
|
||||
mock_session.list_tools = AsyncMock(return_value=mock_tool_list)
|
||||
mock_session.initialize = AsyncMock()
|
||||
|
||||
@@ -1650,7 +1718,6 @@ async def test_load_tools_prevents_multiple_calls():
|
||||
assert mock_session.list_tools.call_count == 1 # Still 1, not incremented
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_prompts_prevents_multiple_calls():
|
||||
"""Test that connect() prevents calling load_prompts() multiple times"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
@@ -1666,6 +1733,7 @@ async def test_load_prompts_prevents_multiple_calls():
|
||||
mock_session = AsyncMock()
|
||||
mock_prompt_list = MagicMock()
|
||||
mock_prompt_list.prompts = []
|
||||
mock_prompt_list.nextCursor = None # No pagination
|
||||
mock_session.list_prompts = AsyncMock(return_value=mock_prompt_list)
|
||||
|
||||
tool.session = mock_session
|
||||
@@ -1688,7 +1756,6 @@ async def test_load_prompts_prevents_multiple_calls():
|
||||
assert mock_session.list_prompts.call_count == 1 # Still 1, not incremented
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_streamable_http_tool_httpx_client_cleanup():
|
||||
"""Test that MCPStreamableHTTPTool properly passes through httpx clients."""
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
@@ -1744,3 +1811,556 @@ async def test_mcp_streamable_http_tool_httpx_client_cleanup():
|
||||
# Get the last call (should be from tool2.connect())
|
||||
call_args = mock_client.call_args
|
||||
assert call_args.kwargs["http_client"] is user_client, "User's client should be passed through"
|
||||
|
||||
|
||||
async def test_load_tools_with_pagination():
|
||||
"""Test that load_tools handles pagination correctly."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
# Create paginated responses
|
||||
page1 = MagicMock()
|
||||
page1.tools = [
|
||||
types.Tool(
|
||||
name="tool_1",
|
||||
description="First tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
types.Tool(
|
||||
name="tool_2",
|
||||
description="Second tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
]
|
||||
page1.nextCursor = "cursor_page2"
|
||||
|
||||
page2 = MagicMock()
|
||||
page2.tools = [
|
||||
types.Tool(
|
||||
name="tool_3",
|
||||
description="Third tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
]
|
||||
page2.nextCursor = "cursor_page3"
|
||||
|
||||
page3 = MagicMock()
|
||||
page3.tools = [
|
||||
types.Tool(
|
||||
name="tool_4",
|
||||
description="Fourth tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
]
|
||||
page3.nextCursor = None # No more pages
|
||||
|
||||
# Mock list_tools to return different pages based on params
|
||||
async def mock_list_tools(params=None):
|
||||
if params is None:
|
||||
return page1
|
||||
if params.cursor == "cursor_page2":
|
||||
return page2
|
||||
if params.cursor == "cursor_page3":
|
||||
return page3
|
||||
raise ValueError("Unexpected cursor value")
|
||||
|
||||
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
|
||||
|
||||
# Load tools with pagination
|
||||
await tool.load_tools()
|
||||
|
||||
# Verify all pages were fetched
|
||||
assert mock_session.list_tools.call_count == 3
|
||||
assert len(tool._functions) == 4
|
||||
assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3", "tool_4"]
|
||||
|
||||
|
||||
async def test_load_prompts_with_pagination():
|
||||
"""Test that load_prompts handles pagination correctly."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
# Create paginated responses
|
||||
page1 = MagicMock()
|
||||
page1.prompts = [
|
||||
types.Prompt(
|
||||
name="prompt_1",
|
||||
description="First prompt",
|
||||
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
|
||||
),
|
||||
types.Prompt(
|
||||
name="prompt_2",
|
||||
description="Second prompt",
|
||||
arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=True)],
|
||||
),
|
||||
]
|
||||
page1.nextCursor = "cursor_page2"
|
||||
|
||||
page2 = MagicMock()
|
||||
page2.prompts = [
|
||||
types.Prompt(
|
||||
name="prompt_3",
|
||||
description="Third prompt",
|
||||
arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=False)],
|
||||
),
|
||||
]
|
||||
page2.nextCursor = None # No more pages
|
||||
|
||||
# Mock list_prompts to return different pages based on params
|
||||
async def mock_list_prompts(params=None):
|
||||
if params is None:
|
||||
return page1
|
||||
if params.cursor == "cursor_page2":
|
||||
return page2
|
||||
raise ValueError("Unexpected cursor value")
|
||||
|
||||
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
|
||||
|
||||
# Load prompts with pagination
|
||||
await tool.load_prompts()
|
||||
|
||||
# Verify all pages were fetched
|
||||
assert mock_session.list_prompts.call_count == 2
|
||||
assert len(tool._functions) == 3
|
||||
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2", "prompt_3"]
|
||||
|
||||
|
||||
async def test_load_tools_pagination_with_duplicates():
|
||||
"""Test that load_tools prevents duplicates across paginated results."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
# Create paginated responses with duplicate tool names
|
||||
page1 = MagicMock()
|
||||
page1.tools = [
|
||||
types.Tool(
|
||||
name="tool_1",
|
||||
description="First tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
types.Tool(
|
||||
name="tool_2",
|
||||
description="Second tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
]
|
||||
page1.nextCursor = "cursor_page2"
|
||||
|
||||
page2 = MagicMock()
|
||||
page2.tools = [
|
||||
types.Tool(
|
||||
name="tool_1", # Duplicate from page1
|
||||
description="Duplicate tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
types.Tool(
|
||||
name="tool_3",
|
||||
description="Third tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
),
|
||||
]
|
||||
page2.nextCursor = None
|
||||
|
||||
# Mock list_tools to return different pages
|
||||
async def mock_list_tools(params=None):
|
||||
if params is None:
|
||||
return page1
|
||||
if params.cursor == "cursor_page2":
|
||||
return page2
|
||||
raise ValueError("Unexpected cursor value")
|
||||
|
||||
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
|
||||
|
||||
# Load tools with pagination
|
||||
await tool.load_tools()
|
||||
|
||||
# Verify duplicates were skipped
|
||||
assert mock_session.list_tools.call_count == 2
|
||||
assert len(tool._functions) == 3
|
||||
assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3"]
|
||||
|
||||
|
||||
async def test_load_prompts_pagination_with_duplicates():
|
||||
"""Test that load_prompts prevents duplicates across paginated results."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
# Create paginated responses with duplicate prompt names
|
||||
page1 = MagicMock()
|
||||
page1.prompts = [
|
||||
types.Prompt(
|
||||
name="prompt_1",
|
||||
description="First prompt",
|
||||
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
|
||||
),
|
||||
]
|
||||
page1.nextCursor = "cursor_page2"
|
||||
|
||||
page2 = MagicMock()
|
||||
page2.prompts = [
|
||||
types.Prompt(
|
||||
name="prompt_1", # Duplicate from page1
|
||||
description="Duplicate prompt",
|
||||
arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=False)],
|
||||
),
|
||||
types.Prompt(
|
||||
name="prompt_2",
|
||||
description="Second prompt",
|
||||
arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=True)],
|
||||
),
|
||||
]
|
||||
page2.nextCursor = None
|
||||
|
||||
# Mock list_prompts to return different pages
|
||||
async def mock_list_prompts(params=None):
|
||||
if params is None:
|
||||
return page1
|
||||
if params.cursor == "cursor_page2":
|
||||
return page2
|
||||
raise ValueError("Unexpected cursor value")
|
||||
|
||||
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
|
||||
|
||||
# Load prompts with pagination
|
||||
await tool.load_prompts()
|
||||
|
||||
# Verify duplicates were skipped
|
||||
assert mock_session.list_prompts.call_count == 2
|
||||
assert len(tool._functions) == 2
|
||||
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
|
||||
|
||||
|
||||
async def test_load_tools_pagination_exception_handling():
|
||||
"""Test that load_tools handles exceptions during pagination gracefully."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
# Mock list_tools to raise an exception on first call
|
||||
mock_session.list_tools = AsyncMock(side_effect=RuntimeError("Connection error"))
|
||||
|
||||
# Load tools should raise the exception (not handled gracefully)
|
||||
with pytest.raises(RuntimeError, match="Connection error"):
|
||||
await tool.load_tools()
|
||||
|
||||
# Verify exception was raised on first call
|
||||
assert mock_session.list_tools.call_count == 1
|
||||
assert len(tool._functions) == 0
|
||||
|
||||
|
||||
async def test_load_prompts_pagination_exception_handling():
|
||||
"""Test that load_prompts handles exceptions during pagination gracefully."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
# Mock list_prompts to raise an exception on first call
|
||||
mock_session.list_prompts = AsyncMock(side_effect=RuntimeError("Connection error"))
|
||||
|
||||
# Load prompts should raise the exception (not handled gracefully)
|
||||
with pytest.raises(RuntimeError, match="Connection error"):
|
||||
await tool.load_prompts()
|
||||
|
||||
# Verify exception was raised on first call
|
||||
assert mock_session.list_prompts.call_count == 1
|
||||
assert len(tool._functions) == 0
|
||||
|
||||
|
||||
async def test_load_tools_empty_pagination():
|
||||
"""Test that load_tools handles empty paginated results."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
# Create empty response
|
||||
page1 = MagicMock()
|
||||
page1.tools = []
|
||||
page1.nextCursor = None
|
||||
|
||||
mock_session.list_tools = AsyncMock(return_value=page1)
|
||||
|
||||
# Load tools
|
||||
await tool.load_tools()
|
||||
|
||||
# Verify
|
||||
assert mock_session.list_tools.call_count == 1
|
||||
assert len(tool._functions) == 0
|
||||
|
||||
|
||||
async def test_load_prompts_empty_pagination():
|
||||
"""Test that load_prompts handles empty paginated results."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
tool = MCPTool(name="test_tool")
|
||||
|
||||
# Mock the session
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
# Create empty response
|
||||
page1 = MagicMock()
|
||||
page1.prompts = []
|
||||
page1.nextCursor = None
|
||||
|
||||
mock_session.list_prompts = AsyncMock(return_value=page1)
|
||||
|
||||
# Load prompts
|
||||
await tool.load_prompts()
|
||||
|
||||
# Verify
|
||||
assert mock_session.list_prompts.call_count == 1
|
||||
assert len(tool._functions) == 0
|
||||
|
||||
|
||||
async def test_mcp_tool_connection_properly_invalidated_after_closed_resource_error():
|
||||
"""Test that verifies reconnection on ClosedResourceError for issue #2884.
|
||||
|
||||
This test verifies the fix for issue #2884: the tool tries operations optimistically
|
||||
and only reconnects when ClosedResourceError is encountered, avoiding extra latency.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from anyio.streams.memory import ClosedResourceError
|
||||
|
||||
from agent_framework._mcp import MCPStdioTool
|
||||
from agent_framework.exceptions import ToolExecutionException
|
||||
|
||||
# Create a mock MCP tool
|
||||
tool = MCPStdioTool(
|
||||
name="test_server",
|
||||
command="test_command",
|
||||
args=["arg1"],
|
||||
load_tools=True,
|
||||
)
|
||||
|
||||
# Mock the session
|
||||
mock_session = MagicMock()
|
||||
mock_session._request_id = 1
|
||||
mock_session.call_tool = AsyncMock()
|
||||
|
||||
# Mock _exit_stack.aclose to track cleanup calls
|
||||
original_exit_stack = tool._exit_stack
|
||||
tool._exit_stack.aclose = AsyncMock()
|
||||
|
||||
# Mock connect() to avoid trying to start actual process
|
||||
with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect:
|
||||
|
||||
async def restore_session(*, reset=False):
|
||||
if reset:
|
||||
await original_exit_stack.aclose()
|
||||
tool.session = mock_session
|
||||
tool.is_connected = True
|
||||
tool._tools_loaded = True
|
||||
|
||||
mock_connect.side_effect = restore_session
|
||||
|
||||
# Simulate initial connection
|
||||
tool.session = mock_session
|
||||
tool.is_connected = True
|
||||
tool._tools_loaded = True
|
||||
|
||||
# First call should work - connection is valid
|
||||
mock_session.call_tool.return_value = MagicMock(content=[])
|
||||
result = await tool.call_tool("test_tool", arg1="value1")
|
||||
assert result is not None
|
||||
|
||||
# Test Case 1: Connection closed unexpectedly, should reconnect and retry
|
||||
# Simulate ClosedResourceError on first call, then succeed
|
||||
call_count = 0
|
||||
|
||||
async def call_tool_with_error(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise ClosedResourceError
|
||||
return MagicMock(content=[])
|
||||
|
||||
mock_session.call_tool = call_tool_with_error
|
||||
|
||||
# This call should trigger reconnection after ClosedResourceError
|
||||
result = await tool.call_tool("test_tool", arg1="value2")
|
||||
assert result is not None
|
||||
# Verify reconnect was attempted with reset=True
|
||||
assert mock_connect.call_count >= 1
|
||||
mock_connect.assert_called_with(reset=True)
|
||||
# Verify _exit_stack.aclose was called during reconnection
|
||||
original_exit_stack.aclose.assert_called()
|
||||
|
||||
# Test Case 2: Reconnection failure
|
||||
# Reset counters
|
||||
call_count = 0
|
||||
mock_connect.reset_mock()
|
||||
original_exit_stack.aclose.reset_mock()
|
||||
|
||||
# Make call_tool always raise ClosedResourceError
|
||||
async def always_fail(*args, **kwargs):
|
||||
raise ClosedResourceError
|
||||
|
||||
mock_session.call_tool = always_fail
|
||||
|
||||
# Change mock_connect to simulate failed reconnection
|
||||
mock_connect.side_effect = Exception("Failed to reconnect")
|
||||
|
||||
# This should raise ToolExecutionException when reconnection fails
|
||||
with pytest.raises(ToolExecutionException) as exc_info:
|
||||
await tool.call_tool("test_tool", arg1="value3")
|
||||
|
||||
# Verify reconnection was attempted
|
||||
assert mock_connect.call_count >= 1
|
||||
# Verify error message indicates reconnection failure
|
||||
assert "failed to reconnect" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error():
|
||||
"""Test that get_prompt also reconnects on ClosedResourceError.
|
||||
|
||||
This verifies that the fix for issue #2884 applies to get_prompt as well,
|
||||
and that _exit_stack.aclose() is properly called during reconnection.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from anyio.streams.memory import ClosedResourceError
|
||||
|
||||
from agent_framework._mcp import MCPStdioTool
|
||||
from agent_framework.exceptions import ToolExecutionException
|
||||
|
||||
# Create a mock MCP tool
|
||||
tool = MCPStdioTool(
|
||||
name="test_server",
|
||||
command="test_command",
|
||||
args=["arg1"],
|
||||
load_prompts=True,
|
||||
)
|
||||
|
||||
# Mock the session
|
||||
mock_session = MagicMock()
|
||||
mock_session._request_id = 1
|
||||
mock_session.get_prompt = AsyncMock()
|
||||
|
||||
# Mock _exit_stack.aclose to track cleanup calls
|
||||
original_exit_stack = tool._exit_stack
|
||||
tool._exit_stack.aclose = AsyncMock()
|
||||
|
||||
# Mock connect() to avoid trying to start actual process
|
||||
with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect:
|
||||
|
||||
async def restore_session(*, reset=False):
|
||||
if reset:
|
||||
await original_exit_stack.aclose()
|
||||
tool.session = mock_session
|
||||
tool.is_connected = True
|
||||
tool._prompts_loaded = True
|
||||
|
||||
mock_connect.side_effect = restore_session
|
||||
|
||||
# Simulate initial connection
|
||||
tool.session = mock_session
|
||||
tool.is_connected = True
|
||||
tool._prompts_loaded = True
|
||||
|
||||
# First call should work - connection is valid
|
||||
mock_session.get_prompt.return_value = MagicMock(messages=[])
|
||||
result = await tool.get_prompt("test_prompt", arg1="value1")
|
||||
assert result is not None
|
||||
|
||||
# Test Case 1: Connection closed unexpectedly, should reconnect and retry
|
||||
# Simulate ClosedResourceError on first call, then succeed
|
||||
call_count = 0
|
||||
|
||||
async def get_prompt_with_error(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise ClosedResourceError
|
||||
return MagicMock(messages=[])
|
||||
|
||||
mock_session.get_prompt = get_prompt_with_error
|
||||
|
||||
# This call should trigger reconnection after ClosedResourceError
|
||||
result = await tool.get_prompt("test_prompt", arg1="value2")
|
||||
assert result is not None
|
||||
# Verify reconnect was attempted with reset=True
|
||||
assert mock_connect.call_count >= 1
|
||||
mock_connect.assert_called_with(reset=True)
|
||||
# Verify _exit_stack.aclose was called during reconnection
|
||||
original_exit_stack.aclose.assert_called()
|
||||
|
||||
# Test Case 2: Reconnection failure
|
||||
# Reset counters
|
||||
call_count = 0
|
||||
mock_connect.reset_mock()
|
||||
original_exit_stack.aclose.reset_mock()
|
||||
|
||||
# Make get_prompt always raise ClosedResourceError
|
||||
async def always_fail(*args, **kwargs):
|
||||
raise ClosedResourceError
|
||||
|
||||
mock_session.get_prompt = always_fail
|
||||
|
||||
# Change mock_connect to simulate failed reconnection
|
||||
mock_connect.side_effect = Exception("Failed to reconnect")
|
||||
|
||||
# This should raise ToolExecutionException when reconnection fails
|
||||
with pytest.raises(ToolExecutionException) as exc_info:
|
||||
await tool.get_prompt("test_prompt", arg1="value3")
|
||||
|
||||
# Verify reconnection was attempted
|
||||
assert mock_connect.call_count >= 1
|
||||
# Verify error message indicates reconnection failure
|
||||
assert "failed to reconnect" in str(exc_info.value).lower()
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from agent_framework import (
|
||||
AIFunction,
|
||||
@@ -15,7 +15,11 @@ from agent_framework import (
|
||||
ToolProtocol,
|
||||
ai_function,
|
||||
)
|
||||
from agent_framework._tools import _parse_annotation, _parse_inputs
|
||||
from agent_framework._tools import (
|
||||
_build_pydantic_model_from_json_schema,
|
||||
_parse_annotation,
|
||||
_parse_inputs,
|
||||
)
|
||||
from agent_framework.exceptions import ToolException
|
||||
from agent_framework.observability import OtelAttr
|
||||
|
||||
@@ -1548,4 +1552,467 @@ def test_parse_annotation_with_annotated_and_literal():
|
||||
assert get_args(literal_type) == ("A", "B", "C")
|
||||
|
||||
|
||||
def test_build_pydantic_model_from_json_schema_array_of_objects_issue():
|
||||
"""Test for Tools with complex input schema (array of objects).
|
||||
|
||||
This test verifies that JSON schemas with array properties containing nested objects
|
||||
are properly parsed, ensuring that the nested object schema is preserved
|
||||
and not reduced to a bare dict.
|
||||
|
||||
Example from issue:
|
||||
```
|
||||
const SalesOrderItemSchema = z.object({
|
||||
customerMaterialNumber: z.string().optional(),
|
||||
quantity: z.number(),
|
||||
unitOfMeasure: z.string()
|
||||
});
|
||||
|
||||
const CreateSalesOrderInputSchema = z.object({
|
||||
contract: z.string(),
|
||||
items: z.array(SalesOrderItemSchema)
|
||||
});
|
||||
```
|
||||
|
||||
The issue was that agents only saw:
|
||||
```
|
||||
{"contract": "str", "items": "list[dict]"}
|
||||
```
|
||||
|
||||
Instead of the proper nested schema with all fields.
|
||||
"""
|
||||
# Schema matching the issue description
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"contract": {"type": "string", "description": "Reference contract number"},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"description": "Sales order line items",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customerMaterialNumber": {
|
||||
"type": "string",
|
||||
"description": "Customer's material number",
|
||||
},
|
||||
"quantity": {"type": "number", "description": "Order quantity"},
|
||||
"unitOfMeasure": {
|
||||
"type": "string",
|
||||
"description": "Unit of measure (e.g., 'ST', 'KG', 'TO')",
|
||||
},
|
||||
},
|
||||
"required": ["quantity", "unitOfMeasure"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["contract", "items"],
|
||||
}
|
||||
|
||||
model = _build_pydantic_model_from_json_schema("create_sales_order", schema)
|
||||
|
||||
# Test valid data
|
||||
valid_data = {
|
||||
"contract": "CONTRACT-123",
|
||||
"items": [
|
||||
{
|
||||
"customerMaterialNumber": "MAT-001",
|
||||
"quantity": 10,
|
||||
"unitOfMeasure": "ST",
|
||||
},
|
||||
{"quantity": 5.5, "unitOfMeasure": "KG"},
|
||||
],
|
||||
}
|
||||
|
||||
instance = model(**valid_data)
|
||||
|
||||
# Verify the data was parsed correctly
|
||||
assert instance.contract == "CONTRACT-123"
|
||||
assert len(instance.items) == 2
|
||||
|
||||
# Verify first item
|
||||
assert instance.items[0].customerMaterialNumber == "MAT-001"
|
||||
assert instance.items[0].quantity == 10
|
||||
assert instance.items[0].unitOfMeasure == "ST"
|
||||
|
||||
# Verify second item (optional field not provided)
|
||||
assert instance.items[1].quantity == 5.5
|
||||
assert instance.items[1].unitOfMeasure == "KG"
|
||||
|
||||
# Verify that items are proper BaseModel instances, not bare dicts
|
||||
assert isinstance(instance.items[0], BaseModel)
|
||||
assert isinstance(instance.items[1], BaseModel)
|
||||
|
||||
# Verify that the nested object has the expected fields
|
||||
assert hasattr(instance.items[0], "customerMaterialNumber")
|
||||
assert hasattr(instance.items[0], "quantity")
|
||||
assert hasattr(instance.items[0], "unitOfMeasure")
|
||||
|
||||
# CRITICAL: Validate using the same methods that actual chat clients use
|
||||
# This is what would actually be sent to the LLM
|
||||
|
||||
# Create an AIFunction wrapper to access the client-facing APIs
|
||||
def dummy_func(**kwargs):
|
||||
return kwargs
|
||||
|
||||
test_func = AIFunction(
|
||||
func=dummy_func,
|
||||
name="create_sales_order",
|
||||
description="Create a sales order",
|
||||
input_model=model,
|
||||
)
|
||||
|
||||
# Test 1: Anthropic client uses tool.parameters() directly
|
||||
anthropic_schema = test_func.parameters()
|
||||
|
||||
# Verify contract property
|
||||
assert "contract" in anthropic_schema["properties"]
|
||||
assert anthropic_schema["properties"]["contract"]["type"] == "string"
|
||||
|
||||
# Verify items array property exists
|
||||
assert "items" in anthropic_schema["properties"]
|
||||
items_prop = anthropic_schema["properties"]["items"]
|
||||
assert items_prop["type"] == "array"
|
||||
|
||||
# THE KEY TEST for Anthropic: array items must have proper object schema
|
||||
assert "items" in items_prop, "Array should have 'items' schema definition"
|
||||
array_items_schema = items_prop["items"]
|
||||
|
||||
# Resolve schema if using $ref
|
||||
if "$ref" in array_items_schema:
|
||||
ref_path = array_items_schema["$ref"]
|
||||
assert ref_path.startswith("#/$defs/") or ref_path.startswith("#/definitions/")
|
||||
ref_name = ref_path.split("/")[-1]
|
||||
defs = anthropic_schema.get("$defs", anthropic_schema.get("definitions", {}))
|
||||
assert ref_name in defs, f"Referenced schema '{ref_name}' should exist"
|
||||
item_schema = defs[ref_name]
|
||||
else:
|
||||
item_schema = array_items_schema
|
||||
|
||||
# Verify the nested object has all properties defined
|
||||
assert "properties" in item_schema, "Array items should have properties (not bare dict)"
|
||||
item_properties = item_schema["properties"]
|
||||
|
||||
# All three fields must be present in schema sent to LLM
|
||||
assert "customerMaterialNumber" in item_properties, "customerMaterialNumber missing from LLM schema"
|
||||
assert "quantity" in item_properties, "quantity missing from LLM schema"
|
||||
assert "unitOfMeasure" in item_properties, "unitOfMeasure missing from LLM schema"
|
||||
|
||||
# Verify types are correct
|
||||
assert item_properties["customerMaterialNumber"]["type"] == "string"
|
||||
assert item_properties["quantity"]["type"] in ["number", "integer"]
|
||||
assert item_properties["unitOfMeasure"]["type"] == "string"
|
||||
|
||||
# Test 2: OpenAI client uses tool.to_json_schema_spec()
|
||||
openai_spec = test_func.to_json_schema_spec()
|
||||
|
||||
assert openai_spec["type"] == "function"
|
||||
assert "function" in openai_spec
|
||||
openai_schema = openai_spec["function"]["parameters"]
|
||||
|
||||
# Verify the same structure is present in OpenAI format
|
||||
assert "items" in openai_schema["properties"]
|
||||
openai_items_prop = openai_schema["properties"]["items"]
|
||||
assert openai_items_prop["type"] == "array"
|
||||
assert "items" in openai_items_prop
|
||||
|
||||
openai_array_items = openai_items_prop["items"]
|
||||
if "$ref" in openai_array_items:
|
||||
ref_path = openai_array_items["$ref"]
|
||||
ref_name = ref_path.split("/")[-1]
|
||||
defs = openai_schema.get("$defs", openai_schema.get("definitions", {}))
|
||||
openai_item_schema = defs[ref_name]
|
||||
else:
|
||||
openai_item_schema = openai_array_items
|
||||
|
||||
assert "properties" in openai_item_schema
|
||||
openai_props = openai_item_schema["properties"]
|
||||
assert "customerMaterialNumber" in openai_props
|
||||
assert "quantity" in openai_props
|
||||
assert "unitOfMeasure" in openai_props
|
||||
|
||||
# Test validation - missing required quantity
|
||||
with pytest.raises(ValidationError):
|
||||
model(
|
||||
contract="CONTRACT-456",
|
||||
items=[
|
||||
{
|
||||
"customerMaterialNumber": "MAT-002",
|
||||
"unitOfMeasure": "TO",
|
||||
# Missing required 'quantity'
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Test validation - missing required unitOfMeasure
|
||||
with pytest.raises(ValidationError):
|
||||
model(
|
||||
contract="CONTRACT-789",
|
||||
items=[
|
||||
{
|
||||
"quantity": 20
|
||||
# Missing required 'unitOfMeasure'
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_one_of_discriminator_polymorphism():
|
||||
"""Test that oneOf with discriminator creates proper polymorphic union types.
|
||||
|
||||
Tests that oneOf + discriminator patterns are properly converted to Pydantic discriminated unions.
|
||||
"""
|
||||
schema = {
|
||||
"$defs": {
|
||||
"CreateProject": {
|
||||
"description": "Action: Create an Azure DevOps project.",
|
||||
"properties": {
|
||||
"name": {
|
||||
"const": "create_project",
|
||||
"default": "create_project",
|
||||
"type": "string",
|
||||
},
|
||||
"params": {"$ref": "#/$defs/CreateProjectParams"},
|
||||
},
|
||||
"required": ["params"],
|
||||
"type": "object",
|
||||
},
|
||||
"CreateProjectParams": {
|
||||
"description": "Parameters for the create_project action.",
|
||||
"properties": {
|
||||
"orgUrl": {"minLength": 1, "type": "string"},
|
||||
"projectName": {"minLength": 1, "type": "string"},
|
||||
"description": {"default": "", "type": "string"},
|
||||
"template": {"default": "Agile", "type": "string"},
|
||||
"sourceControl": {
|
||||
"default": "Git",
|
||||
"enum": ["Git", "Tfvc"],
|
||||
"type": "string",
|
||||
},
|
||||
"visibility": {"default": "private", "type": "string"},
|
||||
},
|
||||
"required": ["orgUrl", "projectName"],
|
||||
"type": "object",
|
||||
},
|
||||
"DeployRequest": {
|
||||
"description": "Request to deploy Azure DevOps resources.",
|
||||
"properties": {
|
||||
"projectName": {"minLength": 1, "type": "string"},
|
||||
"organization": {"minLength": 1, "type": "string"},
|
||||
"actions": {
|
||||
"items": {
|
||||
"discriminator": {
|
||||
"mapping": {
|
||||
"create_project": "#/$defs/CreateProject",
|
||||
"hello_world": "#/$defs/HelloWorld",
|
||||
},
|
||||
"propertyName": "name",
|
||||
},
|
||||
"oneOf": [
|
||||
{"$ref": "#/$defs/HelloWorld"},
|
||||
{"$ref": "#/$defs/CreateProject"},
|
||||
],
|
||||
},
|
||||
"type": "array",
|
||||
},
|
||||
},
|
||||
"required": ["projectName", "organization"],
|
||||
"type": "object",
|
||||
},
|
||||
"HelloWorld": {
|
||||
"description": "Action: Prints a greeting message.",
|
||||
"properties": {
|
||||
"name": {
|
||||
"const": "hello_world",
|
||||
"default": "hello_world",
|
||||
"type": "string",
|
||||
},
|
||||
"params": {"$ref": "#/$defs/HelloWorldParams"},
|
||||
},
|
||||
"required": ["params"],
|
||||
"type": "object",
|
||||
},
|
||||
"HelloWorldParams": {
|
||||
"description": "Parameters for the hello_world action.",
|
||||
"properties": {
|
||||
"name": {
|
||||
"description": "Name to greet",
|
||||
"minLength": 1,
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
"required": ["name"],
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
"properties": {"params": {"$ref": "#/$defs/DeployRequest"}},
|
||||
"required": ["params"],
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# Build the model
|
||||
model = _build_pydantic_model_from_json_schema("deploy_tool", schema)
|
||||
|
||||
# Verify the model structure
|
||||
assert model is not None
|
||||
assert issubclass(model, BaseModel)
|
||||
|
||||
# Test with HelloWorld action
|
||||
hello_world_data = {
|
||||
"params": {
|
||||
"projectName": "MyProject",
|
||||
"organization": "MyOrg",
|
||||
"actions": [
|
||||
{
|
||||
"name": "hello_world",
|
||||
"params": {"name": "Alice"},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
instance = model(**hello_world_data)
|
||||
assert instance.params.projectName == "MyProject"
|
||||
assert instance.params.organization == "MyOrg"
|
||||
assert len(instance.params.actions) == 1
|
||||
assert instance.params.actions[0].name == "hello_world"
|
||||
assert instance.params.actions[0].params.name == "Alice"
|
||||
|
||||
# Test with CreateProject action
|
||||
create_project_data = {
|
||||
"params": {
|
||||
"projectName": "MyProject",
|
||||
"organization": "MyOrg",
|
||||
"actions": [
|
||||
{
|
||||
"name": "create_project",
|
||||
"params": {
|
||||
"orgUrl": "https://dev.azure.com/myorg",
|
||||
"projectName": "NewProject",
|
||||
"sourceControl": "Git",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
instance2 = model(**create_project_data)
|
||||
assert instance2.params.actions[0].name == "create_project"
|
||||
assert instance2.params.actions[0].params.projectName == "NewProject"
|
||||
assert instance2.params.actions[0].params.sourceControl == "Git"
|
||||
|
||||
# Test with mixed actions
|
||||
mixed_data = {
|
||||
"params": {
|
||||
"projectName": "MyProject",
|
||||
"organization": "MyOrg",
|
||||
"actions": [
|
||||
{"name": "hello_world", "params": {"name": "Bob"}},
|
||||
{
|
||||
"name": "create_project",
|
||||
"params": {
|
||||
"orgUrl": "https://dev.azure.com/myorg",
|
||||
"projectName": "AnotherProject",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
instance3 = model(**mixed_data)
|
||||
assert len(instance3.params.actions) == 2
|
||||
assert instance3.params.actions[0].name == "hello_world"
|
||||
assert instance3.params.actions[1].name == "create_project"
|
||||
|
||||
|
||||
def test_const_creates_literal():
|
||||
"""Test that const in JSON Schema creates Literal type."""
|
||||
schema = {
|
||||
"properties": {
|
||||
"action": {
|
||||
"const": "create",
|
||||
"type": "string",
|
||||
"description": "Action type",
|
||||
},
|
||||
"value": {"type": "integer"},
|
||||
},
|
||||
"required": ["action", "value"],
|
||||
}
|
||||
|
||||
model = _build_pydantic_model_from_json_schema("test_const", schema)
|
||||
|
||||
# Verify valid const value works
|
||||
instance = model(action="create", value=42)
|
||||
assert instance.action == "create"
|
||||
assert instance.value == 42
|
||||
|
||||
# Verify incorrect const value fails
|
||||
with pytest.raises(ValidationError):
|
||||
model(action="delete", value=42)
|
||||
|
||||
|
||||
def test_enum_creates_literal():
|
||||
"""Test that enum in JSON Schema creates Literal type."""
|
||||
schema = {
|
||||
"properties": {
|
||||
"status": {
|
||||
"enum": ["pending", "approved", "rejected"],
|
||||
"type": "string",
|
||||
"description": "Status",
|
||||
},
|
||||
"priority": {"enum": [1, 2, 3], "type": "integer"},
|
||||
},
|
||||
"required": ["status"],
|
||||
}
|
||||
|
||||
model = _build_pydantic_model_from_json_schema("test_enum", schema)
|
||||
|
||||
# Verify valid enum values work
|
||||
instance = model(status="approved", priority=2)
|
||||
assert instance.status == "approved"
|
||||
assert instance.priority == 2
|
||||
|
||||
# Verify invalid enum value fails
|
||||
with pytest.raises(ValidationError):
|
||||
model(status="unknown")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
model(status="pending", priority=5)
|
||||
|
||||
|
||||
def test_nested_object_with_const_and_enum():
|
||||
"""Test that const and enum work in nested objects."""
|
||||
schema = {
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"const": "production",
|
||||
"default": "production",
|
||||
"type": "string",
|
||||
},
|
||||
"level": {"enum": ["low", "medium", "high"], "type": "string"},
|
||||
},
|
||||
"required": ["level"],
|
||||
}
|
||||
},
|
||||
"required": ["config"],
|
||||
}
|
||||
|
||||
model = _build_pydantic_model_from_json_schema("test_nested", schema)
|
||||
|
||||
# Valid data
|
||||
instance = model(config={"type": "production", "level": "high"})
|
||||
assert instance.config.type == "production"
|
||||
assert instance.config.level == "high"
|
||||
|
||||
# Invalid const in nested object
|
||||
with pytest.raises(ValidationError):
|
||||
model(config={"type": "development", "level": "low"})
|
||||
|
||||
# Invalid enum in nested object
|
||||
with pytest.raises(ValidationError):
|
||||
model(config={"type": "production", "level": "critical"})
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user