Python: MCP Improvements: improved connection loss behavior, pagination for loading and a param to control representation (#3154)

* pagination support (#2848) added a parse_tool_result param and connection loss (#2884)

* fix #3153

* improved connection handling

* improved logic
This commit is contained in:
Eduard van Valkenburg
2026-01-13 05:09:33 +01:00
committed by GitHub
Unverified
parent 203fb7b1c4
commit b2893fbc00
4 changed files with 1381 additions and 93 deletions
+210 -81
View File
@@ -4,13 +4,14 @@ import logging
import re
import sys
from abc import abstractmethod
from collections.abc import Collection, Sequence
from collections.abc import Callable, Collection, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Literal
import httpx
from anyio import ClosedResourceError
from mcp import types
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
@@ -21,7 +22,11 @@ from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from pydantic import BaseModel, create_model
from ._tools import AIFunction, HostedMCPSpecificApproval, _build_pydantic_model_from_json_schema
from ._tools import (
AIFunction,
HostedMCPSpecificApproval,
_build_pydantic_model_from_json_schema,
)
from ._types import (
ChatMessage,
Contents,
@@ -329,7 +334,9 @@ class MCPTool:
approval_mode: (Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None) = None,
allowed_tools: Collection[str] | None = None,
load_tools: bool = True,
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
load_prompts: bool = True,
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
session: ClientSession | None = None,
request_timeout: int | None = None,
chat_client: "ChatClientProtocol | None" = None,
@@ -347,7 +354,9 @@ class MCPTool:
self.allowed_tools = allowed_tools
self.additional_properties = additional_properties
self.load_tools_flag = load_tools
self.parse_tool_results = parse_tool_results
self.load_prompts_flag = load_prompts
self.parse_prompt_results = parse_prompt_results
self._exit_stack = AsyncExitStack()
self.session = session
self.request_timeout = request_timeout
@@ -367,15 +376,23 @@ class MCPTool:
return self._functions
return [func for func in self._functions if func.name in self.allowed_tools]
async def connect(self) -> None:
async def connect(self, *, reset: bool = False) -> None:
"""Connect to the MCP server.
Establishes a connection to the MCP server, initializes the session,
and loads tools and prompts if configured to do so.
Keyword Args:
reset: If True, forces a reconnection even if already connected.
Raises:
ToolException: If connection or session initialization fails.
"""
if reset:
await self._exit_stack.aclose()
self.session = None
self.is_connected = False
self._exit_stack = AsyncExitStack()
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
@@ -565,86 +582,88 @@ class MCPTool:
"""Load prompts from the MCP server.
Retrieves available prompts from the connected MCP server and converts
them into AIFunction instances.
them into AIFunction instances. Handles pagination automatically.
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
try:
prompt_list = await self.session.list_prompts()
except Exception as exc:
logger.info(
"Prompt could not be loaded, you can exclude trying to load, by setting: load_prompts=False",
exc_info=exc,
)
prompt_list = None
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
for prompt in prompt_list.prompts if prompt_list else []:
local_name = _normalize_mcp_name(prompt.name)
params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()
# Skip if already loaded
if local_name in existing_names:
continue
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
input_model = _get_input_model_from_mcp_prompt(prompt)
approval_mode = self._determine_approval_mode(local_name)
func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
approval_mode=approval_mode,
input_model=input_model,
)
self._functions.append(func)
existing_names.add(local_name)
for prompt in prompt_list.prompts:
local_name = _normalize_mcp_name(prompt.name)
# Skip if already loaded
if local_name in existing_names:
continue
input_model = _get_input_model_from_mcp_prompt(prompt)
approval_mode = self._determine_approval_mode(local_name)
func: AIFunction[BaseModel, list[ChatMessage] | Any | types.GetPromptResult] = AIFunction(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
approval_mode=approval_mode,
input_model=input_model,
)
self._functions.append(func)
existing_names.add(local_name)
# Check if there are more pages
if not prompt_list or not prompt_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor)
async def load_tools(self) -> None:
"""Load tools from the MCP server.
Retrieves available tools from the connected MCP server and converts
them into AIFunction instances.
them into AIFunction instances. Handles pagination automatically.
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
try:
tool_list = await self.session.list_tools()
except Exception as exc:
logger.info(
"Tools could not be loaded, you can exclude trying to load, by setting: load_tools=False",
exc_info=exc,
)
tool_list = None
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
for tool in tool_list.tools if tool_list else []:
local_name = _normalize_mcp_name(tool.name)
params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()
# Skip if already loaded
if local_name in existing_names:
continue
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
input_model = _get_input_model_from_mcp_tool(tool)
approval_mode = self._determine_approval_mode(local_name)
# Create AIFunctions out of each tool
func: AIFunction[BaseModel, list[Contents]] = AIFunction(
func=partial(self.call_tool, tool.name),
name=local_name,
description=tool.description or "",
approval_mode=approval_mode,
input_model=input_model,
)
self._functions.append(func)
existing_names.add(local_name)
for tool in tool_list.tools:
local_name = _normalize_mcp_name(tool.name)
# Skip if already loaded
if local_name in existing_names:
continue
input_model = _get_input_model_from_mcp_tool(tool)
approval_mode = self._determine_approval_mode(local_name)
# Create AIFunctions out of each tool
func: AIFunction[BaseModel, list[Contents] | Any | types.CallToolResult] = AIFunction(
func=partial(self.call_tool, tool.name),
name=local_name,
description=tool.description or "",
approval_mode=approval_mode,
input_model=input_model,
)
self._functions.append(func)
existing_names.add(local_name)
# Check if there are more pages
if not tool_list or not tool_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
async def close(self) -> None:
"""Disconnect from the MCP server.
@@ -664,7 +683,28 @@ class MCPTool:
"""
pass
async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]:
async def _ensure_connected(self) -> None:
"""Ensure the connection is valid, reconnecting if necessary.
This method proactively checks if the connection is valid and
reconnects if it's not, avoiding the need to catch ClosedResourceError.
Raises:
ToolExecutionException: If reconnection fails.
"""
try:
await self.session.send_ping() # type: ignore[union-attr]
except Exception:
logger.info("MCP connection invalid or closed. Reconnecting...")
try:
await self.connect(reset=True)
except Exception as ex:
raise ToolExecutionException(
"Failed to establish MCP connection.",
inner_exception=ex,
) from ex
async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents] | Any | types.CallToolResult:
"""Call a tool with the given arguments.
Args:
@@ -680,8 +720,6 @@ class MCPTool:
ToolExecutionException: If the MCP server is not connected, tools are not loaded,
or the tool call fails.
"""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
if not self.load_tools_flag:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
@@ -692,16 +730,44 @@ class MCPTool:
filtered_kwargs = {
k: v for k, v in kwargs.items() if k not in {"chat_options", "tools", "tool_choice", "thread"}
}
try:
return _parse_contents_from_mcp_tool_result(
await self.session.call_tool(tool_name, arguments=filtered_kwargs)
)
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage]:
# Try the operation, reconnecting once if the connection is closed
for attempt in range(2):
try:
result = await self.session.call_tool(tool_name, arguments=filtered_kwargs) # type: ignore
if self.parse_tool_results is None:
return result
if self.parse_tool_results is True:
return _parse_contents_from_mcp_tool_result(result)
if callable(self.parse_tool_results):
return self.parse_tool_results(result)
return result
except ClosedResourceError as cl_ex:
if attempt == 0:
# First attempt failed, try reconnecting
logger.info("MCP connection closed unexpectedly. Reconnecting...")
try:
await self.connect(reset=True)
continue # Retry the operation
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
else:
# Second attempt also failed, give up
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] | Any | types.GetPromptResult:
"""Call a prompt with the given arguments.
Args:
@@ -717,19 +783,46 @@ class MCPTool:
ToolExecutionException: If the MCP server is not connected, prompts are not loaded,
or the prompt call fails.
"""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
if not self.load_prompts_flag:
raise ToolExecutionException(
"Prompts are not loaded for this server, please set load_prompts=True in the constructor."
)
try:
prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs)
return [_parse_message_from_mcp(message) for message in prompt_result.messages]
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
# Try the operation, reconnecting once if the connection is closed
for attempt in range(2):
try:
prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) # type: ignore
if self.parse_prompt_results is None:
return prompt_result
if self.parse_prompt_results is True:
return [_parse_message_from_mcp(message) for message in prompt_result.messages]
if callable(self.parse_prompt_results):
return self.parse_prompt_results(prompt_result)
return prompt_result
except ClosedResourceError as cl_ex:
if attempt == 0:
# First attempt failed, try reconnecting
logger.info("MCP connection closed unexpectedly. Reconnecting...")
try:
await self.connect(reset=True)
continue # Retry the operation
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
else:
# Second attempt also failed, give up
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
raise ToolExecutionException(
f"Failed to call prompt '{prompt_name}' - connection lost.",
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to get prompt '{prompt_name}' after retries.")
async def __aenter__(self) -> Self:
"""Enter the async context manager.
@@ -804,7 +897,9 @@ class MCPStdioTool(MCPTool):
command: str,
*,
load_tools: bool = True,
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
load_prompts: bool = True,
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
@@ -830,7 +925,15 @@ class MCPStdioTool(MCPTool):
Keyword Args:
load_tools: Whether to load tools from the MCP server.
parse_tool_results: How to parse tool results from the MCP server.
Set to True, to use the default parser that converts to Agent Framework types.
Set to a callable to use a custom parser function.
Set to None to return the raw MCP tool result.
load_prompts: Whether to load prompts from the MCP server.
parse_prompt_results: How to parse prompt results from the MCP server.
Set to True, to use the default parser that converts to Agent Framework types.
Set to a callable to use a custom parser function.
Set to None to return the raw MCP prompt result.
request_timeout: The default timeout in seconds for all requests.
session: The session to use for the MCP connection.
description: The description of the tool.
@@ -857,7 +960,9 @@ class MCPStdioTool(MCPTool):
session=session,
chat_client=chat_client,
load_tools=load_tools,
parse_tool_results=parse_tool_results,
load_prompts=load_prompts,
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
)
self.command = command
@@ -913,7 +1018,9 @@ class MCPStreamableHTTPTool(MCPTool):
url: str,
*,
load_tools: bool = True,
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
load_prompts: bool = True,
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
@@ -939,7 +1046,15 @@ class MCPStreamableHTTPTool(MCPTool):
Keyword Args:
load_tools: Whether to load tools from the MCP server.
parse_tool_results: How to parse tool results from the MCP server.
Set to True, to use the default parser that converts to Agent Framework types.
Set to a callable to use a custom parser function.
Set to None to return the raw MCP tool result.
load_prompts: Whether to load prompts from the MCP server.
parse_prompt_results: How to parse prompt results from the MCP server.
Set to True, to use the default parser that converts to Agent Framework types.
Set to a callable to use a custom parser function.
Set to None to return the raw MCP prompt result.
request_timeout: The default timeout in seconds for all requests.
session: The session to use for the MCP connection.
description: The description of the tool.
@@ -968,7 +1083,9 @@ class MCPStreamableHTTPTool(MCPTool):
session=session,
chat_client=chat_client,
load_tools=load_tools,
parse_tool_results=parse_tool_results,
load_prompts=load_prompts,
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
)
self.url = url
@@ -1016,7 +1133,9 @@ class MCPWebsocketTool(MCPTool):
url: str,
*,
load_tools: bool = True,
parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True,
load_prompts: bool = True,
parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True,
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
@@ -1040,7 +1159,15 @@ class MCPWebsocketTool(MCPTool):
Keyword Args:
load_tools: Whether to load tools from the MCP server.
parse_tool_results: How to parse tool results from the MCP server.
Set to True, to use the default parser that converts to Agent Framework types.
Set to a callable to use a custom parser function.
Set to None to return the raw MCP tool result.
load_prompts: Whether to load prompts from the MCP server.
parse_prompt_results: How to parse prompt results from the MCP server.
Set to True, to use the default parser that converts to Agent Framework types.
Set to a callable to use a custom parser function.
Set to None to return the raw MCP prompt result.
request_timeout: The default timeout in seconds for all requests.
session: The session to use for the MCP connection.
description: The description of the tool.
@@ -1064,7 +1191,9 @@ class MCPWebsocketTool(MCPTool):
session=session,
chat_client=chat_client,
load_tools=load_tools,
parse_tool_results=parse_tool_results,
load_prompts=load_prompts,
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
)
self.url = url
+78 -6
View File
@@ -4,7 +4,15 @@ import asyncio
import inspect
import json
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, Collection, Mapping, MutableMapping, Sequence
from collections.abc import (
AsyncIterable,
Awaitable,
Callable,
Collection,
Mapping,
MutableMapping,
Sequence,
)
from functools import wraps
from time import perf_counter, time_ns
from typing import (
@@ -18,6 +26,7 @@ from typing import (
Protocol,
TypedDict,
TypeVar,
Union,
cast,
get_args,
get_origin,
@@ -121,7 +130,13 @@ def _parse_inputs(
if inputs is None:
return []
from ._types import BaseContent, DataContent, HostedFileContent, HostedVectorStoreContent, UriContent
from ._types import (
BaseContent,
DataContent,
HostedFileContent,
HostedVectorStoreContent,
UriContent,
)
parsed_inputs: list["Contents"] = []
if not isinstance(inputs, list):
@@ -1010,6 +1025,27 @@ def _build_pydantic_model_from_json_schema(
if not properties:
return create_model(f"{model_name}_input")
def _resolve_literal_type(prop_details: dict[str, Any]) -> type | None:
"""Check if property should be a Literal type (const or enum).
Args:
prop_details: The JSON Schema property details
Returns:
Literal type if const or enum is present, None otherwise
"""
# const → Literal["value"]
if "const" in prop_details:
return Literal[prop_details["const"]] # type: ignore
# enum → Literal["a", "b", ...]
if "enum" in prop_details and isinstance(prop_details["enum"], list):
enum_values = prop_details["enum"]
if enum_values:
return Literal[tuple(enum_values)] # type: ignore
return None
def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
"""Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays.
@@ -1020,6 +1056,31 @@ def _build_pydantic_model_from_json_schema(
Returns:
Python type annotation (could be int, str, list[str], or a nested Pydantic model)
"""
# Handle oneOf + discriminator (polymorphic objects)
if "oneOf" in prop_details and "discriminator" in prop_details:
discriminator = prop_details["discriminator"]
disc_field = discriminator.get("propertyName")
variants = []
for variant in prop_details["oneOf"]:
if "$ref" in variant:
ref = variant["$ref"]
if ref.startswith("#/$defs/"):
def_name = ref.split("/")[-1]
resolved = definitions.get(def_name)
if resolved:
variant_model = _resolve_type(
resolved,
parent_name=f"{parent_name}_{def_name}",
)
variants.append(variant_model)
if variants and disc_field:
return Annotated[
Union[tuple(variants)], # type: ignore
Field(discriminator=disc_field),
]
# Handle $ref by resolving the reference
if "$ref" in prop_details:
ref = prop_details["$ref"]
@@ -1070,9 +1131,15 @@ def _build_pydantic_model_from_json_schema(
else nested_prop_details
)
nested_python_type = _resolve_type(
nested_prop_details, f"{nested_model_name}_{nested_prop_name}"
)
# Check for Literal types first (const/enum)
literal_type = _resolve_literal_type(nested_prop_details)
if literal_type is not None:
nested_python_type = literal_type
else:
nested_python_type = _resolve_type(
nested_prop_details,
f"{nested_model_name}_{nested_prop_name}",
)
nested_description = nested_prop_details.get("description", "")
# Build field kwargs for nested property
@@ -1109,7 +1176,12 @@ def _build_pydantic_model_from_json_schema(
for prop_name, prop_details in properties.items():
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}")
# Check for Literal types first (const/enum)
literal_type = _resolve_literal_type(prop_details)
if literal_type is not None:
python_type = literal_type
else:
python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}")
description = prop_details.get("description", "")
# Build field kwargs (description, etc.)
+624 -4
View File
@@ -1248,6 +1248,75 @@ async def test_streamable_http_integration():
assert result[0].text is not None
@pytest.mark.flaky
@skip_if_mcp_integration_tests_disabled
async def test_mcp_connection_reset_integration():
"""Test that connection reset works correctly with a real MCP server.
This integration test verifies:
1. Initial connection and tool execution works
2. Simulating connection failure triggers automatic reconnection
3. Tool execution works after reconnection
4. Exit stack cleanup happens properly during reconnection
"""
url = os.environ.get("LOCAL_MCP_URL")
tool = MCPStreamableHTTPTool(name="integration_test", url=url)
async with tool:
# Verify initial connection
assert tool.session is not None
assert tool.is_connected is True
assert len(tool.functions) > 0, "The MCP server should have at least one function."
# Get the first function and invoke it
func = tool.functions[0]
first_result = await func.invoke(query="What is Agent Framework?")
assert first_result is not None
assert len(first_result) > 0
# Store the original session and exit stack for comparison
original_session = tool.session
original_exit_stack = tool._exit_stack
original_call_tool = tool.session.call_tool
# Simulate connection failure by making call_tool raise ClosedResourceError once
call_count = 0
async def call_tool_with_error(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First call fails with connection error
from anyio.streams.memory import ClosedResourceError
raise ClosedResourceError
# After reconnection, delegate to the original method
return await original_call_tool(*args, **kwargs)
tool.session.call_tool = call_tool_with_error
# Invoke the function again - this should trigger automatic reconnection on ClosedResourceError
second_result = await func.invoke(query="What is Agent Framework?")
assert second_result is not None
assert len(second_result) > 0
# Verify we have a new session and exit stack after reconnection
assert tool.session is not None
assert tool.session is not original_session, "Session should be replaced after reconnection"
assert tool._exit_stack is not original_exit_stack, "Exit stack should be replaced after reconnection"
assert tool.is_connected is True
# Verify tools are still available after reconnection
assert len(tool.functions) > 0
# Both results should be valid (we don't compare content as it may vary)
if hasattr(first_result[0], "text"):
assert first_result[0].text is not None
if hasattr(second_result[0], "text"):
assert second_result[0].text is not None
async def test_mcp_tool_message_handler_notification():
"""Test that message_handler correctly processes tools/list_changed and prompts/list_changed
notifications."""
@@ -1549,7 +1618,6 @@ def test_mcp_websocket_tool_get_mcp_client_with_kwargs():
)
@pytest.mark.asyncio
async def test_mcp_tool_deduplication():
"""Test that MCP tools are not duplicated in MCPTool"""
from agent_framework._mcp import MCPTool
@@ -1611,7 +1679,6 @@ async def test_mcp_tool_deduplication():
assert added_count == 1 # Only 1 new function added
@pytest.mark.asyncio
async def test_load_tools_prevents_multiple_calls():
"""Test that connect() prevents calling load_tools() multiple times"""
from unittest.mock import AsyncMock, MagicMock
@@ -1627,6 +1694,7 @@ async def test_load_tools_prevents_multiple_calls():
mock_session = AsyncMock()
mock_tool_list = MagicMock()
mock_tool_list.tools = []
mock_tool_list.nextCursor = None # No pagination
mock_session.list_tools = AsyncMock(return_value=mock_tool_list)
mock_session.initialize = AsyncMock()
@@ -1650,7 +1718,6 @@ async def test_load_tools_prevents_multiple_calls():
assert mock_session.list_tools.call_count == 1 # Still 1, not incremented
@pytest.mark.asyncio
async def test_load_prompts_prevents_multiple_calls():
"""Test that connect() prevents calling load_prompts() multiple times"""
from unittest.mock import AsyncMock, MagicMock
@@ -1666,6 +1733,7 @@ async def test_load_prompts_prevents_multiple_calls():
mock_session = AsyncMock()
mock_prompt_list = MagicMock()
mock_prompt_list.prompts = []
mock_prompt_list.nextCursor = None # No pagination
mock_session.list_prompts = AsyncMock(return_value=mock_prompt_list)
tool.session = mock_session
@@ -1688,7 +1756,6 @@ async def test_load_prompts_prevents_multiple_calls():
assert mock_session.list_prompts.call_count == 1 # Still 1, not incremented
@pytest.mark.asyncio
async def test_mcp_streamable_http_tool_httpx_client_cleanup():
"""Test that MCPStreamableHTTPTool properly passes through httpx clients."""
from unittest.mock import AsyncMock, Mock, patch
@@ -1744,3 +1811,556 @@ async def test_mcp_streamable_http_tool_httpx_client_cleanup():
# Get the last call (should be from tool2.connect())
call_args = mock_client.call_args
assert call_args.kwargs["http_client"] is user_client, "User's client should be passed through"
async def test_load_tools_with_pagination():
"""Test that load_tools handles pagination correctly."""
from unittest.mock import AsyncMock, MagicMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
# Create paginated responses
page1 = MagicMock()
page1.tools = [
types.Tool(
name="tool_1",
description="First tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
types.Tool(
name="tool_2",
description="Second tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
]
page1.nextCursor = "cursor_page2"
page2 = MagicMock()
page2.tools = [
types.Tool(
name="tool_3",
description="Third tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
]
page2.nextCursor = "cursor_page3"
page3 = MagicMock()
page3.tools = [
types.Tool(
name="tool_4",
description="Fourth tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
]
page3.nextCursor = None # No more pages
# Mock list_tools to return different pages based on params
async def mock_list_tools(params=None):
if params is None:
return page1
if params.cursor == "cursor_page2":
return page2
if params.cursor == "cursor_page3":
return page3
raise ValueError("Unexpected cursor value")
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
# Load tools with pagination
await tool.load_tools()
# Verify all pages were fetched
assert mock_session.list_tools.call_count == 3
assert len(tool._functions) == 4
assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3", "tool_4"]
async def test_load_prompts_with_pagination():
"""Test that load_prompts handles pagination correctly."""
from unittest.mock import AsyncMock, MagicMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
# Create paginated responses
page1 = MagicMock()
page1.prompts = [
types.Prompt(
name="prompt_1",
description="First prompt",
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
),
types.Prompt(
name="prompt_2",
description="Second prompt",
arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=True)],
),
]
page1.nextCursor = "cursor_page2"
page2 = MagicMock()
page2.prompts = [
types.Prompt(
name="prompt_3",
description="Third prompt",
arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=False)],
),
]
page2.nextCursor = None # No more pages
# Mock list_prompts to return different pages based on params
async def mock_list_prompts(params=None):
if params is None:
return page1
if params.cursor == "cursor_page2":
return page2
raise ValueError("Unexpected cursor value")
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
# Load prompts with pagination
await tool.load_prompts()
# Verify all pages were fetched
assert mock_session.list_prompts.call_count == 2
assert len(tool._functions) == 3
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2", "prompt_3"]
async def test_load_tools_pagination_with_duplicates():
"""Test that load_tools prevents duplicates across paginated results."""
from unittest.mock import AsyncMock, MagicMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
# Create paginated responses with duplicate tool names
page1 = MagicMock()
page1.tools = [
types.Tool(
name="tool_1",
description="First tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
types.Tool(
name="tool_2",
description="Second tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
]
page1.nextCursor = "cursor_page2"
page2 = MagicMock()
page2.tools = [
types.Tool(
name="tool_1", # Duplicate from page1
description="Duplicate tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
types.Tool(
name="tool_3",
description="Third tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
),
]
page2.nextCursor = None
# Mock list_tools to return different pages
async def mock_list_tools(params=None):
if params is None:
return page1
if params.cursor == "cursor_page2":
return page2
raise ValueError("Unexpected cursor value")
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
# Load tools with pagination
await tool.load_tools()
# Verify duplicates were skipped
assert mock_session.list_tools.call_count == 2
assert len(tool._functions) == 3
assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3"]
async def test_load_prompts_pagination_with_duplicates():
"""Test that load_prompts prevents duplicates across paginated results."""
from unittest.mock import AsyncMock, MagicMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
# Create paginated responses with duplicate prompt names
page1 = MagicMock()
page1.prompts = [
types.Prompt(
name="prompt_1",
description="First prompt",
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
),
]
page1.nextCursor = "cursor_page2"
page2 = MagicMock()
page2.prompts = [
types.Prompt(
name="prompt_1", # Duplicate from page1
description="Duplicate prompt",
arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=False)],
),
types.Prompt(
name="prompt_2",
description="Second prompt",
arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=True)],
),
]
page2.nextCursor = None
# Mock list_prompts to return different pages
async def mock_list_prompts(params=None):
if params is None:
return page1
if params.cursor == "cursor_page2":
return page2
raise ValueError("Unexpected cursor value")
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
# Load prompts with pagination
await tool.load_prompts()
# Verify duplicates were skipped
assert mock_session.list_prompts.call_count == 2
assert len(tool._functions) == 2
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
async def test_load_tools_pagination_exception_handling():
"""Test that load_tools handles exceptions during pagination gracefully."""
from unittest.mock import AsyncMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
# Mock list_tools to raise an exception on first call
mock_session.list_tools = AsyncMock(side_effect=RuntimeError("Connection error"))
# Load tools should raise the exception (not handled gracefully)
with pytest.raises(RuntimeError, match="Connection error"):
await tool.load_tools()
# Verify exception was raised on first call
assert mock_session.list_tools.call_count == 1
assert len(tool._functions) == 0
async def test_load_prompts_pagination_exception_handling():
"""Test that load_prompts handles exceptions during pagination gracefully."""
from unittest.mock import AsyncMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
# Mock list_prompts to raise an exception on first call
mock_session.list_prompts = AsyncMock(side_effect=RuntimeError("Connection error"))
# Load prompts should raise the exception (not handled gracefully)
with pytest.raises(RuntimeError, match="Connection error"):
await tool.load_prompts()
# Verify exception was raised on first call
assert mock_session.list_prompts.call_count == 1
assert len(tool._functions) == 0
async def test_load_tools_empty_pagination():
"""Test that load_tools handles empty paginated results."""
from unittest.mock import AsyncMock, MagicMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
# Create empty response
page1 = MagicMock()
page1.tools = []
page1.nextCursor = None
mock_session.list_tools = AsyncMock(return_value=page1)
# Load tools
await tool.load_tools()
# Verify
assert mock_session.list_tools.call_count == 1
assert len(tool._functions) == 0
async def test_load_prompts_empty_pagination():
"""Test that load_prompts handles empty paginated results."""
from unittest.mock import AsyncMock, MagicMock
from agent_framework._mcp import MCPTool
tool = MCPTool(name="test_tool")
# Mock the session
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
# Create empty response
page1 = MagicMock()
page1.prompts = []
page1.nextCursor = None
mock_session.list_prompts = AsyncMock(return_value=page1)
# Load prompts
await tool.load_prompts()
# Verify
assert mock_session.list_prompts.call_count == 1
assert len(tool._functions) == 0
async def test_mcp_tool_connection_properly_invalidated_after_closed_resource_error():
"""Test that verifies reconnection on ClosedResourceError for issue #2884.
This test verifies the fix for issue #2884: the tool tries operations optimistically
and only reconnects when ClosedResourceError is encountered, avoiding extra latency.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from anyio.streams.memory import ClosedResourceError
from agent_framework._mcp import MCPStdioTool
from agent_framework.exceptions import ToolExecutionException
# Create a mock MCP tool
tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=True,
)
# Mock the session
mock_session = MagicMock()
mock_session._request_id = 1
mock_session.call_tool = AsyncMock()
# Mock _exit_stack.aclose to track cleanup calls
original_exit_stack = tool._exit_stack
tool._exit_stack.aclose = AsyncMock()
# Mock connect() to avoid trying to start actual process
with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect:
async def restore_session(*, reset=False):
if reset:
await original_exit_stack.aclose()
tool.session = mock_session
tool.is_connected = True
tool._tools_loaded = True
mock_connect.side_effect = restore_session
# Simulate initial connection
tool.session = mock_session
tool.is_connected = True
tool._tools_loaded = True
# First call should work - connection is valid
mock_session.call_tool.return_value = MagicMock(content=[])
result = await tool.call_tool("test_tool", arg1="value1")
assert result is not None
# Test Case 1: Connection closed unexpectedly, should reconnect and retry
# Simulate ClosedResourceError on first call, then succeed
call_count = 0
async def call_tool_with_error(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ClosedResourceError
return MagicMock(content=[])
mock_session.call_tool = call_tool_with_error
# This call should trigger reconnection after ClosedResourceError
result = await tool.call_tool("test_tool", arg1="value2")
assert result is not None
# Verify reconnect was attempted with reset=True
assert mock_connect.call_count >= 1
mock_connect.assert_called_with(reset=True)
# Verify _exit_stack.aclose was called during reconnection
original_exit_stack.aclose.assert_called()
# Test Case 2: Reconnection failure
# Reset counters
call_count = 0
mock_connect.reset_mock()
original_exit_stack.aclose.reset_mock()
# Make call_tool always raise ClosedResourceError
async def always_fail(*args, **kwargs):
raise ClosedResourceError
mock_session.call_tool = always_fail
# Change mock_connect to simulate failed reconnection
mock_connect.side_effect = Exception("Failed to reconnect")
# This should raise ToolExecutionException when reconnection fails
with pytest.raises(ToolExecutionException) as exc_info:
await tool.call_tool("test_tool", arg1="value3")
# Verify reconnection was attempted
assert mock_connect.call_count >= 1
# Verify error message indicates reconnection failure
assert "failed to reconnect" in str(exc_info.value).lower()
async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error():
"""Test that get_prompt also reconnects on ClosedResourceError.
This verifies that the fix for issue #2884 applies to get_prompt as well,
and that _exit_stack.aclose() is properly called during reconnection.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from anyio.streams.memory import ClosedResourceError
from agent_framework._mcp import MCPStdioTool
from agent_framework.exceptions import ToolExecutionException
# Create a mock MCP tool
tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_prompts=True,
)
# Mock the session
mock_session = MagicMock()
mock_session._request_id = 1
mock_session.get_prompt = AsyncMock()
# Mock _exit_stack.aclose to track cleanup calls
original_exit_stack = tool._exit_stack
tool._exit_stack.aclose = AsyncMock()
# Mock connect() to avoid trying to start actual process
with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect:
async def restore_session(*, reset=False):
if reset:
await original_exit_stack.aclose()
tool.session = mock_session
tool.is_connected = True
tool._prompts_loaded = True
mock_connect.side_effect = restore_session
# Simulate initial connection
tool.session = mock_session
tool.is_connected = True
tool._prompts_loaded = True
# First call should work - connection is valid
mock_session.get_prompt.return_value = MagicMock(messages=[])
result = await tool.get_prompt("test_prompt", arg1="value1")
assert result is not None
# Test Case 1: Connection closed unexpectedly, should reconnect and retry
# Simulate ClosedResourceError on first call, then succeed
call_count = 0
async def get_prompt_with_error(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ClosedResourceError
return MagicMock(messages=[])
mock_session.get_prompt = get_prompt_with_error
# This call should trigger reconnection after ClosedResourceError
result = await tool.get_prompt("test_prompt", arg1="value2")
assert result is not None
# Verify reconnect was attempted with reset=True
assert mock_connect.call_count >= 1
mock_connect.assert_called_with(reset=True)
# Verify _exit_stack.aclose was called during reconnection
original_exit_stack.aclose.assert_called()
# Test Case 2: Reconnection failure
# Reset counters
call_count = 0
mock_connect.reset_mock()
original_exit_stack.aclose.reset_mock()
# Make get_prompt always raise ClosedResourceError
async def always_fail(*args, **kwargs):
raise ClosedResourceError
mock_session.get_prompt = always_fail
# Change mock_connect to simulate failed reconnection
mock_connect.side_effect = Exception("Failed to reconnect")
# This should raise ToolExecutionException when reconnection fails
with pytest.raises(ToolExecutionException) as exc_info:
await tool.get_prompt("test_prompt", arg1="value3")
# Verify reconnection was attempted
assert mock_connect.call_count >= 1
# Verify error message indicates reconnection failure
assert "failed to reconnect" in str(exc_info.value).lower()
+469 -2
View File
@@ -5,7 +5,7 @@ from unittest.mock import Mock
import pytest
from opentelemetry import trace
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from agent_framework import (
AIFunction,
@@ -15,7 +15,11 @@ from agent_framework import (
ToolProtocol,
ai_function,
)
from agent_framework._tools import _parse_annotation, _parse_inputs
from agent_framework._tools import (
_build_pydantic_model_from_json_schema,
_parse_annotation,
_parse_inputs,
)
from agent_framework.exceptions import ToolException
from agent_framework.observability import OtelAttr
@@ -1548,4 +1552,467 @@ def test_parse_annotation_with_annotated_and_literal():
assert get_args(literal_type) == ("A", "B", "C")
def test_build_pydantic_model_from_json_schema_array_of_objects_issue():
"""Test for Tools with complex input schema (array of objects).
This test verifies that JSON schemas with array properties containing nested objects
are properly parsed, ensuring that the nested object schema is preserved
and not reduced to a bare dict.
Example from issue:
```
const SalesOrderItemSchema = z.object({
customerMaterialNumber: z.string().optional(),
quantity: z.number(),
unitOfMeasure: z.string()
});
const CreateSalesOrderInputSchema = z.object({
contract: z.string(),
items: z.array(SalesOrderItemSchema)
});
```
The issue was that agents only saw:
```
{"contract": "str", "items": "list[dict]"}
```
Instead of the proper nested schema with all fields.
"""
# Schema matching the issue description
schema = {
"type": "object",
"properties": {
"contract": {"type": "string", "description": "Reference contract number"},
"items": {
"type": "array",
"description": "Sales order line items",
"items": {
"type": "object",
"properties": {
"customerMaterialNumber": {
"type": "string",
"description": "Customer's material number",
},
"quantity": {"type": "number", "description": "Order quantity"},
"unitOfMeasure": {
"type": "string",
"description": "Unit of measure (e.g., 'ST', 'KG', 'TO')",
},
},
"required": ["quantity", "unitOfMeasure"],
},
},
},
"required": ["contract", "items"],
}
model = _build_pydantic_model_from_json_schema("create_sales_order", schema)
# Test valid data
valid_data = {
"contract": "CONTRACT-123",
"items": [
{
"customerMaterialNumber": "MAT-001",
"quantity": 10,
"unitOfMeasure": "ST",
},
{"quantity": 5.5, "unitOfMeasure": "KG"},
],
}
instance = model(**valid_data)
# Verify the data was parsed correctly
assert instance.contract == "CONTRACT-123"
assert len(instance.items) == 2
# Verify first item
assert instance.items[0].customerMaterialNumber == "MAT-001"
assert instance.items[0].quantity == 10
assert instance.items[0].unitOfMeasure == "ST"
# Verify second item (optional field not provided)
assert instance.items[1].quantity == 5.5
assert instance.items[1].unitOfMeasure == "KG"
# Verify that items are proper BaseModel instances, not bare dicts
assert isinstance(instance.items[0], BaseModel)
assert isinstance(instance.items[1], BaseModel)
# Verify that the nested object has the expected fields
assert hasattr(instance.items[0], "customerMaterialNumber")
assert hasattr(instance.items[0], "quantity")
assert hasattr(instance.items[0], "unitOfMeasure")
# CRITICAL: Validate using the same methods that actual chat clients use
# This is what would actually be sent to the LLM
# Create an AIFunction wrapper to access the client-facing APIs
def dummy_func(**kwargs):
return kwargs
test_func = AIFunction(
func=dummy_func,
name="create_sales_order",
description="Create a sales order",
input_model=model,
)
# Test 1: Anthropic client uses tool.parameters() directly
anthropic_schema = test_func.parameters()
# Verify contract property
assert "contract" in anthropic_schema["properties"]
assert anthropic_schema["properties"]["contract"]["type"] == "string"
# Verify items array property exists
assert "items" in anthropic_schema["properties"]
items_prop = anthropic_schema["properties"]["items"]
assert items_prop["type"] == "array"
# THE KEY TEST for Anthropic: array items must have proper object schema
assert "items" in items_prop, "Array should have 'items' schema definition"
array_items_schema = items_prop["items"]
# Resolve schema if using $ref
if "$ref" in array_items_schema:
ref_path = array_items_schema["$ref"]
assert ref_path.startswith("#/$defs/") or ref_path.startswith("#/definitions/")
ref_name = ref_path.split("/")[-1]
defs = anthropic_schema.get("$defs", anthropic_schema.get("definitions", {}))
assert ref_name in defs, f"Referenced schema '{ref_name}' should exist"
item_schema = defs[ref_name]
else:
item_schema = array_items_schema
# Verify the nested object has all properties defined
assert "properties" in item_schema, "Array items should have properties (not bare dict)"
item_properties = item_schema["properties"]
# All three fields must be present in schema sent to LLM
assert "customerMaterialNumber" in item_properties, "customerMaterialNumber missing from LLM schema"
assert "quantity" in item_properties, "quantity missing from LLM schema"
assert "unitOfMeasure" in item_properties, "unitOfMeasure missing from LLM schema"
# Verify types are correct
assert item_properties["customerMaterialNumber"]["type"] == "string"
assert item_properties["quantity"]["type"] in ["number", "integer"]
assert item_properties["unitOfMeasure"]["type"] == "string"
# Test 2: OpenAI client uses tool.to_json_schema_spec()
openai_spec = test_func.to_json_schema_spec()
assert openai_spec["type"] == "function"
assert "function" in openai_spec
openai_schema = openai_spec["function"]["parameters"]
# Verify the same structure is present in OpenAI format
assert "items" in openai_schema["properties"]
openai_items_prop = openai_schema["properties"]["items"]
assert openai_items_prop["type"] == "array"
assert "items" in openai_items_prop
openai_array_items = openai_items_prop["items"]
if "$ref" in openai_array_items:
ref_path = openai_array_items["$ref"]
ref_name = ref_path.split("/")[-1]
defs = openai_schema.get("$defs", openai_schema.get("definitions", {}))
openai_item_schema = defs[ref_name]
else:
openai_item_schema = openai_array_items
assert "properties" in openai_item_schema
openai_props = openai_item_schema["properties"]
assert "customerMaterialNumber" in openai_props
assert "quantity" in openai_props
assert "unitOfMeasure" in openai_props
# Test validation - missing required quantity
with pytest.raises(ValidationError):
model(
contract="CONTRACT-456",
items=[
{
"customerMaterialNumber": "MAT-002",
"unitOfMeasure": "TO",
# Missing required 'quantity'
}
],
)
# Test validation - missing required unitOfMeasure
with pytest.raises(ValidationError):
model(
contract="CONTRACT-789",
items=[
{
"quantity": 20
# Missing required 'unitOfMeasure'
}
],
)
def test_one_of_discriminator_polymorphism():
"""Test that oneOf with discriminator creates proper polymorphic union types.
Tests that oneOf + discriminator patterns are properly converted to Pydantic discriminated unions.
"""
schema = {
"$defs": {
"CreateProject": {
"description": "Action: Create an Azure DevOps project.",
"properties": {
"name": {
"const": "create_project",
"default": "create_project",
"type": "string",
},
"params": {"$ref": "#/$defs/CreateProjectParams"},
},
"required": ["params"],
"type": "object",
},
"CreateProjectParams": {
"description": "Parameters for the create_project action.",
"properties": {
"orgUrl": {"minLength": 1, "type": "string"},
"projectName": {"minLength": 1, "type": "string"},
"description": {"default": "", "type": "string"},
"template": {"default": "Agile", "type": "string"},
"sourceControl": {
"default": "Git",
"enum": ["Git", "Tfvc"],
"type": "string",
},
"visibility": {"default": "private", "type": "string"},
},
"required": ["orgUrl", "projectName"],
"type": "object",
},
"DeployRequest": {
"description": "Request to deploy Azure DevOps resources.",
"properties": {
"projectName": {"minLength": 1, "type": "string"},
"organization": {"minLength": 1, "type": "string"},
"actions": {
"items": {
"discriminator": {
"mapping": {
"create_project": "#/$defs/CreateProject",
"hello_world": "#/$defs/HelloWorld",
},
"propertyName": "name",
},
"oneOf": [
{"$ref": "#/$defs/HelloWorld"},
{"$ref": "#/$defs/CreateProject"},
],
},
"type": "array",
},
},
"required": ["projectName", "organization"],
"type": "object",
},
"HelloWorld": {
"description": "Action: Prints a greeting message.",
"properties": {
"name": {
"const": "hello_world",
"default": "hello_world",
"type": "string",
},
"params": {"$ref": "#/$defs/HelloWorldParams"},
},
"required": ["params"],
"type": "object",
},
"HelloWorldParams": {
"description": "Parameters for the hello_world action.",
"properties": {
"name": {
"description": "Name to greet",
"minLength": 1,
"type": "string",
}
},
"required": ["name"],
"type": "object",
},
},
"properties": {"params": {"$ref": "#/$defs/DeployRequest"}},
"required": ["params"],
"type": "object",
}
# Build the model
model = _build_pydantic_model_from_json_schema("deploy_tool", schema)
# Verify the model structure
assert model is not None
assert issubclass(model, BaseModel)
# Test with HelloWorld action
hello_world_data = {
"params": {
"projectName": "MyProject",
"organization": "MyOrg",
"actions": [
{
"name": "hello_world",
"params": {"name": "Alice"},
}
],
}
}
instance = model(**hello_world_data)
assert instance.params.projectName == "MyProject"
assert instance.params.organization == "MyOrg"
assert len(instance.params.actions) == 1
assert instance.params.actions[0].name == "hello_world"
assert instance.params.actions[0].params.name == "Alice"
# Test with CreateProject action
create_project_data = {
"params": {
"projectName": "MyProject",
"organization": "MyOrg",
"actions": [
{
"name": "create_project",
"params": {
"orgUrl": "https://dev.azure.com/myorg",
"projectName": "NewProject",
"sourceControl": "Git",
},
}
],
}
}
instance2 = model(**create_project_data)
assert instance2.params.actions[0].name == "create_project"
assert instance2.params.actions[0].params.projectName == "NewProject"
assert instance2.params.actions[0].params.sourceControl == "Git"
# Test with mixed actions
mixed_data = {
"params": {
"projectName": "MyProject",
"organization": "MyOrg",
"actions": [
{"name": "hello_world", "params": {"name": "Bob"}},
{
"name": "create_project",
"params": {
"orgUrl": "https://dev.azure.com/myorg",
"projectName": "AnotherProject",
},
},
],
}
}
instance3 = model(**mixed_data)
assert len(instance3.params.actions) == 2
assert instance3.params.actions[0].name == "hello_world"
assert instance3.params.actions[1].name == "create_project"
def test_const_creates_literal():
"""Test that const in JSON Schema creates Literal type."""
schema = {
"properties": {
"action": {
"const": "create",
"type": "string",
"description": "Action type",
},
"value": {"type": "integer"},
},
"required": ["action", "value"],
}
model = _build_pydantic_model_from_json_schema("test_const", schema)
# Verify valid const value works
instance = model(action="create", value=42)
assert instance.action == "create"
assert instance.value == 42
# Verify incorrect const value fails
with pytest.raises(ValidationError):
model(action="delete", value=42)
def test_enum_creates_literal():
"""Test that enum in JSON Schema creates Literal type."""
schema = {
"properties": {
"status": {
"enum": ["pending", "approved", "rejected"],
"type": "string",
"description": "Status",
},
"priority": {"enum": [1, 2, 3], "type": "integer"},
},
"required": ["status"],
}
model = _build_pydantic_model_from_json_schema("test_enum", schema)
# Verify valid enum values work
instance = model(status="approved", priority=2)
assert instance.status == "approved"
assert instance.priority == 2
# Verify invalid enum value fails
with pytest.raises(ValidationError):
model(status="unknown")
with pytest.raises(ValidationError):
model(status="pending", priority=5)
def test_nested_object_with_const_and_enum():
"""Test that const and enum work in nested objects."""
schema = {
"properties": {
"config": {
"type": "object",
"properties": {
"type": {
"const": "production",
"default": "production",
"type": "string",
},
"level": {"enum": ["low", "medium", "high"], "type": "string"},
},
"required": ["level"],
}
},
"required": ["config"],
}
model = _build_pydantic_model_from_json_schema("test_nested", schema)
# Valid data
instance = model(config={"type": "production", "level": "high"})
assert instance.config.type == "production"
assert instance.config.level == "high"
# Invalid const in nested object
with pytest.raises(ValidationError):
model(config={"type": "development", "level": "low"})
# Invalid enum in nested object
with pytest.raises(ValidationError):
model(config={"type": "production", "level": "critical"})
# endregion