mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: fix: prevent MCP message_handler deadlock on notification reload (#4866)
* fix(python): prevent MCP message_handler deadlock on notification reload When an MCP server sends a notifications/tools/list_changed or notifications/prompts/list_changed notification, the message_handler previously awaited load_tools()/load_prompts() directly. Since the handler runs on the MCP SDK's single-threaded receive loop, this caused a deadlock: load_tools() sends a list_tools request and waits for its response, but the receive loop cannot deliver that response while blocked in the handler. This manifested as a timeout in call_tool(), which then surfaced as "Error: Function failed." to the model instead of the real tool output. The MATLAB MCP server reliably triggers this because it sends a tools/list_changed notification during tool execution. Fix: schedule reloads as background asyncio.Tasks via a new _schedule_reload() helper, freeing the receive loop immediately. Fixes #4828 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address PR review feedback: fix exc_info, coalesce reloads, shutdown cleanup, tests - Fix exc_info=exc -> exc_info=True in _schedule_reload and message_handler - Tighten _schedule_reload param type from Any to Coroutine[Any, Any, None] - Coalesce reloads: cancel-and-replace per reload kind to prevent unbounded growth - Cancel pending reload tasks in _close_on_owner before tearing down session - Re-raise CancelledError in _safe_reload to respect task cancellation - Replace flaky asyncio.sleep(0) with asyncio.wait_for/gather in tests - Add caplog assertions to verify reload failure is actually logged - Assert _pending_reload_tasks cleanup on error path Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address review comments on MCP reload handling - Fix exc_info=True -> exc_info=message in message_handler error logging, since the handler is not called from an except block - Await cancelled reload tasks in _close_on_owner before tearing down the session to avoid 'Task was destroyed but pending' warnings - Add cancel-and-replace test verifying duplicate notifications cancel the first reload task and only keep one in flight Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: remove Task.cancelling() call for Python 3.10 compat Task.cancelling() was added in Python 3.11. Replace with awaiting the task and checking cancelled() instead. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add debug log when cancelling superseded reload task Log at DEBUG level when a new notification cancels an in-flight reload task, improving observability of the cancel-and-replace behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
574631671d
commit
7d23582e2b
@@ -10,7 +10,7 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Collection, Sequence
|
||||
from collections.abc import Callable, Collection, Coroutine, Sequence
|
||||
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
@@ -264,6 +264,7 @@ class MCPTool:
|
||||
self.is_connected: bool = False
|
||||
self._tools_loaded: bool = False
|
||||
self._prompts_loaded: bool = False
|
||||
self._pending_reload_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MCPTool(name={self.name}, description={self.description})"
|
||||
@@ -905,12 +906,47 @@ class MCPTool:
|
||||
if isinstance(message, types.ServerNotification):
|
||||
match message.root.method:
|
||||
case "notifications/tools/list_changed":
|
||||
await self.load_tools()
|
||||
self._schedule_reload(self.load_tools())
|
||||
case "notifications/prompts/list_changed":
|
||||
await self.load_prompts()
|
||||
self._schedule_reload(self.load_prompts())
|
||||
case _:
|
||||
logger.debug("Unhandled notification: %s", message.root.method)
|
||||
|
||||
def _schedule_reload(self, coro: Coroutine[Any, Any, None]) -> None:
|
||||
"""Schedule a reload coroutine as a background task.
|
||||
|
||||
Reloads (load_tools / load_prompts) triggered by MCP server
|
||||
notifications must NOT be awaited inside the message handler because
|
||||
the handler runs on the MCP SDK's single-threaded receive loop.
|
||||
Awaiting a session request (e.g. ``list_tools``) from within that loop
|
||||
deadlocks: the receive loop cannot read the response while it is
|
||||
blocked waiting for the handler to return.
|
||||
|
||||
Instead we fire the reload as an independent ``asyncio.Task`` and keep
|
||||
a strong reference in ``_pending_reload_tasks`` so it is not garbage-
|
||||
collected before completion. Only one reload per kind (tools / prompts)
|
||||
is kept in flight; a new notification cancels the previous pending task
|
||||
for the same coroutine name to avoid unbounded growth.
|
||||
"""
|
||||
# Cancel-and-replace: only one reload per kind should be in flight.
|
||||
reload_name = f"mcp-reload:{self.name}:{coro.__qualname__}"
|
||||
for existing in list(self._pending_reload_tasks):
|
||||
if existing.get_name() == reload_name and not existing.done():
|
||||
logger.debug("Cancelling in-flight reload %s; superseded by new notification", reload_name)
|
||||
existing.cancel()
|
||||
|
||||
async def _safe_reload() -> None:
|
||||
try:
|
||||
await coro
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.warning("Background MCP reload failed", exc_info=True)
|
||||
|
||||
task = asyncio.create_task(_safe_reload(), name=reload_name)
|
||||
self._pending_reload_tasks.add(task)
|
||||
task.add_done_callback(self._pending_reload_tasks.discard)
|
||||
|
||||
def _determine_approval_mode(
|
||||
self,
|
||||
*candidate_names: str,
|
||||
@@ -1047,6 +1083,14 @@ class MCPTool:
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
# Cancel any pending reload tasks before tearing down the session.
|
||||
tasks = list(self._pending_reload_tasks)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
self._pending_reload_tasks.clear()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
await self._safe_close_exit_stack()
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self.session = None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# type: ignore[reportPrivateUsage]
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -1615,7 +1616,7 @@ async def test_mcp_connection_reset_integration():
|
||||
|
||||
async def test_mcp_tool_message_handler_notification():
|
||||
"""Test that message_handler correctly processes tools/list_changed and prompts/list_changed
|
||||
notifications."""
|
||||
notifications by scheduling reloads as background tasks."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
|
||||
# Mock the load_tools and load_prompts methods
|
||||
@@ -1629,6 +1630,8 @@ async def test_mcp_tool_message_handler_notification():
|
||||
|
||||
result = await tool.message_handler(tools_notification)
|
||||
assert result is None
|
||||
# The reload is scheduled as a background task; let it run.
|
||||
await asyncio.sleep(0)
|
||||
tool.load_tools.assert_called_once()
|
||||
|
||||
# Reset mock
|
||||
@@ -1641,6 +1644,7 @@ async def test_mcp_tool_message_handler_notification():
|
||||
|
||||
result = await tool.message_handler(prompts_notification)
|
||||
assert result is None
|
||||
await asyncio.sleep(0)
|
||||
tool.load_prompts.assert_called_once()
|
||||
|
||||
# Test unhandled notification
|
||||
@@ -1664,6 +1668,112 @@ async def test_mcp_tool_message_handler_error():
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_mcp_tool_message_handler_does_not_block_receive_loop():
|
||||
"""Test that message_handler does not deadlock the MCP receive loop.
|
||||
|
||||
Regression test for https://github.com/microsoft/agent-framework/issues/4828.
|
||||
When the MCP server sends a ``notifications/tools/list_changed``
|
||||
notification, the handler must NOT await ``load_tools()`` synchronously
|
||||
because that would block the single-threaded MCP receive loop, preventing
|
||||
it from delivering the ``list_tools`` response — a classic deadlock.
|
||||
"""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
|
||||
# Use an event to make load_tools block until we release it.
|
||||
# This simulates load_tools waiting for a session response that the
|
||||
# receive loop would need to deliver.
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_load_tools():
|
||||
await release.wait()
|
||||
|
||||
tool.load_tools = slow_load_tools # type: ignore[assignment]
|
||||
|
||||
tools_notification = Mock(spec=types.ServerNotification)
|
||||
tools_notification.root = Mock()
|
||||
tools_notification.root.method = "notifications/tools/list_changed"
|
||||
|
||||
# message_handler must return immediately even though load_tools blocks.
|
||||
await tool.message_handler(tools_notification)
|
||||
|
||||
# If the handler had awaited load_tools synchronously, we would never
|
||||
# reach this line (deadlock). Verify the reload task is pending.
|
||||
assert len(tool._pending_reload_tasks) == 1
|
||||
|
||||
# Unblock the reload so the background task finishes cleanly.
|
||||
release.set()
|
||||
# Wait for the pending reload task(s) to complete so their done-callbacks
|
||||
# have a chance to remove them from _pending_reload_tasks.
|
||||
await asyncio.wait_for(asyncio.gather(*tool._pending_reload_tasks), timeout=1)
|
||||
assert len(tool._pending_reload_tasks) == 0
|
||||
|
||||
|
||||
async def test_mcp_tool_message_handler_reload_failure_is_logged(caplog: pytest.LogCaptureFixture):
|
||||
"""Background reload errors are logged, not raised into the receive loop."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool.load_tools = AsyncMock(side_effect=RuntimeError("connection lost"))
|
||||
|
||||
tools_notification = Mock(spec=types.ServerNotification)
|
||||
tools_notification.root = Mock()
|
||||
tools_notification.root.method = "notifications/tools/list_changed"
|
||||
|
||||
await tool.message_handler(tools_notification)
|
||||
# Let the background task run — it should not propagate the exception.
|
||||
# Snapshot tasks and await them to ensure done-callbacks fire.
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.wait_for(asyncio.gather(*pending, return_exceptions=True), timeout=1)
|
||||
tool.load_tools.assert_called_once()
|
||||
assert len(tool._pending_reload_tasks) == 0
|
||||
|
||||
# Verify the warning was actually logged with exception info.
|
||||
reload_warnings = [r for r in caplog.records if "Background MCP reload failed" in r.message]
|
||||
assert len(reload_warnings) == 1
|
||||
assert reload_warnings[0].levelname == "WARNING"
|
||||
assert reload_warnings[0].exc_info is not None
|
||||
|
||||
|
||||
async def test_mcp_tool_message_handler_cancel_and_replace():
|
||||
"""Sending two notifications in quick succession cancels the first reload task."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
|
||||
release = asyncio.Event()
|
||||
call_count = 0
|
||||
|
||||
async def blocking_load_tools():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await release.wait()
|
||||
|
||||
tool.load_tools = blocking_load_tools # type: ignore[assignment]
|
||||
|
||||
notification = Mock(spec=types.ServerNotification)
|
||||
notification.root = Mock()
|
||||
notification.root.method = "notifications/tools/list_changed"
|
||||
|
||||
# First notification — starts a blocking reload task.
|
||||
await tool.message_handler(notification)
|
||||
assert len(tool._pending_reload_tasks) == 1
|
||||
first_task = next(iter(tool._pending_reload_tasks))
|
||||
|
||||
# Second notification — should cancel the first and replace it.
|
||||
await tool.message_handler(notification)
|
||||
# Yield to the event loop so the cancellation is processed.
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await first_task
|
||||
|
||||
assert first_task.cancelled()
|
||||
|
||||
assert len(tool._pending_reload_tasks) == 1
|
||||
second_task = next(iter(tool._pending_reload_tasks))
|
||||
assert second_task is not first_task
|
||||
|
||||
# Unblock and let the second task finish.
|
||||
release.set()
|
||||
await asyncio.wait_for(asyncio.gather(*tool._pending_reload_tasks), timeout=1)
|
||||
assert len(tool._pending_reload_tasks) == 0
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_no_client():
|
||||
"""Test sampling callback error path when no chat client is available."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
|
||||
Reference in New Issue
Block a user