mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
fix duplicate names between supplied tools and mcp servers (#4649)
This commit is contained in:
committed by
GitHub
Unverified
parent
84bae0f42a
commit
b7990908fe
@@ -8,6 +8,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agent_framework import BaseChatClient
|
||||
from agent_framework._tools import _append_unique_tools # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import SupportsAgentRun
|
||||
@@ -22,7 +23,7 @@ def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]:
|
||||
mcp_tools: List of MCP tool instances.
|
||||
|
||||
Returns:
|
||||
List of functions from connected MCP tools.
|
||||
Functions from connected MCP tools.
|
||||
"""
|
||||
functions: list[Any] = []
|
||||
for mcp_tool in mcp_tools:
|
||||
@@ -56,7 +57,11 @@ def collect_server_tools(agent: SupportsAgentRun) -> list[Any]:
|
||||
# Include functions from connected MCP tools (only available on Agent)
|
||||
mcp_tools = getattr(agent, "mcp_tools", None)
|
||||
if mcp_tools:
|
||||
server_tools.extend(_collect_mcp_tool_functions(mcp_tools))
|
||||
_append_unique_tools(
|
||||
server_tools,
|
||||
_collect_mcp_tool_functions(mcp_tools),
|
||||
duplicate_error_message="Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool.",
|
||||
)
|
||||
|
||||
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
|
||||
for tool in server_tools:
|
||||
@@ -109,26 +114,13 @@ def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list
|
||||
logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)")
|
||||
return None
|
||||
|
||||
server_tool_names = {getattr(tool, "name", None) for tool in server_tools}
|
||||
unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names]
|
||||
|
||||
if not unique_client_tools:
|
||||
# Same check: must pass server tools if any require approval
|
||||
if server_tools and _has_approval_tools(server_tools):
|
||||
logger.info(
|
||||
f"[TOOLS] Client tools duplicate server but server has approval tools - "
|
||||
f"passing {len(server_tools)} server tools for approval mode"
|
||||
)
|
||||
return server_tools
|
||||
logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter")
|
||||
return None
|
||||
|
||||
combined_tools: list[Any] = []
|
||||
if server_tools:
|
||||
combined_tools.extend(server_tools)
|
||||
combined_tools.extend(unique_client_tools)
|
||||
combined_tools = _append_unique_tools(
|
||||
list(server_tools),
|
||||
client_tools,
|
||||
duplicate_error_message="Tool names must be unique.",
|
||||
)
|
||||
logger.info(
|
||||
f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools "
|
||||
f"({len(server_tools)} server + {len(unique_client_tools)} unique client)"
|
||||
f"({len(server_tools)} server + {len(client_tools)} client)"
|
||||
)
|
||||
return combined_tools
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from agent_framework import Agent, tool
|
||||
|
||||
from agent_framework_ag_ui._orchestration._tooling import (
|
||||
@@ -20,7 +21,8 @@ class DummyTool:
|
||||
class MockMCPTool:
|
||||
"""Mock MCP tool that simulates connected MCP tool with functions."""
|
||||
|
||||
def __init__(self, functions: list[DummyTool], is_connected: bool = True) -> None:
|
||||
def __init__(self, functions: list[DummyTool], is_connected: bool = True, name: str = "mock-mcp") -> None:
|
||||
self.name = name
|
||||
self.functions = functions
|
||||
self.is_connected = is_connected
|
||||
|
||||
@@ -45,11 +47,8 @@ def test_merge_tools_filters_duplicates() -> None:
|
||||
server = [DummyTool("a"), DummyTool("b")]
|
||||
client = [DummyTool("b"), DummyTool("c")]
|
||||
|
||||
merged = merge_tools(server, client)
|
||||
|
||||
assert merged is not None
|
||||
names = [getattr(t, "name", None) for t in merged]
|
||||
assert names == ["a", "b", "c"]
|
||||
with pytest.raises(ValueError, match="Duplicate tool name 'b'"):
|
||||
merge_tools(server, client)
|
||||
|
||||
|
||||
def test_register_additional_client_tools_assigns_when_configured() -> None:
|
||||
@@ -131,6 +130,17 @@ def test_collect_server_tools_with_mcp_tools_via_public_property() -> None:
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
def test_collect_server_tools_raises_on_duplicate_agent_and_mcp_tool_names() -> None:
|
||||
duplicate_tool = DummyTool("regular_tool")
|
||||
mock_mcp = MockMCPTool([duplicate_tool], is_connected=True, name="docs-mcp")
|
||||
|
||||
agent = _create_chat_agent_with_tool("regular_tool")
|
||||
agent.mcp_tools = [mock_mcp]
|
||||
|
||||
with pytest.raises(ValueError, match="Duplicate tool name 'regular_tool'"):
|
||||
collect_server_tools(agent)
|
||||
|
||||
|
||||
# Additional tests for tooling coverage
|
||||
|
||||
|
||||
@@ -176,11 +186,11 @@ def test_merge_tools_no_client_tools() -> None:
|
||||
|
||||
|
||||
def test_merge_tools_all_duplicates() -> None:
|
||||
"""merge_tools returns None when all client tools duplicate server tools."""
|
||||
"""merge_tools raises when client and server tools share a name."""
|
||||
server = [DummyTool("a"), DummyTool("b")]
|
||||
client = [DummyTool("a"), DummyTool("b")]
|
||||
result = merge_tools(server, client)
|
||||
assert result is None
|
||||
with pytest.raises(ValueError, match="Duplicate tool name 'a'"):
|
||||
merge_tools(server, client)
|
||||
|
||||
|
||||
def test_merge_tools_empty_server() -> None:
|
||||
@@ -208,7 +218,7 @@ def test_merge_tools_with_approval_tools_no_client() -> None:
|
||||
|
||||
|
||||
def test_merge_tools_with_approval_tools_all_duplicates() -> None:
|
||||
"""merge_tools returns server tools with approval mode even when client duplicates."""
|
||||
"""merge_tools raises even when a client tool duplicates an approval-gated server tool."""
|
||||
|
||||
class ApprovalTool:
|
||||
def __init__(self, name: str):
|
||||
@@ -217,7 +227,5 @@ def test_merge_tools_with_approval_tools_all_duplicates() -> None:
|
||||
|
||||
server = [ApprovalTool("write_doc")]
|
||||
client = [DummyTool("write_doc")] # Same name as server
|
||||
result = merge_tools(server, client)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].approval_mode == "always_require"
|
||||
with pytest.raises(ValueError, match="Duplicate tool name 'write_doc'"):
|
||||
merge_tools(server, client)
|
||||
|
||||
@@ -29,6 +29,7 @@ from mcp.server.lowlevel import Server
|
||||
from mcp.shared.exceptions import McpError
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage]
|
||||
from ._clients import BaseChatClient, SupportsChatGetResponse
|
||||
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
|
||||
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
|
||||
@@ -40,12 +41,7 @@ from ._sessions import (
|
||||
InMemoryHistoryProvider,
|
||||
SessionContext,
|
||||
)
|
||||
from ._tools import (
|
||||
FunctionInvocationLayer,
|
||||
FunctionTool,
|
||||
ToolTypes,
|
||||
normalize_tools,
|
||||
)
|
||||
from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
@@ -79,6 +75,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
_append_unique_tools = _tool_utils._append_unique_tools # pyright: ignore[reportPrivateUsage]
|
||||
_get_tool_name = _tool_utils._get_tool_name # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
OptionsCoT = TypeVar(
|
||||
"OptionsCoT",
|
||||
@@ -88,19 +87,6 @@ OptionsCoT = TypeVar(
|
||||
)
|
||||
|
||||
|
||||
def _get_tool_name(tool: Any) -> str | None:
|
||||
"""Extract a tool's name from either an object with a .name attribute or a dict tool definition."""
|
||||
if isinstance(tool, Mapping):
|
||||
tool_mapping = cast(Mapping[str, Any], tool)
|
||||
func = tool_mapping.get("function")
|
||||
if isinstance(func, Mapping):
|
||||
func_mapping = cast(Mapping[str, Any], func)
|
||||
name = func_mapping.get("name")
|
||||
return name if isinstance(name, str) else None
|
||||
return None
|
||||
return getattr(tool, "name", None)
|
||||
|
||||
|
||||
def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Merge two options dicts, with override values taking precedence.
|
||||
|
||||
@@ -115,11 +101,14 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
|
||||
for key, value in override.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key == "tools" and result.get("tools"):
|
||||
# Combine tool lists, avoiding duplicates by name
|
||||
existing_names = {_get_tool_name(t) for t in result["tools"]} - {None}
|
||||
unique_new = [t for t in value if _get_tool_name(t) not in existing_names]
|
||||
result["tools"] = list(result["tools"]) + unique_new
|
||||
if key == "tools" and (result.get("tools") or value):
|
||||
base_tools = normalize_tools(result.get("tools"))
|
||||
override_tools = normalize_tools(value)
|
||||
result["tools"] = _append_unique_tools(
|
||||
list(base_tools),
|
||||
override_tools,
|
||||
duplicate_error_message="Tool names must be unique.",
|
||||
)
|
||||
elif key == "logit_bias" and result.get("logit_bias"):
|
||||
# Merge logit_bias dicts
|
||||
result["logit_bias"] = {**result["logit_bias"], **value}
|
||||
@@ -1117,25 +1106,34 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
)
|
||||
|
||||
agent_name = self._get_agent_name()
|
||||
base_tools = normalize_tools(chat_options.pop("tools", None))
|
||||
mcp_duplicate_message = "Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool."
|
||||
|
||||
# Normalize tools
|
||||
normalized_tools = normalize_tools(tools_)
|
||||
|
||||
# Resolve final tool list (runtime provided tools + local MCP server tools)
|
||||
final_tools: list[FunctionTool | Callable[..., Any] | dict[str, Any] | Any] = []
|
||||
# Resolve final tool list (configured tools + runtime provided tools + local MCP server tools)
|
||||
final_tools = list(base_tools)
|
||||
for tool in normalized_tools:
|
||||
if isinstance(tool, MCPTool):
|
||||
if not tool.is_connected:
|
||||
await self._async_exit_stack.enter_async_context(tool)
|
||||
final_tools.extend(tool.functions) # type: ignore
|
||||
_append_unique_tools(
|
||||
final_tools,
|
||||
tool.functions,
|
||||
duplicate_error_message=mcp_duplicate_message,
|
||||
)
|
||||
else:
|
||||
final_tools.append(tool) # type: ignore
|
||||
_append_unique_tools(final_tools, [tool]) # type: ignore[list-item]
|
||||
|
||||
existing_names = {name for t in final_tools if (name := _get_tool_name(t)) is not None}
|
||||
for mcp_server in self.mcp_tools:
|
||||
if not mcp_server.is_connected:
|
||||
await self._async_exit_stack.enter_async_context(mcp_server)
|
||||
final_tools.extend(f for f in mcp_server.functions if f.name not in existing_names)
|
||||
_append_unique_tools(
|
||||
final_tools,
|
||||
mcp_server.functions,
|
||||
duplicate_error_message=mcp_duplicate_message,
|
||||
)
|
||||
|
||||
# Merge runtime kwargs into additional_function_arguments so they're available
|
||||
# in function middleware context and tool invocation.
|
||||
@@ -1164,7 +1162,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
"store": opts.pop("store", None),
|
||||
"temperature": opts.pop("temperature", None),
|
||||
"tool_choice": opts.pop("tool_choice", None),
|
||||
"tools": final_tools,
|
||||
"tools": final_tools or None,
|
||||
"top_p": opts.pop("top_p", None),
|
||||
"user": opts.pop("user", None),
|
||||
**opts, # Remaining options are provider-specific
|
||||
|
||||
@@ -26,9 +26,7 @@ from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
from opentelemetry import propagate
|
||||
|
||||
from ._tools import (
|
||||
FunctionTool,
|
||||
)
|
||||
from ._tools import FunctionTool
|
||||
from ._types import (
|
||||
Content,
|
||||
Message,
|
||||
@@ -59,6 +57,8 @@ class MCPSpecificApproval(TypedDict, total=False):
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
|
||||
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
|
||||
|
||||
# region: Helpers
|
||||
|
||||
@@ -372,6 +372,20 @@ def _normalize_mcp_name(name: str) -> str:
|
||||
return re.sub(r"[^A-Za-z0-9_.-]", "-", name)
|
||||
|
||||
|
||||
def _build_prefixed_mcp_name(
|
||||
normalized_name: str,
|
||||
tool_name_prefix: str | None,
|
||||
) -> str:
|
||||
"""Build the exposed MCP function name from a normalized name and optional prefix."""
|
||||
if not tool_name_prefix:
|
||||
return normalized_name
|
||||
normalized_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-")
|
||||
if not normalized_prefix:
|
||||
return normalized_name
|
||||
trimmed_name = normalized_name.lstrip("_.-")
|
||||
return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix
|
||||
|
||||
|
||||
def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None:
|
||||
"""Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s)."""
|
||||
carrier: dict[str, str] = {}
|
||||
@@ -415,6 +429,7 @@ class MCPTool:
|
||||
description: str | None = None,
|
||||
approval_mode: (Literal["always_require", "never_require"] | MCPSpecificApproval | None) = None,
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
|
||||
load_prompts: bool = True,
|
||||
@@ -435,6 +450,7 @@ class MCPTool:
|
||||
description: A description of the MCP tool.
|
||||
approval_mode: Whether approval is required to run tools.
|
||||
allowed_tools: A collection of tool names to allow.
|
||||
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
|
||||
load_tools: Whether to load tools from the MCP server.
|
||||
parse_tool_results: An optional callable with signature
|
||||
``Callable[[types.CallToolResult], str]`` that overrides the default result
|
||||
@@ -458,6 +474,7 @@ class MCPTool:
|
||||
self.description = description or ""
|
||||
self.approval_mode = approval_mode
|
||||
self.allowed_tools = allowed_tools
|
||||
self.tool_name_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-") if tool_name_prefix else None
|
||||
self.additional_properties = additional_properties
|
||||
self.load_tools_flag = load_tools
|
||||
self.parse_tool_results = parse_tool_results
|
||||
@@ -480,7 +497,19 @@ class MCPTool:
|
||||
"""Get the list of functions that are allowed."""
|
||||
if not self.allowed_tools:
|
||||
return self._functions
|
||||
return [func for func in self._functions if func.name in self.allowed_tools]
|
||||
allowed_names = set(self.allowed_tools)
|
||||
filtered_functions: list[FunctionTool] = []
|
||||
for func in self._functions:
|
||||
additional_properties = func.additional_properties or {}
|
||||
normalized_name = additional_properties.get(_MCP_NORMALIZED_NAME_KEY)
|
||||
remote_name = additional_properties.get(_MCP_REMOTE_NAME_KEY)
|
||||
if (
|
||||
func.name in allowed_names
|
||||
or (isinstance(normalized_name, str) and normalized_name in allowed_names)
|
||||
or (isinstance(remote_name, str) and remote_name in allowed_names)
|
||||
):
|
||||
filtered_functions.append(func)
|
||||
return filtered_functions
|
||||
|
||||
async def _safe_close_exit_stack(self) -> None:
|
||||
"""Safely close the exit stack, handling cross-task boundary errors.
|
||||
@@ -706,12 +735,16 @@ class MCPTool:
|
||||
|
||||
def _determine_approval_mode(
|
||||
self,
|
||||
local_name: str,
|
||||
*candidate_names: str,
|
||||
) -> Literal["always_require", "never_require"] | None:
|
||||
if isinstance(self.approval_mode, dict):
|
||||
if (always_require := self.approval_mode.get("always_require_approval")) and local_name in always_require:
|
||||
if (always_require := self.approval_mode.get("always_require_approval")) and any(
|
||||
name in always_require for name in candidate_names
|
||||
):
|
||||
return "always_require"
|
||||
if (never_require := self.approval_mode.get("never_require_approval")) and local_name in never_require:
|
||||
if (never_require := self.approval_mode.get("never_require_approval")) and any(
|
||||
name in never_require for name in candidate_names
|
||||
):
|
||||
return "never_require"
|
||||
return None
|
||||
return self.approval_mode # type: ignore[reportReturnType]
|
||||
@@ -736,20 +769,25 @@ class MCPTool:
|
||||
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
|
||||
|
||||
for prompt in prompt_list.prompts:
|
||||
local_name = _normalize_mcp_name(prompt.name)
|
||||
normalized_name = _normalize_mcp_name(prompt.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
input_model = _get_input_model_from_mcp_prompt(prompt)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name)
|
||||
func: FunctionTool = FunctionTool(
|
||||
func=partial(self.get_prompt, prompt.name),
|
||||
name=local_name,
|
||||
description=prompt.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
additional_properties={
|
||||
_MCP_REMOTE_NAME_KEY: prompt.name,
|
||||
_MCP_NORMALIZED_NAME_KEY: normalized_name,
|
||||
},
|
||||
)
|
||||
self._functions.append(func)
|
||||
existing_names.add(local_name)
|
||||
@@ -779,13 +817,14 @@ class MCPTool:
|
||||
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
|
||||
|
||||
for tool in tool_list.tools:
|
||||
local_name = _normalize_mcp_name(tool.name)
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
|
||||
# Create FunctionTools out of each tool
|
||||
func: FunctionTool = FunctionTool(
|
||||
func=partial(self.call_tool, tool.name),
|
||||
@@ -793,6 +832,10 @@ class MCPTool:
|
||||
description=tool.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=tool.inputSchema,
|
||||
additional_properties={
|
||||
_MCP_REMOTE_NAME_KEY: tool.name,
|
||||
_MCP_NORMALIZED_NAME_KEY: normalized_name,
|
||||
},
|
||||
)
|
||||
self._functions.append(func)
|
||||
existing_names.add(local_name)
|
||||
@@ -1055,6 +1098,7 @@ class MCPStdioTool(MCPTool):
|
||||
name: str,
|
||||
command: str,
|
||||
*,
|
||||
tool_name_prefix: str | None = None,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
|
||||
load_prompts: bool = True,
|
||||
@@ -1083,6 +1127,7 @@ class MCPStdioTool(MCPTool):
|
||||
command: The command to run the MCP server.
|
||||
|
||||
Keyword Args:
|
||||
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
|
||||
load_tools: Whether to load tools from the MCP server.
|
||||
parse_tool_results: An optional callable with signature
|
||||
``Callable[[types.CallToolResult], str]`` that overrides the default result
|
||||
@@ -1119,6 +1164,7 @@ class MCPStdioTool(MCPTool):
|
||||
description=description,
|
||||
approval_mode=approval_mode,
|
||||
allowed_tools=allowed_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
additional_properties=additional_properties,
|
||||
session=session,
|
||||
client=client,
|
||||
@@ -1180,6 +1226,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
name: str,
|
||||
url: str,
|
||||
*,
|
||||
tool_name_prefix: str | None = None,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
|
||||
load_prompts: bool = True,
|
||||
@@ -1208,6 +1255,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
url: The URL of the MCP server.
|
||||
|
||||
Keyword Args:
|
||||
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
|
||||
load_tools: Whether to load tools from the MCP server.
|
||||
parse_tool_results: An optional callable with signature
|
||||
``Callable[[types.CallToolResult], str]`` that overrides the default result
|
||||
@@ -1246,6 +1294,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
description=description,
|
||||
approval_mode=approval_mode,
|
||||
allowed_tools=allowed_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
additional_properties=additional_properties,
|
||||
session=session,
|
||||
client=client,
|
||||
@@ -1299,6 +1348,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
name: str,
|
||||
url: str,
|
||||
*,
|
||||
tool_name_prefix: str | None = None,
|
||||
load_tools: bool = True,
|
||||
parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None,
|
||||
load_prompts: bool = True,
|
||||
@@ -1325,6 +1375,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
url: The URL of the MCP server.
|
||||
|
||||
Keyword Args:
|
||||
tool_name_prefix: Optional prefix to prepend to exposed MCP function names.
|
||||
load_tools: Whether to load tools from the MCP server.
|
||||
parse_tool_results: An optional callable with signature
|
||||
``Callable[[types.CallToolResult], str]`` that overrides the default result
|
||||
@@ -1358,6 +1409,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
description=description,
|
||||
approval_mode=approval_mode,
|
||||
allowed_tools=allowed_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
additional_properties=additional_properties,
|
||||
session=session,
|
||||
client=client,
|
||||
|
||||
@@ -71,7 +71,6 @@ if TYPE_CHECKING:
|
||||
ResponseStream,
|
||||
)
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
else:
|
||||
MCPTool = Any # type: ignore[assignment,misc]
|
||||
|
||||
@@ -83,9 +82,23 @@ DEFAULT_MAX_ITERATIONS: Final[int] = 40
|
||||
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
|
||||
SHELL_TOOL_KIND_VALUE: Final[str] = "shell"
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
# region Helpers
|
||||
|
||||
|
||||
def _get_tool_name(tool: Any) -> str | None:
|
||||
"""Extract a tool name from a tool object or dict tool definition."""
|
||||
if isinstance(tool, Mapping):
|
||||
func = tool.get("function", None) # type: ignore
|
||||
if func and isinstance(func, Mapping):
|
||||
name = func.get("name") # type: ignore
|
||||
return name if isinstance(name, str) else None
|
||||
return None
|
||||
name = getattr(tool, "name", None)
|
||||
return name if isinstance(name, str) else None
|
||||
|
||||
|
||||
def _parse_inputs( # pyright: ignore[reportUnusedFunction]
|
||||
inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None,
|
||||
) -> list[Content]:
|
||||
@@ -701,6 +714,51 @@ class FunctionTool(SerializationMixin):
|
||||
ToolTypes: TypeAlias = FunctionTool | MCPTool | Mapping[str, Any] | object
|
||||
|
||||
|
||||
def _raise_duplicate_tool_name(tool_name: str, duplicate_error_message: str | None = None) -> None:
|
||||
message = duplicate_error_message or "Tool names must be unique."
|
||||
raise ValueError(f"Duplicate tool name '{tool_name}'. {message}")
|
||||
|
||||
|
||||
def _append_unique_tools(
|
||||
existing_tools: list[ToolTypes],
|
||||
new_tools: Sequence[ToolTypes],
|
||||
*,
|
||||
duplicate_error_message: str | None = None,
|
||||
) -> list[ToolTypes]:
|
||||
seen_by_name: dict[str, ToolTypes] = {}
|
||||
for tool_item in existing_tools:
|
||||
if tool_name := _get_tool_name(tool_item):
|
||||
seen_by_name[tool_name] = tool_item
|
||||
|
||||
for tool_item in new_tools:
|
||||
tool_name = _get_tool_name(tool_item)
|
||||
if tool_name is None:
|
||||
existing_tools.append(tool_item)
|
||||
continue
|
||||
|
||||
existing_tool = seen_by_name.get(tool_name)
|
||||
if existing_tool is None:
|
||||
seen_by_name[tool_name] = tool_item
|
||||
existing_tools.append(tool_item)
|
||||
continue
|
||||
|
||||
if existing_tool is tool_item:
|
||||
continue
|
||||
|
||||
_raise_duplicate_tool_name(tool_name, duplicate_error_message)
|
||||
|
||||
return existing_tools
|
||||
|
||||
|
||||
def _ensure_unique_tool_names(
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]],
|
||||
*,
|
||||
duplicate_error_message: str | None = None,
|
||||
) -> list[ToolTypes]:
|
||||
normalized_tools = normalize_tools(tools)
|
||||
return _append_unique_tools([], normalized_tools, duplicate_error_message=duplicate_error_message)
|
||||
|
||||
|
||||
def normalize_tools(
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None,
|
||||
) -> list[ToolTypes]:
|
||||
@@ -1320,7 +1378,7 @@ def _get_tool_map(
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]],
|
||||
) -> dict[str, FunctionTool]:
|
||||
tool_list: dict[str, FunctionTool] = {}
|
||||
for tool_item in normalize_tools(tools):
|
||||
for tool_item in _ensure_unique_tool_names(tools):
|
||||
if isinstance(tool_item, FunctionTool):
|
||||
tool_list[tool_item.name] = tool_item
|
||||
return tool_list
|
||||
|
||||
@@ -30,7 +30,7 @@ from agent_framework import (
|
||||
tool,
|
||||
)
|
||||
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name
|
||||
|
||||
|
||||
class _FixedTokenizer:
|
||||
@@ -41,6 +41,30 @@ class _FixedTokenizer:
|
||||
return self.token_count
|
||||
|
||||
|
||||
class _ConnectedMCPTool(MCPTool):
|
||||
def __init__(self, name: str, function_names: list[str], *, tool_name_prefix: str | None = None) -> None:
|
||||
super().__init__(name=name, tool_name_prefix=tool_name_prefix)
|
||||
self.is_connected = True
|
||||
self._functions = []
|
||||
for function_name in function_names:
|
||||
normalized_name = _normalize_mcp_name(function_name)
|
||||
exposed_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
self._functions.append(
|
||||
FunctionTool(
|
||||
func=lambda value=function_name: value,
|
||||
name=exposed_name,
|
||||
description=f"{function_name} from {name}",
|
||||
additional_properties={
|
||||
"_mcp_remote_name": function_name,
|
||||
"_mcp_normalized_name": normalized_name,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> contextlib.AbstractAsyncContextManager[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_agent_session_type(agent_session: AgentSession) -> None:
|
||||
assert isinstance(agent_session, AgentSession)
|
||||
|
||||
@@ -953,6 +977,7 @@ async def test_chat_agent_run_with_mcp_tools(client: SupportsChatGetResponse) ->
|
||||
|
||||
# Create a mock MCP tool
|
||||
mock_mcp_tool = MagicMock(spec=MCPTool)
|
||||
mock_mcp_tool.name = "mock-mcp"
|
||||
mock_mcp_tool.is_connected = False
|
||||
mock_mcp_tool.functions = [MagicMock()]
|
||||
|
||||
@@ -970,6 +995,7 @@ async def test_chat_agent_with_local_mcp_tools(client: SupportsChatGetResponse)
|
||||
"""Test agent initialization with local MCP tools."""
|
||||
# Create a mock MCP tool
|
||||
mock_mcp_tool = MagicMock(spec=MCPTool)
|
||||
mock_mcp_tool.name = "mock-mcp"
|
||||
mock_mcp_tool.is_connected = False
|
||||
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
|
||||
mock_mcp_tool.__aexit__ = AsyncMock(return_value=None)
|
||||
@@ -1009,6 +1035,7 @@ async def test_mcp_tools_not_duplicated_when_passed_as_runtime_tools(
|
||||
|
||||
# Create a mock MCP tool that is already connected (simulates turn 2)
|
||||
mock_mcp_tool = MagicMock(spec=MCPTool)
|
||||
mock_mcp_tool.name = "mock-mcp"
|
||||
mock_mcp_tool.is_connected = True
|
||||
mock_mcp_tool.functions = [mcp_func_a, mcp_func_b]
|
||||
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
|
||||
@@ -1032,6 +1059,77 @@ async def test_mcp_tools_not_duplicated_when_passed_as_runtime_tools(
|
||||
assert len(tool_names) == 3
|
||||
|
||||
|
||||
async def test_agent_run_raises_on_local_and_agent_mcp_name_conflict(chat_client_base: Any) -> None:
|
||||
local_tool = FunctionTool(
|
||||
func=lambda: "local",
|
||||
name="delete_all_data",
|
||||
description="Local protected tool",
|
||||
approval_mode="always_require",
|
||||
)
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
name="TestAgent",
|
||||
tools=[_ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"])],
|
||||
)
|
||||
|
||||
with raises(ValueError, match="tool_name_prefix"):
|
||||
await agent.run("hello", tools=[local_tool])
|
||||
|
||||
|
||||
async def test_agent_run_raises_on_runtime_local_and_runtime_mcp_name_conflict(chat_client_base: Any) -> None:
|
||||
local_tool = FunctionTool(
|
||||
func=lambda: "local",
|
||||
name="delete_all_data",
|
||||
description="Local protected tool",
|
||||
approval_mode="always_require",
|
||||
)
|
||||
runtime_mcp = _ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"])
|
||||
agent = Agent(client=chat_client_base, name="TestAgent")
|
||||
|
||||
with raises(ValueError, match="tool_name_prefix"):
|
||||
await agent.run("hello", tools=[local_tool, runtime_mcp])
|
||||
|
||||
|
||||
async def test_agent_run_raises_on_duplicate_agent_mcp_names(chat_client_base: Any) -> None:
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
name="TestAgent",
|
||||
tools=[
|
||||
_ConnectedMCPTool(name="docs-mcp", function_names=["search"]),
|
||||
_ConnectedMCPTool(name="github-mcp", function_names=["search"]),
|
||||
],
|
||||
)
|
||||
|
||||
with raises(ValueError, match="tool_name_prefix"):
|
||||
await agent.run("hello")
|
||||
|
||||
|
||||
async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> None:
|
||||
captured_options: list[dict[str, Any]] = []
|
||||
|
||||
original_inner = chat_client_base._inner_get_response
|
||||
|
||||
async def capturing_inner(
|
||||
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
|
||||
) -> ChatResponse:
|
||||
captured_options.append(dict(options))
|
||||
return await original_inner(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._inner_get_response = capturing_inner
|
||||
|
||||
local_tool = FunctionTool(func=lambda: "local", name="search", description="Local search tool")
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
name="TestAgent",
|
||||
tools=[_ConnectedMCPTool(name="docs-mcp", function_names=["search"], tool_name_prefix="docs")],
|
||||
)
|
||||
|
||||
await agent.run("hello", tools=[local_tool])
|
||||
|
||||
tool_names = [tool.name for tool in captured_options[0]["tools"]]
|
||||
assert tool_names == ["search", "docs_search"]
|
||||
|
||||
|
||||
async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None:
|
||||
"""Verify tool execution receives 'session' inside **kwargs when function is called by client."""
|
||||
|
||||
@@ -1291,7 +1389,7 @@ def test_merge_options_none_values_ignored():
|
||||
|
||||
|
||||
def test_merge_options_tools_combined():
|
||||
"""Test _merge_options combines tool lists without duplicates."""
|
||||
"""Test _merge_options raises when distinct tools share the same name."""
|
||||
|
||||
class MockTool:
|
||||
def __init__(self, name):
|
||||
@@ -1304,13 +1402,8 @@ def test_merge_options_tools_combined():
|
||||
base = {"tools": [tool1]}
|
||||
override = {"tools": [tool2, tool3]}
|
||||
|
||||
result = _merge_options(base, override)
|
||||
|
||||
# Should have tool1 and tool2, but not duplicate tool3
|
||||
assert len(result["tools"]) == 2
|
||||
tool_names = [t.name for t in result["tools"]]
|
||||
assert "tool1" in tool_names
|
||||
assert "tool2" in tool_names
|
||||
with raises(ValueError, match="Duplicate tool name 'tool1'"):
|
||||
_merge_options(base, override)
|
||||
|
||||
|
||||
def test_merge_options_dict_tools_combined():
|
||||
@@ -1335,7 +1428,7 @@ def test_merge_options_dict_tools_combined():
|
||||
|
||||
|
||||
def test_merge_options_dict_tools_deduplicates():
|
||||
"""Test _merge_options deduplicates dict-defined tools by function name."""
|
||||
"""Test _merge_options raises on duplicate dict-defined tool names."""
|
||||
base = {
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "tool_a"}},
|
||||
@@ -1348,12 +1441,8 @@ def test_merge_options_dict_tools_deduplicates():
|
||||
]
|
||||
}
|
||||
|
||||
result = _merge_options(base, override)
|
||||
|
||||
assert len(result["tools"]) == 2
|
||||
names = [_get_tool_name(t) for t in result["tools"]]
|
||||
assert names.count("tool_a") == 1
|
||||
assert "tool_b" in names
|
||||
with raises(ValueError, match="Duplicate tool name 'tool_a'"):
|
||||
_merge_options(base, override)
|
||||
|
||||
|
||||
def test_merge_options_mixed_tools_combined():
|
||||
@@ -1379,7 +1468,7 @@ def test_merge_options_mixed_tools_combined():
|
||||
|
||||
|
||||
def test_merge_options_mixed_tools_deduplicates():
|
||||
"""Test _merge_options deduplicates when a dict tool and object tool share the same name."""
|
||||
"""Test _merge_options raises when a dict tool and object tool share the same name."""
|
||||
|
||||
class MockTool:
|
||||
def __init__(self, name):
|
||||
@@ -1392,10 +1481,8 @@ def test_merge_options_mixed_tools_deduplicates():
|
||||
]
|
||||
}
|
||||
|
||||
result = _merge_options(base, override)
|
||||
|
||||
assert len(result["tools"]) == 1
|
||||
assert _get_tool_name(result["tools"][0]) == "tool_a"
|
||||
with raises(ValueError, match="Duplicate tool name 'tool_a'"):
|
||||
_merge_options(base, override)
|
||||
|
||||
|
||||
def test_merge_options_nameless_tools_not_deduplicated():
|
||||
@@ -1417,6 +1504,20 @@ def test_merge_options_nameless_tools_not_deduplicated():
|
||||
assert len(result["tools"]) == 2
|
||||
|
||||
|
||||
def test_merge_options_same_tool_object_kept_once():
|
||||
"""Test _merge_options silently keeps a repeated reference to the same tool object once."""
|
||||
|
||||
class MockTool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
tool_a = MockTool("tool_a")
|
||||
|
||||
result = _merge_options({"tools": [tool_a]}, {"tools": [tool_a]})
|
||||
|
||||
assert result["tools"] == [tool_a]
|
||||
|
||||
|
||||
def test_get_tool_name_dict_no_function_key():
|
||||
"""_get_tool_name returns None for a dict without a 'function' key."""
|
||||
assert _get_tool_name({"type": "function"}) is None
|
||||
|
||||
@@ -53,6 +53,81 @@ def test_normalize_mcp_name():
|
||||
assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes"
|
||||
|
||||
|
||||
def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None:
|
||||
assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio"
|
||||
assert (
|
||||
MCPStreamableHTTPTool(
|
||||
name="http",
|
||||
url="https://example.com/mcp",
|
||||
tool_name_prefix="http",
|
||||
).tool_name_prefix
|
||||
== "http"
|
||||
)
|
||||
assert (
|
||||
MCPWebsocketTool(
|
||||
name="ws",
|
||||
url="wss://example.com/mcp",
|
||||
tool_name_prefix="ws",
|
||||
).tool_name_prefix
|
||||
== "ws"
|
||||
)
|
||||
|
||||
|
||||
async def test_load_tools_with_tool_name_prefix_preserves_matching_configuration():
|
||||
"""Prefixed MCP tool names should still honor unprefixed allow/approval configuration."""
|
||||
tool = MCPTool(
|
||||
name="docs",
|
||||
tool_name_prefix="docs",
|
||||
allowed_tools=["search_docs"],
|
||||
approval_mode={"always_require_approval": ["search_docs"]},
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.tools = [
|
||||
types.Tool(
|
||||
name="search_docs",
|
||||
description="Search docs",
|
||||
inputSchema={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
mock_session.list_tools = AsyncMock(return_value=page)
|
||||
|
||||
await tool.load_tools()
|
||||
|
||||
assert [function.name for function in tool._functions] == ["docs_search_docs"]
|
||||
assert [function.name for function in tool.functions] == ["docs_search_docs"]
|
||||
assert tool.functions[0].approval_mode == "always_require"
|
||||
|
||||
|
||||
async def test_load_prompts_with_tool_name_prefix() -> None:
|
||||
"""Prefixed MCP prompt names should be exposed with the configured prefix."""
|
||||
tool = MCPTool(name="docs", tool_name_prefix="docs")
|
||||
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.prompts = [
|
||||
types.Prompt(
|
||||
name="summarize docs",
|
||||
description="Summarize docs",
|
||||
arguments=[types.PromptArgument(name="topic", description="Topic", required=True)],
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
mock_session.list_prompts = AsyncMock(return_value=page)
|
||||
|
||||
await tool.load_prompts()
|
||||
|
||||
assert [function.name for function in tool._functions] == ["docs_summarize-docs"]
|
||||
|
||||
|
||||
def test_mcp_prompt_message_to_ai_content():
|
||||
"""Test conversion from MCP prompt message to AI content."""
|
||||
mcp_message = types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello, world!"))
|
||||
|
||||
Reference in New Issue
Block a user