mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Address PR feedbacks
This commit is contained in:
@@ -80,11 +80,14 @@ agent_framework/
|
||||
|
||||
- **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s.
|
||||
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
|
||||
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Fields:
|
||||
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
|
||||
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
|
||||
- `cancel_remote_task_on_local_cancellation: bool = True` — on local `CancelledError`, spawn a best-effort `tasks/cancel` before re-raising.
|
||||
- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working.
|
||||
- **Phase-aware reconnect**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id, which is safe because the request references the existing operation.
|
||||
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
|
||||
- `max_task_wait: timedelta | None` — client-side deadline for the whole post-create lifecycle (poll + result fetch). When exceeded, raises `ToolExecutionException` and fires a best-effort `tasks/cancel`. `None` (default) means no client-side bound. Bounds sleeps, sends, AND reconnects via `asyncio.wait_for`.
|
||||
- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working. An unparseable success response (server accepted the augmented call but returned a payload that is neither `CreateTaskResult` nor `CallToolResult`) **does not** fall back — it raises `ToolExecutionException` to avoid double-executing a side-effecting tool.
|
||||
- **Submit-vs-track reconnect policy**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id (a shared `_send_with_one_reconnect` helper).
|
||||
- **Cancel-on-abandonment vs terminal failure**: any path where the remote task may still be running (max-wait exceeded, hard `McpError` in poll, malformed `tasks/get`, second connection loss in poll/fetch, reconnect failure) fires best-effort `tasks/cancel` before raising. Terminal failures (`failed`/`cancelled`/`input_required` server-side, `completed+isError`, malformed `tasks/result` after server completed) do **not** cancel — the server is already done. `_MCPTaskAbandoned` is the private marker distinguishing the two.
|
||||
- **Transient poll retry**: a slow `tasks/get` that surfaces as `McpError(code=408 REQUEST_TIMEOUT)` is retried (bounded by `max_task_wait`). All other non-connection `McpError`s during poll are treated as abandonment. `tasks/result` does not get transient retry — the server has already completed, so a slow payload fetch is anomalous.
|
||||
|
||||
### File Access Harness (`_harness/_file_access.py`)
|
||||
|
||||
|
||||
@@ -160,8 +160,19 @@ _MCP_TASK_CANCEL_TIMEOUT = timedelta(seconds=5)
|
||||
_MCP_TASK_TERMINAL_STATUSES: frozenset[str] = frozenset({"completed", "failed", "cancelled", "input_required"})
|
||||
|
||||
|
||||
class _MCPTaskAbandoned(ToolExecutionException):
|
||||
"""Raised when the remote MCP task may still be running and must be cancelled.
|
||||
|
||||
Subclass of ToolExecutionException so callers see a normal tool failure.
|
||||
"""
|
||||
|
||||
|
||||
class _MCPDeadlineExpired(Exception):
|
||||
"""Internal marker for ``max_task_wait`` expiry; distinct from inner TimeoutError."""
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.MCP_LONG_RUNNING_TASKS)
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class MCPTaskOptions:
|
||||
"""Options controlling how MCPTool drives the MCP long-running task lifecycle.
|
||||
|
||||
@@ -169,6 +180,9 @@ class MCPTaskOptions:
|
||||
the framework transparently drives the SEP-2663 ``tools/call`` → ``tasks/get``
|
||||
(polled) → ``tasks/result`` lifecycle so the agent sees a normal tool result.
|
||||
|
||||
Instances are immutable; replace the whole object via
|
||||
``MCPTool.task_options = MCPTaskOptions(...)`` to change behavior.
|
||||
|
||||
Attributes:
|
||||
default_ttl: Optional default time-to-live forwarded to the server as
|
||||
``params.task.ttl`` (milliseconds, integer). When ``None``, the server
|
||||
@@ -176,14 +190,24 @@ class MCPTaskOptions:
|
||||
cancel_remote_task_on_local_cancellation: If True (default), a local
|
||||
cancellation of the awaiting coroutine triggers a best-effort
|
||||
``tasks/cancel`` on the server before re-raising ``CancelledError``.
|
||||
Only gates ``CancelledError``; abandonment paths (max-wait,
|
||||
unrecoverable poll errors, lost connection after task_id is known)
|
||||
always cancel regardless of this flag.
|
||||
max_task_wait: Optional client-side deadline for the whole post-create
|
||||
lifecycle (poll + result fetch). When exceeded, raises
|
||||
``ToolExecutionException`` and fires a best-effort ``tasks/cancel``.
|
||||
``None`` (default) means no client-side bound. Must be positive if set.
|
||||
"""
|
||||
|
||||
default_ttl: timedelta | None = None
|
||||
cancel_remote_task_on_local_cancellation: bool = True
|
||||
max_task_wait: timedelta | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.default_ttl is not None and self.default_ttl.total_seconds() < 0:
|
||||
raise ValueError("MCPTaskOptions.default_ttl must be non-negative.")
|
||||
if self.max_task_wait is not None and self.max_task_wait.total_seconds() <= 0:
|
||||
raise ValueError("MCPTaskOptions.max_task_wait must be positive.")
|
||||
|
||||
|
||||
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
@@ -1532,10 +1556,10 @@ class MCPTool:
|
||||
filtered_kwargs, meta = self._prepare_call_kwargs(tool_name, kwargs)
|
||||
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
|
||||
|
||||
# Phase 1: issue augmented tools/call. Do NOT retry on connection loss here:
|
||||
# Submit the task: issue augmented tools/call. Do NOT retry on connection loss here:
|
||||
# the server may have accepted the request and created a task before the
|
||||
# response was lost, so retrying could start the long-running operation twice.
|
||||
# Reconnect-and-retry is only safe after task_id is known (phase 2).
|
||||
# Reconnect-and-retry is only safe after the task_id is known.
|
||||
try:
|
||||
task_id, fallback_result = await self._call_tool_as_task_create(tool_name, filtered_kwargs, meta)
|
||||
except (ClosedResourceError, McpError) as ex:
|
||||
@@ -1565,15 +1589,40 @@ class MCPTool:
|
||||
|
||||
assert task_id is not None # noqa: S101 # nosec B101 - protected by the branch above
|
||||
|
||||
# Phase 2: poll until terminal status, then fetch payload. Never re-issue tools/call
|
||||
# past this point; reconnect-and-retry only against the same task_id.
|
||||
try:
|
||||
# Track to completion: poll until terminal, then fetch payload. Never re-issue
|
||||
# tools/call past this point; reconnect-and-retry only against the same task_id.
|
||||
opts = self._effective_task_options()
|
||||
max_wait_s = opts.max_task_wait.total_seconds() if opts.max_task_wait is not None else None
|
||||
|
||||
async def _await_task_completion() -> str | list[Content]:
|
||||
terminal = await self._poll_task_until_terminal(task_id)
|
||||
return await self._handle_terminal_task(tool_name, task_id, terminal, parser)
|
||||
|
||||
try:
|
||||
if max_wait_s is not None:
|
||||
try:
|
||||
return await self._await_with_deadline(_await_task_completion(), max_wait_s)
|
||||
except _MCPDeadlineExpired as ex:
|
||||
self._spawn_best_effort_cancel(task_id)
|
||||
raise ToolExecutionException(
|
||||
f"MCP task '{task_id}' exceeded max_task_wait of {max_wait_s}s.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
else:
|
||||
return await _await_task_completion()
|
||||
except asyncio.CancelledError:
|
||||
if self._effective_task_options().cancel_remote_task_on_local_cancellation:
|
||||
if opts.cancel_remote_task_on_local_cancellation:
|
||||
self._spawn_best_effort_cancel(task_id)
|
||||
raise
|
||||
except _MCPTaskAbandoned:
|
||||
# Pre-terminal abandonment (hard poll error, malformed get, second
|
||||
# disconnect, reconnect failure): cancel + re-raise as plain
|
||||
# ToolExecutionException to the function-calling loop.
|
||||
self._spawn_best_effort_cancel(task_id)
|
||||
raise
|
||||
# Plain ToolExecutionException from terminal failures (failed/cancelled/
|
||||
# input_required, completed+isError, malformed result post-completion)
|
||||
# propagates without cancel — server is already done.
|
||||
|
||||
async def _call_tool_as_task_create(
|
||||
self, tool_name: str, arguments: dict[str, Any], meta: dict[str, Any] | None
|
||||
@@ -1636,54 +1685,50 @@ class MCPTool:
|
||||
|
||||
try:
|
||||
legacy = types.CallToolResult.model_validate(raw)
|
||||
except ValidationError:
|
||||
logger.debug(
|
||||
"Augmented tools/call for '%s' returned a non-CreateTaskResult/non-CallToolResult payload; "
|
||||
"falling back to plain tools/call.",
|
||||
tool_name,
|
||||
)
|
||||
fallback = await self.session.call_tool(tool_name, arguments=arguments, meta=meta) # type: ignore[union-attr]
|
||||
return None, fallback
|
||||
except ValidationError as ex:
|
||||
# Augmented call succeeded server-side; re-issuing a plain tools/call
|
||||
# could double-execute a side-effecting tool.
|
||||
raise ToolExecutionException(
|
||||
f"MCP server returned an unparseable response to augmented tools/call "
|
||||
f"for '{tool_name}'; cannot safely retry (server may have started the operation).",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
|
||||
return None, legacy
|
||||
|
||||
async def _poll_task_until_terminal(self, task_id: str) -> types.GetTaskResult:
|
||||
"""Poll ``tasks/get`` until the task reaches a terminal status."""
|
||||
from anyio import ClosedResourceError
|
||||
import httpx
|
||||
from mcp import types
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
# SDK raises McpError(code=httpx.REQUEST_TIMEOUT=408) on session read timeout.
|
||||
transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)})
|
||||
|
||||
while True:
|
||||
for attempt in range(2):
|
||||
try:
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
try:
|
||||
# GetTaskResult.ttl is required-but-Optional in the SDK; coerce below.
|
||||
lenient = await self._send_with_one_reconnect(
|
||||
request, types.Result, operation="tasks/get", task_id=task_id
|
||||
)
|
||||
except McpError as ex:
|
||||
if ex.error.code in transient_codes:
|
||||
logger.debug(
|
||||
"Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id
|
||||
)
|
||||
# Use lenient Result then coerce: GetTaskResult.ttl is required by
|
||||
# schema but servers may legitimately omit it.
|
||||
lenient = await self.session.send_request(request, types.Result) # type: ignore[union-attr]
|
||||
snapshot = self._coerce_get_task_result(lenient, task_id)
|
||||
break
|
||||
except (ClosedResourceError, McpError) as ex:
|
||||
if not self._is_connection_lost(ex):
|
||||
error_message = ex.error.message if isinstance(ex, McpError) else str(ex)
|
||||
raise ToolExecutionException(error_message, inner_exception=ex) from ex
|
||||
if attempt == 0:
|
||||
logger.info("MCP connection lost during tasks/get; reconnecting (task_id=%s).", task_id)
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
continue
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
raise ToolExecutionException(
|
||||
f"MCP connection lost; task state unknown (task_id={task_id}).",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
else: # pragma: no cover - defensive
|
||||
raise ToolExecutionException(f"Failed to poll task '{task_id}'.")
|
||||
await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds())
|
||||
continue
|
||||
# Hard server error mid-poll: task may still be running.
|
||||
raise _MCPTaskAbandoned(ex.error.message, inner_exception=ex) from ex
|
||||
|
||||
try:
|
||||
snapshot = self._coerce_get_task_result(lenient, task_id)
|
||||
except ToolExecutionException as ex:
|
||||
# Malformed tasks/get response; task may still be running.
|
||||
raise _MCPTaskAbandoned(str(ex), inner_exception=ex) from ex
|
||||
|
||||
if snapshot.status in _MCP_TASK_TERMINAL_STATUSES:
|
||||
return snapshot
|
||||
@@ -1750,40 +1795,22 @@ class MCPTool:
|
||||
|
||||
async def _fetch_task_result(self, task_id: str) -> types.CallToolResult:
|
||||
"""Send ``tasks/result`` and reinterpret the open-typed payload as a CallToolResult."""
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
from mcp.shared.exceptions import McpError
|
||||
from pydantic import ValidationError
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))
|
||||
)
|
||||
payload = await self.session.send_request( # type: ignore[union-attr]
|
||||
request, types.GetTaskPayloadResult
|
||||
)
|
||||
break
|
||||
except (ClosedResourceError, McpError) as ex:
|
||||
if not self._is_connection_lost(ex):
|
||||
error_message = ex.error.message if isinstance(ex, McpError) else str(ex)
|
||||
raise ToolExecutionException(error_message, inner_exception=ex) from ex
|
||||
if attempt == 0:
|
||||
logger.info("MCP connection lost during tasks/result; reconnecting (task_id=%s).", task_id)
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
continue
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
raise ToolExecutionException(
|
||||
f"MCP connection lost; task state unknown (task_id={task_id}).",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
else: # pragma: no cover - defensive
|
||||
raise ToolExecutionException(f"Failed to fetch result for task '{task_id}'.")
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))
|
||||
)
|
||||
# Connection-loss retry only via the helper; no transient-code retry — server
|
||||
# has already completed the task, so a slow payload fetch is anomalous.
|
||||
try:
|
||||
payload = await self._send_with_one_reconnect(
|
||||
request, types.GetTaskPayloadResult, operation="tasks/result", task_id=task_id
|
||||
)
|
||||
except McpError as ex:
|
||||
# Server reported completed; a hard fetch error is a plain failure (no cancel).
|
||||
raise ToolExecutionException(ex.error.message, inner_exception=ex) from ex
|
||||
|
||||
# GetTaskPayloadResult carries the tool result via extra fields; reinterpret as CallToolResult.
|
||||
payload_dict = payload.model_dump(by_alias=True, exclude_none=True)
|
||||
@@ -1791,10 +1818,78 @@ class MCPTool:
|
||||
try:
|
||||
return types.CallToolResult.model_validate(payload_dict)
|
||||
except ValidationError as ex:
|
||||
# Server reported completed; malformed payload is a plain failure (no cancel needed).
|
||||
raise ToolExecutionException(
|
||||
f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult."
|
||||
f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult.",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
|
||||
async def _send_with_one_reconnect(
|
||||
self,
|
||||
request: types.ClientRequest,
|
||||
result_type: type[Any],
|
||||
*,
|
||||
operation: str,
|
||||
task_id: str,
|
||||
) -> Any:
|
||||
"""Send ``request`` with one reconnect-and-retry on connection loss.
|
||||
|
||||
After a second loss (or reconnect failure), raise ``_MCPTaskAbandoned``.
|
||||
Non-connection errors propagate unchanged.
|
||||
"""
|
||||
from anyio import ClosedResourceError
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
return await self.session.send_request(request, result_type) # type: ignore[union-attr]
|
||||
except (ClosedResourceError, McpError) as ex:
|
||||
if not self._is_connection_lost(ex):
|
||||
raise
|
||||
if attempt == 0:
|
||||
logger.info(
|
||||
"MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id
|
||||
)
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
except Exception as reconn_ex:
|
||||
# Reconnect failure: task may still be running.
|
||||
raise _MCPTaskAbandoned(
|
||||
"Failed to reconnect to MCP server.", inner_exception=reconn_ex
|
||||
) from reconn_ex
|
||||
continue
|
||||
# Second connection loss: task may still be running.
|
||||
raise _MCPTaskAbandoned(
|
||||
f"MCP connection lost; task state unknown (task_id={task_id}).",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
raise AssertionError(f"unreachable: {operation} for {task_id}") # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
async def _await_with_deadline(coro: Coroutine[Any, Any, Any], timeout_s: float) -> Any:
|
||||
"""Await ``coro`` with a deadline; raise ``_MCPDeadlineExpired`` only on deadline.
|
||||
|
||||
Unlike ``asyncio.wait_for``, an ``asyncio.TimeoutError`` raised by ``coro``
|
||||
itself propagates unchanged so callers can distinguish their own deadline
|
||||
from a stray inner timeout.
|
||||
"""
|
||||
inner = asyncio.ensure_future(coro)
|
||||
try:
|
||||
done, _pending = await asyncio.wait({inner}, timeout=timeout_s)
|
||||
except BaseException:
|
||||
# Outer caller cancelled (or another exception): cancel inner + drain.
|
||||
inner.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await inner
|
||||
raise
|
||||
if inner in done:
|
||||
return inner.result()
|
||||
# Deadline fired before inner finished.
|
||||
inner.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await inner
|
||||
raise _MCPDeadlineExpired
|
||||
|
||||
def _spawn_best_effort_cancel(self, task_id: str) -> None:
|
||||
"""Fire-and-forget ``tasks/cancel`` so local cancellation propagates server-side."""
|
||||
try:
|
||||
@@ -1807,18 +1902,35 @@ class MCPTool:
|
||||
cancel_task.add_done_callback(self._pending_reload_tasks.discard)
|
||||
|
||||
async def _try_cancel_task(self, task_id: str) -> None:
|
||||
"""Send ``tasks/cancel`` swallowing every failure; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``."""
|
||||
"""Send ``tasks/cancel``; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``.
|
||||
|
||||
Failures log at warning so unattributed orphan tasks are debuggable.
|
||||
"""
|
||||
from mcp import types
|
||||
|
||||
async def _send() -> None:
|
||||
request = types.ClientRequest(types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id)))
|
||||
try:
|
||||
await self.session.send_request(request, types.CancelTaskResult) # type: ignore[union-attr]
|
||||
except Exception:
|
||||
logger.debug("Best-effort tasks/cancel for '%s' failed.", task_id, exc_info=True)
|
||||
|
||||
with contextlib.suppress(Exception, asyncio.CancelledError, asyncio.TimeoutError):
|
||||
await asyncio.wait_for(_send(), timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds())
|
||||
request = types.ClientRequest(
|
||||
types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr]
|
||||
timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Best-effort tasks/cancel for '%s' timed out after %.1fs; "
|
||||
"remote task may still be running.",
|
||||
task_id,
|
||||
_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Best-effort tasks/cancel for '%s' failed; remote task may still be running.",
|
||||
task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_connection_lost(ex: BaseException) -> bool:
|
||||
|
||||
@@ -7,6 +7,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import _AsyncGeneratorContextManager # type: ignore
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -5553,4 +5554,504 @@ async def test_call_tool_as_task_create_disconnect_does_not_retry() -> None:
|
||||
reconnect_mock.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_fetch_task_result_reconnects_during_fetch() -> None:
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
fetch_calls = 0
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal fetch_calls
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="r1")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="r1", status="completed")
|
||||
if method == "tasks/result":
|
||||
fetch_calls += 1
|
||||
if fetch_calls == 1:
|
||||
raise ClosedResourceError
|
||||
return _make_payload("fetched after reconnect")
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
reconnect_calls = 0
|
||||
|
||||
async def fake_connect(reset: bool = False) -> None:
|
||||
nonlocal reconnect_calls
|
||||
reconnect_calls += 1
|
||||
assert reset is True
|
||||
|
||||
with patch.object(MCPTool, "connect", side_effect=fake_connect):
|
||||
result = await tool.call_tool("slow_op")
|
||||
|
||||
assert _mcp_result_to_text(result) == "fetched after reconnect"
|
||||
assert reconnect_calls == 1
|
||||
assert fetch_calls == 2
|
||||
|
||||
|
||||
async def test_fetch_task_result_second_disconnect_raises_task_state_unknown_and_cancels() -> None:
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="r2")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="r2", status="completed")
|
||||
if method == "tasks/result":
|
||||
raise ClosedResourceError
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with (
|
||||
patch.object(MCPTool, "connect", new=AsyncMock(return_value=None)),
|
||||
pytest.raises(ToolExecutionException, match="task state unknown"),
|
||||
):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
# Drain the fire-and-forget cancel so the assertion is deterministic.
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
assert cancel_called is True
|
||||
|
||||
|
||||
async def test_call_tool_as_task_create_unparseable_success_raises() -> None:
|
||||
"""An unparseable success-shaped response must NOT silently retry tools/call."""
|
||||
# Result with neither task.taskId nor a valid CallToolResult shape.
|
||||
unparseable = types.Result.model_validate({"foo": "bar"})
|
||||
|
||||
tool = _make_task_tool()
|
||||
tool.session.send_request = AsyncMock(return_value=unparseable) # type: ignore[union-attr]
|
||||
tool.session.call_tool = AsyncMock(return_value=types.CallToolResult(content=[])) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(ToolExecutionException, match="unparseable response"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
# Critically: no plain tools/call fallback (would risk double execution).
|
||||
tool.session.call_tool.assert_not_called() # type: ignore[union-attr]
|
||||
|
||||
|
||||
async def test_call_tool_as_task_max_wait_exceeded_raises_and_cancels(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from agent_framework import MCPTaskOptions
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(milliseconds=50)))
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="mw")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="mw", status="working")
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
assert cancel_called is True
|
||||
|
||||
|
||||
async def test_call_tool_as_task_max_wait_cancels_even_when_local_cancel_option_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Locks contract: max_task_wait abandonment ignores the local-cancel option."""
|
||||
from agent_framework import MCPTaskOptions
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
tool = _make_task_tool(
|
||||
task_options=MCPTaskOptions(
|
||||
cancel_remote_task_on_local_cancellation=False,
|
||||
max_task_wait=timedelta(milliseconds=50),
|
||||
),
|
||||
)
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="mw2")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="mw2", status="working")
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
assert cancel_called is True
|
||||
|
||||
|
||||
async def test_call_tool_as_task_poll_transient_request_timeout_keeps_polling(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
import httpx
|
||||
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
poll_calls = 0
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal poll_calls, cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="t1")
|
||||
if method == "tasks/get":
|
||||
poll_calls += 1
|
||||
if poll_calls == 1:
|
||||
raise McpError(types.ErrorData(code=int(httpx.codes.REQUEST_TIMEOUT), message="slow poll"))
|
||||
return _make_task_snapshot(task_id="t1", status="completed")
|
||||
if method == "tasks/result":
|
||||
return _make_payload("recovered after transient")
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
result = await tool.call_tool("slow_op")
|
||||
assert _mcp_result_to_text(result) == "recovered after transient"
|
||||
assert poll_calls == 2
|
||||
# Transient retry must not fire cancel.
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
assert cancel_called is False
|
||||
|
||||
|
||||
async def test_call_tool_as_task_poll_hard_mcperror_cancels_and_raises(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="h1")
|
||||
if method == "tasks/get":
|
||||
raise McpError(types.ErrorData(code=types.INVALID_PARAMS, message="bad task id"))
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(ToolExecutionException, match="bad task id"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
assert cancel_called is True
|
||||
|
||||
|
||||
async def test_call_tool_as_task_malformed_tasks_get_response_cancels_and_raises(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Malformed tasks/get response counts as abandonment (task may still be running)."""
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
# Result without a valid GetTaskResult shape (no taskId/status/etc.).
|
||||
malformed = types.Result.model_validate({"some": "junk"})
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="m1")
|
||||
if method == "tasks/get":
|
||||
return malformed
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(ToolExecutionException, match="malformed tasks/get"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
assert cancel_called is True
|
||||
|
||||
|
||||
async def test_call_tool_as_task_failed_terminal_does_not_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Terminal failures (server already done) must NOT fire tasks/cancel."""
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="f1")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="f1", status="failed", status_message="boom")
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(ToolExecutionException, match="task failed: boom"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
# Let any (incorrect) background work settle, then verify no cancel.
|
||||
await asyncio.sleep(0.02)
|
||||
assert cancel_called is False
|
||||
|
||||
|
||||
async def test_try_cancel_task_logs_warning_on_timeout(
|
||||
caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
# Shorten cancel timeout so the test is fast.
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_CANCEL_TIMEOUT", _mcp_module.timedelta(milliseconds=20))
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
async def hang(*_a: Any, **_kw: Any) -> Any:
|
||||
await asyncio.sleep(10.0)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=hang) # type: ignore[union-attr]
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger=_mcp_module.logger.name):
|
||||
await tool._try_cancel_task("hang-1")
|
||||
|
||||
assert any("timed out" in r.getMessage() and "hang-1" in r.getMessage() for r in caplog.records)
|
||||
|
||||
|
||||
async def test_mcp_task_options_is_frozen() -> None:
|
||||
from dataclasses import FrozenInstanceError
|
||||
|
||||
from agent_framework import MCPTaskOptions
|
||||
|
||||
opts = MCPTaskOptions()
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
opts.default_ttl = timedelta(seconds=5) # type: ignore[misc]
|
||||
|
||||
|
||||
async def test_mcp_task_options_max_task_wait_rejects_non_positive() -> None:
|
||||
from agent_framework import MCPTaskOptions
|
||||
|
||||
with pytest.raises(ValueError, match="positive"):
|
||||
MCPTaskOptions(max_task_wait=timedelta(0))
|
||||
with pytest.raises(ValueError, match="positive"):
|
||||
MCPTaskOptions(max_task_wait=timedelta(seconds=-1))
|
||||
|
||||
|
||||
async def test_fetch_task_result_hard_mcperror_raises_without_cancel() -> None:
|
||||
"""tasks/result hard McpError must wrap as ToolExecutionException without cancel (server done)."""
|
||||
tool = _make_task_tool()
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="hf")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="hf", status="completed")
|
||||
if method == "tasks/result":
|
||||
raise McpError(types.ErrorData(code=types.INTERNAL_ERROR, message="payload vanished"))
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(ToolExecutionException, match="payload vanished"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
# No raw McpError leak and no cancel — server already reported the task as done.
|
||||
await asyncio.sleep(0.02)
|
||||
assert cancel_called is False
|
||||
|
||||
|
||||
async def test_completion_wait_timeout_without_max_wait_is_not_translated(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Stray asyncio.TimeoutError during the completion wait must not pretend the deadline
|
||||
expired when max_task_wait is None (and must not fire a spurious tasks/cancel).
|
||||
"""
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
tool = _make_task_tool()
|
||||
|
||||
def boom_parser(_: Any) -> list[Content]:
|
||||
raise asyncio.TimeoutError
|
||||
|
||||
tool.parse_tool_results = boom_parser
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="t2")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="t2", status="completed")
|
||||
if method == "tasks/result":
|
||||
return _make_payload("ok")
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
# Must NOT translate to max_task_wait expiry and must NOT cancel.
|
||||
await asyncio.sleep(0.02)
|
||||
assert cancel_called is False
|
||||
|
||||
|
||||
async def test_completion_wait_inner_timeout_with_max_wait_set_propagates(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""An asyncio.TimeoutError raised by the completion wait itself must propagate
|
||||
unchanged even when max_task_wait IS set, and must NOT fire a spurious cancel.
|
||||
"""
|
||||
from agent_framework import MCPTaskOptions
|
||||
from agent_framework import _mcp as _mcp_module
|
||||
|
||||
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
|
||||
|
||||
# Deadline set comfortably above the actual test run time.
|
||||
tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(seconds=5)))
|
||||
|
||||
def boom_parser(_: Any) -> list[Content]:
|
||||
raise asyncio.TimeoutError("inner parser timeout")
|
||||
|
||||
tool.parse_tool_results = boom_parser
|
||||
|
||||
cancel_called = False
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
nonlocal cancel_called
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="t3")
|
||||
if method == "tasks/get":
|
||||
return _make_task_snapshot(task_id="t3", status="completed")
|
||||
if method == "tasks/result":
|
||||
return _make_payload("ok")
|
||||
if method == "tasks/cancel":
|
||||
cancel_called = True
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError, match="inner parser timeout"):
|
||||
await tool.call_tool("slow_op")
|
||||
|
||||
# Inner TimeoutError must NOT be translated into "exceeded max_task_wait" and must NOT cancel.
|
||||
await asyncio.sleep(0.02)
|
||||
assert cancel_called is False
|
||||
|
||||
|
||||
async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Deadline must cancel through a long ``asyncio.sleep`` (clamped to MAX), not wait it out."""
|
||||
from agent_framework import MCPTaskOptions
|
||||
|
||||
tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(milliseconds=100)))
|
||||
|
||||
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
|
||||
method = request.root.method
|
||||
if method == "tools/call":
|
||||
return _make_create_task_result(task_id="ds")
|
||||
if method == "tasks/get":
|
||||
# Suggest a 5s poll interval (gets clamped to MAX=5s); wait_for must cut through it.
|
||||
return _make_task_snapshot(task_id="ds", status="working", poll_interval_ms=5000)
|
||||
if method == "tasks/cancel":
|
||||
return types.CancelTaskResult()
|
||||
raise AssertionError(method)
|
||||
|
||||
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
started = loop.time()
|
||||
with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"):
|
||||
await tool.call_tool("slow_op")
|
||||
elapsed = loop.time() - started
|
||||
|
||||
# Should fire near the 100ms deadline, well below the 5s clamped sleep.
|
||||
assert elapsed < 1.0, f"deadline did not interrupt long sleep (elapsed={elapsed:.3f}s)"
|
||||
|
||||
pending = list(tool._pending_reload_tasks)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user