Python: Skip MCP prompt loading when unsupported (#5370)

* Python: Skip MCP prompt loading when unsupported

* Fix MCP pagination pyright checks

* Simplify MCP support flag checks
This commit is contained in:
Baidar
2026-05-20 13:50:26 +02:00
committed by GitHub
Unverified
parent dd1e615dad
commit 0ba552b84c
2 changed files with 299 additions and 27 deletions
+150 -24
View File
@@ -255,7 +255,7 @@ class MCPTool:
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
self.request_timeout = request_timeout
@@ -265,6 +265,11 @@ class MCPTool:
self.is_connected: bool = False
self._tools_loaded: bool = False
self._prompts_loaded: bool = False
self._server_capabilities: types.ServerCapabilities | None = None
self._supports_tools: bool = True
self._supports_prompts: bool = True
self._supports_logging: bool | None = None
self._ping_available: bool = True
self._pending_reload_tasks: set[asyncio.Task[None]] = set()
def __str__(self) -> str:
@@ -566,11 +571,11 @@ class MCPTool:
stop_error: BaseException | None = None
try:
while True:
action, reset, future = await queue.get()
action, reset, load_configured, future = await queue.get()
try:
if action == "connect":
await self._connect_on_owner(reset=reset)
await self._connect_on_owner(reset=reset, load_configured=load_configured)
elif action == "close":
await self._close_on_owner()
else:
@@ -595,7 +600,7 @@ class MCPTool:
finally:
while True:
try:
_, _, future = queue.get_nowait()
_, _, _, future = queue.get_nowait()
except asyncio.QueueEmpty:
break
if not future.done():
@@ -608,12 +613,18 @@ class MCPTool:
owner_task = self._lifecycle_owner_task
return owner_task is not None and asyncio.current_task() is owner_task
async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> None:
async def _run_on_lifecycle_owner(
self,
action: str,
*,
reset: bool = False,
load_configured: bool = True,
) -> None:
await self._ensure_lifecycle_owner()
if self._is_lifecycle_owner_task():
if action == "connect":
await self._connect_on_owner(reset=reset)
await self._connect_on_owner(reset=reset, load_configured=load_configured)
elif action == "close":
await self._close_on_owner()
else:
@@ -625,7 +636,7 @@ class MCPTool:
raise RuntimeError("MCP lifecycle owner is not available.")
future = asyncio.get_running_loop().create_future()
await queue.put((action, reset, future))
await queue.put((action, reset, load_configured, future))
await future
async def _safe_close_exit_stack(self) -> None:
@@ -656,6 +667,32 @@ class MCPTool:
await self._safe_close_exit_stack()
return _should_propagate_cancelled_error(ex)
def _reset_session_state(self) -> None:
self._server_capabilities = None
self._supports_tools = True
self._supports_prompts = True
self._supports_logging = None
self._ping_available = True
def _set_server_capabilities(self, capabilities: types.ServerCapabilities | None) -> None:
self._server_capabilities = capabilities
if capabilities is None:
self._supports_tools = False
self._supports_prompts = False
self._supports_logging = False
return
self._supports_tools = getattr(capabilities, "tools", None) is not None
self._supports_prompts = getattr(capabilities, "prompts", None) is not None
self._supports_logging = getattr(capabilities, "logging", None) is not None
async def _reconnect_without_loading(self) -> None:
if self._is_lifecycle_owner_task():
await self._connect_on_owner(reset=True, load_configured=False)
return
await self._run_on_lifecycle_owner("connect", reset=True, load_configured=False)
async def connect(self, *, reset: bool = False) -> None:
if self._is_lifecycle_owner_task():
await self._connect_on_owner(reset=reset)
@@ -664,7 +701,7 @@ class MCPTool:
async with self._lifecycle_request_lock:
await self._run_on_lifecycle_owner("connect", reset=reset)
async def _connect_on_owner(self, *, reset: bool = False) -> None:
async def _connect_on_owner(self, *, reset: bool = False, load_configured: bool = True) -> None:
"""Connect to the MCP server.
Establishes a connection to the MCP server, initializes the session,
@@ -672,6 +709,7 @@ class MCPTool:
Keyword Args:
reset: If True, forces a reconnection even if already connected.
load_configured: If True, loads tools and prompts according to the constructor flags.
Raises:
ToolException: If connection or session initialization fails.
@@ -680,6 +718,7 @@ class MCPTool:
await self._safe_close_exit_stack()
self.session = None
self.is_connected = False
self._reset_session_state()
self._exit_stack = AsyncExitStack()
if not self.session:
try:
@@ -741,7 +780,8 @@ class MCPTool:
inner_exception=ex if isinstance(ex, Exception) else None,
) from ex
try:
await session.initialize()
initialize_result = await session.initialize()
self._set_server_capabilities(getattr(initialize_result, "capabilities", None))
except (Exception, asyncio.CancelledError) as ex:
if await self._close_and_check_cancelled(ex):
raise
@@ -759,17 +799,22 @@ class MCPTool:
self.session = session
elif self.session._request_id == 0: # type: ignore[attr-defined]
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
initialize_result = await self.session.initialize()
self._set_server_capabilities(getattr(initialize_result, "capabilities", None))
elif self._server_capabilities is None:
self._set_server_capabilities(getattr(self.session, "_server_capabilities", None))
logger.debug("Connected to MCP server: %s", self.session)
self.is_connected = True
if self.load_tools_flag:
await self.load_tools()
if load_configured and self.load_tools_flag:
if self._supports_tools:
await self.load_tools()
self._tools_loaded = True
if self.load_prompts_flag:
await self.load_prompts()
if load_configured and self.load_prompts_flag:
if self._supports_prompts:
await self.load_prompts()
self._prompts_loaded = True
if logger.level != logging.NOTSET:
if logger.level != logging.NOTSET and self._supports_logging is not False:
try:
level_name = cast(
Any, next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
@@ -973,17 +1018,49 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
from anyio import ClosedResourceError
from mcp import types
if not self._supports_prompts:
logger.debug("Skipping MCP prompt loading because the server did not advertise prompts support.")
return
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()
prompt_list: types.ListPromptsResult | None = None
for attempt in range(2):
try:
# Ensure connection is still valid before each page request
await self._ensure_connected()
if not self._supports_prompts:
logger.debug(
"Skipping MCP prompt loading because the server did not advertise prompts support."
)
return
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
break
except ClosedResourceError as cl_ex:
if attempt == 0:
logger.info("MCP connection closed unexpectedly while loading prompts. Reconnecting...")
try:
await self._reconnect_without_loading()
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
continue
logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex)
raise ToolExecutionException(
"Failed to load prompts - connection lost.",
inner_exception=cl_ex,
) from cl_ex
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
if prompt_list is None:
raise ToolExecutionException("Failed to load prompts.")
for prompt in prompt_list.prompts:
normalized_name = _normalize_mcp_name(prompt.name)
@@ -1010,7 +1087,7 @@ class MCPTool:
existing_names.add(local_name)
# Check if there are more pages
if not prompt_list or not prompt_list.nextCursor:
if not prompt_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor)
@@ -1023,18 +1100,48 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
from anyio import ClosedResourceError
from mcp import types
if not self._supports_tools:
logger.debug("Skipping MCP tool loading because the server did not advertise tools support.")
return
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
self._tool_call_meta_by_name.clear()
params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()
tool_list: types.ListToolsResult | None = None
for attempt in range(2):
try:
# Ensure connection is still valid before each page request
await self._ensure_connected()
if not self._supports_tools:
logger.debug("Skipping MCP tool loading because the server did not advertise tools support.")
return
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
break
except ClosedResourceError as cl_ex:
if attempt == 0:
logger.info("MCP connection closed unexpectedly while loading tools. Reconnecting...")
try:
await self._reconnect_without_loading()
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
continue
logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex)
raise ToolExecutionException(
"Failed to load tools - connection lost.",
inner_exception=cl_ex,
) from cl_ex
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
if tool_list is None:
raise ToolExecutionException("Failed to load tools.")
for tool in tool_list.tools:
if tool.meta is not None:
@@ -1083,7 +1190,7 @@ class MCPTool:
existing_names.add(local_name)
# Check if there are more pages
if not tool_list or not tool_list.nextCursor:
if not tool_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
@@ -1100,6 +1207,7 @@ class MCPTool:
self._exit_stack = AsyncExitStack()
self.session = None
self.is_connected = False
self._reset_session_state()
async def close(self) -> None:
"""Disconnect from the MCP server.
@@ -1131,12 +1239,30 @@ class MCPTool:
Raises:
ToolExecutionException: If reconnection fails.
"""
from mcp.shared.exceptions import McpError
if not self._ping_available:
return
try:
await self.session.send_ping() # type: ignore[union-attr]
except McpError as mcp_exc:
if mcp_exc.error.code == -32601:
self._ping_available = False
logger.debug("Skipping future MCP pings because the server does not support ping.")
return
logger.info("MCP connection invalid or closed. Reconnecting...")
try:
await self._reconnect_without_loading()
except Exception as ex:
raise ToolExecutionException(
"Failed to establish MCP connection.",
inner_exception=ex,
) from ex
except Exception:
logger.info("MCP connection invalid or closed. Reconnecting...")
try:
await self.connect(reset=True)
await self._reconnect_without_loading()
except Exception as ex:
raise ToolExecutionException(
"Failed to establish MCP connection.",
+149 -3
View File
@@ -4031,14 +4031,102 @@ async def test_connect_reinitializes_existing_session_and_loads_tools_and_prompt
assert tool._prompts_loaded is True
async def test_connect_skips_tools_and_prompts_when_server_does_not_advertise_capabilities() -> None:
tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True)
tool.is_connected = True
tool.session = Mock()
tool.session._request_id = 0
tool.session.initialize = AsyncMock(
return_value=types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ServerCapabilities(),
serverInfo=types.Implementation(name="test", version="1.0"),
)
)
tool.session.list_tools = AsyncMock()
tool.session.list_prompts = AsyncMock()
tool.session.set_logging_level = AsyncMock()
with patch.object(logger, "level", logging.INFO):
await tool._connect_on_owner()
tool.session.initialize.assert_awaited_once()
tool.session.list_tools.assert_not_called()
tool.session.list_prompts.assert_not_called()
tool.session.set_logging_level.assert_not_called()
assert tool.is_connected is True
assert tool._supports_tools is False
assert tool._supports_prompts is False
assert tool._supports_logging is False
assert tool._tools_loaded is True
assert tool._prompts_loaded is True
async def test_connect_treats_missing_capabilities_as_unsupported() -> None:
tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True)
tool.is_connected = True
tool.session = Mock()
tool.session._request_id = 0
tool.session.initialize = AsyncMock(return_value=Mock(capabilities=None))
tool.session.list_tools = AsyncMock()
tool.session.list_prompts = AsyncMock()
with patch.object(logger, "level", logging.NOTSET):
await tool._connect_on_owner()
tool.session.list_tools.assert_not_called()
tool.session.list_prompts.assert_not_called()
assert tool._supports_tools is False
assert tool._supports_prompts is False
assert tool._supports_logging is False
async def test_connect_sets_logging_level_when_server_advertises_logging() -> None:
tool = MCPTool(name="test_tool", load_tools=False, load_prompts=False)
tool.is_connected = True
tool.session = Mock()
tool.session._request_id = 0
tool.session.initialize = AsyncMock(
return_value=types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ServerCapabilities(logging=types.LoggingCapability()),
serverInfo=types.Implementation(name="test", version="1.0"),
)
)
tool.session.set_logging_level = AsyncMock()
with patch.object(logger, "level", logging.INFO):
await tool._connect_on_owner()
tool.session.set_logging_level.assert_awaited_once_with("info")
assert tool._supports_logging is True
async def test_ensure_connected_skips_future_pings_when_ping_is_not_available() -> None:
tool = MCPTool(name="test_tool")
tool.session = Mock(
send_ping=AsyncMock(
side_effect=McpError(types.ErrorData(code=-32601, message="Method 'ping' is not available."))
)
)
with patch.object(tool, "_reconnect_without_loading", AsyncMock()) as mock_reconnect:
await tool._ensure_connected()
await tool._ensure_connected()
tool.session.send_ping.assert_awaited_once()
mock_reconnect.assert_not_awaited()
assert tool._ping_available is False
async def test_ensure_connected_reconnects_on_failed_ping() -> None:
tool = MCPTool(name="test_tool")
tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed")))
with patch.object(tool, "connect", AsyncMock()) as mock_connect:
with patch.object(tool, "_reconnect_without_loading", AsyncMock()) as mock_reconnect:
await tool._ensure_connected()
mock_connect.assert_awaited_once_with(reset=True)
mock_reconnect.assert_awaited_once_with()
async def test_ensure_connected_wraps_reconnect_failure() -> None:
@@ -4046,12 +4134,70 @@ async def test_ensure_connected_wraps_reconnect_failure() -> None:
tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed")))
with (
patch.object(tool, "connect", AsyncMock(side_effect=RuntimeError("still closed"))),
patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=RuntimeError("still closed"))),
pytest.raises(ToolExecutionException, match="Failed to establish MCP connection"),
):
await tool._ensure_connected()
async def test_load_tools_reconnects_on_closed_resource_when_ping_is_unavailable() -> None:
from anyio import ClosedResourceError
tool = MCPTool(name="test_tool", load_tools=True)
tool._ping_available = False
first_session = Mock()
first_session.list_tools = AsyncMock(side_effect=ClosedResourceError())
tool.session = first_session
page = Mock()
page.tools = []
page.nextCursor = None
second_session = Mock()
second_session.list_tools = AsyncMock(return_value=page)
async def reconnect() -> None:
tool.session = second_session
tool._supports_tools = True
with patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=reconnect)) as mock_reconnect:
await tool.load_tools()
first_session.list_tools.assert_awaited_once()
mock_reconnect.assert_awaited_once_with()
second_session.list_tools.assert_awaited_once()
async def test_load_prompts_reconnects_on_closed_resource_when_ping_is_unavailable() -> None:
from anyio import ClosedResourceError
tool = MCPTool(name="test_tool", load_prompts=True)
tool._ping_available = False
first_session = Mock()
first_session.list_prompts = AsyncMock(side_effect=ClosedResourceError())
tool.session = first_session
page = Mock()
page.prompts = []
page.nextCursor = None
second_session = Mock()
second_session.list_prompts = AsyncMock(return_value=page)
async def reconnect() -> None:
tool.session = second_session
tool._supports_prompts = True
with patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=reconnect)) as mock_reconnect:
await tool.load_prompts()
first_session.list_prompts.assert_awaited_once()
mock_reconnect.assert_awaited_once_with()
second_session.list_prompts.assert_awaited_once()
async def test_mcp_tool_filters_framework_kwargs():
"""Test that call_tool filters out framework-specific kwargs before calling MCP session.