Files
agent-framework/python/packages/declarative/tests/test_declarative_approval_binding.py
T
2026-06-11 22:02:58 -07:00

529 lines
21 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false
# pyright: reportMissingParameterType=false, reportUnknownMemberType=false
# pyright: reportPrivateUsage=false, reportUnknownVariableType=false
# pyright: reportGeneralTypeIssues=false
"""Regression tests pinning the approval-flow binding contract.
The resumed invocation MUST come from the framework-delivered
``original_request`` payload (the data the reviewer approved) for both
``InvokeFunctionTool`` and ``InvokeMcpTool``. These tests verify that:
* Invocation parameters come from ``original_request``, not from any prior
side-channel state.
* Concurrent pending approvals on the same executor do not swap.
* Pre-existing state at old approval keys is ignored entirely.
* Resume works on a freshly constructed executor (checkpoint-restore
simulation), without any prior ``ctx.state`` write.
* For MCP, ``connection_name`` is sourced from the approval payload and
``headers`` are re-evaluated from the action definition on resume.
"""
import sys
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
try:
import powerfx # noqa: F401
_powerfx_available = True
except (ImportError, RuntimeError):
_powerfx_available = False
pytestmark = pytest.mark.skipif(
not _powerfx_available or sys.version_info >= (3, 14),
reason="PowerFx engine not available (requires dotnet runtime)",
)
from agent_framework import Content # noqa: E402
from agent_framework_declarative._workflows import ( # noqa: E402
DECLARATIVE_STATE_KEY,
ActionComplete,
InvokeFunctionToolExecutor,
MCPToolApprovalRequest,
MCPToolHandler,
MCPToolInvocation,
MCPToolResult,
ToolApprovalRequest,
ToolApprovalResponse,
)
from agent_framework_declarative._workflows._declarative_base import DeclarativeWorkflowState # noqa: E402
from agent_framework_declarative._workflows._executors_mcp import ( # noqa: E402
InvokeMcpToolActionExecutor,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_state() -> MagicMock:
"""In-memory mock of the underlying State."""
state = MagicMock()
state._data = {}
def _get(key: str, default: Any = None) -> Any:
return state._data.get(key, default)
def _set(key: str, value: Any) -> None:
state._data[key] = value
def _has(key: str) -> bool:
return key in state._data
def _delete(key: str) -> None:
state._data.pop(key, None)
state.get = MagicMock(side_effect=_get)
state.set = MagicMock(side_effect=_set)
state.has = MagicMock(side_effect=_has)
state.delete = MagicMock(side_effect=_delete)
return state
@pytest.fixture
def mock_context(mock_state: MagicMock) -> MagicMock:
ctx = MagicMock()
ctx.state = mock_state
ctx.send_message = AsyncMock()
ctx.yield_output = AsyncMock()
ctx.request_info = AsyncMock()
return ctx
def _seed_state(mock_state: MagicMock) -> None:
mock_state._data[DECLARATIVE_STATE_KEY] = {
"Inputs": {},
"Outputs": {},
"Local": {},
"Custom": {},
"System": {
"ConversationId": "00000000-0000-0000-0000-000000000000",
"LastMessage": {"Text": "", "Id": ""},
"LastMessageText": "",
"LastMessageId": "",
},
"Agent": {},
"Conversation": {"messages": [], "history": []},
}
class _RecordingMcpHandler(MCPToolHandler):
def __init__(self, result: MCPToolResult | None = None) -> None:
self.result = result or MCPToolResult(outputs=[Content.from_text("ok")])
self.invocations: list[MCPToolInvocation] = []
@property
def call_count(self) -> int:
return len(self.invocations)
@property
def last(self) -> MCPToolInvocation | None:
return self.invocations[-1] if self.invocations else None
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
self.invocations.append(invocation)
return self.result
# ---------------------------------------------------------------------------
# InvokeFunctionTool: approval-binding regression
# ---------------------------------------------------------------------------
class TestFunctionToolApprovalBinding:
def _action(self, *, fn_name: str = "my_tool") -> dict[str, Any]:
return {
"kind": "InvokeFunctionTool",
"id": "fn_action",
"functionName": fn_name,
"requireApproval": True,
"output": {"result": "Local.result"},
}
@pytest.mark.asyncio
async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None:
"""The id on the emitted ToolApprovalRequest must match the framework's pending-request key."""
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
_seed_state(mock_state)
def my_tool(x: int) -> int:
return x
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
await executor.handle_action(ActionTrigger(), mock_context)
mock_context.request_info.assert_called_once()
emitted_request = mock_context.request_info.call_args[0][0]
framework_request_id = mock_context.request_info.call_args.kwargs["request_id"]
assert isinstance(emitted_request, ToolApprovalRequest)
assert emitted_request.request_id == framework_request_id
@pytest.mark.asyncio
async def test_resume_uses_request_payload_arguments(self, mock_state, mock_context) -> None:
_seed_state(mock_state)
call_log: list[int] = []
def my_tool(x: int) -> int:
call_log.append(x)
return x
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
request = ToolApprovalRequest(request_id="r-1", function_name="my_tool", arguments={"x": 1})
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
assert call_log == [1]
@pytest.mark.asyncio
async def test_concurrent_pending_approvals_do_not_swap(self, mock_state, mock_context) -> None:
"""Two pending approvals, responses delivered out of order — each invocation uses its own payload."""
_seed_state(mock_state)
call_log: list[int] = []
def my_tool(x: int) -> int:
call_log.append(x)
return x
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
request_a = ToolApprovalRequest(request_id="r-A", function_name="my_tool", arguments={"x": 1})
request_b = ToolApprovalRequest(request_id="r-B", function_name="my_tool", arguments={"x": 999})
# Deliver response for B first, then for A. Each invocation must use its own payload.
await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context)
await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context)
assert call_log == [999, 1]
@pytest.mark.asyncio
async def test_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None:
"""Pre-existing state at the OLD approval key is ignored — payload wins."""
_seed_state(mock_state)
call_log: list[int] = []
def my_tool(x: int) -> int:
call_log.append(x)
return x
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
# Poison the old key shape (no longer read by the executor).
mock_state._data["_tool_approval_state_fn_action"] = {"function_name": "other", "arguments": {"x": 999}}
request = ToolApprovalRequest(request_id="r-3", function_name="my_tool", arguments={"x": 7})
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
assert call_log == [7]
# The poison was never read or deleted by the executor.
assert "_tool_approval_state_fn_action" in mock_state._data
@pytest.mark.asyncio
async def test_fresh_executor_resume_works(self, mock_state, mock_context) -> None:
"""Simulates checkpoint restore: a brand-new executor instance handles the approval response."""
_seed_state(mock_state)
call_log: list[int] = []
def my_tool(x: int) -> int:
call_log.append(x)
return x
# Pretend the executor that emitted the request is gone; a fresh one handles the response.
fresh = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
request = ToolApprovalRequest(request_id="r-4", function_name="my_tool", arguments={"x": 42})
await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
assert call_log == [42]
mock_context.send_message.assert_called_once()
sent = mock_context.send_message.call_args[0][0]
assert isinstance(sent, ActionComplete)
@pytest.mark.asyncio
async def test_rejection_uses_request_payload_function_name(self, mock_state, mock_context) -> None:
_seed_state(mock_state)
def my_tool(x: int) -> int:
raise AssertionError("should not be called when rejected")
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
request = ToolApprovalRequest(request_id="r-5", function_name="my_tool", arguments={"x": 3})
await executor.handle_approval_response(
request, ToolApprovalResponse(approved=False, reason="not authorized"), mock_context
)
# The rejection message references the function name from the request payload.
local = mock_state._data[DECLARATIVE_STATE_KEY]["Local"]
assert local["result"]["rejected"] is True
assert local["result"]["reason"] == "not authorized"
# ---------------------------------------------------------------------------
# InvokeMcpTool: approval-binding regression
# ---------------------------------------------------------------------------
class TestMcpToolApprovalBinding:
def _action(self, *, headers: dict[str, Any] | None = None) -> dict[str, Any]:
action: dict[str, Any] = {
"kind": "InvokeMcpTool",
"id": "mcp_action",
"serverUrl": "https://mcp.example/api",
"toolName": "search",
"requireApproval": True,
"output": {"result": "Local.Result"},
}
if headers is not None:
action["headers"] = headers
return action
@pytest.mark.asyncio
async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None:
"""The id on the emitted MCPToolApprovalRequest must match the framework's pending-request key."""
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
_seed_state(mock_state)
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=_RecordingMcpHandler())
await executor.handle_action(ActionTrigger(), mock_context)
mock_context.request_info.assert_called_once()
emitted_request = mock_context.request_info.call_args[0][0]
framework_request_id = mock_context.request_info.call_args.kwargs["request_id"]
assert isinstance(emitted_request, MCPToolApprovalRequest)
assert emitted_request.request_id == framework_request_id
@pytest.mark.asyncio
async def test_resume_uses_request_payload_fields(self, mock_state, mock_context) -> None:
_seed_state(mock_state)
handler = _RecordingMcpHandler()
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
request = MCPToolApprovalRequest(
request_id="r-1",
tool_name="search",
server_url="https://mcp.example/api",
server_label="prod",
arguments={"q": "x"},
connection_name="conn-A",
)
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
assert handler.call_count == 1
inv = handler.last
assert inv is not None
assert inv.tool_name == "search"
assert inv.server_url == "https://mcp.example/api"
assert inv.server_label == "prod"
assert inv.arguments == {"q": "x"}
assert inv.connection_name == "conn-A"
@pytest.mark.asyncio
async def test_concurrent_pending_mcp_approvals_do_not_swap(self, mock_state, mock_context) -> None:
_seed_state(mock_state)
handler = _RecordingMcpHandler()
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
request_a = MCPToolApprovalRequest(
request_id="r-A",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "alpha"},
connection_name="conn-A",
)
request_b = MCPToolApprovalRequest(
request_id="r-B",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "beta"},
connection_name="conn-B",
)
await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context)
await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context)
assert handler.call_count == 2
assert handler.invocations[0].arguments == {"q": "beta"}
assert handler.invocations[0].connection_name == "conn-B"
assert handler.invocations[1].arguments == {"q": "alpha"}
assert handler.invocations[1].connection_name == "conn-A"
@pytest.mark.asyncio
async def test_headers_reevaluated_from_action_def_on_resume(self, mock_state, mock_context) -> None:
"""Headers come from the action definition (re-evaluated) so secrets are not in the payload."""
_seed_state(mock_state)
handler = _RecordingMcpHandler()
executor = InvokeMcpToolActionExecutor(
self._action(headers={"Authorization": "Bearer tk"}),
mcp_tool_handler=handler,
)
request = MCPToolApprovalRequest(
request_id="r-1",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "x"},
connection_name=None,
)
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
assert handler.last is not None
assert handler.last.headers == {"Authorization": "Bearer tk"}
@pytest.mark.asyncio
async def test_mcp_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None:
_seed_state(mock_state)
handler = _RecordingMcpHandler()
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
mock_state._data["_mcp_tool_approval_state_mcp_action"] = {"poison": True}
request = MCPToolApprovalRequest(
request_id="r-1",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "real"},
connection_name=None,
)
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
assert handler.call_count == 1
assert handler.last is not None
assert handler.last.arguments == {"q": "real"}
# The poison was never read or deleted by the executor.
assert "_mcp_tool_approval_state_mcp_action" in mock_state._data
@pytest.mark.asyncio
async def test_fresh_mcp_executor_resume_works(self, mock_state, mock_context) -> None:
"""Checkpoint-restore simulation: fresh executor handles the response."""
_seed_state(mock_state)
handler = _RecordingMcpHandler()
fresh = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
request = MCPToolApprovalRequest(
request_id="r-1",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "fresh"},
connection_name=None,
)
await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
assert handler.call_count == 1
assert handler.last is not None
assert handler.last.arguments == {"q": "fresh"}
@pytest.mark.asyncio
async def test_request_payload_carries_connection_name(self, mock_state, mock_context) -> None:
"""When emitting the approval request, connection_name flows into MCPToolApprovalRequest."""
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
_seed_state(mock_state)
action = self._action()
action["connection"] = {"name": "conn-from-action"}
executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler())
await executor.handle_action(ActionTrigger(), mock_context)
mock_context.request_info.assert_called_once()
request = mock_context.request_info.call_args[0][0]
assert isinstance(request, MCPToolApprovalRequest)
assert request.connection_name == "conn-from-action"
@pytest.mark.asyncio
async def test_request_payload_pins_conversation_id(self, mock_state, mock_context) -> None:
"""Evaluated ``conversationId`` is pinned in ``metadata`` at request-emit time."""
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
_seed_state(mock_state)
state = DeclarativeWorkflowState(mock_state)
state.set("Local.targetConversation", "conv-original")
action = self._action()
action["conversationId"] = "=Local.targetConversation"
executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler())
await executor.handle_action(ActionTrigger(), mock_context)
mock_context.request_info.assert_called_once()
request = mock_context.request_info.call_args[0][0]
assert isinstance(request, MCPToolApprovalRequest)
assert request.metadata.get("conversation_id") == "conv-original"
@pytest.mark.asyncio
async def test_resume_routes_output_to_pinned_conversation_not_mutated_state(
self, mock_state, mock_context
) -> None:
"""Output appends to the conversation pinned on ``original_request``, not the
current state evaluation."""
_seed_state(mock_state)
state = DeclarativeWorkflowState(mock_state)
state.set("System.conversations.conv-original.messages", [])
state.set("System.conversations.conv-mutated.messages", [])
state.set("Local.targetConversation", "conv-mutated")
handler = _RecordingMcpHandler(MCPToolResult(outputs=[Content.from_text("approved-output")]))
action = self._action()
action["conversationId"] = "=Local.targetConversation"
executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=handler)
original_request = MCPToolApprovalRequest(
request_id="r-1",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "x"},
connection_name=None,
metadata={"conversation_id": "conv-original"},
)
await executor.handle_approval_response(original_request, ToolApprovalResponse(approved=True), mock_context)
assert len(state.get("System.conversations.conv-original.messages") or []) == 1
assert state.get("System.conversations.conv-mutated.messages") == []
@pytest.mark.asyncio
async def test_resume_handles_legacy_request_without_new_fields(self, mock_state, mock_context) -> None:
"""Resume tolerates payloads lacking ``connection_name`` / ``metadata`` (legacy pickle shape)."""
@dataclass
class _LegacyMCPApprovalRequest:
request_id: str
tool_name: str
server_url: str
server_label: str | None
arguments: dict[str, Any]
header_names: list[str]
_seed_state(mock_state)
handler = _RecordingMcpHandler()
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
legacy_request = _LegacyMCPApprovalRequest(
request_id="r-1",
tool_name="search",
server_url="https://mcp.example/api",
server_label=None,
arguments={"q": "x"},
header_names=[],
)
await executor.handle_approval_response(
legacy_request, # type: ignore[arg-type]
ToolApprovalResponse(approved=True),
mock_context,
)
assert handler.call_count == 1
assert handler.last is not None
assert handler.last.connection_name is None