From e8ff541ebf07339b12585a3d4984d78a13e7e1aa Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Fri, 29 May 2026 09:21:14 +0200 Subject: [PATCH] Python: consolidate MCP reliability fixes (#6145) * Python: consolidate MCP reliability fixes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix MCP cleanup and metadata typing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Satisfy MCP metadata mypy typing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix Pyright metadata mapping type Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/packages/core/agent_framework/_mcp.py | 92 +++++++-- python/packages/core/tests/core/test_mcp.py | 206 +++++++++++++++++++ 2 files changed, 278 insertions(+), 20 deletions(-) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index b2942de2a0..d872b2b92d 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -10,7 +10,7 @@ import logging import re import sys from abc import abstractmethod -from collections.abc import Callable, Collection, Coroutine, Sequence +from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore from datetime import timedelta from functools import partial @@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, return meta +def _url_origin(url: Any) -> tuple[str, str, int | None]: + port = url.port + if port is None: + port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None + return (url.scheme, url.host or "", port) + + def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]: """Lazily import the MCP streamable HTTP transport.""" try: @@ -255,6 +262,7 @@ class MCPTool: self._exit_stack = AsyncExitStack() self._lifecycle_lock = asyncio.Lock() self._lifecycle_request_lock = asyncio.Lock() + self._function_load_lock = asyncio.Lock() self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None self._lifecycle_owner_task: asyncio.Task[None] | None = None self.session = session @@ -655,6 +663,11 @@ class MCPTool: raise except asyncio.CancelledError: logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.") + except Exception as e: + if type(e).__name__ == "ExceptionGroup": + logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e) + else: + raise async def _close_and_check_cancelled(self, ex: BaseException) -> bool: """Close the exit stack and return True if *ex* is a genuine task cancellation. @@ -1018,6 +1031,10 @@ class MCPTool: Raises: ToolExecutionException: If the MCP server is not connected. """ + async with self._function_load_lock: + await self._load_prompts_locked() + + async def _load_prompts_locked(self) -> None: from anyio import ClosedResourceError from mcp import types @@ -1100,6 +1117,10 @@ class MCPTool: Raises: ToolExecutionException: If the MCP server is not connected. """ + async with self._function_load_lock: + await self._load_tools_locked() + + async def _load_tools_locked(self) -> None: from anyio import ClosedResourceError from mcp import types @@ -1109,7 +1130,7 @@ class MCPTool: # Track existing function names to prevent duplicates existing_names = {func.name for func in self._functions} - self._tool_call_meta_by_name.clear() + tool_call_meta_by_name: dict[str, dict[str, Any]] = {} params: types.PaginatedRequestParams | None = None while True: @@ -1145,7 +1166,7 @@ class MCPTool: for tool in tool_list.tools: if tool.meta is not None: - self._tool_call_meta_by_name[tool.name] = dict(tool.meta) + tool_call_meta_by_name[tool.name] = dict(tool.meta) normalized_name = _normalize_mcp_name(tool.name) local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) @@ -1194,6 +1215,8 @@ class MCPTool: break params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) + self._tool_call_meta_by_name = tool_call_meta_by_name + async def _close_on_owner(self) -> None: # Cancel any pending reload tasks before tearing down the session. tasks = list(self._pending_reload_tasks) @@ -1276,7 +1299,11 @@ class MCPTool: tool_name: The name of the tool to call. Keyword Args: - kwargs: Arguments to pass to the tool. + _meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the + ``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument. + User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in + non-conflicting keys. + kwargs: Remaining arguments to pass to the tool. Returns: A list of Content items representing the tool output. The default @@ -1294,6 +1321,19 @@ class MCPTool: raise ToolExecutionException( "Tools are not loaded for this server, please set load_tools=True in the constructor." ) + + raw_user_meta: object | None = kwargs.get("_meta") + user_meta: dict[str, Any] | None = None + if raw_user_meta is not None and not isinstance(raw_user_meta, dict): + raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.") + if isinstance(raw_user_meta, dict): + raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta) + user_meta = {} + for key, value in raw_user_meta_dict.items(): + if not isinstance(key, str): + raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.") + user_meta[key] = value + # Filter out framework kwargs that cannot be serialized by the MCP SDK. # These are internal objects passed through the function invocation pipeline # that should not be forwarded to external MCP servers. @@ -1313,12 +1353,16 @@ class MCPTool: "conversation_id", "options", "response_format", + "_meta", } } # Some MCP proxies require their tools/list metadata to be echoed on tools/call. tool_meta = self._tool_call_meta_by_name.get(tool_name) - meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None) + request_meta = dict(tool_meta) if tool_meta is not None else None + if user_meta is not None: + request_meta = {**(request_meta or {}), **user_meta} + meta = _inject_otel_into_mcp_meta(request_meta) parser = self.parse_tool_results or self._parse_tool_result_from_mcp # Try the operation, reconnecting once if the connection is closed @@ -1336,28 +1380,33 @@ class MCPTool: return parser(result) except ToolExecutionException: raise - except ClosedResourceError as cl_ex: + except (ClosedResourceError, McpError) as call_ex: + is_session_terminated = ( + isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower() + ) + is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated + if not is_connection_lost: + error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex) + raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex + if attempt == 0: - # First attempt failed, try reconnecting - logger.info("MCP connection closed unexpectedly. Reconnecting...") + # First attempt failed, try reconnecting. + logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...") try: await self.connect(reset=True) - continue # Retry the operation + continue except Exception as reconn_ex: raise ToolExecutionException( "Failed to reconnect to MCP server.", inner_exception=reconn_ex, ) from reconn_ex - else: - # Second attempt also failed, give up - logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}") - raise ToolExecutionException( - f"Failed to call tool '{tool_name}' - connection lost.", - inner_exception=cl_ex, - ) from cl_ex - except McpError as mcp_exc: - error_message = mcp_exc.error.message - raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc + + # Second attempt also failed, give up. + logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex) + raise ToolExecutionException( + f"Failed to call tool '{tool_name}' - connection lost.", + inner_exception=call_ex, + ) from call_ex except Exception as ex: raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.") @@ -1718,10 +1767,11 @@ class MCPStreamableHTTPTool(MCPTool): Returns: An async context manager for the streamable HTTP client transport. """ - from httpx import AsyncClient, Request, Timeout + from httpx import URL, AsyncClient, Request, Timeout http_client = self._httpx_client if self._header_provider is not None: + target_origin = _url_origin(URL(self.url)) if http_client is None: http_client = AsyncClient( follow_redirects=True, @@ -1732,6 +1782,8 @@ class MCPStreamableHTTPTool(MCPTool): if not hasattr(self, "_inject_headers_hook"): async def _inject_headers(request: Request) -> None: # noqa: RUF029 + if _url_origin(request.url) != target_origin: + return headers = _mcp_call_headers.get({}) for key, value in headers.items(): request.headers[key] = value diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 6273eb76e6..519d8e5db3 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1161,6 +1161,43 @@ async def test_local_mcp_server_function_execution_error(): await func.invoke(param="test_value") +async def test_mcp_tool_reconnects_after_session_terminated_error(): + """Session termination errors should reconnect once and retry the tool call.""" + + class TestServer(MCPTool): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.connect_count = 0 + self.sessions: list[Any] = [] + + async def connect(self, *, reset: bool = False) -> None: + self.connect_count += 1 + self.session = Mock(spec=ClientSession) + self.sessions.append(self.session) + if self.connect_count == 1: + self.session.call_tool = AsyncMock( + side_effect=McpError(types.ErrorData(code=-32000, message="Session terminated")) + ) + else: + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="recovered")]) + ) + self.is_connected = True + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + await server.connect() + + result = await server.call_tool("test_tool", param="test_value") + + assert _mcp_result_to_text(result) == "recovered" + assert server.connect_count == 2 + assert server.sessions[0].call_tool.await_count == 1 + assert server.sessions[1].call_tool.await_count == 1 + + async def test_mcp_tool_call_tool_raises_on_is_error(): """Test that call_tool raises ToolExecutionException when MCP returns isError=True.""" @@ -3260,6 +3297,68 @@ async def test_load_prompts_pagination_with_duplicates(): assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"] +async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta(): + """Concurrent tool reloads should not duplicate functions or lose tools/list metadata.""" + tool = MCPTool(name="test_tool") + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + page = Mock() + page.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + _meta={"echo": "tool_1"}, + ), + ] + page.nextCursor = None + + async def mock_list_tools(params: Any = None) -> Any: + assert params is None + await asyncio.sleep(0) + return page + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1) + + assert mock_session.list_tools.call_count == 2 + assert [f.name for f in tool._functions] == ["tool_1"] + assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}} + + +async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts(): + """Concurrent prompt reloads should not duplicate functions.""" + tool = MCPTool(name="test_tool") + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + page = Mock() + page.prompts = [ + types.Prompt( + name="prompt_1", + description="First prompt", + arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)], + ), + ] + page.nextCursor = None + + async def mock_list_prompts(params: Any = None) -> Any: + assert params is None + await asyncio.sleep(0) + return page + + mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts) + + await asyncio.wait_for(asyncio.gather(tool.load_prompts(), tool.load_prompts()), timeout=1) + + assert mock_session.list_prompts.call_count == 2 + assert [f.name for f in tool._functions] == ["prompt_1"] + + async def test_load_tools_pagination_exception_handling(): """Test that load_tools handles exceptions during pagination gracefully.""" from unittest.mock import AsyncMock @@ -3891,6 +3990,31 @@ async def test_mcp_tool_safe_close_handles_cancelled_error(): mock_exit_stack.aclose.assert_called_once() +async def test_mcp_tool_safe_close_handles_cleanup_exception_group(): + """Cleanup task groups should not hide the original connect failure.""" + import builtins + from contextlib import AsyncExitStack + + exception_group_type = getattr(builtins, "ExceptionGroup", None) + if exception_group_type is None: + pytest.skip("ExceptionGroup is not available on this Python version") + + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + load_tools=False, + load_prompts=False, + ) + + mock_exit_stack = AsyncMock(spec=AsyncExitStack) + mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")])) + tool._exit_stack = mock_exit_stack + + await tool._safe_close_exit_stack() + + mock_exit_stack.aclose.assert_called_once() + + async def test_connect_sets_logging_level_when_logger_level_is_set(): """Test that connect() sets the MCP server logging level when the logger level is not NOTSET.""" @@ -4389,6 +4513,52 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta(): assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta +async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta(): + """User-provided _meta should be sent as MCP request metadata, not tool arguments.""" + from opentelemetry import trace + + tool_meta = {"from_tool": "tool-value", "shared": "tool-value"} + user_meta = {"from_user": "user-value", "shared": "user-value"} + + class TestServer(MCPTool): + async def connect(self) -> None: + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + _meta=tool_meta, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): + await server.call_tool("test_tool", param="test_value", _meta=user_meta) + + call_kwargs = server.session.call_tool.call_args.kwargs + assert call_kwargs["arguments"] == {"param": "test_value"} + assert call_kwargs["meta"] == { + "from_tool": "tool-value", + "from_user": "user-value", + "shared": "user-value", + } + assert user_meta == {"from_user": "user-value", "shared": "user-value"} + + async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client(): """Test that calling get_mcp_client multiple times does not accumulate duplicate hooks.""" tool = MCPStreamableHTTPTool( @@ -4641,6 +4811,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook(): await tool._httpx_client.aclose() +async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect(): + """The request hook must not re-add caller headers after a cross-origin redirect.""" + import httpx + + from agent_framework._mcp import _mcp_call_headers + + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"}, + ) + + try: + with patch("agent_framework._mcp.streamable_http_client"): + tool.get_mcp_client() + + assert tool._httpx_client is not None + hooks = tool._httpx_client.event_hooks.get("request", []) + assert len(hooks) == 1 + + token = _mcp_call_headers.set({"Authorization": "Bearer secret"}) + try: + same_origin = httpx.Request("POST", "http://example.com/redirected") + await hooks[0](same_origin) + assert same_origin.headers.get("Authorization") == "Bearer secret" + + cross_origin = httpx.Request("POST", "http://attacker.example/capture") + await hooks[0](cross_origin) + assert "Authorization" not in cross_origin.headers + finally: + _mcp_call_headers.reset(token) + finally: + if getattr(tool, "_httpx_client", None) is not None: + await tool._httpx_client.aclose() + + async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client(): """Test that header_provider works when the user provides their own httpx client.""" import httpx