diff --git a/python/packages/core/agent_framework/declarative/__init__.py b/python/packages/core/agent_framework/declarative/__init__.py index ba88e6a0a9..b5e9c9ef9e 100644 --- a/python/packages/core/agent_framework/declarative/__init__.py +++ b/python/packages/core/agent_framework/declarative/__init__.py @@ -25,13 +25,20 @@ _IMPORTS = [ "DeclarativeLoaderError", "DeclarativeWorkflowError", "DefaultHttpRequestHandler", + "DefaultMCPToolHandler", "ExternalInputRequest", "ExternalInputResponse", "HttpRequestHandler", "HttpRequestInfo", "HttpRequestResult", + "MCPToolApprovalRequest", + "MCPToolHandler", + "MCPToolInvocation", + "MCPToolResult", "ProviderLookupError", "ProviderTypeMapping", + "ToolApprovalRequest", + "ToolApprovalResponse", "WorkflowFactory", "WorkflowState", ] diff --git a/python/packages/core/agent_framework/declarative/__init__.pyi b/python/packages/core/agent_framework/declarative/__init__.pyi index f18be22f50..c64e730441 100644 --- a/python/packages/core/agent_framework/declarative/__init__.pyi +++ b/python/packages/core/agent_framework/declarative/__init__.pyi @@ -8,13 +8,20 @@ from agent_framework_declarative import ( DeclarativeLoaderError, DeclarativeWorkflowError, DefaultHttpRequestHandler, + DefaultMCPToolHandler, ExternalInputRequest, ExternalInputResponse, HttpRequestHandler, HttpRequestInfo, HttpRequestResult, + MCPToolApprovalRequest, + MCPToolHandler, + MCPToolInvocation, + MCPToolResult, ProviderLookupError, ProviderTypeMapping, + ToolApprovalRequest, + ToolApprovalResponse, WorkflowFactory, WorkflowState, ) @@ -27,13 +34,20 @@ __all__ = [ "DeclarativeLoaderError", "DeclarativeWorkflowError", "DefaultHttpRequestHandler", + "DefaultMCPToolHandler", "ExternalInputRequest", "ExternalInputResponse", "HttpRequestHandler", "HttpRequestInfo", "HttpRequestResult", + "MCPToolApprovalRequest", + "MCPToolHandler", + "MCPToolInvocation", + "MCPToolResult", "ProviderLookupError", "ProviderTypeMapping", + "ToolApprovalRequest", + "ToolApprovalResponse", "WorkflowFactory", "WorkflowState", ] diff --git a/python/packages/declarative/AGENTS.md b/python/packages/declarative/AGENTS.md index 1add614601..3c9402fc4e 100644 --- a/python/packages/declarative/AGENTS.md +++ b/python/packages/declarative/AGENTS.md @@ -9,6 +9,7 @@ YAML/JSON-based declarative agent and workflow definitions. - **`WorkflowState`** - State management for declarative workflows - **`ProviderTypeMapping`** - Maps provider types to implementations - **`HttpRequestHandler`** / **`DefaultHttpRequestHandler`** - Pluggable HTTP transport for the `HttpRequestAction` declarative action (configured via `WorkflowFactory(http_request_handler=...)`) +- **`MCPToolHandler`** / **`DefaultMCPToolHandler`** - Pluggable MCP transport for the `InvokeMcpTool` declarative action (configured via `WorkflowFactory(mcp_tool_handler=...)`) - **`DeclarativeLoaderError`** / **`ProviderLookupError`** / **`DeclarativeWorkflowError`** / **`DeclarativeActionError`** - Error types ## External Input Handling diff --git a/python/packages/declarative/agent_framework_declarative/__init__.py b/python/packages/declarative/agent_framework_declarative/__init__.py index 6afcb3c791..84bc404d5d 100644 --- a/python/packages/declarative/agent_framework_declarative/__init__.py +++ b/python/packages/declarative/agent_framework_declarative/__init__.py @@ -9,11 +9,18 @@ from ._workflows import ( DeclarativeActionError, DeclarativeWorkflowError, DefaultHttpRequestHandler, + DefaultMCPToolHandler, ExternalInputRequest, ExternalInputResponse, HttpRequestHandler, HttpRequestInfo, HttpRequestResult, + MCPToolApprovalRequest, + MCPToolHandler, + MCPToolInvocation, + MCPToolResult, + ToolApprovalRequest, + ToolApprovalResponse, WorkflowFactory, WorkflowState, ) @@ -31,13 +38,20 @@ __all__ = [ "DeclarativeLoaderError", "DeclarativeWorkflowError", "DefaultHttpRequestHandler", + "DefaultMCPToolHandler", "ExternalInputRequest", "ExternalInputResponse", "HttpRequestHandler", "HttpRequestInfo", "HttpRequestResult", + "MCPToolApprovalRequest", + "MCPToolHandler", + "MCPToolInvocation", + "MCPToolResult", "ProviderLookupError", "ProviderTypeMapping", + "ToolApprovalRequest", + "ToolApprovalResponse", "WorkflowFactory", "WorkflowState", "__version__", diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py index c199e4551b..d06fdeba17 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py @@ -72,6 +72,11 @@ from ._executors_http import ( HTTP_ACTION_EXECUTORS, HttpRequestActionExecutor, ) +from ._executors_mcp import ( + MCP_ACTION_EXECUTORS, + InvokeMcpToolActionExecutor, + MCPToolApprovalRequest, +) from ._executors_tools import ( FUNCTION_TOOL_REGISTRY_KEY, TOOL_ACTION_EXECUTORS, @@ -90,6 +95,12 @@ from ._http_handler import ( HttpRequestInfo, HttpRequestResult, ) +from ._mcp_handler import ( + DefaultMCPToolHandler, + MCPToolHandler, + MCPToolInvocation, + MCPToolResult, +) from ._state import WorkflowState __all__ = [ @@ -102,6 +113,7 @@ __all__ = [ "EXTERNAL_INPUT_EXECUTORS", "FUNCTION_TOOL_REGISTRY_KEY", "HTTP_ACTION_EXECUTORS", + "MCP_ACTION_EXECUTORS", "TOOL_ACTION_EXECUTORS", "TOOL_APPROVAL_STATE_KEY", "TOOL_REGISTRY_KEY", @@ -126,6 +138,7 @@ __all__ = [ "DeclarativeWorkflowError", "DeclarativeWorkflowState", "DefaultHttpRequestHandler", + "DefaultMCPToolHandler", "EmitEventExecutor", "EndConversationExecutor", "EndWorkflowExecutor", @@ -140,9 +153,14 @@ __all__ = [ "HttpRequestResult", "InvokeAzureAgentExecutor", "InvokeFunctionToolExecutor", + "InvokeMcpToolActionExecutor", "JoinExecutor", "LoopControl", "LoopIterationResult", + "MCPToolApprovalRequest", + "MCPToolHandler", + "MCPToolInvocation", + "MCPToolResult", "QuestionExecutor", "RequestExternalInputExecutor", "ResetVariableExecutor", diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py index fb5dcb88f8..67b4a58273 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py @@ -41,8 +41,10 @@ from ._executors_control_flow import ( ) from ._executors_external_input import EXTERNAL_INPUT_EXECUTORS from ._executors_http import HTTP_ACTION_EXECUTORS, HttpRequestActionExecutor +from ._executors_mcp import MCP_ACTION_EXECUTORS, InvokeMcpToolActionExecutor from ._executors_tools import TOOL_ACTION_EXECUTORS, InvokeFunctionToolExecutor from ._http_handler import HttpRequestHandler +from ._mcp_handler import MCPToolHandler logger = logging.getLogger(__name__) @@ -55,6 +57,7 @@ ALL_ACTION_EXECUTORS = { **EXTERNAL_INPUT_EXECUTORS, **TOOL_ACTION_EXECUTORS, **HTTP_ACTION_EXECUTORS, + **MCP_ACTION_EXECUTORS, } # Action kinds that terminate control flow (no fall-through to successor) @@ -90,6 +93,7 @@ ACTION_REQUIRED_FIELDS: dict[str, list[str]] = { "EmitEvent": ["event"], "InvokeFunctionTool": ["functionName"], "HttpRequestAction": ["url"], + "InvokeMcpTool": ["serverUrl", "toolName"], } # Alternate field names that satisfy required field requirements @@ -135,6 +139,7 @@ class DeclarativeWorkflowBuilder: validate: bool = True, max_iterations: int | None = None, http_request_handler: HttpRequestHandler | None = None, + mcp_tool_handler: MCPToolHandler | None = None, ): """Initialize the builder. @@ -150,6 +155,9 @@ class DeclarativeWorkflowBuilder: http_request_handler: Handler used to dispatch HttpRequestAction requests. Must be supplied when the workflow contains any HttpRequestAction; otherwise build raises ``DeclarativeWorkflowError``. + mcp_tool_handler: Handler used to dispatch InvokeMcpTool calls. + Must be supplied when the workflow contains any InvokeMcpTool; + otherwise build raises ``DeclarativeWorkflowError``. """ self._yaml_def = yaml_definition self._workflow_id = workflow_id or yaml_definition.get("name", "declarative_workflow") @@ -162,6 +170,7 @@ class DeclarativeWorkflowBuilder: self._validate = validate self._seen_explicit_ids: set[str] = set() # Track explicit IDs for duplicate detection self._http_request_handler = http_request_handler + self._mcp_tool_handler = mcp_tool_handler # Resolve max_iterations: explicit arg > YAML maxTurns > core default resolved = max_iterations if max_iterations is not None else yaml_definition.get("maxTurns") if resolved is not None and (not isinstance(resolved, int) or resolved <= 0): @@ -481,6 +490,19 @@ class DeclarativeWorkflowBuilder: id=action_id, http_request_handler=self._http_request_handler, ) + elif kind == "InvokeMcpTool": + if self._mcp_tool_handler is None: + raise DeclarativeWorkflowError( + f"Workflow defines InvokeMcpTool '{action_id}' but no " + "mcp_tool_handler was supplied to WorkflowFactory. Pass " + "mcp_tool_handler=DefaultMCPToolHandler() (or a custom " + "implementation) to enable MCP tool invocations." + ) + executor = InvokeMcpToolActionExecutor( + action_def, + id=action_id, + mcp_tool_handler=self._mcp_tool_handler, + ) else: executor = executor_class(action_def, id=action_id) self._executors[action_id] = executor diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py new file mode 100644 index 0000000000..73b66341ea --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -0,0 +1,614 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Executor for the ``InvokeMcpTool`` declarative action. + +Mirrors the .NET ``InvokeMcpToolExecutor``: dispatches an MCP tool call through +the configured :class:`MCPToolHandler`, parses tool outputs, and routes +results to the configured ``output.{result, messages, autoSend}`` paths and +optional conversation history. Supports a human-in-loop approval flow via +``ctx.request_info()`` / :func:`@response_handler` for ``requireApproval=true``. + +Security notes: + +- The executor never echoes header VALUES (auth tokens, API keys) into the + approval request — only header NAMES are surfaced to the caller. This + matches the security posture of :mod:`._executors_http` (which never logs + request headers either) and prevents secrets from leaking through workflow + events that are typically observable to operators / UIs. +- ``_MCPToolApprovalState`` snapshots the EVALUATED values for non-secret + fields (server URL, tool name, arguments) at approval-request time so that + subsequent state mutations cannot make the executor "approve X then call + Y". Headers are stored as the raw expression strings (not evaluated values) + so secrets are not persisted in the workflow's checkpoint state. They are + re-evaluated on resume. +- Tool outputs flow back into agent conversations through ``conversationId`` + and through Tool-role messages emitted to ``output.messages``. They share + the same prompt-injection risk surface as ``HttpRequestAction``: workflow + authors must trust the MCP server they invoke. +""" + +import json +import logging +import uuid +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +import httpx +from agent_framework import ( + Content, + Message, + WorkflowContext, + handler, + response_handler, +) +from agent_framework.exceptions import ToolExecutionException + +from ._declarative_base import ( + ActionComplete, + DeclarativeActionExecutor, + DeclarativeWorkflowState, +) +from ._executors_tools import ToolApprovalResponse +from ._mcp_handler import MCPToolHandler, MCPToolInvocation, MCPToolResult + +__all__ = [ + "MCP_ACTION_EXECUTORS", + "InvokeMcpToolActionExecutor", + "MCPToolApprovalRequest", +] + +logger = logging.getLogger(__name__) + +_MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state" + + +# --------------------------------------------------------------------------- +# Request / state types +# --------------------------------------------------------------------------- + + +@dataclass +class MCPToolApprovalRequest: + """Approval request emitted before invoking an MCP tool. + + Mirrors :class:`agent_framework_declarative.ToolApprovalRequest` but for + MCP-style invocations. Only header NAMES are surfaced — header values are + intentionally omitted because they typically carry authentication + secrets. + + Attributes: + request_id: Unique identifier for this approval request. Matches the + id workflow event-emitters use. + tool_name: Evaluated name of the tool to be invoked. + server_url: Evaluated MCP server URL. + server_label: Optional human-readable label for diagnostics. + arguments: Evaluated arguments to be forwarded to the tool. + header_names: Sorted list of outbound header names (no values). Empty + when no headers are configured. + """ + + request_id: str + tool_name: str + server_url: str + server_label: str | None + arguments: dict[str, Any] + header_names: list[str] = field(default_factory=lambda: []) + + +@dataclass +class _MCPToolApprovalState: + """Internal state saved during the approval yield for resumption. + + Stores **evaluated** values for non-secret fields to prevent + "approve X / execute Y" attacks. Stores the raw expression string for + ``headers`` so that secret values are NOT persisted in checkpoint state; + the expressions are re-evaluated against current state on resume. + """ + + server_url: str + tool_name: str + server_label: str | None + arguments: dict[str, Any] + connection_name: str | None + headers_def: Any + auto_send: bool + conversation_id_expr: str | None + output_messages_path: str | None + output_result_path: str | None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_messages_path(state: DeclarativeWorkflowState, conversation_id_expr: str | None) -> str | None: + """Return the configured conversation messages path, if any. + + Returns ``System.conversations.{evaluated_id}.messages`` when a + ``conversation_id_expr`` is configured and evaluates to a non-empty value. + Returns ``None`` when no conversation id expression is configured or when + the expression evaluates to ``None`` or an empty string (mirrors .NET + ``GetConversationId`` behaviour). + """ + if not conversation_id_expr: + return None + evaluated = state.eval_if_expression(conversation_id_expr) + if evaluated is None or (isinstance(evaluated, str) and not evaluated): + return None + return f"System.conversations.{evaluated}.messages" + + +def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None: + """Extract a state path from ``output.{key}`` field. + + Supports two YAML shapes: + + - ``output: { result: Local.MyVar }`` — plain string. + - ``output: { result: { path: Local.MyVar } }`` — object form. + """ + output: Any = action_def.get("output") + if not isinstance(output, Mapping): + return None + value: Any = output.get(key) # type: ignore[reportUnknownMemberType] + if isinstance(value, str): + return value or None + if isinstance(value, Mapping): + path: Any = value.get("path") # type: ignore[reportUnknownMemberType] + return path if isinstance(path, str) and path else None + return None + + +def _format_outputs_for_send(parsed_results: list[Any]) -> str: + """Render parsed MCP outputs to a string for ``ctx.yield_output(...)``. + + - Empty list → ``""``. + - All-string list → newline-joined. + - Single element (any type — scalar, dict, list) → JSON-dumped element. + This avoids surprising ``"[42]"`` / ``"[true]"`` / ``"[null]"`` when + an MCP tool returns a single scalar JSON value. + - Multi-element non-string list → JSON-dump the whole list. + """ + if not parsed_results: + return "" + if all(isinstance(item, str) for item in parsed_results): + return "\n".join(parsed_results) # type: ignore[arg-type] + if len(parsed_results) == 1: + return json.dumps(parsed_results[0], ensure_ascii=False) + return json.dumps(parsed_results, ensure_ascii=False) + + +# --------------------------------------------------------------------------- +# Executor +# --------------------------------------------------------------------------- + + +class InvokeMcpToolActionExecutor(DeclarativeActionExecutor): + """Executor for the ``InvokeMcpTool`` declarative action. + + Dispatches through the supplied :class:`MCPToolHandler` and: + + - Evaluates ``serverUrl`` / ``toolName`` / ``serverLabel`` / ``arguments`` + / ``headers`` / ``connection.name`` from the action definition. + - When ``requireApproval=true``: emits a :class:`MCPToolApprovalRequest` + via ``ctx.request_info()`` and yields. On resume, the response is + checked; on rejection, ``output.result`` is set to ``"Error: ..."`` and + no tool call is made. + - On success: parses each :class:`agent_framework.Content` output (text → + JSON-first / data / uri → URI string) and assigns the parsed list to + ``output.result``. Builds a single Tool-role :class:`Message` + containing all output contents and assigns it to ``output.messages``. + When ``output.autoSend`` is true (default), emits the rendered string + via ``ctx.yield_output(...)``. When ``conversationId`` is configured, + appends an Assistant-role :class:`Message` with the same contents to + ``System.conversations.{id}.messages``. + - On error returned by the handler (``is_error=True``): assigns + ``"Error: "`` to ``output.result`` and completes normally + (parity with .NET ``AssignErrorAsync``). + + .. note:: + + ``output.messages`` receives a SINGLE Tool-role :class:`Message` + (containing the full tool output as ``contents``), unlike + :class:`agent_framework_declarative.InvokeFunctionToolExecutor` which + writes a list of two messages (assistant call + tool result). This + matches the .NET ``InvokeMcpToolExecutor`` output contract. + """ + + def __init__( + self, + action_def: dict[str, Any], + *, + id: str | None = None, + mcp_tool_handler: MCPToolHandler, + ) -> None: + """Create an MCP tool action executor. + + Args: + action_def: Parsed ``InvokeMcpTool`` YAML dict. + id: Optional executor id (defaults to action id or generated). + mcp_tool_handler: Handler used to dispatch MCP tool calls. + Required: the builder enforces presence at workflow-build + time. + """ + super().__init__(action_def, id=id) + self._mcp_tool_handler = mcp_tool_handler + + # ----- Main handler -------------------------------------------------------- + + @handler + async def handle_action( + self, + trigger: Any, + ctx: WorkflowContext[ActionComplete, str], + ) -> None: + """Execute the MCP tool action.""" + state = await self._ensure_state_initialized(ctx, trigger) + + server_url = self._get_server_url(state) + tool_name = self._get_tool_name(state) + server_label = self._get_server_label(state) + arguments = self._get_arguments(state) + headers = self._get_headers(state) + connection_name = self._get_connection_name(state) + require_approval = self._get_require_approval(state) + auto_send = self._get_auto_send(state) + conversation_id_expr = self._action_def.get("conversationId") + output_messages_path = _get_output_path(self._action_def, "messages") + output_result_path = _get_output_path(self._action_def, "result") + + if require_approval: + request_id = str(uuid.uuid4()) + approval_state = _MCPToolApprovalState( + server_url=server_url, + tool_name=tool_name, + server_label=server_label, + arguments=arguments, + connection_name=connection_name, + headers_def=self._action_def.get("headers"), + auto_send=auto_send, + conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None, + output_messages_path=output_messages_path, + output_result_path=output_result_path, + ) + ctx.state.set(self._approval_key(), approval_state) + + request = MCPToolApprovalRequest( + request_id=request_id, + tool_name=tool_name, + server_url=server_url, + server_label=server_label, + arguments=arguments, + header_names=sorted(headers.keys()), + ) + logger.info( + "%s: requesting approval for MCP tool '%s' on '%s'", + self.__class__.__name__, + tool_name, + server_url, + ) + await ctx.request_info(request, ToolApprovalResponse, request_id=request_id) + # Workflow yields here — resume in handle_approval_response. + return + + # No approval required - invoke directly. + invocation = MCPToolInvocation( + server_url=server_url, + tool_name=tool_name, + server_label=server_label, + arguments=arguments, + headers=headers, + connection_name=connection_name, + ) + result = await self._invoke_with_narrow_catch(invocation) + await self._process_result( + ctx=ctx, + state=state, + result=result, + auto_send=auto_send, + conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None, + output_messages_path=output_messages_path, + output_result_path=output_result_path, + ) + await ctx.send_message(ActionComplete()) + + # ----- Approval response handler ------------------------------------------ + + @response_handler + async def handle_approval_response( + self, + original_request: MCPToolApprovalRequest, + response: ToolApprovalResponse, + ctx: WorkflowContext[ActionComplete, str], + ) -> None: + """Resume after the workflow yielded for an approval request.""" + state = self._get_state(ctx.state) + approval_key = self._approval_key() + + try: + approval_state: _MCPToolApprovalState = ctx.state.get(approval_key) + except KeyError: + logger.error("%s: approval state missing for executor '%s'", self.__class__.__name__, self.id) + await ctx.send_message(ActionComplete()) + return + try: + ctx.state.delete(approval_key) + except KeyError: + logger.warning("%s: approval state already deleted for '%s'", self.__class__.__name__, self.id) + + if not response.approved: + logger.info( + "%s: MCP tool '%s' rejected: %s", + self.__class__.__name__, + approval_state.tool_name, + response.reason, + ) + self._assign_error( + state, approval_state.output_result_path, "MCP tool invocation was not approved by user." + ) + await ctx.send_message(ActionComplete()) + return + + # Approved — re-evaluate headers (not stored at approval time for security). + headers = self._evaluate_headers(state, approval_state.headers_def) + + invocation = MCPToolInvocation( + server_url=approval_state.server_url, + tool_name=approval_state.tool_name, + server_label=approval_state.server_label, + arguments=approval_state.arguments, + headers=headers, + connection_name=approval_state.connection_name, + ) + result = await self._invoke_with_narrow_catch(invocation) + await self._process_result( + ctx=ctx, + state=state, + result=result, + auto_send=approval_state.auto_send, + conversation_id_expr=approval_state.conversation_id_expr, + output_messages_path=approval_state.output_messages_path, + output_result_path=approval_state.output_result_path, + ) + await ctx.send_message(ActionComplete()) + + # ----- Field resolution ---------------------------------------------------- + + def _get_server_url(self, state: DeclarativeWorkflowState) -> str: + raw = self._action_def.get("serverUrl") + if raw is None: + raise ValueError("InvokeMcpTool requires a 'serverUrl' field.") + evaluated = state.eval_if_expression(raw) + if not isinstance(evaluated, str) or not evaluated: + raise ValueError("InvokeMcpTool 'serverUrl' evaluated to an empty value.") + return evaluated + + def _get_tool_name(self, state: DeclarativeWorkflowState) -> str: + raw = self._action_def.get("toolName") + if raw is None: + raise ValueError("InvokeMcpTool requires a 'toolName' field.") + evaluated = state.eval_if_expression(raw) + if not isinstance(evaluated, str) or not evaluated: + raise ValueError("InvokeMcpTool 'toolName' evaluated to an empty value.") + return evaluated + + def _get_server_label(self, state: DeclarativeWorkflowState) -> str | None: + raw = self._action_def.get("serverLabel") + if raw is None: + return None + evaluated = state.eval_if_expression(raw) + if evaluated is None: + return None + text = str(evaluated) + return text or None + + def _get_arguments(self, state: DeclarativeWorkflowState) -> dict[str, Any]: + """Evaluate ``arguments`` map. Preserves ``None`` values (parity with .NET).""" + raw = self._action_def.get("arguments") + if raw is None: + return {} + if not isinstance(raw, Mapping) or not raw: + return {} + result: dict[str, Any] = {} + for key, value in raw.items(): # type: ignore[reportUnknownVariableType] + if not isinstance(key, str) or not key: + continue + result[key] = state.eval_if_expression(value) + return result + + def _get_headers(self, state: DeclarativeWorkflowState) -> dict[str, str]: + return self._evaluate_headers(state, self._action_def.get("headers")) + + @staticmethod + def _evaluate_headers(state: DeclarativeWorkflowState, headers_def: Any) -> dict[str, str]: + """Evaluate the ``headers`` map. Empty string values are skipped.""" + if not isinstance(headers_def, Mapping) or not headers_def: + return {} + result: dict[str, str] = {} + for key, value in headers_def.items(): # type: ignore[reportUnknownVariableType] + if not isinstance(key, str) or not key: + continue + evaluated = state.eval_if_expression(value) + if evaluated is None: + continue + text = str(evaluated) + if not text: + continue + result[key] = text + return result + + def _get_connection_name(self, state: DeclarativeWorkflowState) -> str | None: + connection = self._action_def.get("connection") + if not isinstance(connection, Mapping): + return None + name_expr: Any = connection.get("name") # type: ignore[reportUnknownMemberType] + if name_expr is None: + return None + evaluated = state.eval_if_expression(name_expr) + if evaluated is None: + return None + text = str(evaluated) + return text or None + + def _get_require_approval(self, state: DeclarativeWorkflowState) -> bool: + raw = self._action_def.get("requireApproval") + if raw is None: + return False + evaluated = state.eval_if_expression(raw) + if isinstance(evaluated, bool): + return evaluated + if isinstance(evaluated, str): + return evaluated.strip().lower() in {"true", "1", "yes"} + return bool(evaluated) + + def _get_auto_send(self, state: DeclarativeWorkflowState) -> bool: + output: Any = self._action_def.get("output") + if not isinstance(output, Mapping): + return True + raw: Any = output.get("autoSend") # type: ignore[reportUnknownMemberType] + if raw is None: + return True + evaluated = state.eval_if_expression(raw) + if isinstance(evaluated, bool): + return evaluated + if isinstance(evaluated, str): + return evaluated.strip().lower() in {"true", "1", "yes"} + return bool(evaluated) + + # ----- Invocation + error handling ---------------------------------------- + + async def _invoke_with_narrow_catch(self, invocation: MCPToolInvocation) -> MCPToolResult: + """Invoke the handler with a narrow exception catch. + + Only known transport / tool exceptions are normalised to an error + result. Programmer bugs (TypeError, ValueError from misuse, etc.) + propagate so they fail loudly. + + ``asyncio.CancelledError`` is a ``BaseException``, not ``Exception``, + so it is not caught here and propagates unchanged for workflow + cancellation. + """ + try: + return await self._mcp_tool_handler.invoke_tool(invocation) + except ToolExecutionException as exc: + message = str(exc) or type(exc).__name__ + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + except httpx.HTTPError as exc: + message = f"{type(exc).__name__}: {exc}" if str(exc) else type(exc).__name__ + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + except Exception as exc: + try: + from mcp.shared.exceptions import McpError + except ImportError: # pragma: no cover - mcp is a hard dep + raise + if isinstance(exc, McpError): + message = str(exc) or type(exc).__name__ + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + raise + + # ----- Result handling ----------------------------------------------------- + + async def _process_result( + self, + *, + ctx: WorkflowContext[ActionComplete, str], + state: DeclarativeWorkflowState, + result: MCPToolResult, + auto_send: bool, + conversation_id_expr: str | None, + output_messages_path: str | None, + output_result_path: str | None, + ) -> None: + """Apply ``result`` to workflow state per the configured output paths.""" + if result.is_error: + # Error path mirrors .NET ``AssignErrorAsync`` — only the result + # path is touched; messages / autoSend / conversation are not. + self._assign_error( + state, + output_result_path, + result.error_message or "MCP tool invocation failed.", + ) + return + + parsed_results = _parse_outputs(result.outputs) + if output_result_path is not None and parsed_results: + state.set(output_result_path, parsed_results) + + # Single Tool-role message (matches .NET line 178 contract). Differs + # from InvokeFunctionTool's two-message [assistant call, tool result] + # convention. + tool_message = Message(role="tool", contents=list(result.outputs)) + if output_messages_path is not None: + state.set(output_messages_path, tool_message) + + if auto_send and parsed_results: + await ctx.yield_output(_format_outputs_for_send(parsed_results)) + + if conversation_id_expr: + messages_path = _get_messages_path(state, conversation_id_expr) + if messages_path is not None: + # Mirrors .NET: conversation gets ASSISTANT-role message with + # the same outputs (so chat history reads it as the agent's + # contribution). + assistant_message = Message(role="assistant", contents=list(result.outputs)) + state.append(messages_path, assistant_message) + + @staticmethod + def _assign_error( + state: DeclarativeWorkflowState, + output_result_path: str | None, + error_message: str, + ) -> None: + """Mirror .NET ``AssignErrorAsync``: store ``"Error: "`` at the result path.""" + if output_result_path is None: + return + state.set(output_result_path, f"Error: {error_message}") + + def _approval_key(self) -> str: + return f"{_MCP_APPROVAL_STATE_KEY}_{self.id}" + + +def _parse_outputs(outputs: list[Content]) -> list[Any]: + """Parse :class:`Content` outputs into Python values for ``output.result``. + + Mirrors .NET ``AssignResultAsync``: + + - ``TextContent`` → JSON-parse text; on failure use the raw text. + - ``DataContent`` / ``UriContent`` → ``content.uri``. + - Other content kinds → ``str(content)``. + """ + parsed: list[Any] = [] + for content in outputs: + kind = getattr(content, "type", None) + if kind == "text": + text_value = getattr(content, "text", None) + text_str = "" if text_value is None else str(text_value) + try: + parsed.append(json.loads(text_str)) + except (json.JSONDecodeError, ValueError): + parsed.append(text_str) + continue + if kind in ("data", "uri"): + uri_value = getattr(content, "uri", None) + parsed.append("" if uri_value is None else str(uri_value)) + continue + parsed.append(str(content)) + return parsed + + +MCP_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = { + "InvokeMcpTool": InvokeMcpToolActionExecutor, +} diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_factory.py b/python/packages/declarative/agent_framework_declarative/_workflows/_factory.py index d1e21d76e9..221dfec3cc 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_factory.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_factory.py @@ -29,6 +29,7 @@ from .._loader import AgentFactory from ._declarative_builder import DeclarativeWorkflowBuilder from ._errors import DeclarativeWorkflowError from ._http_handler import HttpRequestHandler +from ._mcp_handler import MCPToolHandler logger = logging.getLogger("agent_framework.declarative") @@ -91,6 +92,7 @@ class WorkflowFactory: checkpoint_storage: CheckpointStorage | None = None, max_iterations: int | None = None, http_request_handler: HttpRequestHandler | None = None, + mcp_tool_handler: MCPToolHandler | None = None, ) -> None: """Initialize the workflow factory. @@ -110,6 +112,13 @@ class WorkflowFactory: otherwise. Use :class:`agent_framework.declarative.DefaultHttpRequestHandler` for a no-policy ``httpx``-based default, or supply your own implementation to enforce SSRF guards, allowlisting, or auth resolution. + mcp_tool_handler: Optional handler used to dispatch MCP tool calls for + ``InvokeMcpTool``. Required if the workflow contains any + ``InvokeMcpTool``; build will fail with :class:`DeclarativeWorkflowError` + otherwise. Use :class:`agent_framework.declarative.DefaultMCPToolHandler` + for a default backed by :class:`agent_framework.MCPStreamableHTTPTool`, + or supply your own implementation to enforce SSRF guards, allowlisting, + or auth/connection resolution. Examples: .. code-block:: python @@ -150,6 +159,7 @@ class WorkflowFactory: self._checkpoint_storage = checkpoint_storage self._max_iterations = max_iterations self._http_request_handler = http_request_handler + self._mcp_tool_handler = mcp_tool_handler def create_workflow_from_yaml_path( self, @@ -394,6 +404,7 @@ class WorkflowFactory: checkpoint_storage=self._checkpoint_storage, max_iterations=self._max_iterations, http_request_handler=self._http_request_handler, + mcp_tool_handler=self._mcp_tool_handler, ) workflow = graph_builder.build() except ValueError as e: diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_mcp_handler.py b/python/packages/declarative/agent_framework_declarative/_workflows/_mcp_handler.py new file mode 100644 index 0000000000..658ce42c23 --- /dev/null +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_mcp_handler.py @@ -0,0 +1,494 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""MCP tool handler abstraction for declarative workflows. + +Mirrors the .NET ``IMcpToolHandler`` / ``DefaultMcpToolHandler`` pair from +``Microsoft.Agents.AI.Workflows.Declarative.Mcp``. Provides: + +- :class:`MCPToolInvocation` — request input data passed from the executor. +- :class:`MCPToolResult` — response data returned to the executor. +- :class:`MCPToolHandler` — :class:`typing.Protocol` callers implement to plug + in custom transports (e.g. with allowlisting, Foundry connection resolution, + per-server auth, etc.). +- :class:`DefaultMCPToolHandler` — production-grade default backed by + :class:`agent_framework.MCPStreamableHTTPTool`. + +Security note: :class:`DefaultMCPToolHandler` performs **no** URL filtering or +SSRF protection. Production deployments should supply a custom handler that +enforces an allowlist or DNS-rebinding-resistant policy. This split mirrors the +.NET design. + +Prompt-injection note: MCP tool outputs flow back into agent conversations +(via ``conversationId`` and Tool-role messages emitted by the executor) so +they share the same risk surface as ``HttpRequestAction``. Workflow authors +must trust the MCP server they invoke. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +from collections import OrderedDict +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol, cast, runtime_checkable + +import httpx + +if TYPE_CHECKING: + from agent_framework import Content + +__all__ = [ + "ClientProvider", + "DefaultMCPToolHandler", + "MCPToolHandler", + "MCPToolInvocation", + "MCPToolResult", +] + +logger = logging.getLogger(__name__) + +_DEFAULT_CACHE_MAX_SIZE = 32 + + +@dataclass +class MCPToolInvocation: + """Description of an MCP tool call to be dispatched by a :class:`MCPToolHandler`. + + Mirrors the input parameters of the .NET ``IMcpToolHandler.InvokeToolAsync`` + method. Field semantics: + + - ``server_url``: Absolute URL of the MCP server. Already evaluated from + the YAML expression. + - ``server_label``: Optional human-readable label used for diagnostics + and as the underlying ``MCPStreamableHTTPTool`` name. + - ``tool_name``: Name of the tool to invoke on the MCP server. + - ``arguments``: Tool arguments. Already evaluated; values may be any + JSON-serialisable Python object (str, int, bool, dict, list, None). + - ``headers``: Outbound HTTP headers (e.g. authentication). Empty values + are skipped by the executor before construction. + - ``connection_name``: Optional Foundry connection name forwarded for + handlers that resolve auth/credentials by connection. The default + handler does not consume this field. + """ + + server_url: str + tool_name: str + server_label: str | None = None + arguments: dict[str, Any] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + headers: dict[str, str] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + connection_name: str | None = None + + +def _empty_outputs() -> list[Any]: + """Default factory for ``MCPToolResult.outputs``. + + Typed as ``list[Any]`` here to keep the dataclass field's runtime + factory simple; the public type on :class:`MCPToolResult` is + ``list[Content]``. + """ + return [] + + +@dataclass +class MCPToolResult: + """Response returned by an :class:`MCPToolHandler`. + + Mirrors the .NET ``McpServerToolResultContent`` shape. ``outputs`` is a + list of :class:`agent_framework.Content` items as parsed by the MCP + transport (TextContent / DataContent / UriContent / etc.). + + On error, ``is_error`` is ``True``, ``error_message`` carries a human + readable description, and ``outputs`` typically contains a single + ``Content.from_text("Error: ...")`` entry for downstream display. + """ + + outputs: list[Content] = field(default_factory=_empty_outputs) + is_error: bool = False + error_message: str | None = None + + +@runtime_checkable +class MCPToolHandler(Protocol): + """Protocol for MCP tool handlers used by ``InvokeMcpTool``. + + Mirrors :class:`HttpRequestHandler` — declares ONLY the invocation method. + Lifecycle methods (``aclose`` / ``__aenter__`` / ``__aexit__``) are NOT + part of the Protocol; concrete implementations may add them as + appropriate. + + Implementations must be safe to call concurrently from multiple workflow + runs. Implementations are responsible for any URL allowlisting, SSRF + guards, retry policies, auth resolution, and other policies the workflow + author wants applied. + """ + + async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult: + """Dispatch ``invocation`` and return the result. + + Args: + invocation: Description of the MCP tool call to perform. + + Returns: + The :class:`MCPToolResult` carrying the parsed outputs (or an + error flag if the tool raised). Implementations SHOULD return a + result with ``is_error=True`` rather than raising for transport + or tool-level failures, so the workflow can store the message in + ``output.result`` (matching .NET ``AssignErrorAsync`` behaviour). + They MAY raise on unexpected programming errors — these will be + propagated unchanged by the executor so they fail loudly. + """ + ... + + +ClientProvider = Callable[[MCPToolInvocation], Awaitable["httpx.AsyncClient | None"]] + + +@dataclass +class _CacheEntry: + """Internal record stored in the LRU cache.""" + + tool: Any # MCPStreamableHTTPTool — typed Any to avoid import at module load + owned_httpx_client: httpx.AsyncClient | None + + +class DefaultMCPToolHandler: + """Default :class:`MCPToolHandler` backed by :class:`agent_framework.MCPStreamableHTTPTool`. + + Caches one :class:`agent_framework.MCPStreamableHTTPTool` instance per + ``(server_url, server_label, connection_name, headers_hash)`` in a + bounded LRU. The cache prevents re-establishing an MCP session for every + invocation while ensuring different header sets (auth tokens) cannot + share a session — matches the .NET design intent while bounding + cardinality. ``server_label`` and ``connection_name`` participate in + the key so that callers using ``client_provider`` to dispatch on those + fields receive a fresh client per logical connection (see below). + Header *names* are lower-cased inside the hash payload only — the + headers passed on the wire keep the caller's original casing — so two + YAML actions that spell ``Authorization`` differently still share a + cache entry. + + Construction modes: + + 1. ``DefaultMCPToolHandler()`` — owns its own ``httpx.AsyncClient`` + instances created lazily per cache entry. Closed by :meth:`aclose`. + 2. ``DefaultMCPToolHandler(client_provider=cb)`` — per-server client + lookup (parity with .NET ``httpClientProvider`` callback). The + callback receives the full :class:`MCPToolInvocation` so it can + dispatch on ``server_url`` / ``connection_name`` / ``server_label``. + Returning ``None`` falls back to an internally-created client. Caller + supplied clients are NOT closed by :meth:`aclose`. + + .. warning:: + + This handler performs **no** URL filtering or SSRF protection. Wrap + or replace it with a custom handler in production deployments. + + Args: + client_provider: Optional per-server ``httpx.AsyncClient`` provider. + cache_max_size: Maximum number of cached MCP clients. When exceeded, + the least-recently-used entry is evicted and its client closed + (only owned clients are closed; caller-supplied ones are not). + Defaults to ``32``. + """ + + def __init__( + self, + *, + client_provider: ClientProvider | None = None, + cache_max_size: int = _DEFAULT_CACHE_MAX_SIZE, + ) -> None: + if cache_max_size <= 0: + raise ValueError(f"cache_max_size must be positive, got {cache_max_size}") + self._client_provider = client_provider + self._cache_max_size = cache_max_size + self._cache: OrderedDict[tuple[str, str, str, str], _CacheEntry] = OrderedDict() + # Outer lock guards the cache + in-flight-future map only — never + # held across network I/O. + self._cache_lock = asyncio.Lock() + # Per-key in-flight futures: while one task is connecting, other + # tasks awaiting the same key will await the same future and share + # the resulting cache entry. + self._inflight: dict[tuple[str, str, str, str], asyncio.Future[_CacheEntry]] = {} + # Set by ``aclose`` to prevent post-close cache insertions and to + # reject new ``invoke_tool`` calls. Once set, never cleared. + self._closed = False + + async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult: + """Invoke ``invocation.tool_name`` on the cached MCP client for the server.""" + from agent_framework import Content + from agent_framework.exceptions import ToolExecutionException + + try: + entry = await self._get_or_create_entry(invocation) + except Exception as exc: + # Connect / cache lookup failures surface as tool errors so the + # workflow can store them at output.result without crashing. + logger.warning( + "DefaultMCPToolHandler: failed to obtain MCP client for url=%s tool=%s: %s", + invocation.server_url, + invocation.tool_name, + exc, + ) + message = f"Failed to connect to MCP server: {type(exc).__name__}: {exc}".rstrip(": ") + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + + try: + raw = await entry.tool.call_tool(invocation.tool_name, **invocation.arguments) + except ToolExecutionException as exc: + logger.info( + "DefaultMCPToolHandler: tool '%s' on '%s' raised ToolExecutionException", + invocation.tool_name, + invocation.server_url, + ) + message = str(exc) or type(exc).__name__ + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + except httpx.HTTPError as exc: + message = f"{type(exc).__name__}: {exc}" if str(exc) else type(exc).__name__ + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + except Exception as exc: + # Be defensive about MCP errors that may bubble up without being + # wrapped in ToolExecutionException by custom parsers. + try: + from mcp.shared.exceptions import McpError + except ImportError: # pragma: no cover - mcp is a hard dep but stay defensive + raise + if isinstance(exc, McpError): + message = str(exc) or type(exc).__name__ + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + raise + + # Defensive normalisation: call_tool is typed ``str | list[Content]``. + # Default parser returns list, but custom parse_tool_results may return str. + if isinstance(raw, str): + outputs: list[Content] = [Content.from_text(raw)] + else: + outputs = list(raw) + return MCPToolResult(outputs=outputs) + + async def aclose(self) -> None: + """Close all cached MCP clients and the owned httpx clients. + + Caller-supplied :class:`httpx.AsyncClient` instances (returned by the + ``client_provider`` callback) are NOT closed. + + Idempotent — a second call returns immediately. Drains any in-flight + ``_create_entry`` tasks before returning so their resources are + cleaned up; the in-flight tasks see ``self._closed`` in phase 3 of + :meth:`_get_or_create_entry`, close their own entry, and resolve + their future with ``RuntimeError("DefaultMCPToolHandler is closed")``. + """ + async with self._cache_lock: + if self._closed: + return + self._closed = True + entries = list(self._cache.values()) + self._cache.clear() + inflight_futures = list(self._inflight.values()) + + # Wait for in-flight creations to finish their self-cleanup. Each + # in-flight task self-closes its entry under the closed-flag branch + # in phase 3 and resolves its future with ``RuntimeError``; we + # swallow it here because the failure is expected at shutdown. + for fut in inflight_futures: + try: + await fut + except BaseException: + logger.debug("DefaultMCPToolHandler: in-flight future raised during aclose", exc_info=True) + continue + + for entry in entries: + await self._close_entry(entry) + + async def __aenter__(self) -> DefaultMCPToolHandler: + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + await self.aclose() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _get_or_create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry: + """Look up (or create) the cached MCP client for this invocation.""" + key = self._cache_key( + invocation.server_url, + invocation.server_label, + invocation.connection_name, + invocation.headers, + ) + + # Phase 1: check the cache and either claim creation or wait for an + # already in-flight creation. + creating = False + async with self._cache_lock: + if self._closed: + raise RuntimeError("DefaultMCPToolHandler is closed") + existing = self._cache.get(key) + if existing is not None: + self._cache.move_to_end(key) + return existing + inflight = self._inflight.get(key) + if inflight is None: + inflight = asyncio.get_running_loop().create_future() + self._inflight[key] = inflight + creating = True + + if not creating: + return await inflight + + # Phase 2: we own creation. Build the entry outside the lock. + try: + entry = await self._create_entry(invocation) + except BaseException as exc: + async with self._cache_lock: + self._inflight.pop(key, None) + if not inflight.done(): + inflight.set_exception(exc if isinstance(exc, BaseException) else RuntimeError(str(exc))) + # Mark the exception retrieved to suppress noisy "Future exception + # was never retrieved" warnings when there are no other awaiters + # (other awaiters still see the exception through their ``await``). + inflight.exception() + raise + + # Phase 3: insert with LRU eviction; resolve the in-flight future. + # If ``aclose`` ran while we were connecting, ``_closed`` is now + # True; don't insert into the cache (it has been drained), close + # the just-built entry, and surface the closed-handler error to + # all awaiters of the future. + evicted: _CacheEntry | None = None + duplicate: _CacheEntry | None = None + handler_closed = False + async with self._cache_lock: + self._inflight.pop(key, None) + if self._closed: + handler_closed = True + else: + existing = self._cache.get(key) + if existing is not None: + # Another writer beat us; prefer the existing entry and + # discard ours after the lock is released. + self._cache.move_to_end(key) + duplicate = entry + entry = existing + else: + self._cache[key] = entry + self._cache.move_to_end(key) + if len(self._cache) > self._cache_max_size: + _evicted_key, evicted = self._cache.popitem(last=False) + if not inflight.done(): + inflight.set_result(entry) + + if handler_closed: + # Close our orphaned entry; resolve the future with a clear + # error so the caller (and any other awaiters) surface a + # consistent "handler is closed" failure rather than receiving + # an entry we are about to close behind their back. + await self._close_entry(entry) + err = RuntimeError("DefaultMCPToolHandler is closed") + if not inflight.done(): + inflight.set_exception(err) + inflight.exception() + raise err + if duplicate is not None: + await self._close_entry(duplicate) + if evicted is not None: + await self._close_entry(evicted) + return entry + + async def _create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry: + """Construct (and connect) a fresh MCP client for ``invocation``.""" + from agent_framework import MCPStreamableHTTPTool + + provided_client: httpx.AsyncClient | None = None + if self._client_provider is not None: + provided_client = await self._client_provider(invocation) + # Capture headers for this cache entry so the header_provider closure + # always returns the same set, regardless of the runtime kwargs. + captured_headers = dict(invocation.headers) + + def _header_provider(_kwargs: dict[str, Any]) -> dict[str, str]: + return captured_headers + + tool: Any = MCPStreamableHTTPTool( + name=invocation.server_label or "McpClient", + url=invocation.server_url, + load_prompts=False, + http_client=provided_client, + header_provider=_header_provider if captured_headers else None, + ) + try: + await tool.connect() + except BaseException: + try: + await tool.close() + except Exception: # pragma: no cover - best effort + logger.debug("DefaultMCPToolHandler: error closing tool after failed connect", exc_info=True) + raise + + # ``MCPStreamableHTTPTool.get_mcp_client`` lazily creates an + # ``httpx.AsyncClient`` when no caller client was provided AND a + # ``header_provider`` was set. We treat any client allocated this + # way as owned (closed by the handler). When the caller supplies + # one, we never close it. + owned_client: httpx.AsyncClient | None = None + if provided_client is None: + owned_client = cast("httpx.AsyncClient | None", getattr(tool, "_httpx_client", None)) + return _CacheEntry(tool=tool, owned_httpx_client=owned_client) + + async def _close_entry(self, entry: _CacheEntry) -> None: + """Close the MCP tool and any owned httpx client.""" + try: + await entry.tool.close() + except Exception: # pragma: no cover - best effort + logger.debug("DefaultMCPToolHandler: error closing MCP tool", exc_info=True) + if entry.owned_httpx_client is not None: + try: + await entry.owned_httpx_client.aclose() + except Exception: # pragma: no cover - best effort + logger.debug("DefaultMCPToolHandler: error closing owned httpx client", exc_info=True) + + @staticmethod + def _cache_key( + server_url: str, + server_label: str | None, + connection_name: str | None, + headers: dict[str, str] | None, + ) -> tuple[str, str, str, str]: + """Build an order-independent cache key for the invocation identity. + + The key includes ``server_label`` and ``connection_name`` so that + callers using ``client_provider`` to dispatch on those fields + receive a fresh client per logical connection (matches the + documented dispatch contract). + + Header *names* are lower-cased inside the hash payload only so + that ``Authorization`` and ``authorization`` map to the same + cache entry. Header values remain case-sensitive (per RFC 7235). + """ + if not headers: + headers_hash = "0" + else: + normalized = sorted((k.lower(), v) for k, v in headers.items()) + payload = json.dumps(normalized, ensure_ascii=False) + headers_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest() + return (server_url, server_label or "", connection_name or "", headers_hash) diff --git a/python/packages/declarative/tests/test_default_mcp_tool_handler.py b/python/packages/declarative/tests/test_default_mcp_tool_handler.py new file mode 100644 index 0000000000..3a5c67e1d6 --- /dev/null +++ b/python/packages/declarative/tests/test_default_mcp_tool_handler.py @@ -0,0 +1,543 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for ``DefaultMCPToolHandler``. + +These tests exercise the real handler against a fake ``MCPStreamableHTTPTool`` +(no real MCP server, no real network) to cover the parts of the handler not +exercisable through the executor stub: cache hit/miss/eviction, concurrent +connect via in-flight futures, header isolation across cache keys, +string-result normalisation, ``load_prompts=False`` verification, and +owned-vs-caller httpx close semantics. +""" + +from __future__ import annotations + +import asyncio +import sys +from typing import Any +from unittest.mock import patch + +import httpx +import pytest +from agent_framework import Content +from agent_framework.exceptions import ToolExecutionException + +from agent_framework_declarative._workflows._mcp_handler import ( + DefaultMCPToolHandler, + MCPToolInvocation, +) + +pytestmark = pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="Skipped on Python 3.14+ to keep parity with rest of declarative suite", +) + + +class FakeTool: + """Stand-in for ``MCPStreamableHTTPTool``. + + Records constructor kwargs, tracks connect/close lifecycle, and dispatches + ``call_tool`` to a per-instance handler. + """ + + instances: list[FakeTool] = [] + + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + self.connect_count = 0 + self.close_count = 0 + self.connect_delay: float = 0.0 + self.connect_error: BaseException | None = None + self.call_handler: Any = lambda **_a: [Content.from_text("ok")] + self._httpx_client: httpx.AsyncClient | None = None + # Mimic MCPStreamableHTTPTool: when no caller client AND header_provider + # is set, lazily allocate an owned httpx client during connect. + FakeTool.instances.append(self) + + async def connect(self) -> None: + if self.connect_delay: + await asyncio.sleep(self.connect_delay) + if self.connect_error is not None: + raise self.connect_error + self.connect_count += 1 + # Mimic lazy httpx allocation when no client provided AND header_provider set. + if self.kwargs.get("http_client") is None and self.kwargs.get("header_provider") is not None: + self._httpx_client = httpx.AsyncClient() + + async def close(self) -> None: + self.close_count += 1 + + async def call_tool(self, tool_name: str, **arguments: Any) -> Any: + return self.call_handler(tool_name=tool_name, **arguments) + + +@pytest.fixture(autouse=True) +def _clear_fake_instances() -> None: + FakeTool.instances.clear() + + +def _patch_tool() -> Any: + """Patch the lazy import inside ``_create_entry`` to substitute FakeTool.""" + import agent_framework + + return patch.object(agent_framework, "MCPStreamableHTTPTool", FakeTool) + + +def _invocation( + *, server_url: str = "https://mcp.example/api", tool_name: str = "search", **overrides: Any +) -> MCPToolInvocation: + return MCPToolInvocation( + server_url=server_url, + tool_name=tool_name, + **overrides, + ) + + +# ---------- Construction --------------------------------------------------- + + +class TestConstruction: + def test_invalid_cache_size_raises(self) -> None: + with pytest.raises(ValueError): + DefaultMCPToolHandler(cache_max_size=0) + with pytest.raises(ValueError): + DefaultMCPToolHandler(cache_max_size=-3) + + +# ---------- Tool kwargs ---------------------------------------------------- + + +class TestToolKwargs: + @pytest.mark.asyncio + async def test_load_prompts_false_passed_to_tool(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation()) + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].kwargs["load_prompts"] is False + + @pytest.mark.asyncio + async def test_server_label_used_as_tool_name(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(server_label="MyMcp")) + assert FakeTool.instances[0].kwargs["name"] == "MyMcp" + + @pytest.mark.asyncio + async def test_default_tool_name_when_no_label(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(server_label=None)) + assert FakeTool.instances[0].kwargs["name"] == "McpClient" + + @pytest.mark.asyncio + async def test_no_header_provider_when_no_headers(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={})) + assert FakeTool.instances[0].kwargs["header_provider"] is None + + @pytest.mark.asyncio + async def test_header_provider_returns_captured_headers(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer T"})) + provider = FakeTool.instances[0].kwargs["header_provider"] + assert provider({}) == {"Authorization": "Bearer T"} + # Even if runtime kwargs change, captured headers stay the same. + assert provider({"foo": "bar"}) == {"Authorization": "Bearer T"} + + +# ---------- Cache behaviour ------------------------------------------------ + + +class TestCache: + @pytest.mark.asyncio + async def test_same_url_and_headers_hit_cache(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"X": "1"})) + await handler.invoke_tool(_invocation(headers={"X": "1"})) + # One tool created, connect called once. + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].connect_count == 1 + + @pytest.mark.asyncio + async def test_different_headers_create_separate_entries(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"Authorization": "tk-A"})) + await handler.invoke_tool(_invocation(headers={"Authorization": "tk-B"})) + assert len(FakeTool.instances) == 2 + + @pytest.mark.asyncio + async def test_different_urls_create_separate_entries(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(server_url="https://mcp.a/api")) + await handler.invoke_tool(_invocation(server_url="https://mcp.b/api")) + assert len(FakeTool.instances) == 2 + + @pytest.mark.asyncio + async def test_lru_eviction_closes_old_entry(self) -> None: + handler = DefaultMCPToolHandler(cache_max_size=2) + with _patch_tool(): + await handler.invoke_tool(_invocation(server_url="https://a/")) + await handler.invoke_tool(_invocation(server_url="https://b/")) + # Inserting a third evicts the LRU entry (the first one). + await handler.invoke_tool(_invocation(server_url="https://c/")) + assert len(FakeTool.instances) == 3 + # First instance (https://a/) was evicted → close() called. + assert FakeTool.instances[0].kwargs["url"] == "https://a/" + assert FakeTool.instances[0].close_count == 1 + # Other two remain in cache → not closed. + assert FakeTool.instances[1].close_count == 0 + assert FakeTool.instances[2].close_count == 0 + + @pytest.mark.asyncio + async def test_repeated_use_keeps_lru_alive(self) -> None: + handler = DefaultMCPToolHandler(cache_max_size=2) + with _patch_tool(): + await handler.invoke_tool(_invocation(server_url="https://a/")) + await handler.invoke_tool(_invocation(server_url="https://b/")) + # Touch a → b becomes LRU. + await handler.invoke_tool(_invocation(server_url="https://a/")) + # Insert c → b is evicted. + await handler.invoke_tool(_invocation(server_url="https://c/")) + # b was evicted. + b = FakeTool.instances[1] + assert b.kwargs["url"] == "https://b/" + assert b.close_count == 1 + # a survived. + a = FakeTool.instances[0] + assert a.kwargs["url"] == "https://a/" + assert a.close_count == 0 + + @pytest.mark.asyncio + async def test_concurrent_connect_shares_one_entry(self) -> None: + """Multiple concurrent invocations with the same key must share one tool.""" + handler = DefaultMCPToolHandler() + + # Slow down connect so concurrency window is observable. + original_connect = FakeTool.connect + + async def slow_connect(self: FakeTool) -> None: + self.connect_delay = 0.05 + await original_connect(self) + + with _patch_tool(), patch.object(FakeTool, "connect", slow_connect): + results = await asyncio.gather( + handler.invoke_tool(_invocation(headers={"X": "1"})), + handler.invoke_tool(_invocation(headers={"X": "1"})), + handler.invoke_tool(_invocation(headers={"X": "1"})), + handler.invoke_tool(_invocation(headers={"X": "1"})), + ) + assert all(not r.is_error for r in results) + # Only one tool was created and connected, despite 4 concurrent calls. + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].connect_count == 1 + + @pytest.mark.asyncio + async def test_different_connection_names_create_separate_entries(self) -> None: + """Same URL/headers but different ``connection_name`` must dispatch separately.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(connection_name="conn-A")) + await handler.invoke_tool(_invocation(connection_name="conn-B")) + assert len(FakeTool.instances) == 2 + + @pytest.mark.asyncio + async def test_different_server_labels_create_separate_entries(self) -> None: + """Same URL/headers but different ``server_label`` must dispatch separately.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(server_label="LabelA")) + await handler.invoke_tool(_invocation(server_label="LabelB")) + assert len(FakeTool.instances) == 2 + + @pytest.mark.asyncio + async def test_full_identity_match_hits_cache(self) -> None: + """All four identity components match → single cached entry.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"})) + await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"})) + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].connect_count == 1 + + @pytest.mark.asyncio + async def test_header_name_case_collapses_to_one_cache_entry(self) -> None: + """Header name spelling differences (case-only) must share a cache entry.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"Authorization": "tk"})) + await handler.invoke_tool(_invocation(headers={"authorization": "tk"})) + await handler.invoke_tool(_invocation(headers={"AUTHORIZATION": "tk"})) + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].connect_count == 1 + + @pytest.mark.asyncio + async def test_header_value_case_does_not_collapse(self) -> None: + """Header *values* remain case-sensitive (different tokens → different sessions).""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer-A"})) + await handler.invoke_tool(_invocation(headers={"Authorization": "bearer-a"})) + assert len(FakeTool.instances) == 2 + + +# ---------- Aclose semantics ---------------------------------------------- + + +class TestAclose: + @pytest.mark.asyncio + async def test_aclose_closes_owned_clients(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"X": "1"})) + tool = FakeTool.instances[0] + owned = tool._httpx_client + assert owned is not None + await handler.aclose() + assert tool.close_count == 1 + assert owned.is_closed + + @pytest.mark.asyncio + async def test_aclose_does_not_close_caller_supplied_client(self) -> None: + caller_client = httpx.AsyncClient() + + async def provider(_inv: MCPToolInvocation) -> httpx.AsyncClient: + return caller_client + + handler = DefaultMCPToolHandler(client_provider=provider) + try: + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"X": "1"})) + await handler.aclose() + assert FakeTool.instances[0].close_count == 1 + # Caller client must still be usable. + assert not caller_client.is_closed + finally: + await caller_client.aclose() + + @pytest.mark.asyncio + async def test_async_context_manager(self) -> None: + with _patch_tool(): + async with DefaultMCPToolHandler() as handler: + await handler.invoke_tool(_invocation()) + tool = FakeTool.instances[0] + assert tool.close_count == 1 + + @pytest.mark.asyncio + async def test_aclose_is_idempotent(self) -> None: + """A second ``aclose`` is a no-op (no exception, no double-close).""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.invoke_tool(_invocation(headers={"X": "1"})) + await handler.aclose() + await handler.aclose() + assert FakeTool.instances[0].close_count == 1 + + @pytest.mark.asyncio + async def test_invoke_after_close_returns_error_result(self) -> None: + """Post-close ``invoke_tool`` surfaces a tool error rather than crashing.""" + handler = DefaultMCPToolHandler() + with _patch_tool(): + await handler.aclose() + result = await handler.invoke_tool(_invocation()) + assert result.is_error is True + assert "closed" in (result.error_message or "").lower() + + @pytest.mark.asyncio + async def test_aclose_drains_inflight_creation(self) -> None: + """An in-flight ``_create_entry`` must not leak when ``aclose`` races with it. + + Reproduces the race described in PR #5630 review-comment 3: + task A claims an inflight future and starts a slow connect; task B + runs ``aclose``; task A must self-clean (close its tool + httpx + client) and surface a closed-handler error rather than orphaning + the entry. + """ + handler = DefaultMCPToolHandler() + connect_started = asyncio.Event() + release_connect = asyncio.Event() + original_connect = FakeTool.connect + + async def gated_connect(self: FakeTool) -> None: + connect_started.set() + await release_connect.wait() + await original_connect(self) + + with _patch_tool(), patch.object(FakeTool, "connect", gated_connect): + invoke_task = asyncio.create_task(handler.invoke_tool(_invocation(headers={"X": "1"}))) + # Wait until task A is mid-connect. + await connect_started.wait() + # Race: kick off aclose. It must wait for the in-flight task. + close_task = asyncio.create_task(handler.aclose()) + # Yield once to ensure aclose has set _closed and is awaiting. + await asyncio.sleep(0) + # Allow the connect to complete; phase 3 sees _closed and self-cleans. + release_connect.set() + result = await invoke_task + await close_task + + # Entry was created and then closed by the in-flight task itself. + assert len(FakeTool.instances) == 1 + assert FakeTool.instances[0].close_count == 1 + # The originating invocation surfaces a closed-handler error. + assert result.is_error is True + assert "closed" in (result.error_message or "").lower() + + +# ---------- Result normalisation ------------------------------------------ + + +class TestResultNormalisation: + @pytest.mark.asyncio + async def test_string_result_wrapped_in_text_content(self) -> None: + handler = DefaultMCPToolHandler() + with _patch_tool(): + inv = _invocation() + result = await handler.invoke_tool(inv) + # The fake's default already returns a list; replace handler for this test. + FakeTool.instances[0].call_handler = lambda **_a: "raw string body" + result = await handler.invoke_tool(inv) + assert result.is_error is False + assert len(result.outputs) == 1 + assert result.outputs[0].text == "raw string body" # type: ignore[reportAttributeAccessIssue] + + @pytest.mark.asyncio + async def test_list_result_passed_through(self) -> None: + handler = DefaultMCPToolHandler() + custom = [Content.from_text("a"), Content.from_text("b")] + with _patch_tool(): + inv = _invocation() + await handler.invoke_tool(inv) + FakeTool.instances[0].call_handler = lambda **_a: custom + result = await handler.invoke_tool(inv) + assert result.is_error is False + assert len(result.outputs) == 2 + + +# ---------- Error mapping -------------------------------------------------- + + +class TestErrorMapping: + @pytest.mark.asyncio + async def test_tool_execution_exception_returns_error_result(self) -> None: + handler = DefaultMCPToolHandler() + + def boom(**_a: Any) -> Any: + raise ToolExecutionException("server says no") + + with _patch_tool(): + inv = _invocation() + await handler.invoke_tool(inv) + FakeTool.instances[0].call_handler = boom + result = await handler.invoke_tool(inv) + assert result.is_error is True + assert result.error_message == "server says no" + assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue] + + @pytest.mark.asyncio + async def test_httpx_error_returns_error_result(self) -> None: + handler = DefaultMCPToolHandler() + + def boom(**_a: Any) -> Any: + raise httpx.ConnectError("dns failure") + + with _patch_tool(): + inv = _invocation() + await handler.invoke_tool(inv) + FakeTool.instances[0].call_handler = boom + result = await handler.invoke_tool(inv) + assert result.is_error is True + assert "dns failure" in (result.error_message or "") + + @pytest.mark.asyncio + async def test_unexpected_exception_propagates(self) -> None: + """RuntimeError (not in the narrow catch list) must propagate.""" + handler = DefaultMCPToolHandler() + + def boom(**_a: Any) -> Any: + raise RuntimeError("programmer error") + + with _patch_tool(): + inv = _invocation() + await handler.invoke_tool(inv) + FakeTool.instances[0].call_handler = boom + with pytest.raises(RuntimeError, match="programmer error"): + await handler.invoke_tool(inv) + + @pytest.mark.asyncio + async def test_connect_failure_returns_error_result(self) -> None: + handler = DefaultMCPToolHandler() + with ( + _patch_tool(), + patch.object( + FakeTool, + "connect", + lambda self: (_ for _ in ()).throw(httpx.ConnectError("server down")), + ), + ): + result = await handler.invoke_tool(_invocation()) + assert result.is_error is True + assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue] + # Failed connect must clear in-flight + cache entries. + assert handler._inflight == {} + assert len(handler._cache) == 0 + + @pytest.mark.asyncio + async def test_cancelled_error_propagates(self) -> None: + """asyncio.CancelledError is BaseException, must NOT be swallowed.""" + handler = DefaultMCPToolHandler() + + def boom(**_a: Any) -> Any: + raise asyncio.CancelledError + + with _patch_tool(): + inv = _invocation() + await handler.invoke_tool(inv) + FakeTool.instances[0].call_handler = boom + with pytest.raises(asyncio.CancelledError): + await handler.invoke_tool(inv) + + +# ---------- Cache key isolation ------------------------------------------- + + +class TestCacheKey: + def test_key_order_independent(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1", "B": "2"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"B": "2", "A": "1"}) + assert k1 == k2 + + def test_key_distinguishes_values(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "2"}) + assert k1 != k2 + + def test_empty_headers_use_fixed_hash(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, None) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {}) + assert k1 == k2 + + def test_key_distinguishes_connection_name(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-A", None) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-B", None) + assert k1 != k2 + + def test_key_distinguishes_server_label(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-A", None, None) + k2 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-B", None, None) + assert k1 != k2 + + def test_key_collapses_header_name_case(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"Authorization": "tk"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"authorization": "tk"}) + assert k1 == k2 + + def test_key_keeps_header_value_case(self) -> None: + k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "Bearer-A"}) + k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "bearer-a"}) + assert k1 != k2 diff --git a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py new file mode 100644 index 0000000000..fdee1f7df1 --- /dev/null +++ b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py @@ -0,0 +1,664 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for ``InvokeMcpToolActionExecutor``. + +Use a stub :class:`MCPToolHandler` that returns canned :class:`MCPToolResult`s. +No real MCP server or network is exercised. See +``test_default_mcp_tool_handler.py`` for tests that exercise the real +``DefaultMCPToolHandler`` against a mocked ``MCPStreamableHTTPTool``. +""" + +import sys +from typing import Any + +import httpx +import pytest + +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + +pytestmark = pytest.mark.skipif( + not _powerfx_available or sys.version_info >= (3, 14), + reason="PowerFx engine not available (requires dotnet runtime)", +) + +from agent_framework import Content, Message # noqa: E402 +from agent_framework.exceptions import ToolExecutionException # noqa: E402 + +from agent_framework_declarative._workflows import ( # noqa: E402 + DECLARATIVE_STATE_KEY, + DeclarativeWorkflowError, + MCPToolHandler, + MCPToolInvocation, + MCPToolResult, + WorkflowFactory, +) + + +class StubMcpHandler: + """Test stub recording the last call and returning a canned result.""" + + def __init__( + self, + result: MCPToolResult | None = None, + *, + raise_exc: BaseException | None = None, + ) -> None: + self.result = result + self.raise_exc = raise_exc + self.last_invocation: MCPToolInvocation | None = None + self.invocations: list[MCPToolInvocation] = [] + self.call_count = 0 + + async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult: + self.call_count += 1 + self.last_invocation = invocation + self.invocations.append(invocation) + if self.raise_exc is not None: + raise self.raise_exc + assert self.result is not None + return self.result + + +def _ok(outputs: list[Content] | None = None) -> MCPToolResult: + return MCPToolResult(outputs=outputs or [Content.from_text("hello")]) + + +def _err(message: str = "boom") -> MCPToolResult: + return MCPToolResult( + outputs=[Content.from_text(f"Error: {message}")], + is_error=True, + error_message=message, + ) + + +def _action( + *, + server_url: str = "https://mcp.example/api", + tool_name: str = "search", + server_label: str | None = None, + arguments: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + require_approval: Any = None, + connection: dict[str, Any] | None = None, + conversation_id: str | None = None, + output: dict[str, Any] | None = None, +) -> dict[str, Any]: + action: dict[str, Any] = { + "kind": "InvokeMcpTool", + "id": "mcp_action", + "serverUrl": server_url, + "toolName": tool_name, + } + if server_label is not None: + action["serverLabel"] = server_label + if arguments is not None: + action["arguments"] = arguments + if headers is not None: + action["headers"] = headers + if require_approval is not None: + action["requireApproval"] = require_approval + if connection is not None: + action["connection"] = connection + if conversation_id is not None: + action["conversationId"] = conversation_id + if output is not None: + action["output"] = output + return action + + +def _yaml(action: dict[str, Any]) -> dict[str, Any]: + return {"name": "mcp_test", "actions": [action]} + + +# ---------- Builder enforcement -------------------------------------------- + + +class TestBuilderEnforcement: + def test_missing_handler_raises_at_build_time(self) -> None: + factory = WorkflowFactory() + with pytest.raises(DeclarativeWorkflowError) as excinfo: + factory.create_workflow_from_definition(_yaml(_action())) + assert "InvokeMcpTool" in str(excinfo.value) + assert "mcp_tool_handler" in str(excinfo.value) + + def test_missing_server_url_fails_validation(self) -> None: + handler = StubMcpHandler(_ok()) + factory = WorkflowFactory(mcp_tool_handler=handler) + action = _action() + del action["serverUrl"] + with pytest.raises(Exception) as excinfo: + factory.create_workflow_from_definition(_yaml(action)) + assert "serverUrl" in str(excinfo.value) + + def test_missing_tool_name_fails_validation(self) -> None: + handler = StubMcpHandler(_ok()) + factory = WorkflowFactory(mcp_tool_handler=handler) + action = _action() + del action["toolName"] + with pytest.raises(Exception) as excinfo: + factory.create_workflow_from_definition(_yaml(action)) + assert "toolName" in str(excinfo.value) + + +# ---------- Field forwarding ---------------------------------------------- + + +class TestFieldForwarding: + @pytest.mark.asyncio + async def test_basic_invocation_forwards_required_fields(self) -> None: + handler = StubMcpHandler(_ok()) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action())) + await workflow.run({}) + assert handler.call_count == 1 + inv = handler.last_invocation + assert inv is not None + assert inv.server_url == "https://mcp.example/api" + assert inv.tool_name == "search" + assert inv.server_label is None + assert inv.headers == {} + assert inv.arguments == {} + assert inv.connection_name is None + + @pytest.mark.asyncio + async def test_arguments_evaluated_and_preserves_none(self) -> None: + handler = StubMcpHandler(_ok()) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition( + _yaml( + _action( + arguments={ + "query": "weather today", + "limit": 5, + "fresh": True, + "missing": None, + } + ) + ) + ) + await workflow.run({}) + inv = handler.last_invocation + assert inv is not None + # ``None`` is preserved (parity with .NET) — caller decides. + assert inv.arguments == { + "query": "weather today", + "limit": 5, + "fresh": True, + "missing": None, + } + + @pytest.mark.asyncio + async def test_headers_drop_empty_values(self) -> None: + handler = StubMcpHandler(_ok()) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition( + _yaml( + _action( + headers={ + "Authorization": "Bearer token-123", + "X-Trace": "trace-id", + "X-Empty": "", + } + ) + ) + ) + await workflow.run({}) + inv = handler.last_invocation + assert inv is not None + assert inv.headers == { + "Authorization": "Bearer token-123", + "X-Trace": "trace-id", + } + + @pytest.mark.asyncio + async def test_server_label_and_connection_name_forwarded(self) -> None: + handler = StubMcpHandler(_ok()) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition( + _yaml( + _action( + server_label="docs-mcp", + connection={"name": "azure-conn"}, + ) + ) + ) + await workflow.run({}) + inv = handler.last_invocation + assert inv is not None + assert inv.server_label == "docs-mcp" + assert inv.connection_name == "azure-conn" + + +# ---------- Output handling ------------------------------------------------ + + +class TestOutput: + @pytest.mark.asyncio + async def test_output_result_parses_json_text(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text('{"k":"v","n":1}')])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + assert decl["Local"]["Result"] == [{"k": "v", "n": 1}] + + @pytest.mark.asyncio + async def test_output_result_falls_back_to_raw_text(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text("plain text not json")])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + assert decl["Local"]["Result"] == ["plain text not json"] + + @pytest.mark.asyncio + async def test_output_messages_writes_single_tool_role_message(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text("hi"), Content.from_text("there")])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"messages": "Local.Messages"}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + msg = decl["Local"]["Messages"] + # Single Tool-role message containing both contents (parity with .NET). + assert isinstance(msg, Message) + assert str(msg.role).lower() == "tool" + assert len(msg.contents) == 2 + + @pytest.mark.asyncio + async def test_uri_content_serialised_as_uri_string(self) -> None: + uri_content = Content.from_uri("https://example.com/file.txt", media_type="text/plain") + handler = StubMcpHandler(_ok([uri_content])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + assert decl["Local"]["Result"] == ["https://example.com/file.txt"] + + @pytest.mark.asyncio + async def test_output_path_object_form(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text("ok")])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": {"path": "Local.Result"}}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + assert decl["Local"]["Result"] == ["ok"] + + +# ---------- Conversation append -------------------------------------------- + + +class TestConversation: + @pytest.mark.asyncio + async def test_conversation_id_appends_assistant_message(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text("answer")])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition( + _yaml( + _action( + conversation_id="conv-42", + output={"result": "Local.Result"}, + ) + ) + ) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + conv = decl["System"]["conversations"]["conv-42"] + msgs = conv["messages"] if isinstance(conv, dict) else conv.messages + assert len(msgs) == 1 + appended = msgs[0] + assert str(appended.role).lower() == "assistant" + # Same contents as the tool output. + assert len(appended.contents) == 1 + + @pytest.mark.asyncio + async def test_empty_conversation_id_does_not_append(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text("answer")])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition( + _yaml( + _action( + conversation_id="", + output={"result": "Local.Result"}, + ) + ) + ) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + # Empty conversation id must not produce a `""` entry under System.conversations. + conversations = decl.get("System", {}).get("conversations", {}) + assert "" not in conversations + + +# ---------- Approval flow -------------------------------------------------- + + +@pytest.fixture +def mock_state(): # type: ignore[no-untyped-def] + from unittest.mock import MagicMock + + state = MagicMock() + state._data = {} + + def _get(key: str, default: Any = None) -> Any: + if key not in state._data: + if default is not None: + return default + raise KeyError(key) + return state._data[key] + + def _set(key: str, value: Any) -> None: + state._data[key] = value + + def _delete(key: str) -> None: + if key in state._data: + del state._data[key] + else: + raise KeyError(key) + + state.get = MagicMock(side_effect=_get) + state.set = MagicMock(side_effect=_set) + state.delete = MagicMock(side_effect=_delete) + return state + + +@pytest.fixture +def mock_context(mock_state): # type: ignore[no-untyped-def] + from unittest.mock import AsyncMock, MagicMock + + ctx = MagicMock() + ctx.state = mock_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + ctx.request_info = AsyncMock() + return ctx + + +def _seed_state(mock_state) -> None: # type: ignore[no-untyped-def] + """Pre-seed the declarative state container as the executors expect.""" + from agent_framework_declarative._workflows import DECLARATIVE_STATE_KEY + + mock_state._data[DECLARATIVE_STATE_KEY] = { + "Local": {}, + "Custom": {}, + "Workflow": {}, + "System": { + "ConversationId": "00000000-0000-0000-0000-000000000000", + "LastMessage": {"Id": "", "Text": ""}, + "LastMessageText": "", + "LastMessageId": "", + }, + "Agent": {}, + "Conversation": {"messages": [], "history": []}, + "Inputs": {}, + } + + +class TestApprovalFlow: + @pytest.mark.asyncio + async def test_approval_required_emits_request_and_yields(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + from agent_framework_declarative._workflows._executors_mcp import ( + _MCP_APPROVAL_STATE_KEY, + InvokeMcpToolActionExecutor, + MCPToolApprovalRequest, + ) + + _seed_state(mock_state) + handler = StubMcpHandler(_ok()) + executor = InvokeMcpToolActionExecutor( + _action( + require_approval=True, + arguments={"q": "x"}, + headers={"Authorization": "Bearer SECRET"}, + output={"result": "Local.Result"}, + ), + mcp_tool_handler=handler, + ) + await executor.handle_action(ActionTrigger(), mock_context) + + # Approval request emitted. + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, MCPToolApprovalRequest) + assert request.tool_name == "search" + assert request.arguments == {"q": "x"} + assert request.header_names == ["Authorization"] + + # NEVER expose the actual auth token in any field of the approval payload. + for value in request.__dict__.values(): + assert "SECRET" not in str(value) + + # Workflow should yield (no ActionComplete sent yet). + mock_context.send_message.assert_not_called() + + # Handler not invoked yet. + assert handler.call_count == 0 + + # Approval state stored. + approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" + assert approval_key in mock_state._data + + @pytest.mark.asyncio + async def test_approval_response_approved_invokes_handler(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] + from agent_framework_declarative._workflows import ActionComplete, ToolApprovalResponse + from agent_framework_declarative._workflows._executors_mcp import ( + _MCP_APPROVAL_STATE_KEY, + InvokeMcpToolActionExecutor, + MCPToolApprovalRequest, + _MCPToolApprovalState, + ) + + _seed_state(mock_state) + handler = StubMcpHandler(_ok([Content.from_text('{"ok":true}')])) + executor = InvokeMcpToolActionExecutor( + _action( + require_approval=True, + output={"result": "Local.Result"}, + ), + mcp_tool_handler=handler, + ) + # Pre-populate approval state. + approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" + mock_state._data[approval_key] = _MCPToolApprovalState( + server_url="https://mcp.example/api", + tool_name="search", + server_label=None, + arguments={"q": "x"}, + connection_name=None, + headers_def={"Authorization": "Bearer tk"}, + auto_send=False, + conversation_id_expr=None, + output_messages_path=None, + output_result_path="Local.Result", + ) + await executor.handle_approval_response( + MCPToolApprovalRequest( + request_id="req-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + ), + ToolApprovalResponse(approved=True), + mock_context, + ) + + assert handler.call_count == 1 + inv = handler.last_invocation + assert inv is not None + # Headers are re-evaluated from headers_def. + assert inv.headers == {"Authorization": "Bearer tk"} + # Approval state was cleaned up. + assert approval_key not in mock_state._data + # ActionComplete was sent. + mock_context.send_message.assert_called_once() + sent = mock_context.send_message.call_args[0][0] + assert isinstance(sent, ActionComplete) + + @pytest.mark.asyncio + async def test_approval_response_rejected_assigns_error(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] + from agent_framework_declarative._workflows import ToolApprovalResponse + from agent_framework_declarative._workflows._executors_mcp import ( + _MCP_APPROVAL_STATE_KEY, + InvokeMcpToolActionExecutor, + MCPToolApprovalRequest, + _MCPToolApprovalState, + ) + + _seed_state(mock_state) + handler = StubMcpHandler(_ok()) + executor = InvokeMcpToolActionExecutor( + _action( + require_approval=True, + output={"result": "Local.Result"}, + ), + mcp_tool_handler=handler, + ) + approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" + mock_state._data[approval_key] = _MCPToolApprovalState( + server_url="https://mcp.example/api", + tool_name="search", + server_label=None, + arguments={}, + connection_name=None, + headers_def=None, + auto_send=True, + conversation_id_expr=None, + output_messages_path=None, + output_result_path="Local.Result", + ) + await executor.handle_approval_response( + MCPToolApprovalRequest( + request_id="req-2", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={}, + ), + ToolApprovalResponse(approved=False, reason="not authorized"), + mock_context, + ) + + assert handler.call_count == 0 + # Error string assigned at output.result. + from agent_framework_declarative._workflows import DECLARATIVE_STATE_KEY + + result = mock_state._data[DECLARATIVE_STATE_KEY]["Local"]["Result"] + assert result == "Error: MCP tool invocation was not approved by user." + + +# ---------- Error handling ------------------------------------------------- + + +class TestErrorHandling: + @pytest.mark.asyncio + async def test_handler_returns_error_result_assigns_error_string(self) -> None: + handler = StubMcpHandler(_err("server down")) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + assert decl["Local"]["Result"] == "Error: server down" + + @pytest.mark.asyncio + async def test_tool_execution_exception_becomes_error_result(self) -> None: + handler = StubMcpHandler(raise_exc=ToolExecutionException("invalid arguments")) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + assert decl["Local"]["Result"] == "Error: invalid arguments" + + @pytest.mark.asyncio + async def test_httpx_error_becomes_error_result(self) -> None: + handler = StubMcpHandler(raise_exc=httpx.ConnectError("dns fail")) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"}))) + await workflow.run({}) + decl = workflow._state.get(DECLARATIVE_STATE_KEY) + result = decl["Local"]["Result"] + assert isinstance(result, str) + assert result.startswith("Error:") + assert "ConnectError" in result + + @pytest.mark.asyncio + async def test_unexpected_exception_propagates(self) -> None: + """Programmer bugs (TypeError etc.) must NOT be swallowed.""" + handler = StubMcpHandler(raise_exc=TypeError("bad type")) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action())) + with pytest.raises(Exception) as excinfo: + await workflow.run({}) + # Either the TypeError reaches us or it gets wrapped by the runner — + # either way the message must surface. + assert "bad type" in str(excinfo.value) + + +# ---------- autoSend ------------------------------------------------------- + + +class TestAutoSend: + @pytest.mark.asyncio + async def test_auto_send_default_true_yields_output(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text("hello")])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action())) + events = await workflow.run({}) + outputs = events.get_outputs() + assert len(outputs) == 1 + + @pytest.mark.asyncio + async def test_auto_send_false_suppresses_yield(self) -> None: + handler = StubMcpHandler(_ok([Content.from_text("hello")])) + factory = WorkflowFactory(mcp_tool_handler=handler) + workflow = factory.create_workflow_from_definition(_yaml(_action(output={"autoSend": False}))) + events = await workflow.run({}) + outputs = events.get_outputs() + assert outputs == [] + + +# ---------- Protocol structure -------------------------------------------- + + +class TestProtocol: + def test_stub_handler_satisfies_protocol(self) -> None: + handler = StubMcpHandler(_ok()) + assert isinstance(handler, MCPToolHandler) + + +# ---------- _format_outputs_for_send -------------------------------------- + + +class TestFormatOutputsForSend: + """Direct tests for the auto-send rendering helper. + + Regression for PR #5630 review-comment 4: a single scalar JSON value + must render bare (e.g. ``"42"``) rather than wrapped (``"[42]"``). + """ + + @pytest.mark.parametrize( + ("parsed", "expected"), + [ + ([], ""), + (["hello"], "hello"), + (["a", "b"], "a\nb"), + ([42], "42"), + ([3.14], "3.14"), + ([True], "true"), + ([False], "false"), + ([None], "null"), + ([{"k": "v"}], '{"k": "v"}'), + ([[1, 2]], "[1, 2]"), + (["hello", 42], '["hello", 42]'), + ([{"a": 1}, {"b": 2}], '[{"a": 1}, {"b": 2}]'), + ], + ) + def test_format_outputs_for_send(self, parsed: list[Any], expected: str) -> None: + from agent_framework_declarative._workflows._executors_mcp import _format_outputs_for_send + + assert _format_outputs_for_send(parsed) == expected diff --git a/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py new file mode 100644 index 0000000000..85b513b562 --- /dev/null +++ b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Invoke MCP Tool sample - demonstrates the InvokeMcpTool declarative action. + +This sample shows how to: + 1. Configure a ``WorkflowFactory`` with a ``MCPToolHandler`` so the YAML + ``InvokeMcpTool`` action can dispatch real MCP tool calls. + 2. Invoke a tool on a public unauthenticated MCP server (the Microsoft + Learn Docs MCP server at ``https://learn.microsoft.com/api/mcp``, + calling ``microsoft_docs_search``). + 3. Bind the parsed tool result to a workflow variable and mirror it into + the conversation via ``conversationId`` so a downstream Foundry agent + can answer questions using only that context. + 4. Optionally pause the MCP tool call for human approval. The YAML reads + ``requireApproval`` from ``Workflow.Inputs.requireApproval`` so the + host can flip the behaviour without editing the workflow definition. + Set the ``MCP_REQUIRE_APPROVAL`` environment variable (``1`` / ``true`` + / ``yes``) to enable the approval flow; leave it unset for the + "fire-and-forget" default. + +Security note: + ``DefaultMCPToolHandler`` connects to whatever MCP server URL the + workflow author specifies and performs **no** allowlisting or SSRF + guards. For production use, replace it with a custom handler that + enforces an allowlist and adds any required authentication headers + per server. MCP tool outputs flow back into agent conversations and + therefore share the same prompt-injection risk surface as + ``HttpRequestAction``: only invoke MCP servers you trust. + + The approval flow is also a defence-in-depth control: even with a + trusted server, requiring human approval lets a reviewer inspect + tool name, arguments, and outbound header NAMES (never values) + before any network call is made. + +Run with: + python samples/03-workflows/declarative/invoke_mcp_tool/main.py + +Run with approval prompts: + MCP_REQUIRE_APPROVAL=1 python -m samples.03-workflows.declarative.invoke_mcp_tool.main +""" + +import asyncio +import os +from pathlib import Path + +from agent_framework import Agent +from agent_framework.declarative import ( + DefaultMCPToolHandler, + MCPToolApprovalRequest, + ToolApprovalResponse, + WorkflowFactory, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity import AzureCliCredential + +DOCS_AGENT_INSTRUCTIONS = """\ +You answer the user's question about Microsoft technology using ONLY the +search results already present in the conversation history. If the answer is +not contained in the conversation, say so plainly rather than guessing. Be +concise and cite the relevant document title or URL when possible. +""" + +_TRUTHY = {"1", "true", "yes", "on"} + + +def _read_require_approval_flag() -> bool: + """Return True when the MCP_REQUIRE_APPROVAL env var requests approval.""" + return os.environ.get("MCP_REQUIRE_APPROVAL", "").strip().lower() in _TRUTHY + + +def _prompt_for_approval(request: MCPToolApprovalRequest) -> ToolApprovalResponse: + """Render the pending MCP call to stdout and read approve/reject from the user.""" + print() + print("-" * 60) + print("MCP tool approval required") + print("-" * 60) + print(f" tool: {request.tool_name}") + print(f" server label: {request.server_label or '(unset)'}") + print(f" server url: {request.server_url}") + if request.arguments: + print(" arguments:") + for key, value in request.arguments.items(): + print(f" {key}: {value!r}") + if request.header_names: + # Only NAMES are surfaced; values are intentionally withheld because + # they typically carry authentication secrets. + print(f" outbound header names: {', '.join(request.header_names)}") + else: + print(" outbound header names: (none)") + print("-" * 60) + + while True: + answer = input("Approve this MCP call? [y/N] ").strip().lower() # noqa: ASYNC250 + if answer in {"y", "yes"}: + return ToolApprovalResponse(approved=True) + if answer in {"", "n", "no"}: + reason = input("Reason for rejection (optional): ").strip() # noqa: ASYNC250 + return ToolApprovalResponse(approved=False, reason=reason or None) + print("Please answer 'y' or 'n'.") + + +async def main() -> None: + """Run the invoke MCP tool workflow.""" + chat_client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + # The agent has no tools — it answers using only the search results that + # ``InvokeMcpTool`` adds to the conversation. + docs_agent = Agent( + client=chat_client, + name="DocsAgent", + instructions=DOCS_AGENT_INSTRUCTIONS, + ) + + agents = {"DocsAgent": docs_agent} + + require_approval = _read_require_approval_flag() + + # The default MCPToolHandler is sufficient for this sample because the + # Microsoft Learn Docs MCP server is public and unauthenticated. For + # authenticated servers, supply a ``client_provider`` callback to route + # requests through a pre-configured ``httpx.AsyncClient`` carrying the + # appropriate credentials, or wrap the handler with one that injects + # headers per call. + async with DefaultMCPToolHandler() as mcp_handler: + factory = WorkflowFactory( + agents=agents, + mcp_tool_handler=mcp_handler, + ) + + workflow_path = Path(__file__).parent / "workflow.yaml" + workflow = factory.create_workflow_from_yaml_path(workflow_path) + + print("=" * 60) + print("Invoke MCP Tool Workflow Demo") + if require_approval: + print("(MCP_REQUIRE_APPROVAL is set — you will be prompted before the tool runs)") + else: + print("(set MCP_REQUIRE_APPROVAL=1 to enable the human-approval flow)") + print("=" * 60) + print() + print("Ask one question that can be answered from the Microsoft Learn docs or provide a keyword to search.") + print() + + user_input = input("You: ").strip() # noqa: ASYNC250 + if not user_input: + user_input = "What is the Agent Framework declarative workflow runtime?" + + # Drive the workflow via dict-shaped inputs so the YAML can read + # both the user's question (``Workflow.Inputs.text``) and the + # approval toggle (``Workflow.Inputs.requireApproval``) without + # any Python-side mutation of the workflow definition. + workflow_inputs: dict[str, object] = { + "text": user_input, + "requireApproval": require_approval, + } + + # The request_info loop below handles the MCP approval flow when + # the YAML requests it. When ``requireApproval`` is false the + # workflow never emits an ``MCPToolApprovalRequest`` event, so + # the loop runs exactly once and exits cleanly — both modes share + # the same code path. + pending: tuple[str, MCPToolApprovalRequest] | None = None + produced_output = False + printed_agent_prefix = False + + while True: + if pending is None: + stream = workflow.run(workflow_inputs, stream=True) + else: + pending_id, pending_request = pending + response = _prompt_for_approval(pending_request) + stream = workflow.run(stream=True, responses={pending_id: response}) + pending = None + + async for event in stream: + if event.type == "output" and isinstance(event.data, str): + if not printed_agent_prefix: + print("\nAgent: ", end="", flush=True) + printed_agent_prefix = True + print(event.data, end="", flush=True) + produced_output = True + elif event.type == "request_info" and isinstance(event.data, MCPToolApprovalRequest): + pending = (event.request_id, event.data) + + if pending is None: + if not produced_output: + # Workflow finished without producing any agent output + # (e.g. the user rejected the MCP tool call and the + # downstream agent had nothing to summarise). + print("\n(no response produced)") + else: + print() + break + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/declarative/invoke_mcp_tool/workflow.yaml b/python/samples/03-workflows/declarative/invoke_mcp_tool/workflow.yaml new file mode 100644 index 0000000000..55f9f0754d --- /dev/null +++ b/python/samples/03-workflows/declarative/invoke_mcp_tool/workflow.yaml @@ -0,0 +1,77 @@ +# +# This workflow demonstrates the InvokeMcpTool declarative action. +# +# InvokeMcpTool lets a workflow author call a tool exposed by a Model Context +# Protocol (MCP) server directly from YAML without writing any Python glue. +# It can: +# +# - dispatch a tool call against an MCP server (with optional auth headers), +# - store the parsed tool result in a workflow variable, and +# - add the result to the conversation so a downstream agent can answer +# questions based on it. +# +# This sample calls ``microsoft_docs_search`` on the public Microsoft Learn +# Docs MCP server (no authentication required) and uses a Foundry agent to +# answer a single question about Microsoft technology using the search +# results. +# +# Example inputs (Choose one or provide yours): +# How do I configure logging in the Agent Framework? +# Gpt-5.4-mini +# +# Workflow inputs (set by the host via ``workflow.run({...})``): +# text: The user's question (required). +# requireApproval: Optional bool. When true, the MCP tool call pauses for +# human approval before contacting the server. Defaults +# to false when omitted. +# +kind: Workflow +trigger: + + kind: OnConversationStart + id: workflow_invoke_mcp_tool_demo + actions: + + # Capture the user's question into a local variable so the MCP tool call + # can pass it as an argument. + - kind: SetVariable + id: capture_query + variable: Local.SearchQuery + value: =Workflow.Inputs.text + + # Invoke microsoft_docs_search on the Microsoft Learn Docs MCP server. + # The result is parsed into Local.SearchResults and also added to the + # conversation (via conversationId) so the agent below can answer the + # user's question based on it. + # + # ``requireApproval`` reads from Workflow.Inputs so the host can toggle + # the human-approval flow without editing this YAML. When the input is + # absent or evaluates to a falsy value, the tool runs without pausing. + - kind: InvokeMcpTool + id: search_docs + conversationId: =System.ConversationId + serverUrl: https://learn.microsoft.com/api/mcp + serverLabel: MicrosoftLearnDocs + toolName: microsoft_docs_search + requireApproval: =Workflow.Inputs.requireApproval + arguments: + query: =Local.SearchQuery + output: + autoSend: false + result: Local.SearchResults + + # Use the agent to answer the user's question using the conversation + # context (which now contains the MCP search results). The user's + # question is supplied via ``input.messages`` (sourced from the workflow + # inputs), and the prior conversation history is bound via + # ``conversationId``. + - kind: InvokeAzureAgent + id: answer_question + conversationId: =System.ConversationId + agent: + name: DocsAgent + input: + messages: =Workflow.Inputs.text + output: + autoSend: true + messages: Local.AgentResponse