mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: keep MCP cleanup on the owner task (#4687)
* Python: keep MCP cleanup on owner task * Avoid MCP owner task deadlocks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix MCP owner-task timeout tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
2f4c4aa614
commit
1b7940c91e
@@ -480,6 +480,10 @@ class MCPTool:
|
||||
self.load_prompts_flag = load_prompts
|
||||
self.parse_prompt_results = parse_prompt_results
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
self._lifecycle_request_lock = asyncio.Lock()
|
||||
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None
|
||||
self._lifecycle_owner_task: asyncio.Task[None] | None = None
|
||||
self.session = session
|
||||
self.request_timeout = request_timeout
|
||||
self.client = client
|
||||
@@ -510,39 +514,113 @@ class MCPTool:
|
||||
filtered_functions.append(func)
|
||||
return filtered_functions
|
||||
|
||||
async def _ensure_lifecycle_owner(self) -> None:
|
||||
async with self._lifecycle_lock:
|
||||
if self._lifecycle_owner_task is not None and not self._lifecycle_owner_task.done():
|
||||
return
|
||||
|
||||
self._lifecycle_queue = asyncio.Queue()
|
||||
self._lifecycle_owner_task = asyncio.create_task(
|
||||
self._run_lifecycle_owner(),
|
||||
name=f"mcp-lifecycle:{self.name}",
|
||||
)
|
||||
|
||||
async def _run_lifecycle_owner(self) -> None:
|
||||
queue = self._lifecycle_queue
|
||||
if queue is None:
|
||||
return
|
||||
|
||||
stop_error: BaseException | None = None
|
||||
try:
|
||||
while True:
|
||||
action, reset, future = await queue.get()
|
||||
|
||||
try:
|
||||
if action == "connect":
|
||||
await self._connect_on_owner(reset=reset)
|
||||
elif action == "close":
|
||||
await self._close_on_owner()
|
||||
else:
|
||||
raise RuntimeError(f"Unknown MCP lifecycle action: {action}")
|
||||
except asyncio.CancelledError as ex:
|
||||
stop_error = ex
|
||||
if not future.done():
|
||||
future.set_exception(ex)
|
||||
raise
|
||||
except Exception as ex:
|
||||
if not future.done():
|
||||
future.set_exception(ex)
|
||||
else:
|
||||
if not future.done():
|
||||
future.set_result(None)
|
||||
|
||||
if action == "close":
|
||||
return
|
||||
except asyncio.CancelledError as ex:
|
||||
stop_error = ex
|
||||
raise
|
||||
finally:
|
||||
while True:
|
||||
try:
|
||||
_, _, future = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if not future.done():
|
||||
future.set_exception(stop_error or RuntimeError("MCP lifecycle owner stopped unexpectedly."))
|
||||
|
||||
self._lifecycle_queue = None
|
||||
self._lifecycle_owner_task = None
|
||||
|
||||
def _is_lifecycle_owner_task(self) -> bool:
|
||||
owner_task = self._lifecycle_owner_task
|
||||
return owner_task is not None and asyncio.current_task() is owner_task
|
||||
|
||||
async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> None:
|
||||
await self._ensure_lifecycle_owner()
|
||||
|
||||
if self._is_lifecycle_owner_task():
|
||||
if action == "connect":
|
||||
await self._connect_on_owner(reset=reset)
|
||||
elif action == "close":
|
||||
await self._close_on_owner()
|
||||
else:
|
||||
raise RuntimeError(f"Unknown MCP lifecycle action: {action}")
|
||||
return
|
||||
|
||||
queue = self._lifecycle_queue
|
||||
if queue is None:
|
||||
raise RuntimeError("MCP lifecycle owner is not available.")
|
||||
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
await queue.put((action, reset, future))
|
||||
await future
|
||||
|
||||
async def _safe_close_exit_stack(self) -> None:
|
||||
"""Safely close the exit stack, handling cross-task boundary errors.
|
||||
|
||||
anyio's cancel scopes are bound to the task they were created in.
|
||||
If aclose() is called from a different task (e.g., during streaming reconnection),
|
||||
anyio will raise a RuntimeError or CancelledError. In this case, we log a warning
|
||||
and allow garbage collection to clean up the resources.
|
||||
|
||||
Known error variants:
|
||||
- "Attempted to exit cancel scope in a different task than it was entered in"
|
||||
- "Attempted to exit a cancel scope that isn't the current task's current cancel scope"
|
||||
- CancelledError from anyio cancel scope cleanup
|
||||
"""
|
||||
"""Safely close the exit stack, handling unexpected cleanup failures."""
|
||||
try:
|
||||
await self._exit_stack.aclose()
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e).lower()
|
||||
# Check for anyio cancel scope errors (multiple variants exist)
|
||||
if "cancel scope" in error_msg:
|
||||
logger.warning(
|
||||
"Could not cleanly close MCP exit stack due to cancel scope error. "
|
||||
"Old resources will be garbage collected. Error: %s",
|
||||
"This indicates MCP lifecycle ownership was lost. Error: %s",
|
||||
e,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
# CancelledError can occur during cleanup when cancel scopes are involved
|
||||
logger.warning(
|
||||
"Could not cleanly close MCP exit stack due to cancellation. Old resources will be garbage collected."
|
||||
)
|
||||
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
|
||||
|
||||
async def connect(self, *, reset: bool = False) -> None:
|
||||
if self._is_lifecycle_owner_task():
|
||||
await self._connect_on_owner(reset=reset)
|
||||
return
|
||||
|
||||
async with self._lifecycle_request_lock:
|
||||
await self._run_on_lifecycle_owner("connect", reset=reset)
|
||||
|
||||
async def _connect_on_owner(self, *, reset: bool = False) -> None:
|
||||
"""Connect to the MCP server.
|
||||
|
||||
Establishes a connection to the MCP server, initializes the session,
|
||||
@@ -844,14 +922,23 @@ class MCPTool:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
await self._safe_close_exit_stack()
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self.session = None
|
||||
self.is_connected = False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Disconnect from the MCP server.
|
||||
|
||||
Closes the connection and cleans up resources.
|
||||
"""
|
||||
await self._safe_close_exit_stack()
|
||||
self.session = None
|
||||
self.is_connected = False
|
||||
if self._is_lifecycle_owner_task():
|
||||
await self._close_on_owner()
|
||||
return
|
||||
|
||||
async with self._lifecycle_request_lock:
|
||||
await self._run_on_lifecycle_owner("close")
|
||||
|
||||
@abstractmethod
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
@@ -1043,7 +1130,7 @@ class MCPTool:
|
||||
except ToolException:
|
||||
raise
|
||||
except Exception as ex:
|
||||
await self._safe_close_exit_stack()
|
||||
await self.close()
|
||||
raise ToolExecutionException("Failed to enter context manager.", inner_exception=ex) from ex
|
||||
|
||||
async def __aexit__(
|
||||
|
||||
@@ -2525,67 +2525,169 @@ async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error():
|
||||
assert "failed to reconnect" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
async def test_mcp_tool_reconnection_handles_cross_task_cancel_scope_error():
|
||||
"""Test that reconnection gracefully handles anyio cancel scope errors.
|
||||
async def test_mcp_tool_close_cleans_up_in_original_task(caplog):
|
||||
"""Closing an MCP tool from another task should still unwind contexts in the owner task."""
|
||||
import asyncio
|
||||
|
||||
This tests the fix for the bug where calling connect(reset=True) from a
|
||||
different task than where the connection was originally established would
|
||||
cause: RuntimeError: Attempted to exit cancel scope in a different task
|
||||
than it was entered in
|
||||
class TaskBoundTransportContext:
|
||||
def __init__(self) -> None:
|
||||
self.enter_task = None
|
||||
self.exit_task = None
|
||||
self.closed_cleanly = False
|
||||
|
||||
This happens when using multiple MCP tools with AG-UI streaming - the first
|
||||
tool call succeeds, but when the connection closes, the second tool call
|
||||
triggers a reconnection from within the streaming loop (a different task).
|
||||
"""
|
||||
from contextlib import AsyncExitStack
|
||||
async def __aenter__(self):
|
||||
self.enter_task = asyncio.current_task()
|
||||
return (Mock(), Mock())
|
||||
|
||||
from agent_framework._mcp import MCPStdioTool
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.exit_task = asyncio.current_task()
|
||||
if self.exit_task is not self.enter_task:
|
||||
raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
|
||||
self.closed_cleanly = True
|
||||
return
|
||||
|
||||
# Use load_tools=False and load_prompts=False to avoid triggering them during connect()
|
||||
tool = MCPStdioTool(
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test_server",
|
||||
command="test_command",
|
||||
args=["arg1"],
|
||||
url="https://example.com/mcp",
|
||||
load_tools=False,
|
||||
load_prompts=False,
|
||||
)
|
||||
|
||||
# Mock the exit stack to raise the cross-task cancel scope error
|
||||
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
|
||||
mock_exit_stack.aclose = AsyncMock(
|
||||
side_effect=RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
|
||||
)
|
||||
tool._exit_stack = mock_exit_stack
|
||||
tool.session = Mock()
|
||||
tool.is_connected = True
|
||||
transport_context = TaskBoundTransportContext()
|
||||
mock_session = Mock()
|
||||
mock_session._request_id = 1
|
||||
mock_session.initialize = AsyncMock()
|
||||
|
||||
# Mock get_mcp_client to return a mock transport
|
||||
mock_transport = (Mock(), Mock())
|
||||
mock_context = AsyncMock()
|
||||
mock_context.__aenter__ = AsyncMock(return_value=mock_transport)
|
||||
mock_context.__aexit__ = AsyncMock()
|
||||
mock_session_context = AsyncMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch.object(tool, "get_mcp_client", return_value=mock_context),
|
||||
patch("agent_framework._mcp.ClientSession") as mock_session_class,
|
||||
patch.object(tool, "get_mcp_client", return_value=transport_context),
|
||||
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_session._request_id = 1
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.set_logging_level = AsyncMock()
|
||||
mock_session_context = AsyncMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_context.__aexit__ = AsyncMock()
|
||||
mock_session_class.return_value = mock_session_context
|
||||
await asyncio.create_task(tool.connect())
|
||||
|
||||
# This should NOT raise even though aclose() raised the cancel scope error
|
||||
# The _safe_close_exit_stack method should catch and log the error
|
||||
await tool.connect(reset=True)
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.WARNING, logger=logger.name):
|
||||
await tool.close()
|
||||
|
||||
# Verify a new exit stack was created (the old mock was replaced)
|
||||
assert tool._exit_stack is not mock_exit_stack
|
||||
assert tool.session is not None
|
||||
assert transport_context.closed_cleanly is True
|
||||
assert transport_context.exit_task is transport_context.enter_task
|
||||
assert not any("cancel scope" in record.getMessage().lower() for record in caplog.records)
|
||||
|
||||
|
||||
async def test_mcp_tool_connect_reset_cleans_up_in_original_task(caplog):
|
||||
"""Resetting an MCP tool from another task should unwind and reconnect on the owner task."""
|
||||
import asyncio
|
||||
|
||||
class TaskBoundTransportContext:
|
||||
def __init__(self) -> None:
|
||||
self.enter_task = None
|
||||
self.exit_task = None
|
||||
self.closed_cleanly = False
|
||||
|
||||
async def __aenter__(self):
|
||||
self.enter_task = asyncio.current_task()
|
||||
return (Mock(), Mock())
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.exit_task = asyncio.current_task()
|
||||
if self.exit_task is not self.enter_task:
|
||||
raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
|
||||
self.closed_cleanly = True
|
||||
return
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test_server",
|
||||
url="https://example.com/mcp",
|
||||
load_tools=False,
|
||||
load_prompts=False,
|
||||
)
|
||||
|
||||
transport_contexts = [TaskBoundTransportContext(), TaskBoundTransportContext()]
|
||||
sessions = []
|
||||
session_contexts = []
|
||||
for _ in range(2):
|
||||
session = Mock()
|
||||
session._request_id = 1
|
||||
session.initialize = AsyncMock()
|
||||
session.set_logging_level = AsyncMock()
|
||||
sessions.append(session)
|
||||
|
||||
session_context = AsyncMock()
|
||||
session_context.__aenter__ = AsyncMock(return_value=session)
|
||||
session_context.__aexit__ = AsyncMock(return_value=None)
|
||||
session_contexts.append(session_context)
|
||||
|
||||
with (
|
||||
patch.object(tool, "get_mcp_client", side_effect=transport_contexts),
|
||||
patch("agent_framework._mcp.ClientSession", side_effect=session_contexts),
|
||||
):
|
||||
await tool.connect()
|
||||
|
||||
caplog.clear()
|
||||
with caplog.at_level(logging.WARNING, logger=logger.name):
|
||||
await asyncio.create_task(tool.connect(reset=True))
|
||||
|
||||
assert transport_contexts[0].closed_cleanly is True
|
||||
assert transport_contexts[0].exit_task is transport_contexts[0].enter_task
|
||||
assert transport_contexts[1].enter_task is transport_contexts[0].enter_task
|
||||
assert tool.session is sessions[1]
|
||||
assert tool.is_connected is True
|
||||
assert not any("cancel scope" in record.getMessage().lower() for record in caplog.records)
|
||||
|
||||
await tool.close()
|
||||
|
||||
|
||||
async def test_mcp_tool_connect_from_lifecycle_owner_bypasses_request_lock() -> None:
|
||||
"""connect(reset=True) should bypass the request queue when already on the owner task."""
|
||||
import asyncio
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test_server",
|
||||
url="https://example.com/mcp",
|
||||
load_tools=False,
|
||||
load_prompts=False,
|
||||
)
|
||||
|
||||
async def connect_from_owner_task() -> None:
|
||||
tool._lifecycle_owner_task = asyncio.current_task()
|
||||
try:
|
||||
async with tool._lifecycle_request_lock:
|
||||
await tool.connect(reset=True)
|
||||
finally:
|
||||
tool._lifecycle_owner_task = None
|
||||
|
||||
with patch.object(tool, "_connect_on_owner", AsyncMock()) as mock_connect_on_owner:
|
||||
await asyncio.wait_for(connect_from_owner_task(), timeout=0.1)
|
||||
|
||||
mock_connect_on_owner.assert_awaited_once_with(reset=True)
|
||||
|
||||
|
||||
async def test_mcp_tool_close_from_lifecycle_owner_bypasses_request_lock() -> None:
|
||||
"""close() should bypass the request queue when already on the owner task."""
|
||||
import asyncio
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test_server",
|
||||
url="https://example.com/mcp",
|
||||
load_tools=False,
|
||||
load_prompts=False,
|
||||
)
|
||||
|
||||
async def close_from_owner_task() -> None:
|
||||
tool._lifecycle_owner_task = asyncio.current_task()
|
||||
try:
|
||||
async with tool._lifecycle_request_lock:
|
||||
await tool.close()
|
||||
finally:
|
||||
tool._lifecycle_owner_task = None
|
||||
|
||||
with patch.object(tool, "_close_on_owner", AsyncMock()) as mock_close_on_owner:
|
||||
await asyncio.wait_for(close_from_owner_task(), timeout=0.1)
|
||||
|
||||
mock_close_on_owner.assert_awaited_once_with()
|
||||
|
||||
|
||||
async def test_mcp_tool_safe_close_reraises_other_runtime_errors():
|
||||
|
||||
Reference in New Issue
Block a user