diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py index 442138649a..585bcb5c3e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py @@ -8,6 +8,7 @@ import logging from typing import TYPE_CHECKING, Any from agent_framework import BaseChatClient +from agent_framework._tools import _append_unique_tools # pyright: ignore[reportPrivateUsage] if TYPE_CHECKING: from agent_framework import SupportsAgentRun @@ -22,7 +23,7 @@ def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]: mcp_tools: List of MCP tool instances. Returns: - List of functions from connected MCP tools. + Functions from connected MCP tools. """ functions: list[Any] = [] for mcp_tool in mcp_tools: @@ -56,7 +57,11 @@ def collect_server_tools(agent: SupportsAgentRun) -> list[Any]: # Include functions from connected MCP tools (only available on Agent) mcp_tools = getattr(agent, "mcp_tools", None) if mcp_tools: - server_tools.extend(_collect_mcp_tool_functions(mcp_tools)) + _append_unique_tools( + server_tools, + _collect_mcp_tool_functions(mcp_tools), + duplicate_error_message="Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool.", + ) logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") for tool in server_tools: @@ -109,26 +114,13 @@ def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)") return None - server_tool_names = {getattr(tool, "name", None) for tool in server_tools} - unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names] - - if not unique_client_tools: - # Same check: must pass server tools if any require approval - if server_tools and _has_approval_tools(server_tools): - logger.info( - f"[TOOLS] Client tools duplicate server but server has approval tools - " - f"passing {len(server_tools)} server tools for approval mode" - ) - return server_tools - logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter") - return None - - combined_tools: list[Any] = [] - if server_tools: - combined_tools.extend(server_tools) - combined_tools.extend(unique_client_tools) + combined_tools = _append_unique_tools( + list(server_tools), + client_tools, + duplicate_error_message="Tool names must be unique.", + ) logger.info( f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools " - f"({len(server_tools)} server + {len(unique_client_tools)} unique client)" + f"({len(server_tools)} server + {len(client_tools)} client)" ) return combined_tools diff --git a/python/packages/ag-ui/tests/ag_ui/test_tooling.py b/python/packages/ag-ui/tests/ag_ui/test_tooling.py index e8567a586d..890ae44541 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_tooling.py +++ b/python/packages/ag-ui/tests/ag_ui/test_tooling.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock +import pytest from agent_framework import Agent, tool from agent_framework_ag_ui._orchestration._tooling import ( @@ -20,7 +21,8 @@ class DummyTool: class MockMCPTool: """Mock MCP tool that simulates connected MCP tool with functions.""" - def __init__(self, functions: list[DummyTool], is_connected: bool = True) -> None: + def __init__(self, functions: list[DummyTool], is_connected: bool = True, name: str = "mock-mcp") -> None: + self.name = name self.functions = functions self.is_connected = is_connected @@ -45,11 +47,8 @@ def test_merge_tools_filters_duplicates() -> None: server = [DummyTool("a"), DummyTool("b")] client = [DummyTool("b"), DummyTool("c")] - merged = merge_tools(server, client) - - assert merged is not None - names = [getattr(t, "name", None) for t in merged] - assert names == ["a", "b", "c"] + with pytest.raises(ValueError, match="Duplicate tool name 'b'"): + merge_tools(server, client) def test_register_additional_client_tools_assigns_when_configured() -> None: @@ -131,6 +130,17 @@ def test_collect_server_tools_with_mcp_tools_via_public_property() -> None: assert len(tools) == 2 +def test_collect_server_tools_raises_on_duplicate_agent_and_mcp_tool_names() -> None: + duplicate_tool = DummyTool("regular_tool") + mock_mcp = MockMCPTool([duplicate_tool], is_connected=True, name="docs-mcp") + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + with pytest.raises(ValueError, match="Duplicate tool name 'regular_tool'"): + collect_server_tools(agent) + + # Additional tests for tooling coverage @@ -176,11 +186,11 @@ def test_merge_tools_no_client_tools() -> None: def test_merge_tools_all_duplicates() -> None: - """merge_tools returns None when all client tools duplicate server tools.""" + """merge_tools raises when client and server tools share a name.""" server = [DummyTool("a"), DummyTool("b")] client = [DummyTool("a"), DummyTool("b")] - result = merge_tools(server, client) - assert result is None + with pytest.raises(ValueError, match="Duplicate tool name 'a'"): + merge_tools(server, client) def test_merge_tools_empty_server() -> None: @@ -208,7 +218,7 @@ def test_merge_tools_with_approval_tools_no_client() -> None: def test_merge_tools_with_approval_tools_all_duplicates() -> None: - """merge_tools returns server tools with approval mode even when client duplicates.""" + """merge_tools raises even when a client tool duplicates an approval-gated server tool.""" class ApprovalTool: def __init__(self, name: str): @@ -217,7 +227,5 @@ def test_merge_tools_with_approval_tools_all_duplicates() -> None: server = [ApprovalTool("write_doc")] client = [DummyTool("write_doc")] # Same name as server - result = merge_tools(server, client) - assert result is not None - assert len(result) == 1 - assert result[0].approval_mode == "always_require" + with pytest.raises(ValueError, match="Duplicate tool name 'write_doc'"): + merge_tools(server, client) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 8f4002e52e..2e6cca7dba 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -29,6 +29,7 @@ from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError from pydantic import BaseModel, Field, create_model +from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage] from ._clients import BaseChatClient, SupportsChatGetResponse from ._mcp import LOG_LEVEL_MAPPING, MCPTool from ._middleware import AgentMiddlewareLayer, MiddlewareTypes @@ -40,12 +41,7 @@ from ._sessions import ( InMemoryHistoryProvider, SessionContext, ) -from ._tools import ( - FunctionInvocationLayer, - FunctionTool, - ToolTypes, - normalize_tools, -) +from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools from ._types import ( AgentResponse, AgentResponseUpdate, @@ -79,6 +75,9 @@ if TYPE_CHECKING: logger = logging.getLogger("agent_framework") +_append_unique_tools = _tool_utils._append_unique_tools # pyright: ignore[reportPrivateUsage] +_get_tool_name = _tool_utils._get_tool_name # pyright: ignore[reportPrivateUsage] + ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) OptionsCoT = TypeVar( "OptionsCoT", @@ -88,19 +87,6 @@ OptionsCoT = TypeVar( ) -def _get_tool_name(tool: Any) -> str | None: - """Extract a tool's name from either an object with a .name attribute or a dict tool definition.""" - if isinstance(tool, Mapping): - tool_mapping = cast(Mapping[str, Any], tool) - func = tool_mapping.get("function") - if isinstance(func, Mapping): - func_mapping = cast(Mapping[str, Any], func) - name = func_mapping.get("name") - return name if isinstance(name, str) else None - return None - return getattr(tool, "name", None) - - def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: """Merge two options dicts, with override values taking precedence. @@ -115,11 +101,14 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, for key, value in override.items(): if value is None: continue - if key == "tools" and result.get("tools"): - # Combine tool lists, avoiding duplicates by name - existing_names = {_get_tool_name(t) for t in result["tools"]} - {None} - unique_new = [t for t in value if _get_tool_name(t) not in existing_names] - result["tools"] = list(result["tools"]) + unique_new + if key == "tools" and (result.get("tools") or value): + base_tools = normalize_tools(result.get("tools")) + override_tools = normalize_tools(value) + result["tools"] = _append_unique_tools( + list(base_tools), + override_tools, + duplicate_error_message="Tool names must be unique.", + ) elif key == "logit_bias" and result.get("logit_bias"): # Merge logit_bias dicts result["logit_bias"] = {**result["logit_bias"], **value} @@ -1117,25 +1106,34 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc] ) agent_name = self._get_agent_name() + base_tools = normalize_tools(chat_options.pop("tools", None)) + mcp_duplicate_message = "Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool." # Normalize tools normalized_tools = normalize_tools(tools_) - # Resolve final tool list (runtime provided tools + local MCP server tools) - final_tools: list[FunctionTool | Callable[..., Any] | dict[str, Any] | Any] = [] + # Resolve final tool list (configured tools + runtime provided tools + local MCP server tools) + final_tools = list(base_tools) for tool in normalized_tools: if isinstance(tool, MCPTool): if not tool.is_connected: await self._async_exit_stack.enter_async_context(tool) - final_tools.extend(tool.functions) # type: ignore + _append_unique_tools( + final_tools, + tool.functions, + duplicate_error_message=mcp_duplicate_message, + ) else: - final_tools.append(tool) # type: ignore + _append_unique_tools(final_tools, [tool]) # type: ignore[list-item] - existing_names = {name for t in final_tools if (name := _get_tool_name(t)) is not None} for mcp_server in self.mcp_tools: if not mcp_server.is_connected: await self._async_exit_stack.enter_async_context(mcp_server) - final_tools.extend(f for f in mcp_server.functions if f.name not in existing_names) + _append_unique_tools( + final_tools, + mcp_server.functions, + duplicate_error_message=mcp_duplicate_message, + ) # Merge runtime kwargs into additional_function_arguments so they're available # in function middleware context and tool invocation. @@ -1164,7 +1162,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc] "store": opts.pop("store", None), "temperature": opts.pop("temperature", None), "tool_choice": opts.pop("tool_choice", None), - "tools": final_tools, + "tools": final_tools or None, "top_p": opts.pop("top_p", None), "user": opts.pop("user", None), **opts, # Remaining options are provider-specific diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 83d896738d..28c5f6db6a 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -26,9 +26,7 @@ from mcp.shared.exceptions import McpError from mcp.shared.session import RequestResponder from opentelemetry import propagate -from ._tools import ( - FunctionTool, -) +from ._tools import FunctionTool from ._types import ( Content, Message, @@ -59,6 +57,8 @@ class MCPSpecificApproval(TypedDict, total=False): logger = logging.getLogger(__name__) +_MCP_REMOTE_NAME_KEY = "_mcp_remote_name" +_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name" # region: Helpers @@ -372,6 +372,20 @@ def _normalize_mcp_name(name: str) -> str: return re.sub(r"[^A-Za-z0-9_.-]", "-", name) +def _build_prefixed_mcp_name( + normalized_name: str, + tool_name_prefix: str | None, +) -> str: + """Build the exposed MCP function name from a normalized name and optional prefix.""" + if not tool_name_prefix: + return normalized_name + normalized_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-") + if not normalized_prefix: + return normalized_name + trimmed_name = normalized_name.lstrip("_.-") + return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix + + def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None: """Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s).""" carrier: dict[str, str] = {} @@ -415,6 +429,7 @@ class MCPTool: description: str | None = None, approval_mode: (Literal["always_require", "never_require"] | MCPSpecificApproval | None) = None, allowed_tools: Collection[str] | None = None, + tool_name_prefix: str | None = None, load_tools: bool = True, parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, @@ -435,6 +450,7 @@ class MCPTool: description: A description of the MCP tool. approval_mode: Whether approval is required to run tools. allowed_tools: A collection of tool names to allow. + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -458,6 +474,7 @@ class MCPTool: self.description = description or "" self.approval_mode = approval_mode self.allowed_tools = allowed_tools + self.tool_name_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-") if tool_name_prefix else None self.additional_properties = additional_properties self.load_tools_flag = load_tools self.parse_tool_results = parse_tool_results @@ -480,7 +497,19 @@ class MCPTool: """Get the list of functions that are allowed.""" if not self.allowed_tools: return self._functions - return [func for func in self._functions if func.name in self.allowed_tools] + allowed_names = set(self.allowed_tools) + filtered_functions: list[FunctionTool] = [] + for func in self._functions: + additional_properties = func.additional_properties or {} + normalized_name = additional_properties.get(_MCP_NORMALIZED_NAME_KEY) + remote_name = additional_properties.get(_MCP_REMOTE_NAME_KEY) + if ( + func.name in allowed_names + or (isinstance(normalized_name, str) and normalized_name in allowed_names) + or (isinstance(remote_name, str) and remote_name in allowed_names) + ): + filtered_functions.append(func) + return filtered_functions async def _safe_close_exit_stack(self) -> None: """Safely close the exit stack, handling cross-task boundary errors. @@ -706,12 +735,16 @@ class MCPTool: def _determine_approval_mode( self, - local_name: str, + *candidate_names: str, ) -> Literal["always_require", "never_require"] | None: if isinstance(self.approval_mode, dict): - if (always_require := self.approval_mode.get("always_require_approval")) and local_name in always_require: + if (always_require := self.approval_mode.get("always_require_approval")) and any( + name in always_require for name in candidate_names + ): return "always_require" - if (never_require := self.approval_mode.get("never_require_approval")) and local_name in never_require: + if (never_require := self.approval_mode.get("never_require_approval")) and any( + name in never_require for name in candidate_names + ): return "never_require" return None return self.approval_mode # type: ignore[reportReturnType] @@ -736,20 +769,25 @@ class MCPTool: prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] for prompt in prompt_list.prompts: - local_name = _normalize_mcp_name(prompt.name) + normalized_name = _normalize_mcp_name(prompt.name) + local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) # Skip if already loaded if local_name in existing_names: continue input_model = _get_input_model_from_mcp_prompt(prompt) - approval_mode = self._determine_approval_mode(local_name) + approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name) func: FunctionTool = FunctionTool( func=partial(self.get_prompt, prompt.name), name=local_name, description=prompt.description or "", approval_mode=approval_mode, input_model=input_model, + additional_properties={ + _MCP_REMOTE_NAME_KEY: prompt.name, + _MCP_NORMALIZED_NAME_KEY: normalized_name, + }, ) self._functions.append(func) existing_names.add(local_name) @@ -779,13 +817,14 @@ class MCPTool: tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] for tool in tool_list.tools: - local_name = _normalize_mcp_name(tool.name) + normalized_name = _normalize_mcp_name(tool.name) + local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) # Skip if already loaded if local_name in existing_names: continue - approval_mode = self._determine_approval_mode(local_name) + approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name) # Create FunctionTools out of each tool func: FunctionTool = FunctionTool( func=partial(self.call_tool, tool.name), @@ -793,6 +832,10 @@ class MCPTool: description=tool.description or "", approval_mode=approval_mode, input_model=tool.inputSchema, + additional_properties={ + _MCP_REMOTE_NAME_KEY: tool.name, + _MCP_NORMALIZED_NAME_KEY: normalized_name, + }, ) self._functions.append(func) existing_names.add(local_name) @@ -1055,6 +1098,7 @@ class MCPStdioTool(MCPTool): name: str, command: str, *, + tool_name_prefix: str | None = None, load_tools: bool = True, parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, @@ -1083,6 +1127,7 @@ class MCPStdioTool(MCPTool): command: The command to run the MCP server. Keyword Args: + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -1119,6 +1164,7 @@ class MCPStdioTool(MCPTool): description=description, approval_mode=approval_mode, allowed_tools=allowed_tools, + tool_name_prefix=tool_name_prefix, additional_properties=additional_properties, session=session, client=client, @@ -1180,6 +1226,7 @@ class MCPStreamableHTTPTool(MCPTool): name: str, url: str, *, + tool_name_prefix: str | None = None, load_tools: bool = True, parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, @@ -1208,6 +1255,7 @@ class MCPStreamableHTTPTool(MCPTool): url: The URL of the MCP server. Keyword Args: + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -1246,6 +1294,7 @@ class MCPStreamableHTTPTool(MCPTool): description=description, approval_mode=approval_mode, allowed_tools=allowed_tools, + tool_name_prefix=tool_name_prefix, additional_properties=additional_properties, session=session, client=client, @@ -1299,6 +1348,7 @@ class MCPWebsocketTool(MCPTool): name: str, url: str, *, + tool_name_prefix: str | None = None, load_tools: bool = True, parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, @@ -1325,6 +1375,7 @@ class MCPWebsocketTool(MCPTool): url: The URL of the MCP server. Keyword Args: + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -1358,6 +1409,7 @@ class MCPWebsocketTool(MCPTool): description=description, approval_mode=approval_mode, allowed_tools=allowed_tools, + tool_name_prefix=tool_name_prefix, additional_properties=additional_properties, session=session, client=client, diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 090f382f1b..bfb2c7d2cb 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -71,7 +71,6 @@ if TYPE_CHECKING: ResponseStream, ) - ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) else: MCPTool = Any # type: ignore[assignment,misc] @@ -83,9 +82,23 @@ DEFAULT_MAX_ITERATIONS: Final[int] = 40 DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3 SHELL_TOOL_KIND_VALUE: Final[str] = "shell" ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]") +ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) + # region Helpers +def _get_tool_name(tool: Any) -> str | None: + """Extract a tool name from a tool object or dict tool definition.""" + if isinstance(tool, Mapping): + func = tool.get("function", None) # type: ignore + if func and isinstance(func, Mapping): + name = func.get("name") # type: ignore + return name if isinstance(name, str) else None + return None + name = getattr(tool, "name", None) + return name if isinstance(name, str) else None + + def _parse_inputs( # pyright: ignore[reportUnusedFunction] inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None, ) -> list[Content]: @@ -701,6 +714,51 @@ class FunctionTool(SerializationMixin): ToolTypes: TypeAlias = FunctionTool | MCPTool | Mapping[str, Any] | object +def _raise_duplicate_tool_name(tool_name: str, duplicate_error_message: str | None = None) -> None: + message = duplicate_error_message or "Tool names must be unique." + raise ValueError(f"Duplicate tool name '{tool_name}'. {message}") + + +def _append_unique_tools( + existing_tools: list[ToolTypes], + new_tools: Sequence[ToolTypes], + *, + duplicate_error_message: str | None = None, +) -> list[ToolTypes]: + seen_by_name: dict[str, ToolTypes] = {} + for tool_item in existing_tools: + if tool_name := _get_tool_name(tool_item): + seen_by_name[tool_name] = tool_item + + for tool_item in new_tools: + tool_name = _get_tool_name(tool_item) + if tool_name is None: + existing_tools.append(tool_item) + continue + + existing_tool = seen_by_name.get(tool_name) + if existing_tool is None: + seen_by_name[tool_name] = tool_item + existing_tools.append(tool_item) + continue + + if existing_tool is tool_item: + continue + + _raise_duplicate_tool_name(tool_name, duplicate_error_message) + + return existing_tools + + +def _ensure_unique_tool_names( + tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]], + *, + duplicate_error_message: str | None = None, +) -> list[ToolTypes]: + normalized_tools = normalize_tools(tools) + return _append_unique_tools([], normalized_tools, duplicate_error_message=duplicate_error_message) + + def normalize_tools( tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None, ) -> list[ToolTypes]: @@ -1320,7 +1378,7 @@ def _get_tool_map( tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]], ) -> dict[str, FunctionTool]: tool_list: dict[str, FunctionTool] = {} - for tool_item in normalize_tools(tools): + for tool_item in _ensure_unique_tool_names(tools): if isinstance(tool_item, FunctionTool): tool_list[tool_item.name] = tool_item return tool_list diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index e666e374eb..32c098e51c 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -30,7 +30,7 @@ from agent_framework import ( tool, ) from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name -from agent_framework._mcp import MCPTool +from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name class _FixedTokenizer: @@ -41,6 +41,30 @@ class _FixedTokenizer: return self.token_count +class _ConnectedMCPTool(MCPTool): + def __init__(self, name: str, function_names: list[str], *, tool_name_prefix: str | None = None) -> None: + super().__init__(name=name, tool_name_prefix=tool_name_prefix) + self.is_connected = True + self._functions = [] + for function_name in function_names: + normalized_name = _normalize_mcp_name(function_name) + exposed_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) + self._functions.append( + FunctionTool( + func=lambda value=function_name: value, + name=exposed_name, + description=f"{function_name} from {name}", + additional_properties={ + "_mcp_remote_name": function_name, + "_mcp_normalized_name": normalized_name, + }, + ) + ) + + def get_mcp_client(self) -> contextlib.AbstractAsyncContextManager[Any]: + raise NotImplementedError + + def test_agent_session_type(agent_session: AgentSession) -> None: assert isinstance(agent_session, AgentSession) @@ -953,6 +977,7 @@ async def test_chat_agent_run_with_mcp_tools(client: SupportsChatGetResponse) -> # Create a mock MCP tool mock_mcp_tool = MagicMock(spec=MCPTool) + mock_mcp_tool.name = "mock-mcp" mock_mcp_tool.is_connected = False mock_mcp_tool.functions = [MagicMock()] @@ -970,6 +995,7 @@ async def test_chat_agent_with_local_mcp_tools(client: SupportsChatGetResponse) """Test agent initialization with local MCP tools.""" # Create a mock MCP tool mock_mcp_tool = MagicMock(spec=MCPTool) + mock_mcp_tool.name = "mock-mcp" mock_mcp_tool.is_connected = False mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool) mock_mcp_tool.__aexit__ = AsyncMock(return_value=None) @@ -1009,6 +1035,7 @@ async def test_mcp_tools_not_duplicated_when_passed_as_runtime_tools( # Create a mock MCP tool that is already connected (simulates turn 2) mock_mcp_tool = MagicMock(spec=MCPTool) + mock_mcp_tool.name = "mock-mcp" mock_mcp_tool.is_connected = True mock_mcp_tool.functions = [mcp_func_a, mcp_func_b] mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool) @@ -1032,6 +1059,77 @@ async def test_mcp_tools_not_duplicated_when_passed_as_runtime_tools( assert len(tool_names) == 3 +async def test_agent_run_raises_on_local_and_agent_mcp_name_conflict(chat_client_base: Any) -> None: + local_tool = FunctionTool( + func=lambda: "local", + name="delete_all_data", + description="Local protected tool", + approval_mode="always_require", + ) + agent = Agent( + client=chat_client_base, + name="TestAgent", + tools=[_ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"])], + ) + + with raises(ValueError, match="tool_name_prefix"): + await agent.run("hello", tools=[local_tool]) + + +async def test_agent_run_raises_on_runtime_local_and_runtime_mcp_name_conflict(chat_client_base: Any) -> None: + local_tool = FunctionTool( + func=lambda: "local", + name="delete_all_data", + description="Local protected tool", + approval_mode="always_require", + ) + runtime_mcp = _ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"]) + agent = Agent(client=chat_client_base, name="TestAgent") + + with raises(ValueError, match="tool_name_prefix"): + await agent.run("hello", tools=[local_tool, runtime_mcp]) + + +async def test_agent_run_raises_on_duplicate_agent_mcp_names(chat_client_base: Any) -> None: + agent = Agent( + client=chat_client_base, + name="TestAgent", + tools=[ + _ConnectedMCPTool(name="docs-mcp", function_names=["search"]), + _ConnectedMCPTool(name="github-mcp", function_names=["search"]), + ], + ) + + with raises(ValueError, match="tool_name_prefix"): + await agent.run("hello") + + +async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> None: + captured_options: list[dict[str, Any]] = [] + + original_inner = chat_client_base._inner_get_response + + async def capturing_inner( + *, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any + ) -> ChatResponse: + captured_options.append(dict(options)) + return await original_inner(messages=messages, options=options, **kwargs) + + chat_client_base._inner_get_response = capturing_inner + + local_tool = FunctionTool(func=lambda: "local", name="search", description="Local search tool") + agent = Agent( + client=chat_client_base, + name="TestAgent", + tools=[_ConnectedMCPTool(name="docs-mcp", function_names=["search"], tool_name_prefix="docs")], + ) + + await agent.run("hello", tools=[local_tool]) + + tool_names = [tool.name for tool in captured_options[0]["tools"]] + assert tool_names == ["search", "docs_search"] + + async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None: """Verify tool execution receives 'session' inside **kwargs when function is called by client.""" @@ -1291,7 +1389,7 @@ def test_merge_options_none_values_ignored(): def test_merge_options_tools_combined(): - """Test _merge_options combines tool lists without duplicates.""" + """Test _merge_options raises when distinct tools share the same name.""" class MockTool: def __init__(self, name): @@ -1304,13 +1402,8 @@ def test_merge_options_tools_combined(): base = {"tools": [tool1]} override = {"tools": [tool2, tool3]} - result = _merge_options(base, override) - - # Should have tool1 and tool2, but not duplicate tool3 - assert len(result["tools"]) == 2 - tool_names = [t.name for t in result["tools"]] - assert "tool1" in tool_names - assert "tool2" in tool_names + with raises(ValueError, match="Duplicate tool name 'tool1'"): + _merge_options(base, override) def test_merge_options_dict_tools_combined(): @@ -1335,7 +1428,7 @@ def test_merge_options_dict_tools_combined(): def test_merge_options_dict_tools_deduplicates(): - """Test _merge_options deduplicates dict-defined tools by function name.""" + """Test _merge_options raises on duplicate dict-defined tool names.""" base = { "tools": [ {"type": "function", "function": {"name": "tool_a"}}, @@ -1348,12 +1441,8 @@ def test_merge_options_dict_tools_deduplicates(): ] } - result = _merge_options(base, override) - - assert len(result["tools"]) == 2 - names = [_get_tool_name(t) for t in result["tools"]] - assert names.count("tool_a") == 1 - assert "tool_b" in names + with raises(ValueError, match="Duplicate tool name 'tool_a'"): + _merge_options(base, override) def test_merge_options_mixed_tools_combined(): @@ -1379,7 +1468,7 @@ def test_merge_options_mixed_tools_combined(): def test_merge_options_mixed_tools_deduplicates(): - """Test _merge_options deduplicates when a dict tool and object tool share the same name.""" + """Test _merge_options raises when a dict tool and object tool share the same name.""" class MockTool: def __init__(self, name): @@ -1392,10 +1481,8 @@ def test_merge_options_mixed_tools_deduplicates(): ] } - result = _merge_options(base, override) - - assert len(result["tools"]) == 1 - assert _get_tool_name(result["tools"][0]) == "tool_a" + with raises(ValueError, match="Duplicate tool name 'tool_a'"): + _merge_options(base, override) def test_merge_options_nameless_tools_not_deduplicated(): @@ -1417,6 +1504,20 @@ def test_merge_options_nameless_tools_not_deduplicated(): assert len(result["tools"]) == 2 +def test_merge_options_same_tool_object_kept_once(): + """Test _merge_options silently keeps a repeated reference to the same tool object once.""" + + class MockTool: + def __init__(self, name): + self.name = name + + tool_a = MockTool("tool_a") + + result = _merge_options({"tools": [tool_a]}, {"tools": [tool_a]}) + + assert result["tools"] == [tool_a] + + def test_get_tool_name_dict_no_function_key(): """_get_tool_name returns None for a dict without a 'function' key.""" assert _get_tool_name({"type": "function"}) is None diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 139b860e21..df3187673a 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -53,6 +53,81 @@ def test_normalize_mcp_name(): assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes" +def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None: + assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio" + assert ( + MCPStreamableHTTPTool( + name="http", + url="https://example.com/mcp", + tool_name_prefix="http", + ).tool_name_prefix + == "http" + ) + assert ( + MCPWebsocketTool( + name="ws", + url="wss://example.com/mcp", + tool_name_prefix="ws", + ).tool_name_prefix + == "ws" + ) + + +async def test_load_tools_with_tool_name_prefix_preserves_matching_configuration(): + """Prefixed MCP tool names should still honor unprefixed allow/approval configuration.""" + tool = MCPTool( + name="docs", + tool_name_prefix="docs", + allowed_tools=["search_docs"], + approval_mode={"always_require_approval": ["search_docs"]}, + ) + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + page = Mock() + page.tools = [ + types.Tool( + name="search_docs", + description="Search docs", + inputSchema={"type": "object", "properties": {"query": {"type": "string"}}}, + ), + ] + page.nextCursor = None + mock_session.list_tools = AsyncMock(return_value=page) + + await tool.load_tools() + + assert [function.name for function in tool._functions] == ["docs_search_docs"] + assert [function.name for function in tool.functions] == ["docs_search_docs"] + assert tool.functions[0].approval_mode == "always_require" + + +async def test_load_prompts_with_tool_name_prefix() -> None: + """Prefixed MCP prompt names should be exposed with the configured prefix.""" + tool = MCPTool(name="docs", tool_name_prefix="docs") + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + page = Mock() + page.prompts = [ + types.Prompt( + name="summarize docs", + description="Summarize docs", + arguments=[types.PromptArgument(name="topic", description="Topic", required=True)], + ), + ] + page.nextCursor = None + mock_session.list_prompts = AsyncMock(return_value=page) + + await tool.load_prompts() + + assert [function.name for function in tool._functions] == ["docs_summarize-docs"] + + def test_mcp_prompt_message_to_ai_content(): """Test conversion from MCP prompt message to AI content.""" mcp_message = types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello, world!")) diff --git a/python/samples/02-agents/skills/script_approval/script_approval.py b/python/samples/02-agents/skills/script_approval/script_approval.py index 701d88de06..b1613ef28f 100644 --- a/python/samples/02-agents/skills/script_approval/script_approval.py +++ b/python/samples/02-agents/skills/script_approval/script_approval.py @@ -90,7 +90,7 @@ async def main() -> None: # maintained automatically — just send the approval response) while result.user_input_requests: for request in result.user_input_requests: - print(f"\nApproval needed:") + print("\nApproval needed:") print(f" Function: {request.function_call.name}") # type: ignore[union-attr] print(f" Arguments: {request.function_call.arguments}") # type: ignore[union-attr]