diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 81227c7e73..5901e34dd9 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -480,6 +480,10 @@ class MCPTool: self.load_prompts_flag = load_prompts self.parse_prompt_results = parse_prompt_results self._exit_stack = AsyncExitStack() + self._lifecycle_lock = asyncio.Lock() + self._lifecycle_request_lock = asyncio.Lock() + self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None + self._lifecycle_owner_task: asyncio.Task[None] | None = None self.session = session self.request_timeout = request_timeout self.client = client @@ -510,39 +514,113 @@ class MCPTool: filtered_functions.append(func) return filtered_functions + async def _ensure_lifecycle_owner(self) -> None: + async with self._lifecycle_lock: + if self._lifecycle_owner_task is not None and not self._lifecycle_owner_task.done(): + return + + self._lifecycle_queue = asyncio.Queue() + self._lifecycle_owner_task = asyncio.create_task( + self._run_lifecycle_owner(), + name=f"mcp-lifecycle:{self.name}", + ) + + async def _run_lifecycle_owner(self) -> None: + queue = self._lifecycle_queue + if queue is None: + return + + stop_error: BaseException | None = None + try: + while True: + action, reset, future = await queue.get() + + try: + if action == "connect": + await self._connect_on_owner(reset=reset) + elif action == "close": + await self._close_on_owner() + else: + raise RuntimeError(f"Unknown MCP lifecycle action: {action}") + except asyncio.CancelledError as ex: + stop_error = ex + if not future.done(): + future.set_exception(ex) + raise + except Exception as ex: + if not future.done(): + future.set_exception(ex) + else: + if not future.done(): + future.set_result(None) + + if action == "close": + return + except asyncio.CancelledError as ex: + stop_error = ex + raise + finally: + while True: + try: + _, _, future = queue.get_nowait() + except asyncio.QueueEmpty: + break + if not future.done(): + future.set_exception(stop_error or RuntimeError("MCP lifecycle owner stopped unexpectedly.")) + + self._lifecycle_queue = None + self._lifecycle_owner_task = None + + def _is_lifecycle_owner_task(self) -> bool: + owner_task = self._lifecycle_owner_task + return owner_task is not None and asyncio.current_task() is owner_task + + async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> None: + await self._ensure_lifecycle_owner() + + if self._is_lifecycle_owner_task(): + if action == "connect": + await self._connect_on_owner(reset=reset) + elif action == "close": + await self._close_on_owner() + else: + raise RuntimeError(f"Unknown MCP lifecycle action: {action}") + return + + queue = self._lifecycle_queue + if queue is None: + raise RuntimeError("MCP lifecycle owner is not available.") + + future = asyncio.get_running_loop().create_future() + await queue.put((action, reset, future)) + await future + async def _safe_close_exit_stack(self) -> None: - """Safely close the exit stack, handling cross-task boundary errors. - - anyio's cancel scopes are bound to the task they were created in. - If aclose() is called from a different task (e.g., during streaming reconnection), - anyio will raise a RuntimeError or CancelledError. In this case, we log a warning - and allow garbage collection to clean up the resources. - - Known error variants: - - "Attempted to exit cancel scope in a different task than it was entered in" - - "Attempted to exit a cancel scope that isn't the current task's current cancel scope" - - CancelledError from anyio cancel scope cleanup - """ + """Safely close the exit stack, handling unexpected cleanup failures.""" try: await self._exit_stack.aclose() except RuntimeError as e: error_msg = str(e).lower() - # Check for anyio cancel scope errors (multiple variants exist) if "cancel scope" in error_msg: logger.warning( "Could not cleanly close MCP exit stack due to cancel scope error. " - "Old resources will be garbage collected. Error: %s", + "This indicates MCP lifecycle ownership was lost. Error: %s", e, ) else: raise except asyncio.CancelledError: - # CancelledError can occur during cleanup when cancel scopes are involved - logger.warning( - "Could not cleanly close MCP exit stack due to cancellation. Old resources will be garbage collected." - ) + logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.") async def connect(self, *, reset: bool = False) -> None: + if self._is_lifecycle_owner_task(): + await self._connect_on_owner(reset=reset) + return + + async with self._lifecycle_request_lock: + await self._run_on_lifecycle_owner("connect", reset=reset) + + async def _connect_on_owner(self, *, reset: bool = False) -> None: """Connect to the MCP server. Establishes a connection to the MCP server, initializes the session, @@ -844,14 +922,23 @@ class MCPTool: break params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) + async def _close_on_owner(self) -> None: + await self._safe_close_exit_stack() + self._exit_stack = AsyncExitStack() + self.session = None + self.is_connected = False + async def close(self) -> None: """Disconnect from the MCP server. Closes the connection and cleans up resources. """ - await self._safe_close_exit_stack() - self.session = None - self.is_connected = False + if self._is_lifecycle_owner_task(): + await self._close_on_owner() + return + + async with self._lifecycle_request_lock: + await self._run_on_lifecycle_owner("close") @abstractmethod def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: @@ -1043,7 +1130,7 @@ class MCPTool: except ToolException: raise except Exception as ex: - await self._safe_close_exit_stack() + await self.close() raise ToolExecutionException("Failed to enter context manager.", inner_exception=ex) from ex async def __aexit__( diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 70aff972fe..b29ec1a794 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -2525,67 +2525,169 @@ async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error(): assert "failed to reconnect" in str(exc_info.value).lower() -async def test_mcp_tool_reconnection_handles_cross_task_cancel_scope_error(): - """Test that reconnection gracefully handles anyio cancel scope errors. +async def test_mcp_tool_close_cleans_up_in_original_task(caplog): + """Closing an MCP tool from another task should still unwind contexts in the owner task.""" + import asyncio - This tests the fix for the bug where calling connect(reset=True) from a - different task than where the connection was originally established would - cause: RuntimeError: Attempted to exit cancel scope in a different task - than it was entered in + class TaskBoundTransportContext: + def __init__(self) -> None: + self.enter_task = None + self.exit_task = None + self.closed_cleanly = False - This happens when using multiple MCP tools with AG-UI streaming - the first - tool call succeeds, but when the connection closes, the second tool call - triggers a reconnection from within the streaming loop (a different task). - """ - from contextlib import AsyncExitStack + async def __aenter__(self): + self.enter_task = asyncio.current_task() + return (Mock(), Mock()) - from agent_framework._mcp import MCPStdioTool + async def __aexit__(self, exc_type, exc, tb): + self.exit_task = asyncio.current_task() + if self.exit_task is not self.enter_task: + raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + self.closed_cleanly = True + return - # Use load_tools=False and load_prompts=False to avoid triggering them during connect() - tool = MCPStdioTool( + tool = MCPStreamableHTTPTool( name="test_server", - command="test_command", - args=["arg1"], + url="https://example.com/mcp", load_tools=False, load_prompts=False, ) - # Mock the exit stack to raise the cross-task cancel scope error - mock_exit_stack = AsyncMock(spec=AsyncExitStack) - mock_exit_stack.aclose = AsyncMock( - side_effect=RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") - ) - tool._exit_stack = mock_exit_stack - tool.session = Mock() - tool.is_connected = True + transport_context = TaskBoundTransportContext() + mock_session = Mock() + mock_session._request_id = 1 + mock_session.initialize = AsyncMock() - # Mock get_mcp_client to return a mock transport - mock_transport = (Mock(), Mock()) - mock_context = AsyncMock() - mock_context.__aenter__ = AsyncMock(return_value=mock_transport) - mock_context.__aexit__ = AsyncMock() + mock_session_context = AsyncMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) with ( - patch.object(tool, "get_mcp_client", return_value=mock_context), - patch("agent_framework._mcp.ClientSession") as mock_session_class, + patch.object(tool, "get_mcp_client", return_value=transport_context), + patch("agent_framework._mcp.ClientSession", return_value=mock_session_context), ): - mock_session = Mock() - mock_session._request_id = 1 - mock_session.initialize = AsyncMock() - mock_session.set_logging_level = AsyncMock() - mock_session_context = AsyncMock() - mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) - mock_session_context.__aexit__ = AsyncMock() - mock_session_class.return_value = mock_session_context + await asyncio.create_task(tool.connect()) - # This should NOT raise even though aclose() raised the cancel scope error - # The _safe_close_exit_stack method should catch and log the error - await tool.connect(reset=True) + caplog.clear() + with caplog.at_level(logging.WARNING, logger=logger.name): + await tool.close() - # Verify a new exit stack was created (the old mock was replaced) - assert tool._exit_stack is not mock_exit_stack - assert tool.session is not None + assert transport_context.closed_cleanly is True + assert transport_context.exit_task is transport_context.enter_task + assert not any("cancel scope" in record.getMessage().lower() for record in caplog.records) + + +async def test_mcp_tool_connect_reset_cleans_up_in_original_task(caplog): + """Resetting an MCP tool from another task should unwind and reconnect on the owner task.""" + import asyncio + + class TaskBoundTransportContext: + def __init__(self) -> None: + self.enter_task = None + self.exit_task = None + self.closed_cleanly = False + + async def __aenter__(self): + self.enter_task = asyncio.current_task() + return (Mock(), Mock()) + + async def __aexit__(self, exc_type, exc, tb): + self.exit_task = asyncio.current_task() + if self.exit_task is not self.enter_task: + raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") + self.closed_cleanly = True + return + + tool = MCPStreamableHTTPTool( + name="test_server", + url="https://example.com/mcp", + load_tools=False, + load_prompts=False, + ) + + transport_contexts = [TaskBoundTransportContext(), TaskBoundTransportContext()] + sessions = [] + session_contexts = [] + for _ in range(2): + session = Mock() + session._request_id = 1 + session.initialize = AsyncMock() + session.set_logging_level = AsyncMock() + sessions.append(session) + + session_context = AsyncMock() + session_context.__aenter__ = AsyncMock(return_value=session) + session_context.__aexit__ = AsyncMock(return_value=None) + session_contexts.append(session_context) + + with ( + patch.object(tool, "get_mcp_client", side_effect=transport_contexts), + patch("agent_framework._mcp.ClientSession", side_effect=session_contexts), + ): + await tool.connect() + + caplog.clear() + with caplog.at_level(logging.WARNING, logger=logger.name): + await asyncio.create_task(tool.connect(reset=True)) + + assert transport_contexts[0].closed_cleanly is True + assert transport_contexts[0].exit_task is transport_contexts[0].enter_task + assert transport_contexts[1].enter_task is transport_contexts[0].enter_task + assert tool.session is sessions[1] assert tool.is_connected is True + assert not any("cancel scope" in record.getMessage().lower() for record in caplog.records) + + await tool.close() + + +async def test_mcp_tool_connect_from_lifecycle_owner_bypasses_request_lock() -> None: + """connect(reset=True) should bypass the request queue when already on the owner task.""" + import asyncio + + tool = MCPStreamableHTTPTool( + name="test_server", + url="https://example.com/mcp", + load_tools=False, + load_prompts=False, + ) + + async def connect_from_owner_task() -> None: + tool._lifecycle_owner_task = asyncio.current_task() + try: + async with tool._lifecycle_request_lock: + await tool.connect(reset=True) + finally: + tool._lifecycle_owner_task = None + + with patch.object(tool, "_connect_on_owner", AsyncMock()) as mock_connect_on_owner: + await asyncio.wait_for(connect_from_owner_task(), timeout=0.1) + + mock_connect_on_owner.assert_awaited_once_with(reset=True) + + +async def test_mcp_tool_close_from_lifecycle_owner_bypasses_request_lock() -> None: + """close() should bypass the request queue when already on the owner task.""" + import asyncio + + tool = MCPStreamableHTTPTool( + name="test_server", + url="https://example.com/mcp", + load_tools=False, + load_prompts=False, + ) + + async def close_from_owner_task() -> None: + tool._lifecycle_owner_task = asyncio.current_task() + try: + async with tool._lifecycle_request_lock: + await tool.close() + finally: + tool._lifecycle_owner_task = None + + with patch.object(tool, "_close_on_owner", AsyncMock()) as mock_close_on_owner: + await asyncio.wait_for(close_from_owner_task(), timeout=0.1) + + mock_close_on_owner.assert_awaited_once_with() async def test_mcp_tool_safe_close_reraises_other_runtime_errors():