Python: MCP long-running task support in Python (#6319)

* MCP long-running task support in Python

* Fix pyupgrade and AGENTS.md reconnect description

- pyupgrade: drop forward-reference string annotations in _mcp.py (Python 3.10+ resolves them natively now that MCPTaskOptions is defined before use).

- AGENTS.md: align reconnect description with current behavior. Phase 1 (initial tools/call) does NOT retry on connection loss; raises 'connection lost; task state unknown' instead, so a server that accepted the request but lost the response cannot start the operation twice. Phase 2 (tasks/get / tasks/result) still reconnects once against the same task_id.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix bandit nosec marker for CI pipeline

* Address PR feedbacks

* Clarifiied comments and addressed more PR feedbacks.

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Peter Ibekwe
2026-06-04 17:04:55 -07:00
committed by GitHub
Unverified
parent 01fc518b29
commit bf4ad48cf2
7 changed files with 1894 additions and 44 deletions
+13
View File
@@ -76,6 +76,19 @@ agent_framework/
- **`SkillScriptRunner`** - Protocol for file-based script execution. Any callable matching `(skill, script, args) -> Any` satisfies it. Code-defined scripts do not use a runner.
- **`SkillsProvider`** - Context provider (extends `ContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts.
### Model Context Protocol (`_mcp.py`)
- **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s.
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
- `max_task_wait: timedelta | None` — client-side deadline for the whole post-create lifecycle (poll + result fetch). When exceeded, raises `ToolExecutionException` and fires a best-effort `tasks/cancel`. `None` (default) means no client-side bound. Bounds sleeps, sends, AND reconnects via `asyncio.wait_for`.
- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working. An unparseable success response (server accepted the augmented call but returned a payload that is neither `CreateTaskResult` nor `CallToolResult`) **does not** fall back — it raises `ToolExecutionException` to avoid double-executing a side-effecting tool.
- **Submit-vs-track reconnect policy**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id (a shared `_send_with_one_reconnect` helper).
- **Cancel-on-abandonment vs terminal failure**: any path where the remote task may still be running (max-wait exceeded, hard `McpError` in poll, malformed `tasks/get`, second connection loss in poll/fetch, reconnect failure) fires best-effort `tasks/cancel` before raising. Terminal failures (`failed`/`cancelled`/`input_required` server-side, `completed+isError`, malformed `tasks/result` after server completed) do **not** cancel — the server is already done. `_MCPTaskAbandoned` is the private marker distinguishing the two.
- **Transient poll retry**: a slow `tasks/get` that surfaces as `McpError(code=408 REQUEST_TIMEOUT)` is retried (bounded by `max_task_wait`). All other non-connection `McpError`s during poll are treated as abandonment. `tasks/result` does not get transient retry — the server has already completed, so a slow payload fetch is anomalous.
### File Access Harness (`_harness/_file_access.py`)
- **`AgentFileStore`** - Abstract async store backing the file-access harness. Implementations expose `write_file`, `read_file`, `delete_file`, `list_files`, `file_exists`, `search_files`, and `create_directory` over forward-slash relative paths.
@@ -124,7 +124,7 @@ from ._harness._todo import (
TodoSessionStore,
TodoStore,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool
from ._middleware import (
AgentContext,
AgentMiddleware,
@@ -444,12 +444,13 @@ __all__ = [
"InlineSkillResource",
"InlineSkillScript",
"LocalEvaluator",
"MCPStdioTool",
"MCPStreamableHTTPTool",
"MCPWebsocketTool",
"MCPSkill",
"MCPSkillResource",
"MCPSkillsSource",
"MCPStdioTool",
"MCPStreamableHTTPTool",
"MCPTaskOptions",
"MCPWebsocketTool",
"MemoryContextProvider",
"MemoryFileStore",
"MemoryIndexEntry",
@@ -58,6 +58,7 @@ class ExperimentalFeature(str, Enum):
FOUNDRY_PREVIEW_TOOLS = "FOUNDRY_PREVIEW_TOOLS"
FUNCTIONAL_WORKFLOWS = "FUNCTIONAL_WORKFLOWS"
HARNESS = "HARNESS"
MCP_LONG_RUNNING_TASKS = "MCP_LONG_RUNNING_TASKS"
MCP_SKILLS = "MCP_SKILLS"
PROGRESSIVE_TOOLS = "PROGRESSIVE_TOOLS"
SKILLS = "SKILLS"
+598 -40
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import base64
import contextlib
import contextvars
import json
import logging
@@ -12,12 +13,14 @@ import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from dataclasses import dataclass
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
from opentelemetry import propagate
from ._feature_stage import ExperimentalFeature, experimental
from ._tools import FunctionTool
from ._types import (
ChatOptions,
@@ -149,6 +152,73 @@ def _url_origin(url: Any) -> tuple[str, str, int | None]:
return (url.scheme, url.host or "", port)
# Internal polling bounds for MCP long-running tasks. Not user-tunable today;
# promote to MCPTaskOptions if a concrete need arises.
_MCP_TASK_MIN_POLL_INTERVAL = timedelta(milliseconds=500)
_MCP_TASK_MAX_POLL_INTERVAL = timedelta(seconds=5)
_MCP_TASK_CANCEL_TIMEOUT = timedelta(seconds=5)
_MCP_TASK_TERMINAL_STATUSES: frozenset[str] = frozenset({"completed", "failed", "cancelled", "input_required"})
# Total send attempts for a Phase 2 request (initial try + one reconnect-and-retry).
# A single transient disconnect should not abort a long-running task; sustained outages
# surface as ``_MCPTaskAbandoned`` after the second failure.
_MCP_RECONNECT_ATTEMPTS = 2
class _MCPTaskAbandoned(ToolExecutionException):
"""Raised when the remote MCP task may still be running and must be cancelled.
Subclass of ToolExecutionException so callers see a normal tool failure.
"""
class _MCPDeadlineExpired(Exception):
"""Internal marker for ``max_task_wait`` expiry; distinct from inner TimeoutError."""
@experimental(feature_id=ExperimentalFeature.MCP_LONG_RUNNING_TASKS)
@dataclass(frozen=True)
class MCPTaskOptions:
"""Options controlling how MCPTool drives the MCP long-running task lifecycle.
When an MCP server advertises a tool with ``execution.taskSupport == "required"``,
the framework transparently drives the SEP-2663 ``tools/call`` → ``tasks/get``
(polled) → ``tasks/result`` lifecycle so the agent sees a normal tool result.
Instances are immutable; replace the whole object via
``MCPTool.task_options = MCPTaskOptions(...)`` to change behavior.
Attributes:
default_ttl: Optional task-record retention time forwarded to the server as
``params.task.ttl`` (milliseconds, integer). The server keeps the task
record around this long after the task reaches a terminal status so the
client can still call ``tasks/get`` / ``tasks/result``; it does not
cancel a running task. When ``None``, the server applies its own default.
Must be positive if set (zero would expire the record before any client
could read it).
cancel_remote_task_on_local_cancellation: If True (default), a local
cancellation of the awaiting coroutine triggers a best-effort
``tasks/cancel`` on the server before re-raising ``CancelledError``.
Only gates ``CancelledError``; abandonment paths (max-wait,
unrecoverable poll errors, lost connection after task_id is known)
always cancel regardless of this flag.
max_task_wait: Optional client-side deadline for the whole post-create
lifecycle (poll + result fetch). When exceeded, raises
``ToolExecutionException`` and fires a best-effort ``tasks/cancel``.
``None`` (default) means no client-side bound. Must be positive if set.
"""
default_ttl: timedelta | None = None
cancel_remote_task_on_local_cancellation: bool = True
max_task_wait: timedelta | None = None
def __post_init__(self) -> None:
if self.default_ttl is not None and self.default_ttl.total_seconds() <= 0:
raise ValueError("MCPTaskOptions.default_ttl must be positive.")
if self.max_task_wait is not None and self.max_task_wait.total_seconds() <= 0:
raise ValueError("MCPTaskOptions.max_task_wait must be positive.")
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
"""Lazily import the MCP streamable HTTP transport."""
try:
@@ -217,6 +287,7 @@ class MCPTool:
request_timeout: int | None = None,
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
) -> None:
"""Initialize the MCP Tool base.
@@ -248,6 +319,9 @@ class MCPTool:
request_timeout: Timeout in seconds for MCP requests.
client: A chat client for sampling callbacks.
additional_properties: Additional properties for the tool.
task_options: Options controlling how long-running MCP tasks are driven for
tools that advertise ``execution.taskSupport == "required"``. When ``None``,
the defaults from :class:`MCPTaskOptions` are used.
"""
self.name = name
self.description = description or ""
@@ -259,6 +333,10 @@ class MCPTool:
self.parse_tool_results = parse_tool_results
self.load_prompts_flag = load_prompts
self.parse_prompt_results = parse_prompt_results
# Defer constructing the default MCPTaskOptions so the experimental warning
# only fires when LRO is actually engaged (lazy-resolved by _effective_task_options).
self._task_options_explicit: MCPTaskOptions | None = task_options
self._task_options_default: MCPTaskOptions | None = None
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
@@ -270,6 +348,7 @@ class MCPTool:
self.client = client
self._functions: list[FunctionTool] = []
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
self._tool_task_support_by_name: dict[str, str] = {}
self.is_connected: bool = False
self._tools_loaded: bool = False
self._prompts_loaded: bool = False
@@ -1131,6 +1210,7 @@ class MCPTool:
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
tool_task_support_by_name: dict[str, str] = {}
params: types.PaginatedRequestParams | None = None
while True:
@@ -1168,6 +1248,10 @@ class MCPTool:
if tool.meta is not None:
tool_call_meta_by_name[tool.name] = dict(tool.meta)
task_support = getattr(getattr(tool, "execution", None), "taskSupport", None)
if task_support is not None:
tool_task_support_by_name[tool.name] = task_support
normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
@@ -1216,6 +1300,7 @@ class MCPTool:
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
self._tool_call_meta_by_name = tool_call_meta_by_name
self._tool_task_support_by_name = tool_task_support_by_name
async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
@@ -1292,6 +1377,29 @@ class MCPTool:
inner_exception=ex,
) from ex
def _effective_task_options(self) -> MCPTaskOptions:
"""Return the effective MCPTaskOptions, lazily constructing defaults on first use.
Defers the implicit ``MCPTaskOptions()`` so the experimental warning only
fires when LRO is actually engaged (server advertises ``taskSupport=required``).
"""
explicit = self._task_options_explicit
if explicit is not None:
return explicit
if self._task_options_default is None:
self._task_options_default = MCPTaskOptions()
return self._task_options_default
@property
def task_options(self) -> MCPTaskOptions:
"""The effective MCPTaskOptions for this tool (lazy defaults)."""
return self._effective_task_options()
@task_options.setter
def task_options(self, value: MCPTaskOptions | None) -> None:
self._task_options_explicit = value
self._task_options_default = None
async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
"""Call a tool with the given arguments.
@@ -1322,47 +1430,12 @@ class MCPTool:
"Tools are not loaded for this server, please set load_tools=True in the constructor."
)
raw_user_meta: object | None = kwargs.get("_meta")
user_meta: dict[str, Any] | None = None
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
if isinstance(raw_user_meta, dict):
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
user_meta = {}
for key, value in raw_user_meta_dict.items():
if not isinstance(key, str):
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
user_meta[key] = value
# Tools advertising taskSupport == "required" cannot complete via plain tools/call;
# route through the long-running task lifecycle transparently.
if self._tool_task_support_by_name.get(tool_name) == "required":
return await self.call_tool_as_task(tool_name, **kwargs)
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
# conversation_id is an internal tracking ID used by services like Azure AI.
# options contains metadata/store used by AG-UI for Azure AI client requirements.
# response_format is a Pydantic model class used for structured output (not serializable).
filtered_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in {
"chat_options",
"tools",
"tool_choice",
"session",
"thread",
"conversation_id",
"options",
"response_format",
"_meta",
}
}
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
tool_meta = self._tool_call_meta_by_name.get(tool_name)
request_meta = dict(tool_meta) if tool_meta is not None else None
if user_meta is not None:
request_meta = {**(request_meta or {}), **user_meta}
meta = _inject_otel_into_mcp_meta(request_meta)
filtered_kwargs, meta = self._prepare_call_kwargs(tool_name, kwargs)
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Try the operation, reconnecting once if the connection is closed
@@ -1411,6 +1484,479 @@ class MCPTool:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
def _prepare_call_kwargs(
self, tool_name: str, kwargs: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any] | None]:
"""Filter framework-only kwargs and build the merged MCP request metadata."""
raw_user_meta: object | None = kwargs.get("_meta")
user_meta: dict[str, Any] | None = None
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
if isinstance(raw_user_meta, dict):
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
user_meta = {}
for key, value in raw_user_meta_dict.items():
if not isinstance(key, str):
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
user_meta[key] = value
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
# conversation_id is an internal tracking ID used by services like Azure AI.
# options contains metadata/store used by AG-UI for Azure AI client requirements.
# response_format is a Pydantic model class used for structured output (not serializable).
filtered_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in {
"chat_options",
"tools",
"tool_choice",
"session",
"thread",
"conversation_id",
"options",
"response_format",
"_meta",
}
}
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
tool_meta = self._tool_call_meta_by_name.get(tool_name)
request_meta = dict(tool_meta) if tool_meta is not None else None
if user_meta is not None:
request_meta = {**(request_meta or {}), **user_meta}
meta = _inject_otel_into_mcp_meta(request_meta)
return filtered_kwargs, meta
async def call_tool_as_task(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
"""Call an MCP tool via the long-running task lifecycle (SEP-2663).
Issues an augmented ``tools/call`` with ``params.task`` set from
``self.task_options``, then polls ``tasks/get`` until the server reports a
terminal status. On ``completed`` the payload is fetched via ``tasks/result``,
validated as a ``CallToolResult`` and parsed identically to :meth:`call_tool`.
Local cancellation triggers a best-effort ``tasks/cancel`` (controlled by
:attr:`MCPTaskOptions.cancel_remote_task_on_local_cancellation`) before
``asyncio.CancelledError`` is re-raised.
Args:
tool_name: The remote MCP tool name.
Keyword Args:
kwargs: Arguments forwarded to the tool. See :meth:`call_tool` for the
framework kwargs that are filtered out.
Returns:
A list of Content items (or a string when a custom ``parse_tool_results``
callback is configured).
"""
from anyio import ClosedResourceError
from mcp.shared.exceptions import McpError
if not self.load_tools_flag:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
)
filtered_kwargs, meta = self._prepare_call_kwargs(tool_name, kwargs)
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Submit the task: issue augmented tools/call. Do NOT retry on connection loss here:
# the server may have accepted the request and created a task before the
# response was lost, so retrying could start the long-running operation twice.
# Reconnect-and-retry is only safe after the task_id is known.
try:
task_id, fallback_result = await self._call_tool_as_task_create(tool_name, filtered_kwargs, meta)
except (ClosedResourceError, McpError) as ex:
if not self._is_connection_lost(ex):
error_message = ex.error.message if isinstance(ex, McpError) else str(ex)
raise ToolExecutionException(error_message, inner_exception=ex) from ex
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost; task state unknown.",
inner_exception=ex,
) from ex
except ToolExecutionException:
raise
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
# Server returned a CallToolResult (no task created) or fell back to plain tools/call.
if fallback_result is not None:
if fallback_result.isError:
parsed = parser(fallback_result)
text = (
"\n".join(c.text for c in parsed if c.type == "text" and c.text)
if isinstance(parsed, list)
else str(parsed)
)
raise ToolExecutionException(text or str(parsed))
return parser(fallback_result)
if task_id is None:
raise ToolExecutionException(
f"MCP server did not return a task_id or fallback result for '{tool_name}'."
)
# Track to completion: poll until terminal, then fetch payload. Never re-issue
# tools/call past this point; reconnect-and-retry only against the same task_id.
opts = self._effective_task_options()
max_wait_s = opts.max_task_wait.total_seconds() if opts.max_task_wait is not None else None
async def _await_task_completion() -> str | list[Content]:
terminal = await self._poll_task_until_terminal(task_id)
return await self._handle_terminal_task(tool_name, task_id, terminal, parser)
try:
if max_wait_s is not None:
try:
result = await self._await_with_deadline(_await_task_completion(), max_wait_s)
return cast("str | list[Content]", result)
except _MCPDeadlineExpired as ex:
self._spawn_best_effort_cancel(task_id)
raise ToolExecutionException(
f"MCP task '{task_id}' exceeded max_task_wait of {max_wait_s}s.",
inner_exception=ex,
) from ex
else:
return await _await_task_completion()
except asyncio.CancelledError:
if opts.cancel_remote_task_on_local_cancellation:
self._spawn_best_effort_cancel(task_id)
raise
except _MCPTaskAbandoned:
# Pre-terminal abandonment (hard poll error, malformed get, second
# disconnect, reconnect failure): cancel + re-raise as plain
# ToolExecutionException to the function-calling loop.
self._spawn_best_effort_cancel(task_id)
raise
# Plain ToolExecutionException from terminal failures (failed/cancelled/
# input_required, completed+isError, malformed result post-completion)
# propagates without cancel — server is already done.
async def _call_tool_as_task_create(
self, tool_name: str, arguments: dict[str, Any], meta: dict[str, Any] | None
) -> tuple[str | None, types.CallToolResult | None]:
"""Send the augmented tools/call.
Returns ``(task_id, None)`` when the server created a task,
``(None, CallToolResult)`` when it returned a non-task result, falling back
to plain ``tools/call`` if the server rejects the ``task`` field outright.
"""
from mcp import types
from mcp.shared.exceptions import McpError
from pydantic import ValidationError
opts = self._effective_task_options()
ttl_ms: int | None = None
if opts.default_ttl is not None:
ttl_ms = int(opts.default_ttl.total_seconds() * 1000)
# Always send TaskMetadata to mark the call as task-augmented; ttl may be omitted.
task_metadata = types.TaskMetadata(ttl=ttl_ms)
request_meta = types.RequestParams.Meta(**meta) if meta else None
params = types.CallToolRequestParams(
name=tool_name,
arguments=arguments,
task=task_metadata,
_meta=request_meta, # type: ignore[call-arg]
)
request = types.ClientRequest(types.CallToolRequest(params=params))
# Use the lenient Result type so we can extract the task_id even when
# the strict CreateTaskResult schema rejects the payload (the MCP Python
# SDK requires Task.ttl, but servers may legitimately omit it).
try:
lenient = await self.session.send_request( # type: ignore[union-attr]
request,
types.Result,
)
except McpError as ex:
if ex.error.code not in (types.METHOD_NOT_FOUND, types.INVALID_PARAMS):
raise
logger.debug(
"Server rejected augmented tools/call for '%s' (code=%s); falling back.",
tool_name,
ex.error.code,
)
fallback = await self.session.call_tool(tool_name, arguments=arguments, meta=meta) # type: ignore[union-attr]
return None, fallback
# Inspect the raw payload: a CreateTaskResult carries `task.taskId`;
# a legacy CallToolResult carries `content` and/or `isError`.
raw: dict[str, Any] = lenient.model_dump(by_alias=True, exclude_none=True)
raw.pop("_meta", None)
task_field = raw.get("task")
if isinstance(task_field, dict):
task_id_val = cast(dict[str, Any], task_field).get("taskId")
if isinstance(task_id_val, str):
return task_id_val, None
try:
legacy = types.CallToolResult.model_validate(raw)
except ValidationError as ex:
# Augmented call succeeded server-side; re-issuing a plain tools/call
# could double-execute a side-effecting tool.
raise ToolExecutionException(
f"MCP server returned an unparseable response to augmented tools/call "
f"for '{tool_name}'; cannot safely retry (server may have started the operation).",
inner_exception=ex,
) from ex
return None, legacy
async def _poll_task_until_terminal(self, task_id: str) -> types.GetTaskResult:
"""Poll ``tasks/get`` until the task reaches a terminal status."""
import httpx
from mcp import types
from mcp.shared.exceptions import McpError
# SDK raises McpError(code=httpx.REQUEST_TIMEOUT=408) on session read timeout.
transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)})
while True:
request = types.ClientRequest(
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
)
try:
# GetTaskResult.ttl is required-but-Optional in the SDK; coerce below.
lenient = await self._send_with_one_reconnect(
request, types.Result, operation="tasks/get", task_id=task_id
)
except McpError as ex:
if ex.error.code in transient_codes:
logger.debug(
"Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id
)
await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds())
continue
# Hard server error mid-poll: task may still be running.
raise _MCPTaskAbandoned(ex.error.message, inner_exception=ex) from ex
try:
snapshot = self._coerce_get_task_result(lenient, task_id)
except ToolExecutionException as ex:
# Malformed tasks/get response; task may still be running.
raise _MCPTaskAbandoned(str(ex), inner_exception=ex) from ex
if snapshot.status in _MCP_TASK_TERMINAL_STATUSES:
return snapshot
await asyncio.sleep(self._compute_poll_delay(snapshot.pollInterval).total_seconds())
@staticmethod
def _coerce_get_task_result(lenient: types.Result, task_id: str) -> types.GetTaskResult:
"""Coerce a lenient Result into GetTaskResult, defaulting ``ttl`` when absent."""
from mcp import types
raw = lenient.model_dump(by_alias=True, exclude_none=True)
raw.pop("_meta", None)
raw.setdefault("ttl", None)
try:
return types.GetTaskResult.model_validate(raw)
except Exception as ex:
raise ToolExecutionException(
f"MCP server returned a malformed tasks/get response for task '{task_id}'.",
inner_exception=ex,
) from ex
@staticmethod
def _compute_poll_delay(server_interval_ms: int | None) -> timedelta:
"""Clamp the server-suggested poll interval to ``[min, max]``."""
if server_interval_ms is None or server_interval_ms <= 0:
return _MCP_TASK_MIN_POLL_INTERVAL
suggested = timedelta(milliseconds=server_interval_ms)
if suggested < _MCP_TASK_MIN_POLL_INTERVAL:
return _MCP_TASK_MIN_POLL_INTERVAL
if suggested > _MCP_TASK_MAX_POLL_INTERVAL:
return _MCP_TASK_MAX_POLL_INTERVAL
return suggested
async def _handle_terminal_task(
self,
tool_name: str,
task_id: str,
snapshot: types.GetTaskResult,
parser: Callable[[types.CallToolResult], str | list[Content]],
) -> str | list[Content]:
"""Map a terminal task snapshot to either a parsed result or an exception."""
status = snapshot.status
if status == "completed":
payload = await self._fetch_task_result(task_id)
if payload.isError:
parsed = parser(payload)
text = (
"\n".join(c.text for c in parsed if c.type == "text" and c.text)
if isinstance(parsed, list)
else str(parsed)
)
raise ToolExecutionException(text or str(parsed))
return parser(payload)
# Non-completed terminal statuses surface as ToolExecutionException so the
# function-calling loop sees a normal failure for tool_name.
message = snapshot.statusMessage or f"MCP task ended with status '{status}'."
if status == "input_required":
# Spec-non-terminal; treated as terminal here because the framework does
# not implement the interactive input flow.
message = snapshot.statusMessage or "MCP task requires additional input and cannot continue."
raise ToolExecutionException(f"Tool '{tool_name}' task {status}: {message}")
async def _fetch_task_result(self, task_id: str) -> types.CallToolResult:
"""Send ``tasks/result`` and reinterpret the open-typed payload as a CallToolResult."""
from mcp import types
from mcp.shared.exceptions import McpError
from pydantic import ValidationError
request = types.ClientRequest(
types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))
)
# Connection-loss retry only via the helper; no transient-code retry — server
# has already completed the task, so a slow payload fetch is anomalous.
try:
payload = await self._send_with_one_reconnect(
request, types.GetTaskPayloadResult, operation="tasks/result", task_id=task_id
)
except McpError as ex:
# Server reported completed; a hard fetch error is a plain failure (no cancel).
raise ToolExecutionException(ex.error.message, inner_exception=ex) from ex
# GetTaskPayloadResult carries the tool result via extra fields; reinterpret as CallToolResult.
payload_dict = payload.model_dump(by_alias=True, exclude_none=True)
payload_dict.pop("_meta", None)
try:
return types.CallToolResult.model_validate(payload_dict)
except ValidationError as ex:
# Server reported completed; malformed payload is a plain failure (no cancel needed).
raise ToolExecutionException(
f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult.",
inner_exception=ex,
) from ex
async def _send_with_one_reconnect(
self,
request: types.ClientRequest,
result_type: type[Any],
*,
operation: str,
task_id: str,
) -> Any:
"""Send ``request`` with one reconnect-and-retry on connection loss.
After a second loss (or reconnect failure), raise ``_MCPTaskAbandoned``.
Non-connection errors propagate unchanged.
"""
from anyio import ClosedResourceError
from mcp.shared.exceptions import McpError
for attempt in range(_MCP_RECONNECT_ATTEMPTS):
try:
return await self.session.send_request(request, result_type) # type: ignore[union-attr]
except (ClosedResourceError, McpError) as ex:
if not self._is_connection_lost(ex):
raise
if attempt < _MCP_RECONNECT_ATTEMPTS - 1:
logger.info(
"MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id
)
try:
await self.connect(reset=True)
except Exception as reconn_ex:
# Reconnect failure: task may still be running.
raise _MCPTaskAbandoned(
"Failed to reconnect to MCP server.", inner_exception=reconn_ex
) from reconn_ex
continue
# Final attempt also lost the connection: task may still be running.
raise _MCPTaskAbandoned(
f"MCP connection lost; task state unknown (task_id={task_id}).",
inner_exception=ex,
) from ex
raise AssertionError(f"unreachable: {operation} for {task_id}") # pragma: no cover
@staticmethod
async def _await_with_deadline(coro: Coroutine[Any, Any, Any], timeout_s: float) -> Any:
"""Await ``coro`` with a deadline; raise ``_MCPDeadlineExpired`` only on deadline.
Unlike ``asyncio.wait_for``, an ``asyncio.TimeoutError`` raised by ``coro``
itself propagates unchanged so callers can distinguish their own deadline
from a stray inner timeout.
"""
inner = asyncio.ensure_future(coro)
try:
done, _pending = await asyncio.wait({inner}, timeout=timeout_s)
except BaseException:
# Outer caller cancelled (or another exception): cancel inner + drain.
inner.cancel()
with contextlib.suppress(BaseException):
await inner
raise
if inner in done:
return inner.result()
# Deadline fired before inner finished.
inner.cancel()
with contextlib.suppress(BaseException):
await inner
raise _MCPDeadlineExpired
def _spawn_best_effort_cancel(self, task_id: str) -> None:
"""Fire-and-forget ``tasks/cancel`` so local cancellation propagates server-side."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
cancel_task = loop.create_task(self._try_cancel_task(task_id))
# Reuse pending-reload bookkeeping so close-on-owner waits/cancels these too.
self._pending_reload_tasks.add(cancel_task)
cancel_task.add_done_callback(self._pending_reload_tasks.discard)
async def _try_cancel_task(self, task_id: str) -> None:
"""Send ``tasks/cancel``; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``.
Failures log at warning so unattributed orphan tasks are debuggable.
"""
from mcp import types
request = types.ClientRequest(
types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))
)
try:
await asyncio.wait_for(
self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr]
timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
)
except asyncio.CancelledError:
raise
except asyncio.TimeoutError:
logger.warning(
"Best-effort tasks/cancel for '%s' timed out after %.1fs; "
"remote task may still be running.",
task_id,
_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
)
except Exception:
logger.warning(
"Best-effort tasks/cancel for '%s' failed; remote task may still be running.",
task_id,
exc_info=True,
)
@staticmethod
def _is_connection_lost(ex: BaseException) -> bool:
"""Return True if *ex* indicates the MCP transport was torn down."""
from anyio import ClosedResourceError
from mcp.shared.exceptions import McpError
if isinstance(ex, ClosedResourceError):
return True
if isinstance(ex, McpError):
return "session terminated" in ex.error.message.lower()
return False
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> str:
"""Call a prompt with the given arguments.
@@ -1554,6 +2100,7 @@ class MCPStdioTool(MCPTool):
encoding: str | None = None,
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP stdio tool.
@@ -1598,6 +2145,8 @@ class MCPStdioTool(MCPTool):
env: The environment variables to set for the command.
encoding: The encoding to use for the command output.
client: The chat client to use for sampling.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
kwargs: Any extra arguments to pass to the stdio client.
"""
super().__init__(
@@ -1614,6 +2163,7 @@ class MCPStdioTool(MCPTool):
load_prompts=load_prompts,
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
task_options=task_options,
)
self.command = command
self.args = args or []
@@ -1687,6 +2237,7 @@ class MCPStreamableHTTPTool(MCPTool):
additional_properties: dict[str, Any] | None = None,
http_client: AsyncClient | None = None,
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
task_options: MCPTaskOptions | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.
@@ -1739,6 +2290,8 @@ class MCPStreamableHTTPTool(MCPTool):
of HTTP headers to inject into every outbound request to the MCP server.
Use this to forward per-request context (e.g. authentication tokens set in
agent middleware) without creating a separate ``httpx.AsyncClient``.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
"""
super().__init__(
@@ -1755,6 +2308,7 @@ class MCPStreamableHTTPTool(MCPTool):
load_prompts=load_prompts,
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
task_options=task_options,
)
self.url = url
self.terminate_on_close = terminate_on_close
@@ -1862,6 +2416,7 @@ class MCPWebsocketTool(MCPTool):
allowed_tools: Collection[str] | None = None,
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP WebSocket tool.
@@ -1904,6 +2459,8 @@ class MCPWebsocketTool(MCPTool):
allowed_tools: A list of tools that are allowed to use this tool.
additional_properties: Additional properties.
client: The chat client to use for sampling.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
kwargs: Any extra arguments to pass to the WebSocket client.
"""
super().__init__(
@@ -1920,6 +2477,7 @@ class MCPWebsocketTool(MCPTool):
load_prompts=load_prompts,
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
task_options=task_options,
)
self.url = url
self._client_kwargs = kwargs
File diff suppressed because it is too large Load Diff