mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Skip MCP prompt loading when unsupported (#5370)
* Python: Skip MCP prompt loading when unsupported * Fix MCP pagination pyright checks * Simplify MCP support flag checks
This commit is contained in:
committed by
GitHub
Unverified
parent
dd1e615dad
commit
0ba552b84c
@@ -255,7 +255,7 @@ class MCPTool:
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
self._lifecycle_request_lock = asyncio.Lock()
|
||||
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None
|
||||
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
|
||||
self._lifecycle_owner_task: asyncio.Task[None] | None = None
|
||||
self.session = session
|
||||
self.request_timeout = request_timeout
|
||||
@@ -265,6 +265,11 @@ class MCPTool:
|
||||
self.is_connected: bool = False
|
||||
self._tools_loaded: bool = False
|
||||
self._prompts_loaded: bool = False
|
||||
self._server_capabilities: types.ServerCapabilities | None = None
|
||||
self._supports_tools: bool = True
|
||||
self._supports_prompts: bool = True
|
||||
self._supports_logging: bool | None = None
|
||||
self._ping_available: bool = True
|
||||
self._pending_reload_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -566,11 +571,11 @@ class MCPTool:
|
||||
stop_error: BaseException | None = None
|
||||
try:
|
||||
while True:
|
||||
action, reset, future = await queue.get()
|
||||
action, reset, load_configured, future = await queue.get()
|
||||
|
||||
try:
|
||||
if action == "connect":
|
||||
await self._connect_on_owner(reset=reset)
|
||||
await self._connect_on_owner(reset=reset, load_configured=load_configured)
|
||||
elif action == "close":
|
||||
await self._close_on_owner()
|
||||
else:
|
||||
@@ -595,7 +600,7 @@ class MCPTool:
|
||||
finally:
|
||||
while True:
|
||||
try:
|
||||
_, _, future = queue.get_nowait()
|
||||
_, _, _, future = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if not future.done():
|
||||
@@ -608,12 +613,18 @@ class MCPTool:
|
||||
owner_task = self._lifecycle_owner_task
|
||||
return owner_task is not None and asyncio.current_task() is owner_task
|
||||
|
||||
async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> None:
|
||||
async def _run_on_lifecycle_owner(
|
||||
self,
|
||||
action: str,
|
||||
*,
|
||||
reset: bool = False,
|
||||
load_configured: bool = True,
|
||||
) -> None:
|
||||
await self._ensure_lifecycle_owner()
|
||||
|
||||
if self._is_lifecycle_owner_task():
|
||||
if action == "connect":
|
||||
await self._connect_on_owner(reset=reset)
|
||||
await self._connect_on_owner(reset=reset, load_configured=load_configured)
|
||||
elif action == "close":
|
||||
await self._close_on_owner()
|
||||
else:
|
||||
@@ -625,7 +636,7 @@ class MCPTool:
|
||||
raise RuntimeError("MCP lifecycle owner is not available.")
|
||||
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
await queue.put((action, reset, future))
|
||||
await queue.put((action, reset, load_configured, future))
|
||||
await future
|
||||
|
||||
async def _safe_close_exit_stack(self) -> None:
|
||||
@@ -656,6 +667,32 @@ class MCPTool:
|
||||
await self._safe_close_exit_stack()
|
||||
return _should_propagate_cancelled_error(ex)
|
||||
|
||||
def _reset_session_state(self) -> None:
|
||||
self._server_capabilities = None
|
||||
self._supports_tools = True
|
||||
self._supports_prompts = True
|
||||
self._supports_logging = None
|
||||
self._ping_available = True
|
||||
|
||||
def _set_server_capabilities(self, capabilities: types.ServerCapabilities | None) -> None:
|
||||
self._server_capabilities = capabilities
|
||||
if capabilities is None:
|
||||
self._supports_tools = False
|
||||
self._supports_prompts = False
|
||||
self._supports_logging = False
|
||||
return
|
||||
|
||||
self._supports_tools = getattr(capabilities, "tools", None) is not None
|
||||
self._supports_prompts = getattr(capabilities, "prompts", None) is not None
|
||||
self._supports_logging = getattr(capabilities, "logging", None) is not None
|
||||
|
||||
async def _reconnect_without_loading(self) -> None:
|
||||
if self._is_lifecycle_owner_task():
|
||||
await self._connect_on_owner(reset=True, load_configured=False)
|
||||
return
|
||||
|
||||
await self._run_on_lifecycle_owner("connect", reset=True, load_configured=False)
|
||||
|
||||
async def connect(self, *, reset: bool = False) -> None:
|
||||
if self._is_lifecycle_owner_task():
|
||||
await self._connect_on_owner(reset=reset)
|
||||
@@ -664,7 +701,7 @@ class MCPTool:
|
||||
async with self._lifecycle_request_lock:
|
||||
await self._run_on_lifecycle_owner("connect", reset=reset)
|
||||
|
||||
async def _connect_on_owner(self, *, reset: bool = False) -> None:
|
||||
async def _connect_on_owner(self, *, reset: bool = False, load_configured: bool = True) -> None:
|
||||
"""Connect to the MCP server.
|
||||
|
||||
Establishes a connection to the MCP server, initializes the session,
|
||||
@@ -672,6 +709,7 @@ class MCPTool:
|
||||
|
||||
Keyword Args:
|
||||
reset: If True, forces a reconnection even if already connected.
|
||||
load_configured: If True, loads tools and prompts according to the constructor flags.
|
||||
|
||||
Raises:
|
||||
ToolException: If connection or session initialization fails.
|
||||
@@ -680,6 +718,7 @@ class MCPTool:
|
||||
await self._safe_close_exit_stack()
|
||||
self.session = None
|
||||
self.is_connected = False
|
||||
self._reset_session_state()
|
||||
self._exit_stack = AsyncExitStack()
|
||||
if not self.session:
|
||||
try:
|
||||
@@ -741,7 +780,8 @@ class MCPTool:
|
||||
inner_exception=ex if isinstance(ex, Exception) else None,
|
||||
) from ex
|
||||
try:
|
||||
await session.initialize()
|
||||
initialize_result = await session.initialize()
|
||||
self._set_server_capabilities(getattr(initialize_result, "capabilities", None))
|
||||
except (Exception, asyncio.CancelledError) as ex:
|
||||
if await self._close_and_check_cancelled(ex):
|
||||
raise
|
||||
@@ -759,17 +799,22 @@ class MCPTool:
|
||||
self.session = session
|
||||
elif self.session._request_id == 0: # type: ignore[attr-defined]
|
||||
# If the session is not initialized, we need to reinitialize it
|
||||
await self.session.initialize()
|
||||
initialize_result = await self.session.initialize()
|
||||
self._set_server_capabilities(getattr(initialize_result, "capabilities", None))
|
||||
elif self._server_capabilities is None:
|
||||
self._set_server_capabilities(getattr(self.session, "_server_capabilities", None))
|
||||
logger.debug("Connected to MCP server: %s", self.session)
|
||||
self.is_connected = True
|
||||
if self.load_tools_flag:
|
||||
await self.load_tools()
|
||||
if load_configured and self.load_tools_flag:
|
||||
if self._supports_tools:
|
||||
await self.load_tools()
|
||||
self._tools_loaded = True
|
||||
if self.load_prompts_flag:
|
||||
await self.load_prompts()
|
||||
if load_configured and self.load_prompts_flag:
|
||||
if self._supports_prompts:
|
||||
await self.load_prompts()
|
||||
self._prompts_loaded = True
|
||||
|
||||
if logger.level != logging.NOTSET:
|
||||
if logger.level != logging.NOTSET and self._supports_logging is not False:
|
||||
try:
|
||||
level_name = cast(
|
||||
Any, next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
|
||||
@@ -973,17 +1018,49 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
if not self._supports_prompts:
|
||||
logger.debug("Skipping MCP prompt loading because the server did not advertise prompts support.")
|
||||
return
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
# Ensure connection is still valid before each page request
|
||||
await self._ensure_connected()
|
||||
prompt_list: types.ListPromptsResult | None = None
|
||||
for attempt in range(2):
|
||||
try:
|
||||
# Ensure connection is still valid before each page request
|
||||
await self._ensure_connected()
|
||||
if not self._supports_prompts:
|
||||
logger.debug(
|
||||
"Skipping MCP prompt loading because the server did not advertise prompts support."
|
||||
)
|
||||
return
|
||||
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
|
||||
break
|
||||
except ClosedResourceError as cl_ex:
|
||||
if attempt == 0:
|
||||
logger.info("MCP connection closed unexpectedly while loading prompts. Reconnecting...")
|
||||
try:
|
||||
await self._reconnect_without_loading()
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
continue
|
||||
logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex)
|
||||
raise ToolExecutionException(
|
||||
"Failed to load prompts - connection lost.",
|
||||
inner_exception=cl_ex,
|
||||
) from cl_ex
|
||||
|
||||
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
|
||||
if prompt_list is None:
|
||||
raise ToolExecutionException("Failed to load prompts.")
|
||||
|
||||
for prompt in prompt_list.prompts:
|
||||
normalized_name = _normalize_mcp_name(prompt.name)
|
||||
@@ -1010,7 +1087,7 @@ class MCPTool:
|
||||
existing_names.add(local_name)
|
||||
|
||||
# Check if there are more pages
|
||||
if not prompt_list or not prompt_list.nextCursor:
|
||||
if not prompt_list.nextCursor:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor)
|
||||
|
||||
@@ -1023,18 +1100,48 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
if not self._supports_tools:
|
||||
logger.debug("Skipping MCP tool loading because the server did not advertise tools support.")
|
||||
return
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
self._tool_call_meta_by_name.clear()
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
# Ensure connection is still valid before each page request
|
||||
await self._ensure_connected()
|
||||
tool_list: types.ListToolsResult | None = None
|
||||
for attempt in range(2):
|
||||
try:
|
||||
# Ensure connection is still valid before each page request
|
||||
await self._ensure_connected()
|
||||
if not self._supports_tools:
|
||||
logger.debug("Skipping MCP tool loading because the server did not advertise tools support.")
|
||||
return
|
||||
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
|
||||
break
|
||||
except ClosedResourceError as cl_ex:
|
||||
if attempt == 0:
|
||||
logger.info("MCP connection closed unexpectedly while loading tools. Reconnecting...")
|
||||
try:
|
||||
await self._reconnect_without_loading()
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
continue
|
||||
logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex)
|
||||
raise ToolExecutionException(
|
||||
"Failed to load tools - connection lost.",
|
||||
inner_exception=cl_ex,
|
||||
) from cl_ex
|
||||
|
||||
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
|
||||
if tool_list is None:
|
||||
raise ToolExecutionException("Failed to load tools.")
|
||||
|
||||
for tool in tool_list.tools:
|
||||
if tool.meta is not None:
|
||||
@@ -1083,7 +1190,7 @@ class MCPTool:
|
||||
existing_names.add(local_name)
|
||||
|
||||
# Check if there are more pages
|
||||
if not tool_list or not tool_list.nextCursor:
|
||||
if not tool_list.nextCursor:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
@@ -1100,6 +1207,7 @@ class MCPTool:
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self.session = None
|
||||
self.is_connected = False
|
||||
self._reset_session_state()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Disconnect from the MCP server.
|
||||
@@ -1131,12 +1239,30 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If reconnection fails.
|
||||
"""
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
if not self._ping_available:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.session.send_ping() # type: ignore[union-attr]
|
||||
except McpError as mcp_exc:
|
||||
if mcp_exc.error.code == -32601:
|
||||
self._ping_available = False
|
||||
logger.debug("Skipping future MCP pings because the server does not support ping.")
|
||||
return
|
||||
logger.info("MCP connection invalid or closed. Reconnecting...")
|
||||
try:
|
||||
await self._reconnect_without_loading()
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to establish MCP connection.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
except Exception:
|
||||
logger.info("MCP connection invalid or closed. Reconnecting...")
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
await self._reconnect_without_loading()
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to establish MCP connection.",
|
||||
|
||||
@@ -4031,14 +4031,102 @@ async def test_connect_reinitializes_existing_session_and_loads_tools_and_prompt
|
||||
assert tool._prompts_loaded is True
|
||||
|
||||
|
||||
async def test_connect_skips_tools_and_prompts_when_server_does_not_advertise_capabilities() -> None:
|
||||
tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True)
|
||||
tool.is_connected = True
|
||||
tool.session = Mock()
|
||||
tool.session._request_id = 0
|
||||
tool.session.initialize = AsyncMock(
|
||||
return_value=types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ServerCapabilities(),
|
||||
serverInfo=types.Implementation(name="test", version="1.0"),
|
||||
)
|
||||
)
|
||||
tool.session.list_tools = AsyncMock()
|
||||
tool.session.list_prompts = AsyncMock()
|
||||
tool.session.set_logging_level = AsyncMock()
|
||||
|
||||
with patch.object(logger, "level", logging.INFO):
|
||||
await tool._connect_on_owner()
|
||||
|
||||
tool.session.initialize.assert_awaited_once()
|
||||
tool.session.list_tools.assert_not_called()
|
||||
tool.session.list_prompts.assert_not_called()
|
||||
tool.session.set_logging_level.assert_not_called()
|
||||
assert tool.is_connected is True
|
||||
assert tool._supports_tools is False
|
||||
assert tool._supports_prompts is False
|
||||
assert tool._supports_logging is False
|
||||
assert tool._tools_loaded is True
|
||||
assert tool._prompts_loaded is True
|
||||
|
||||
|
||||
async def test_connect_treats_missing_capabilities_as_unsupported() -> None:
|
||||
tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True)
|
||||
tool.is_connected = True
|
||||
tool.session = Mock()
|
||||
tool.session._request_id = 0
|
||||
tool.session.initialize = AsyncMock(return_value=Mock(capabilities=None))
|
||||
tool.session.list_tools = AsyncMock()
|
||||
tool.session.list_prompts = AsyncMock()
|
||||
|
||||
with patch.object(logger, "level", logging.NOTSET):
|
||||
await tool._connect_on_owner()
|
||||
|
||||
tool.session.list_tools.assert_not_called()
|
||||
tool.session.list_prompts.assert_not_called()
|
||||
assert tool._supports_tools is False
|
||||
assert tool._supports_prompts is False
|
||||
assert tool._supports_logging is False
|
||||
|
||||
|
||||
async def test_connect_sets_logging_level_when_server_advertises_logging() -> None:
|
||||
tool = MCPTool(name="test_tool", load_tools=False, load_prompts=False)
|
||||
tool.is_connected = True
|
||||
tool.session = Mock()
|
||||
tool.session._request_id = 0
|
||||
tool.session.initialize = AsyncMock(
|
||||
return_value=types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ServerCapabilities(logging=types.LoggingCapability()),
|
||||
serverInfo=types.Implementation(name="test", version="1.0"),
|
||||
)
|
||||
)
|
||||
tool.session.set_logging_level = AsyncMock()
|
||||
|
||||
with patch.object(logger, "level", logging.INFO):
|
||||
await tool._connect_on_owner()
|
||||
|
||||
tool.session.set_logging_level.assert_awaited_once_with("info")
|
||||
assert tool._supports_logging is True
|
||||
|
||||
|
||||
async def test_ensure_connected_skips_future_pings_when_ping_is_not_available() -> None:
|
||||
tool = MCPTool(name="test_tool")
|
||||
tool.session = Mock(
|
||||
send_ping=AsyncMock(
|
||||
side_effect=McpError(types.ErrorData(code=-32601, message="Method 'ping' is not available."))
|
||||
)
|
||||
)
|
||||
|
||||
with patch.object(tool, "_reconnect_without_loading", AsyncMock()) as mock_reconnect:
|
||||
await tool._ensure_connected()
|
||||
await tool._ensure_connected()
|
||||
|
||||
tool.session.send_ping.assert_awaited_once()
|
||||
mock_reconnect.assert_not_awaited()
|
||||
assert tool._ping_available is False
|
||||
|
||||
|
||||
async def test_ensure_connected_reconnects_on_failed_ping() -> None:
|
||||
tool = MCPTool(name="test_tool")
|
||||
tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed")))
|
||||
|
||||
with patch.object(tool, "connect", AsyncMock()) as mock_connect:
|
||||
with patch.object(tool, "_reconnect_without_loading", AsyncMock()) as mock_reconnect:
|
||||
await tool._ensure_connected()
|
||||
|
||||
mock_connect.assert_awaited_once_with(reset=True)
|
||||
mock_reconnect.assert_awaited_once_with()
|
||||
|
||||
|
||||
async def test_ensure_connected_wraps_reconnect_failure() -> None:
|
||||
@@ -4046,12 +4134,70 @@ async def test_ensure_connected_wraps_reconnect_failure() -> None:
|
||||
tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed")))
|
||||
|
||||
with (
|
||||
patch.object(tool, "connect", AsyncMock(side_effect=RuntimeError("still closed"))),
|
||||
patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=RuntimeError("still closed"))),
|
||||
pytest.raises(ToolExecutionException, match="Failed to establish MCP connection"),
|
||||
):
|
||||
await tool._ensure_connected()
|
||||
|
||||
|
||||
async def test_load_tools_reconnects_on_closed_resource_when_ping_is_unavailable() -> None:
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
tool = MCPTool(name="test_tool", load_tools=True)
|
||||
tool._ping_available = False
|
||||
|
||||
first_session = Mock()
|
||||
first_session.list_tools = AsyncMock(side_effect=ClosedResourceError())
|
||||
tool.session = first_session
|
||||
|
||||
page = Mock()
|
||||
page.tools = []
|
||||
page.nextCursor = None
|
||||
|
||||
second_session = Mock()
|
||||
second_session.list_tools = AsyncMock(return_value=page)
|
||||
|
||||
async def reconnect() -> None:
|
||||
tool.session = second_session
|
||||
tool._supports_tools = True
|
||||
|
||||
with patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=reconnect)) as mock_reconnect:
|
||||
await tool.load_tools()
|
||||
|
||||
first_session.list_tools.assert_awaited_once()
|
||||
mock_reconnect.assert_awaited_once_with()
|
||||
second_session.list_tools.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_load_prompts_reconnects_on_closed_resource_when_ping_is_unavailable() -> None:
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
tool = MCPTool(name="test_tool", load_prompts=True)
|
||||
tool._ping_available = False
|
||||
|
||||
first_session = Mock()
|
||||
first_session.list_prompts = AsyncMock(side_effect=ClosedResourceError())
|
||||
tool.session = first_session
|
||||
|
||||
page = Mock()
|
||||
page.prompts = []
|
||||
page.nextCursor = None
|
||||
|
||||
second_session = Mock()
|
||||
second_session.list_prompts = AsyncMock(return_value=page)
|
||||
|
||||
async def reconnect() -> None:
|
||||
tool.session = second_session
|
||||
tool._supports_prompts = True
|
||||
|
||||
with patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=reconnect)) as mock_reconnect:
|
||||
await tool.load_prompts()
|
||||
|
||||
first_session.list_prompts.assert_awaited_once()
|
||||
mock_reconnect.assert_awaited_once_with()
|
||||
second_session.list_prompts.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_mcp_tool_filters_framework_kwargs():
|
||||
"""Test that call_tool filters out framework-specific kwargs before calling MCP session.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user