mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Python parity for InvokeMcpTool in declarative workflow (#5630)
* Add Python parity for HttpRequestAction in declarative workflow * Ran pyupgrade and pright to fix CI issues * Fix conversation ID dot parsing for http executor * Removed unnecessary export command * Initial implementation of invoke mcp tool in python * Update sample to support require approval to be toggled by environment variable. * Fix cache and PR comments * Update python/samples/03-workflows/declarative/invoke_mcp_tool/main.py Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com> --------- Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
f3f71f0fe8
commit
f25e81701d
@@ -9,6 +9,7 @@ YAML/JSON-based declarative agent and workflow definitions.
|
||||
- **`WorkflowState`** - State management for declarative workflows
|
||||
- **`ProviderTypeMapping`** - Maps provider types to implementations
|
||||
- **`HttpRequestHandler`** / **`DefaultHttpRequestHandler`** - Pluggable HTTP transport for the `HttpRequestAction` declarative action (configured via `WorkflowFactory(http_request_handler=...)`)
|
||||
- **`MCPToolHandler`** / **`DefaultMCPToolHandler`** - Pluggable MCP transport for the `InvokeMcpTool` declarative action (configured via `WorkflowFactory(mcp_tool_handler=...)`)
|
||||
- **`DeclarativeLoaderError`** / **`ProviderLookupError`** / **`DeclarativeWorkflowError`** / **`DeclarativeActionError`** - Error types
|
||||
|
||||
## External Input Handling
|
||||
|
||||
@@ -9,11 +9,18 @@ from ._workflows import (
|
||||
DeclarativeActionError,
|
||||
DeclarativeWorkflowError,
|
||||
DefaultHttpRequestHandler,
|
||||
DefaultMCPToolHandler,
|
||||
ExternalInputRequest,
|
||||
ExternalInputResponse,
|
||||
HttpRequestHandler,
|
||||
HttpRequestInfo,
|
||||
HttpRequestResult,
|
||||
MCPToolApprovalRequest,
|
||||
MCPToolHandler,
|
||||
MCPToolInvocation,
|
||||
MCPToolResult,
|
||||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
WorkflowFactory,
|
||||
WorkflowState,
|
||||
)
|
||||
@@ -31,13 +38,20 @@ __all__ = [
|
||||
"DeclarativeLoaderError",
|
||||
"DeclarativeWorkflowError",
|
||||
"DefaultHttpRequestHandler",
|
||||
"DefaultMCPToolHandler",
|
||||
"ExternalInputRequest",
|
||||
"ExternalInputResponse",
|
||||
"HttpRequestHandler",
|
||||
"HttpRequestInfo",
|
||||
"HttpRequestResult",
|
||||
"MCPToolApprovalRequest",
|
||||
"MCPToolHandler",
|
||||
"MCPToolInvocation",
|
||||
"MCPToolResult",
|
||||
"ProviderLookupError",
|
||||
"ProviderTypeMapping",
|
||||
"ToolApprovalRequest",
|
||||
"ToolApprovalResponse",
|
||||
"WorkflowFactory",
|
||||
"WorkflowState",
|
||||
"__version__",
|
||||
|
||||
@@ -72,6 +72,11 @@ from ._executors_http import (
|
||||
HTTP_ACTION_EXECUTORS,
|
||||
HttpRequestActionExecutor,
|
||||
)
|
||||
from ._executors_mcp import (
|
||||
MCP_ACTION_EXECUTORS,
|
||||
InvokeMcpToolActionExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
)
|
||||
from ._executors_tools import (
|
||||
FUNCTION_TOOL_REGISTRY_KEY,
|
||||
TOOL_ACTION_EXECUTORS,
|
||||
@@ -90,6 +95,12 @@ from ._http_handler import (
|
||||
HttpRequestInfo,
|
||||
HttpRequestResult,
|
||||
)
|
||||
from ._mcp_handler import (
|
||||
DefaultMCPToolHandler,
|
||||
MCPToolHandler,
|
||||
MCPToolInvocation,
|
||||
MCPToolResult,
|
||||
)
|
||||
from ._state import WorkflowState
|
||||
|
||||
__all__ = [
|
||||
@@ -102,6 +113,7 @@ __all__ = [
|
||||
"EXTERNAL_INPUT_EXECUTORS",
|
||||
"FUNCTION_TOOL_REGISTRY_KEY",
|
||||
"HTTP_ACTION_EXECUTORS",
|
||||
"MCP_ACTION_EXECUTORS",
|
||||
"TOOL_ACTION_EXECUTORS",
|
||||
"TOOL_APPROVAL_STATE_KEY",
|
||||
"TOOL_REGISTRY_KEY",
|
||||
@@ -126,6 +138,7 @@ __all__ = [
|
||||
"DeclarativeWorkflowError",
|
||||
"DeclarativeWorkflowState",
|
||||
"DefaultHttpRequestHandler",
|
||||
"DefaultMCPToolHandler",
|
||||
"EmitEventExecutor",
|
||||
"EndConversationExecutor",
|
||||
"EndWorkflowExecutor",
|
||||
@@ -140,9 +153,14 @@ __all__ = [
|
||||
"HttpRequestResult",
|
||||
"InvokeAzureAgentExecutor",
|
||||
"InvokeFunctionToolExecutor",
|
||||
"InvokeMcpToolActionExecutor",
|
||||
"JoinExecutor",
|
||||
"LoopControl",
|
||||
"LoopIterationResult",
|
||||
"MCPToolApprovalRequest",
|
||||
"MCPToolHandler",
|
||||
"MCPToolInvocation",
|
||||
"MCPToolResult",
|
||||
"QuestionExecutor",
|
||||
"RequestExternalInputExecutor",
|
||||
"ResetVariableExecutor",
|
||||
|
||||
+22
@@ -41,8 +41,10 @@ from ._executors_control_flow import (
|
||||
)
|
||||
from ._executors_external_input import EXTERNAL_INPUT_EXECUTORS
|
||||
from ._executors_http import HTTP_ACTION_EXECUTORS, HttpRequestActionExecutor
|
||||
from ._executors_mcp import MCP_ACTION_EXECUTORS, InvokeMcpToolActionExecutor
|
||||
from ._executors_tools import TOOL_ACTION_EXECUTORS, InvokeFunctionToolExecutor
|
||||
from ._http_handler import HttpRequestHandler
|
||||
from ._mcp_handler import MCPToolHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,6 +57,7 @@ ALL_ACTION_EXECUTORS = {
|
||||
**EXTERNAL_INPUT_EXECUTORS,
|
||||
**TOOL_ACTION_EXECUTORS,
|
||||
**HTTP_ACTION_EXECUTORS,
|
||||
**MCP_ACTION_EXECUTORS,
|
||||
}
|
||||
|
||||
# Action kinds that terminate control flow (no fall-through to successor)
|
||||
@@ -90,6 +93,7 @@ ACTION_REQUIRED_FIELDS: dict[str, list[str]] = {
|
||||
"EmitEvent": ["event"],
|
||||
"InvokeFunctionTool": ["functionName"],
|
||||
"HttpRequestAction": ["url"],
|
||||
"InvokeMcpTool": ["serverUrl", "toolName"],
|
||||
}
|
||||
|
||||
# Alternate field names that satisfy required field requirements
|
||||
@@ -135,6 +139,7 @@ class DeclarativeWorkflowBuilder:
|
||||
validate: bool = True,
|
||||
max_iterations: int | None = None,
|
||||
http_request_handler: HttpRequestHandler | None = None,
|
||||
mcp_tool_handler: MCPToolHandler | None = None,
|
||||
):
|
||||
"""Initialize the builder.
|
||||
|
||||
@@ -150,6 +155,9 @@ class DeclarativeWorkflowBuilder:
|
||||
http_request_handler: Handler used to dispatch HttpRequestAction requests.
|
||||
Must be supplied when the workflow contains any HttpRequestAction;
|
||||
otherwise build raises ``DeclarativeWorkflowError``.
|
||||
mcp_tool_handler: Handler used to dispatch InvokeMcpTool calls.
|
||||
Must be supplied when the workflow contains any InvokeMcpTool;
|
||||
otherwise build raises ``DeclarativeWorkflowError``.
|
||||
"""
|
||||
self._yaml_def = yaml_definition
|
||||
self._workflow_id = workflow_id or yaml_definition.get("name", "declarative_workflow")
|
||||
@@ -162,6 +170,7 @@ class DeclarativeWorkflowBuilder:
|
||||
self._validate = validate
|
||||
self._seen_explicit_ids: set[str] = set() # Track explicit IDs for duplicate detection
|
||||
self._http_request_handler = http_request_handler
|
||||
self._mcp_tool_handler = mcp_tool_handler
|
||||
# Resolve max_iterations: explicit arg > YAML maxTurns > core default
|
||||
resolved = max_iterations if max_iterations is not None else yaml_definition.get("maxTurns")
|
||||
if resolved is not None and (not isinstance(resolved, int) or resolved <= 0):
|
||||
@@ -481,6 +490,19 @@ class DeclarativeWorkflowBuilder:
|
||||
id=action_id,
|
||||
http_request_handler=self._http_request_handler,
|
||||
)
|
||||
elif kind == "InvokeMcpTool":
|
||||
if self._mcp_tool_handler is None:
|
||||
raise DeclarativeWorkflowError(
|
||||
f"Workflow defines InvokeMcpTool '{action_id}' but no "
|
||||
"mcp_tool_handler was supplied to WorkflowFactory. Pass "
|
||||
"mcp_tool_handler=DefaultMCPToolHandler() (or a custom "
|
||||
"implementation) to enable MCP tool invocations."
|
||||
)
|
||||
executor = InvokeMcpToolActionExecutor(
|
||||
action_def,
|
||||
id=action_id,
|
||||
mcp_tool_handler=self._mcp_tool_handler,
|
||||
)
|
||||
else:
|
||||
executor = executor_class(action_def, id=action_id)
|
||||
self._executors[action_id] = executor
|
||||
|
||||
@@ -0,0 +1,614 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Executor for the ``InvokeMcpTool`` declarative action.
|
||||
|
||||
Mirrors the .NET ``InvokeMcpToolExecutor``: dispatches an MCP tool call through
|
||||
the configured :class:`MCPToolHandler`, parses tool outputs, and routes
|
||||
results to the configured ``output.{result, messages, autoSend}`` paths and
|
||||
optional conversation history. Supports a human-in-loop approval flow via
|
||||
``ctx.request_info()`` / :func:`@response_handler` for ``requireApproval=true``.
|
||||
|
||||
Security notes:
|
||||
|
||||
- The executor never echoes header VALUES (auth tokens, API keys) into the
|
||||
approval request — only header NAMES are surfaced to the caller. This
|
||||
matches the security posture of :mod:`._executors_http` (which never logs
|
||||
request headers either) and prevents secrets from leaking through workflow
|
||||
events that are typically observable to operators / UIs.
|
||||
- ``_MCPToolApprovalState`` snapshots the EVALUATED values for non-secret
|
||||
fields (server URL, tool name, arguments) at approval-request time so that
|
||||
subsequent state mutations cannot make the executor "approve X then call
|
||||
Y". Headers are stored as the raw expression strings (not evaluated values)
|
||||
so secrets are not persisted in the workflow's checkpoint state. They are
|
||||
re-evaluated on resume.
|
||||
- Tool outputs flow back into agent conversations through ``conversationId``
|
||||
and through Tool-role messages emitted to ``output.messages``. They share
|
||||
the same prompt-injection risk surface as ``HttpRequestAction``: workflow
|
||||
authors must trust the MCP server they invoke.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
response_handler,
|
||||
)
|
||||
from agent_framework.exceptions import ToolExecutionException
|
||||
|
||||
from ._declarative_base import (
|
||||
ActionComplete,
|
||||
DeclarativeActionExecutor,
|
||||
DeclarativeWorkflowState,
|
||||
)
|
||||
from ._executors_tools import ToolApprovalResponse
|
||||
from ._mcp_handler import MCPToolHandler, MCPToolInvocation, MCPToolResult
|
||||
|
||||
__all__ = [
|
||||
"MCP_ACTION_EXECUTORS",
|
||||
"InvokeMcpToolActionExecutor",
|
||||
"MCPToolApprovalRequest",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / state types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolApprovalRequest:
|
||||
"""Approval request emitted before invoking an MCP tool.
|
||||
|
||||
Mirrors :class:`agent_framework_declarative.ToolApprovalRequest` but for
|
||||
MCP-style invocations. Only header NAMES are surfaced — header values are
|
||||
intentionally omitted because they typically carry authentication
|
||||
secrets.
|
||||
|
||||
Attributes:
|
||||
request_id: Unique identifier for this approval request. Matches the
|
||||
id workflow event-emitters use.
|
||||
tool_name: Evaluated name of the tool to be invoked.
|
||||
server_url: Evaluated MCP server URL.
|
||||
server_label: Optional human-readable label for diagnostics.
|
||||
arguments: Evaluated arguments to be forwarded to the tool.
|
||||
header_names: Sorted list of outbound header names (no values). Empty
|
||||
when no headers are configured.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
tool_name: str
|
||||
server_url: str
|
||||
server_label: str | None
|
||||
arguments: dict[str, Any]
|
||||
header_names: list[str] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MCPToolApprovalState:
|
||||
"""Internal state saved during the approval yield for resumption.
|
||||
|
||||
Stores **evaluated** values for non-secret fields to prevent
|
||||
"approve X / execute Y" attacks. Stores the raw expression string for
|
||||
``headers`` so that secret values are NOT persisted in checkpoint state;
|
||||
the expressions are re-evaluated against current state on resume.
|
||||
"""
|
||||
|
||||
server_url: str
|
||||
tool_name: str
|
||||
server_label: str | None
|
||||
arguments: dict[str, Any]
|
||||
connection_name: str | None
|
||||
headers_def: Any
|
||||
auto_send: bool
|
||||
conversation_id_expr: str | None
|
||||
output_messages_path: str | None
|
||||
output_result_path: str | None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_messages_path(state: DeclarativeWorkflowState, conversation_id_expr: str | None) -> str | None:
|
||||
"""Return the configured conversation messages path, if any.
|
||||
|
||||
Returns ``System.conversations.{evaluated_id}.messages`` when a
|
||||
``conversation_id_expr`` is configured and evaluates to a non-empty value.
|
||||
Returns ``None`` when no conversation id expression is configured or when
|
||||
the expression evaluates to ``None`` or an empty string (mirrors .NET
|
||||
``GetConversationId`` behaviour).
|
||||
"""
|
||||
if not conversation_id_expr:
|
||||
return None
|
||||
evaluated = state.eval_if_expression(conversation_id_expr)
|
||||
if evaluated is None or (isinstance(evaluated, str) and not evaluated):
|
||||
return None
|
||||
return f"System.conversations.{evaluated}.messages"
|
||||
|
||||
|
||||
def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None:
|
||||
"""Extract a state path from ``output.{key}`` field.
|
||||
|
||||
Supports two YAML shapes:
|
||||
|
||||
- ``output: { result: Local.MyVar }`` — plain string.
|
||||
- ``output: { result: { path: Local.MyVar } }`` — object form.
|
||||
"""
|
||||
output: Any = action_def.get("output")
|
||||
if not isinstance(output, Mapping):
|
||||
return None
|
||||
value: Any = output.get(key) # type: ignore[reportUnknownMemberType]
|
||||
if isinstance(value, str):
|
||||
return value or None
|
||||
if isinstance(value, Mapping):
|
||||
path: Any = value.get("path") # type: ignore[reportUnknownMemberType]
|
||||
return path if isinstance(path, str) and path else None
|
||||
return None
|
||||
|
||||
|
||||
def _format_outputs_for_send(parsed_results: list[Any]) -> str:
|
||||
"""Render parsed MCP outputs to a string for ``ctx.yield_output(...)``.
|
||||
|
||||
- Empty list → ``""``.
|
||||
- All-string list → newline-joined.
|
||||
- Single element (any type — scalar, dict, list) → JSON-dumped element.
|
||||
This avoids surprising ``"[42]"`` / ``"[true]"`` / ``"[null]"`` when
|
||||
an MCP tool returns a single scalar JSON value.
|
||||
- Multi-element non-string list → JSON-dump the whole list.
|
||||
"""
|
||||
if not parsed_results:
|
||||
return ""
|
||||
if all(isinstance(item, str) for item in parsed_results):
|
||||
return "\n".join(parsed_results) # type: ignore[arg-type]
|
||||
if len(parsed_results) == 1:
|
||||
return json.dumps(parsed_results[0], ensure_ascii=False)
|
||||
return json.dumps(parsed_results, ensure_ascii=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
"""Executor for the ``InvokeMcpTool`` declarative action.
|
||||
|
||||
Dispatches through the supplied :class:`MCPToolHandler` and:
|
||||
|
||||
- Evaluates ``serverUrl`` / ``toolName`` / ``serverLabel`` / ``arguments``
|
||||
/ ``headers`` / ``connection.name`` from the action definition.
|
||||
- When ``requireApproval=true``: emits a :class:`MCPToolApprovalRequest`
|
||||
via ``ctx.request_info()`` and yields. On resume, the response is
|
||||
checked; on rejection, ``output.result`` is set to ``"Error: ..."`` and
|
||||
no tool call is made.
|
||||
- On success: parses each :class:`agent_framework.Content` output (text →
|
||||
JSON-first / data / uri → URI string) and assigns the parsed list to
|
||||
``output.result``. Builds a single Tool-role :class:`Message`
|
||||
containing all output contents and assigns it to ``output.messages``.
|
||||
When ``output.autoSend`` is true (default), emits the rendered string
|
||||
via ``ctx.yield_output(...)``. When ``conversationId`` is configured,
|
||||
appends an Assistant-role :class:`Message` with the same contents to
|
||||
``System.conversations.{id}.messages``.
|
||||
- On error returned by the handler (``is_error=True``): assigns
|
||||
``"Error: <message>"`` to ``output.result`` and completes normally
|
||||
(parity with .NET ``AssignErrorAsync``).
|
||||
|
||||
.. note::
|
||||
|
||||
``output.messages`` receives a SINGLE Tool-role :class:`Message`
|
||||
(containing the full tool output as ``contents``), unlike
|
||||
:class:`agent_framework_declarative.InvokeFunctionToolExecutor` which
|
||||
writes a list of two messages (assistant call + tool result). This
|
||||
matches the .NET ``InvokeMcpToolExecutor`` output contract.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_def: dict[str, Any],
|
||||
*,
|
||||
id: str | None = None,
|
||||
mcp_tool_handler: MCPToolHandler,
|
||||
) -> None:
|
||||
"""Create an MCP tool action executor.
|
||||
|
||||
Args:
|
||||
action_def: Parsed ``InvokeMcpTool`` YAML dict.
|
||||
id: Optional executor id (defaults to action id or generated).
|
||||
mcp_tool_handler: Handler used to dispatch MCP tool calls.
|
||||
Required: the builder enforces presence at workflow-build
|
||||
time.
|
||||
"""
|
||||
super().__init__(action_def, id=id)
|
||||
self._mcp_tool_handler = mcp_tool_handler
|
||||
|
||||
# ----- Main handler --------------------------------------------------------
|
||||
|
||||
@handler
|
||||
async def handle_action(
|
||||
self,
|
||||
trigger: Any,
|
||||
ctx: WorkflowContext[ActionComplete, str],
|
||||
) -> None:
|
||||
"""Execute the MCP tool action."""
|
||||
state = await self._ensure_state_initialized(ctx, trigger)
|
||||
|
||||
server_url = self._get_server_url(state)
|
||||
tool_name = self._get_tool_name(state)
|
||||
server_label = self._get_server_label(state)
|
||||
arguments = self._get_arguments(state)
|
||||
headers = self._get_headers(state)
|
||||
connection_name = self._get_connection_name(state)
|
||||
require_approval = self._get_require_approval(state)
|
||||
auto_send = self._get_auto_send(state)
|
||||
conversation_id_expr = self._action_def.get("conversationId")
|
||||
output_messages_path = _get_output_path(self._action_def, "messages")
|
||||
output_result_path = _get_output_path(self._action_def, "result")
|
||||
|
||||
if require_approval:
|
||||
request_id = str(uuid.uuid4())
|
||||
approval_state = _MCPToolApprovalState(
|
||||
server_url=server_url,
|
||||
tool_name=tool_name,
|
||||
server_label=server_label,
|
||||
arguments=arguments,
|
||||
connection_name=connection_name,
|
||||
headers_def=self._action_def.get("headers"),
|
||||
auto_send=auto_send,
|
||||
conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None,
|
||||
output_messages_path=output_messages_path,
|
||||
output_result_path=output_result_path,
|
||||
)
|
||||
ctx.state.set(self._approval_key(), approval_state)
|
||||
|
||||
request = MCPToolApprovalRequest(
|
||||
request_id=request_id,
|
||||
tool_name=tool_name,
|
||||
server_url=server_url,
|
||||
server_label=server_label,
|
||||
arguments=arguments,
|
||||
header_names=sorted(headers.keys()),
|
||||
)
|
||||
logger.info(
|
||||
"%s: requesting approval for MCP tool '%s' on '%s'",
|
||||
self.__class__.__name__,
|
||||
tool_name,
|
||||
server_url,
|
||||
)
|
||||
await ctx.request_info(request, ToolApprovalResponse, request_id=request_id)
|
||||
# Workflow yields here — resume in handle_approval_response.
|
||||
return
|
||||
|
||||
# No approval required - invoke directly.
|
||||
invocation = MCPToolInvocation(
|
||||
server_url=server_url,
|
||||
tool_name=tool_name,
|
||||
server_label=server_label,
|
||||
arguments=arguments,
|
||||
headers=headers,
|
||||
connection_name=connection_name,
|
||||
)
|
||||
result = await self._invoke_with_narrow_catch(invocation)
|
||||
await self._process_result(
|
||||
ctx=ctx,
|
||||
state=state,
|
||||
result=result,
|
||||
auto_send=auto_send,
|
||||
conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None,
|
||||
output_messages_path=output_messages_path,
|
||||
output_result_path=output_result_path,
|
||||
)
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
# ----- Approval response handler ------------------------------------------
|
||||
|
||||
@response_handler
|
||||
async def handle_approval_response(
|
||||
self,
|
||||
original_request: MCPToolApprovalRequest,
|
||||
response: ToolApprovalResponse,
|
||||
ctx: WorkflowContext[ActionComplete, str],
|
||||
) -> None:
|
||||
"""Resume after the workflow yielded for an approval request."""
|
||||
state = self._get_state(ctx.state)
|
||||
approval_key = self._approval_key()
|
||||
|
||||
try:
|
||||
approval_state: _MCPToolApprovalState = ctx.state.get(approval_key)
|
||||
except KeyError:
|
||||
logger.error("%s: approval state missing for executor '%s'", self.__class__.__name__, self.id)
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
try:
|
||||
ctx.state.delete(approval_key)
|
||||
except KeyError:
|
||||
logger.warning("%s: approval state already deleted for '%s'", self.__class__.__name__, self.id)
|
||||
|
||||
if not response.approved:
|
||||
logger.info(
|
||||
"%s: MCP tool '%s' rejected: %s",
|
||||
self.__class__.__name__,
|
||||
approval_state.tool_name,
|
||||
response.reason,
|
||||
)
|
||||
self._assign_error(
|
||||
state, approval_state.output_result_path, "MCP tool invocation was not approved by user."
|
||||
)
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
# Approved — re-evaluate headers (not stored at approval time for security).
|
||||
headers = self._evaluate_headers(state, approval_state.headers_def)
|
||||
|
||||
invocation = MCPToolInvocation(
|
||||
server_url=approval_state.server_url,
|
||||
tool_name=approval_state.tool_name,
|
||||
server_label=approval_state.server_label,
|
||||
arguments=approval_state.arguments,
|
||||
headers=headers,
|
||||
connection_name=approval_state.connection_name,
|
||||
)
|
||||
result = await self._invoke_with_narrow_catch(invocation)
|
||||
await self._process_result(
|
||||
ctx=ctx,
|
||||
state=state,
|
||||
result=result,
|
||||
auto_send=approval_state.auto_send,
|
||||
conversation_id_expr=approval_state.conversation_id_expr,
|
||||
output_messages_path=approval_state.output_messages_path,
|
||||
output_result_path=approval_state.output_result_path,
|
||||
)
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
# ----- Field resolution ----------------------------------------------------
|
||||
|
||||
def _get_server_url(self, state: DeclarativeWorkflowState) -> str:
|
||||
raw = self._action_def.get("serverUrl")
|
||||
if raw is None:
|
||||
raise ValueError("InvokeMcpTool requires a 'serverUrl' field.")
|
||||
evaluated = state.eval_if_expression(raw)
|
||||
if not isinstance(evaluated, str) or not evaluated:
|
||||
raise ValueError("InvokeMcpTool 'serverUrl' evaluated to an empty value.")
|
||||
return evaluated
|
||||
|
||||
def _get_tool_name(self, state: DeclarativeWorkflowState) -> str:
|
||||
raw = self._action_def.get("toolName")
|
||||
if raw is None:
|
||||
raise ValueError("InvokeMcpTool requires a 'toolName' field.")
|
||||
evaluated = state.eval_if_expression(raw)
|
||||
if not isinstance(evaluated, str) or not evaluated:
|
||||
raise ValueError("InvokeMcpTool 'toolName' evaluated to an empty value.")
|
||||
return evaluated
|
||||
|
||||
def _get_server_label(self, state: DeclarativeWorkflowState) -> str | None:
|
||||
raw = self._action_def.get("serverLabel")
|
||||
if raw is None:
|
||||
return None
|
||||
evaluated = state.eval_if_expression(raw)
|
||||
if evaluated is None:
|
||||
return None
|
||||
text = str(evaluated)
|
||||
return text or None
|
||||
|
||||
def _get_arguments(self, state: DeclarativeWorkflowState) -> dict[str, Any]:
|
||||
"""Evaluate ``arguments`` map. Preserves ``None`` values (parity with .NET)."""
|
||||
raw = self._action_def.get("arguments")
|
||||
if raw is None:
|
||||
return {}
|
||||
if not isinstance(raw, Mapping) or not raw:
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in raw.items(): # type: ignore[reportUnknownVariableType]
|
||||
if not isinstance(key, str) or not key:
|
||||
continue
|
||||
result[key] = state.eval_if_expression(value)
|
||||
return result
|
||||
|
||||
def _get_headers(self, state: DeclarativeWorkflowState) -> dict[str, str]:
|
||||
return self._evaluate_headers(state, self._action_def.get("headers"))
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_headers(state: DeclarativeWorkflowState, headers_def: Any) -> dict[str, str]:
|
||||
"""Evaluate the ``headers`` map. Empty string values are skipped."""
|
||||
if not isinstance(headers_def, Mapping) or not headers_def:
|
||||
return {}
|
||||
result: dict[str, str] = {}
|
||||
for key, value in headers_def.items(): # type: ignore[reportUnknownVariableType]
|
||||
if not isinstance(key, str) or not key:
|
||||
continue
|
||||
evaluated = state.eval_if_expression(value)
|
||||
if evaluated is None:
|
||||
continue
|
||||
text = str(evaluated)
|
||||
if not text:
|
||||
continue
|
||||
result[key] = text
|
||||
return result
|
||||
|
||||
def _get_connection_name(self, state: DeclarativeWorkflowState) -> str | None:
|
||||
connection = self._action_def.get("connection")
|
||||
if not isinstance(connection, Mapping):
|
||||
return None
|
||||
name_expr: Any = connection.get("name") # type: ignore[reportUnknownMemberType]
|
||||
if name_expr is None:
|
||||
return None
|
||||
evaluated = state.eval_if_expression(name_expr)
|
||||
if evaluated is None:
|
||||
return None
|
||||
text = str(evaluated)
|
||||
return text or None
|
||||
|
||||
def _get_require_approval(self, state: DeclarativeWorkflowState) -> bool:
|
||||
raw = self._action_def.get("requireApproval")
|
||||
if raw is None:
|
||||
return False
|
||||
evaluated = state.eval_if_expression(raw)
|
||||
if isinstance(evaluated, bool):
|
||||
return evaluated
|
||||
if isinstance(evaluated, str):
|
||||
return evaluated.strip().lower() in {"true", "1", "yes"}
|
||||
return bool(evaluated)
|
||||
|
||||
def _get_auto_send(self, state: DeclarativeWorkflowState) -> bool:
|
||||
output: Any = self._action_def.get("output")
|
||||
if not isinstance(output, Mapping):
|
||||
return True
|
||||
raw: Any = output.get("autoSend") # type: ignore[reportUnknownMemberType]
|
||||
if raw is None:
|
||||
return True
|
||||
evaluated = state.eval_if_expression(raw)
|
||||
if isinstance(evaluated, bool):
|
||||
return evaluated
|
||||
if isinstance(evaluated, str):
|
||||
return evaluated.strip().lower() in {"true", "1", "yes"}
|
||||
return bool(evaluated)
|
||||
|
||||
# ----- Invocation + error handling ----------------------------------------
|
||||
|
||||
async def _invoke_with_narrow_catch(self, invocation: MCPToolInvocation) -> MCPToolResult:
|
||||
"""Invoke the handler with a narrow exception catch.
|
||||
|
||||
Only known transport / tool exceptions are normalised to an error
|
||||
result. Programmer bugs (TypeError, ValueError from misuse, etc.)
|
||||
propagate so they fail loudly.
|
||||
|
||||
``asyncio.CancelledError`` is a ``BaseException``, not ``Exception``,
|
||||
so it is not caught here and propagates unchanged for workflow
|
||||
cancellation.
|
||||
"""
|
||||
try:
|
||||
return await self._mcp_tool_handler.invoke_tool(invocation)
|
||||
except ToolExecutionException as exc:
|
||||
message = str(exc) or type(exc).__name__
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
message = f"{type(exc).__name__}: {exc}" if str(exc) else type(exc).__name__
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
except Exception as exc:
|
||||
try:
|
||||
from mcp.shared.exceptions import McpError
|
||||
except ImportError: # pragma: no cover - mcp is a hard dep
|
||||
raise
|
||||
if isinstance(exc, McpError):
|
||||
message = str(exc) or type(exc).__name__
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
raise
|
||||
|
||||
# ----- Result handling -----------------------------------------------------
|
||||
|
||||
async def _process_result(
|
||||
self,
|
||||
*,
|
||||
ctx: WorkflowContext[ActionComplete, str],
|
||||
state: DeclarativeWorkflowState,
|
||||
result: MCPToolResult,
|
||||
auto_send: bool,
|
||||
conversation_id_expr: str | None,
|
||||
output_messages_path: str | None,
|
||||
output_result_path: str | None,
|
||||
) -> None:
|
||||
"""Apply ``result`` to workflow state per the configured output paths."""
|
||||
if result.is_error:
|
||||
# Error path mirrors .NET ``AssignErrorAsync`` — only the result
|
||||
# path is touched; messages / autoSend / conversation are not.
|
||||
self._assign_error(
|
||||
state,
|
||||
output_result_path,
|
||||
result.error_message or "MCP tool invocation failed.",
|
||||
)
|
||||
return
|
||||
|
||||
parsed_results = _parse_outputs(result.outputs)
|
||||
if output_result_path is not None and parsed_results:
|
||||
state.set(output_result_path, parsed_results)
|
||||
|
||||
# Single Tool-role message (matches .NET line 178 contract). Differs
|
||||
# from InvokeFunctionTool's two-message [assistant call, tool result]
|
||||
# convention.
|
||||
tool_message = Message(role="tool", contents=list(result.outputs))
|
||||
if output_messages_path is not None:
|
||||
state.set(output_messages_path, tool_message)
|
||||
|
||||
if auto_send and parsed_results:
|
||||
await ctx.yield_output(_format_outputs_for_send(parsed_results))
|
||||
|
||||
if conversation_id_expr:
|
||||
messages_path = _get_messages_path(state, conversation_id_expr)
|
||||
if messages_path is not None:
|
||||
# Mirrors .NET: conversation gets ASSISTANT-role message with
|
||||
# the same outputs (so chat history reads it as the agent's
|
||||
# contribution).
|
||||
assistant_message = Message(role="assistant", contents=list(result.outputs))
|
||||
state.append(messages_path, assistant_message)
|
||||
|
||||
@staticmethod
|
||||
def _assign_error(
|
||||
state: DeclarativeWorkflowState,
|
||||
output_result_path: str | None,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""Mirror .NET ``AssignErrorAsync``: store ``"Error: <msg>"`` at the result path."""
|
||||
if output_result_path is None:
|
||||
return
|
||||
state.set(output_result_path, f"Error: {error_message}")
|
||||
|
||||
def _approval_key(self) -> str:
|
||||
return f"{_MCP_APPROVAL_STATE_KEY}_{self.id}"
|
||||
|
||||
|
||||
def _parse_outputs(outputs: list[Content]) -> list[Any]:
|
||||
"""Parse :class:`Content` outputs into Python values for ``output.result``.
|
||||
|
||||
Mirrors .NET ``AssignResultAsync``:
|
||||
|
||||
- ``TextContent`` → JSON-parse text; on failure use the raw text.
|
||||
- ``DataContent`` / ``UriContent`` → ``content.uri``.
|
||||
- Other content kinds → ``str(content)``.
|
||||
"""
|
||||
parsed: list[Any] = []
|
||||
for content in outputs:
|
||||
kind = getattr(content, "type", None)
|
||||
if kind == "text":
|
||||
text_value = getattr(content, "text", None)
|
||||
text_str = "" if text_value is None else str(text_value)
|
||||
try:
|
||||
parsed.append(json.loads(text_str))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
parsed.append(text_str)
|
||||
continue
|
||||
if kind in ("data", "uri"):
|
||||
uri_value = getattr(content, "uri", None)
|
||||
parsed.append("" if uri_value is None else str(uri_value))
|
||||
continue
|
||||
parsed.append(str(content))
|
||||
return parsed
|
||||
|
||||
|
||||
MCP_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = {
|
||||
"InvokeMcpTool": InvokeMcpToolActionExecutor,
|
||||
}
|
||||
@@ -29,6 +29,7 @@ from .._loader import AgentFactory
|
||||
from ._declarative_builder import DeclarativeWorkflowBuilder
|
||||
from ._errors import DeclarativeWorkflowError
|
||||
from ._http_handler import HttpRequestHandler
|
||||
from ._mcp_handler import MCPToolHandler
|
||||
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
@@ -91,6 +92,7 @@ class WorkflowFactory:
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
max_iterations: int | None = None,
|
||||
http_request_handler: HttpRequestHandler | None = None,
|
||||
mcp_tool_handler: MCPToolHandler | None = None,
|
||||
) -> None:
|
||||
"""Initialize the workflow factory.
|
||||
|
||||
@@ -110,6 +112,13 @@ class WorkflowFactory:
|
||||
otherwise. Use :class:`agent_framework.declarative.DefaultHttpRequestHandler`
|
||||
for a no-policy ``httpx``-based default, or supply your own implementation
|
||||
to enforce SSRF guards, allowlisting, or auth resolution.
|
||||
mcp_tool_handler: Optional handler used to dispatch MCP tool calls for
|
||||
``InvokeMcpTool``. Required if the workflow contains any
|
||||
``InvokeMcpTool``; build will fail with :class:`DeclarativeWorkflowError`
|
||||
otherwise. Use :class:`agent_framework.declarative.DefaultMCPToolHandler`
|
||||
for a default backed by :class:`agent_framework.MCPStreamableHTTPTool`,
|
||||
or supply your own implementation to enforce SSRF guards, allowlisting,
|
||||
or auth/connection resolution.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -150,6 +159,7 @@ class WorkflowFactory:
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
self._max_iterations = max_iterations
|
||||
self._http_request_handler = http_request_handler
|
||||
self._mcp_tool_handler = mcp_tool_handler
|
||||
|
||||
def create_workflow_from_yaml_path(
|
||||
self,
|
||||
@@ -394,6 +404,7 @@ class WorkflowFactory:
|
||||
checkpoint_storage=self._checkpoint_storage,
|
||||
max_iterations=self._max_iterations,
|
||||
http_request_handler=self._http_request_handler,
|
||||
mcp_tool_handler=self._mcp_tool_handler,
|
||||
)
|
||||
workflow = graph_builder.build()
|
||||
except ValueError as e:
|
||||
|
||||
@@ -0,0 +1,494 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""MCP tool handler abstraction for declarative workflows.
|
||||
|
||||
Mirrors the .NET ``IMcpToolHandler`` / ``DefaultMcpToolHandler`` pair from
|
||||
``Microsoft.Agents.AI.Workflows.Declarative.Mcp``. Provides:
|
||||
|
||||
- :class:`MCPToolInvocation` — request input data passed from the executor.
|
||||
- :class:`MCPToolResult` — response data returned to the executor.
|
||||
- :class:`MCPToolHandler` — :class:`typing.Protocol` callers implement to plug
|
||||
in custom transports (e.g. with allowlisting, Foundry connection resolution,
|
||||
per-server auth, etc.).
|
||||
- :class:`DefaultMCPToolHandler` — production-grade default backed by
|
||||
:class:`agent_framework.MCPStreamableHTTPTool`.
|
||||
|
||||
Security note: :class:`DefaultMCPToolHandler` performs **no** URL filtering or
|
||||
SSRF protection. Production deployments should supply a custom handler that
|
||||
enforces an allowlist or DNS-rebinding-resistant policy. This split mirrors the
|
||||
.NET design.
|
||||
|
||||
Prompt-injection note: MCP tool outputs flow back into agent conversations
|
||||
(via ``conversationId`` and Tool-role messages emitted by the executor) so
|
||||
they share the same risk surface as ``HttpRequestAction``. Workflow authors
|
||||
must trust the MCP server they invoke.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast, runtime_checkable
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import Content
|
||||
|
||||
__all__ = [
|
||||
"ClientProvider",
|
||||
"DefaultMCPToolHandler",
|
||||
"MCPToolHandler",
|
||||
"MCPToolInvocation",
|
||||
"MCPToolResult",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_CACHE_MAX_SIZE = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolInvocation:
|
||||
"""Description of an MCP tool call to be dispatched by a :class:`MCPToolHandler`.
|
||||
|
||||
Mirrors the input parameters of the .NET ``IMcpToolHandler.InvokeToolAsync``
|
||||
method. Field semantics:
|
||||
|
||||
- ``server_url``: Absolute URL of the MCP server. Already evaluated from
|
||||
the YAML expression.
|
||||
- ``server_label``: Optional human-readable label used for diagnostics
|
||||
and as the underlying ``MCPStreamableHTTPTool`` name.
|
||||
- ``tool_name``: Name of the tool to invoke on the MCP server.
|
||||
- ``arguments``: Tool arguments. Already evaluated; values may be any
|
||||
JSON-serialisable Python object (str, int, bool, dict, list, None).
|
||||
- ``headers``: Outbound HTTP headers (e.g. authentication). Empty values
|
||||
are skipped by the executor before construction.
|
||||
- ``connection_name``: Optional Foundry connection name forwarded for
|
||||
handlers that resolve auth/credentials by connection. The default
|
||||
handler does not consume this field.
|
||||
"""
|
||||
|
||||
server_url: str
|
||||
tool_name: str
|
||||
server_label: str | None = None
|
||||
arguments: dict[str, Any] = field(default_factory=dict) # type: ignore[reportUnknownVariableType]
|
||||
headers: dict[str, str] = field(default_factory=dict) # type: ignore[reportUnknownVariableType]
|
||||
connection_name: str | None = None
|
||||
|
||||
|
||||
def _empty_outputs() -> list[Any]:
|
||||
"""Default factory for ``MCPToolResult.outputs``.
|
||||
|
||||
Typed as ``list[Any]`` here to keep the dataclass field's runtime
|
||||
factory simple; the public type on :class:`MCPToolResult` is
|
||||
``list[Content]``.
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPToolResult:
|
||||
"""Response returned by an :class:`MCPToolHandler`.
|
||||
|
||||
Mirrors the .NET ``McpServerToolResultContent`` shape. ``outputs`` is a
|
||||
list of :class:`agent_framework.Content` items as parsed by the MCP
|
||||
transport (TextContent / DataContent / UriContent / etc.).
|
||||
|
||||
On error, ``is_error`` is ``True``, ``error_message`` carries a human
|
||||
readable description, and ``outputs`` typically contains a single
|
||||
``Content.from_text("Error: ...")`` entry for downstream display.
|
||||
"""
|
||||
|
||||
outputs: list[Content] = field(default_factory=_empty_outputs)
|
||||
is_error: bool = False
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MCPToolHandler(Protocol):
|
||||
"""Protocol for MCP tool handlers used by ``InvokeMcpTool``.
|
||||
|
||||
Mirrors :class:`HttpRequestHandler` — declares ONLY the invocation method.
|
||||
Lifecycle methods (``aclose`` / ``__aenter__`` / ``__aexit__``) are NOT
|
||||
part of the Protocol; concrete implementations may add them as
|
||||
appropriate.
|
||||
|
||||
Implementations must be safe to call concurrently from multiple workflow
|
||||
runs. Implementations are responsible for any URL allowlisting, SSRF
|
||||
guards, retry policies, auth resolution, and other policies the workflow
|
||||
author wants applied.
|
||||
"""
|
||||
|
||||
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
|
||||
"""Dispatch ``invocation`` and return the result.
|
||||
|
||||
Args:
|
||||
invocation: Description of the MCP tool call to perform.
|
||||
|
||||
Returns:
|
||||
The :class:`MCPToolResult` carrying the parsed outputs (or an
|
||||
error flag if the tool raised). Implementations SHOULD return a
|
||||
result with ``is_error=True`` rather than raising for transport
|
||||
or tool-level failures, so the workflow can store the message in
|
||||
``output.result`` (matching .NET ``AssignErrorAsync`` behaviour).
|
||||
They MAY raise on unexpected programming errors — these will be
|
||||
propagated unchanged by the executor so they fail loudly.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
ClientProvider = Callable[[MCPToolInvocation], Awaitable["httpx.AsyncClient | None"]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CacheEntry:
|
||||
"""Internal record stored in the LRU cache."""
|
||||
|
||||
tool: Any # MCPStreamableHTTPTool — typed Any to avoid import at module load
|
||||
owned_httpx_client: httpx.AsyncClient | None
|
||||
|
||||
|
||||
class DefaultMCPToolHandler:
|
||||
"""Default :class:`MCPToolHandler` backed by :class:`agent_framework.MCPStreamableHTTPTool`.
|
||||
|
||||
Caches one :class:`agent_framework.MCPStreamableHTTPTool` instance per
|
||||
``(server_url, server_label, connection_name, headers_hash)`` in a
|
||||
bounded LRU. The cache prevents re-establishing an MCP session for every
|
||||
invocation while ensuring different header sets (auth tokens) cannot
|
||||
share a session — matches the .NET design intent while bounding
|
||||
cardinality. ``server_label`` and ``connection_name`` participate in
|
||||
the key so that callers using ``client_provider`` to dispatch on those
|
||||
fields receive a fresh client per logical connection (see below).
|
||||
Header *names* are lower-cased inside the hash payload only — the
|
||||
headers passed on the wire keep the caller's original casing — so two
|
||||
YAML actions that spell ``Authorization`` differently still share a
|
||||
cache entry.
|
||||
|
||||
Construction modes:
|
||||
|
||||
1. ``DefaultMCPToolHandler()`` — owns its own ``httpx.AsyncClient``
|
||||
instances created lazily per cache entry. Closed by :meth:`aclose`.
|
||||
2. ``DefaultMCPToolHandler(client_provider=cb)`` — per-server client
|
||||
lookup (parity with .NET ``httpClientProvider`` callback). The
|
||||
callback receives the full :class:`MCPToolInvocation` so it can
|
||||
dispatch on ``server_url`` / ``connection_name`` / ``server_label``.
|
||||
Returning ``None`` falls back to an internally-created client. Caller
|
||||
supplied clients are NOT closed by :meth:`aclose`.
|
||||
|
||||
.. warning::
|
||||
|
||||
This handler performs **no** URL filtering or SSRF protection. Wrap
|
||||
or replace it with a custom handler in production deployments.
|
||||
|
||||
Args:
|
||||
client_provider: Optional per-server ``httpx.AsyncClient`` provider.
|
||||
cache_max_size: Maximum number of cached MCP clients. When exceeded,
|
||||
the least-recently-used entry is evicted and its client closed
|
||||
(only owned clients are closed; caller-supplied ones are not).
|
||||
Defaults to ``32``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client_provider: ClientProvider | None = None,
|
||||
cache_max_size: int = _DEFAULT_CACHE_MAX_SIZE,
|
||||
) -> None:
|
||||
if cache_max_size <= 0:
|
||||
raise ValueError(f"cache_max_size must be positive, got {cache_max_size}")
|
||||
self._client_provider = client_provider
|
||||
self._cache_max_size = cache_max_size
|
||||
self._cache: OrderedDict[tuple[str, str, str, str], _CacheEntry] = OrderedDict()
|
||||
# Outer lock guards the cache + in-flight-future map only — never
|
||||
# held across network I/O.
|
||||
self._cache_lock = asyncio.Lock()
|
||||
# Per-key in-flight futures: while one task is connecting, other
|
||||
# tasks awaiting the same key will await the same future and share
|
||||
# the resulting cache entry.
|
||||
self._inflight: dict[tuple[str, str, str, str], asyncio.Future[_CacheEntry]] = {}
|
||||
# Set by ``aclose`` to prevent post-close cache insertions and to
|
||||
# reject new ``invoke_tool`` calls. Once set, never cleared.
|
||||
self._closed = False
|
||||
|
||||
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
|
||||
"""Invoke ``invocation.tool_name`` on the cached MCP client for the server."""
|
||||
from agent_framework import Content
|
||||
from agent_framework.exceptions import ToolExecutionException
|
||||
|
||||
try:
|
||||
entry = await self._get_or_create_entry(invocation)
|
||||
except Exception as exc:
|
||||
# Connect / cache lookup failures surface as tool errors so the
|
||||
# workflow can store them at output.result without crashing.
|
||||
logger.warning(
|
||||
"DefaultMCPToolHandler: failed to obtain MCP client for url=%s tool=%s: %s",
|
||||
invocation.server_url,
|
||||
invocation.tool_name,
|
||||
exc,
|
||||
)
|
||||
message = f"Failed to connect to MCP server: {type(exc).__name__}: {exc}".rstrip(": ")
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
|
||||
try:
|
||||
raw = await entry.tool.call_tool(invocation.tool_name, **invocation.arguments)
|
||||
except ToolExecutionException as exc:
|
||||
logger.info(
|
||||
"DefaultMCPToolHandler: tool '%s' on '%s' raised ToolExecutionException",
|
||||
invocation.tool_name,
|
||||
invocation.server_url,
|
||||
)
|
||||
message = str(exc) or type(exc).__name__
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
message = f"{type(exc).__name__}: {exc}" if str(exc) else type(exc).__name__
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
except Exception as exc:
|
||||
# Be defensive about MCP errors that may bubble up without being
|
||||
# wrapped in ToolExecutionException by custom parsers.
|
||||
try:
|
||||
from mcp.shared.exceptions import McpError
|
||||
except ImportError: # pragma: no cover - mcp is a hard dep but stay defensive
|
||||
raise
|
||||
if isinstance(exc, McpError):
|
||||
message = str(exc) or type(exc).__name__
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
raise
|
||||
|
||||
# Defensive normalisation: call_tool is typed ``str | list[Content]``.
|
||||
# Default parser returns list, but custom parse_tool_results may return str.
|
||||
if isinstance(raw, str):
|
||||
outputs: list[Content] = [Content.from_text(raw)]
|
||||
else:
|
||||
outputs = list(raw)
|
||||
return MCPToolResult(outputs=outputs)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close all cached MCP clients and the owned httpx clients.
|
||||
|
||||
Caller-supplied :class:`httpx.AsyncClient` instances (returned by the
|
||||
``client_provider`` callback) are NOT closed.
|
||||
|
||||
Idempotent — a second call returns immediately. Drains any in-flight
|
||||
``_create_entry`` tasks before returning so their resources are
|
||||
cleaned up; the in-flight tasks see ``self._closed`` in phase 3 of
|
||||
:meth:`_get_or_create_entry`, close their own entry, and resolve
|
||||
their future with ``RuntimeError("DefaultMCPToolHandler is closed")``.
|
||||
"""
|
||||
async with self._cache_lock:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
entries = list(self._cache.values())
|
||||
self._cache.clear()
|
||||
inflight_futures = list(self._inflight.values())
|
||||
|
||||
# Wait for in-flight creations to finish their self-cleanup. Each
|
||||
# in-flight task self-closes its entry under the closed-flag branch
|
||||
# in phase 3 and resolves its future with ``RuntimeError``; we
|
||||
# swallow it here because the failure is expected at shutdown.
|
||||
for fut in inflight_futures:
|
||||
try:
|
||||
await fut
|
||||
except BaseException:
|
||||
logger.debug("DefaultMCPToolHandler: in-flight future raised during aclose", exc_info=True)
|
||||
continue
|
||||
|
||||
for entry in entries:
|
||||
await self._close_entry(entry)
|
||||
|
||||
async def __aenter__(self) -> DefaultMCPToolHandler:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
await self.aclose()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_or_create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry:
|
||||
"""Look up (or create) the cached MCP client for this invocation."""
|
||||
key = self._cache_key(
|
||||
invocation.server_url,
|
||||
invocation.server_label,
|
||||
invocation.connection_name,
|
||||
invocation.headers,
|
||||
)
|
||||
|
||||
# Phase 1: check the cache and either claim creation or wait for an
|
||||
# already in-flight creation.
|
||||
creating = False
|
||||
async with self._cache_lock:
|
||||
if self._closed:
|
||||
raise RuntimeError("DefaultMCPToolHandler is closed")
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return existing
|
||||
inflight = self._inflight.get(key)
|
||||
if inflight is None:
|
||||
inflight = asyncio.get_running_loop().create_future()
|
||||
self._inflight[key] = inflight
|
||||
creating = True
|
||||
|
||||
if not creating:
|
||||
return await inflight
|
||||
|
||||
# Phase 2: we own creation. Build the entry outside the lock.
|
||||
try:
|
||||
entry = await self._create_entry(invocation)
|
||||
except BaseException as exc:
|
||||
async with self._cache_lock:
|
||||
self._inflight.pop(key, None)
|
||||
if not inflight.done():
|
||||
inflight.set_exception(exc if isinstance(exc, BaseException) else RuntimeError(str(exc)))
|
||||
# Mark the exception retrieved to suppress noisy "Future exception
|
||||
# was never retrieved" warnings when there are no other awaiters
|
||||
# (other awaiters still see the exception through their ``await``).
|
||||
inflight.exception()
|
||||
raise
|
||||
|
||||
# Phase 3: insert with LRU eviction; resolve the in-flight future.
|
||||
# If ``aclose`` ran while we were connecting, ``_closed`` is now
|
||||
# True; don't insert into the cache (it has been drained), close
|
||||
# the just-built entry, and surface the closed-handler error to
|
||||
# all awaiters of the future.
|
||||
evicted: _CacheEntry | None = None
|
||||
duplicate: _CacheEntry | None = None
|
||||
handler_closed = False
|
||||
async with self._cache_lock:
|
||||
self._inflight.pop(key, None)
|
||||
if self._closed:
|
||||
handler_closed = True
|
||||
else:
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
# Another writer beat us; prefer the existing entry and
|
||||
# discard ours after the lock is released.
|
||||
self._cache.move_to_end(key)
|
||||
duplicate = entry
|
||||
entry = existing
|
||||
else:
|
||||
self._cache[key] = entry
|
||||
self._cache.move_to_end(key)
|
||||
if len(self._cache) > self._cache_max_size:
|
||||
_evicted_key, evicted = self._cache.popitem(last=False)
|
||||
if not inflight.done():
|
||||
inflight.set_result(entry)
|
||||
|
||||
if handler_closed:
|
||||
# Close our orphaned entry; resolve the future with a clear
|
||||
# error so the caller (and any other awaiters) surface a
|
||||
# consistent "handler is closed" failure rather than receiving
|
||||
# an entry we are about to close behind their back.
|
||||
await self._close_entry(entry)
|
||||
err = RuntimeError("DefaultMCPToolHandler is closed")
|
||||
if not inflight.done():
|
||||
inflight.set_exception(err)
|
||||
inflight.exception()
|
||||
raise err
|
||||
if duplicate is not None:
|
||||
await self._close_entry(duplicate)
|
||||
if evicted is not None:
|
||||
await self._close_entry(evicted)
|
||||
return entry
|
||||
|
||||
async def _create_entry(self, invocation: MCPToolInvocation) -> _CacheEntry:
|
||||
"""Construct (and connect) a fresh MCP client for ``invocation``."""
|
||||
from agent_framework import MCPStreamableHTTPTool
|
||||
|
||||
provided_client: httpx.AsyncClient | None = None
|
||||
if self._client_provider is not None:
|
||||
provided_client = await self._client_provider(invocation)
|
||||
# Capture headers for this cache entry so the header_provider closure
|
||||
# always returns the same set, regardless of the runtime kwargs.
|
||||
captured_headers = dict(invocation.headers)
|
||||
|
||||
def _header_provider(_kwargs: dict[str, Any]) -> dict[str, str]:
|
||||
return captured_headers
|
||||
|
||||
tool: Any = MCPStreamableHTTPTool(
|
||||
name=invocation.server_label or "McpClient",
|
||||
url=invocation.server_url,
|
||||
load_prompts=False,
|
||||
http_client=provided_client,
|
||||
header_provider=_header_provider if captured_headers else None,
|
||||
)
|
||||
try:
|
||||
await tool.connect()
|
||||
except BaseException:
|
||||
try:
|
||||
await tool.close()
|
||||
except Exception: # pragma: no cover - best effort
|
||||
logger.debug("DefaultMCPToolHandler: error closing tool after failed connect", exc_info=True)
|
||||
raise
|
||||
|
||||
# ``MCPStreamableHTTPTool.get_mcp_client`` lazily creates an
|
||||
# ``httpx.AsyncClient`` when no caller client was provided AND a
|
||||
# ``header_provider`` was set. We treat any client allocated this
|
||||
# way as owned (closed by the handler). When the caller supplies
|
||||
# one, we never close it.
|
||||
owned_client: httpx.AsyncClient | None = None
|
||||
if provided_client is None:
|
||||
owned_client = cast("httpx.AsyncClient | None", getattr(tool, "_httpx_client", None))
|
||||
return _CacheEntry(tool=tool, owned_httpx_client=owned_client)
|
||||
|
||||
async def _close_entry(self, entry: _CacheEntry) -> None:
|
||||
"""Close the MCP tool and any owned httpx client."""
|
||||
try:
|
||||
await entry.tool.close()
|
||||
except Exception: # pragma: no cover - best effort
|
||||
logger.debug("DefaultMCPToolHandler: error closing MCP tool", exc_info=True)
|
||||
if entry.owned_httpx_client is not None:
|
||||
try:
|
||||
await entry.owned_httpx_client.aclose()
|
||||
except Exception: # pragma: no cover - best effort
|
||||
logger.debug("DefaultMCPToolHandler: error closing owned httpx client", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(
|
||||
server_url: str,
|
||||
server_label: str | None,
|
||||
connection_name: str | None,
|
||||
headers: dict[str, str] | None,
|
||||
) -> tuple[str, str, str, str]:
|
||||
"""Build an order-independent cache key for the invocation identity.
|
||||
|
||||
The key includes ``server_label`` and ``connection_name`` so that
|
||||
callers using ``client_provider`` to dispatch on those fields
|
||||
receive a fresh client per logical connection (matches the
|
||||
documented dispatch contract).
|
||||
|
||||
Header *names* are lower-cased inside the hash payload only so
|
||||
that ``Authorization`` and ``authorization`` map to the same
|
||||
cache entry. Header values remain case-sensitive (per RFC 7235).
|
||||
"""
|
||||
if not headers:
|
||||
headers_hash = "0"
|
||||
else:
|
||||
normalized = sorted((k.lower(), v) for k, v in headers.items())
|
||||
payload = json.dumps(normalized, ensure_ascii=False)
|
||||
headers_hash = hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
return (server_url, server_label or "", connection_name or "", headers_hash)
|
||||
@@ -0,0 +1,543 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for ``DefaultMCPToolHandler``.
|
||||
|
||||
These tests exercise the real handler against a fake ``MCPStreamableHTTPTool``
|
||||
(no real MCP server, no real network) to cover the parts of the handler not
|
||||
exercisable through the executor stub: cache hit/miss/eviction, concurrent
|
||||
connect via in-flight futures, header isolation across cache keys,
|
||||
string-result normalisation, ``load_prompts=False`` verification, and
|
||||
owned-vs-caller httpx close semantics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from agent_framework import Content
|
||||
from agent_framework.exceptions import ToolExecutionException
|
||||
|
||||
from agent_framework_declarative._workflows._mcp_handler import (
|
||||
DefaultMCPToolHandler,
|
||||
MCPToolInvocation,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14),
|
||||
reason="Skipped on Python 3.14+ to keep parity with rest of declarative suite",
|
||||
)
|
||||
|
||||
|
||||
class FakeTool:
|
||||
"""Stand-in for ``MCPStreamableHTTPTool``.
|
||||
|
||||
Records constructor kwargs, tracks connect/close lifecycle, and dispatches
|
||||
``call_tool`` to a per-instance handler.
|
||||
"""
|
||||
|
||||
instances: list[FakeTool] = []
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.connect_count = 0
|
||||
self.close_count = 0
|
||||
self.connect_delay: float = 0.0
|
||||
self.connect_error: BaseException | None = None
|
||||
self.call_handler: Any = lambda **_a: [Content.from_text("ok")]
|
||||
self._httpx_client: httpx.AsyncClient | None = None
|
||||
# Mimic MCPStreamableHTTPTool: when no caller client AND header_provider
|
||||
# is set, lazily allocate an owned httpx client during connect.
|
||||
FakeTool.instances.append(self)
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self.connect_delay:
|
||||
await asyncio.sleep(self.connect_delay)
|
||||
if self.connect_error is not None:
|
||||
raise self.connect_error
|
||||
self.connect_count += 1
|
||||
# Mimic lazy httpx allocation when no client provided AND header_provider set.
|
||||
if self.kwargs.get("http_client") is None and self.kwargs.get("header_provider") is not None:
|
||||
self._httpx_client = httpx.AsyncClient()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_count += 1
|
||||
|
||||
async def call_tool(self, tool_name: str, **arguments: Any) -> Any:
|
||||
return self.call_handler(tool_name=tool_name, **arguments)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_fake_instances() -> None:
|
||||
FakeTool.instances.clear()
|
||||
|
||||
|
||||
def _patch_tool() -> Any:
|
||||
"""Patch the lazy import inside ``_create_entry`` to substitute FakeTool."""
|
||||
import agent_framework
|
||||
|
||||
return patch.object(agent_framework, "MCPStreamableHTTPTool", FakeTool)
|
||||
|
||||
|
||||
def _invocation(
|
||||
*, server_url: str = "https://mcp.example/api", tool_name: str = "search", **overrides: Any
|
||||
) -> MCPToolInvocation:
|
||||
return MCPToolInvocation(
|
||||
server_url=server_url,
|
||||
tool_name=tool_name,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
# ---------- Construction ---------------------------------------------------
|
||||
|
||||
|
||||
class TestConstruction:
|
||||
def test_invalid_cache_size_raises(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
DefaultMCPToolHandler(cache_max_size=0)
|
||||
with pytest.raises(ValueError):
|
||||
DefaultMCPToolHandler(cache_max_size=-3)
|
||||
|
||||
|
||||
# ---------- Tool kwargs ----------------------------------------------------
|
||||
|
||||
|
||||
class TestToolKwargs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_prompts_false_passed_to_tool(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].kwargs["load_prompts"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_label_used_as_tool_name(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_label="MyMcp"))
|
||||
assert FakeTool.instances[0].kwargs["name"] == "MyMcp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_tool_name_when_no_label(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_label=None))
|
||||
assert FakeTool.instances[0].kwargs["name"] == "McpClient"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_header_provider_when_no_headers(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={}))
|
||||
assert FakeTool.instances[0].kwargs["header_provider"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_provider_returns_captured_headers(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer T"}))
|
||||
provider = FakeTool.instances[0].kwargs["header_provider"]
|
||||
assert provider({}) == {"Authorization": "Bearer T"}
|
||||
# Even if runtime kwargs change, captured headers stay the same.
|
||||
assert provider({"foo": "bar"}) == {"Authorization": "Bearer T"}
|
||||
|
||||
|
||||
# ---------- Cache behaviour ------------------------------------------------
|
||||
|
||||
|
||||
class TestCache:
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_url_and_headers_hit_cache(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"X": "1"}))
|
||||
await handler.invoke_tool(_invocation(headers={"X": "1"}))
|
||||
# One tool created, connect called once.
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_headers_create_separate_entries(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "tk-A"}))
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "tk-B"}))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_urls_create_separate_entries(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_url="https://mcp.a/api"))
|
||||
await handler.invoke_tool(_invocation(server_url="https://mcp.b/api"))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction_closes_old_entry(self) -> None:
|
||||
handler = DefaultMCPToolHandler(cache_max_size=2)
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_url="https://a/"))
|
||||
await handler.invoke_tool(_invocation(server_url="https://b/"))
|
||||
# Inserting a third evicts the LRU entry (the first one).
|
||||
await handler.invoke_tool(_invocation(server_url="https://c/"))
|
||||
assert len(FakeTool.instances) == 3
|
||||
# First instance (https://a/) was evicted → close() called.
|
||||
assert FakeTool.instances[0].kwargs["url"] == "https://a/"
|
||||
assert FakeTool.instances[0].close_count == 1
|
||||
# Other two remain in cache → not closed.
|
||||
assert FakeTool.instances[1].close_count == 0
|
||||
assert FakeTool.instances[2].close_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repeated_use_keeps_lru_alive(self) -> None:
|
||||
handler = DefaultMCPToolHandler(cache_max_size=2)
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_url="https://a/"))
|
||||
await handler.invoke_tool(_invocation(server_url="https://b/"))
|
||||
# Touch a → b becomes LRU.
|
||||
await handler.invoke_tool(_invocation(server_url="https://a/"))
|
||||
# Insert c → b is evicted.
|
||||
await handler.invoke_tool(_invocation(server_url="https://c/"))
|
||||
# b was evicted.
|
||||
b = FakeTool.instances[1]
|
||||
assert b.kwargs["url"] == "https://b/"
|
||||
assert b.close_count == 1
|
||||
# a survived.
|
||||
a = FakeTool.instances[0]
|
||||
assert a.kwargs["url"] == "https://a/"
|
||||
assert a.close_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_connect_shares_one_entry(self) -> None:
|
||||
"""Multiple concurrent invocations with the same key must share one tool."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
|
||||
# Slow down connect so concurrency window is observable.
|
||||
original_connect = FakeTool.connect
|
||||
|
||||
async def slow_connect(self: FakeTool) -> None:
|
||||
self.connect_delay = 0.05
|
||||
await original_connect(self)
|
||||
|
||||
with _patch_tool(), patch.object(FakeTool, "connect", slow_connect):
|
||||
results = await asyncio.gather(
|
||||
handler.invoke_tool(_invocation(headers={"X": "1"})),
|
||||
handler.invoke_tool(_invocation(headers={"X": "1"})),
|
||||
handler.invoke_tool(_invocation(headers={"X": "1"})),
|
||||
handler.invoke_tool(_invocation(headers={"X": "1"})),
|
||||
)
|
||||
assert all(not r.is_error for r in results)
|
||||
# Only one tool was created and connected, despite 4 concurrent calls.
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_connection_names_create_separate_entries(self) -> None:
|
||||
"""Same URL/headers but different ``connection_name`` must dispatch separately."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(connection_name="conn-A"))
|
||||
await handler.invoke_tool(_invocation(connection_name="conn-B"))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_server_labels_create_separate_entries(self) -> None:
|
||||
"""Same URL/headers but different ``server_label`` must dispatch separately."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_label="LabelA"))
|
||||
await handler.invoke_tool(_invocation(server_label="LabelB"))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_identity_match_hits_cache(self) -> None:
|
||||
"""All four identity components match → single cached entry."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
|
||||
await handler.invoke_tool(_invocation(server_label="Lbl", connection_name="C", headers={"X": "1"}))
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_name_case_collapses_to_one_cache_entry(self) -> None:
|
||||
"""Header name spelling differences (case-only) must share a cache entry."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "tk"}))
|
||||
await handler.invoke_tool(_invocation(headers={"authorization": "tk"}))
|
||||
await handler.invoke_tool(_invocation(headers={"AUTHORIZATION": "tk"}))
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_value_case_does_not_collapse(self) -> None:
|
||||
"""Header *values* remain case-sensitive (different tokens → different sessions)."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "Bearer-A"}))
|
||||
await handler.invoke_tool(_invocation(headers={"Authorization": "bearer-a"}))
|
||||
assert len(FakeTool.instances) == 2
|
||||
|
||||
|
||||
# ---------- Aclose semantics ----------------------------------------------
|
||||
|
||||
|
||||
class TestAclose:
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_closes_owned_clients(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"X": "1"}))
|
||||
tool = FakeTool.instances[0]
|
||||
owned = tool._httpx_client
|
||||
assert owned is not None
|
||||
await handler.aclose()
|
||||
assert tool.close_count == 1
|
||||
assert owned.is_closed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_does_not_close_caller_supplied_client(self) -> None:
|
||||
caller_client = httpx.AsyncClient()
|
||||
|
||||
async def provider(_inv: MCPToolInvocation) -> httpx.AsyncClient:
|
||||
return caller_client
|
||||
|
||||
handler = DefaultMCPToolHandler(client_provider=provider)
|
||||
try:
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"X": "1"}))
|
||||
await handler.aclose()
|
||||
assert FakeTool.instances[0].close_count == 1
|
||||
# Caller client must still be usable.
|
||||
assert not caller_client.is_closed
|
||||
finally:
|
||||
await caller_client.aclose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self) -> None:
|
||||
with _patch_tool():
|
||||
async with DefaultMCPToolHandler() as handler:
|
||||
await handler.invoke_tool(_invocation())
|
||||
tool = FakeTool.instances[0]
|
||||
assert tool.close_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_is_idempotent(self) -> None:
|
||||
"""A second ``aclose`` is a no-op (no exception, no double-close)."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(headers={"X": "1"}))
|
||||
await handler.aclose()
|
||||
await handler.aclose()
|
||||
assert FakeTool.instances[0].close_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_after_close_returns_error_result(self) -> None:
|
||||
"""Post-close ``invoke_tool`` surfaces a tool error rather than crashing."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.aclose()
|
||||
result = await handler.invoke_tool(_invocation())
|
||||
assert result.is_error is True
|
||||
assert "closed" in (result.error_message or "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_drains_inflight_creation(self) -> None:
|
||||
"""An in-flight ``_create_entry`` must not leak when ``aclose`` races with it.
|
||||
|
||||
Reproduces the race described in PR #5630 review-comment 3:
|
||||
task A claims an inflight future and starts a slow connect; task B
|
||||
runs ``aclose``; task A must self-clean (close its tool + httpx
|
||||
client) and surface a closed-handler error rather than orphaning
|
||||
the entry.
|
||||
"""
|
||||
handler = DefaultMCPToolHandler()
|
||||
connect_started = asyncio.Event()
|
||||
release_connect = asyncio.Event()
|
||||
original_connect = FakeTool.connect
|
||||
|
||||
async def gated_connect(self: FakeTool) -> None:
|
||||
connect_started.set()
|
||||
await release_connect.wait()
|
||||
await original_connect(self)
|
||||
|
||||
with _patch_tool(), patch.object(FakeTool, "connect", gated_connect):
|
||||
invoke_task = asyncio.create_task(handler.invoke_tool(_invocation(headers={"X": "1"})))
|
||||
# Wait until task A is mid-connect.
|
||||
await connect_started.wait()
|
||||
# Race: kick off aclose. It must wait for the in-flight task.
|
||||
close_task = asyncio.create_task(handler.aclose())
|
||||
# Yield once to ensure aclose has set _closed and is awaiting.
|
||||
await asyncio.sleep(0)
|
||||
# Allow the connect to complete; phase 3 sees _closed and self-cleans.
|
||||
release_connect.set()
|
||||
result = await invoke_task
|
||||
await close_task
|
||||
|
||||
# Entry was created and then closed by the in-flight task itself.
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].close_count == 1
|
||||
# The originating invocation surfaces a closed-handler error.
|
||||
assert result.is_error is True
|
||||
assert "closed" in (result.error_message or "").lower()
|
||||
|
||||
|
||||
# ---------- Result normalisation ------------------------------------------
|
||||
|
||||
|
||||
class TestResultNormalisation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_result_wrapped_in_text_content(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
inv = _invocation()
|
||||
result = await handler.invoke_tool(inv)
|
||||
# The fake's default already returns a list; replace handler for this test.
|
||||
FakeTool.instances[0].call_handler = lambda **_a: "raw string body"
|
||||
result = await handler.invoke_tool(inv)
|
||||
assert result.is_error is False
|
||||
assert len(result.outputs) == 1
|
||||
assert result.outputs[0].text == "raw string body" # type: ignore[reportAttributeAccessIssue]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_result_passed_through(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
custom = [Content.from_text("a"), Content.from_text("b")]
|
||||
with _patch_tool():
|
||||
inv = _invocation()
|
||||
await handler.invoke_tool(inv)
|
||||
FakeTool.instances[0].call_handler = lambda **_a: custom
|
||||
result = await handler.invoke_tool(inv)
|
||||
assert result.is_error is False
|
||||
assert len(result.outputs) == 2
|
||||
|
||||
|
||||
# ---------- Error mapping --------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorMapping:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_exception_returns_error_result(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
|
||||
def boom(**_a: Any) -> Any:
|
||||
raise ToolExecutionException("server says no")
|
||||
|
||||
with _patch_tool():
|
||||
inv = _invocation()
|
||||
await handler.invoke_tool(inv)
|
||||
FakeTool.instances[0].call_handler = boom
|
||||
result = await handler.invoke_tool(inv)
|
||||
assert result.is_error is True
|
||||
assert result.error_message == "server says no"
|
||||
assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_httpx_error_returns_error_result(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
|
||||
def boom(**_a: Any) -> Any:
|
||||
raise httpx.ConnectError("dns failure")
|
||||
|
||||
with _patch_tool():
|
||||
inv = _invocation()
|
||||
await handler.invoke_tool(inv)
|
||||
FakeTool.instances[0].call_handler = boom
|
||||
result = await handler.invoke_tool(inv)
|
||||
assert result.is_error is True
|
||||
assert "dns failure" in (result.error_message or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_propagates(self) -> None:
|
||||
"""RuntimeError (not in the narrow catch list) must propagate."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
|
||||
def boom(**_a: Any) -> Any:
|
||||
raise RuntimeError("programmer error")
|
||||
|
||||
with _patch_tool():
|
||||
inv = _invocation()
|
||||
await handler.invoke_tool(inv)
|
||||
FakeTool.instances[0].call_handler = boom
|
||||
with pytest.raises(RuntimeError, match="programmer error"):
|
||||
await handler.invoke_tool(inv)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure_returns_error_result(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with (
|
||||
_patch_tool(),
|
||||
patch.object(
|
||||
FakeTool,
|
||||
"connect",
|
||||
lambda self: (_ for _ in ()).throw(httpx.ConnectError("server down")),
|
||||
),
|
||||
):
|
||||
result = await handler.invoke_tool(_invocation())
|
||||
assert result.is_error is True
|
||||
assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue]
|
||||
# Failed connect must clear in-flight + cache entries.
|
||||
assert handler._inflight == {}
|
||||
assert len(handler._cache) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_error_propagates(self) -> None:
|
||||
"""asyncio.CancelledError is BaseException, must NOT be swallowed."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
|
||||
def boom(**_a: Any) -> Any:
|
||||
raise asyncio.CancelledError
|
||||
|
||||
with _patch_tool():
|
||||
inv = _invocation()
|
||||
await handler.invoke_tool(inv)
|
||||
FakeTool.instances[0].call_handler = boom
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await handler.invoke_tool(inv)
|
||||
|
||||
|
||||
# ---------- Cache key isolation -------------------------------------------
|
||||
|
||||
|
||||
class TestCacheKey:
|
||||
def test_key_order_independent(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1", "B": "2"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"B": "2", "A": "1"})
|
||||
assert k1 == k2
|
||||
|
||||
def test_key_distinguishes_values(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "1"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"A": "2"})
|
||||
assert k1 != k2
|
||||
|
||||
def test_empty_headers_use_fixed_hash(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, None)
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {})
|
||||
assert k1 == k2
|
||||
|
||||
def test_key_distinguishes_connection_name(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-A", None)
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, "conn-B", None)
|
||||
assert k1 != k2
|
||||
|
||||
def test_key_distinguishes_server_label(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-A", None, None)
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", "Lbl-B", None, None)
|
||||
assert k1 != k2
|
||||
|
||||
def test_key_collapses_header_name_case(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"Authorization": "tk"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"authorization": "tk"})
|
||||
assert k1 == k2
|
||||
|
||||
def test_key_keeps_header_value_case(self) -> None:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "Bearer-A"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "bearer-a"})
|
||||
assert k1 != k2
|
||||
@@ -0,0 +1,664 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for ``InvokeMcpToolActionExecutor``.
|
||||
|
||||
Use a stub :class:`MCPToolHandler` that returns canned :class:`MCPToolResult`s.
|
||||
No real MCP server or network is exercised. See
|
||||
``test_default_mcp_tool_handler.py`` for tests that exercise the real
|
||||
``DefaultMCPToolHandler`` against a mocked ``MCPStreamableHTTPTool``.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import powerfx # noqa: F401
|
||||
|
||||
_powerfx_available = True
|
||||
except (ImportError, RuntimeError):
|
||||
_powerfx_available = False
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _powerfx_available or sys.version_info >= (3, 14),
|
||||
reason="PowerFx engine not available (requires dotnet runtime)",
|
||||
)
|
||||
|
||||
from agent_framework import Content, Message # noqa: E402
|
||||
from agent_framework.exceptions import ToolExecutionException # noqa: E402
|
||||
|
||||
from agent_framework_declarative._workflows import ( # noqa: E402
|
||||
DECLARATIVE_STATE_KEY,
|
||||
DeclarativeWorkflowError,
|
||||
MCPToolHandler,
|
||||
MCPToolInvocation,
|
||||
MCPToolResult,
|
||||
WorkflowFactory,
|
||||
)
|
||||
|
||||
|
||||
class StubMcpHandler:
|
||||
"""Test stub recording the last call and returning a canned result."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result: MCPToolResult | None = None,
|
||||
*,
|
||||
raise_exc: BaseException | None = None,
|
||||
) -> None:
|
||||
self.result = result
|
||||
self.raise_exc = raise_exc
|
||||
self.last_invocation: MCPToolInvocation | None = None
|
||||
self.invocations: list[MCPToolInvocation] = []
|
||||
self.call_count = 0
|
||||
|
||||
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
|
||||
self.call_count += 1
|
||||
self.last_invocation = invocation
|
||||
self.invocations.append(invocation)
|
||||
if self.raise_exc is not None:
|
||||
raise self.raise_exc
|
||||
assert self.result is not None
|
||||
return self.result
|
||||
|
||||
|
||||
def _ok(outputs: list[Content] | None = None) -> MCPToolResult:
|
||||
return MCPToolResult(outputs=outputs or [Content.from_text("hello")])
|
||||
|
||||
|
||||
def _err(message: str = "boom") -> MCPToolResult:
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
|
||||
|
||||
def _action(
|
||||
*,
|
||||
server_url: str = "https://mcp.example/api",
|
||||
tool_name: str = "search",
|
||||
server_label: str | None = None,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
require_approval: Any = None,
|
||||
connection: dict[str, Any] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
output: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
action: dict[str, Any] = {
|
||||
"kind": "InvokeMcpTool",
|
||||
"id": "mcp_action",
|
||||
"serverUrl": server_url,
|
||||
"toolName": tool_name,
|
||||
}
|
||||
if server_label is not None:
|
||||
action["serverLabel"] = server_label
|
||||
if arguments is not None:
|
||||
action["arguments"] = arguments
|
||||
if headers is not None:
|
||||
action["headers"] = headers
|
||||
if require_approval is not None:
|
||||
action["requireApproval"] = require_approval
|
||||
if connection is not None:
|
||||
action["connection"] = connection
|
||||
if conversation_id is not None:
|
||||
action["conversationId"] = conversation_id
|
||||
if output is not None:
|
||||
action["output"] = output
|
||||
return action
|
||||
|
||||
|
||||
def _yaml(action: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"name": "mcp_test", "actions": [action]}
|
||||
|
||||
|
||||
# ---------- Builder enforcement --------------------------------------------
|
||||
|
||||
|
||||
class TestBuilderEnforcement:
|
||||
def test_missing_handler_raises_at_build_time(self) -> None:
|
||||
factory = WorkflowFactory()
|
||||
with pytest.raises(DeclarativeWorkflowError) as excinfo:
|
||||
factory.create_workflow_from_definition(_yaml(_action()))
|
||||
assert "InvokeMcpTool" in str(excinfo.value)
|
||||
assert "mcp_tool_handler" in str(excinfo.value)
|
||||
|
||||
def test_missing_server_url_fails_validation(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
action = _action()
|
||||
del action["serverUrl"]
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
factory.create_workflow_from_definition(_yaml(action))
|
||||
assert "serverUrl" in str(excinfo.value)
|
||||
|
||||
def test_missing_tool_name_fails_validation(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
action = _action()
|
||||
del action["toolName"]
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
factory.create_workflow_from_definition(_yaml(action))
|
||||
assert "toolName" in str(excinfo.value)
|
||||
|
||||
|
||||
# ---------- Field forwarding ----------------------------------------------
|
||||
|
||||
|
||||
class TestFieldForwarding:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_invocation_forwards_required_fields(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action()))
|
||||
await workflow.run({})
|
||||
assert handler.call_count == 1
|
||||
inv = handler.last_invocation
|
||||
assert inv is not None
|
||||
assert inv.server_url == "https://mcp.example/api"
|
||||
assert inv.tool_name == "search"
|
||||
assert inv.server_label is None
|
||||
assert inv.headers == {}
|
||||
assert inv.arguments == {}
|
||||
assert inv.connection_name is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arguments_evaluated_and_preserves_none(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(
|
||||
_yaml(
|
||||
_action(
|
||||
arguments={
|
||||
"query": "weather today",
|
||||
"limit": 5,
|
||||
"fresh": True,
|
||||
"missing": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
await workflow.run({})
|
||||
inv = handler.last_invocation
|
||||
assert inv is not None
|
||||
# ``None`` is preserved (parity with .NET) — caller decides.
|
||||
assert inv.arguments == {
|
||||
"query": "weather today",
|
||||
"limit": 5,
|
||||
"fresh": True,
|
||||
"missing": None,
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_headers_drop_empty_values(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(
|
||||
_yaml(
|
||||
_action(
|
||||
headers={
|
||||
"Authorization": "Bearer token-123",
|
||||
"X-Trace": "trace-id",
|
||||
"X-Empty": "",
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
await workflow.run({})
|
||||
inv = handler.last_invocation
|
||||
assert inv is not None
|
||||
assert inv.headers == {
|
||||
"Authorization": "Bearer token-123",
|
||||
"X-Trace": "trace-id",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_label_and_connection_name_forwarded(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(
|
||||
_yaml(
|
||||
_action(
|
||||
server_label="docs-mcp",
|
||||
connection={"name": "azure-conn"},
|
||||
)
|
||||
)
|
||||
)
|
||||
await workflow.run({})
|
||||
inv = handler.last_invocation
|
||||
assert inv is not None
|
||||
assert inv.server_label == "docs-mcp"
|
||||
assert inv.connection_name == "azure-conn"
|
||||
|
||||
|
||||
# ---------- Output handling ------------------------------------------------
|
||||
|
||||
|
||||
class TestOutput:
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_result_parses_json_text(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text('{"k":"v","n":1}')]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
assert decl["Local"]["Result"] == [{"k": "v", "n": 1}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_result_falls_back_to_raw_text(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text("plain text not json")]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
assert decl["Local"]["Result"] == ["plain text not json"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_messages_writes_single_tool_role_message(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text("hi"), Content.from_text("there")]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"messages": "Local.Messages"})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
msg = decl["Local"]["Messages"]
|
||||
# Single Tool-role message containing both contents (parity with .NET).
|
||||
assert isinstance(msg, Message)
|
||||
assert str(msg.role).lower() == "tool"
|
||||
assert len(msg.contents) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uri_content_serialised_as_uri_string(self) -> None:
|
||||
uri_content = Content.from_uri("https://example.com/file.txt", media_type="text/plain")
|
||||
handler = StubMcpHandler(_ok([uri_content]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
assert decl["Local"]["Result"] == ["https://example.com/file.txt"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_path_object_form(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text("ok")]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": {"path": "Local.Result"}})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
assert decl["Local"]["Result"] == ["ok"]
|
||||
|
||||
|
||||
# ---------- Conversation append --------------------------------------------
|
||||
|
||||
|
||||
class TestConversation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_id_appends_assistant_message(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text("answer")]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(
|
||||
_yaml(
|
||||
_action(
|
||||
conversation_id="conv-42",
|
||||
output={"result": "Local.Result"},
|
||||
)
|
||||
)
|
||||
)
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
conv = decl["System"]["conversations"]["conv-42"]
|
||||
msgs = conv["messages"] if isinstance(conv, dict) else conv.messages
|
||||
assert len(msgs) == 1
|
||||
appended = msgs[0]
|
||||
assert str(appended.role).lower() == "assistant"
|
||||
# Same contents as the tool output.
|
||||
assert len(appended.contents) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_conversation_id_does_not_append(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text("answer")]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(
|
||||
_yaml(
|
||||
_action(
|
||||
conversation_id="",
|
||||
output={"result": "Local.Result"},
|
||||
)
|
||||
)
|
||||
)
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
# Empty conversation id must not produce a `""` entry under System.conversations.
|
||||
conversations = decl.get("System", {}).get("conversations", {})
|
||||
assert "" not in conversations
|
||||
|
||||
|
||||
# ---------- Approval flow --------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state(): # type: ignore[no-untyped-def]
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
state = MagicMock()
|
||||
state._data = {}
|
||||
|
||||
def _get(key: str, default: Any = None) -> Any:
|
||||
if key not in state._data:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(key)
|
||||
return state._data[key]
|
||||
|
||||
def _set(key: str, value: Any) -> None:
|
||||
state._data[key] = value
|
||||
|
||||
def _delete(key: str) -> None:
|
||||
if key in state._data:
|
||||
del state._data[key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
state.get = MagicMock(side_effect=_get)
|
||||
state.set = MagicMock(side_effect=_set)
|
||||
state.delete = MagicMock(side_effect=_delete)
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(mock_state): # type: ignore[no-untyped-def]
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.state = mock_state
|
||||
ctx.send_message = AsyncMock()
|
||||
ctx.yield_output = AsyncMock()
|
||||
ctx.request_info = AsyncMock()
|
||||
return ctx
|
||||
|
||||
|
||||
def _seed_state(mock_state) -> None: # type: ignore[no-untyped-def]
|
||||
"""Pre-seed the declarative state container as the executors expect."""
|
||||
from agent_framework_declarative._workflows import DECLARATIVE_STATE_KEY
|
||||
|
||||
mock_state._data[DECLARATIVE_STATE_KEY] = {
|
||||
"Local": {},
|
||||
"Custom": {},
|
||||
"Workflow": {},
|
||||
"System": {
|
||||
"ConversationId": "00000000-0000-0000-0000-000000000000",
|
||||
"LastMessage": {"Id": "", "Text": ""},
|
||||
"LastMessageText": "",
|
||||
"LastMessageId": "",
|
||||
},
|
||||
"Agent": {},
|
||||
"Conversation": {"messages": [], "history": []},
|
||||
"Inputs": {},
|
||||
}
|
||||
|
||||
|
||||
class TestApprovalFlow:
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_required_emits_request_and_yields(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
|
||||
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
|
||||
from agent_framework_declarative._workflows._executors_mcp import (
|
||||
_MCP_APPROVAL_STATE_KEY,
|
||||
InvokeMcpToolActionExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
)
|
||||
|
||||
_seed_state(mock_state)
|
||||
handler = StubMcpHandler(_ok())
|
||||
executor = InvokeMcpToolActionExecutor(
|
||||
_action(
|
||||
require_approval=True,
|
||||
arguments={"q": "x"},
|
||||
headers={"Authorization": "Bearer SECRET"},
|
||||
output={"result": "Local.Result"},
|
||||
),
|
||||
mcp_tool_handler=handler,
|
||||
)
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
# Approval request emitted.
|
||||
mock_context.request_info.assert_called_once()
|
||||
request = mock_context.request_info.call_args[0][0]
|
||||
assert isinstance(request, MCPToolApprovalRequest)
|
||||
assert request.tool_name == "search"
|
||||
assert request.arguments == {"q": "x"}
|
||||
assert request.header_names == ["Authorization"]
|
||||
|
||||
# NEVER expose the actual auth token in any field of the approval payload.
|
||||
for value in request.__dict__.values():
|
||||
assert "SECRET" not in str(value)
|
||||
|
||||
# Workflow should yield (no ActionComplete sent yet).
|
||||
mock_context.send_message.assert_not_called()
|
||||
|
||||
# Handler not invoked yet.
|
||||
assert handler.call_count == 0
|
||||
|
||||
# Approval state stored.
|
||||
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
|
||||
assert approval_key in mock_state._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_response_approved_invokes_handler(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
|
||||
from agent_framework_declarative._workflows import ActionComplete, ToolApprovalResponse
|
||||
from agent_framework_declarative._workflows._executors_mcp import (
|
||||
_MCP_APPROVAL_STATE_KEY,
|
||||
InvokeMcpToolActionExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
_MCPToolApprovalState,
|
||||
)
|
||||
|
||||
_seed_state(mock_state)
|
||||
handler = StubMcpHandler(_ok([Content.from_text('{"ok":true}')]))
|
||||
executor = InvokeMcpToolActionExecutor(
|
||||
_action(
|
||||
require_approval=True,
|
||||
output={"result": "Local.Result"},
|
||||
),
|
||||
mcp_tool_handler=handler,
|
||||
)
|
||||
# Pre-populate approval state.
|
||||
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
|
||||
mock_state._data[approval_key] = _MCPToolApprovalState(
|
||||
server_url="https://mcp.example/api",
|
||||
tool_name="search",
|
||||
server_label=None,
|
||||
arguments={"q": "x"},
|
||||
connection_name=None,
|
||||
headers_def={"Authorization": "Bearer tk"},
|
||||
auto_send=False,
|
||||
conversation_id_expr=None,
|
||||
output_messages_path=None,
|
||||
output_result_path="Local.Result",
|
||||
)
|
||||
await executor.handle_approval_response(
|
||||
MCPToolApprovalRequest(
|
||||
request_id="req-1",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "x"},
|
||||
),
|
||||
ToolApprovalResponse(approved=True),
|
||||
mock_context,
|
||||
)
|
||||
|
||||
assert handler.call_count == 1
|
||||
inv = handler.last_invocation
|
||||
assert inv is not None
|
||||
# Headers are re-evaluated from headers_def.
|
||||
assert inv.headers == {"Authorization": "Bearer tk"}
|
||||
# Approval state was cleaned up.
|
||||
assert approval_key not in mock_state._data
|
||||
# ActionComplete was sent.
|
||||
mock_context.send_message.assert_called_once()
|
||||
sent = mock_context.send_message.call_args[0][0]
|
||||
assert isinstance(sent, ActionComplete)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_response_rejected_assigns_error(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
|
||||
from agent_framework_declarative._workflows import ToolApprovalResponse
|
||||
from agent_framework_declarative._workflows._executors_mcp import (
|
||||
_MCP_APPROVAL_STATE_KEY,
|
||||
InvokeMcpToolActionExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
_MCPToolApprovalState,
|
||||
)
|
||||
|
||||
_seed_state(mock_state)
|
||||
handler = StubMcpHandler(_ok())
|
||||
executor = InvokeMcpToolActionExecutor(
|
||||
_action(
|
||||
require_approval=True,
|
||||
output={"result": "Local.Result"},
|
||||
),
|
||||
mcp_tool_handler=handler,
|
||||
)
|
||||
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
|
||||
mock_state._data[approval_key] = _MCPToolApprovalState(
|
||||
server_url="https://mcp.example/api",
|
||||
tool_name="search",
|
||||
server_label=None,
|
||||
arguments={},
|
||||
connection_name=None,
|
||||
headers_def=None,
|
||||
auto_send=True,
|
||||
conversation_id_expr=None,
|
||||
output_messages_path=None,
|
||||
output_result_path="Local.Result",
|
||||
)
|
||||
await executor.handle_approval_response(
|
||||
MCPToolApprovalRequest(
|
||||
request_id="req-2",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={},
|
||||
),
|
||||
ToolApprovalResponse(approved=False, reason="not authorized"),
|
||||
mock_context,
|
||||
)
|
||||
|
||||
assert handler.call_count == 0
|
||||
# Error string assigned at output.result.
|
||||
from agent_framework_declarative._workflows import DECLARATIVE_STATE_KEY
|
||||
|
||||
result = mock_state._data[DECLARATIVE_STATE_KEY]["Local"]["Result"]
|
||||
assert result == "Error: MCP tool invocation was not approved by user."
|
||||
|
||||
|
||||
# ---------- Error handling -------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_returns_error_result_assigns_error_string(self) -> None:
|
||||
handler = StubMcpHandler(_err("server down"))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
assert decl["Local"]["Result"] == "Error: server down"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_exception_becomes_error_result(self) -> None:
|
||||
handler = StubMcpHandler(raise_exc=ToolExecutionException("invalid arguments"))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
assert decl["Local"]["Result"] == "Error: invalid arguments"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_httpx_error_becomes_error_result(self) -> None:
|
||||
handler = StubMcpHandler(raise_exc=httpx.ConnectError("dns fail"))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"result": "Local.Result"})))
|
||||
await workflow.run({})
|
||||
decl = workflow._state.get(DECLARATIVE_STATE_KEY)
|
||||
result = decl["Local"]["Result"]
|
||||
assert isinstance(result, str)
|
||||
assert result.startswith("Error:")
|
||||
assert "ConnectError" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_propagates(self) -> None:
|
||||
"""Programmer bugs (TypeError etc.) must NOT be swallowed."""
|
||||
handler = StubMcpHandler(raise_exc=TypeError("bad type"))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action()))
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
await workflow.run({})
|
||||
# Either the TypeError reaches us or it gets wrapped by the runner —
|
||||
# either way the message must surface.
|
||||
assert "bad type" in str(excinfo.value)
|
||||
|
||||
|
||||
# ---------- autoSend -------------------------------------------------------
|
||||
|
||||
|
||||
class TestAutoSend:
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_send_default_true_yields_output(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text("hello")]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action()))
|
||||
events = await workflow.run({})
|
||||
outputs = events.get_outputs()
|
||||
assert len(outputs) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_send_false_suppresses_yield(self) -> None:
|
||||
handler = StubMcpHandler(_ok([Content.from_text("hello")]))
|
||||
factory = WorkflowFactory(mcp_tool_handler=handler)
|
||||
workflow = factory.create_workflow_from_definition(_yaml(_action(output={"autoSend": False})))
|
||||
events = await workflow.run({})
|
||||
outputs = events.get_outputs()
|
||||
assert outputs == []
|
||||
|
||||
|
||||
# ---------- Protocol structure --------------------------------------------
|
||||
|
||||
|
||||
class TestProtocol:
|
||||
def test_stub_handler_satisfies_protocol(self) -> None:
|
||||
handler = StubMcpHandler(_ok())
|
||||
assert isinstance(handler, MCPToolHandler)
|
||||
|
||||
|
||||
# ---------- _format_outputs_for_send --------------------------------------
|
||||
|
||||
|
||||
class TestFormatOutputsForSend:
|
||||
"""Direct tests for the auto-send rendering helper.
|
||||
|
||||
Regression for PR #5630 review-comment 4: a single scalar JSON value
|
||||
must render bare (e.g. ``"42"``) rather than wrapped (``"[42]"``).
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("parsed", "expected"),
|
||||
[
|
||||
([], ""),
|
||||
(["hello"], "hello"),
|
||||
(["a", "b"], "a\nb"),
|
||||
([42], "42"),
|
||||
([3.14], "3.14"),
|
||||
([True], "true"),
|
||||
([False], "false"),
|
||||
([None], "null"),
|
||||
([{"k": "v"}], '{"k": "v"}'),
|
||||
([[1, 2]], "[1, 2]"),
|
||||
(["hello", 42], '["hello", 42]'),
|
||||
([{"a": 1}, {"b": 2}], '[{"a": 1}, {"b": 2}]'),
|
||||
],
|
||||
)
|
||||
def test_format_outputs_for_send(self, parsed: list[Any], expected: str) -> None:
|
||||
from agent_framework_declarative._workflows._executors_mcp import _format_outputs_for_send
|
||||
|
||||
assert _format_outputs_for_send(parsed) == expected
|
||||
Reference in New Issue
Block a user