diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py index 579cabfafa..31ee990cb6 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/__init__.py @@ -76,12 +76,10 @@ from ._executors_mcp import ( from ._executors_tools import ( FUNCTION_TOOL_REGISTRY_KEY, TOOL_ACTION_EXECUTORS, - TOOL_APPROVAL_STATE_KEY, BaseToolExecutor, InvokeFunctionToolExecutor, ToolApprovalRequest, ToolApprovalResponse, - ToolApprovalState, ToolInvocationResult, ) from ._factory import WorkflowFactory @@ -111,7 +109,6 @@ __all__ = [ "HTTP_ACTION_EXECUTORS", "MCP_ACTION_EXECUTORS", "TOOL_ACTION_EXECUTORS", - "TOOL_APPROVAL_STATE_KEY", "TOOL_REGISTRY_KEY", "ActionComplete", "ActionTrigger", @@ -164,7 +161,6 @@ __all__ = [ "SetVariableExecutor", "ToolApprovalRequest", "ToolApprovalResponse", - "ToolApprovalState", "ToolInvocationResult", "WorkflowFactory", "WorkflowState", diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py index 73b66341ea..9a7d918704 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -15,12 +15,11 @@ Security notes: matches the security posture of :mod:`._executors_http` (which never logs request headers either) and prevents secrets from leaking through workflow events that are typically observable to operators / UIs. -- ``_MCPToolApprovalState`` snapshots the EVALUATED values for non-secret - fields (server URL, tool name, arguments) at approval-request time so that - subsequent state mutations cannot make the executor "approve X then call - Y". Headers are stored as the raw expression strings (not evaluated values) - so secrets are not persisted in the workflow's checkpoint state. They are - re-evaluated on resume. +- The :class:`MCPToolApprovalRequest` payload is the source of truth for the + resumed invocation: ``tool_name``, ``server_url``, ``server_label``, + ``arguments``, and ``connection_name`` come from the request the reviewer + approved. Headers are re-evaluated from the action definition on resume so + that secret values are not persisted in the workflow's checkpoint state. - Tool outputs flow back into agent conversations through ``conversationId`` and through Tool-role messages emitted to ``output.messages``. They share the same prompt-injection risk surface as ``HttpRequestAction``: workflow @@ -60,8 +59,6 @@ __all__ = [ logger = logging.getLogger(__name__) -_MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state" - # --------------------------------------------------------------------------- # Request / state types @@ -86,6 +83,9 @@ class MCPToolApprovalRequest: arguments: Evaluated arguments to be forwarded to the tool. header_names: Sorted list of outbound header names (no values). Empty when no headers are configured. + connection_name: Optional connection identifier the invocation will + use. Surfaced so the reviewer can see which connection is bound + to the approved call. """ request_id: str @@ -94,28 +94,7 @@ class MCPToolApprovalRequest: server_label: str | None arguments: dict[str, Any] header_names: list[str] = field(default_factory=lambda: []) - - -@dataclass -class _MCPToolApprovalState: - """Internal state saved during the approval yield for resumption. - - Stores **evaluated** values for non-secret fields to prevent - "approve X / execute Y" attacks. Stores the raw expression string for - ``headers`` so that secret values are NOT persisted in checkpoint state; - the expressions are re-evaluated against current state on resume. - """ - - server_url: str - tool_name: str - server_label: str | None - arguments: dict[str, Any] - connection_name: str | None - headers_def: Any - auto_send: bool - conversation_id_expr: str | None - output_messages_path: str | None - output_result_path: str | None + connection_name: str | None = None # --------------------------------------------------------------------------- @@ -260,20 +239,6 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor): if require_approval: request_id = str(uuid.uuid4()) - approval_state = _MCPToolApprovalState( - server_url=server_url, - tool_name=tool_name, - server_label=server_label, - arguments=arguments, - connection_name=connection_name, - headers_def=self._action_def.get("headers"), - auto_send=auto_send, - conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None, - output_messages_path=output_messages_path, - output_result_path=output_result_path, - ) - ctx.state.set(self._approval_key(), approval_state) - request = MCPToolApprovalRequest( request_id=request_id, tool_name=tool_name, @@ -281,6 +246,7 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor): server_label=server_label, arguments=arguments, header_names=sorted(headers.keys()), + connection_name=connection_name, ) logger.info( "%s: requesting approval for MCP tool '%s' on '%s'", @@ -322,54 +288,59 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor): response: ToolApprovalResponse, ctx: WorkflowContext[ActionComplete, str], ) -> None: - """Resume after the workflow yielded for an approval request.""" - state = self._get_state(ctx.state) - approval_key = self._approval_key() + """Resume after the workflow yielded for an approval request. - try: - approval_state: _MCPToolApprovalState = ctx.state.get(approval_key) - except KeyError: - logger.error("%s: approval state missing for executor '%s'", self.__class__.__name__, self.id) - await ctx.send_message(ActionComplete()) - return - try: - ctx.state.delete(approval_key) - except KeyError: - logger.warning("%s: approval state already deleted for '%s'", self.__class__.__name__, self.id) + Invocation fields (``tool_name``, ``server_url``, ``server_label``, + ``arguments``, ``connection_name``) are sourced from + ``original_request``. Output configuration is re-derived from the + action definition; header values are re-evaluated from the action + definition so secrets remain out of checkpoint state. + """ + state = self._get_state(ctx.state) + + tool_name = original_request.tool_name + server_url = original_request.server_url + server_label = original_request.server_label + arguments = original_request.arguments + connection_name = original_request.connection_name + + auto_send = self._get_auto_send(state) + conversation_id_value = self._action_def.get("conversationId") + conversation_id_expr = conversation_id_value if isinstance(conversation_id_value, str) else None + output_messages_path = _get_output_path(self._action_def, "messages") + output_result_path = _get_output_path(self._action_def, "result") if not response.approved: logger.info( "%s: MCP tool '%s' rejected: %s", self.__class__.__name__, - approval_state.tool_name, + tool_name, response.reason, ) - self._assign_error( - state, approval_state.output_result_path, "MCP tool invocation was not approved by user." - ) + self._assign_error(state, output_result_path, "MCP tool invocation was not approved by user.") await ctx.send_message(ActionComplete()) return - # Approved — re-evaluate headers (not stored at approval time for security). - headers = self._evaluate_headers(state, approval_state.headers_def) + # Approved — re-evaluate headers (not surfaced at approval time for security). + headers = self._evaluate_headers(state, self._action_def.get("headers")) invocation = MCPToolInvocation( - server_url=approval_state.server_url, - tool_name=approval_state.tool_name, - server_label=approval_state.server_label, - arguments=approval_state.arguments, + server_url=server_url, + tool_name=tool_name, + server_label=server_label, + arguments=arguments, headers=headers, - connection_name=approval_state.connection_name, + connection_name=connection_name, ) result = await self._invoke_with_narrow_catch(invocation) await self._process_result( ctx=ctx, state=state, result=result, - auto_send=approval_state.auto_send, - conversation_id_expr=approval_state.conversation_id_expr, - output_messages_path=approval_state.output_messages_path, - output_result_path=approval_state.output_result_path, + auto_send=auto_send, + conversation_id_expr=conversation_id_expr, + output_messages_path=output_messages_path, + output_result_path=output_result_path, ) await ctx.send_message(ActionComplete()) @@ -577,9 +548,6 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor): return state.set(output_result_path, f"Error: {error_message}") - def _approval_key(self) -> str: - return f"{_MCP_APPROVAL_STATE_KEY}_{self.id}" - def _parse_outputs(outputs: list[Content]) -> list[Any]: """Parse :class:`Content` outputs into Python values for ``output.result``. diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py index b2c046a69b..d522cf5664 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py @@ -41,10 +41,6 @@ logger = logging.getLogger(__name__) # at runtime are discoverable by both agent-based and function-based tool executors. FUNCTION_TOOL_REGISTRY_KEY = TOOL_REGISTRY_KEY -# State key prefix for storing approval state during yield/resume. -# The executor's ID is appended to create a per-executor key. -TOOL_APPROVAL_STATE_KEY = "_tool_approval_state" - # ============================================================================ # Request/Response Types for Approval Flow @@ -87,26 +83,6 @@ class ToolApprovalResponse: reason: str | None = None -# ============================================================================ -# State Types for Approval Flow -# ============================================================================ - - -@dataclass -class ToolApprovalState: - """State saved during approval yield for resumption. - - Stored in State under a per-executor key when requireApproval=true. - Retrieved by handle_approval_response() to continue execution. - """ - - function_name: str - arguments: dict[str, Any] - output_messages_var: str | None - output_result_var: str | None - auto_send: bool - - # ============================================================================ # Result Types # ============================================================================ @@ -501,25 +477,16 @@ class BaseToolExecutor(DeclarativeActionExecutor): require_approval = self._action_def.get("requireApproval", False) if require_approval: - # Save state for resumption (keyed by executor ID to avoid collisions) - approval_state = ToolApprovalState( - function_name=function_name, - arguments=arguments, - output_messages_var=messages_var, - output_result_var=result_var, - auto_send=auto_send, - ) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}" - ctx.state.set(approval_key, approval_state) - - # Emit approval request - workflow yields here + # Emit approval request - the request payload is the source of + # truth for resumed invocation; no side-channel state is written. + request_id = str(uuid.uuid4()) request = ToolApprovalRequest( - request_id=str(uuid.uuid4()), + request_id=request_id, function_name=function_name, arguments=arguments, ) logger.info(f"{self.__class__.__name__}: requesting approval for '{function_name}'") - await ctx.request_info(request, ToolApprovalResponse) + await ctx.request_info(request, ToolApprovalResponse, request_id=request_id) # Workflow yields - will resume in handle_approval_response return @@ -545,36 +512,16 @@ class BaseToolExecutor(DeclarativeActionExecutor): ) -> None: """Handle response to a ToolApprovalRequest. - Called when the workflow resumes after yielding for approval. - Either executes the tool (if approved) or stores rejection status. + Resumes after the workflow yielded for approval. The invocation + ``function_name`` and ``arguments`` are sourced from + ``original_request`` (the payload the reviewer approved); output + configuration is re-derived from the executor's action definition. """ state = self._get_state(ctx.state) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}" - # Retrieve saved invocation state - try: - approval_state: ToolApprovalState = ctx.state.get(approval_key) - except KeyError: - error_msg = "Approval state not found, cannot resume tool invocation" - logger.error(f"{self.__class__.__name__}: {error_msg}") - # Try to store error - get output config from action def as fallback - _, result_var, _ = self._get_output_config() - if result_var and state: - state.set(_normalize_variable_path(result_var), {"error": error_msg}) - await ctx.send_message(ActionComplete()) - return - - # Clean up approval state - try: - ctx.state.delete(approval_key) - except KeyError: - logger.warning(f"{self.__class__.__name__}: approval state already deleted") - - function_name = approval_state.function_name - arguments = approval_state.arguments - messages_var = approval_state.output_messages_var - result_var = approval_state.output_result_var - auto_send = approval_state.auto_send + function_name = original_request.function_name + arguments = original_request.arguments + messages_var, result_var, auto_send = self._get_output_config() # Check if approved if not response.approved: diff --git a/python/packages/declarative/tests/test_declarative_approval_binding.py b/python/packages/declarative/tests/test_declarative_approval_binding.py new file mode 100644 index 0000000000..f96b16ee8a --- /dev/null +++ b/python/packages/declarative/tests/test_declarative_approval_binding.py @@ -0,0 +1,441 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false +# pyright: reportMissingParameterType=false, reportUnknownMemberType=false +# pyright: reportPrivateUsage=false, reportUnknownVariableType=false +# pyright: reportGeneralTypeIssues=false + +"""Regression tests pinning the approval-flow binding contract. + +The resumed invocation MUST come from the framework-delivered +``original_request`` payload (the data the reviewer approved) for both +``InvokeFunctionTool`` and ``InvokeMcpTool``. These tests verify that: + +* Invocation parameters come from ``original_request``, not from any prior + side-channel state. +* Concurrent pending approvals on the same executor do not swap. +* Pre-existing state at old approval keys is ignored entirely. +* Resume works on a freshly constructed executor (checkpoint-restore + simulation), without any prior ``ctx.state`` write. +* For MCP, ``connection_name`` is sourced from the approval payload and + ``headers`` are re-evaluated from the action definition on resume. +""" + +import sys +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + +pytestmark = pytest.mark.skipif( + not _powerfx_available or sys.version_info >= (3, 14), + reason="PowerFx engine not available (requires dotnet runtime)", +) + +from agent_framework import Content # noqa: E402 + +from agent_framework_declarative._workflows import ( # noqa: E402 + DECLARATIVE_STATE_KEY, + ActionComplete, + InvokeFunctionToolExecutor, + MCPToolApprovalRequest, + MCPToolHandler, + MCPToolInvocation, + MCPToolResult, + ToolApprovalRequest, + ToolApprovalResponse, +) +from agent_framework_declarative._workflows._executors_mcp import ( # noqa: E402 + InvokeMcpToolActionExecutor, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_state() -> MagicMock: + """In-memory mock of the underlying State.""" + state = MagicMock() + state._data = {} + + def _get(key: str, default: Any = None) -> Any: + return state._data.get(key, default) + + def _set(key: str, value: Any) -> None: + state._data[key] = value + + def _has(key: str) -> bool: + return key in state._data + + def _delete(key: str) -> None: + state._data.pop(key, None) + + state.get = MagicMock(side_effect=_get) + state.set = MagicMock(side_effect=_set) + state.has = MagicMock(side_effect=_has) + state.delete = MagicMock(side_effect=_delete) + return state + + +@pytest.fixture +def mock_context(mock_state: MagicMock) -> MagicMock: + ctx = MagicMock() + ctx.state = mock_state + ctx.send_message = AsyncMock() + ctx.yield_output = AsyncMock() + ctx.request_info = AsyncMock() + return ctx + + +def _seed_state(mock_state: MagicMock) -> None: + mock_state._data[DECLARATIVE_STATE_KEY] = { + "Inputs": {}, + "Outputs": {}, + "Local": {}, + "Custom": {}, + "System": { + "ConversationId": "00000000-0000-0000-0000-000000000000", + "LastMessage": {"Text": "", "Id": ""}, + "LastMessageText": "", + "LastMessageId": "", + }, + "Agent": {}, + "Conversation": {"messages": [], "history": []}, + } + + +class _RecordingMcpHandler(MCPToolHandler): + def __init__(self, result: MCPToolResult | None = None) -> None: + self.result = result or MCPToolResult(outputs=[Content.from_text("ok")]) + self.invocations: list[MCPToolInvocation] = [] + + @property + def call_count(self) -> int: + return len(self.invocations) + + @property + def last(self) -> MCPToolInvocation | None: + return self.invocations[-1] if self.invocations else None + + async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult: + self.invocations.append(invocation) + return self.result + + +# --------------------------------------------------------------------------- +# InvokeFunctionTool: approval-binding regression +# --------------------------------------------------------------------------- + + +class TestFunctionToolApprovalBinding: + def _action(self, *, fn_name: str = "my_tool") -> dict[str, Any]: + return { + "kind": "InvokeFunctionTool", + "id": "fn_action", + "functionName": fn_name, + "requireApproval": True, + "output": {"result": "Local.result"}, + } + + @pytest.mark.asyncio + async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None: + """The id on the emitted ToolApprovalRequest must match the framework's pending-request key.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + + def my_tool(x: int) -> int: + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + emitted_request = mock_context.request_info.call_args[0][0] + framework_request_id = mock_context.request_info.call_args.kwargs["request_id"] + assert isinstance(emitted_request, ToolApprovalRequest) + assert emitted_request.request_id == framework_request_id + + @pytest.mark.asyncio + async def test_resume_uses_request_payload_arguments(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-1", function_name="my_tool", arguments={"x": 1}) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [1] + + @pytest.mark.asyncio + async def test_concurrent_pending_approvals_do_not_swap(self, mock_state, mock_context) -> None: + """Two pending approvals, responses delivered out of order — each invocation uses its own payload.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request_a = ToolApprovalRequest(request_id="r-A", function_name="my_tool", arguments={"x": 1}) + request_b = ToolApprovalRequest(request_id="r-B", function_name="my_tool", arguments={"x": 999}) + + # Deliver response for B first, then for A. Each invocation must use its own payload. + await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context) + await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [999, 1] + + @pytest.mark.asyncio + async def test_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None: + """Pre-existing state at the OLD approval key is ignored — payload wins.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + # Poison the old key shape (no longer read by the executor). + mock_state._data["_tool_approval_state_fn_action"] = {"function_name": "other", "arguments": {"x": 999}} + + request = ToolApprovalRequest(request_id="r-3", function_name="my_tool", arguments={"x": 7}) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [7] + # The poison was never read or deleted by the executor. + assert "_tool_approval_state_fn_action" in mock_state._data + + @pytest.mark.asyncio + async def test_fresh_executor_resume_works(self, mock_state, mock_context) -> None: + """Simulates checkpoint restore: a brand-new executor instance handles the approval response.""" + _seed_state(mock_state) + call_log: list[int] = [] + + def my_tool(x: int) -> int: + call_log.append(x) + return x + + # Pretend the executor that emitted the request is gone; a fresh one handles the response. + fresh = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-4", function_name="my_tool", arguments={"x": 42}) + await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert call_log == [42] + mock_context.send_message.assert_called_once() + sent = mock_context.send_message.call_args[0][0] + assert isinstance(sent, ActionComplete) + + @pytest.mark.asyncio + async def test_rejection_uses_request_payload_function_name(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + + def my_tool(x: int) -> int: + raise AssertionError("should not be called when rejected") + + executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool}) + + request = ToolApprovalRequest(request_id="r-5", function_name="my_tool", arguments={"x": 3}) + await executor.handle_approval_response( + request, ToolApprovalResponse(approved=False, reason="not authorized"), mock_context + ) + + # The rejection message references the function name from the request payload. + local = mock_state._data[DECLARATIVE_STATE_KEY]["Local"] + assert local["result"]["rejected"] is True + assert local["result"]["reason"] == "not authorized" + + +# --------------------------------------------------------------------------- +# InvokeMcpTool: approval-binding regression +# --------------------------------------------------------------------------- + + +class TestMcpToolApprovalBinding: + def _action(self, *, headers: dict[str, Any] | None = None) -> dict[str, Any]: + action: dict[str, Any] = { + "kind": "InvokeMcpTool", + "id": "mcp_action", + "serverUrl": "https://mcp.example/api", + "toolName": "search", + "requireApproval": True, + "output": {"result": "Local.Result"}, + } + if headers is not None: + action["headers"] = headers + return action + + @pytest.mark.asyncio + async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None: + """The id on the emitted MCPToolApprovalRequest must match the framework's pending-request key.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=_RecordingMcpHandler()) + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + emitted_request = mock_context.request_info.call_args[0][0] + framework_request_id = mock_context.request_info.call_args.kwargs["request_id"] + assert isinstance(emitted_request, MCPToolApprovalRequest) + assert emitted_request.request_id == framework_request_id + + @pytest.mark.asyncio + async def test_resume_uses_request_payload_fields(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label="prod", + arguments={"q": "x"}, + connection_name="conn-A", + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + inv = handler.last + assert inv is not None + assert inv.tool_name == "search" + assert inv.server_url == "https://mcp.example/api" + assert inv.server_label == "prod" + assert inv.arguments == {"q": "x"} + assert inv.connection_name == "conn-A" + + @pytest.mark.asyncio + async def test_concurrent_pending_mcp_approvals_do_not_swap(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request_a = MCPToolApprovalRequest( + request_id="r-A", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "alpha"}, + connection_name="conn-A", + ) + request_b = MCPToolApprovalRequest( + request_id="r-B", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "beta"}, + connection_name="conn-B", + ) + + await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context) + await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 2 + assert handler.invocations[0].arguments == {"q": "beta"} + assert handler.invocations[0].connection_name == "conn-B" + assert handler.invocations[1].arguments == {"q": "alpha"} + assert handler.invocations[1].connection_name == "conn-A" + + @pytest.mark.asyncio + async def test_headers_reevaluated_from_action_def_on_resume(self, mock_state, mock_context) -> None: + """Headers come from the action definition (re-evaluated) so secrets are not in the payload.""" + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor( + self._action(headers={"Authorization": "Bearer tk"}), + mcp_tool_handler=handler, + ) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "x"}, + connection_name=None, + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.last is not None + assert handler.last.headers == {"Authorization": "Bearer tk"} + + @pytest.mark.asyncio + async def test_mcp_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None: + _seed_state(mock_state) + handler = _RecordingMcpHandler() + executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + mock_state._data["_mcp_tool_approval_state_mcp_action"] = {"poison": True} + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "real"}, + connection_name=None, + ) + await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.arguments == {"q": "real"} + # The poison was never read or deleted by the executor. + assert "_mcp_tool_approval_state_mcp_action" in mock_state._data + + @pytest.mark.asyncio + async def test_fresh_mcp_executor_resume_works(self, mock_state, mock_context) -> None: + """Checkpoint-restore simulation: fresh executor handles the response.""" + _seed_state(mock_state) + handler = _RecordingMcpHandler() + fresh = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler) + + request = MCPToolApprovalRequest( + request_id="r-1", + tool_name="search", + server_url="https://mcp.example/api", + server_label=None, + arguments={"q": "fresh"}, + connection_name=None, + ) + await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context) + + assert handler.call_count == 1 + assert handler.last is not None + assert handler.last.arguments == {"q": "fresh"} + + @pytest.mark.asyncio + async def test_request_payload_carries_connection_name(self, mock_state, mock_context) -> None: + """When emitting the approval request, connection_name flows into MCPToolApprovalRequest.""" + from agent_framework_declarative._workflows._declarative_base import ActionTrigger + + _seed_state(mock_state) + action = self._action() + action["connection"] = {"name": "conn-from-action"} + executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler()) + + await executor.handle_action(ActionTrigger(), mock_context) + + mock_context.request_info.assert_called_once() + request = mock_context.request_info.call_args[0][0] + assert isinstance(request, MCPToolApprovalRequest) + assert request.connection_name == "conn-from-action" diff --git a/python/packages/declarative/tests/test_function_tool_executor.py b/python/packages/declarative/tests/test_function_tool_executor.py index f11b356865..bcf04bd21d 100644 --- a/python/packages/declarative/tests/test_function_tool_executor.py +++ b/python/packages/declarative/tests/test_function_tool_executor.py @@ -35,14 +35,12 @@ pytestmark = pytest.mark.skipif( from agent_framework_declarative._workflows import ( # noqa: E402 DECLARATIVE_STATE_KEY, FUNCTION_TOOL_REGISTRY_KEY, - TOOL_APPROVAL_STATE_KEY, ActionComplete, ActionTrigger, DeclarativeWorkflowBuilder, InvokeFunctionToolExecutor, ToolApprovalRequest, ToolApprovalResponse, - ToolApprovalState, ToolInvocationResult, WorkflowFactory, ) @@ -393,21 +391,6 @@ class TestToolApprovalTypes: assert response.approved is False assert response.reason == "Not authorized" - def test_approval_state(self): - """Test creating approval state for yield/resume.""" - state = ToolApprovalState( - function_name="delete_user", - arguments={"user_id": "123"}, - output_messages_var="Local.messages", - output_result_var="Local.result", - auto_send=True, - ) - assert state.function_name == "delete_user" - assert state.arguments == {"user_id": "123"} - assert state.output_messages_var == "Local.messages" - assert state.output_result_var == "Local.result" - assert state.auto_send is True - class TestInvokeFunctionToolEdgeCases: """Tests for edge cases and error handling.""" @@ -1075,13 +1058,6 @@ class TestApprovalFlow: # Should NOT have sent ActionComplete (workflow yields) mock_context.send_message.assert_not_called() - # Approval state should be saved in state - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_test" - saved_state = mock_state._data[approval_key] - assert isinstance(saved_state, ToolApprovalState) - assert saved_state.function_name == "my_tool" - assert saved_state.arguments == {"x": 5} - @pytest.mark.asyncio async def test_approval_response_approved(self, mock_state, mock_context): """When approval response is approved, the tool should be invoked.""" @@ -1104,17 +1080,7 @@ class TestApprovalFlow: executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool}) - # Pre-populate approval state (simulating what handle_action stores) - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_approved" - mock_state._data[approval_key] = ToolApprovalState( - function_name="my_tool", - arguments={"x": 7}, - output_messages_var=None, - output_result_var="Local.result", - auto_send=True, - ) - - # Simulate the response + # Simulate the response — invocation params come from original_request original_request = ToolApprovalRequest( request_id="req-123", function_name="my_tool", @@ -1124,7 +1090,7 @@ class TestApprovalFlow: await executor.handle_approval_response(original_request, response, mock_context) - # Tool should have been called + # Tool should have been called with the approved arguments assert call_log == [7] # ActionComplete should have been sent @@ -1132,9 +1098,6 @@ class TestApprovalFlow: sent = mock_context.send_message.call_args[0][0] assert isinstance(sent, ActionComplete) - # Approval state should be cleaned up - assert approval_key not in mock_state._data - @pytest.mark.asyncio async def test_approval_response_rejected(self, mock_state, mock_context): """When approval response is rejected, rejection status should be stored.""" @@ -1154,16 +1117,6 @@ class TestApprovalFlow: executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool}) - # Pre-populate approval state - approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_rejected" - mock_state._data[approval_key] = ToolApprovalState( - function_name="my_tool", - arguments={"x": 5}, - output_messages_var=None, - output_result_var="Local.result", - auto_send=True, - ) - original_request = ToolApprovalRequest( request_id="req-456", function_name="my_tool", @@ -1185,36 +1138,6 @@ class TestApprovalFlow: assert result["reason"] == "Not authorized" assert result["approved"] is False - @pytest.mark.asyncio - async def test_approval_response_missing_state(self, mock_state, mock_context): - """When approval state is missing on resume, should log error and complete.""" - self._init_state(mock_state) - - action_def = { - "kind": "InvokeFunctionTool", - "id": "missing_state_test", - "functionName": "my_tool", - "requireApproval": True, - "output": {"result": "Local.result"}, - } - - executor = InvokeFunctionToolExecutor(action_def, tools={}) - - # Don't populate approval state - simulate missing state - original_request = ToolApprovalRequest( - request_id="req-789", - function_name="my_tool", - arguments={}, - ) - response = ToolApprovalResponse(approved=True) - - await executor.handle_approval_response(original_request, response, mock_context) - - # Should still send ActionComplete - mock_context.send_message.assert_called_once() - sent = mock_context.send_message.call_args[0][0] - assert isinstance(sent, ActionComplete) - # ============================================================================ # State registry tool lookup (lines 255-257) diff --git a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py index fdee1f7df1..549cdd30a7 100644 --- a/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py +++ b/python/packages/declarative/tests/test_invoke_mcp_tool_executor.py @@ -403,7 +403,6 @@ class TestApprovalFlow: async def test_approval_required_emits_request_and_yields(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows._declarative_base import ActionTrigger from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, ) @@ -439,18 +438,12 @@ class TestApprovalFlow: # Handler not invoked yet. assert handler.call_count == 0 - # Approval state stored. - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - assert approval_key in mock_state._data - @pytest.mark.asyncio async def test_approval_response_approved_invokes_handler(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows import ActionComplete, ToolApprovalResponse from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, - _MCPToolApprovalState, ) _seed_state(mock_state) @@ -458,24 +451,11 @@ class TestApprovalFlow: executor = InvokeMcpToolActionExecutor( _action( require_approval=True, + headers={"Authorization": "Bearer tk"}, output={"result": "Local.Result"}, ), mcp_tool_handler=handler, ) - # Pre-populate approval state. - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - mock_state._data[approval_key] = _MCPToolApprovalState( - server_url="https://mcp.example/api", - tool_name="search", - server_label=None, - arguments={"q": "x"}, - connection_name=None, - headers_def={"Authorization": "Bearer tk"}, - auto_send=False, - conversation_id_expr=None, - output_messages_path=None, - output_result_path="Local.Result", - ) await executor.handle_approval_response( MCPToolApprovalRequest( request_id="req-1", @@ -491,10 +471,12 @@ class TestApprovalFlow: assert handler.call_count == 1 inv = handler.last_invocation assert inv is not None - # Headers are re-evaluated from headers_def. + # Invocation fields source from the approval request payload. + assert inv.tool_name == "search" + assert inv.server_url == "https://mcp.example/api" + assert inv.arguments == {"q": "x"} + # Headers are re-evaluated from the action definition on resume. assert inv.headers == {"Authorization": "Bearer tk"} - # Approval state was cleaned up. - assert approval_key not in mock_state._data # ActionComplete was sent. mock_context.send_message.assert_called_once() sent = mock_context.send_message.call_args[0][0] @@ -504,10 +486,8 @@ class TestApprovalFlow: async def test_approval_response_rejected_assigns_error(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def] from agent_framework_declarative._workflows import ToolApprovalResponse from agent_framework_declarative._workflows._executors_mcp import ( - _MCP_APPROVAL_STATE_KEY, InvokeMcpToolActionExecutor, MCPToolApprovalRequest, - _MCPToolApprovalState, ) _seed_state(mock_state) @@ -519,19 +499,6 @@ class TestApprovalFlow: ), mcp_tool_handler=handler, ) - approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action" - mock_state._data[approval_key] = _MCPToolApprovalState( - server_url="https://mcp.example/api", - tool_name="search", - server_label=None, - arguments={}, - connection_name=None, - headers_def=None, - auto_send=True, - conversation_id_expr=None, - output_messages_path=None, - output_result_path="Local.Result", - ) await executor.handle_approval_response( MCPToolApprovalRequest( request_id="req-2", diff --git a/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py index 85b513b562..358ee91904 100644 --- a/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py +++ b/python/samples/03-workflows/declarative/invoke_mcp_tool/main.py @@ -87,6 +87,8 @@ def _prompt_for_approval(request: MCPToolApprovalRequest) -> ToolApprovalRespons print(f" outbound header names: {', '.join(request.header_names)}") else: print(" outbound header names: (none)") + if request.connection_name: + print(f" connection: {request.connection_name}") print("-" * 60) while True: