diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index fbbc171ae1..aadda788b0 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -80,11 +80,14 @@ agent_framework/ - **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s. - **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses. -- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Fields: +- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields: - `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies. - - `cancel_remote_task_on_local_cancellation: bool = True` — on local `CancelledError`, spawn a best-effort `tasks/cancel` before re-raising. -- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working. -- **Phase-aware reconnect**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id, which is safe because the request references the existing operation. + - `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel. + - `max_task_wait: timedelta | None` — client-side deadline for the whole post-create lifecycle (poll + result fetch). When exceeded, raises `ToolExecutionException` and fires a best-effort `tasks/cancel`. `None` (default) means no client-side bound. Bounds sleeps, sends, AND reconnects via `asyncio.wait_for`. +- **Permissive fallback**: servers that ignore the augmentation (return `CallToolResult` directly) or reject the unknown `task` field with `METHOD_NOT_FOUND` / `INVALID_PARAMS` fall back to the plain `session.call_tool(...)` path so legacy servers keep working. An unparseable success response (server accepted the augmented call but returned a payload that is neither `CreateTaskResult` nor `CallToolResult`) **does not** fall back — it raises `ToolExecutionException` to avoid double-executing a side-effecting tool. +- **Submit-vs-track reconnect policy**: a dropped connection before a `task_id` is known raises `ToolExecutionException("connection lost; task state unknown")` without re-issuing the augmented `tools/call`, so a server that accepted the request but lost the response cannot be made to start the same operation twice; once a `task_id` exists, `tasks/get` / `tasks/result` reconnect once and retry against the same id (a shared `_send_with_one_reconnect` helper). +- **Cancel-on-abandonment vs terminal failure**: any path where the remote task may still be running (max-wait exceeded, hard `McpError` in poll, malformed `tasks/get`, second connection loss in poll/fetch, reconnect failure) fires best-effort `tasks/cancel` before raising. Terminal failures (`failed`/`cancelled`/`input_required` server-side, `completed+isError`, malformed `tasks/result` after server completed) do **not** cancel — the server is already done. `_MCPTaskAbandoned` is the private marker distinguishing the two. +- **Transient poll retry**: a slow `tasks/get` that surfaces as `McpError(code=408 REQUEST_TIMEOUT)` is retried (bounded by `max_task_wait`). All other non-connection `McpError`s during poll are treated as abandonment. `tasks/result` does not get transient retry — the server has already completed, so a slow payload fetch is anomalous. ### File Access Harness (`_harness/_file_access.py`) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index c9fbd564e4..dd1d1c8db8 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -160,8 +160,19 @@ _MCP_TASK_CANCEL_TIMEOUT = timedelta(seconds=5) _MCP_TASK_TERMINAL_STATUSES: frozenset[str] = frozenset({"completed", "failed", "cancelled", "input_required"}) +class _MCPTaskAbandoned(ToolExecutionException): + """Raised when the remote MCP task may still be running and must be cancelled. + + Subclass of ToolExecutionException so callers see a normal tool failure. + """ + + +class _MCPDeadlineExpired(Exception): + """Internal marker for ``max_task_wait`` expiry; distinct from inner TimeoutError.""" + + @experimental(feature_id=ExperimentalFeature.MCP_LONG_RUNNING_TASKS) -@dataclass +@dataclass(frozen=True) class MCPTaskOptions: """Options controlling how MCPTool drives the MCP long-running task lifecycle. @@ -169,6 +180,9 @@ class MCPTaskOptions: the framework transparently drives the SEP-2663 ``tools/call`` → ``tasks/get`` (polled) → ``tasks/result`` lifecycle so the agent sees a normal tool result. + Instances are immutable; replace the whole object via + ``MCPTool.task_options = MCPTaskOptions(...)`` to change behavior. + Attributes: default_ttl: Optional default time-to-live forwarded to the server as ``params.task.ttl`` (milliseconds, integer). When ``None``, the server @@ -176,14 +190,24 @@ class MCPTaskOptions: cancel_remote_task_on_local_cancellation: If True (default), a local cancellation of the awaiting coroutine triggers a best-effort ``tasks/cancel`` on the server before re-raising ``CancelledError``. + Only gates ``CancelledError``; abandonment paths (max-wait, + unrecoverable poll errors, lost connection after task_id is known) + always cancel regardless of this flag. + max_task_wait: Optional client-side deadline for the whole post-create + lifecycle (poll + result fetch). When exceeded, raises + ``ToolExecutionException`` and fires a best-effort ``tasks/cancel``. + ``None`` (default) means no client-side bound. Must be positive if set. """ default_ttl: timedelta | None = None cancel_remote_task_on_local_cancellation: bool = True + max_task_wait: timedelta | None = None def __post_init__(self) -> None: if self.default_ttl is not None and self.default_ttl.total_seconds() < 0: raise ValueError("MCPTaskOptions.default_ttl must be non-negative.") + if self.max_task_wait is not None and self.max_task_wait.total_seconds() <= 0: + raise ValueError("MCPTaskOptions.max_task_wait must be positive.") def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]: @@ -1532,10 +1556,10 @@ class MCPTool: filtered_kwargs, meta = self._prepare_call_kwargs(tool_name, kwargs) parser = self.parse_tool_results or self._parse_tool_result_from_mcp - # Phase 1: issue augmented tools/call. Do NOT retry on connection loss here: + # Submit the task: issue augmented tools/call. Do NOT retry on connection loss here: # the server may have accepted the request and created a task before the # response was lost, so retrying could start the long-running operation twice. - # Reconnect-and-retry is only safe after task_id is known (phase 2). + # Reconnect-and-retry is only safe after the task_id is known. try: task_id, fallback_result = await self._call_tool_as_task_create(tool_name, filtered_kwargs, meta) except (ClosedResourceError, McpError) as ex: @@ -1565,15 +1589,40 @@ class MCPTool: assert task_id is not None # noqa: S101 # nosec B101 - protected by the branch above - # Phase 2: poll until terminal status, then fetch payload. Never re-issue tools/call - # past this point; reconnect-and-retry only against the same task_id. - try: + # Track to completion: poll until terminal, then fetch payload. Never re-issue + # tools/call past this point; reconnect-and-retry only against the same task_id. + opts = self._effective_task_options() + max_wait_s = opts.max_task_wait.total_seconds() if opts.max_task_wait is not None else None + + async def _await_task_completion() -> str | list[Content]: terminal = await self._poll_task_until_terminal(task_id) return await self._handle_terminal_task(tool_name, task_id, terminal, parser) + + try: + if max_wait_s is not None: + try: + return await self._await_with_deadline(_await_task_completion(), max_wait_s) + except _MCPDeadlineExpired as ex: + self._spawn_best_effort_cancel(task_id) + raise ToolExecutionException( + f"MCP task '{task_id}' exceeded max_task_wait of {max_wait_s}s.", + inner_exception=ex, + ) from ex + else: + return await _await_task_completion() except asyncio.CancelledError: - if self._effective_task_options().cancel_remote_task_on_local_cancellation: + if opts.cancel_remote_task_on_local_cancellation: self._spawn_best_effort_cancel(task_id) raise + except _MCPTaskAbandoned: + # Pre-terminal abandonment (hard poll error, malformed get, second + # disconnect, reconnect failure): cancel + re-raise as plain + # ToolExecutionException to the function-calling loop. + self._spawn_best_effort_cancel(task_id) + raise + # Plain ToolExecutionException from terminal failures (failed/cancelled/ + # input_required, completed+isError, malformed result post-completion) + # propagates without cancel — server is already done. async def _call_tool_as_task_create( self, tool_name: str, arguments: dict[str, Any], meta: dict[str, Any] | None @@ -1636,54 +1685,50 @@ class MCPTool: try: legacy = types.CallToolResult.model_validate(raw) - except ValidationError: - logger.debug( - "Augmented tools/call for '%s' returned a non-CreateTaskResult/non-CallToolResult payload; " - "falling back to plain tools/call.", - tool_name, - ) - fallback = await self.session.call_tool(tool_name, arguments=arguments, meta=meta) # type: ignore[union-attr] - return None, fallback + except ValidationError as ex: + # Augmented call succeeded server-side; re-issuing a plain tools/call + # could double-execute a side-effecting tool. + raise ToolExecutionException( + f"MCP server returned an unparseable response to augmented tools/call " + f"for '{tool_name}'; cannot safely retry (server may have started the operation).", + inner_exception=ex, + ) from ex return None, legacy async def _poll_task_until_terminal(self, task_id: str) -> types.GetTaskResult: """Poll ``tasks/get`` until the task reaches a terminal status.""" - from anyio import ClosedResourceError + import httpx from mcp import types from mcp.shared.exceptions import McpError + # SDK raises McpError(code=httpx.REQUEST_TIMEOUT=408) on session read timeout. + transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)}) + while True: - for attempt in range(2): - try: - request = types.ClientRequest( - types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id)) + request = types.ClientRequest( + types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id)) + ) + try: + # GetTaskResult.ttl is required-but-Optional in the SDK; coerce below. + lenient = await self._send_with_one_reconnect( + request, types.Result, operation="tasks/get", task_id=task_id + ) + except McpError as ex: + if ex.error.code in transient_codes: + logger.debug( + "Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id ) - # Use lenient Result then coerce: GetTaskResult.ttl is required by - # schema but servers may legitimately omit it. - lenient = await self.session.send_request(request, types.Result) # type: ignore[union-attr] - snapshot = self._coerce_get_task_result(lenient, task_id) - break - except (ClosedResourceError, McpError) as ex: - if not self._is_connection_lost(ex): - error_message = ex.error.message if isinstance(ex, McpError) else str(ex) - raise ToolExecutionException(error_message, inner_exception=ex) from ex - if attempt == 0: - logger.info("MCP connection lost during tasks/get; reconnecting (task_id=%s).", task_id) - try: - await self.connect(reset=True) - continue - except Exception as reconn_ex: - raise ToolExecutionException( - "Failed to reconnect to MCP server.", - inner_exception=reconn_ex, - ) from reconn_ex - raise ToolExecutionException( - f"MCP connection lost; task state unknown (task_id={task_id}).", - inner_exception=ex, - ) from ex - else: # pragma: no cover - defensive - raise ToolExecutionException(f"Failed to poll task '{task_id}'.") + await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds()) + continue + # Hard server error mid-poll: task may still be running. + raise _MCPTaskAbandoned(ex.error.message, inner_exception=ex) from ex + + try: + snapshot = self._coerce_get_task_result(lenient, task_id) + except ToolExecutionException as ex: + # Malformed tasks/get response; task may still be running. + raise _MCPTaskAbandoned(str(ex), inner_exception=ex) from ex if snapshot.status in _MCP_TASK_TERMINAL_STATUSES: return snapshot @@ -1750,40 +1795,22 @@ class MCPTool: async def _fetch_task_result(self, task_id: str) -> types.CallToolResult: """Send ``tasks/result`` and reinterpret the open-typed payload as a CallToolResult.""" - from anyio import ClosedResourceError from mcp import types from mcp.shared.exceptions import McpError from pydantic import ValidationError - for attempt in range(2): - try: - request = types.ClientRequest( - types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id)) - ) - payload = await self.session.send_request( # type: ignore[union-attr] - request, types.GetTaskPayloadResult - ) - break - except (ClosedResourceError, McpError) as ex: - if not self._is_connection_lost(ex): - error_message = ex.error.message if isinstance(ex, McpError) else str(ex) - raise ToolExecutionException(error_message, inner_exception=ex) from ex - if attempt == 0: - logger.info("MCP connection lost during tasks/result; reconnecting (task_id=%s).", task_id) - try: - await self.connect(reset=True) - continue - except Exception as reconn_ex: - raise ToolExecutionException( - "Failed to reconnect to MCP server.", - inner_exception=reconn_ex, - ) from reconn_ex - raise ToolExecutionException( - f"MCP connection lost; task state unknown (task_id={task_id}).", - inner_exception=ex, - ) from ex - else: # pragma: no cover - defensive - raise ToolExecutionException(f"Failed to fetch result for task '{task_id}'.") + request = types.ClientRequest( + types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id)) + ) + # Connection-loss retry only via the helper; no transient-code retry — server + # has already completed the task, so a slow payload fetch is anomalous. + try: + payload = await self._send_with_one_reconnect( + request, types.GetTaskPayloadResult, operation="tasks/result", task_id=task_id + ) + except McpError as ex: + # Server reported completed; a hard fetch error is a plain failure (no cancel). + raise ToolExecutionException(ex.error.message, inner_exception=ex) from ex # GetTaskPayloadResult carries the tool result via extra fields; reinterpret as CallToolResult. payload_dict = payload.model_dump(by_alias=True, exclude_none=True) @@ -1791,10 +1818,78 @@ class MCPTool: try: return types.CallToolResult.model_validate(payload_dict) except ValidationError as ex: + # Server reported completed; malformed payload is a plain failure (no cancel needed). raise ToolExecutionException( - f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult." + f"MCP task '{task_id}' result payload could not be parsed as a CallToolResult.", + inner_exception=ex, ) from ex + async def _send_with_one_reconnect( + self, + request: types.ClientRequest, + result_type: type[Any], + *, + operation: str, + task_id: str, + ) -> Any: + """Send ``request`` with one reconnect-and-retry on connection loss. + + After a second loss (or reconnect failure), raise ``_MCPTaskAbandoned``. + Non-connection errors propagate unchanged. + """ + from anyio import ClosedResourceError + from mcp.shared.exceptions import McpError + + for attempt in range(2): + try: + return await self.session.send_request(request, result_type) # type: ignore[union-attr] + except (ClosedResourceError, McpError) as ex: + if not self._is_connection_lost(ex): + raise + if attempt == 0: + logger.info( + "MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id + ) + try: + await self.connect(reset=True) + except Exception as reconn_ex: + # Reconnect failure: task may still be running. + raise _MCPTaskAbandoned( + "Failed to reconnect to MCP server.", inner_exception=reconn_ex + ) from reconn_ex + continue + # Second connection loss: task may still be running. + raise _MCPTaskAbandoned( + f"MCP connection lost; task state unknown (task_id={task_id}).", + inner_exception=ex, + ) from ex + raise AssertionError(f"unreachable: {operation} for {task_id}") # pragma: no cover + + @staticmethod + async def _await_with_deadline(coro: Coroutine[Any, Any, Any], timeout_s: float) -> Any: + """Await ``coro`` with a deadline; raise ``_MCPDeadlineExpired`` only on deadline. + + Unlike ``asyncio.wait_for``, an ``asyncio.TimeoutError`` raised by ``coro`` + itself propagates unchanged so callers can distinguish their own deadline + from a stray inner timeout. + """ + inner = asyncio.ensure_future(coro) + try: + done, _pending = await asyncio.wait({inner}, timeout=timeout_s) + except BaseException: + # Outer caller cancelled (or another exception): cancel inner + drain. + inner.cancel() + with contextlib.suppress(BaseException): + await inner + raise + if inner in done: + return inner.result() + # Deadline fired before inner finished. + inner.cancel() + with contextlib.suppress(BaseException): + await inner + raise _MCPDeadlineExpired + def _spawn_best_effort_cancel(self, task_id: str) -> None: """Fire-and-forget ``tasks/cancel`` so local cancellation propagates server-side.""" try: @@ -1807,18 +1902,35 @@ class MCPTool: cancel_task.add_done_callback(self._pending_reload_tasks.discard) async def _try_cancel_task(self, task_id: str) -> None: - """Send ``tasks/cancel`` swallowing every failure; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``.""" + """Send ``tasks/cancel``; bounded by ``_MCP_TASK_CANCEL_TIMEOUT``. + + Failures log at warning so unattributed orphan tasks are debuggable. + """ from mcp import types - async def _send() -> None: - request = types.ClientRequest(types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))) - try: - await self.session.send_request(request, types.CancelTaskResult) # type: ignore[union-attr] - except Exception: - logger.debug("Best-effort tasks/cancel for '%s' failed.", task_id, exc_info=True) - - with contextlib.suppress(Exception, asyncio.CancelledError, asyncio.TimeoutError): - await asyncio.wait_for(_send(), timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds()) + request = types.ClientRequest( + types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id)) + ) + try: + await asyncio.wait_for( + self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr] + timeout=_MCP_TASK_CANCEL_TIMEOUT.total_seconds(), + ) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError: + logger.warning( + "Best-effort tasks/cancel for '%s' timed out after %.1fs; " + "remote task may still be running.", + task_id, + _MCP_TASK_CANCEL_TIMEOUT.total_seconds(), + ) + except Exception: + logger.warning( + "Best-effort tasks/cancel for '%s' failed; remote task may still be running.", + task_id, + exc_info=True, + ) @staticmethod def _is_connection_lost(ex: BaseException) -> bool: diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index d592e68d6b..1a90554c71 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -7,6 +7,7 @@ import logging import os import sys from contextlib import _AsyncGeneratorContextManager # type: ignore +from datetime import timedelta from typing import Any from unittest.mock import AsyncMock, Mock, patch @@ -5553,4 +5554,504 @@ async def test_call_tool_as_task_create_disconnect_does_not_retry() -> None: reconnect_mock.assert_not_awaited() +async def test_fetch_task_result_reconnects_during_fetch() -> None: + from anyio import ClosedResourceError + + tool = _make_task_tool() + + fetch_calls = 0 + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal fetch_calls + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="r1") + if method == "tasks/get": + return _make_task_snapshot(task_id="r1", status="completed") + if method == "tasks/result": + fetch_calls += 1 + if fetch_calls == 1: + raise ClosedResourceError + return _make_payload("fetched after reconnect") + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + reconnect_calls = 0 + + async def fake_connect(reset: bool = False) -> None: + nonlocal reconnect_calls + reconnect_calls += 1 + assert reset is True + + with patch.object(MCPTool, "connect", side_effect=fake_connect): + result = await tool.call_tool("slow_op") + + assert _mcp_result_to_text(result) == "fetched after reconnect" + assert reconnect_calls == 1 + assert fetch_calls == 2 + + +async def test_fetch_task_result_second_disconnect_raises_task_state_unknown_and_cancels() -> None: + from anyio import ClosedResourceError + + tool = _make_task_tool() + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="r2") + if method == "tasks/get": + return _make_task_snapshot(task_id="r2", status="completed") + if method == "tasks/result": + raise ClosedResourceError + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with ( + patch.object(MCPTool, "connect", new=AsyncMock(return_value=None)), + pytest.raises(ToolExecutionException, match="task state unknown"), + ): + await tool.call_tool("slow_op") + + # Drain the fire-and-forget cancel so the assertion is deterministic. + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + assert cancel_called is True + + +async def test_call_tool_as_task_create_unparseable_success_raises() -> None: + """An unparseable success-shaped response must NOT silently retry tools/call.""" + # Result with neither task.taskId nor a valid CallToolResult shape. + unparseable = types.Result.model_validate({"foo": "bar"}) + + tool = _make_task_tool() + tool.session.send_request = AsyncMock(return_value=unparseable) # type: ignore[union-attr] + tool.session.call_tool = AsyncMock(return_value=types.CallToolResult(content=[])) # type: ignore[union-attr] + + with pytest.raises(ToolExecutionException, match="unparseable response"): + await tool.call_tool("slow_op") + + # Critically: no plain tools/call fallback (would risk double execution). + tool.session.call_tool.assert_not_called() # type: ignore[union-attr] + + +async def test_call_tool_as_task_max_wait_exceeded_raises_and_cancels(monkeypatch: pytest.MonkeyPatch) -> None: + from agent_framework import MCPTaskOptions + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(milliseconds=50))) + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="mw") + if method == "tasks/get": + return _make_task_snapshot(task_id="mw", status="working") + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"): + await tool.call_tool("slow_op") + + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + assert cancel_called is True + + +async def test_call_tool_as_task_max_wait_cancels_even_when_local_cancel_option_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Locks contract: max_task_wait abandonment ignores the local-cancel option.""" + from agent_framework import MCPTaskOptions + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + tool = _make_task_tool( + task_options=MCPTaskOptions( + cancel_remote_task_on_local_cancellation=False, + max_task_wait=timedelta(milliseconds=50), + ), + ) + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="mw2") + if method == "tasks/get": + return _make_task_snapshot(task_id="mw2", status="working") + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"): + await tool.call_tool("slow_op") + + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + assert cancel_called is True + + +async def test_call_tool_as_task_poll_transient_request_timeout_keeps_polling( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import httpx + + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + tool = _make_task_tool() + + poll_calls = 0 + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal poll_calls, cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="t1") + if method == "tasks/get": + poll_calls += 1 + if poll_calls == 1: + raise McpError(types.ErrorData(code=int(httpx.codes.REQUEST_TIMEOUT), message="slow poll")) + return _make_task_snapshot(task_id="t1", status="completed") + if method == "tasks/result": + return _make_payload("recovered after transient") + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + result = await tool.call_tool("slow_op") + assert _mcp_result_to_text(result) == "recovered after transient" + assert poll_calls == 2 + # Transient retry must not fire cancel. + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + assert cancel_called is False + + +async def test_call_tool_as_task_poll_hard_mcperror_cancels_and_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + tool = _make_task_tool() + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="h1") + if method == "tasks/get": + raise McpError(types.ErrorData(code=types.INVALID_PARAMS, message="bad task id")) + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(ToolExecutionException, match="bad task id"): + await tool.call_tool("slow_op") + + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + assert cancel_called is True + + +async def test_call_tool_as_task_malformed_tasks_get_response_cancels_and_raises( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Malformed tasks/get response counts as abandonment (task may still be running).""" + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + tool = _make_task_tool() + + # Result without a valid GetTaskResult shape (no taskId/status/etc.). + malformed = types.Result.model_validate({"some": "junk"}) + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="m1") + if method == "tasks/get": + return malformed + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(ToolExecutionException, match="malformed tasks/get"): + await tool.call_tool("slow_op") + + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + assert cancel_called is True + + +async def test_call_tool_as_task_failed_terminal_does_not_cancel(monkeypatch: pytest.MonkeyPatch) -> None: + """Terminal failures (server already done) must NOT fire tasks/cancel.""" + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + tool = _make_task_tool() + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="f1") + if method == "tasks/get": + return _make_task_snapshot(task_id="f1", status="failed", status_message="boom") + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(ToolExecutionException, match="task failed: boom"): + await tool.call_tool("slow_op") + + # Let any (incorrect) background work settle, then verify no cancel. + await asyncio.sleep(0.02) + assert cancel_called is False + + +async def test_try_cancel_task_logs_warning_on_timeout( + caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch +) -> None: + from agent_framework import _mcp as _mcp_module + + # Shorten cancel timeout so the test is fast. + monkeypatch.setattr(_mcp_module, "_MCP_TASK_CANCEL_TIMEOUT", _mcp_module.timedelta(milliseconds=20)) + + tool = _make_task_tool() + + async def hang(*_a: Any, **_kw: Any) -> Any: + await asyncio.sleep(10.0) + + tool.session.send_request = AsyncMock(side_effect=hang) # type: ignore[union-attr] + + with caplog.at_level(logging.WARNING, logger=_mcp_module.logger.name): + await tool._try_cancel_task("hang-1") + + assert any("timed out" in r.getMessage() and "hang-1" in r.getMessage() for r in caplog.records) + + +async def test_mcp_task_options_is_frozen() -> None: + from dataclasses import FrozenInstanceError + + from agent_framework import MCPTaskOptions + + opts = MCPTaskOptions() + with pytest.raises(FrozenInstanceError): + opts.default_ttl = timedelta(seconds=5) # type: ignore[misc] + + +async def test_mcp_task_options_max_task_wait_rejects_non_positive() -> None: + from agent_framework import MCPTaskOptions + + with pytest.raises(ValueError, match="positive"): + MCPTaskOptions(max_task_wait=timedelta(0)) + with pytest.raises(ValueError, match="positive"): + MCPTaskOptions(max_task_wait=timedelta(seconds=-1)) + + +async def test_fetch_task_result_hard_mcperror_raises_without_cancel() -> None: + """tasks/result hard McpError must wrap as ToolExecutionException without cancel (server done).""" + tool = _make_task_tool() + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="hf") + if method == "tasks/get": + return _make_task_snapshot(task_id="hf", status="completed") + if method == "tasks/result": + raise McpError(types.ErrorData(code=types.INTERNAL_ERROR, message="payload vanished")) + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(ToolExecutionException, match="payload vanished"): + await tool.call_tool("slow_op") + + # No raw McpError leak and no cancel — server already reported the task as done. + await asyncio.sleep(0.02) + assert cancel_called is False + + +async def test_completion_wait_timeout_without_max_wait_is_not_translated(monkeypatch: pytest.MonkeyPatch) -> None: + """Stray asyncio.TimeoutError during the completion wait must not pretend the deadline + expired when max_task_wait is None (and must not fire a spurious tasks/cancel). + """ + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + tool = _make_task_tool() + + def boom_parser(_: Any) -> list[Content]: + raise asyncio.TimeoutError + + tool.parse_tool_results = boom_parser + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="t2") + if method == "tasks/get": + return _make_task_snapshot(task_id="t2", status="completed") + if method == "tasks/result": + return _make_payload("ok") + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(asyncio.TimeoutError): + await tool.call_tool("slow_op") + + # Must NOT translate to max_task_wait expiry and must NOT cancel. + await asyncio.sleep(0.02) + assert cancel_called is False + + +async def test_completion_wait_inner_timeout_with_max_wait_set_propagates( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """An asyncio.TimeoutError raised by the completion wait itself must propagate + unchanged even when max_task_wait IS set, and must NOT fire a spurious cancel. + """ + from agent_framework import MCPTaskOptions + from agent_framework import _mcp as _mcp_module + + monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) + + # Deadline set comfortably above the actual test run time. + tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(seconds=5))) + + def boom_parser(_: Any) -> list[Content]: + raise asyncio.TimeoutError("inner parser timeout") + + tool.parse_tool_results = boom_parser + + cancel_called = False + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + nonlocal cancel_called + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="t3") + if method == "tasks/get": + return _make_task_snapshot(task_id="t3", status="completed") + if method == "tasks/result": + return _make_payload("ok") + if method == "tasks/cancel": + cancel_called = True + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + with pytest.raises(asyncio.TimeoutError, match="inner parser timeout"): + await tool.call_tool("slow_op") + + # Inner TimeoutError must NOT be translated into "exceeded max_task_wait" and must NOT cancel. + await asyncio.sleep(0.02) + assert cancel_called is False + + +async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPatch) -> None: + """Deadline must cancel through a long ``asyncio.sleep`` (clamped to MAX), not wait it out.""" + from agent_framework import MCPTaskOptions + + tool = _make_task_tool(task_options=MCPTaskOptions(max_task_wait=timedelta(milliseconds=100))) + + async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> Any: + method = request.root.method + if method == "tools/call": + return _make_create_task_result(task_id="ds") + if method == "tasks/get": + # Suggest a 5s poll interval (gets clamped to MAX=5s); wait_for must cut through it. + return _make_task_snapshot(task_id="ds", status="working", poll_interval_ms=5000) + if method == "tasks/cancel": + return types.CancelTaskResult() + raise AssertionError(method) + + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + + loop = asyncio.get_running_loop() + started = loop.time() + with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"): + await tool.call_tool("slow_op") + elapsed = loop.time() - started + + # Should fire near the 100ms deadline, well below the 5s clamped sleep. + assert elapsed < 1.0, f"deadline did not interrupt long sleep (elapsed={elapsed:.3f}s)" + + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + # endregion