Address PR feedbacks

This commit is contained in:
Peter Ibekwe
2026-06-04 11:03:11 -07:00
Unverified
parent 6e016b1cfb
commit 012c135efd
3 changed files with 706 additions and 90 deletions
+7 -4
View File
@@ -80,11 +80,14 @@ agent_framework/
- **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s.
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Fields:
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
- `cancel_remote_task_on_local_cancellation: bool = True` — on local `CancelledError`, spawn a best-effort `tasks/cancel` before re-raising.
- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working.
- **Phase-aware reconnect**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id, which is safe because the request references the existing operation.
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
- `max_task_wait: timedelta | None` — client-side deadline for the whole post-create lifecycle (poll + result fetch). When exceeded, raises `ToolExecutionException` and fires a best-effort `tasks/cancel`. `None` (default) means no client-side bound. Bounds sleeps, sends, AND reconnects via `asyncio.wait_for`.
- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working. An unparseable success response (server accepted the augmented call but returned a payload that is neither `CreateTaskResult` nor `CallToolResult`) **does not** fall back — it raises `ToolExecutionException` to avoid double-executing a side-effecting tool.
- **Submit-vs-track reconnect policy**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id (a shared `_send_with_one_reconnect` helper).
- **Cancel-on-abandonment vs terminal failure**: any path where the remote task may still be running (max-wait exceeded, hard `McpError` in poll, malformed `tasks/get`, second connection loss in poll/fetch, reconnect failure) fires best-effort `tasks/cancel` before raising. Terminal failures (`failed`/`cancelled`/`input_required` server-side, `completed+isError`, malformed `tasks/result` after server completed) do **not** cancel — the server is already done. `_MCPTaskAbandoned` is the private marker distinguishing the two.
- **Transient poll retry**: a slow `tasks/get` that surfaces as `McpError(code=408 REQUEST_TIMEOUT)` is retried (bounded by `max_task_wait`). All other non-connection `McpError`s during poll are treated as abandonment. `tasks/result` does not get transient retry — the server has already completed, so a slow payload fetch is anomalous.
### File Access Harness (`_harness/_file_access.py`)
+198 -86
View File
@@ -160,8 +160,19 @@ _MCP_TASK_CANCEL_TIMEOUT = timedelta(seconds=5)
_MCP_TASK_TERMINAL_STATUSES: frozenset[str] = frozenset({"completed", "failed", "cancelled", "input_required"})
class _MCPTaskAbandoned(ToolExecutionException):
"""Raised when the remote MCP task may still be running and must be cancelled.
Subclass of ToolExecutionException so callers see a normal tool failure.
"""
class _MCPDeadlineExpired(Exception):
"""Internal marker for ``max_task_wait`` expiry; distinct from inner TimeoutError."""
@experimental(feature_id=ExperimentalFeature.MCP_LONG_RUNNING_TASKS)
@dataclass
@dataclass(frozen=True)
class MCPTaskOptions:
"""Options controlling how MCPTool drives the MCP long-running task lifecycle.
@@ -169,6 +180,9 @@ class MCPTaskOptions:
the framework transparently drives the SEP-2663 ``tools/call`` → ``tasks/get``
(polled) → ``tasks/result`` lifecycle so the agent sees a normal tool result.
Instances are immutable; replace the whole object via
``MCPTool.task_options = MCPTaskOptions(...)`` to change behavior.
Attributes:
default_ttl: Optional default time-to-live forwarded to the server as
``params.task.ttl`` (milliseconds, integer). When ``None``, the server
@@ -176,14 +190,24 @@ class MCPTaskOptions:
cancel_remote_task_on_local_cancellation: If True (default), a local
cancellation of the awaiting coroutine triggers a best-effort
``tasks/cancel`` on the server before re-raising ``CancelledError``.
Only gates ``CancelledError``; abandonment paths (max-wait,
unrecoverable poll errors, lost connection after task_id is known)
always cancel regardless of this flag.
max_task_wait: Optional client-side deadline for the whole post-create
lifecycle (poll + result fetch). When exceeded, raises
``ToolExecutionException`` and fires a best-effort ``tasks/cancel``.
``None`` (default) means no client-side bound. Must be positive if set.
"""
default_ttl: timedelta | None = None
cancel_remote_task_on_local_cancellation: bool = True
max_task_wait: timedelta | None = None
def __post_init__(self) -> None:
if self.default_ttl is not None and self.default_ttl.total_seconds() < 0:
raise ValueError("MCPTaskOptions.default_ttl must be non-negative.")
if self.max_task_wait is not None and self.max_task_wait.total_seconds() <= 0:
raise ValueError("MCPTaskOptions.max_task_wait must be positive.")
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
@@ -1532,10 +1556,10 @@ class MCPTool:
filtered_kwargs, meta = self._prepare_call_kwargs(tool_name, kwargs)
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Phase 1: issue augmented tools/call. Do NOT retry on connection loss here:
# Submit the task: issue augmented tools/call. Do NOT retry on connection loss here:
# the server may have accepted the request and created a task before the
# response was lost, so retrying could start the long-running operation twice.
# Reconnect-and-retry is only safe after task_id is known (phase 2).
# Reconnect-and-retry is only safe after the task_id is known.
try:
task_id, fallback_result = await self._call_tool_as_task_create(tool_name, filtered_kwargs, meta)
except (ClosedResourceError, McpError) as ex:
@@ -1565,15 +1589,40 @@ class MCPTool:
assert task_id is not None # noqa: S101 # nosec B101 - protected by the branch above
# Phase 2: poll until terminal status, then fetch payload. Never re-issue tools/call
# past this point; reconnect-and-retry only against the same task_id.
try:
# Track to completion: poll until terminal, then fetch payload. Never re-issue
# tools/call past this point; reconnect-and-retry only against the same task_id.
opts = self._effective_task_options()
max_wait_s = opts.max_task_wait.total_seconds() if opts.max_task_wait is not None else None
async def _await_task_completion() -> str | list[Content]:
terminal = await self._poll_task_until_terminal(task_id)
return await self._handle_terminal_task(tool_name, task_id, terminal, parser)
try:
if max_wait_s is not None:
try:
return await self._await_with_deadline(_await_task_completion(), max_wait_s)
except _MCPDeadlineExpired as ex:
self._spawn_best_effort_cancel(task_id)
raise ToolExecutionException(
f"MCP task '{task_id}' exceeded max_task_wait of {max_wait_s}s.",
inner_exception=ex,
) from ex
else:
return await _await_task_completion()
except asyncio.CancelledError:
if self._effective_task_options().cancel_remote_task_on_local_cancellation:
if opts.cancel_remote_task_on_local_cancellation:
self._spawn_best_effort_cancel(task_id)
raise
except _MCPTaskAbandoned:
# Pre-terminal abandonment (hard poll error, malformed get, second
# disconnect, reconnect failure): cancel + re-raise as plain
# ToolExecutionException to the function-calling loop.
self._spawn_best_effort_cancel(task_id)
raise
# Plain ToolExecutionException from terminal failures (failed/cancelled/
# input_required, completed+isError, malformed result post-completion)
# propagates without cancel — server is already done.
async def _call_tool_as_task_create(
self, tool_name: str, arguments: dict[str, Any], meta: dict[str, Any] | None
@@ -1636,54 +1685,50 @@ class MCPTool:
try:
legacy = types.CallToolResult.model_validate(raw)
except ValidationError:
logger.debug(
"Augmented tools/call for '%s' returned a non-CreateTaskResult/non-CallToolResult payload; "
"falling back to plain tools/call.",
tool_name,
)
fallback = await self.session.call_tool(tool_name, arguments=arguments, meta=meta) # type: ignore[union-attr]
return None, fallback
except ValidationError as ex:
# Augmented call succeeded server-side; re-issuing a plain tools/call
# could double-execute a side-effecting tool.
raise ToolExecutionException(
f"MCP server returned an unparseable response to augmented tools/call "
f"for '{tool_name}'; cannot safely retry (server may have started the operation).",
inner_exception=ex,
) from ex
return None, legacy
async def _poll_task_until_terminal(self, task_id: str) -> types.GetTaskResult:
"""Poll ``tasks/get`` until the task reaches a terminal status."""
from anyio import ClosedResourceError
import httpx
from mcp import types
from mcp.shared.exceptions import McpError
# SDK raises McpError(code=httpx.REQUEST_TIMEOUT=408) on session read timeout.
transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)})
while True:
for attempt in range(2):
try:
request = types.ClientRequest(
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
request = types.ClientRequest(
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
)
try:
# GetTaskResult.ttl is required-but-Optional in the SDK; coerce below.
lenient = await self._send_with_one_reconnect(
request, types.Result, operation="tasks/get", task_id=task_id
)
except McpError as ex:
if ex.error.code in transient_codes:
logger.debug(
"Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id
)
# Use lenient Result then coerce: GetTaskResult.ttl is required by
# schema but servers may legitimately omit it.
lenient = await self.session.send_request(request, types.Result) # type: ignore[union-attr]
snapshot = self._coerce_get_task_result(lenient, task_id)
break
except (ClosedResourceError, McpError) as ex:
if not self._is_connection_lost(ex):
error_message = ex.error.message if isinstance(ex, McpError) else str(ex)
raise ToolExecutionException(error_message, inner_exception=ex) from ex
if attempt == 0:
logger.info("MCP connection lost during tasks/get; reconnecting (task_id=%s).", task_id)
try:
await self.connect(reset=True)
continue
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
raise ToolExecutionException(
f"MCP connection lost; task state unknown (task_id={task_id}).",
inner_exception=ex,
) from ex
else: # pragma: no cover - defensive
raise ToolExecutionException(f"Failed to poll task '{task_id}'.")
await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds())
continue
# Hard server error mid-poll: task may still be running.
raise _MCPTaskAbandoned(ex.error.message, inner_exception=ex) from ex
try:
snapshot = self._coerce_get_task_result(lenient, task_id)
except ToolExecutionException as ex:
# Malformed tasks/get response; task may still be running.
raise _MCPTaskAbandoned(str(ex), inner_exception=ex) from ex
if snapshot.status in _MCP_TASK_TERMINAL_STATUSES:
return snapshot
@@ -1750,40 +1795,22 @@ class MCPTool:
async def _fetch_task_result(self, task_id: str) -> types.CallToolResult:
"""Send ``tasks/result`` and reinterpret the open-typed payload as a CallToolResult."""
from anyio import ClosedResourceError
from mcp import types
from mcp.shared.exceptions import McpError
from pydantic import ValidationError
for attempt in range(2):
try:
request = types.ClientRequest(
types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))
)
payload = await self.session.send_request( # type: ignore[union-attr]
request, types.GetTaskPayloadResult
)
break
except (ClosedResourceError, McpError) as ex:
if not self._is_connection_lost(ex):
error_message = ex.error.message if isinstance(ex, McpError) else str(ex)
raise ToolExecutionException(error_message, inner_exception=ex) from ex
if attempt == 0:
logger.info("MCP connection lost during tasks/result; reconnecting (task_id=%s).", task_id)
try:
await self.connect(reset=True)
continue
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
raise ToolExecutionException(
f"MCP connection lost; task state unknown (task_id={task_id}).",
inner_exception=ex,
) from ex
else: # pragma: no cover - defensive
raise ToolExecutionException(f"Failed to fetch result for task '{task_id}'.")
request = types.ClientRequest(
types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))
)
# Connection-loss retry only via the helper; no transient-code retry — server
# has already completed the task, so a slow payload fetch is anomalous.
try:
payload = await self._send_with_one_reconnect(
request, types.GetTaskPayloadResult, operation="tasks/result", task_id=task_id
)
except McpError as ex:
# Server reported completed; a hard fetch error is a plain failure (no cancel).
raise ToolExecutionException(ex.error.message, inner_exception=ex) from ex
# GetTaskPayloadResult carries the tool result via extra fields; reinterpret as CallToolResult.
payload_dict = payload.model_dump(by_alias=True, exclude_none=True)
@@ -1791,10 +1818,78 @@ class MCPTool:
try:
return types.CallToolResult.model_validate(payload_dict)
except ValidationError as ex:
# Server reported completed; malformed payload is a plain failure (no cancel needed).
raise ToolExecutionException(
f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult."
f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult.",
inner_exception=ex,
) from ex
async def _send_with_one_reconnect(
self,
request: types.ClientRequest,
result_type: type[Any],
*,
operation: str,
task_id: str,
) -> Any:
"""Send ``request`` with one reconnect-and-retry on connection loss.
After a second loss (or reconnect failure), raise ``_MCPTaskAbandoned``.
Non-connection errors propagate unchanged.
"""
from anyio import ClosedResourceError
from mcp.shared.exceptions import McpError
for attempt in range(2):
try:
return await self.session.send_request(request, result_type) # type: ignore[union-attr]
except (ClosedResourceError, McpError) as ex:
if not self._is_connection_lost(ex):
raise
if attempt == 0:
logger.info(
"MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id
)
try:
await self.connect(reset=True)
except Exception as reconn_ex:
# Reconnect failure: task may still be running.
raise _MCPTaskAbandoned(
"Failed to reconnect to MCP server.", inner_exception=reconn_ex
) from reconn_ex
continue
# Second connection loss: task may still be running.
raise _MCPTaskAbandoned(
f"MCP connection lost; task state unknown (task_id={task_id}).",
inner_exception=ex,
) from ex
raise AssertionError(f"unreachable: {operation} for {task_id}") # pragma: no cover
@staticmethod
async def _await_with_deadline(coro: Coroutine[Any, Any, Any], timeout_s: float) -> Any:
"""Await ``coro`` with a deadline; raise ``_MCPDeadlineExpired`` only on deadline.
Unlike ``asyncio.wait_for``, an ``asyncio.TimeoutError`` raised by ``coro``
itself propagates unchanged so callers can distinguish their own deadline
from a stray inner timeout.
"""
inner = asyncio.ensure_future(coro)
try:
done, _pending = await asyncio.wait({inner}, timeout=timeout_s)
except BaseException:
# Outer caller cancelled (or another exception): cancel inner + drain.
inner.cancel()
with contextlib.suppress(BaseException):
await inner
raise
if inner in done:
return inner.result()
# Deadline fired before inner finished.
inner.cancel()
with contextlib.suppress(BaseException):
await inner
raise _MCPDeadlineExpired
def _spawn_best_effort_cancel(self, task_id: str) -> None:
"""Fire-and-forget ``tasks/cancel`` so local cancellation propagates server-side."""
try:
@@ -1807,18 +1902,35 @@ class MCPTool:
cancel_task.add_done_callback(self._pending_reload_tasks.discard)
async def _try_cancel_task(self, task_id: str) -> None:
"""Send ``tasks/cancel`` swallowing every failure; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``."""
"""Send ``tasks/cancel``; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``.
Failures log at warning so unattributed orphan tasks are debuggable.
"""
from mcp import types
async def _send() -> None:
request = types.ClientRequest(types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id)))
try:
await self.session.send_request(request, types.CancelTaskResult) # type: ignore[union-attr]
except Exception:
logger.debug("Best-effort tasks/cancel for '%s' failed.", task_id, exc_info=True)
with contextlib.suppress(Exception, asyncio.CancelledError, asyncio.TimeoutError):
await asyncio.wait_for(_send(), timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds())
request = types.ClientRequest(
types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))
)
try:
await asyncio.wait_for(
self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr]
timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
)
except asyncio.CancelledError:
raise
except asyncio.TimeoutError:
logger.warning(
"Best-effort tasks/cancel for '%s' timed out after %.1fs; "
"remote task may still be running.",
task_id,
_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
)
except Exception:
logger.warning(
"Best-effort tasks/cancel for '%s' failed; remote task may still be running.",
task_id,
exc_info=True,
)
@staticmethod
def _is_connection_lost(ex: BaseException) -> bool:
+501
View File
@@ -7,6 +7,7 @@ import logging
import os
import sys
from contextlib import _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
@@ -5553,4 +5554,504 @@ async def test_call_tool_as_task_create_disconnect_does_not_retry() -> None:
reconnect_mock.assert_not_awaited()
async def test_fetch_task_result_reconnects_during_fetch() -> None:
from anyio import ClosedResourceError
tool = _make_task_tool()
fetch_calls = 0
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal fetch_calls
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="r1")
if method == "tasks/get":
return _make_task_snapshot(task_id="r1", status="completed")
if method == "tasks/result":
fetch_calls += 1
if fetch_calls == 1:
raise ClosedResourceError
return _make_payload("fetched after reconnect")
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
reconnect_calls = 0
async def fake_connect(reset: bool = False) -> None:
nonlocal reconnect_calls
reconnect_calls += 1
assert reset is True
with patch.object(MCPTool, "connect", side_effect=fake_connect):
result = await tool.call_tool("slow_op")
assert _mcp_result_to_text(result) == "fetched after reconnect"
assert reconnect_calls == 1
assert fetch_calls == 2
async def test_fetch_task_result_second_disconnect_raises_task_state_unknown_and_cancels() -> None:
from anyio import ClosedResourceError
tool = _make_task_tool()
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="r2")
if method == "tasks/get":
return _make_task_snapshot(task_id="r2", status="completed")
if method == "tasks/result":
raise ClosedResourceError
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with (
patch.object(MCPTool, "connect", new=AsyncMock(return_value=None)),
pytest.raises(ToolExecutionException, match="task state unknown"),
):
await tool.call_tool("slow_op")
# Drain the fire-and-forget cancel so the assertion is deterministic.
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
assert cancel_called is True
async def test_call_tool_as_task_create_unparseable_success_raises() -> None:
"""An unparseable success-shaped response must NOT silently retry tools/call."""
# Result with neither task.taskId nor a valid CallToolResult shape.
unparseable = types.Result.model_validate({"foo": "bar"})
tool = _make_task_tool()
tool.session.send_request = AsyncMock(return_value=unparseable) # type: ignore[union-attr]
tool.session.call_tool = AsyncMock(return_value=types.CallToolResult(content=[])) # type: ignore[union-attr]
with pytest.raises(ToolExecutionException, match="unparseable response"):
await tool.call_tool("slow_op")
# Critically: no plain tools/call fallback (would risk double execution).
tool.session.call_tool.assert_not_called() # type: ignore[union-attr]
async def test_call_tool_as_task_max_wait_exceeded_raises_and_cancels(monkeypatch: pytest.MonkeyPatch) -> None:
from agent_framework import MCPTaskOptions
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(milliseconds=50)))
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="mw")
if method == "tasks/get":
return _make_task_snapshot(task_id="mw", status="working")
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"):
await tool.call_tool("slow_op")
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
assert cancel_called is True
async def test_call_tool_as_task_max_wait_cancels_even_when_local_cancel_option_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Locks contract: max_task_wait abandonment ignores the local-cancel option."""
from agent_framework import MCPTaskOptions
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
tool = _make_task_tool(
task_options=MCPTaskOptions(
cancel_remote_task_on_local_cancellation=False,
max_task_wait=timedelta(milliseconds=50),
),
)
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="mw2")
if method == "tasks/get":
return _make_task_snapshot(task_id="mw2", status="working")
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"):
await tool.call_tool("slow_op")
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
assert cancel_called is True
async def test_call_tool_as_task_poll_transient_request_timeout_keeps_polling(
monkeypatch: pytest.MonkeyPatch,
) -> None:
import httpx
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
tool = _make_task_tool()
poll_calls = 0
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal poll_calls, cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="t1")
if method == "tasks/get":
poll_calls += 1
if poll_calls == 1:
raise McpError(types.ErrorData(code=int(httpx.codes.REQUEST_TIMEOUT), message="slow poll"))
return _make_task_snapshot(task_id="t1", status="completed")
if method == "tasks/result":
return _make_payload("recovered after transient")
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
result = await tool.call_tool("slow_op")
assert _mcp_result_to_text(result) == "recovered after transient"
assert poll_calls == 2
# Transient retry must not fire cancel.
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
assert cancel_called is False
async def test_call_tool_as_task_poll_hard_mcperror_cancels_and_raises(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
tool = _make_task_tool()
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="h1")
if method == "tasks/get":
raise McpError(types.ErrorData(code=types.INVALID_PARAMS, message="bad task id"))
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(ToolExecutionException, match="bad task id"):
await tool.call_tool("slow_op")
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
assert cancel_called is True
async def test_call_tool_as_task_malformed_tasks_get_response_cancels_and_raises(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Malformed tasks/get response counts as abandonment (task may still be running)."""
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
tool = _make_task_tool()
# Result without a valid GetTaskResult shape (no taskId/status/etc.).
malformed = types.Result.model_validate({"some": "junk"})
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="m1")
if method == "tasks/get":
return malformed
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(ToolExecutionException, match="malformed tasks/get"):
await tool.call_tool("slow_op")
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
assert cancel_called is True
async def test_call_tool_as_task_failed_terminal_does_not_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
"""Terminal failures (server already done) must NOT fire tasks/cancel."""
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
tool = _make_task_tool()
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="f1")
if method == "tasks/get":
return _make_task_snapshot(task_id="f1", status="failed", status_message="boom")
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(ToolExecutionException, match="task failed: boom"):
await tool.call_tool("slow_op")
# Let any (incorrect) background work settle, then verify no cancel.
await asyncio.sleep(0.02)
assert cancel_called is False
async def test_try_cancel_task_logs_warning_on_timeout(
caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch
) -> None:
from agent_framework import _mcp as _mcp_module
# Shorten cancel timeout so the test is fast.
monkeypatch.setattr(_mcp_module, "_MCP_TASK_CANCEL_TIMEOUT", _mcp_module.timedelta(milliseconds=20))
tool = _make_task_tool()
async def hang(*_a: Any, **_kw: Any) -> Any:
await asyncio.sleep(10.0)
tool.session.send_request = AsyncMock(side_effect=hang) # type: ignore[union-attr]
with caplog.at_level(logging.WARNING, logger=_mcp_module.logger.name):
await tool._try_cancel_task("hang-1")
assert any("timed out" in r.getMessage() and "hang-1" in r.getMessage() for r in caplog.records)
async def test_mcp_task_options_is_frozen() -> None:
from dataclasses import FrozenInstanceError
from agent_framework import MCPTaskOptions
opts = MCPTaskOptions()
with pytest.raises(FrozenInstanceError):
opts.default_ttl = timedelta(seconds=5) # type: ignore[misc]
async def test_mcp_task_options_max_task_wait_rejects_non_positive() -> None:
from agent_framework import MCPTaskOptions
with pytest.raises(ValueError, match="positive"):
MCPTaskOptions(max_task_wait=timedelta(0))
with pytest.raises(ValueError, match="positive"):
MCPTaskOptions(max_task_wait=timedelta(seconds=-1))
async def test_fetch_task_result_hard_mcperror_raises_without_cancel() -> None:
"""tasks/result hard McpError must wrap as ToolExecutionException without cancel (server done)."""
tool = _make_task_tool()
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="hf")
if method == "tasks/get":
return _make_task_snapshot(task_id="hf", status="completed")
if method == "tasks/result":
raise McpError(types.ErrorData(code=types.INTERNAL_ERROR, message="payload vanished"))
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(ToolExecutionException, match="payload vanished"):
await tool.call_tool("slow_op")
# No raw McpError leak and no cancel — server already reported the task as done.
await asyncio.sleep(0.02)
assert cancel_called is False
async def test_completion_wait_timeout_without_max_wait_is_not_translated(monkeypatch: pytest.MonkeyPatch) -> None:
"""Stray asyncio.TimeoutError during the completion wait must not pretend the deadline
expired when max_task_wait is None (and must not fire a spurious tasks/cancel).
"""
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
tool = _make_task_tool()
def boom_parser(_: Any) -> list[Content]:
raise asyncio.TimeoutError
tool.parse_tool_results = boom_parser
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="t2")
if method == "tasks/get":
return _make_task_snapshot(task_id="t2", status="completed")
if method == "tasks/result":
return _make_payload("ok")
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(asyncio.TimeoutError):
await tool.call_tool("slow_op")
# Must NOT translate to max_task_wait expiry and must NOT cancel.
await asyncio.sleep(0.02)
assert cancel_called is False
async def test_completion_wait_inner_timeout_with_max_wait_set_propagates(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""An asyncio.TimeoutError raised by the completion wait itself must propagate
unchanged even when max_task_wait IS set, and must NOT fire a spurious cancel.
"""
from agent_framework import MCPTaskOptions
from agent_framework import _mcp as _mcp_module
monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1))
# Deadline set comfortably above the actual test run time.
tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(seconds=5)))
def boom_parser(_: Any) -> list[Content]:
raise asyncio.TimeoutError("inner parser timeout")
tool.parse_tool_results = boom_parser
cancel_called = False
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
nonlocal cancel_called
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="t3")
if method == "tasks/get":
return _make_task_snapshot(task_id="t3", status="completed")
if method == "tasks/result":
return _make_payload("ok")
if method == "tasks/cancel":
cancel_called = True
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
with pytest.raises(asyncio.TimeoutError, match="inner parser timeout"):
await tool.call_tool("slow_op")
# Inner TimeoutError must NOT be translated into "exceeded max_task_wait" and must NOT cancel.
await asyncio.sleep(0.02)
assert cancel_called is False
async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPatch) -> None:
"""Deadline must cancel through a long ``asyncio.sleep`` (clamped to MAX), not wait it out."""
from agent_framework import MCPTaskOptions
tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(milliseconds=100)))
async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any:
method = request.root.method
if method == "tools/call":
return _make_create_task_result(task_id="ds")
if method == "tasks/get":
# Suggest a 5s poll interval (gets clamped to MAX=5s); wait_for must cut through it.
return _make_task_snapshot(task_id="ds", status="working", poll_interval_ms=5000)
if method == "tasks/cancel":
return types.CancelTaskResult()
raise AssertionError(method)
tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr]
loop = asyncio.get_running_loop()
started = loop.time()
with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"):
await tool.call_tool("slow_op")
elapsed = loop.time() - started
# Should fire near the 100ms deadline, well below the 5s clamped sleep.
assert elapsed < 1.0, f"deadline did not interrupt long sleep (elapsed={elapsed:.3f}s)"
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.gather(*pending, return_exceptions=True)
# endregion