Python: consolidate MCP reliability fixes (#6145)

* Python: consolidate MCP reliability fixes

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix MCP cleanup and metadata typing

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Satisfy MCP metadata mypy typing

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix Pyright metadata mapping type

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-05-29 09:21:14 +02:00
committed by GitHub
Unverified
parent d2d5384f28
commit e8ff541ebf
2 changed files with 278 additions and 20 deletions
+72 -20
View File
@@ -10,7 +10,7 @@ import logging
import re
import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Coroutine, Sequence
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
return meta
def _url_origin(url: Any) -> tuple[str, str, int | None]:
port = url.port
if port is None:
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
return (url.scheme, url.host or "", port)
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
"""Lazily import the MCP streamable HTTP transport."""
try:
@@ -255,6 +262,7 @@ class MCPTool:
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._function_load_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
@@ -655,6 +663,11 @@ class MCPTool:
raise
except asyncio.CancelledError:
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
except Exception as e:
if type(e).__name__ == "ExceptionGroup":
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
else:
raise
async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
@@ -1018,6 +1031,10 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_prompts_locked()
async def _load_prompts_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types
@@ -1100,6 +1117,10 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_tools_locked()
async def _load_tools_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types
@@ -1109,7 +1130,7 @@ class MCPTool:
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
self._tool_call_meta_by_name.clear()
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
params: types.PaginatedRequestParams | None = None
while True:
@@ -1145,7 +1166,7 @@ class MCPTool:
for tool in tool_list.tools:
if tool.meta is not None:
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
tool_call_meta_by_name[tool.name] = dict(tool.meta)
normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
@@ -1194,6 +1215,8 @@ class MCPTool:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
self._tool_call_meta_by_name = tool_call_meta_by_name
async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
tasks = list(self._pending_reload_tasks)
@@ -1276,7 +1299,11 @@ class MCPTool:
tool_name: The name of the tool to call.
Keyword Args:
kwargs: Arguments to pass to the tool.
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
non-conflicting keys.
kwargs: Remaining arguments to pass to the tool.
Returns:
A list of Content items representing the tool output. The default
@@ -1294,6 +1321,19 @@ class MCPTool:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
)
raw_user_meta: object | None = kwargs.get("_meta")
user_meta: dict[str, Any] | None = None
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
if isinstance(raw_user_meta, dict):
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
user_meta = {}
for key, value in raw_user_meta_dict.items():
if not isinstance(key, str):
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
user_meta[key] = value
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
@@ -1313,12 +1353,16 @@ class MCPTool:
"conversation_id",
"options",
"response_format",
"_meta",
}
}
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
tool_meta = self._tool_call_meta_by_name.get(tool_name)
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
request_meta = dict(tool_meta) if tool_meta is not None else None
if user_meta is not None:
request_meta = {**(request_meta or {}), **user_meta}
meta = _inject_otel_into_mcp_meta(request_meta)
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Try the operation, reconnecting once if the connection is closed
@@ -1336,28 +1380,33 @@ class MCPTool:
return parser(result)
except ToolExecutionException:
raise
except ClosedResourceError as cl_ex:
except (ClosedResourceError, McpError) as call_ex:
is_session_terminated = (
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
)
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
if not is_connection_lost:
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex
if attempt == 0:
# First attempt failed, try reconnecting
logger.info("MCP connection closed unexpectedly. Reconnecting...")
# First attempt failed, try reconnecting.
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
try:
await self.connect(reset=True)
continue # Retry the operation
continue
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
else:
# Second attempt also failed, give up
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
error_message = mcp_exc.error.message
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
# Second attempt also failed, give up.
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=call_ex,
) from call_ex
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
@@ -1718,10 +1767,11 @@ class MCPStreamableHTTPTool(MCPTool):
Returns:
An async context manager for the streamable HTTP client transport.
"""
from httpx import AsyncClient, Request, Timeout
from httpx import URL, AsyncClient, Request, Timeout
http_client = self._httpx_client
if self._header_provider is not None:
target_origin = _url_origin(URL(self.url))
if http_client is None:
http_client = AsyncClient(
follow_redirects=True,
@@ -1732,6 +1782,8 @@ class MCPStreamableHTTPTool(MCPTool):
if not hasattr(self, "_inject_headers_hook"):
async def _inject_headers(request: Request) -> None: # noqa: RUF029
if _url_origin(request.url) != target_origin:
return
headers = _mcp_call_headers.get({})
for key, value in headers.items():
request.headers[key] = value
+206
View File
@@ -1161,6 +1161,43 @@ async def test_local_mcp_server_function_execution_error():
await func.invoke(param="test_value")
async def test_mcp_tool_reconnects_after_session_terminated_error():
"""Session termination errors should reconnect once and retry the tool call."""
class TestServer(MCPTool):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.connect_count = 0
self.sessions: list[Any] = []
async def connect(self, *, reset: bool = False) -> None:
self.connect_count += 1
self.session = Mock(spec=ClientSession)
self.sessions.append(self.session)
if self.connect_count == 1:
self.session.call_tool = AsyncMock(
side_effect=McpError(types.ErrorData(code=-32000, message="Session terminated"))
)
else:
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="recovered")])
)
self.is_connected = True
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server")
await server.connect()
result = await server.call_tool("test_tool", param="test_value")
assert _mcp_result_to_text(result) == "recovered"
assert server.connect_count == 2
assert server.sessions[0].call_tool.await_count == 1
assert server.sessions[1].call_tool.await_count == 1
async def test_mcp_tool_call_tool_raises_on_is_error():
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""
@@ -3260,6 +3297,68 @@ async def test_load_prompts_pagination_with_duplicates():
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta():
"""Concurrent tool reloads should not duplicate functions or lose tools/list metadata."""
tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
page = Mock()
page.tools = [
types.Tool(
name="tool_1",
description="First tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta={"echo": "tool_1"},
),
]
page.nextCursor = None
async def mock_list_tools(params: Any = None) -> Any:
assert params is None
await asyncio.sleep(0)
return page
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)
assert mock_session.list_tools.call_count == 2
assert [f.name for f in tool._functions] == ["tool_1"]
assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}}
async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts():
"""Concurrent prompt reloads should not duplicate functions."""
tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
page = Mock()
page.prompts = [
types.Prompt(
name="prompt_1",
description="First prompt",
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
),
]
page.nextCursor = None
async def mock_list_prompts(params: Any = None) -> Any:
assert params is None
await asyncio.sleep(0)
return page
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
await asyncio.wait_for(asyncio.gather(tool.load_prompts(), tool.load_prompts()), timeout=1)
assert mock_session.list_prompts.call_count == 2
assert [f.name for f in tool._functions] == ["prompt_1"]
async def test_load_tools_pagination_exception_handling():
"""Test that load_tools handles exceptions during pagination gracefully."""
from unittest.mock import AsyncMock
@@ -3891,6 +3990,31 @@ async def test_mcp_tool_safe_close_handles_cancelled_error():
mock_exit_stack.aclose.assert_called_once()
async def test_mcp_tool_safe_close_handles_cleanup_exception_group():
"""Cleanup task groups should not hide the original connect failure."""
import builtins
from contextlib import AsyncExitStack
exception_group_type = getattr(builtins, "ExceptionGroup", None)
if exception_group_type is None:
pytest.skip("ExceptionGroup is not available on this Python version")
tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
load_tools=False,
load_prompts=False,
)
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")]))
tool._exit_stack = mock_exit_stack
await tool._safe_close_exit_stack()
mock_exit_stack.aclose.assert_called_once()
async def test_connect_sets_logging_level_when_logger_level_is_set():
"""Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""
@@ -4389,6 +4513,52 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta():
assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta
async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta():
"""User-provided _meta should be sent as MCP request metadata, not tool arguments."""
from opentelemetry import trace
tool_meta = {"from_tool": "tool-value", "shared": "tool-value"}
user_meta = {"from_user": "user-value", "shared": "user-value"}
class TestServer(MCPTool):
async def connect(self) -> None:
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta=tool_meta,
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")])
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server")
async with server:
await server.load_tools()
with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)):
await server.call_tool("test_tool", param="test_value", _meta=user_meta)
call_kwargs = server.session.call_tool.call_args.kwargs
assert call_kwargs["arguments"] == {"param": "test_value"}
assert call_kwargs["meta"] == {
"from_tool": "tool-value",
"from_user": "user-value",
"shared": "user-value",
}
assert user_meta == {"from_user": "user-value", "shared": "user-value"}
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
tool = MCPStreamableHTTPTool(
@@ -4641,6 +4811,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
await tool._httpx_client.aclose()
async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect():
"""The request hook must not re-add caller headers after a cross-origin redirect."""
import httpx
from agent_framework._mcp import _mcp_call_headers
tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"},
)
try:
with patch("agent_framework._mcp.streamable_http_client"):
tool.get_mcp_client()
assert tool._httpx_client is not None
hooks = tool._httpx_client.event_hooks.get("request", [])
assert len(hooks) == 1
token = _mcp_call_headers.set({"Authorization": "Bearer secret"})
try:
same_origin = httpx.Request("POST", "http://example.com/redirected")
await hooks[0](same_origin)
assert same_origin.headers.get("Authorization") == "Bearer secret"
cross_origin = httpx.Request("POST", "http://attacker.example/capture")
await hooks[0](cross_origin)
assert "Authorization" not in cross_origin.headers
finally:
_mcp_call_headers.reset(token)
finally:
if getattr(tool, "_httpx_client", None) is not None:
await tool._httpx_client.aclose()
async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client():
"""Test that header_provider works when the user provides their own httpx client."""
import httpx