fix duplicate names between supplied tools and mcp servers (#4649)

This commit is contained in:
Eduard van Valkenburg
2026-03-13 09:22:56 +01:00
committed by GitHub
Unverified
parent 84bae0f42a
commit b7990908fe
8 changed files with 385 additions and 101 deletions
@@ -8,6 +8,7 @@ import logging
from typing import TYPE_CHECKING, Any
from agent_framework import BaseChatClient
from agent_framework._tools import _append_unique_tools # pyright: ignore[reportPrivateUsage]
if TYPE_CHECKING:
from agent_framework import SupportsAgentRun
@@ -22,7 +23,7 @@ def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]:
mcp_tools: List of MCP tool instances.
Returns:
List of functions from connected MCP tools.
Functions from connected MCP tools.
"""
functions: list[Any] = []
for mcp_tool in mcp_tools:
@@ -56,7 +57,11 @@ def collect_server_tools(agent: SupportsAgentRun) -> list[Any]:
# Include functions from connected MCP tools (only available on Agent)
mcp_tools = getattr(agent, "mcp_tools", None)
if mcp_tools:
server_tools.extend(_collect_mcp_tool_functions(mcp_tools))
_append_unique_tools(
server_tools,
_collect_mcp_tool_functions(mcp_tools),
duplicate_error_message="Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool.",
)
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
for tool in server_tools:
@@ -109,26 +114,13 @@ def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list
logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)")
return None
server_tool_names = {getattr(tool, "name", None) for tool in server_tools}
unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names]
if not unique_client_tools:
# Same check: must pass server tools if any require approval
if server_tools and _has_approval_tools(server_tools):
logger.info(
f"[TOOLS] Client tools duplicate server but server has approval tools - "
f"passing {len(server_tools)} server tools for approval mode"
)
return server_tools
logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter")
return None
combined_tools: list[Any] = []
if server_tools:
combined_tools.extend(server_tools)
combined_tools.extend(unique_client_tools)
combined_tools = _append_unique_tools(
list(server_tools),
client_tools,
duplicate_error_message="Tool names must be unique.",
)
logger.info(
f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools "
f"({len(server_tools)} server + {len(unique_client_tools)} unique client)"
f"({len(server_tools)} server + {len(client_tools)} client)"
)
return combined_tools
@@ -2,6 +2,7 @@
from unittest.mock import MagicMock
import pytest
from agent_framework import Agent, tool
from agent_framework_ag_ui._orchestration._tooling import (
@@ -20,7 +21,8 @@ class DummyTool:
class MockMCPTool:
"""Mock MCP tool that simulates connected MCP tool with functions."""
def __init__(self, functions: list[DummyTool], is_connected: bool = True) -> None:
def __init__(self, functions: list[DummyTool], is_connected: bool = True, name: str = "mock-mcp") -> None:
self.name = name
self.functions = functions
self.is_connected = is_connected
@@ -45,11 +47,8 @@ def test_merge_tools_filters_duplicates() -> None:
server = [DummyTool("a"), DummyTool("b")]
client = [DummyTool("b"), DummyTool("c")]
merged = merge_tools(server, client)
assert merged is not None
names = [getattr(t, "name", None) for t in merged]
assert names == ["a", "b", "c"]
with pytest.raises(ValueError, match="Duplicate tool name 'b'"):
merge_tools(server, client)
def test_register_additional_client_tools_assigns_when_configured() -> None:
@@ -131,6 +130,17 @@ def test_collect_server_tools_with_mcp_tools_via_public_property() -> None:
assert len(tools) == 2
def test_collect_server_tools_raises_on_duplicate_agent_and_mcp_tool_names() -> None:
duplicate_tool = DummyTool("regular_tool")
mock_mcp = MockMCPTool([duplicate_tool], is_connected=True, name="docs-mcp")
agent = _create_chat_agent_with_tool("regular_tool")
agent.mcp_tools = [mock_mcp]
with pytest.raises(ValueError, match="Duplicate tool name 'regular_tool'"):
collect_server_tools(agent)
# Additional tests for tooling coverage
@@ -176,11 +186,11 @@ def test_merge_tools_no_client_tools() -> None:
def test_merge_tools_all_duplicates() -> None:
"""merge_tools returns None when all client tools duplicate server tools."""
"""merge_tools raises when client and server tools share a name."""
server = [DummyTool("a"), DummyTool("b")]
client = [DummyTool("a"), DummyTool("b")]
result = merge_tools(server, client)
assert result is None
with pytest.raises(ValueError, match="Duplicate tool name 'a'"):
merge_tools(server, client)
def test_merge_tools_empty_server() -> None:
@@ -208,7 +218,7 @@ def test_merge_tools_with_approval_tools_no_client() -> None:
def test_merge_tools_with_approval_tools_all_duplicates() -> None:
"""merge_tools returns server tools with approval mode even when client duplicates."""
"""merge_tools raises even when a client tool duplicates an approval-gated server tool."""
class ApprovalTool:
def __init__(self, name: str):
@@ -217,7 +227,5 @@ def test_merge_tools_with_approval_tools_all_duplicates() -> None:
server = [ApprovalTool("write_doc")]
client = [DummyTool("write_doc")] # Same name as server
result = merge_tools(server, client)
assert result is not None
assert len(result) == 1
assert result[0].approval_mode == "always_require"
with pytest.raises(ValueError, match="Duplicate tool name 'write_doc'"):
merge_tools(server, client)
+29 -31
View File
@@ -29,6 +29,7 @@ from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError
from pydantic import BaseModel, Field, create_model
from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage]
from ._clients import BaseChatClient, SupportsChatGetResponse
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
@@ -40,12 +41,7 @@ from ._sessions import (
InMemoryHistoryProvider,
SessionContext,
)
from ._tools import (
FunctionInvocationLayer,
FunctionTool,
ToolTypes,
normalize_tools,
)
from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools
from ._types import (
AgentResponse,
AgentResponseUpdate,
@@ -79,6 +75,9 @@ if TYPE_CHECKING:
logger = logging.getLogger("agent_framework")
_append_unique_tools = _tool_utils._append_unique_tools # pyright: ignore[reportPrivateUsage]
_get_tool_name = _tool_utils._get_tool_name # pyright: ignore[reportPrivateUsage]
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
OptionsCoT = TypeVar(
"OptionsCoT",
@@ -88,19 +87,6 @@ OptionsCoT = TypeVar(
)
def _get_tool_name(tool: Any) -> str | None:
"""Extract a tool's name from either an object with a .name attribute or a dict tool definition."""
if isinstance(tool, Mapping):
tool_mapping = cast(Mapping[str, Any], tool)
func = tool_mapping.get("function")
if isinstance(func, Mapping):
func_mapping = cast(Mapping[str, Any], func)
name = func_mapping.get("name")
return name if isinstance(name, str) else None
return None
return getattr(tool, "name", None)
def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Merge two options dicts, with override values taking precedence.
@@ -115,11 +101,14 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
for key, value in override.items():
if value is None:
continue
if key == "tools" and result.get("tools"):
# Combine tool lists, avoiding duplicates by name
existing_names = {_get_tool_name(t) for t in result["tools"]} - {None}
unique_new = [t for t in value if _get_tool_name(t) not in existing_names]
result["tools"] = list(result["tools"]) + unique_new
if key == "tools" and (result.get("tools") or value):
base_tools = normalize_tools(result.get("tools"))
override_tools = normalize_tools(value)
result["tools"] = _append_unique_tools(
list(base_tools),
override_tools,
duplicate_error_message="Tool names must be unique.",
)
elif key == "logit_bias" and result.get("logit_bias"):
# Merge logit_bias dicts
result["logit_bias"] = {**result["logit_bias"], **value}
@@ -1117,25 +1106,34 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
)
agent_name = self._get_agent_name()
base_tools = normalize_tools(chat_options.pop("tools", None))
mcp_duplicate_message = "Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool."
# Normalize tools
normalized_tools = normalize_tools(tools_)
# Resolve final tool list (runtime provided tools + local MCP server tools)
final_tools: list[FunctionTool | Callable[..., Any] | dict[str, Any] | Any] = []
# Resolve final tool list (configured tools + runtime provided tools + local MCP server tools)
final_tools = list(base_tools)
for tool in normalized_tools:
if isinstance(tool, MCPTool):
if not tool.is_connected:
await self._async_exit_stack.enter_async_context(tool)
final_tools.extend(tool.functions) # type: ignore
_append_unique_tools(
final_tools,
tool.functions,
duplicate_error_message=mcp_duplicate_message,
)
else:
final_tools.append(tool) # type: ignore
_append_unique_tools(final_tools, [tool]) # type: ignore[list-item]
existing_names = {name for t in final_tools if (name := _get_tool_name(t)) is not None}
for mcp_server in self.mcp_tools:
if not mcp_server.is_connected:
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(f for f in mcp_server.functions if f.name not in existing_names)
_append_unique_tools(
final_tools,
mcp_server.functions,
duplicate_error_message=mcp_duplicate_message,
)
# Merge runtime kwargs into additional_function_arguments so they're available
# in function middleware context and tool invocation.
@@ -1164,7 +1162,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
"store": opts.pop("store", None),
"temperature": opts.pop("temperature", None),
"tool_choice": opts.pop("tool_choice", None),
"tools": final_tools,
"tools": final_tools or None,
"top_p": opts.pop("top_p", None),
"user": opts.pop("user", None),
**opts, # Remaining options are provider-specific
+63 -11
View File
@@ -26,9 +26,7 @@ from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from opentelemetry import propagate
from ._tools import (
FunctionTool,
)
from ._tools import FunctionTool
from ._types import (
Content,
Message,
@@ -59,6 +57,8 @@ class MCPSpecificApproval(TypedDict, total=False):
logger = logging.getLogger(__name__)
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
# region: Helpers
@@ -372,6 +372,20 @@ def _normalize_mcp_name(name: str) -> str:
return re.sub(r"[^A-Za-z0-9_.-]", "-", name)
def _build_prefixed_mcp_name(
normalized_name: str,
tool_name_prefix: str | None,
) -> str:
"""Build the exposed MCP function name from a normalized name and optional prefix."""
if not tool_name_prefix:
return normalized_name
normalized_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-")
if not normalized_prefix:
return normalized_name
trimmed_name = normalized_name.lstrip("_.-")
return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix
def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None:
"""Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s)."""
carrier: dict[str, str] = {}
@@ -415,6 +429,7 @@ class MCPTool:
description: str | None = None,
approval_mode: (Literal["always_require", "never_require"] | MCPSpecificApproval | None) = None,
allowed_tools: Collection[str] | None = None,
tool_name_prefix: str | None = None,
load_tools: bool = True,
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
load_prompts: bool = True,
@@ -435,6 +450,7 @@ class MCPTool:
description: A description of the MCP tool.
approval_mode: Whether approval is required to run tools.
allowed_tools: A collection of tool names to allow.
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
load_tools: Whether to load tools from the MCP server.
parse_tool_results: An optional callable with signature
``Callable[[types.CallToolResult], str]`` that overrides the default result
@@ -458,6 +474,7 @@ class MCPTool:
self.description = description or ""
self.approval_mode = approval_mode
self.allowed_tools = allowed_tools
self.tool_name_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-") if tool_name_prefix else None
self.additional_properties = additional_properties
self.load_tools_flag = load_tools
self.parse_tool_results = parse_tool_results
@@ -480,7 +497,19 @@ class MCPTool:
"""Get the list of functions that are allowed."""
if not self.allowed_tools:
return self._functions
return [func for func in self._functions if func.name in self.allowed_tools]
allowed_names = set(self.allowed_tools)
filtered_functions: list[FunctionTool] = []
for func in self._functions:
additional_properties = func.additional_properties or {}
normalized_name = additional_properties.get(_MCP_NORMALIZED_NAME_KEY)
remote_name = additional_properties.get(_MCP_REMOTE_NAME_KEY)
if (
func.name in allowed_names
or (isinstance(normalized_name, str) and normalized_name in allowed_names)
or (isinstance(remote_name, str) and remote_name in allowed_names)
):
filtered_functions.append(func)
return filtered_functions
async def _safe_close_exit_stack(self) -> None:
"""Safely close the exit stack, handling cross-task boundary errors.
@@ -706,12 +735,16 @@ class MCPTool:
def _determine_approval_mode(
self,
local_name: str,
*candidate_names: str,
) -> Literal["always_require", "never_require"] | None:
if isinstance(self.approval_mode, dict):
if (always_require := self.approval_mode.get("always_require_approval")) and local_name in always_require:
if (always_require := self.approval_mode.get("always_require_approval")) and any(
name in always_require for name in candidate_names
):
return "always_require"
if (never_require := self.approval_mode.get("never_require_approval")) and local_name in never_require:
if (never_require := self.approval_mode.get("never_require_approval")) and any(
name in never_require for name in candidate_names
):
return "never_require"
return None
return self.approval_mode # type: ignore[reportReturnType]
@@ -736,20 +769,25 @@ class MCPTool:
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
for prompt in prompt_list.prompts:
local_name = _normalize_mcp_name(prompt.name)
normalized_name = _normalize_mcp_name(prompt.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
# Skip if already loaded
if local_name in existing_names:
continue
input_model = _get_input_model_from_mcp_prompt(prompt)
approval_mode = self._determine_approval_mode(local_name)
approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name)
func: FunctionTool = FunctionTool(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
approval_mode=approval_mode,
input_model=input_model,
additional_properties={
_MCP_REMOTE_NAME_KEY: prompt.name,
_MCP_NORMALIZED_NAME_KEY: normalized_name,
},
)
self._functions.append(func)
existing_names.add(local_name)
@@ -779,13 +817,14 @@ class MCPTool:
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
for tool in tool_list.tools:
local_name = _normalize_mcp_name(tool.name)
normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
# Skip if already loaded
if local_name in existing_names:
continue
approval_mode = self._determine_approval_mode(local_name)
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
# Create FunctionTools out of each tool
func: FunctionTool = FunctionTool(
func=partial(self.call_tool, tool.name),
@@ -793,6 +832,10 @@ class MCPTool:
description=tool.description or "",
approval_mode=approval_mode,
input_model=tool.inputSchema,
additional_properties={
_MCP_REMOTE_NAME_KEY: tool.name,
_MCP_NORMALIZED_NAME_KEY: normalized_name,
},
)
self._functions.append(func)
existing_names.add(local_name)
@@ -1055,6 +1098,7 @@ class MCPStdioTool(MCPTool):
name: str,
command: str,
*,
tool_name_prefix: str | None = None,
load_tools: bool = True,
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
load_prompts: bool = True,
@@ -1083,6 +1127,7 @@ class MCPStdioTool(MCPTool):
command: The command to run the MCP server.
Keyword Args:
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
load_tools: Whether to load tools from the MCP server.
parse_tool_results: An optional callable with signature
``Callable[[types.CallToolResult], str]`` that overrides the default result
@@ -1119,6 +1164,7 @@ class MCPStdioTool(MCPTool):
description=description,
approval_mode=approval_mode,
allowed_tools=allowed_tools,
tool_name_prefix=tool_name_prefix,
additional_properties=additional_properties,
session=session,
client=client,
@@ -1180,6 +1226,7 @@ class MCPStreamableHTTPTool(MCPTool):
name: str,
url: str,
*,
tool_name_prefix: str | None = None,
load_tools: bool = True,
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
load_prompts: bool = True,
@@ -1208,6 +1255,7 @@ class MCPStreamableHTTPTool(MCPTool):
url: The URL of the MCP server.
Keyword Args:
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
load_tools: Whether to load tools from the MCP server.
parse_tool_results: An optional callable with signature
``Callable[[types.CallToolResult], str]`` that overrides the default result
@@ -1246,6 +1294,7 @@ class MCPStreamableHTTPTool(MCPTool):
description=description,
approval_mode=approval_mode,
allowed_tools=allowed_tools,
tool_name_prefix=tool_name_prefix,
additional_properties=additional_properties,
session=session,
client=client,
@@ -1299,6 +1348,7 @@ class MCPWebsocketTool(MCPTool):
name: str,
url: str,
*,
tool_name_prefix: str | None = None,
load_tools: bool = True,
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
load_prompts: bool = True,
@@ -1325,6 +1375,7 @@ class MCPWebsocketTool(MCPTool):
url: The URL of the MCP server.
Keyword Args:
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
load_tools: Whether to load tools from the MCP server.
parse_tool_results: An optional callable with signature
``Callable[[types.CallToolResult], str]`` that overrides the default result
@@ -1358,6 +1409,7 @@ class MCPWebsocketTool(MCPTool):
description=description,
approval_mode=approval_mode,
allowed_tools=allowed_tools,
tool_name_prefix=tool_name_prefix,
additional_properties=additional_properties,
session=session,
client=client,
+60 -2
View File
@@ -71,7 +71,6 @@ if TYPE_CHECKING:
ResponseStream,
)
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
else:
MCPTool = Any # type: ignore[assignment,misc]
@@ -83,9 +82,23 @@ DEFAULT_MAX_ITERATIONS: Final[int] = 40
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
SHELL_TOOL_KIND_VALUE: Final[str] = "shell"
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
# region Helpers
def _get_tool_name(tool: Any) -> str | None:
"""Extract a tool name from a tool object or dict tool definition."""
if isinstance(tool, Mapping):
func = tool.get("function", None) # type: ignore
if func and isinstance(func, Mapping):
name = func.get("name") # type: ignore
return name if isinstance(name, str) else None
return None
name = getattr(tool, "name", None)
return name if isinstance(name, str) else None
def _parse_inputs( # pyright: ignore[reportUnusedFunction]
inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None,
) -> list[Content]:
@@ -701,6 +714,51 @@ class FunctionTool(SerializationMixin):
ToolTypes: TypeAlias = FunctionTool | MCPTool | Mapping[str, Any] | object
def _raise_duplicate_tool_name(tool_name: str, duplicate_error_message: str | None = None) -> None:
message = duplicate_error_message or "Tool names must be unique."
raise ValueError(f"Duplicate tool name '{tool_name}'. {message}")
def _append_unique_tools(
existing_tools: list[ToolTypes],
new_tools: Sequence[ToolTypes],
*,
duplicate_error_message: str | None = None,
) -> list[ToolTypes]:
seen_by_name: dict[str, ToolTypes] = {}
for tool_item in existing_tools:
if tool_name := _get_tool_name(tool_item):
seen_by_name[tool_name] = tool_item
for tool_item in new_tools:
tool_name = _get_tool_name(tool_item)
if tool_name is None:
existing_tools.append(tool_item)
continue
existing_tool = seen_by_name.get(tool_name)
if existing_tool is None:
seen_by_name[tool_name] = tool_item
existing_tools.append(tool_item)
continue
if existing_tool is tool_item:
continue
_raise_duplicate_tool_name(tool_name, duplicate_error_message)
return existing_tools
def _ensure_unique_tool_names(
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]],
*,
duplicate_error_message: str | None = None,
) -> list[ToolTypes]:
normalized_tools = normalize_tools(tools)
return _append_unique_tools([], normalized_tools, duplicate_error_message=duplicate_error_message)
def normalize_tools(
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None,
) -> list[ToolTypes]:
@@ -1320,7 +1378,7 @@ def _get_tool_map(
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]],
) -> dict[str, FunctionTool]:
tool_list: dict[str, FunctionTool] = {}
for tool_item in normalize_tools(tools):
for tool_item in _ensure_unique_tool_names(tools):
if isinstance(tool_item, FunctionTool):
tool_list[tool_item.name] = tool_item
return tool_list
+122 -21
View File
@@ -30,7 +30,7 @@ from agent_framework import (
tool,
)
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
from agent_framework._mcp import MCPTool
from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name
class _FixedTokenizer:
@@ -41,6 +41,30 @@ class _FixedTokenizer:
return self.token_count
class _ConnectedMCPTool(MCPTool):
def __init__(self, name: str, function_names: list[str], *, tool_name_prefix: str | None = None) -> None:
super().__init__(name=name, tool_name_prefix=tool_name_prefix)
self.is_connected = True
self._functions = []
for function_name in function_names:
normalized_name = _normalize_mcp_name(function_name)
exposed_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
self._functions.append(
FunctionTool(
func=lambda value=function_name: value,
name=exposed_name,
description=f"{function_name} from {name}",
additional_properties={
"_mcp_remote_name": function_name,
"_mcp_normalized_name": normalized_name,
},
)
)
def get_mcp_client(self) -> contextlib.AbstractAsyncContextManager[Any]:
raise NotImplementedError
def test_agent_session_type(agent_session: AgentSession) -> None:
assert isinstance(agent_session, AgentSession)
@@ -953,6 +977,7 @@ async def test_chat_agent_run_with_mcp_tools(client: SupportsChatGetResponse) ->
# Create a mock MCP tool
mock_mcp_tool = MagicMock(spec=MCPTool)
mock_mcp_tool.name = "mock-mcp"
mock_mcp_tool.is_connected = False
mock_mcp_tool.functions = [MagicMock()]
@@ -970,6 +995,7 @@ async def test_chat_agent_with_local_mcp_tools(client: SupportsChatGetResponse)
"""Test agent initialization with local MCP tools."""
# Create a mock MCP tool
mock_mcp_tool = MagicMock(spec=MCPTool)
mock_mcp_tool.name = "mock-mcp"
mock_mcp_tool.is_connected = False
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
mock_mcp_tool.__aexit__ = AsyncMock(return_value=None)
@@ -1009,6 +1035,7 @@ async def test_mcp_tools_not_duplicated_when_passed_as_runtime_tools(
# Create a mock MCP tool that is already connected (simulates turn 2)
mock_mcp_tool = MagicMock(spec=MCPTool)
mock_mcp_tool.name = "mock-mcp"
mock_mcp_tool.is_connected = True
mock_mcp_tool.functions = [mcp_func_a, mcp_func_b]
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
@@ -1032,6 +1059,77 @@ async def test_mcp_tools_not_duplicated_when_passed_as_runtime_tools(
assert len(tool_names) == 3
async def test_agent_run_raises_on_local_and_agent_mcp_name_conflict(chat_client_base: Any) -> None:
local_tool = FunctionTool(
func=lambda: "local",
name="delete_all_data",
description="Local protected tool",
approval_mode="always_require",
)
agent = Agent(
client=chat_client_base,
name="TestAgent",
tools=[_ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"])],
)
with raises(ValueError, match="tool_name_prefix"):
await agent.run("hello", tools=[local_tool])
async def test_agent_run_raises_on_runtime_local_and_runtime_mcp_name_conflict(chat_client_base: Any) -> None:
local_tool = FunctionTool(
func=lambda: "local",
name="delete_all_data",
description="Local protected tool",
approval_mode="always_require",
)
runtime_mcp = _ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"])
agent = Agent(client=chat_client_base, name="TestAgent")
with raises(ValueError, match="tool_name_prefix"):
await agent.run("hello", tools=[local_tool, runtime_mcp])
async def test_agent_run_raises_on_duplicate_agent_mcp_names(chat_client_base: Any) -> None:
agent = Agent(
client=chat_client_base,
name="TestAgent",
tools=[
_ConnectedMCPTool(name="docs-mcp", function_names=["search"]),
_ConnectedMCPTool(name="github-mcp", function_names=["search"]),
],
)
with raises(ValueError, match="tool_name_prefix"):
await agent.run("hello")
async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> None:
captured_options: list[dict[str, Any]] = []
original_inner = chat_client_base._inner_get_response
async def capturing_inner(
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
) -> ChatResponse:
captured_options.append(dict(options))
return await original_inner(messages=messages, options=options, **kwargs)
chat_client_base._inner_get_response = capturing_inner
local_tool = FunctionTool(func=lambda: "local", name="search", description="Local search tool")
agent = Agent(
client=chat_client_base,
name="TestAgent",
tools=[_ConnectedMCPTool(name="docs-mcp", function_names=["search"], tool_name_prefix="docs")],
)
await agent.run("hello", tools=[local_tool])
tool_names = [tool.name for tool in captured_options[0]["tools"]]
assert tool_names == ["search", "docs_search"]
async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None:
"""Verify tool execution receives 'session' inside **kwargs when function is called by client."""
@@ -1291,7 +1389,7 @@ def test_merge_options_none_values_ignored():
def test_merge_options_tools_combined():
"""Test _merge_options combines tool lists without duplicates."""
"""Test _merge_options raises when distinct tools share the same name."""
class MockTool:
def __init__(self, name):
@@ -1304,13 +1402,8 @@ def test_merge_options_tools_combined():
base = {"tools": [tool1]}
override = {"tools": [tool2, tool3]}
result = _merge_options(base, override)
# Should have tool1 and tool2, but not duplicate tool3
assert len(result["tools"]) == 2
tool_names = [t.name for t in result["tools"]]
assert "tool1" in tool_names
assert "tool2" in tool_names
with raises(ValueError, match="Duplicate tool name 'tool1'"):
_merge_options(base, override)
def test_merge_options_dict_tools_combined():
@@ -1335,7 +1428,7 @@ def test_merge_options_dict_tools_combined():
def test_merge_options_dict_tools_deduplicates():
"""Test _merge_options deduplicates dict-defined tools by function name."""
"""Test _merge_options raises on duplicate dict-defined tool names."""
base = {
"tools": [
{"type": "function", "function": {"name": "tool_a"}},
@@ -1348,12 +1441,8 @@ def test_merge_options_dict_tools_deduplicates():
]
}
result = _merge_options(base, override)
assert len(result["tools"]) == 2
names = [_get_tool_name(t) for t in result["tools"]]
assert names.count("tool_a") == 1
assert "tool_b" in names
with raises(ValueError, match="Duplicate tool name 'tool_a'"):
_merge_options(base, override)
def test_merge_options_mixed_tools_combined():
@@ -1379,7 +1468,7 @@ def test_merge_options_mixed_tools_combined():
def test_merge_options_mixed_tools_deduplicates():
"""Test _merge_options deduplicates when a dict tool and object tool share the same name."""
"""Test _merge_options raises when a dict tool and object tool share the same name."""
class MockTool:
def __init__(self, name):
@@ -1392,10 +1481,8 @@ def test_merge_options_mixed_tools_deduplicates():
]
}
result = _merge_options(base, override)
assert len(result["tools"]) == 1
assert _get_tool_name(result["tools"][0]) == "tool_a"
with raises(ValueError, match="Duplicate tool name 'tool_a'"):
_merge_options(base, override)
def test_merge_options_nameless_tools_not_deduplicated():
@@ -1417,6 +1504,20 @@ def test_merge_options_nameless_tools_not_deduplicated():
assert len(result["tools"]) == 2
def test_merge_options_same_tool_object_kept_once():
"""Test _merge_options silently keeps a repeated reference to the same tool object once."""
class MockTool:
def __init__(self, name):
self.name = name
tool_a = MockTool("tool_a")
result = _merge_options({"tools": [tool_a]}, {"tools": [tool_a]})
assert result["tools"] == [tool_a]
def test_get_tool_name_dict_no_function_key():
"""_get_tool_name returns None for a dict without a 'function' key."""
assert _get_tool_name({"type": "function"}) is None
@@ -53,6 +53,81 @@ def test_normalize_mcp_name():
assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes"
def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None:
assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio"
assert (
MCPStreamableHTTPTool(
name="http",
url="https://example.com/mcp",
tool_name_prefix="http",
).tool_name_prefix
== "http"
)
assert (
MCPWebsocketTool(
name="ws",
url="wss://example.com/mcp",
tool_name_prefix="ws",
).tool_name_prefix
== "ws"
)
async def test_load_tools_with_tool_name_prefix_preserves_matching_configuration():
"""Prefixed MCP tool names should still honor unprefixed allow/approval configuration."""
tool = MCPTool(
name="docs",
tool_name_prefix="docs",
allowed_tools=["search_docs"],
approval_mode={"always_require_approval": ["search_docs"]},
)
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
page = Mock()
page.tools = [
types.Tool(
name="search_docs",
description="Search docs",
inputSchema={"type": "object", "properties": {"query": {"type": "string"}}},
),
]
page.nextCursor = None
mock_session.list_tools = AsyncMock(return_value=page)
await tool.load_tools()
assert [function.name for function in tool._functions] == ["docs_search_docs"]
assert [function.name for function in tool.functions] == ["docs_search_docs"]
assert tool.functions[0].approval_mode == "always_require"
async def test_load_prompts_with_tool_name_prefix() -> None:
"""Prefixed MCP prompt names should be exposed with the configured prefix."""
tool = MCPTool(name="docs", tool_name_prefix="docs")
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
page = Mock()
page.prompts = [
types.Prompt(
name="summarize docs",
description="Summarize docs",
arguments=[types.PromptArgument(name="topic", description="Topic", required=True)],
),
]
page.nextCursor = None
mock_session.list_prompts = AsyncMock(return_value=page)
await tool.load_prompts()
assert [function.name for function in tool._functions] == ["docs_summarize-docs"]
def test_mcp_prompt_message_to_ai_content():
"""Test conversion from MCP prompt message to AI content."""
mcp_message = types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello, world!"))