mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: consolidate MCP reliability fixes (#6145)
* Python: consolidate MCP reliability fixes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix MCP cleanup and metadata typing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Satisfy MCP metadata mypy typing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix Pyright metadata mapping type Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
d2d5384f28
commit
e8ff541ebf
@@ -10,7 +10,7 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Collection, Coroutine, Sequence
|
||||
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
|
||||
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
|
||||
return meta
|
||||
|
||||
|
||||
def _url_origin(url: Any) -> tuple[str, str, int | None]:
|
||||
port = url.port
|
||||
if port is None:
|
||||
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
|
||||
return (url.scheme, url.host or "", port)
|
||||
|
||||
|
||||
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
"""Lazily import the MCP streamable HTTP transport."""
|
||||
try:
|
||||
@@ -255,6 +262,7 @@ class MCPTool:
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
self._lifecycle_request_lock = asyncio.Lock()
|
||||
self._function_load_lock = asyncio.Lock()
|
||||
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
|
||||
self._lifecycle_owner_task: asyncio.Task[None] | None = None
|
||||
self.session = session
|
||||
@@ -655,6 +663,11 @@ class MCPTool:
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "ExceptionGroup":
|
||||
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
|
||||
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
|
||||
@@ -1018,6 +1031,10 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
async with self._function_load_lock:
|
||||
await self._load_prompts_locked()
|
||||
|
||||
async def _load_prompts_locked(self) -> None:
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
@@ -1100,6 +1117,10 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
async with self._function_load_lock:
|
||||
await self._load_tools_locked()
|
||||
|
||||
async def _load_tools_locked(self) -> None:
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
@@ -1109,7 +1130,7 @@ class MCPTool:
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
self._tool_call_meta_by_name.clear()
|
||||
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
@@ -1145,7 +1166,7 @@ class MCPTool:
|
||||
|
||||
for tool in tool_list.tools:
|
||||
if tool.meta is not None:
|
||||
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
@@ -1194,6 +1215,8 @@ class MCPTool:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
self._tool_call_meta_by_name = tool_call_meta_by_name
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
# Cancel any pending reload tasks before tearing down the session.
|
||||
tasks = list(self._pending_reload_tasks)
|
||||
@@ -1276,7 +1299,11 @@ class MCPTool:
|
||||
tool_name: The name of the tool to call.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Arguments to pass to the tool.
|
||||
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
|
||||
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
|
||||
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
|
||||
non-conflicting keys.
|
||||
kwargs: Remaining arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
A list of Content items representing the tool output. The default
|
||||
@@ -1294,6 +1321,19 @@ class MCPTool:
|
||||
raise ToolExecutionException(
|
||||
"Tools are not loaded for this server, please set load_tools=True in the constructor."
|
||||
)
|
||||
|
||||
raw_user_meta: object | None = kwargs.get("_meta")
|
||||
user_meta: dict[str, Any] | None = None
|
||||
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
|
||||
if isinstance(raw_user_meta, dict):
|
||||
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
|
||||
user_meta = {}
|
||||
for key, value in raw_user_meta_dict.items():
|
||||
if not isinstance(key, str):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
|
||||
user_meta[key] = value
|
||||
|
||||
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
|
||||
# These are internal objects passed through the function invocation pipeline
|
||||
# that should not be forwarded to external MCP servers.
|
||||
@@ -1313,12 +1353,16 @@ class MCPTool:
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
}
|
||||
}
|
||||
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
tool_meta = self._tool_call_meta_by_name.get(tool_name)
|
||||
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
|
||||
request_meta = dict(tool_meta) if tool_meta is not None else None
|
||||
if user_meta is not None:
|
||||
request_meta = {**(request_meta or {}), **user_meta}
|
||||
meta = _inject_otel_into_mcp_meta(request_meta)
|
||||
|
||||
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
|
||||
# Try the operation, reconnecting once if the connection is closed
|
||||
@@ -1336,28 +1380,33 @@ class MCPTool:
|
||||
return parser(result)
|
||||
except ToolExecutionException:
|
||||
raise
|
||||
except ClosedResourceError as cl_ex:
|
||||
except (ClosedResourceError, McpError) as call_ex:
|
||||
is_session_terminated = (
|
||||
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
|
||||
)
|
||||
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
|
||||
if not is_connection_lost:
|
||||
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
|
||||
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex
|
||||
|
||||
if attempt == 0:
|
||||
# First attempt failed, try reconnecting
|
||||
logger.info("MCP connection closed unexpectedly. Reconnecting...")
|
||||
# First attempt failed, try reconnecting.
|
||||
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
continue # Retry the operation
|
||||
continue
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
else:
|
||||
# Second attempt also failed, give up
|
||||
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost.",
|
||||
inner_exception=cl_ex,
|
||||
) from cl_ex
|
||||
except McpError as mcp_exc:
|
||||
error_message = mcp_exc.error.message
|
||||
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
|
||||
|
||||
# Second attempt also failed, give up.
|
||||
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost.",
|
||||
inner_exception=call_ex,
|
||||
) from call_ex
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
|
||||
@@ -1718,10 +1767,11 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
Returns:
|
||||
An async context manager for the streamable HTTP client transport.
|
||||
"""
|
||||
from httpx import AsyncClient, Request, Timeout
|
||||
from httpx import URL, AsyncClient, Request, Timeout
|
||||
|
||||
http_client = self._httpx_client
|
||||
if self._header_provider is not None:
|
||||
target_origin = _url_origin(URL(self.url))
|
||||
if http_client is None:
|
||||
http_client = AsyncClient(
|
||||
follow_redirects=True,
|
||||
@@ -1732,6 +1782,8 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
if not hasattr(self, "_inject_headers_hook"):
|
||||
|
||||
async def _inject_headers(request: Request) -> None: # noqa: RUF029
|
||||
if _url_origin(request.url) != target_origin:
|
||||
return
|
||||
headers = _mcp_call_headers.get({})
|
||||
for key, value in headers.items():
|
||||
request.headers[key] = value
|
||||
|
||||
@@ -1161,6 +1161,43 @@ async def test_local_mcp_server_function_execution_error():
|
||||
await func.invoke(param="test_value")
|
||||
|
||||
|
||||
async def test_mcp_tool_reconnects_after_session_terminated_error():
|
||||
"""Session termination errors should reconnect once and retry the tool call."""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.connect_count = 0
|
||||
self.sessions: list[Any] = []
|
||||
|
||||
async def connect(self, *, reset: bool = False) -> None:
|
||||
self.connect_count += 1
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.sessions.append(self.session)
|
||||
if self.connect_count == 1:
|
||||
self.session.call_tool = AsyncMock(
|
||||
side_effect=McpError(types.ErrorData(code=-32000, message="Session terminated"))
|
||||
)
|
||||
else:
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="recovered")])
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
await server.connect()
|
||||
|
||||
result = await server.call_tool("test_tool", param="test_value")
|
||||
|
||||
assert _mcp_result_to_text(result) == "recovered"
|
||||
assert server.connect_count == 2
|
||||
assert server.sessions[0].call_tool.await_count == 1
|
||||
assert server.sessions[1].call_tool.await_count == 1
|
||||
|
||||
|
||||
async def test_mcp_tool_call_tool_raises_on_is_error():
|
||||
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""
|
||||
|
||||
@@ -3260,6 +3297,68 @@ async def test_load_prompts_pagination_with_duplicates():
|
||||
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
|
||||
|
||||
|
||||
async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta():
|
||||
"""Concurrent tool reloads should not duplicate functions or lose tools/list metadata."""
|
||||
tool = MCPTool(name="test_tool")
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.tools = [
|
||||
types.Tool(
|
||||
name="tool_1",
|
||||
description="First tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
_meta={"echo": "tool_1"},
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
|
||||
async def mock_list_tools(params: Any = None) -> Any:
|
||||
assert params is None
|
||||
await asyncio.sleep(0)
|
||||
return page
|
||||
|
||||
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)
|
||||
|
||||
assert mock_session.list_tools.call_count == 2
|
||||
assert [f.name for f in tool._functions] == ["tool_1"]
|
||||
assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}}
|
||||
|
||||
|
||||
async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts():
|
||||
"""Concurrent prompt reloads should not duplicate functions."""
|
||||
tool = MCPTool(name="test_tool")
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.prompts = [
|
||||
types.Prompt(
|
||||
name="prompt_1",
|
||||
description="First prompt",
|
||||
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
|
||||
async def mock_list_prompts(params: Any = None) -> Any:
|
||||
assert params is None
|
||||
await asyncio.sleep(0)
|
||||
return page
|
||||
|
||||
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(tool.load_prompts(), tool.load_prompts()), timeout=1)
|
||||
|
||||
assert mock_session.list_prompts.call_count == 2
|
||||
assert [f.name for f in tool._functions] == ["prompt_1"]
|
||||
|
||||
|
||||
async def test_load_tools_pagination_exception_handling():
|
||||
"""Test that load_tools handles exceptions during pagination gracefully."""
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -3891,6 +3990,31 @@ async def test_mcp_tool_safe_close_handles_cancelled_error():
|
||||
mock_exit_stack.aclose.assert_called_once()
|
||||
|
||||
|
||||
async def test_mcp_tool_safe_close_handles_cleanup_exception_group():
|
||||
"""Cleanup task groups should not hide the original connect failure."""
|
||||
import builtins
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
exception_group_type = getattr(builtins, "ExceptionGroup", None)
|
||||
if exception_group_type is None:
|
||||
pytest.skip("ExceptionGroup is not available on this Python version")
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
load_tools=False,
|
||||
load_prompts=False,
|
||||
)
|
||||
|
||||
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
|
||||
mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")]))
|
||||
tool._exit_stack = mock_exit_stack
|
||||
|
||||
await tool._safe_close_exit_stack()
|
||||
|
||||
mock_exit_stack.aclose.assert_called_once()
|
||||
|
||||
|
||||
async def test_connect_sets_logging_level_when_logger_level_is_set():
|
||||
"""Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""
|
||||
|
||||
@@ -4389,6 +4513,52 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta():
|
||||
assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta
|
||||
|
||||
|
||||
async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta():
|
||||
"""User-provided _meta should be sent as MCP request metadata, not tool arguments."""
|
||||
from opentelemetry import trace
|
||||
|
||||
tool_meta = {"from_tool": "tool-value", "shared": "tool-value"}
|
||||
user_meta = {"from_user": "user-value", "shared": "user-value"}
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self) -> None:
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
_meta=tool_meta,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")])
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
|
||||
with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)):
|
||||
await server.call_tool("test_tool", param="test_value", _meta=user_meta)
|
||||
|
||||
call_kwargs = server.session.call_tool.call_args.kwargs
|
||||
assert call_kwargs["arguments"] == {"param": "test_value"}
|
||||
assert call_kwargs["meta"] == {
|
||||
"from_tool": "tool-value",
|
||||
"from_user": "user-value",
|
||||
"shared": "user-value",
|
||||
}
|
||||
assert user_meta == {"from_user": "user-value", "shared": "user-value"}
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
|
||||
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
|
||||
tool = MCPStreamableHTTPTool(
|
||||
@@ -4641,6 +4811,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect():
|
||||
"""The request hook must not re-add caller headers after a cross-origin redirect."""
|
||||
import httpx
|
||||
|
||||
from agent_framework._mcp import _mcp_call_headers
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"},
|
||||
)
|
||||
|
||||
try:
|
||||
with patch("agent_framework._mcp.streamable_http_client"):
|
||||
tool.get_mcp_client()
|
||||
|
||||
assert tool._httpx_client is not None
|
||||
hooks = tool._httpx_client.event_hooks.get("request", [])
|
||||
assert len(hooks) == 1
|
||||
|
||||
token = _mcp_call_headers.set({"Authorization": "Bearer secret"})
|
||||
try:
|
||||
same_origin = httpx.Request("POST", "http://example.com/redirected")
|
||||
await hooks[0](same_origin)
|
||||
assert same_origin.headers.get("Authorization") == "Bearer secret"
|
||||
|
||||
cross_origin = httpx.Request("POST", "http://attacker.example/capture")
|
||||
await hooks[0](cross_origin)
|
||||
assert "Authorization" not in cross_origin.headers
|
||||
finally:
|
||||
_mcp_call_headers.reset(token)
|
||||
finally:
|
||||
if getattr(tool, "_httpx_client", None) is not None:
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client():
|
||||
"""Test that header_provider works when the user provides their own httpx client."""
|
||||
import httpx
|
||||
|
||||
Reference in New Issue
Block a user