Python: keep MCP cleanup on the owner task (#4687)

* Python: keep MCP cleanup on owner task

* Avoid MCP owner task deadlocks

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix MCP owner-task timeout tests

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-03-14 14:54:05 +01:00
committed by GitHub
Unverified
parent 2f4c4aa614
commit 1b7940c91e
2 changed files with 256 additions and 67 deletions
+109 -22
View File
@@ -480,6 +480,10 @@ class MCPTool:
self.load_prompts_flag = load_prompts
self.parse_prompt_results = parse_prompt_results
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
self.request_timeout = request_timeout
self.client = client
@@ -510,39 +514,113 @@ class MCPTool:
filtered_functions.append(func)
return filtered_functions
async def _ensure_lifecycle_owner(self) -> None:
async with self._lifecycle_lock:
if self._lifecycle_owner_task is not None and not self._lifecycle_owner_task.done():
return
self._lifecycle_queue = asyncio.Queue()
self._lifecycle_owner_task = asyncio.create_task(
self._run_lifecycle_owner(),
name=f"mcp-lifecycle:{self.name}",
)
async def _run_lifecycle_owner(self) -> None:
queue = self._lifecycle_queue
if queue is None:
return
stop_error: BaseException | None = None
try:
while True:
action, reset, future = await queue.get()
try:
if action == "connect":
await self._connect_on_owner(reset=reset)
elif action == "close":
await self._close_on_owner()
else:
raise RuntimeError(f"Unknown MCP lifecycle action: {action}")
except asyncio.CancelledError as ex:
stop_error = ex
if not future.done():
future.set_exception(ex)
raise
except Exception as ex:
if not future.done():
future.set_exception(ex)
else:
if not future.done():
future.set_result(None)
if action == "close":
return
except asyncio.CancelledError as ex:
stop_error = ex
raise
finally:
while True:
try:
_, _, future = queue.get_nowait()
except asyncio.QueueEmpty:
break
if not future.done():
future.set_exception(stop_error or RuntimeError("MCP lifecycle owner stopped unexpectedly."))
self._lifecycle_queue = None
self._lifecycle_owner_task = None
def _is_lifecycle_owner_task(self) -> bool:
owner_task = self._lifecycle_owner_task
return owner_task is not None and asyncio.current_task() is owner_task
async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> None:
await self._ensure_lifecycle_owner()
if self._is_lifecycle_owner_task():
if action == "connect":
await self._connect_on_owner(reset=reset)
elif action == "close":
await self._close_on_owner()
else:
raise RuntimeError(f"Unknown MCP lifecycle action: {action}")
return
queue = self._lifecycle_queue
if queue is None:
raise RuntimeError("MCP lifecycle owner is not available.")
future = asyncio.get_running_loop().create_future()
await queue.put((action, reset, future))
await future
async def _safe_close_exit_stack(self) -> None:
"""Safely close the exit stack, handling cross-task boundary errors.
anyio's cancel scopes are bound to the task they were created in.
If aclose() is called from a different task (e.g., during streaming reconnection),
anyio will raise a RuntimeError or CancelledError. In this case, we log a warning
and allow garbage collection to clean up the resources.
Known error variants:
- "Attempted to exit cancel scope in a different task than it was entered in"
- "Attempted to exit a cancel scope that isn't the current task's current cancel scope"
- CancelledError from anyio cancel scope cleanup
"""
"""Safely close the exit stack, handling unexpected cleanup failures."""
try:
await self._exit_stack.aclose()
except RuntimeError as e:
error_msg = str(e).lower()
# Check for anyio cancel scope errors (multiple variants exist)
if "cancel scope" in error_msg:
logger.warning(
"Could not cleanly close MCP exit stack due to cancel scope error. "
"Old resources will be garbage collected. Error: %s",
"This indicates MCP lifecycle ownership was lost. Error: %s",
e,
)
else:
raise
except asyncio.CancelledError:
# CancelledError can occur during cleanup when cancel scopes are involved
logger.warning(
"Could not cleanly close MCP exit stack due to cancellation. Old resources will be garbage collected."
)
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
async def connect(self, *, reset: bool = False) -> None:
if self._is_lifecycle_owner_task():
await self._connect_on_owner(reset=reset)
return
async with self._lifecycle_request_lock:
await self._run_on_lifecycle_owner("connect", reset=reset)
async def _connect_on_owner(self, *, reset: bool = False) -> None:
"""Connect to the MCP server.
Establishes a connection to the MCP server, initializes the session,
@@ -844,14 +922,23 @@ class MCPTool:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
async def _close_on_owner(self) -> None:
await self._safe_close_exit_stack()
self._exit_stack = AsyncExitStack()
self.session = None
self.is_connected = False
async def close(self) -> None:
"""Disconnect from the MCP server.
Closes the connection and cleans up resources.
"""
await self._safe_close_exit_stack()
self.session = None
self.is_connected = False
if self._is_lifecycle_owner_task():
await self._close_on_owner()
return
async with self._lifecycle_request_lock:
await self._run_on_lifecycle_owner("close")
@abstractmethod
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
@@ -1043,7 +1130,7 @@ class MCPTool:
except ToolException:
raise
except Exception as ex:
await self._safe_close_exit_stack()
await self.close()
raise ToolExecutionException("Failed to enter context manager.", inner_exception=ex) from ex
async def __aexit__(
+147 -45
View File
@@ -2525,67 +2525,169 @@ async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error():
assert "failed to reconnect" in str(exc_info.value).lower()
async def test_mcp_tool_reconnection_handles_cross_task_cancel_scope_error():
"""Test that reconnection gracefully handles anyio cancel scope errors.
async def test_mcp_tool_close_cleans_up_in_original_task(caplog):
"""Closing an MCP tool from another task should still unwind contexts in the owner task."""
import asyncio
This tests the fix for the bug where calling connect(reset=True) from a
different task than where the connection was originally established would
cause: RuntimeError: Attempted to exit cancel scope in a different task
than it was entered in
class TaskBoundTransportContext:
def __init__(self) -> None:
self.enter_task = None
self.exit_task = None
self.closed_cleanly = False
This happens when using multiple MCP tools with AG-UI streaming - the first
tool call succeeds, but when the connection closes, the second tool call
triggers a reconnection from within the streaming loop (a different task).
"""
from contextlib import AsyncExitStack
async def __aenter__(self):
self.enter_task = asyncio.current_task()
return (Mock(), Mock())
from agent_framework._mcp import MCPStdioTool
async def __aexit__(self, exc_type, exc, tb):
self.exit_task = asyncio.current_task()
if self.exit_task is not self.enter_task:
raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
self.closed_cleanly = True
return
# Use load_tools=False and load_prompts=False to avoid triggering them during connect()
tool = MCPStdioTool(
tool = MCPStreamableHTTPTool(
name="test_server",
command="test_command",
args=["arg1"],
url="https://example.com/mcp",
load_tools=False,
load_prompts=False,
)
# Mock the exit stack to raise the cross-task cancel scope error
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(
side_effect=RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
)
tool._exit_stack = mock_exit_stack
tool.session = Mock()
tool.is_connected = True
transport_context = TaskBoundTransportContext()
mock_session = Mock()
mock_session._request_id = 1
mock_session.initialize = AsyncMock()
# Mock get_mcp_client to return a mock transport
mock_transport = (Mock(), Mock())
mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(return_value=mock_transport)
mock_context.__aexit__ = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock(return_value=None)
with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession") as mock_session_class,
patch.object(tool, "get_mcp_client", return_value=transport_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
):
mock_session = Mock()
mock_session._request_id = 1
mock_session.initialize = AsyncMock()
mock_session.set_logging_level = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock()
mock_session_class.return_value = mock_session_context
await asyncio.create_task(tool.connect())
# This should NOT raise even though aclose() raised the cancel scope error
# The _safe_close_exit_stack method should catch and log the error
await tool.connect(reset=True)
caplog.clear()
with caplog.at_level(logging.WARNING, logger=logger.name):
await tool.close()
# Verify a new exit stack was created (the old mock was replaced)
assert tool._exit_stack is not mock_exit_stack
assert tool.session is not None
assert transport_context.closed_cleanly is True
assert transport_context.exit_task is transport_context.enter_task
assert not any("cancel scope" in record.getMessage().lower() for record in caplog.records)
async def test_mcp_tool_connect_reset_cleans_up_in_original_task(caplog):
"""Resetting an MCP tool from another task should unwind and reconnect on the owner task."""
import asyncio
class TaskBoundTransportContext:
def __init__(self) -> None:
self.enter_task = None
self.exit_task = None
self.closed_cleanly = False
async def __aenter__(self):
self.enter_task = asyncio.current_task()
return (Mock(), Mock())
async def __aexit__(self, exc_type, exc, tb):
self.exit_task = asyncio.current_task()
if self.exit_task is not self.enter_task:
raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
self.closed_cleanly = True
return
tool = MCPStreamableHTTPTool(
name="test_server",
url="https://example.com/mcp",
load_tools=False,
load_prompts=False,
)
transport_contexts = [TaskBoundTransportContext(), TaskBoundTransportContext()]
sessions = []
session_contexts = []
for _ in range(2):
session = Mock()
session._request_id = 1
session.initialize = AsyncMock()
session.set_logging_level = AsyncMock()
sessions.append(session)
session_context = AsyncMock()
session_context.__aenter__ = AsyncMock(return_value=session)
session_context.__aexit__ = AsyncMock(return_value=None)
session_contexts.append(session_context)
with (
patch.object(tool, "get_mcp_client", side_effect=transport_contexts),
patch("agent_framework._mcp.ClientSession", side_effect=session_contexts),
):
await tool.connect()
caplog.clear()
with caplog.at_level(logging.WARNING, logger=logger.name):
await asyncio.create_task(tool.connect(reset=True))
assert transport_contexts[0].closed_cleanly is True
assert transport_contexts[0].exit_task is transport_contexts[0].enter_task
assert transport_contexts[1].enter_task is transport_contexts[0].enter_task
assert tool.session is sessions[1]
assert tool.is_connected is True
assert not any("cancel scope" in record.getMessage().lower() for record in caplog.records)
await tool.close()
async def test_mcp_tool_connect_from_lifecycle_owner_bypasses_request_lock() -> None:
"""connect(reset=True) should bypass the request queue when already on the owner task."""
import asyncio
tool = MCPStreamableHTTPTool(
name="test_server",
url="https://example.com/mcp",
load_tools=False,
load_prompts=False,
)
async def connect_from_owner_task() -> None:
tool._lifecycle_owner_task = asyncio.current_task()
try:
async with tool._lifecycle_request_lock:
await tool.connect(reset=True)
finally:
tool._lifecycle_owner_task = None
with patch.object(tool, "_connect_on_owner", AsyncMock()) as mock_connect_on_owner:
await asyncio.wait_for(connect_from_owner_task(), timeout=0.1)
mock_connect_on_owner.assert_awaited_once_with(reset=True)
async def test_mcp_tool_close_from_lifecycle_owner_bypasses_request_lock() -> None:
"""close() should bypass the request queue when already on the owner task."""
import asyncio
tool = MCPStreamableHTTPTool(
name="test_server",
url="https://example.com/mcp",
load_tools=False,
load_prompts=False,
)
async def close_from_owner_task() -> None:
tool._lifecycle_owner_task = asyncio.current_task()
try:
async with tool._lifecycle_request_lock:
await tool.close()
finally:
tool._lifecycle_owner_task = None
with patch.object(tool, "_close_on_owner", AsyncMock()) as mock_close_on_owner:
await asyncio.wait_for(close_from_owner_task(), timeout=0.1)
mock_close_on_owner.assert_awaited_once_with()
async def test_mcp_tool_safe_close_reraises_other_runtime_errors():