Python: Add Python parity for InvokeMcpTool in declarative workflow (#5630)

* Add Python parity for HttpRequestAction in declarative workflow

* Ran pyupgrade and pright to fix CI issues

* Fix conversation ID dot parsing for http executor

* Removed unnecessary export command

* Initial implementation of invoke mcp tool in python

* Update sample to support require approval to be toggled by environment variable.

* Fix cache and PR comments

* Update python/samples/03-workflows/declarative/invoke_mcp_tool/main.py

Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com>

---------

Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com>
This commit is contained in:
Peter Ibekwe
2026-05-05 13:16:03 -07:00
committed by GitHub
Unverified
parent f3f71f0fe8
commit f25e81701d
13 changed files with 2680 additions and 0 deletions
+1
View File
@@ -9,6 +9,7 @@ YAML/JSON-based declarative agent and workflow definitions.
- **`WorkflowState`** - State management for declarative workflows
- **`ProviderTypeMapping`** - Maps provider types to implementations
- **`HttpRequestHandler`** / **`DefaultHttpRequestHandler`** - Pluggable HTTP transport for the `HttpRequestAction` declarative action (configured via `WorkflowFactory(http_request_handler=...)`)
- **`MCPToolHandler`** / **`DefaultMCPToolHandler`** - Pluggable MCP transport for the `InvokeMcpTool` declarative action (configured via `WorkflowFactory(mcp_tool_handler=...)`)
- **`DeclarativeLoaderError`** / **`ProviderLookupError`** / **`DeclarativeWorkflowError`** / **`DeclarativeActionError`** - Error types
## External Input Handling
@@ -9,11 +9,18 @@ from ._workflows import (
DeclarativeActionError,
DeclarativeWorkflowError,
DefaultHttpRequestHandler,
DefaultMCPToolHandler,
ExternalInputRequest,
ExternalInputResponse,
HttpRequestHandler,
HttpRequestInfo,
HttpRequestResult,
MCPToolApprovalRequest,
MCPToolHandler,
MCPToolInvocation,
MCPToolResult,
ToolApprovalRequest,
ToolApprovalResponse,
WorkflowFactory,
WorkflowState,
)
@@ -31,13 +38,20 @@ __all__ = [
"DeclarativeLoaderError",
"DeclarativeWorkflowError",
"DefaultHttpRequestHandler",
"DefaultMCPToolHandler",
"ExternalInputRequest",
"ExternalInputResponse",
"HttpRequestHandler",
"HttpRequestInfo",
"HttpRequestResult",
"MCPToolApprovalRequest",
"MCPToolHandler",
"MCPToolInvocation",
"MCPToolResult",
"ProviderLookupError",
"ProviderTypeMapping",
"ToolApprovalRequest",
"ToolApprovalResponse",
"WorkflowFactory",
"WorkflowState",
"__version__",
@@ -72,6 +72,11 @@ from ._executors_http import (
HTTP_ACTION_EXECUTORS,
HttpRequestActionExecutor,
)
from ._executors_mcp import (
MCP_ACTION_EXECUTORS,
InvokeMcpToolActionExecutor,
MCPToolApprovalRequest,
)
from ._executors_tools import (
FUNCTION_TOOL_REGISTRY_KEY,
TOOL_ACTION_EXECUTORS,
@@ -90,6 +95,12 @@ from ._http_handler import (
HttpRequestInfo,
HttpRequestResult,
)
from ._mcp_handler import (
DefaultMCPToolHandler,
MCPToolHandler,
MCPToolInvocation,
MCPToolResult,
)
from ._state import WorkflowState
__all__ = [
@@ -102,6 +113,7 @@ __all__ = [
"EXTERNAL_INPUT_EXECUTORS",
"FUNCTION_TOOL_REGISTRY_KEY",
"HTTP_ACTION_EXECUTORS",
"MCP_ACTION_EXECUTORS",
"TOOL_ACTION_EXECUTORS",
"TOOL_APPROVAL_STATE_KEY",
"TOOL_REGISTRY_KEY",
@@ -126,6 +138,7 @@ __all__ = [
"DeclarativeWorkflowError",
"DeclarativeWorkflowState",
"DefaultHttpRequestHandler",
"DefaultMCPToolHandler",
"EmitEventExecutor",
"EndConversationExecutor",
"EndWorkflowExecutor",
@@ -140,9 +153,14 @@ __all__ = [
"HttpRequestResult",
"InvokeAzureAgentExecutor",
"InvokeFunctionToolExecutor",
"InvokeMcpToolActionExecutor",
"JoinExecutor",
"LoopControl",
"LoopIterationResult",
"MCPToolApprovalRequest",
"MCPToolHandler",
"MCPToolInvocation",
"MCPToolResult",
"QuestionExecutor",
"RequestExternalInputExecutor",
"ResetVariableExecutor",
@@ -41,8 +41,10 @@ from ._executors_control_flow import (
)
from ._executors_external_input import EXTERNAL_INPUT_EXECUTORS
from ._executors_http import HTTP_ACTION_EXECUTORS, HttpRequestActionExecutor
from ._executors_mcp import MCP_ACTION_EXECUTORS, InvokeMcpToolActionExecutor
from ._executors_tools import TOOL_ACTION_EXECUTORS, InvokeFunctionToolExecutor
from ._http_handler import HttpRequestHandler
from ._mcp_handler import MCPToolHandler
logger = logging.getLogger(__name__)
@@ -55,6 +57,7 @@ ALL_ACTION_EXECUTORS = {
**EXTERNAL_INPUT_EXECUTORS,
**TOOL_ACTION_EXECUTORS,
**HTTP_ACTION_EXECUTORS,
**MCP_ACTION_EXECUTORS,
}
# Action kinds that terminate control flow (no fall-through to successor)
@@ -90,6 +93,7 @@ ACTION_REQUIRED_FIELDS: dict[str, list[str]] = {
"EmitEvent": ["event"],
"InvokeFunctionTool": ["functionName"],
"HttpRequestAction": ["url"],
"InvokeMcpTool": ["serverUrl", "toolName"],
}
# Alternate field names that satisfy required field requirements
@@ -135,6 +139,7 @@ class DeclarativeWorkflowBuilder:
validate: bool = True,
max_iterations: int | None = None,
http_request_handler: HttpRequestHandler | None = None,
mcp_tool_handler: MCPToolHandler | None = None,
):
"""Initialize the builder.
@@ -150,6 +155,9 @@ class DeclarativeWorkflowBuilder:
http_request_handler: Handler used to dispatch HttpRequestAction requests.
Must be supplied when the workflow contains any HttpRequestAction;
otherwise build raises ``DeclarativeWorkflowError``.
mcp_tool_handler: Handler used to dispatch InvokeMcpTool calls.
Must be supplied when the workflow contains any InvokeMcpTool;
otherwise build raises ``DeclarativeWorkflowError``.
"""
self._yaml_def = yaml_definition
self._workflow_id = workflow_id or yaml_definition.get("name", "declarative_workflow")
@@ -162,6 +170,7 @@ class DeclarativeWorkflowBuilder:
self._validate = validate
self._seen_explicit_ids: set[str] = set() # Track explicit IDs for duplicate detection
self._http_request_handler = http_request_handler
self._mcp_tool_handler = mcp_tool_handler
# Resolve max_iterations: explicit arg > YAML maxTurns > core default
resolved = max_iterations if max_iterations is not None else yaml_definition.get("maxTurns")
if resolved is not None and (not isinstance(resolved, int) or resolved <= 0):
@@ -481,6 +490,19 @@ class DeclarativeWorkflowBuilder:
id=action_id,
http_request_handler=self._http_request_handler,
)
elif kind == "InvokeMcpTool":
if self._mcp_tool_handler is None:
raise DeclarativeWorkflowError(
f"Workflow defines InvokeMcpTool '{action_id}' but no "
"mcp_tool_handler was supplied to WorkflowFactory. Pass "
"mcp_tool_handler=DefaultMCPToolHandler() (or a custom "
"implementation) to enable MCP tool invocations."
)
executor = InvokeMcpToolActionExecutor(
action_def,
id=action_id,
mcp_tool_handler=self._mcp_tool_handler,
)
else:
executor = executor_class(action_def, id=action_id)
self._executors[action_id] = executor
@@ -0,0 +1,614 @@
# Copyright (c) Microsoft. All rights reserved.
"""Executor for the ``InvokeMcpTool`` declarative action.
Mirrors the .NET ``InvokeMcpToolExecutor``: dispatches an MCP tool call through
the configured :class:`MCPToolHandler`, parses tool outputs, and routes
results to the configured ``output.{result, messages, autoSend}`` paths and
optional conversation history. Supports a human-in-loop approval flow via
``ctx.request_info()`` / :func:`@response_handler` for ``requireApproval=true``.
Security notes:
- The executor never echoes header VALUES (auth tokens, API keys) into the
approval request — only header NAMES are surfaced to the caller. This
matches the security posture of :mod:`._executors_http` (which never logs
request headers either) and prevents secrets from leaking through workflow
events that are typically observable to operators / UIs.
- ``_MCPToolApprovalState`` snapshots the EVALUATED values for non-secret
fields (server URL, tool name, arguments) at approval-request time so that
subsequent state mutations cannot make the executor "approve X then call
Y". Headers are stored as the raw expression strings (not evaluated values)
so secrets are not persisted in the workflow's checkpoint state. They are
re-evaluated on resume.
- Tool outputs flow back into agent conversations through ``conversationId``
and through Tool-role messages emitted to ``output.messages``. They share
the same prompt-injection risk surface as ``HttpRequestAction``: workflow
authors must trust the MCP server they invoke.
"""
import json
import logging
import uuid
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
import httpx
from agent_framework import (
Content,
Message,
WorkflowContext,
handler,
response_handler,
)
from agent_framework.exceptions import ToolExecutionException
from ._declarative_base import (
ActionComplete,
DeclarativeActionExecutor,
DeclarativeWorkflowState,
)
from ._executors_tools import ToolApprovalResponse
from ._mcp_handler import MCPToolHandler, MCPToolInvocation, MCPToolResult
__all__ = [
"MCP_ACTION_EXECUTORS",
"InvokeMcpToolActionExecutor",
"MCPToolApprovalRequest",
]
logger = logging.getLogger(__name__)
_MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state"
# ---------------------------------------------------------------------------
# Request / state types
# ---------------------------------------------------------------------------
@dataclass
class MCPToolApprovalRequest:
"""Approval request emitted before invoking an MCP tool.
Mirrors :class:`agent_framework_declarative.ToolApprovalRequest` but for
MCP-style invocations. Only header NAMES are surfaced — header values are
intentionally omitted because they typically carry authentication
secrets.
Attributes:
request_id: Unique identifier for this approval request. Matches the
id workflow event-emitters use.
tool_name: Evaluated name of the tool to be invoked.
server_url: Evaluated MCP server URL.
server_label: Optional human-readable label for diagnostics.
arguments: Evaluated arguments to be forwarded to the tool.
header_names: Sorted list of outbound header names (no values). Empty
when no headers are configured.
"""
request_id: str
tool_name: str
server_url: str
server_label: str | None
arguments: dict[str, Any]
header_names: list[str] = field(default_factory=lambda: [])
@dataclass
class _MCPToolApprovalState:
"""Internal state saved during the approval yield for resumption.
Stores **evaluated** values for non-secret fields to prevent
"approve X / execute Y" attacks. Stores the raw expression string for
``headers`` so that secret values are NOT persisted in checkpoint state;
the expressions are re-evaluated against current state on resume.
"""
server_url: str
tool_name: str
server_label: str | None
arguments: dict[str, Any]
connection_name: str | None
headers_def: Any
auto_send: bool
conversation_id_expr: str | None
output_messages_path: str | None
output_result_path: str | None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_messages_path(state: DeclarativeWorkflowState, conversation_id_expr: str | None) -> str | None:
"""Return the configured conversation messages path, if any.
Returns ``System.conversations.{evaluated_id}.messages`` when a
``conversation_id_expr`` is configured and evaluates to a non-empty value.
Returns ``None`` when no conversation id expression is configured or when
the expression evaluates to ``None`` or an empty string (mirrors .NET
``GetConversationId`` behaviour).
"""
if not conversation_id_expr:
return None
evaluated = state.eval_if_expression(conversation_id_expr)
if evaluated is None or (isinstance(evaluated, str) and not evaluated):
return None
return f"System.conversations.{evaluated}.messages"
def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None:
"""Extract a state path from ``output.{key}`` field.
Supports two YAML shapes:
- ``output: { result: Local.MyVar }`` — plain string.
- ``output: { result: { path: Local.MyVar } }`` — object form.
"""
output: Any = action_def.get("output")
if not isinstance(output, Mapping):
return None
value: Any = output.get(key) # type: ignore[reportUnknownMemberType]
if isinstance(value, str):
return value or None
if isinstance(value, Mapping):
path: Any = value.get("path") # type: ignore[reportUnknownMemberType]
return path if isinstance(path, str) and path else None
return None
def _format_outputs_for_send(parsed_results: list[Any]) -> str:
"""Render parsed MCP outputs to a string for ``ctx.yield_output(...)``.
- Empty list → ``""``.
- All-string list → newline-joined.
- Single element (any type — scalar, dict, list) → JSON-dumped element.
This avoids surprising ``"[42]"`` / ``"[true]"`` / ``"[null]"`` when
an MCP tool returns a single scalar JSON value.
- Multi-element non-string list → JSON-dump the whole list.
"""
if not parsed_results:
return ""
if all(isinstance(item, str) for item in parsed_results):
return "\n".join(parsed_results) # type: ignore[arg-type]
if len(parsed_results) == 1:
return json.dumps(parsed_results[0], ensure_ascii=False)
return json.dumps(parsed_results, ensure_ascii=False)
# ---------------------------------------------------------------------------
# Executor
# ---------------------------------------------------------------------------
class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
"""Executor for the ``InvokeMcpTool`` declarative action.
Dispatches through the supplied :class:`MCPToolHandler` and:
- Evaluates ``serverUrl`` / ``toolName`` / ``serverLabel`` / ``arguments``
/ ``headers`` / ``connection.name`` from the action definition.
- When ``requireApproval=true``: emits a :class:`MCPToolApprovalRequest`
via ``ctx.request_info()`` and yields. On resume, the response is
checked; on rejection, ``output.result`` is set to ``"Error: ..."`` and
no tool call is made.
- On success: parses each :class:`agent_framework.Content` output (text →
JSON-first / data / uri → URI string) and assigns the parsed list to
``output.result``. Builds a single Tool-role :class:`Message`
containing all output contents and assigns it to ``output.messages``.
When ``output.autoSend`` is true (default), emits the rendered string
via ``ctx.yield_output(...)``. When ``conversationId`` is configured,
appends an Assistant-role :class:`Message` with the same contents to
``System.conversations.{id}.messages``.
- On error returned by the handler (``is_error=True``): assigns
``"Error: <message>"`` to ``output.result`` and completes normally
(parity with .NET ``AssignErrorAsync``).
.. note::
``output.messages`` receives a SINGLE Tool-role :class:`Message`
(containing the full tool output as ``contents``), unlike
:class:`agent_framework_declarative.InvokeFunctionToolExecutor` which
writes a list of two messages (assistant call + tool result). This
matches the .NET ``InvokeMcpToolExecutor`` output contract.
"""
def __init__(
self,
action_def: dict[str, Any],
*,
id: str | None = None,
mcp_tool_handler: MCPToolHandler,
) -> None:
"""Create an MCP tool action executor.
Args:
action_def: Parsed ``InvokeMcpTool`` YAML dict.
id: Optional executor id (defaults to action id or generated).
mcp_tool_handler: Handler used to dispatch MCP tool calls.
Required: the builder enforces presence at workflow-build
time.
"""
super().__init__(action_def, id=id)
self._mcp_tool_handler = mcp_tool_handler
# ----- Main handler --------------------------------------------------------
@handler
async def handle_action(
self,
trigger: Any,
ctx: WorkflowContext[ActionComplete, str],
) -> None:
"""Execute the MCP tool action."""
state = await self._ensure_state_initialized(ctx, trigger)
server_url = self._get_server_url(state)
tool_name = self._get_tool_name(state)
server_label = self._get_server_label(state)
arguments = self._get_arguments(state)
headers = self._get_headers(state)
connection_name = self._get_connection_name(state)
require_approval = self._get_require_approval(state)
auto_send = self._get_auto_send(state)
conversation_id_expr = self._action_def.get("conversationId")
output_messages_path = _get_output_path(self._action_def, "messages")
output_result_path = _get_output_path(self._action_def, "result")
if require_approval:
request_id = str(uuid.uuid4())
approval_state = _MCPToolApprovalState(
server_url=server_url,
tool_name=tool_name,
server_label=server_label,
arguments=arguments,
connection_name=connection_name,
headers_def=self._action_def.get("headers"),
auto_send=auto_send,
conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None,
output_messages_path=output_messages_path,
output_result_path=output_result_path,
)
ctx.state.set(self._approval_key(), approval_state)
request = MCPToolApprovalRequest(
request_id=request_id,
tool_name=tool_name,
server_url=server_url,
server_label=server_label,
arguments=arguments,
header_names=sorted(headers.keys()),
)
logger.info(
"%s: requesting approval for MCP tool '%s' on '%s'",
self.__class__.__name__,
tool_name,
server_url,
)
await ctx.request_info(request, ToolApprovalResponse, request_id=request_id)
# Workflow yields here — resume in handle_approval_response.
return
# No approval required - invoke directly.
invocation = MCPToolInvocation(
server_url=server_url,
tool_name=tool_name,
server_label=server_label,
arguments=arguments,
headers=headers,
connection_name=connection_name,
)
result = await self._invoke_with_narrow_catch(invocation)
await self._process_result(
ctx=ctx,
state=state,
result=result,
auto_send=auto_send,
conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None,
output_messages_path=output_messages_path,
output_result_path=output_result_path,
)
await ctx.send_message(ActionComplete())
# ----- Approval response handler ------------------------------------------
@response_handler
async def handle_approval_response(
self,
original_request: MCPToolApprovalRequest,
response: ToolApprovalResponse,
ctx: WorkflowContext[ActionComplete, str],
) -> None:
"""Resume after the workflow yielded for an approval request."""
state = self._get_state(ctx.state)
approval_key = self._approval_key()
try:
approval_state: _MCPToolApprovalState = ctx.state.get(approval_key)
except KeyError:
logger.error("%s: approval state missing for executor '%s'", self.__class__.__name__, self.id)
await ctx.send_message(ActionComplete())
return
try:
ctx.state.delete(approval_key)
except KeyError:
logger.warning("%s: approval state already deleted for '%s'", self.__class__.__name__, self.id)
if not response.approved:
logger.info(
"%s: MCP tool '%s' rejected: %s",
self.__class__.__name__,
approval_state.tool_name,
response.reason,
)
self._assign_error(
state, approval_state.output_result_path, "MCP tool invocation was not approved by user."
)
await ctx.send_message(ActionComplete())
return
# Approved — re-evaluate headers (not stored at approval time for security).
headers = self._evaluate_headers(state, approval_state.headers_def)
invocation = MCPToolInvocation(
server_url=approval_state.server_url,
tool_name=approval_state.tool_name,
server_label=approval_state.server_label,
arguments=approval_state.arguments,
headers=headers,
connection_name=approval_state.connection_name,
)
result = await self._invoke_with_narrow_catch(invocation)
await self._process_result(
ctx=ctx,
state=state,
result=result,
auto_send=approval_state.auto_send,
conversation_id_expr=approval_state.conversation_id_expr,
output_messages_path=approval_state.output_messages_path,
output_result_path=approval_state.output_result_path,
)
await ctx.send_message(ActionComplete())
# ----- Field resolution ----------------------------------------------------
def _get_server_url(self, state: DeclarativeWorkflowState) -> str:
raw = self._action_def.get("serverUrl")
if raw is None:
raise ValueError("InvokeMcpTool requires a 'serverUrl' field.")
evaluated = state.eval_if_expression(raw)
if not isinstance(evaluated, str) or not evaluated:
raise ValueError("InvokeMcpTool 'serverUrl' evaluated to an empty value.")
return evaluated
def _get_tool_name(self, state: DeclarativeWorkflowState) -> str:
raw = self._action_def.get("toolName")
if raw is None:
raise ValueError("InvokeMcpTool requires a 'toolName' field.")
evaluated = state.eval_if_expression(raw)
if not isinstance(evaluated, str) or not evaluated:
raise ValueError("InvokeMcpTool 'toolName' evaluated to an empty value.")
return evaluated
def _get_server_label(self, state: DeclarativeWorkflowState) -> str | None:
raw = self._action_def.get("serverLabel")
if raw is None:
return None
evaluated = state.eval_if_expression(raw)
if evaluated is None:
return None
text = str(evaluated)
return text or None
def _get_arguments(self, state: DeclarativeWorkflowState) -> dict[str, Any]:
"""Evaluate ``arguments`` map. Preserves ``None`` values (parity with .NET)."""
raw = self._action_def.get("arguments")
if raw is None:
return {}
if not isinstance(raw, Mapping) or not raw:
return {}
result: dict[str, Any] = {}
for key, value in raw.items(): # type: ignore[reportUnknownVariableType]
if not isinstance(key, str) or not key:
continue
result[key] = state.eval_if_expression(value)
return result
def _get_headers(self, state: DeclarativeWorkflowState) -> dict[str, str]:
return self._evaluate_headers(state, self._action_def.get("headers"))
@staticmethod
def _evaluate_headers(state: DeclarativeWorkflowState, headers_def: Any) -> dict[str, str]:
"""Evaluate the ``headers`` map. Empty string values are skipped."""
if not isinstance(headers_def, Mapping) or not headers_def:
return {}
result: dict[str, str] = {}
for key, value in headers_def.items(): # type: ignore[reportUnknownVariableType]
if not isinstance(key, str) or not key:
continue
evaluated = state.eval_if_expression(value)
if evaluated is None:
continue
text = str(evaluated)
if not text:
continue
result[key] = text
return result
def _get_connection_name(self, state: DeclarativeWorkflowState) -> str | None:
connection = self._action_def.get("connection")
if not isinstance(connection, Mapping):
return None
name_expr: Any = connection.get("name") # type: ignore[reportUnknownMemberType]
if name_expr is None:
return None
evaluated = state.eval_if_expression(name_expr)
if evaluated is None:
return None
text = str(evaluated)
return text or None
def _get_require_approval(self, state: DeclarativeWorkflowState) -> bool:
raw = self._action_def.get("requireApproval")
if raw is None:
return False
evaluated = state.eval_if_expression(raw)
if isinstance(evaluated, bool):
return evaluated
if isinstance(evaluated, str):
return evaluated.strip().lower() in {"true", "1", "yes"}
return bool(evaluated)
def _get_auto_send(self, state: DeclarativeWorkflowState) -> bool:
output: Any = self._action_def.get("output")
if not isinstance(output, Mapping):
return True
raw: Any = output.get("autoSend") # type: ignore[reportUnknownMemberType]
if raw is None:
return True
evaluated = state.eval_if_expression(raw)
if isinstance(evaluated, bool):
return evaluated
if isinstance(evaluated, str):
return evaluated.strip().lower() in {"true", "1", "yes"}
return bool(evaluated)
# ----- Invocation + error handling ----------------------------------------
async def _invoke_with_narrow_catch(self, invocation: MCPToolInvocation) -> MCPToolResult:
"""Invoke the handler with a narrow exception catch.
Only known transport / tool exceptions are normalised to an error
result. Programmer bugs (TypeError, ValueError from misuse, etc.)
propagate so they fail loudly.
``asyncio.CancelledError`` is a ``BaseException``, not ``Exception``,
so it is not caught here and propagates unchanged for workflow
cancellation.
"""
try:
return await self._mcp_tool_handler.invoke_tool(invocation)
except ToolExecutionException as exc:
message = str(exc) or type(exc).__name__
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
except httpx.HTTPError as exc:
message = f"{type(exc).__name__}: {exc}" if str(exc) else type(exc).__name__
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
except Exception as exc:
try:
from mcp.shared.exceptions import McpError
except ImportError: # pragma: no cover - mcp is a hard dep
raise
if isinstance(exc, McpError):
message = str(exc) or type(exc).__name__
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
raise
# ----- Result handling -----------------------------------------------------
async def _process_result(
self,
*,
ctx: WorkflowContext[ActionComplete, str],
state: DeclarativeWorkflowState,
result: MCPToolResult,
auto_send: bool,
conversation_id_expr: str | None,
output_messages_path: str | None,
output_result_path: str | None,
) -> None:
"""Apply ``result`` to workflow state per the configured output paths."""
if result.is_error:
# Error path mirrors .NET ``AssignErrorAsync`` — only the result
# path is touched; messages / autoSend / conversation are not.
self._assign_error(
state,
output_result_path,
result.error_message or "MCP tool invocation failed.",
)
return
parsed_results = _parse_outputs(result.outputs)
if output_result_path is not None and parsed_results:
state.set(output_result_path, parsed_results)
# Single Tool-role message (matches .NET line 178 contract). Differs
# from InvokeFunctionTool's two-message [assistant call, tool result]
# convention.
tool_message = Message(role="tool", contents=list(result.outputs))
if output_messages_path is not None:
state.set(output_messages_path, tool_message)
if auto_send and parsed_results:
await ctx.yield_output(_format_outputs_for_send(parsed_results))
if conversation_id_expr:
messages_path = _get_messages_path(state, conversation_id_expr)
if messages_path is not None:
# Mirrors .NET: conversation gets ASSISTANT-role message with
# the same outputs (so chat history reads it as the agent's
# contribution).
assistant_message = Message(role="assistant", contents=list(result.outputs))
state.append(messages_path, assistant_message)
@staticmethod
def _assign_error(
state: DeclarativeWorkflowState,
output_result_path: str | None,
error_message: str,
) -> None:
"""Mirror .NET ``AssignErrorAsync``: store ``"Error: <msg>"`` at the result path."""
if output_result_path is None:
return
state.set(output_result_path, f"Error: {error_message}")
def _approval_key(self) -> str:
return f"{_MCP_APPROVAL_STATE_KEY}_{self.id}"
def _parse_outputs(outputs: list[Content]) -> list[Any]:
"""Parse :class:`Content` outputs into Python values for ``output.result``.
Mirrors .NET ``AssignResultAsync``:
- ``TextContent`` → JSON-parse text; on failure use the raw text.
- ``DataContent`` / ``UriContent`` → ``content.uri``.
- Other content kinds → ``str(content)``.
"""
parsed: list[Any] = []
for content in outputs:
kind = getattr(content, "type", None)
if kind == "text":
text_value = getattr(content, "text", None)
text_str = "" if text_value is None else str(text_value)
try:
parsed.append(json.loads(text_str))
except (json.JSONDecodeError, ValueError):
parsed.append(text_str)
continue
if kind in ("data", "uri"):
uri_value = getattr(content, "uri", None)
parsed.append("" if uri_value is None else str(uri_value))
continue
parsed.append(str(content))
return parsed
MCP_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = {
"InvokeMcpTool": InvokeMcpToolActionExecutor,
}
@@ -29,6 +29,7 @@ from .._loader import AgentFactory
from ._declarative_builder import DeclarativeWorkflowBuilder
from ._errors import DeclarativeWorkflowError
from ._http_handler import HttpRequestHandler
from ._mcp_handler import MCPToolHandler
logger = logging.getLogger("agent_framework.declarative")
@@ -91,6 +92,7 @@ class WorkflowFactory:
checkpoint_storage: CheckpointStorage | None = None,
max_iterations: int | None = None,
http_request_handler: HttpRequestHandler | None = None,
mcp_tool_handler: MCPToolHandler | None = None,
) -> None:
"""Initialize the workflow factory.
@@ -110,6 +112,13 @@ class WorkflowFactory:
otherwise. Use :class:`agent_framework.declarative.DefaultHttpRequestHandler`
for a no-policy ``httpx``-based default, or supply your own implementation
to enforce SSRF guards, allowlisting, or auth resolution.
mcp_tool_handler: Optional handler used to dispatch MCP tool calls for
``InvokeMcpTool``. Required if the workflow contains any
``InvokeMcpTool``; build will fail with :class:`DeclarativeWorkflowError`
otherwise. Use :class:`agent_framework.declarative.DefaultMCPToolHandler`
for a default backed by :class:`agent_framework.MCPStreamableHTTPTool`,
or supply your own implementation to enforce SSRF guards, allowlisting,
or auth/connection resolution.
Examples:
.. code-block:: python
@@ -150,6 +159,7 @@ class WorkflowFactory:
self._checkpoint_storage = checkpoint_storage
self._max_iterations = max_iterations
self._http_request_handler = http_request_handler
self._mcp_tool_handler = mcp_tool_handler
def create_workflow_from_yaml_path(
self,
@@ -394,6 +404,7 @@ class WorkflowFactory:
checkpoint_storage=self._checkpoint_storage,
max_iterations=self._max_iterations,
http_request_handler=self._http_request_handler,
mcp_tool_handler=self._mcp_tool_handler,
)
workflow = graph_builder.build()
except ValueError as e:
@@ -0,0 +1,494 @@
# Copyright (c) Microsoft. All rights reserved.
"""MCP tool handler abstraction for declarative workflows.
Mirrors the .NET ``IMcpToolHandler`` / ``DefaultMcpToolHandler`` pair from
``Microsoft.Agents.AI.Workflows.Declarative.Mcp``. Provides:
- :class:`MCPToolInvocation` — request input data passed from the executor.
- :class:`MCPToolResult` — response data returned to the executor.
- :class:`MCPToolHandler` — :class:`typing.Protocol` callers implement to plug
in custom transports (e.g. with allowlisting, Foundry connection resolution,
per-server auth, etc.).
- :class:`DefaultMCPToolHandler` — production-grade default backed by
:class:`agent_framework.MCPStreamableHTTPTool`.
Security note: :class:`DefaultMCPToolHandler` performs **no** URL filtering or
SSRF protection. Production deployments should supply a custom handler that
enforces an allowlist or DNS-rebinding-resistant policy. This split mirrors the
.NET design.
Prompt-injection note: MCP tool outputs flow back into agent conversations
(via ``conversationId`` and Tool-role messages emitted by the executor) so
they share the same risk surface as ``HttpRequestAction``. Workflow authors
must trust the MCP server they invoke.
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import logging
from collections import OrderedDict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, cast, runtime_checkable
import httpx
if TYPE_CHECKING:
from agent_framework import Content
__all__ = [
"ClientProvider",
"DefaultMCPToolHandler",
"MCPToolHandler",
"MCPToolInvocation",
"MCPToolResult",
]
logger = logging.getLogger(__name__)
_DEFAULT_CACHE_MAX_SIZE = 32
@dataclass
class MCPToolInvocation:
"""Description of an MCP tool call to be dispatched by a :class:`MCPToolHandler`.
Mirrors the input parameters of the .NET ``IMcpToolHandler.InvokeToolAsync``
method. Field semantics:
- ``server_url``: Absolute URL of the MCP server. Already evaluated from
the YAML expression.
- ``server_label``: Optional human-readable label used for diagnostics
and as the underlying ``MCPStreamableHTTPTool`` name.
- ``tool_name``: Name of the tool to invoke on the MCP server.
- ``arguments``: Tool arguments. Already evaluated; values may be any
JSON-serialisable Python object (str, int, bool, dict, list, None).
- ``headers``: Outbound HTTP headers (e.g. authentication). Empty values
are skipped by the executor before construction.
- ``connection_name``: Optional Foundry connection name forwarded for
handlers that resolve auth/credentials by connection. The default
handler does not consume this field.
"""
server_url: str
tool_name: str
server_label: str | None = None
arguments: dict[str, Any] = field(default_factory=dict) # type: ignore[reportUnknownVariableType]
headers: dict[str, str] = field(default_factory=dict) # type: ignore[reportUnknownVariableType]
connection_name: str | None = None
def _empty_outputs() -> list[Any]:
"""Default factory for ``MCPToolResult.outputs``.
Typed as ``list[Any]`` here to keep the dataclass field's runtime
factory simple; the public type on :class:`MCPToolResult` is
``list[Content]``.
"""
return []
@dataclass
class MCPToolResult:
"""Response returned by an :class:`MCPToolHandler`.
Mirrors the .NET ``McpServerToolResultContent`` shape. ``outputs`` is a
list of :class:`agent_framework.Content` items as parsed by the MCP
transport (TextContent / DataContent / UriContent / etc.).
On error, ``is_error`` is ``True``, ``error_message`` carries a human
readable description, and ``outputs`` typically contains a single
``Content.from_text("Error: ...")`` entry for downstream display.
"""
outputs: list[Content] = field(default_factory=_empty_outputs)
is_error: bool = False
error_message: str | None = None
@runtime_checkable
class MCPToolHandler(Protocol):
"""Protocol for MCP tool handlers used by ``InvokeMcpTool``.
Mirrors :class:`HttpRequestHandler` — declares ONLY the invocation method.
Lifecycle methods (``aclose`` / ``__aenter__`` / ``__aexit__``) are NOT
part of the Protocol; concrete implementations may add them as
appropriate.
Implementations must be safe to call concurrently from multiple workflow
runs. Implementations are responsible for any URL allowlisting, SSRF
guards, retry policies, auth resolution, and other policies the workflow
author wants applied.
"""
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
"""Dispatch ``invocation`` and return the result.
Args:
invocation: Description of the MCP tool call to perform.
Returns:
The :class:`MCPToolResult` carrying the parsed outputs (or an
error flag if the tool raised). Implementations SHOULD return a
result with ``is_error=True`` rather than raising for transport
or tool-level failures, so the workflow can store the message in
``output.result`` (matching .NET ``AssignErrorAsync`` behaviour).
They MAY raise on unexpected programming errors — these will be
propagated unchanged by the executor so they fail loudly.
"""
...
ClientProvider = Callable[[MCPToolInvocation], Awaitable["httpx.AsyncClient | None"]]
@dataclass
class _CacheEntry:
"""Internal record stored in the LRU cache."""
tool: Any # MCPStreamableHTTPTool — typed Any to avoid import at module load
owned_httpx_client: httpx.AsyncClient | None
class DefaultMCPToolHandler:
"""Default :class:`MCPToolHandler` backed by :class:`agent_framework.MCPStreamableHTTPTool`.
Caches one :class:`agent_framework.MCPStreamableHTTPTool` instance per
``(server_url, server_label, connection_name, headers_hash)`` in a
bounded LRU. The cache prevents re-establishing an MCP session for every
invocation while ensuring different header sets (auth tokens) cannot
share a session — matches the .NET design intent while bounding
cardinality. ``server_label`` and ``connection_name`` participate in
the key so that callers using ``client_provider`` to dispatch on those
fields receive a fresh client per logical connection (see below).
Header *names* are lower-cased inside the hash payload only — the
headers passed on the wire keep the caller's original casing — so two
YAML actions that spell ``Authorization`` differently still share a
cache entry.
Construction modes:
1. ``DefaultMCPToolHandler()`` — owns its own ``httpx.AsyncClient``
instances created lazily per cache entry. Closed by :meth:`aclose`.
2. ``DefaultMCPToolHandler(client_provider=cb)`` — per-server client
lookup (parity with .NET ``httpClientProvider`` callback). The
callback receives the full :class:`MCPToolInvocation` so it can
dispatch on ``server_url`` / ``connection_name`` / ``server_label``.
Returning ``None`` falls back to an internally-created client. Caller
supplied clients are NOT closed by :meth:`aclose`.
.. warning::
This handler performs **no** URL filtering or SSRF protection. Wrap
or replace it with a custom handler in production deployments.
Args:
client_provider: Optional per-server ``httpx.AsyncClient`` provider.
cache_max_size: Maximum number of cached MCP clients. When exceeded,
the least-recently-used entry is evicted and its client closed
(only owned clients are closed; caller-supplied ones are not).
Defaults to ``32``.
"""
def __init__(
self,
*,
client_provider: ClientProvider | None = None,
cache_max_size: int = _DEFAULT_CACHE_MAX_SIZE,
) -> None:
if cache_max_size <= 0:
raise ValueError(f"cache_max_size must be positive, got {cache_max_size}")
self._client_provider = client_provider
self._cache_max_size = cache_max_size
self._cache: OrderedDict[tuple[str, str, str, str], _CacheEntry] = OrderedDict()
# Outer lock guards the cache + in-flight-future map only — never
# held across network I/O.
self._cache_lock = asyncio.Lock()
# Per-key in-flight futures: while one task is connecting, other
# tasks awaiting the same key will await the same future and share
# the resulting cache entry.
self._inflight: dict[tuple[str, str, str, str], asyncio.Future[_CacheEntry]] = {}
# Set by ``aclose`` to prevent post-close cache insertions and to
# reject new ``invoke_tool`` calls. Once set, never cleared.
self._closed = False
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
"""Invoke ``invocation.tool_name`` on the cached MCP client for the server."""
from agent_framework import Content
from agent_framework.exceptions import ToolExecutionException
try:
entry = await self._get_or_create_entry(invocation)
except Exception as exc:
# Connect / cache lookup failures surface as tool errors so the
# workflow can store them at output.result without crashing.
logger.warning(
"DefaultMCPToolHandler: failed to obtain MCP client for url=%s tool=%s: %s",
invocation.server_url,
invocation.tool_name,
exc,
)
message = f"Failed to connect to MCP server: {type(exc).__name__}: {exc}".rstrip(": ")
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
try:
raw = await entry.tool.call_tool(invocation.tool_name, **invocation.arguments)
except ToolExecutionException as exc:
logger.info(
"DefaultMCPToolHandler: tool '%s' on '%s' raised ToolExecutionException",
invocation.tool_name,
invocation.server_url,
)
message = str(exc) or type(exc).__name__
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
except httpx.HTTPError as exc:
message = f"{type(exc).__name__}: {exc}" if str(exc) else type(exc).__name__
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
except Exception as exc:
# Be defensive about MCP errors that may bubble up without being
# wrapped in ToolExecutionException by custom parsers.
try:
from mcp.shared.exceptions import McpError
except ImportError: # pragma: no cover - mcp is a hard dep but stay defensive
raise
if isinstance(exc, McpError):
message = str(exc) or type(exc).__name__
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
raise
# Defensive normalisation: call_tool is typed ``str | list[Content]``.
# Default parser returns list, but custom parse_tool_results may return str.
if isinstance(raw, str):
outputs: list[Content] = [Content.from_text(raw)]
else:
outputs = list(raw)
return MCPToolResult(outputs=outputs)
async def aclose(self) -> None:
"""Close all cached MCP clients and the owned httpx clients.
Caller-supplied :class:`httpx.AsyncClient` instances (returned by the
``client_provider`` callback) are NOT closed.
Idempotent — a second call returns immediately. Drains any in-flight
``_create_entry`` tasks before returning so their resources are
cleaned up; the in-flight tasks see ``self._closed`` in phase 3 of
:meth:`_get_or_create_entry`, close their own entry, and resolve
their future with ``RuntimeError("DefaultMCPToolHandler is closed")``.
"""
async with self._cache_lock:
if self._closed:
return
self._closed = True
entries = list(self._cache.values())
self._cache.clear()
inflight_futures = list(self._inflight.values())
# Wait for in-flight creations to finish their self-cleanup. Each
# in-flight task self-closes its entry under the closed-flag branch
# in phase 3 and resolves its future with ``RuntimeError``; we
# swallow it here because the failure is expected at shutdown.
for fut in inflight_futures:
try:
await fut
except BaseException:
logger.debug("DefaultMCPToolHandler: in-flight future raised during aclose", exc_info=True)
continue
for entry in entries:
await self._close_entry(entry)
async def __aenter__(self) -> DefaultMCPToolHandler:
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
await self.aclose()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _get_or_create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry:
"""Look up (or create) the cached MCP client for this invocation."""
key = self._cache_key(
invocation.server_url,
invocation.server_label,
invocation.connection_name,
invocation.headers,
)
# Phase 1: check the cache and either claim creation or wait for an
# already in-flight creation.
creating = False
async with self._cache_lock:
if self._closed:
raise RuntimeError("DefaultMCPToolHandler is closed")
existing = self._cache.get(key)
if existing is not None:
self._cache.move_to_end(key)
return existing
inflight = self._inflight.get(key)
if inflight is None:
inflight = asyncio.get_running_loop().create_future()
self._inflight[key] = inflight
creating = True
if not creating:
return await inflight
# Phase 2: we own creation. Build the entry outside the lock.
try:
entry = await self._create_entry(invocation)
except BaseException as exc:
async with self._cache_lock:
self._inflight.pop(key, None)
if not inflight.done():
inflight.set_exception(exc if isinstance(exc, BaseException) else RuntimeError(str(exc)))
# Mark the exception retrieved to suppress noisy "Future exception
# was never retrieved" warnings when there are no other awaiters
# (other awaiters still see the exception through their ``await``).
inflight.exception()
raise
# Phase 3: insert with LRU eviction; resolve the in-flight future.
# If ``aclose`` ran while we were connecting, ``_closed`` is now
# True; don't insert into the cache (it has been drained), close
# the just-built entry, and surface the closed-handler error to
# all awaiters of the future.
evicted: _CacheEntry | None = None
duplicate: _CacheEntry | None = None
handler_closed = False
async with self._cache_lock:
self._inflight.pop(key, None)
if self._closed:
handler_closed = True
else:
existing = self._cache.get(key)
if existing is not None:
# Another writer beat us; prefer the existing entry and
# discard ours after the lock is released.
self._cache.move_to_end(key)
duplicate = entry
entry = existing
else:
self._cache[key] = entry
self._cache.move_to_end(key)
if len(self._cache) > self._cache_max_size:
_evicted_key, evicted = self._cache.popitem(last=False)
if not inflight.done():
inflight.set_result(entry)
if handler_closed:
# Close our orphaned entry; resolve the future with a clear
# error so the caller (and any other awaiters) surface a
# consistent "handler is closed" failure rather than receiving
# an entry we are about to close behind their back.
await self._close_entry(entry)
err = RuntimeError("DefaultMCPToolHandler is closed")
if not inflight.done():
inflight.set_exception(err)
inflight.exception()
raise err
if duplicate is not None:
await self._close_entry(duplicate)
if evicted is not None:
await self._close_entry(evicted)
return entry
async def _create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry:
"""Construct (and connect) a fresh MCP client for ``invocation``."""
from agent_framework import MCPStreamableHTTPTool
provided_client: httpx.AsyncClient | None = None
if self._client_provider is not None:
provided_client = await self._client_provider(invocation)
# Capture headers for this cache entry so the header_provider closure
# always returns the same set, regardless of the runtime kwargs.
captured_headers = dict(invocation.headers)
def _header_provider(_kwargs: dict[str, Any]) -> dict[str, str]:
return captured_headers
tool: Any = MCPStreamableHTTPTool(
name=invocation.server_label or "McpClient",
url=invocation.server_url,
load_prompts=False,
http_client=provided_client,
header_provider=_header_provider if captured_headers else None,
)
try:
await tool.connect()
except BaseException:
try:
await tool.close()
except Exception: # pragma: no cover - best effort
logger.debug("DefaultMCPToolHandler: error closing tool after failed connect", exc_info=True)
raise
# ``MCPStreamableHTTPTool.get_mcp_client`` lazily creates an
# ``httpx.AsyncClient`` when no caller client was provided AND a
# ``header_provider`` was set. We treat any client allocated this
# way as owned (closed by the handler). When the caller supplies
# one, we never close it.
owned_client: httpx.AsyncClient | None = None
if provided_client is None:
owned_client = cast("httpx.AsyncClient | None", getattr(tool, "_httpx_client", None))
return _CacheEntry(tool=tool, owned_httpx_client=owned_client)
async def _close_entry(self, entry: _CacheEntry) -> None:
"""Close the MCP tool and any owned httpx client."""
try:
await entry.tool.close()
except Exception: # pragma: no cover - best effort
logger.debug("DefaultMCPToolHandler: error closing MCP tool", exc_info=True)
if entry.owned_httpx_client is not None:
try:
await entry.owned_httpx_client.aclose()
except Exception: # pragma: no cover - best effort
logger.debug("DefaultMCPToolHandler: error closing owned httpx client", exc_info=True)
@staticmethod
def _cache_key(
server_url: str,
server_label: str | None,
connection_name: str | None,
headers: dict[str, str] | None,
) -> tuple[str, str, str, str]:
"""Build an order-independent cache key for the invocation identity.
The key includes ``server_label`` and ``connection_name`` so that
callers using ``client_provider`` to dispatch on those fields
receive a fresh client per logical connection (matches the
documented dispatch contract).
Header *names* are lower-cased inside the hash payload only so
that ``Authorization`` and ``authorization`` map to the same
cache entry. Header values remain case-sensitive (per RFC 7235).
"""
if not headers:
headers_hash = "0"
else:
normalized = sorted((k.lower(), v) for k, v in headers.items())
payload = json.dumps(normalized, ensure_ascii=False)
headers_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest()
return (server_url, server_label or "", connection_name or "", headers_hash)
@@ -0,0 +1,543 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for ``DefaultMCPToolHandler``.
These tests exercise the real handler against a fake ``MCPStreamableHTTPTool``
(no real MCP server, no real network) to cover the parts of the handler not
exercisable through the executor stub: cache hit/miss/eviction, concurrent
connect via in-flight futures, header isolation across cache keys,
string-result normalisation, ``load_prompts=False`` verification, and
owned-vs-caller httpx close semantics.
"""
from __future__ import annotations
import asyncio
import sys
from typing import Any
from unittest.mock import patch
import httpx
import pytest
from agent_framework import Content
from agent_framework.exceptions import ToolExecutionException
from agent_framework_declarative._workflows._mcp_handler import (
DefaultMCPToolHandler,
MCPToolInvocation,
)
pytestmark = pytest.mark.skipif(
sys.version_info >= (3, 14),
reason="Skipped on Python 3.14+ to keep parity with rest of declarative suite",
)
class FakeTool:
"""Stand-in for ``MCPStreamableHTTPTool``.
Records constructor kwargs, tracks connect/close lifecycle, and dispatches
``call_tool`` to a per-instance handler.
"""
instances: list[FakeTool] = []
def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs
self.connect_count = 0
self.close_count = 0
self.connect_delay: float = 0.0
self.connect_error: BaseException | None = None
self.call_handler: Any = lambda **_a: [Content.from_text("ok")]
self._httpx_client: httpx.AsyncClient | None = None
# Mimic MCPStreamableHTTPTool: when no caller client AND header_provider
# is set, lazily allocate an owned httpx client during connect.
FakeTool.instances.append(self)
async def connect(self) -> None:
if self.connect_delay:
await asyncio.sleep(self.connect_delay)
if self.connect_error is not None:
raise self.connect_error
self.connect_count += 1
# Mimic lazy httpx allocation when no client provided AND header_provider set.
if self.kwargs.get("http_client") is None and self.kwargs.get("header_provider") is not None:
self._httpx_client = httpx.AsyncClient()
async def close(self) -> None:
self.close_count += 1
async def call_tool(self, tool_name: str, **arguments: Any) -> Any:
return self.call_handler(tool_name=tool_name, **arguments)
@pytest.fixture(autouse=True)
def _clear_fake_instances() -> None:
FakeTool.instances.clear()
def _patch_tool() -> Any:
"""Patch the lazy import inside ``_create_entry`` to substitute FakeTool."""
import agent_framework
return patch.object(agent_framework, "MCPStreamableHTTPTool", FakeTool)
def _invocation(
*, server_url: str = "https://mcp.example/api", tool_name: str = "search", **overrides: Any
) -> MCPToolInvocation:
return MCPToolInvocation(
server_url=server_url,
tool_name=tool_name,
**overrides,
)
# ---------- Construction ---------------------------------------------------
class TestConstruction:
def test_invalid_cache_size_raises(self) -> None:
with pytest.raises(ValueError):
DefaultMCPToolHandler(cache_max_size=0)
with pytest.raises(ValueError):
DefaultMCPToolHandler(cache_max_size=-3)
# ---------- Tool kwargs ----------------------------------------------------
class TestToolKwargs:
@pytest.mark.asyncio
async def test_load_prompts_false_passed_to_tool(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation())
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].kwargs["load_prompts"] is False
@pytest.mark.asyncio
async def test_server_label_used_as_tool_name(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(server_label="MyMcp"))
assert FakeTool.instances[0].kwargs["name"] == "MyMcp"
@pytest.mark.asyncio
async def test_default_tool_name_when_no_label(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(server_label=None))
assert FakeTool.instances[0].kwargs["name"] == "McpClient"
@pytest.mark.asyncio
async def test_no_header_provider_when_no_headers(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={}))
assert FakeTool.instances[0].kwargs["header_provider"] is None
@pytest.mark.asyncio
async def test_header_provider_returns_captured_headers(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer T"}))
provider = FakeTool.instances[0].kwargs["header_provider"]
assert provider({}) == {"Authorization": "Bearer T"}
# Even if runtime kwargs change, captured headers stay the same.
assert provider({"foo": "bar"}) == {"Authorization": "Bearer T"}
# ---------- Cache behaviour ------------------------------------------------
class TestCache:
@pytest.mark.asyncio
async def test_same_url_and_headers_hit_cache(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"X": "1"}))
await handler.invoke_tool(_invocation(headers={"X": "1"}))
# One tool created, connect called once.
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].connect_count == 1
@pytest.mark.asyncio
async def test_different_headers_create_separate_entries(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"Authorization": "tk-A"}))
await handler.invoke_tool(_invocation(headers={"Authorization": "tk-B"}))
assert len(FakeTool.instances) == 2
@pytest.mark.asyncio
async def test_different_urls_create_separate_entries(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(server_url="https://mcp.a/api"))
await handler.invoke_tool(_invocation(server_url="https://mcp.b/api"))
assert len(FakeTool.instances) == 2
@pytest.mark.asyncio
async def test_lru_eviction_closes_old_entry(self) -> None:
handler = DefaultMCPToolHandler(cache_max_size=2)
with _patch_tool():
await handler.invoke_tool(_invocation(server_url="https://a/"))
await handler.invoke_tool(_invocation(server_url="https://b/"))
# Inserting a third evicts the LRU entry (the first one).
await handler.invoke_tool(_invocation(server_url="https://c/"))
assert len(FakeTool.instances) == 3
# First instance (https://a/) was evicted → close() called.
assert FakeTool.instances[0].kwargs["url"] == "https://a/"
assert FakeTool.instances[0].close_count == 1
# Other two remain in cache → not closed.
assert FakeTool.instances[1].close_count == 0
assert FakeTool.instances[2].close_count == 0
@pytest.mark.asyncio
async def test_repeated_use_keeps_lru_alive(self) -> None:
handler = DefaultMCPToolHandler(cache_max_size=2)
with _patch_tool():
await handler.invoke_tool(_invocation(server_url="https://a/"))
await handler.invoke_tool(_invocation(server_url="https://b/"))
# Touch a → b becomes LRU.
await handler.invoke_tool(_invocation(server_url="https://a/"))
# Insert c → b is evicted.
await handler.invoke_tool(_invocation(server_url="https://c/"))
# b was evicted.
b = FakeTool.instances[1]
assert b.kwargs["url"] == "https://b/"
assert b.close_count == 1
# a survived.
a = FakeTool.instances[0]
assert a.kwargs["url"] == "https://a/"
assert a.close_count == 0
@pytest.mark.asyncio
async def test_concurrent_connect_shares_one_entry(self) -> None:
"""Multiple concurrent invocations with the same key must share one tool."""
handler = DefaultMCPToolHandler()
# Slow down connect so concurrency window is observable.
original_connect = FakeTool.connect
async def slow_connect(self: FakeTool) -> None:
self.connect_delay = 0.05
await original_connect(self)
with _patch_tool(), patch.object(FakeTool, "connect", slow_connect):
results = await asyncio.gather(
handler.invoke_tool(_invocation(headers={"X": "1"})),
handler.invoke_tool(_invocation(headers={"X": "1"})),
handler.invoke_tool(_invocation(headers={"X": "1"})),
handler.invoke_tool(_invocation(headers={"X": "1"})),
)
assert all(not r.is_error for r in results)
# Only one tool was created and connected, despite 4 concurrent calls.
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].connect_count == 1
@pytest.mark.asyncio
async def test_different_connection_names_create_separate_entries(self) -> None:
"""Same URL/headers but different ``connection_name`` must dispatch separately."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(connection_name="conn-A"))
await handler.invoke_tool(_invocation(connection_name="conn-B"))
assert len(FakeTool.instances) == 2
@pytest.mark.asyncio
async def test_different_server_labels_create_separate_entries(self) -> None:
"""Same URL/headers but different ``server_label`` must dispatch separately."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(server_label="LabelA"))
await handler.invoke_tool(_invocation(server_label="LabelB"))
assert len(FakeTool.instances) == 2
@pytest.mark.asyncio
async def test_full_identity_match_hits_cache(self) -> None:
"""All four identity components match → single cached entry."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].connect_count == 1
@pytest.mark.asyncio
async def test_header_name_case_collapses_to_one_cache_entry(self) -> None:
"""Header name spelling differences (case-only) must share a cache entry."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"Authorization": "tk"}))
await handler.invoke_tool(_invocation(headers={"authorization": "tk"}))
await handler.invoke_tool(_invocation(headers={"AUTHORIZATION": "tk"}))
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].connect_count == 1
@pytest.mark.asyncio
async def test_header_value_case_does_not_collapse(self) -> None:
"""Header *values* remain case-sensitive (different tokens → different sessions)."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer-A"}))
await handler.invoke_tool(_invocation(headers={"Authorization": "bearer-a"}))
assert len(FakeTool.instances) == 2
# ---------- Aclose semantics ----------------------------------------------
class TestAclose:
@pytest.mark.asyncio
async def test_aclose_closes_owned_clients(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"X": "1"}))
tool = FakeTool.instances[0]
owned = tool._httpx_client
assert owned is not None
await handler.aclose()
assert tool.close_count == 1
assert owned.is_closed
@pytest.mark.asyncio
async def test_aclose_does_not_close_caller_supplied_client(self) -> None:
caller_client = httpx.AsyncClient()
async def provider(_inv: MCPToolInvocation) -> httpx.AsyncClient:
return caller_client
handler = DefaultMCPToolHandler(client_provider=provider)
try:
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"X": "1"}))
await handler.aclose()
assert FakeTool.instances[0].close_count == 1
# Caller client must still be usable.
assert not caller_client.is_closed
finally:
await caller_client.aclose()
@pytest.mark.asyncio
async def test_async_context_manager(self) -> None:
with _patch_tool():
async with DefaultMCPToolHandler() as handler:
await handler.invoke_tool(_invocation())
tool = FakeTool.instances[0]
assert tool.close_count == 1
@pytest.mark.asyncio
async def test_aclose_is_idempotent(self) -> None:
"""A second ``aclose`` is a no-op (no exception, no double-close)."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.invoke_tool(_invocation(headers={"X": "1"}))
await handler.aclose()
await handler.aclose()
assert FakeTool.instances[0].close_count == 1
@pytest.mark.asyncio
async def test_invoke_after_close_returns_error_result(self) -> None:
"""Post-close ``invoke_tool`` surfaces a tool error rather than crashing."""
handler = DefaultMCPToolHandler()
with _patch_tool():
await handler.aclose()
result = await handler.invoke_tool(_invocation())
assert result.is_error is True
assert "closed" in (result.error_message or "").lower()
@pytest.mark.asyncio
async def test_aclose_drains_inflight_creation(self) -> None:
"""An in-flight ``_create_entry`` must not leak when ``aclose`` races with it.
Reproduces the race described in PR #5630 review-comment 3:
task A claims an inflight future and starts a slow connect; task B
runs ``aclose``; task A must self-clean (close its tool + httpx
client) and surface a closed-handler error rather than orphaning
the entry.
"""
handler = DefaultMCPToolHandler()
connect_started = asyncio.Event()
release_connect = asyncio.Event()
original_connect = FakeTool.connect
async def gated_connect(self: FakeTool) -> None:
connect_started.set()
await release_connect.wait()
await original_connect(self)
with _patch_tool(), patch.object(FakeTool, "connect", gated_connect):
invoke_task = asyncio.create_task(handler.invoke_tool(_invocation(headers={"X": "1"})))
# Wait until task A is mid-connect.
await connect_started.wait()
# Race: kick off aclose. It must wait for the in-flight task.
close_task = asyncio.create_task(handler.aclose())
# Yield once to ensure aclose has set _closed and is awaiting.
await asyncio.sleep(0)
# Allow the connect to complete; phase 3 sees _closed and self-cleans.
release_connect.set()
result = await invoke_task
await close_task
# Entry was created and then closed by the in-flight task itself.
assert len(FakeTool.instances) == 1
assert FakeTool.instances[0].close_count == 1
# The originating invocation surfaces a closed-handler error.
assert result.is_error is True
assert "closed" in (result.error_message or "").lower()
# ---------- Result normalisation ------------------------------------------
class TestResultNormalisation:
@pytest.mark.asyncio
async def test_string_result_wrapped_in_text_content(self) -> None:
handler = DefaultMCPToolHandler()
with _patch_tool():
inv = _invocation()
result = await handler.invoke_tool(inv)
# The fake's default already returns a list; replace handler for this test.
FakeTool.instances[0].call_handler = lambda **_a: "raw string body"
result = await handler.invoke_tool(inv)
assert result.is_error is False
assert len(result.outputs) == 1
assert result.outputs[0].text == "raw string body" # type: ignore[reportAttributeAccessIssue]
@pytest.mark.asyncio
async def test_list_result_passed_through(self) -> None:
handler = DefaultMCPToolHandler()
custom = [Content.from_text("a"), Content.from_text("b")]
with _patch_tool():
inv = _invocation()
await handler.invoke_tool(inv)
FakeTool.instances[0].call_handler = lambda **_a: custom
result = await handler.invoke_tool(inv)
assert result.is_error is False
assert len(result.outputs) == 2
# ---------- Error mapping --------------------------------------------------
class TestErrorMapping:
@pytest.mark.asyncio
async def test_tool_execution_exception_returns_error_result(self) -> None:
handler = DefaultMCPToolHandler()
def boom(**_a: Any) -> Any:
raise ToolExecutionException("server says no")
with _patch_tool():
inv = _invocation()
await handler.invoke_tool(inv)
FakeTool.instances[0].call_handler = boom
result = await handler.invoke_tool(inv)
assert result.is_error is True
assert result.error_message == "server says no"
assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue]
@pytest.mark.asyncio
async def test_httpx_error_returns_error_result(self) -> None:
handler = DefaultMCPToolHandler()
def boom(**_a: Any) -> Any:
raise httpx.ConnectError("dns failure")
with _patch_tool():
inv = _invocation()
await handler.invoke_tool(inv)
FakeTool.instances[0].call_handler = boom
result = await handler.invoke_tool(inv)
assert result.is_error is True
assert "dns failure" in (result.error_message or "")
@pytest.mark.asyncio
async def test_unexpected_exception_propagates(self) -> None:
"""RuntimeError (not in the narrow catch list) must propagate."""
handler = DefaultMCPToolHandler()
def boom(**_a: Any) -> Any:
raise RuntimeError("programmer error")
with _patch_tool():
inv = _invocation()
await handler.invoke_tool(inv)
FakeTool.instances[0].call_handler = boom
with pytest.raises(RuntimeError, match="programmer error"):
await handler.invoke_tool(inv)
@pytest.mark.asyncio
async def test_connect_failure_returns_error_result(self) -> None:
handler = DefaultMCPToolHandler()
with (
_patch_tool(),
patch.object(
FakeTool,
"connect",
lambda self: (_ for _ in ()).throw(httpx.ConnectError("server down")),
),
):
result = await handler.invoke_tool(_invocation())
assert result.is_error is True
assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue]
# Failed connect must clear in-flight + cache entries.
assert handler._inflight == {}
assert len(handler._cache) == 0
@pytest.mark.asyncio
async def test_cancelled_error_propagates(self) -> None:
"""asyncio.CancelledError is BaseException, must NOT be swallowed."""
handler = DefaultMCPToolHandler()
def boom(**_a: Any) -> Any:
raise asyncio.CancelledError
with _patch_tool():
inv = _invocation()
await handler.invoke_tool(inv)
FakeTool.instances[0].call_handler = boom
with pytest.raises(asyncio.CancelledError):
await handler.invoke_tool(inv)
# ---------- Cache key isolation -------------------------------------------
class TestCacheKey:
def test_key_order_independent(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1", "B": "2"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"B": "2", "A": "1"})
assert k1 == k2
def test_key_distinguishes_values(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "2"})
assert k1 != k2
def test_empty_headers_use_fixed_hash(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, None)
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {})
assert k1 == k2
def test_key_distinguishes_connection_name(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-A", None)
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-B", None)
assert k1 != k2
def test_key_distinguishes_server_label(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-A", None, None)
k2 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-B", None, None)
assert k1 != k2
def test_key_collapses_header_name_case(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"Authorization": "tk"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"authorization": "tk"})
assert k1 == k2
def test_key_keeps_header_value_case(self) -> None:
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "Bearer-A"})
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "bearer-a"})
assert k1 != k2
@@ -0,0 +1,664 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for ``InvokeMcpToolActionExecutor``.
Use a stub :class:`MCPToolHandler` that returns canned :class:`MCPToolResult`s.
No real MCP server or network is exercised. See
``test_default_mcp_tool_handler.py`` for tests that exercise the real
``DefaultMCPToolHandler`` against a mocked ``MCPStreamableHTTPTool``.
"""
import sys
from typing import Any
import httpx
import pytest
try:
import powerfx # noqa: F401
_powerfx_available = True
except (ImportError, RuntimeError):
_powerfx_available = False
pytestmark = pytest.mark.skipif(
not _powerfx_available or sys.version_info >= (3, 14),
reason="PowerFx engine not available (requires dotnet runtime)",
)
from agent_framework import Content, Message # noqa: E402
from agent_framework.exceptions import ToolExecutionException # noqa: E402
from agent_framework_declarative._workflows import ( # noqa: E402
DECLARATIVE_STATE_KEY,
DeclarativeWorkflowError,
MCPToolHandler,
MCPToolInvocation,
MCPToolResult,
WorkflowFactory,
)
class StubMcpHandler:
"""Test stub recording the last call and returning a canned result."""
def __init__(
self,
result: MCPToolResult | None = None,
*,
raise_exc: BaseException | None = None,
) -> None:
self.result = result
self.raise_exc = raise_exc
self.last_invocation: MCPToolInvocation | None = None
self.invocations: list[MCPToolInvocation] = []
self.call_count = 0
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
self.call_count += 1
self.last_invocation = invocation
self.invocations.append(invocation)
if self.raise_exc is not None:
raise self.raise_exc
assert self.result is not None
return self.result
def _ok(outputs: list[Content] | None = None) -> MCPToolResult:
return MCPToolResult(outputs=outputs or [Content.from_text("hello")])
def _err(message: str = "boom") -> MCPToolResult:
return MCPToolResult(
outputs=[Content.from_text(f"Error: {message}")],
is_error=True,
error_message=message,
)
def _action(
*,
server_url: str = "https://mcp.example/api",
tool_name: str = "search",
server_label: str | None = None,
arguments: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
require_approval: Any = None,
connection: dict[str, Any] | None = None,
conversation_id: str | None = None,
output: dict[str, Any] | None = None,
) -> dict[str, Any]:
action: dict[str, Any] = {
"kind": "InvokeMcpTool",
"id": "mcp_action",
"serverUrl": server_url,
"toolName": tool_name,
}
if server_label is not None:
action["serverLabel"] = server_label
if arguments is not None:
action["arguments"] = arguments
if headers is not None:
action["headers"] = headers
if require_approval is not None:
action["requireApproval"] = require_approval
if connection is not None:
action["connection"] = connection
if conversation_id is not None:
action["conversationId"] = conversation_id
if output is not None:
action["output"] = output
return action
def _yaml(action: dict[str, Any]) -> dict[str, Any]:
return {"name": "mcp_test", "actions": [action]}
# ---------- Builder enforcement --------------------------------------------
class TestBuilderEnforcement:
def test_missing_handler_raises_at_build_time(self) -> None:
factory = WorkflowFactory()
with pytest.raises(DeclarativeWorkflowError) as excinfo:
factory.create_workflow_from_definition(_yaml(_action()))
assert "InvokeMcpTool" in str(excinfo.value)
assert "mcp_tool_handler" in str(excinfo.value)
def test_missing_server_url_fails_validation(self) -> None:
handler = StubMcpHandler(_ok())
factory = WorkflowFactory(mcp_tool_handler=handler)
action = _action()
del action["serverUrl"]
with pytest.raises(Exception) as excinfo:
factory.create_workflow_from_definition(_yaml(action))
assert "serverUrl" in str(excinfo.value)
def test_missing_tool_name_fails_validation(self) -> None:
handler = StubMcpHandler(_ok())
factory = WorkflowFactory(mcp_tool_handler=handler)
action = _action()
del action["toolName"]
with pytest.raises(Exception) as excinfo:
factory.create_workflow_from_definition(_yaml(action))
assert "toolName" in str(excinfo.value)
# ---------- Field forwarding ----------------------------------------------
class TestFieldForwarding:
@pytest.mark.asyncio
async def test_basic_invocation_forwards_required_fields(self) -> None:
handler = StubMcpHandler(_ok())
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action()))
await workflow.run({})
assert handler.call_count == 1
inv = handler.last_invocation
assert inv is not None
assert inv.server_url == "https://mcp.example/api"
assert inv.tool_name == "search"
assert inv.server_label is None
assert inv.headers == {}
assert inv.arguments == {}
assert inv.connection_name is None
@pytest.mark.asyncio
async def test_arguments_evaluated_and_preserves_none(self) -> None:
handler = StubMcpHandler(_ok())
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(
_yaml(
_action(
arguments={
"query": "weather today",
"limit": 5,
"fresh": True,
"missing": None,
}
)
)
)
await workflow.run({})
inv = handler.last_invocation
assert inv is not None
# ``None`` is preserved (parity with .NET) — caller decides.
assert inv.arguments == {
"query": "weather today",
"limit": 5,
"fresh": True,
"missing": None,
}
@pytest.mark.asyncio
async def test_headers_drop_empty_values(self) -> None:
handler = StubMcpHandler(_ok())
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(
_yaml(
_action(
headers={
"Authorization": "Bearer token-123",
"X-Trace": "trace-id",
"X-Empty": "",
}
)
)
)
await workflow.run({})
inv = handler.last_invocation
assert inv is not None
assert inv.headers == {
"Authorization": "Bearer token-123",
"X-Trace": "trace-id",
}
@pytest.mark.asyncio
async def test_server_label_and_connection_name_forwarded(self) -> None:
handler = StubMcpHandler(_ok())
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(
_yaml(
_action(
server_label="docs-mcp",
connection={"name": "azure-conn"},
)
)
)
await workflow.run({})
inv = handler.last_invocation
assert inv is not None
assert inv.server_label == "docs-mcp"
assert inv.connection_name == "azure-conn"
# ---------- Output handling ------------------------------------------------
class TestOutput:
@pytest.mark.asyncio
async def test_output_result_parses_json_text(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text('{"k":"v","n":1}')]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
assert decl["Local"]["Result"] == [{"k": "v", "n": 1}]
@pytest.mark.asyncio
async def test_output_result_falls_back_to_raw_text(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text("plain text not json")]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
assert decl["Local"]["Result"] == ["plain text not json"]
@pytest.mark.asyncio
async def test_output_messages_writes_single_tool_role_message(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text("hi"), Content.from_text("there")]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"messages": "Local.Messages"})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
msg = decl["Local"]["Messages"]
# Single Tool-role message containing both contents (parity with .NET).
assert isinstance(msg, Message)
assert str(msg.role).lower() == "tool"
assert len(msg.contents) == 2
@pytest.mark.asyncio
async def test_uri_content_serialised_as_uri_string(self) -> None:
uri_content = Content.from_uri("https://example.com/file.txt", media_type="text/plain")
handler = StubMcpHandler(_ok([uri_content]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
assert decl["Local"]["Result"] == ["https://example.com/file.txt"]
@pytest.mark.asyncio
async def test_output_path_object_form(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text("ok")]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": {"path": "Local.Result"}})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
assert decl["Local"]["Result"] == ["ok"]
# ---------- Conversation append --------------------------------------------
class TestConversation:
@pytest.mark.asyncio
async def test_conversation_id_appends_assistant_message(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text("answer")]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(
_yaml(
_action(
conversation_id="conv-42",
output={"result": "Local.Result"},
)
)
)
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
conv = decl["System"]["conversations"]["conv-42"]
msgs = conv["messages"] if isinstance(conv, dict) else conv.messages
assert len(msgs) == 1
appended = msgs[0]
assert str(appended.role).lower() == "assistant"
# Same contents as the tool output.
assert len(appended.contents) == 1
@pytest.mark.asyncio
async def test_empty_conversation_id_does_not_append(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text("answer")]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(
_yaml(
_action(
conversation_id="",
output={"result": "Local.Result"},
)
)
)
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
# Empty conversation id must not produce a `""` entry under System.conversations.
conversations = decl.get("System", {}).get("conversations", {})
assert "" not in conversations
# ---------- Approval flow --------------------------------------------------
@pytest.fixture
def mock_state(): # type: ignore[no-untyped-def]
from unittest.mock import MagicMock
state = MagicMock()
state._data = {}
def _get(key: str, default: Any = None) -> Any:
if key not in state._data:
if default is not None:
return default
raise KeyError(key)
return state._data[key]
def _set(key: str, value: Any) -> None:
state._data[key] = value
def _delete(key: str) -> None:
if key in state._data:
del state._data[key]
else:
raise KeyError(key)
state.get = MagicMock(side_effect=_get)
state.set = MagicMock(side_effect=_set)
state.delete = MagicMock(side_effect=_delete)
return state
@pytest.fixture
def mock_context(mock_state): # type: ignore[no-untyped-def]
from unittest.mock import AsyncMock, MagicMock
ctx = MagicMock()
ctx.state = mock_state
ctx.send_message = AsyncMock()
ctx.yield_output = AsyncMock()
ctx.request_info = AsyncMock()
return ctx
def _seed_state(mock_state) -> None: # type: ignore[no-untyped-def]
"""Pre-seed the declarative state container as the executors expect."""
from agent_framework_declarative._workflows import DECLARATIVE_STATE_KEY
mock_state._data[DECLARATIVE_STATE_KEY] = {
"Local": {},
"Custom": {},
"Workflow": {},
"System": {
"ConversationId": "00000000-0000-0000-0000-000000000000",
"LastMessage": {"Id": "", "Text": ""},
"LastMessageText": "",
"LastMessageId": "",
},
"Agent": {},
"Conversation": {"messages": [], "history": []},
"Inputs": {},
}
class TestApprovalFlow:
@pytest.mark.asyncio
async def test_approval_required_emits_request_and_yields(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
from agent_framework_declarative._workflows._executors_mcp import (
_MCP_APPROVAL_STATE_KEY,
InvokeMcpToolActionExecutor,
MCPToolApprovalRequest,
)
_seed_state(mock_state)
handler = StubMcpHandler(_ok())
executor = InvokeMcpToolActionExecutor(
_action(
require_approval=True,
arguments={"q": "x"},
headers={"Authorization": "Bearer SECRET"},
output={"result": "Local.Result"},
),
mcp_tool_handler=handler,
)
await executor.handle_action(ActionTrigger(), mock_context)
# Approval request emitted.
mock_context.request_info.assert_called_once()
request = mock_context.request_info.call_args[0][0]
assert isinstance(request, MCPToolApprovalRequest)
assert request.tool_name == "search"
assert request.arguments == {"q": "x"}
assert request.header_names == ["Authorization"]
# NEVER expose the actual auth token in any field of the approval payload.
for value in request.__dict__.values():
assert "SECRET" not in str(value)
# Workflow should yield (no ActionComplete sent yet).
mock_context.send_message.assert_not_called()
# Handler not invoked yet.
assert handler.call_count == 0
# Approval state stored.
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
assert approval_key in mock_state._data
@pytest.mark.asyncio
async def test_approval_response_approved_invokes_handler(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
from agent_framework_declarative._workflows import ActionComplete, ToolApprovalResponse
from agent_framework_declarative._workflows._executors_mcp import (
_MCP_APPROVAL_STATE_KEY,
InvokeMcpToolActionExecutor,
MCPToolApprovalRequest,
_MCPToolApprovalState,
)
_seed_state(mock_state)
handler = StubMcpHandler(_ok([Content.from_text('{"ok":true}')]))
executor = InvokeMcpToolActionExecutor(
_action(
require_approval=True,
output={"result": "Local.Result"},
),
mcp_tool_handler=handler,
)
# Pre-populate approval state.
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
mock_state._data[approval_key] = _MCPToolApprovalState(
server_url="https://mcp.example/api",
tool_name="search",
server_label=None,
arguments={"q": "x"},
connection_name=None,
headers_def={"Authorization": "Bearer tk"},
auto_send=False,
conversation_id_expr=None,
output_messages_path=None,
output_result_path="Local.Result",
)
await executor.handle_approval_response(
MCPToolApprovalRequest(
request_id="req-1",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "x"},
),
ToolApprovalResponse(approved=True),
mock_context,
)
assert handler.call_count == 1
inv = handler.last_invocation
assert inv is not None
# Headers are re-evaluated from headers_def.
assert inv.headers == {"Authorization": "Bearer tk"}
# Approval state was cleaned up.
assert approval_key not in mock_state._data
# ActionComplete was sent.
mock_context.send_message.assert_called_once()
sent = mock_context.send_message.call_args[0][0]
assert isinstance(sent, ActionComplete)
@pytest.mark.asyncio
async def test_approval_response_rejected_assigns_error(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
from agent_framework_declarative._workflows import ToolApprovalResponse
from agent_framework_declarative._workflows._executors_mcp import (
_MCP_APPROVAL_STATE_KEY,
InvokeMcpToolActionExecutor,
MCPToolApprovalRequest,
_MCPToolApprovalState,
)
_seed_state(mock_state)
handler = StubMcpHandler(_ok())
executor = InvokeMcpToolActionExecutor(
_action(
require_approval=True,
output={"result": "Local.Result"},
),
mcp_tool_handler=handler,
)
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
mock_state._data[approval_key] = _MCPToolApprovalState(
server_url="https://mcp.example/api",
tool_name="search",
server_label=None,
arguments={},
connection_name=None,
headers_def=None,
auto_send=True,
conversation_id_expr=None,
output_messages_path=None,
output_result_path="Local.Result",
)
await executor.handle_approval_response(
MCPToolApprovalRequest(
request_id="req-2",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={},
),
ToolApprovalResponse(approved=False, reason="not authorized"),
mock_context,
)
assert handler.call_count == 0
# Error string assigned at output.result.
from agent_framework_declarative._workflows import DECLARATIVE_STATE_KEY
result = mock_state._data[DECLARATIVE_STATE_KEY]["Local"]["Result"]
assert result == "Error: MCP tool invocation was not approved by user."
# ---------- Error handling -------------------------------------------------
class TestErrorHandling:
@pytest.mark.asyncio
async def test_handler_returns_error_result_assigns_error_string(self) -> None:
handler = StubMcpHandler(_err("server down"))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
assert decl["Local"]["Result"] == "Error: server down"
@pytest.mark.asyncio
async def test_tool_execution_exception_becomes_error_result(self) -> None:
handler = StubMcpHandler(raise_exc=ToolExecutionException("invalid arguments"))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
assert decl["Local"]["Result"] == "Error: invalid arguments"
@pytest.mark.asyncio
async def test_httpx_error_becomes_error_result(self) -> None:
handler = StubMcpHandler(raise_exc=httpx.ConnectError("dns fail"))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
await workflow.run({})
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
result = decl["Local"]["Result"]
assert isinstance(result, str)
assert result.startswith("Error:")
assert "ConnectError" in result
@pytest.mark.asyncio
async def test_unexpected_exception_propagates(self) -> None:
"""Programmer bugs (TypeError etc.) must NOT be swallowed."""
handler = StubMcpHandler(raise_exc=TypeError("bad type"))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action()))
with pytest.raises(Exception) as excinfo:
await workflow.run({})
# Either the TypeError reaches us or it gets wrapped by the runner —
# either way the message must surface.
assert "bad type" in str(excinfo.value)
# ---------- autoSend -------------------------------------------------------
class TestAutoSend:
@pytest.mark.asyncio
async def test_auto_send_default_true_yields_output(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text("hello")]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action()))
events = await workflow.run({})
outputs = events.get_outputs()
assert len(outputs) == 1
@pytest.mark.asyncio
async def test_auto_send_false_suppresses_yield(self) -> None:
handler = StubMcpHandler(_ok([Content.from_text("hello")]))
factory = WorkflowFactory(mcp_tool_handler=handler)
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"autoSend": False})))
events = await workflow.run({})
outputs = events.get_outputs()
assert outputs == []
# ---------- Protocol structure --------------------------------------------
class TestProtocol:
def test_stub_handler_satisfies_protocol(self) -> None:
handler = StubMcpHandler(_ok())
assert isinstance(handler, MCPToolHandler)
# ---------- _format_outputs_for_send --------------------------------------
class TestFormatOutputsForSend:
"""Direct tests for the auto-send rendering helper.
Regression for PR #5630 review-comment 4: a single scalar JSON value
must render bare (e.g. ``"42"``) rather than wrapped (``"[42]"``).
"""
@pytest.mark.parametrize(
("parsed", "expected"),
[
([], ""),
(["hello"], "hello"),
(["a", "b"], "a\nb"),
([42], "42"),
([3.14], "3.14"),
([True], "true"),
([False], "false"),
([None], "null"),
([{"k": "v"}], '{"k": "v"}'),
([[1, 2]], "[1, 2]"),
(["hello", 42], '["hello", 42]'),
([{"a": 1}, {"b": 2}], '[{"a": 1}, {"b": 2}]'),
],
)
def test_format_outputs_for_send(self, parsed: list[Any], expected: str) -> None:
from agent_framework_declarative._workflows._executors_mcp import _format_outputs_for_send
assert _format_outputs_for_send(parsed) == expected