diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 9dfb29932f..0d85b1699a 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -158,6 +158,22 @@ def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextM return _streamable_http_client(*args, **kwargs) # type: ignore[return-value] +def _should_propagate_cancelled_error(ex: BaseException) -> bool: + """Return True if *ex* is a genuine task-cancellation that should propagate unchanged. + + On Python >= 3.11, ``task.cancelling() > 0`` distinguishes a real caller-driven + cancellation from a CancelledError raised internally by a library (e.g. via an + anyio cancel scope). On older Python versions the API is unavailable, so we + always return False and let callers wrap the error in ToolException instead. + """ + if not isinstance(ex, asyncio.CancelledError): + return False + if sys.version_info < (3, 11): + return False + task = asyncio.current_task() + return task is not None and task.cancelling() > 0 + + # region: MCP Plugin @@ -627,6 +643,17 @@ class MCPTool: except asyncio.CancelledError: logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.") + async def _close_and_check_cancelled(self, ex: BaseException) -> bool: + """Close the exit stack and return True if *ex* is a genuine task cancellation. + + Callers should immediately re-raise when this returns True:: + + if await self._close_and_check_cancelled(ex): + raise + """ + await self._safe_close_exit_stack() + return _should_propagate_cancelled_error(ex) + async def connect(self, *, reset: bool = False) -> None: if self._is_lifecycle_owner_task(): await self._connect_on_owner(reset=reset) @@ -655,14 +682,23 @@ class MCPTool: if not self.session: try: transport = await self._exit_stack.enter_async_context(self.get_mcp_client()) - except Exception as ex: - await self._safe_close_exit_stack() + except (Exception, asyncio.CancelledError) as ex: + # On Python >= 3.11, re-raise genuine task cancellation (task.cancelling() > 0) + # instead of wrapping it in ToolException. On Python < 3.11, task.cancelling() + # is unavailable so MCP-internal CancelledErrors cannot be distinguished from + # caller-driven cancellation; they are wrapped as ToolException in that case. + if await self._close_and_check_cancelled(ex): + raise command = getattr(self, "command", None) if command: error_msg = f"Failed to start MCP server '{command}': {ex}" else: error_msg = f"Failed to connect to MCP server: {ex}" - raise ToolException(error_msg, inner_exception=ex) from ex + # CancelledError is a BaseException (not Exception) on Python >= 3.8, so + # inner_exception=None and ToolException.__init__ won't log exc_info. + if isinstance(ex, asyncio.CancelledError): + logger.debug(error_msg, exc_info=True) + raise ToolException(error_msg, inner_exception=ex if isinstance(ex, Exception) else None) from ex try: try: from mcp import types @@ -692,16 +728,21 @@ class MCPTool: sampling_capabilities=sampling_capabilities, ) ) - except Exception as ex: - await self._safe_close_exit_stack() + except (Exception, asyncio.CancelledError) as ex: + if await self._close_and_check_cancelled(ex): + raise + session_error_msg = f"Failed to create MCP session: {ex}" + if isinstance(ex, asyncio.CancelledError): + logger.debug(session_error_msg, exc_info=True) raise ToolException( - message="Failed to create MCP session. Please check your configuration.", - inner_exception=ex, + message=session_error_msg, + inner_exception=ex if isinstance(ex, Exception) else None, ) from ex try: await session.initialize() - except Exception as ex: - await self._safe_close_exit_stack() + except (Exception, asyncio.CancelledError) as ex: + if await self._close_and_check_cancelled(ex): + raise # Provide context about initialization failure command = getattr(self, "command", None) if command: @@ -710,7 +751,9 @@ class MCPTool: error_msg = f"MCP server '{full_command}' failed to initialize: {ex}" else: error_msg = f"MCP server failed to initialize: {ex}" - raise ToolException(error_msg, inner_exception=ex) from ex + if isinstance(ex, asyncio.CancelledError): + logger.debug(error_msg, exc_info=True) + raise ToolException(error_msg, inner_exception=ex if isinstance(ex, Exception) else None) from ex self.session = session elif self.session._request_id == 0: # type: ignore[attr-defined] # If the session is not initialized, we need to reinitialize it diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 01cf1717bd..487331e3f0 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. # type: ignore[reportPrivateUsage] +import asyncio import json import logging import os +import sys from contextlib import _AsyncGeneratorContextManager # type: ignore from typing import Any from unittest.mock import AsyncMock, Mock, patch @@ -27,6 +29,7 @@ from agent_framework._mcp import ( _build_prefixed_mcp_name, _get_input_model_from_mcp_prompt, _normalize_mcp_name, + _should_propagate_cancelled_error, logger, ) from agent_framework._middleware import FunctionMiddlewarePipeline @@ -2176,6 +2179,7 @@ async def test_connect_session_creation_failure(): await tool.connect() assert "Failed to create MCP session" in str(exc_info.value) + assert "Session creation failed" in str(exc_info.value) # exception text is now part of the message assert "Session creation failed" in str(exc_info.value.__cause__) @@ -2264,6 +2268,282 @@ async def test_connect_cleanup_on_initialization_failure(): tool._exit_stack.aclose.assert_called_once() +async def test_connect_cancelled_error_during_transport_creation_raises_tool_exception(): + """Test that CancelledError from transport creation is wrapped in ToolException.""" + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + tool._exit_stack.aclose = AsyncMock() + tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("cancel scope")) + + with pytest.raises(ToolException, match="Failed to connect to MCP server"): + await tool.connect() + + tool._exit_stack.aclose.assert_called_once() + + +async def test_connect_cancelled_error_during_transport_creation_stdio_raises_tool_exception(): + """Test that CancelledError from transport creation uses the command-specific message for MCPStdioTool.""" + tool = MCPStdioTool(name="test", command="my-server") + tool._exit_stack.aclose = AsyncMock() + tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("cancel scope")) + + with pytest.raises(ToolException, match="Failed to start MCP server 'my-server'"): + await tool.connect() + + tool._exit_stack.aclose.assert_called_once() + + +async def test_connect_cancelled_error_during_session_creation_raises_tool_exception(): + """Test that CancelledError from session creation is wrapped in ToolException.""" + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + with patch("mcp.client.session.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock(side_effect=asyncio.CancelledError("cancel scope")) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(ToolException, match="Failed to create MCP session"): + await tool.connect() + + +async def test_connect_cancelled_error_during_initialize_raises_tool_exception(): + """Test that CancelledError from session.initialize() is wrapped in ToolException. + + This is the primary regression test for the bug: when an MCP server is unreachable, + the MCP library raises asyncio.CancelledError internally, which previously escaped + all except Exception handlers and could not be caught by user code. + """ + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + mock_session = Mock() + mock_session.initialize = AsyncMock(side_effect=asyncio.CancelledError("Cancelled via cancel scope")) + + with patch("mcp.client.session.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(ToolException, match="MCP server failed to initialize"): + await tool.connect() + + +async def test_connect_cancelled_error_during_initialize_stdio_raises_tool_exception(): + """Test that CancelledError from session.initialize() uses the command-specific message for MCPStdioTool.""" + tool = MCPStdioTool(name="test", command="my-server", args=["--port", "8080"]) + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + mock_session = Mock() + mock_session.initialize = AsyncMock(side_effect=asyncio.CancelledError("Cancelled via cancel scope")) + + with patch("mcp.client.session.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(ToolException, match="MCP server 'my-server --port 8080' failed to initialize"): + await tool.connect() + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="task.cancelling() requires Python >= 3.11") +async def test_connect_genuine_cancellation_during_transport_creation_propagates(): + """Test that genuine task cancellation (task.cancelling() > 0) propagates as CancelledError.""" + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + tool._exit_stack.aclose = AsyncMock() + + mock_cancelled_task = Mock() + mock_cancelled_task.cancelling.return_value = 1 + + with patch("asyncio.current_task", return_value=mock_cancelled_task): + tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("task cancelled")) + with pytest.raises(asyncio.CancelledError): + await tool.connect() + + tool._exit_stack.aclose.assert_called_once() + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="task.cancelling() requires Python >= 3.11") +async def test_connect_genuine_cancellation_during_initialize_propagates(): + """Test that genuine task cancellation during initialize() propagates as CancelledError.""" + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + tool._exit_stack.aclose = AsyncMock() + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + mock_session = Mock() + mock_session.initialize = AsyncMock(side_effect=asyncio.CancelledError("task cancelled")) + + mock_cancelled_task = Mock() + mock_cancelled_task.cancelling.return_value = 1 + + with ( + patch("asyncio.current_task", return_value=mock_cancelled_task), + patch("mcp.client.session.ClientSession") as mock_session_class, + ): + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(asyncio.CancelledError): + await tool.connect() + + tool._exit_stack.aclose.assert_called_once() + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="task.cancelling() requires Python >= 3.11") +async def test_connect_genuine_cancellation_during_session_creation_propagates(): + """Test that genuine task cancellation during session creation propagates as CancelledError.""" + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + tool._exit_stack.aclose = AsyncMock() + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + mock_cancelled_task = Mock() + mock_cancelled_task.cancelling.return_value = 1 + + with ( + patch("asyncio.current_task", return_value=mock_cancelled_task), + patch("mcp.client.session.ClientSession") as mock_session_class, + ): + mock_session_class.return_value.__aenter__ = AsyncMock(side_effect=asyncio.CancelledError("task cancelled")) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(asyncio.CancelledError): + await tool.connect() + + tool._exit_stack.aclose.assert_called_once() + + +async def test_aenter_cancelled_error_during_connect_is_catchable_as_exception(): + """Test that CancelledError during __aenter__ is catchable as Exception. + + Verifies the end-to-end fix: async with MCPStreamableHTTPTool(...) raises an + exception that can be caught by a normal `except Exception` block. + """ + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + + mock_session = Mock() + mock_session.initialize = AsyncMock(side_effect=asyncio.CancelledError("Cancelled via cancel scope")) + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + with patch("mcp.client.session.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + caught = None + try: + async with tool: + pass + except Exception as e: + caught = e + + assert caught is not None, "Expected an exception to be caught by except Exception" + assert isinstance(caught, ToolException) + + +# Tests for _should_propagate_cancelled_error helper + + +def test_should_propagate_cancelled_error_returns_false_for_non_cancelled_error(): + assert _should_propagate_cancelled_error(RuntimeError("boom")) is False + + +def test_should_propagate_cancelled_error_returns_false_when_no_current_task(): + with patch("asyncio.current_task", return_value=None): + assert _should_propagate_cancelled_error(asyncio.CancelledError()) is False + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="task.cancelling() requires Python >= 3.11") +def test_should_propagate_cancelled_error_returns_true_when_task_is_cancelling(): + mock_task = Mock() + mock_task.cancelling.return_value = 1 + with patch("asyncio.current_task", return_value=mock_task): + assert _should_propagate_cancelled_error(asyncio.CancelledError()) is True + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="task.cancelling() requires Python >= 3.11") +def test_should_propagate_cancelled_error_returns_false_when_task_not_cancelling(): + mock_task = Mock() + mock_task.cancelling.return_value = 0 + with patch("asyncio.current_task", return_value=mock_task): + assert _should_propagate_cancelled_error(asyncio.CancelledError()) is False + + +async def test_connect_cancelled_error_during_session_creation_includes_exception_in_message(): + """Test that CancelledError from session creation includes exception details in ToolException message.""" + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + with patch("mcp.client.session.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock( + side_effect=asyncio.CancelledError("cancel scope detail") + ) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + with pytest.raises(ToolException) as exc_info: + await tool.connect() + + assert "Failed to create MCP session" in str(exc_info.value) + assert "cancel scope detail" in str(exc_info.value) + + +async def test_connect_cancelled_error_during_session_creation_logs_with_exc_info(): + """Test that CancelledError from session creation is logged with exc_info=True.""" + tool = MCPStreamableHTTPTool(name="test", url="http://example.com") + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + with patch("mcp.client.session.ClientSession") as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock(side_effect=asyncio.CancelledError("cancel scope")) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + from agent_framework._mcp import logger as mcp_logger + + with patch.object(mcp_logger, "debug") as mock_debug: + with pytest.raises(ToolException): + await tool.connect() + + # Verify logger.debug was called with exc_info=True (not an exception instance) + debug_calls = mock_debug.call_args_list + cancel_calls = [c for c in debug_calls if "Failed to create MCP session" in str(c)] + assert cancel_calls, "Expected a debug log for the cancelled session creation" + _, kwargs = cancel_calls[0] + assert kwargs.get("exc_info") is True + + def test_mcp_stdio_tool_get_mcp_client_with_env_and_kwargs(): """Test MCPStdioTool.get_mcp_client() with environment variables and client kwargs.""" env_vars = {"PATH": "/usr/bin", "DEBUG": "1"}