mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: MCP long-running task support in Python (#6319)
* MCP long-running task support in Python * Fix pyupgrade and AGENTS.md reconnect description - pyupgrade: drop forward-reference string annotations in _mcp.py (Python 3.10+ resolves them natively now that MCPTaskOptions is defined before use). - AGENTS.md: align reconnect description with current behavior. Phase 1 (initial tools/call) does NOT retry on connection loss; raises 'connection lost; task state unknown' instead, so a server that accepted the request but lost the response cannot start the operation twice. Phase 2 (tasks/get / tasks/result) still reconnects once against the same task_id. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix bandit nosec marker for CI pipeline * Address PR feedbacks * Clarifiied comments and addressed more PR feedbacks. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
01fc518b29
commit
bf4ad48cf2
@@ -76,6 +76,19 @@ agent_framework/
|
||||
- **`SkillScriptRunner`** - Protocol for file-based script execution. Any callable matching `(skill, script, args) -> Any` satisfies it. Code-defined scripts do not use a runner.
|
||||
- **`SkillsProvider`** - Context provider (extends `ContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts.
|
||||
|
||||
### Model Context Protocol (`_mcp.py`)
|
||||
|
||||
- **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s.
|
||||
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
|
||||
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
|
||||
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
|
||||
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
|
||||
- `max_task_wait: timedelta | None` — client-side deadline for the whole post-create lifecycle (poll + result fetch). When exceeded, raises `ToolExecutionException` and fires a best-effort `tasks/cancel`. `None` (default) means no client-side bound. Bounds sleeps, sends, AND reconnects via `asyncio.wait_for`.
|
||||
- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working. An unparseable success response (server accepted the augmented call but returned a payload that is neither `CreateTaskResult` nor `CallToolResult`) **does not** fall back — it raises `ToolExecutionException` to avoid double-executing a side-effecting tool.
|
||||
- **Submit-vs-track reconnect policy**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id (a shared `_send_with_one_reconnect` helper).
|
||||
- **Cancel-on-abandonment vs terminal failure**: any path where the remote task may still be running (max-wait exceeded, hard `McpError` in poll, malformed `tasks/get`, second connection loss in poll/fetch, reconnect failure) fires best-effort `tasks/cancel` before raising. Terminal failures (`failed`/`cancelled`/`input_required` server-side, `completed+isError`, malformed `tasks/result` after server completed) do **not** cancel — the server is already done. `_MCPTaskAbandoned` is the private marker distinguishing the two.
|
||||
- **Transient poll retry**: a slow `tasks/get` that surfaces as `McpError(code=408 REQUEST_TIMEOUT)` is retried (bounded by `max_task_wait`). All other non-connection `McpError`s during poll are treated as abandonment. `tasks/result` does not get transient retry — the server has already completed, so a slow payload fetch is anomalous.
|
||||
|
||||
### File Access Harness (`_harness/_file_access.py`)
|
||||
|
||||
- **`AgentFileStore`** - Abstract async store backing the file-access harness. Implementations expose `write_file`, `read_file`, `delete_file`, `list_files`, `file_exists`, `search_files`, and `create_directory` over forward-slash relative paths.
|
||||
|
||||
@@ -124,7 +124,7 @@ from ._harness._todo import (
|
||||
TodoSessionStore,
|
||||
TodoStore,
|
||||
)
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool
|
||||
from ._middleware import (
|
||||
AgentContext,
|
||||
AgentMiddleware,
|
||||
@@ -444,12 +444,13 @@ __all__ = [
|
||||
"InlineSkillResource",
|
||||
"InlineSkillScript",
|
||||
"LocalEvaluator",
|
||||
"MCPStdioTool",
|
||||
"MCPStreamableHTTPTool",
|
||||
"MCPWebsocketTool",
|
||||
"MCPSkill",
|
||||
"MCPSkillResource",
|
||||
"MCPSkillsSource",
|
||||
"MCPStdioTool",
|
||||
"MCPStreamableHTTPTool",
|
||||
"MCPTaskOptions",
|
||||
"MCPWebsocketTool",
|
||||
"MemoryContextProvider",
|
||||
"MemoryFileStore",
|
||||
"MemoryIndexEntry",
|
||||
|
||||
@@ -58,6 +58,7 @@ class ExperimentalFeature(str, Enum):
|
||||
FOUNDRY_PREVIEW_TOOLS = "FOUNDRY_PREVIEW_TOOLS"
|
||||
FUNCTIONAL_WORKFLOWS = "FUNCTIONAL_WORKFLOWS"
|
||||
HARNESS = "HARNESS"
|
||||
MCP_LONG_RUNNING_TASKS = "MCP_LONG_RUNNING_TASKS"
|
||||
MCP_SKILLS = "MCP_SKILLS"
|
||||
PROGRESSIVE_TOOLS = "PROGRESSIVE_TOOLS"
|
||||
SKILLS = "SKILLS"
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
@@ -12,12 +13,14 @@ import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
|
||||
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
|
||||
|
||||
from opentelemetry import propagate
|
||||
|
||||
from ._feature_stage import ExperimentalFeature, experimental
|
||||
from ._tools import FunctionTool
|
||||
from ._types import (
|
||||
ChatOptions,
|
||||
@@ -149,6 +152,73 @@ def _url_origin(url: Any) -> tuple[str, str, int | None]:
|
||||
return (url.scheme, url.host or "", port)
|
||||
|
||||
|
||||
# Internal polling bounds for MCP long-running tasks. Not user-tunable today;
|
||||
# promote to MCPTaskOptions if a concrete need arises.
|
||||
_MCP_TASK_MIN_POLL_INTERVAL = timedelta(milliseconds=500)
|
||||
_MCP_TASK_MAX_POLL_INTERVAL = timedelta(seconds=5)
|
||||
_MCP_TASK_CANCEL_TIMEOUT = timedelta(seconds=5)
|
||||
_MCP_TASK_TERMINAL_STATUSES: frozenset[str] = frozenset({"completed", "failed", "cancelled", "input_required"})
|
||||
|
||||
# Total send attempts for a Phase 2 request (initial try + one reconnect-and-retry).
|
||||
# A single transient disconnect should not abort a long-running task; sustained outages
|
||||
# surface as ``_MCPTaskAbandoned`` after the second failure.
|
||||
_MCP_RECONNECT_ATTEMPTS = 2
|
||||
|
||||
|
||||
class _MCPTaskAbandoned(ToolExecutionException):
|
||||
"""Raised when the remote MCP task may still be running and must be cancelled.
|
||||
|
||||
Subclass of ToolExecutionException so callers see a normal tool failure.
|
||||
"""
|
||||
|
||||
|
||||
class _MCPDeadlineExpired(Exception):
|
||||
"""Internal marker for ``max_task_wait`` expiry; distinct from inner TimeoutError."""
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.MCP_LONG_RUNNING_TASKS)
|
||||
@dataclass(frozen=True)
|
||||
class MCPTaskOptions:
|
||||
"""Options controlling how MCPTool drives the MCP long-running task lifecycle.
|
||||
|
||||
When an MCP server advertises a tool with ``execution.taskSupport == "required"``,
|
||||
the framework transparently drives the SEP-2663 ``tools/call`` → ``tasks/get``
|
||||
(polled) → ``tasks/result`` lifecycle so the agent sees a normal tool result.
|
||||
|
||||
Instances are immutable; replace the whole object via
|
||||
``MCPTool.task_options = MCPTaskOptions(...)`` to change behavior.
|
||||
|
||||
Attributes:
|
||||
default_ttl: Optional task-record retention time forwarded to the server as
|
||||
``params.task.ttl`` (milliseconds, integer). The server keeps the task
|
||||
record around this long after the task reaches a terminal status so the
|
||||
client can still call ``tasks/get`` / ``tasks/result``; it does not
|
||||
cancel a running task. When ``None``, the server applies its own default.
|
||||
Must be positive if set (zero would expire the record before any client
|
||||
could read it).
|
||||
cancel_remote_task_on_local_cancellation: If True (default), a local
|
||||
cancellation of the awaiting coroutine triggers a best-effort
|
||||
``tasks/cancel`` on the server before re-raising ``CancelledError``.
|
||||
Only gates ``CancelledError``; abandonment paths (max-wait,
|
||||
unrecoverable poll errors, lost connection after task_id is known)
|
||||
always cancel regardless of this flag.
|
||||
max_task_wait: Optional client-side deadline for the whole post-create
|
||||
lifecycle (poll + result fetch). When exceeded, raises
|
||||
``ToolExecutionException`` and fires a best-effort ``tasks/cancel``.
|
||||
``None`` (default) means no client-side bound. Must be positive if set.
|
||||
"""
|
||||
|
||||
default_ttl: timedelta | None = None
|
||||
cancel_remote_task_on_local_cancellation: bool = True
|
||||
max_task_wait: timedelta | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.default_ttl is not None and self.default_ttl.total_seconds() <= 0:
|
||||
raise ValueError("MCPTaskOptions.default_ttl must be positive.")
|
||||
if self.max_task_wait is not None and self.max_task_wait.total_seconds() <= 0:
|
||||
raise ValueError("MCPTaskOptions.max_task_wait must be positive.")
|
||||
|
||||
|
||||
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
"""Lazily import the MCP streamable HTTP transport."""
|
||||
try:
|
||||
@@ -217,6 +287,7 @@ class MCPTool:
|
||||
request_timeout: int | None = None,
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
) -> None:
|
||||
"""Initialize the MCP Tool base.
|
||||
|
||||
@@ -248,6 +319,9 @@ class MCPTool:
|
||||
request_timeout: Timeout in seconds for MCP requests.
|
||||
client: A chat client for sampling callbacks.
|
||||
additional_properties: Additional properties for the tool.
|
||||
task_options: Options controlling how long-running MCP tasks are driven for
|
||||
tools that advertise ``execution.taskSupport == "required"``. When ``None``,
|
||||
the defaults from :class:`MCPTaskOptions` are used.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description or ""
|
||||
@@ -259,6 +333,10 @@ class MCPTool:
|
||||
self.parse_tool_results = parse_tool_results
|
||||
self.load_prompts_flag = load_prompts
|
||||
self.parse_prompt_results = parse_prompt_results
|
||||
# Defer constructing the default MCPTaskOptions so the experimental warning
|
||||
# only fires when LRO is actually engaged (lazy-resolved by _effective_task_options).
|
||||
self._task_options_explicit: MCPTaskOptions | None = task_options
|
||||
self._task_options_default: MCPTaskOptions | None = None
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
self._lifecycle_request_lock = asyncio.Lock()
|
||||
@@ -270,6 +348,7 @@ class MCPTool:
|
||||
self.client = client
|
||||
self._functions: list[FunctionTool] = []
|
||||
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
self._tool_task_support_by_name: dict[str, str] = {}
|
||||
self.is_connected: bool = False
|
||||
self._tools_loaded: bool = False
|
||||
self._prompts_loaded: bool = False
|
||||
@@ -1131,6 +1210,7 @@ class MCPTool:
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
tool_task_support_by_name: dict[str, str] = {}
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
@@ -1168,6 +1248,10 @@ class MCPTool:
|
||||
if tool.meta is not None:
|
||||
tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
|
||||
task_support = getattr(getattr(tool, "execution", None), "taskSupport", None)
|
||||
if task_support is not None:
|
||||
tool_task_support_by_name[tool.name] = task_support
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
@@ -1216,6 +1300,7 @@ class MCPTool:
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
self._tool_call_meta_by_name = tool_call_meta_by_name
|
||||
self._tool_task_support_by_name = tool_task_support_by_name
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
# Cancel any pending reload tasks before tearing down the session.
|
||||
@@ -1292,6 +1377,29 @@ class MCPTool:
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
|
||||
def _effective_task_options(self) -> MCPTaskOptions:
|
||||
"""Return the effective MCPTaskOptions, lazily constructing defaults on first use.
|
||||
|
||||
Defers the implicit ``MCPTaskOptions()`` so the experimental warning only
|
||||
fires when LRO is actually engaged (server advertises ``taskSupport=required``).
|
||||
"""
|
||||
explicit = self._task_options_explicit
|
||||
if explicit is not None:
|
||||
return explicit
|
||||
if self._task_options_default is None:
|
||||
self._task_options_default = MCPTaskOptions()
|
||||
return self._task_options_default
|
||||
|
||||
@property
|
||||
def task_options(self) -> MCPTaskOptions:
|
||||
"""The effective MCPTaskOptions for this tool (lazy defaults)."""
|
||||
return self._effective_task_options()
|
||||
|
||||
@task_options.setter
|
||||
def task_options(self, value: MCPTaskOptions | None) -> None:
|
||||
self._task_options_explicit = value
|
||||
self._task_options_default = None
|
||||
|
||||
async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
|
||||
"""Call a tool with the given arguments.
|
||||
|
||||
@@ -1322,47 +1430,12 @@ class MCPTool:
|
||||
"Tools are not loaded for this server, please set load_tools=True in the constructor."
|
||||
)
|
||||
|
||||
raw_user_meta: object | None = kwargs.get("_meta")
|
||||
user_meta: dict[str, Any] | None = None
|
||||
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
|
||||
if isinstance(raw_user_meta, dict):
|
||||
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
|
||||
user_meta = {}
|
||||
for key, value in raw_user_meta_dict.items():
|
||||
if not isinstance(key, str):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
|
||||
user_meta[key] = value
|
||||
# Tools advertising taskSupport == "required" cannot complete via plain tools/call;
|
||||
# route through the long-running task lifecycle transparently.
|
||||
if self._tool_task_support_by_name.get(tool_name) == "required":
|
||||
return await self.call_tool_as_task(tool_name, **kwargs)
|
||||
|
||||
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
|
||||
# These are internal objects passed through the function invocation pipeline
|
||||
# that should not be forwarded to external MCP servers.
|
||||
# conversation_id is an internal tracking ID used by services like Azure AI.
|
||||
# options contains metadata/store used by AG-UI for Azure AI client requirements.
|
||||
# response_format is a Pydantic model class used for structured output (not serializable).
|
||||
filtered_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in {
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"thread",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
}
|
||||
}
|
||||
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
tool_meta = self._tool_call_meta_by_name.get(tool_name)
|
||||
request_meta = dict(tool_meta) if tool_meta is not None else None
|
||||
if user_meta is not None:
|
||||
request_meta = {**(request_meta or {}), **user_meta}
|
||||
meta = _inject_otel_into_mcp_meta(request_meta)
|
||||
filtered_kwargs, meta = self._prepare_call_kwargs(tool_name, kwargs)
|
||||
|
||||
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
|
||||
# Try the operation, reconnecting once if the connection is closed
|
||||
@@ -1411,6 +1484,479 @@ class MCPTool:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
|
||||
|
||||
def _prepare_call_kwargs(
|
||||
self, tool_name: str, kwargs: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], dict[str, Any] | None]:
|
||||
"""Filter framework-only kwargs and build the merged MCP request metadata."""
|
||||
raw_user_meta: object | None = kwargs.get("_meta")
|
||||
user_meta: dict[str, Any] | None = None
|
||||
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
|
||||
if isinstance(raw_user_meta, dict):
|
||||
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
|
||||
user_meta = {}
|
||||
for key, value in raw_user_meta_dict.items():
|
||||
if not isinstance(key, str):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
|
||||
user_meta[key] = value
|
||||
|
||||
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
|
||||
# These are internal objects passed through the function invocation pipeline
|
||||
# that should not be forwarded to external MCP servers.
|
||||
# conversation_id is an internal tracking ID used by services like Azure AI.
|
||||
# options contains metadata/store used by AG-UI for Azure AI client requirements.
|
||||
# response_format is a Pydantic model class used for structured output (not serializable).
|
||||
filtered_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in {
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"thread",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
}
|
||||
}
|
||||
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
tool_meta = self._tool_call_meta_by_name.get(tool_name)
|
||||
request_meta = dict(tool_meta) if tool_meta is not None else None
|
||||
if user_meta is not None:
|
||||
request_meta = {**(request_meta or {}), **user_meta}
|
||||
meta = _inject_otel_into_mcp_meta(request_meta)
|
||||
return filtered_kwargs, meta
|
||||
|
||||
async def call_tool_as_task(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
|
||||
"""Call an MCP tool via the long-running task lifecycle (SEP-2663).
|
||||
|
||||
Issues an augmented ``tools/call`` with ``params.task`` set from
|
||||
``self.task_options``, then polls ``tasks/get`` until the server reports a
|
||||
terminal status. On ``completed`` the payload is fetched via ``tasks/result``,
|
||||
validated as a ``CallToolResult`` and parsed identically to :meth:`call_tool`.
|
||||
|
||||
Local cancellation triggers a best-effort ``tasks/cancel`` (controlled by
|
||||
:attr:`MCPTaskOptions.cancel_remote_task_on_local_cancellation`) before
|
||||
``asyncio.CancelledError`` is re-raised.
|
||||
|
||||
Args:
|
||||
tool_name: The remote MCP tool name.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Arguments forwarded to the tool. See :meth:`call_tool` for the
|
||||
framework kwargs that are filtered out.
|
||||
|
||||
Returns:
|
||||
A list of Content items (or a string when a custom ``parse_tool_results``
|
||||
callback is configured).
|
||||
"""
|
||||
from anyio import ClosedResourceError
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
if not self.load_tools_flag:
|
||||
raise ToolExecutionException(
|
||||
"Tools are not loaded for this server, please set load_tools=True in the constructor."
|
||||
)
|
||||
|
||||
filtered_kwargs, meta = self._prepare_call_kwargs(tool_name, kwargs)
|
||||
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
|
||||
|
||||
# Submit the task: issue augmented tools/call. Do NOT retry on connection loss here:
|
||||
# the server may have accepted the request and created a task before the
|
||||
# response was lost, so retrying could start the long-running operation twice.
|
||||
# Reconnect-and-retry is only safe after the task_id is known.
|
||||
try:
|
||||
task_id, fallback_result = await self._call_tool_as_task_create(tool_name, filtered_kwargs, meta)
|
||||
except (ClosedResourceError, McpError) as ex:
|
||||
if not self._is_connection_lost(ex):
|
||||
error_message = ex.error.message if isinstance(ex, McpError) else str(ex)
|
||||
raise ToolExecutionException(error_message, inner_exception=ex) from ex
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost; task state unknown.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
except ToolExecutionException:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
|
||||
# Server returned a CallToolResult (no task created) or fell back to plain tools/call.
|
||||
if fallback_result is not None:
|
||||
if fallback_result.isError:
|
||||
parsed = parser(fallback_result)
|
||||
text = (
|
||||
"\n".join(c.text for c in parsed if c.type == "text" and c.text)
|
||||
if isinstance(parsed, list)
|
||||
else str(parsed)
|
||||
)
|
||||
raise ToolExecutionException(text or str(parsed))
|
||||
return parser(fallback_result)
|
||||
|
||||
if task_id is None:
|
||||
raise ToolExecutionException(
|
||||
f"MCP server did not return a task_id or fallback result for '{tool_name}'."
|
||||
)
|
||||
|
||||
# Track to completion: poll until terminal, then fetch payload. Never re-issue
|
||||
# tools/call past this point; reconnect-and-retry only against the same task_id.
|
||||
opts = self._effective_task_options()
|
||||
max_wait_s = opts.max_task_wait.total_seconds() if opts.max_task_wait is not None else None
|
||||
|
||||
async def _await_task_completion() -> str | list[Content]:
|
||||
terminal = await self._poll_task_until_terminal(task_id)
|
||||
return await self._handle_terminal_task(tool_name, task_id, terminal, parser)
|
||||
|
||||
try:
|
||||
if max_wait_s is not None:
|
||||
try:
|
||||
result = await self._await_with_deadline(_await_task_completion(), max_wait_s)
|
||||
return cast("str | list[Content]", result)
|
||||
except _MCPDeadlineExpired as ex:
|
||||
self._spawn_best_effort_cancel(task_id)
|
||||
raise ToolExecutionException(
|
||||
f"MCP task '{task_id}' exceeded max_task_wait of {max_wait_s}s.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
else:
|
||||
return await _await_task_completion()
|
||||
except asyncio.CancelledError:
|
||||
if opts.cancel_remote_task_on_local_cancellation:
|
||||
self._spawn_best_effort_cancel(task_id)
|
||||
raise
|
||||
except _MCPTaskAbandoned:
|
||||
# Pre-terminal abandonment (hard poll error, malformed get, second
|
||||
# disconnect, reconnect failure): cancel + re-raise as plain
|
||||
# ToolExecutionException to the function-calling loop.
|
||||
self._spawn_best_effort_cancel(task_id)
|
||||
raise
|
||||
# Plain ToolExecutionException from terminal failures (failed/cancelled/
|
||||
# input_required, completed+isError, malformed result post-completion)
|
||||
# propagates without cancel — server is already done.
|
||||
|
||||
async def _call_tool_as_task_create(
|
||||
self, tool_name: str, arguments: dict[str, Any], meta: dict[str, Any] | None
|
||||
) -> tuple[str | None, types.CallToolResult | None]:
|
||||
"""Send the augmented tools/call.
|
||||
|
||||
Returns ``(task_id, None)`` when the server created a task,
|
||||
``(None, CallToolResult)`` when it returned a non-task result, falling back
|
||||
to plain ``tools/call`` if the server rejects the ``task`` field outright.
|
||||
"""
|
||||
from mcp import types
|
||||
from mcp.shared.exceptions import McpError
|
||||
from pydantic import ValidationError
|
||||
|
||||
opts = self._effective_task_options()
|
||||
ttl_ms: int | None = None
|
||||
if opts.default_ttl is not None:
|
||||
ttl_ms = int(opts.default_ttl.total_seconds() * 1000)
|
||||
# Always send TaskMetadata to mark the call as task-augmented; ttl may be omitted.
|
||||
task_metadata = types.TaskMetadata(ttl=ttl_ms)
|
||||
|
||||
request_meta = types.RequestParams.Meta(**meta) if meta else None
|
||||
params = types.CallToolRequestParams(
|
||||
name=tool_name,
|
||||
arguments=arguments,
|
||||
task=task_metadata,
|
||||
_meta=request_meta, # type: ignore[call-arg]
|
||||
)
|
||||
request = types.ClientRequest(types.CallToolRequest(params=params))
|
||||
|
||||
# Use the lenient Result type so we can extract the task_id even when
|
||||
# the strict CreateTaskResult schema rejects the payload (the MCP Python
|
||||
# SDK requires Task.ttl, but servers may legitimately omit it).
|
||||
try:
|
||||
lenient = await self.session.send_request( # type: ignore[union-attr]
|
||||
request,
|
||||
types.Result,
|
||||
)
|
||||
except McpError as ex:
|
||||
if ex.error.code not in (types.METHOD_NOT_FOUND, types.INVALID_PARAMS):
|
||||
raise
|
||||
logger.debug(
|
||||
"Server rejected augmented tools/call for '%s' (code=%s); falling back.",
|
||||
tool_name,
|
||||
ex.error.code,
|
||||
)
|
||||
fallback = await self.session.call_tool(tool_name, arguments=arguments, meta=meta) # type: ignore[union-attr]
|
||||
return None, fallback
|
||||
|
||||
# Inspect the raw payload: a CreateTaskResult carries `task.taskId`;
|
||||
# a legacy CallToolResult carries `content` and/or `isError`.
|
||||
raw: dict[str, Any] = lenient.model_dump(by_alias=True, exclude_none=True)
|
||||
raw.pop("_meta", None)
|
||||
|
||||
task_field = raw.get("task")
|
||||
if isinstance(task_field, dict):
|
||||
task_id_val = cast(dict[str, Any], task_field).get("taskId")
|
||||
if isinstance(task_id_val, str):
|
||||
return task_id_val, None
|
||||
|
||||
try:
|
||||
legacy = types.CallToolResult.model_validate(raw)
|
||||
except ValidationError as ex:
|
||||
# Augmented call succeeded server-side; re-issuing a plain tools/call
|
||||
# could double-execute a side-effecting tool.
|
||||
raise ToolExecutionException(
|
||||
f"MCP server returned an unparseable response to augmented tools/call "
|
||||
f"for '{tool_name}'; cannot safely retry (server may have started the operation).",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
|
||||
return None, legacy
|
||||
|
||||
async def _poll_task_until_terminal(self, task_id: str) -> types.GetTaskResult:
|
||||
"""Poll ``tasks/get`` until the task reaches a terminal status."""
|
||||
import httpx
|
||||
from mcp import types
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
# SDK raises McpError(code=httpx.REQUEST_TIMEOUT=408) on session read timeout.
|
||||
transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)})
|
||||
|
||||
while True:
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
try:
|
||||
# GetTaskResult.ttl is required-but-Optional in the SDK; coerce below.
|
||||
lenient = await self._send_with_one_reconnect(
|
||||
request, types.Result, operation="tasks/get", task_id=task_id
|
||||
)
|
||||
except McpError as ex:
|
||||
if ex.error.code in transient_codes:
|
||||
logger.debug(
|
||||
"Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id
|
||||
)
|
||||
await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds())
|
||||
continue
|
||||
# Hard server error mid-poll: task may still be running.
|
||||
raise _MCPTaskAbandoned(ex.error.message, inner_exception=ex) from ex
|
||||
|
||||
try:
|
||||
snapshot = self._coerce_get_task_result(lenient, task_id)
|
||||
except ToolExecutionException as ex:
|
||||
# Malformed tasks/get response; task may still be running.
|
||||
raise _MCPTaskAbandoned(str(ex), inner_exception=ex) from ex
|
||||
|
||||
if snapshot.status in _MCP_TASK_TERMINAL_STATUSES:
|
||||
return snapshot
|
||||
|
||||
await asyncio.sleep(self._compute_poll_delay(snapshot.pollInterval).total_seconds())
|
||||
|
||||
@staticmethod
|
||||
def _coerce_get_task_result(lenient: types.Result, task_id: str) -> types.GetTaskResult:
|
||||
"""Coerce a lenient Result into GetTaskResult, defaulting ``ttl`` when absent."""
|
||||
from mcp import types
|
||||
|
||||
raw = lenient.model_dump(by_alias=True, exclude_none=True)
|
||||
raw.pop("_meta", None)
|
||||
raw.setdefault("ttl", None)
|
||||
try:
|
||||
return types.GetTaskResult.model_validate(raw)
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(
|
||||
f"MCP server returned a malformed tasks/get response for task '{task_id}'.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
|
||||
@staticmethod
|
||||
def _compute_poll_delay(server_interval_ms: int | None) -> timedelta:
|
||||
"""Clamp the server-suggested poll interval to ``[min, max]``."""
|
||||
if server_interval_ms is None or server_interval_ms <= 0:
|
||||
return _MCP_TASK_MIN_POLL_INTERVAL
|
||||
suggested = timedelta(milliseconds=server_interval_ms)
|
||||
if suggested < _MCP_TASK_MIN_POLL_INTERVAL:
|
||||
return _MCP_TASK_MIN_POLL_INTERVAL
|
||||
if suggested > _MCP_TASK_MAX_POLL_INTERVAL:
|
||||
return _MCP_TASK_MAX_POLL_INTERVAL
|
||||
return suggested
|
||||
|
||||
async def _handle_terminal_task(
|
||||
self,
|
||||
tool_name: str,
|
||||
task_id: str,
|
||||
snapshot: types.GetTaskResult,
|
||||
parser: Callable[[types.CallToolResult], str | list[Content]],
|
||||
) -> str | list[Content]:
|
||||
"""Map a terminal task snapshot to either a parsed result or an exception."""
|
||||
status = snapshot.status
|
||||
if status == "completed":
|
||||
payload = await self._fetch_task_result(task_id)
|
||||
if payload.isError:
|
||||
parsed = parser(payload)
|
||||
text = (
|
||||
"\n".join(c.text for c in parsed if c.type == "text" and c.text)
|
||||
if isinstance(parsed, list)
|
||||
else str(parsed)
|
||||
)
|
||||
raise ToolExecutionException(text or str(parsed))
|
||||
return parser(payload)
|
||||
|
||||
# Non-completed terminal statuses surface as ToolExecutionException so the
|
||||
# function-calling loop sees a normal failure for tool_name.
|
||||
message = snapshot.statusMessage or f"MCP task ended with status '{status}'."
|
||||
if status == "input_required":
|
||||
# Spec-non-terminal; treated as terminal here because the framework does
|
||||
# not implement the interactive input flow.
|
||||
message = snapshot.statusMessage or "MCP task requires additional input and cannot continue."
|
||||
raise ToolExecutionException(f"Tool '{tool_name}' task {status}: {message}")
|
||||
|
||||
async def _fetch_task_result(self, task_id: str) -> types.CallToolResult:
|
||||
"""Send ``tasks/result`` and reinterpret the open-typed payload as a CallToolResult."""
|
||||
from mcp import types
|
||||
from mcp.shared.exceptions import McpError
|
||||
from pydantic import ValidationError
|
||||
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))
|
||||
)
|
||||
# Connection-loss retry only via the helper; no transient-code retry — server
|
||||
# has already completed the task, so a slow payload fetch is anomalous.
|
||||
try:
|
||||
payload = await self._send_with_one_reconnect(
|
||||
request, types.GetTaskPayloadResult, operation="tasks/result", task_id=task_id
|
||||
)
|
||||
except McpError as ex:
|
||||
# Server reported completed; a hard fetch error is a plain failure (no cancel).
|
||||
raise ToolExecutionException(ex.error.message, inner_exception=ex) from ex
|
||||
|
||||
# GetTaskPayloadResult carries the tool result via extra fields; reinterpret as CallToolResult.
|
||||
payload_dict = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
payload_dict.pop("_meta", None)
|
||||
try:
|
||||
return types.CallToolResult.model_validate(payload_dict)
|
||||
except ValidationError as ex:
|
||||
# Server reported completed; malformed payload is a plain failure (no cancel needed).
|
||||
raise ToolExecutionException(
|
||||
f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
|
||||
async def _send_with_one_reconnect(
|
||||
self,
|
||||
request: types.ClientRequest,
|
||||
result_type: type[Any],
|
||||
*,
|
||||
operation: str,
|
||||
task_id: str,
|
||||
) -> Any:
|
||||
"""Send ``request`` with one reconnect-and-retry on connection loss.
|
||||
|
||||
After a second loss (or reconnect failure), raise ``_MCPTaskAbandoned``.
|
||||
Non-connection errors propagate unchanged.
|
||||
"""
|
||||
from anyio import ClosedResourceError
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
for attempt in range(_MCP_RECONNECT_ATTEMPTS):
|
||||
try:
|
||||
return await self.session.send_request(request, result_type) # type: ignore[union-attr]
|
||||
except (ClosedResourceError, McpError) as ex:
|
||||
if not self._is_connection_lost(ex):
|
||||
raise
|
||||
if attempt < _MCP_RECONNECT_ATTEMPTS - 1:
|
||||
logger.info(
|
||||
"MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id
|
||||
)
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
except Exception as reconn_ex:
|
||||
# Reconnect failure: task may still be running.
|
||||
raise _MCPTaskAbandoned(
|
||||
"Failed to reconnect to MCP server.", inner_exception=reconn_ex
|
||||
) from reconn_ex
|
||||
continue
|
||||
# Final attempt also lost the connection: task may still be running.
|
||||
raise _MCPTaskAbandoned(
|
||||
f"MCP connection lost; task state unknown (task_id={task_id}).",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
raise AssertionError(f"unreachable: {operation} for {task_id}") # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
async def _await_with_deadline(coro: Coroutine[Any, Any, Any], timeout_s: float) -> Any:
|
||||
"""Await ``coro`` with a deadline; raise ``_MCPDeadlineExpired`` only on deadline.
|
||||
|
||||
Unlike ``asyncio.wait_for``, an ``asyncio.TimeoutError`` raised by ``coro``
|
||||
itself propagates unchanged so callers can distinguish their own deadline
|
||||
from a stray inner timeout.
|
||||
"""
|
||||
inner = asyncio.ensure_future(coro)
|
||||
try:
|
||||
done, _pending = await asyncio.wait({inner}, timeout=timeout_s)
|
||||
except BaseException:
|
||||
# Outer caller cancelled (or another exception): cancel inner + drain.
|
||||
inner.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await inner
|
||||
raise
|
||||
if inner in done:
|
||||
return inner.result()
|
||||
# Deadline fired before inner finished.
|
||||
inner.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await inner
|
||||
raise _MCPDeadlineExpired
|
||||
|
||||
def _spawn_best_effort_cancel(self, task_id: str) -> None:
|
||||
"""Fire-and-forget ``tasks/cancel`` so local cancellation propagates server-side."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
cancel_task = loop.create_task(self._try_cancel_task(task_id))
|
||||
# Reuse pending-reload bookkeeping so close-on-owner waits/cancels these too.
|
||||
self._pending_reload_tasks.add(cancel_task)
|
||||
cancel_task.add_done_callback(self._pending_reload_tasks.discard)
|
||||
|
||||
async def _try_cancel_task(self, task_id: str) -> None:
|
||||
"""Send ``tasks/cancel``; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``.
|
||||
|
||||
Failures log at warning so unattributed orphan tasks are debuggable.
|
||||
"""
|
||||
from mcp import types
|
||||
|
||||
request = types.ClientRequest(
|
||||
types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr]
|
||||
timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Best-effort tasks/cancel for '%s' timed out after %.1fs; "
|
||||
"remote task may still be running.",
|
||||
task_id,
|
||||
_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Best-effort tasks/cancel for '%s' failed; remote task may still be running.",
|
||||
task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_connection_lost(ex: BaseException) -> bool:
|
||||
"""Return True if *ex* indicates the MCP transport was torn down."""
|
||||
from anyio import ClosedResourceError
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
if isinstance(ex, ClosedResourceError):
|
||||
return True
|
||||
if isinstance(ex, McpError):
|
||||
return "session terminated" in ex.error.message.lower()
|
||||
return False
|
||||
|
||||
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> str:
|
||||
"""Call a prompt with the given arguments.
|
||||
|
||||
@@ -1554,6 +2100,7 @@ class MCPStdioTool(MCPTool):
|
||||
encoding: str | None = None,
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP stdio tool.
|
||||
@@ -1598,6 +2145,8 @@ class MCPStdioTool(MCPTool):
|
||||
env: The environment variables to set for the command.
|
||||
encoding: The encoding to use for the command output.
|
||||
client: The chat client to use for sampling.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
kwargs: Any extra arguments to pass to the stdio client.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -1614,6 +2163,7 @@ class MCPStdioTool(MCPTool):
|
||||
load_prompts=load_prompts,
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
)
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
@@ -1687,6 +2237,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
http_client: AsyncClient | None = None,
|
||||
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP streamable HTTP tool.
|
||||
@@ -1739,6 +2290,8 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
of HTTP headers to inject into every outbound request to the MCP server.
|
||||
Use this to forward per-request context (e.g. authentication tokens set in
|
||||
agent middleware) without creating a separate ``httpx.AsyncClient``.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -1755,6 +2308,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
load_prompts=load_prompts,
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
)
|
||||
self.url = url
|
||||
self.terminate_on_close = terminate_on_close
|
||||
@@ -1862,6 +2416,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP WebSocket tool.
|
||||
@@ -1904,6 +2459,8 @@ class MCPWebsocketTool(MCPTool):
|
||||
allowed_tools: A list of tools that are allowed to use this tool.
|
||||
additional_properties: Additional properties.
|
||||
client: The chat client to use for sampling.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
kwargs: Any extra arguments to pass to the WebSocket client.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -1920,6 +2477,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
load_prompts=load_prompts,
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
)
|
||||
self.url = url
|
||||
self._client_kwargs = kwargs
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user