mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [Breaking] Additional bug fix for declarative workflows (#6489)
* Fix declarative object parsing bug * Remove unnecessary comment * Address PR comments * Address PR comments. * Fix CI failures. * declarative action approval bugfix * Address PR comments * Inlined single use variables.
This commit is contained in:
committed by
GitHub
Unverified
parent
0f483fa968
commit
ed4ff188fc
@@ -76,12 +76,10 @@ from ._executors_mcp import (
|
||||
from ._executors_tools import (
|
||||
FUNCTION_TOOL_REGISTRY_KEY,
|
||||
TOOL_ACTION_EXECUTORS,
|
||||
TOOL_APPROVAL_STATE_KEY,
|
||||
BaseToolExecutor,
|
||||
InvokeFunctionToolExecutor,
|
||||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
ToolApprovalState,
|
||||
ToolInvocationResult,
|
||||
)
|
||||
from ._factory import WorkflowFactory
|
||||
@@ -111,7 +109,6 @@ __all__ = [
|
||||
"HTTP_ACTION_EXECUTORS",
|
||||
"MCP_ACTION_EXECUTORS",
|
||||
"TOOL_ACTION_EXECUTORS",
|
||||
"TOOL_APPROVAL_STATE_KEY",
|
||||
"TOOL_REGISTRY_KEY",
|
||||
"ActionComplete",
|
||||
"ActionTrigger",
|
||||
@@ -164,7 +161,6 @@ __all__ = [
|
||||
"SetVariableExecutor",
|
||||
"ToolApprovalRequest",
|
||||
"ToolApprovalResponse",
|
||||
"ToolApprovalState",
|
||||
"ToolInvocationResult",
|
||||
"WorkflowFactory",
|
||||
"WorkflowState",
|
||||
|
||||
+51
-116
@@ -10,17 +10,11 @@ optional conversation history. Supports a human-in-loop approval flow via
|
||||
|
||||
Security notes:
|
||||
|
||||
- The executor never echoes header VALUES (auth tokens, API keys) into the
|
||||
approval request — only header NAMES are surfaced to the caller. This
|
||||
matches the security posture of :mod:`._executors_http` (which never logs
|
||||
request headers either) and prevents secrets from leaking through workflow
|
||||
events that are typically observable to operators / UIs.
|
||||
- ``_MCPToolApprovalState`` snapshots the EVALUATED values for non-secret
|
||||
fields (server URL, tool name, arguments) at approval-request time so that
|
||||
subsequent state mutations cannot make the executor "approve X then call
|
||||
Y". Headers are stored as the raw expression strings (not evaluated values)
|
||||
so secrets are not persisted in the workflow's checkpoint state. They are
|
||||
re-evaluated on resume.
|
||||
- Approval requests surface header NAMES only; header values are not echoed,
|
||||
matching the posture of :mod:`._executors_http`.
|
||||
- :class:`MCPToolApprovalRequest` carries the values the resume handler will
|
||||
use; header values are re-evaluated on resume to keep secrets out of
|
||||
checkpoint state.
|
||||
- Tool outputs flow back into agent conversations through ``conversationId``
|
||||
and through Tool-role messages emitted to ``output.messages``. They share
|
||||
the same prompt-injection risk surface as ``HttpRequestAction``: workflow
|
||||
@@ -60,8 +54,6 @@ __all__ = [
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / state types
|
||||
@@ -72,20 +64,16 @@ _MCP_APPROVAL_STATE_KEY = "_mcp_tool_approval_state"
|
||||
class MCPToolApprovalRequest:
|
||||
"""Approval request emitted before invoking an MCP tool.
|
||||
|
||||
Mirrors :class:`agent_framework_declarative.ToolApprovalRequest` but for
|
||||
MCP-style invocations. Only header NAMES are surfaced — header values are
|
||||
intentionally omitted because they typically carry authentication
|
||||
secrets.
|
||||
|
||||
Attributes:
|
||||
request_id: Unique identifier for this approval request. Matches the
|
||||
id workflow event-emitters use.
|
||||
tool_name: Evaluated name of the tool to be invoked.
|
||||
request_id: Identifier matching the framework's pending-request key.
|
||||
tool_name: Evaluated tool name.
|
||||
server_url: Evaluated MCP server URL.
|
||||
server_label: Optional human-readable label for diagnostics.
|
||||
arguments: Evaluated arguments to be forwarded to the tool.
|
||||
header_names: Sorted list of outbound header names (no values). Empty
|
||||
when no headers are configured.
|
||||
server_label: Optional human-readable label.
|
||||
arguments: Evaluated tool arguments.
|
||||
header_names: Outbound header names (values withheld).
|
||||
connection_name: Connection identifier the invocation will use.
|
||||
metadata: Internal routing data pinned at approval-request time
|
||||
(e.g. ``conversation_id``) for use by the resume handler.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
@@ -94,28 +82,8 @@ class MCPToolApprovalRequest:
|
||||
server_label: str | None
|
||||
arguments: dict[str, Any]
|
||||
header_names: list[str] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MCPToolApprovalState:
|
||||
"""Internal state saved during the approval yield for resumption.
|
||||
|
||||
Stores **evaluated** values for non-secret fields to prevent
|
||||
"approve X / execute Y" attacks. Stores the raw expression string for
|
||||
``headers`` so that secret values are NOT persisted in checkpoint state;
|
||||
the expressions are re-evaluated against current state on resume.
|
||||
"""
|
||||
|
||||
server_url: str
|
||||
tool_name: str
|
||||
server_label: str | None
|
||||
arguments: dict[str, Any]
|
||||
connection_name: str | None
|
||||
headers_def: Any
|
||||
auto_send: bool
|
||||
conversation_id_expr: str | None
|
||||
output_messages_path: str | None
|
||||
output_result_path: str | None
|
||||
connection_name: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=lambda: {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -123,21 +91,15 @@ class _MCPToolApprovalState:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_messages_path(state: DeclarativeWorkflowState, conversation_id_expr: str | None) -> str | None:
|
||||
"""Return the configured conversation messages path, if any.
|
||||
|
||||
Returns ``System.conversations.{evaluated_id}.messages`` when a
|
||||
``conversation_id_expr`` is configured and evaluates to a non-empty value.
|
||||
Returns ``None`` when no conversation id expression is configured or when
|
||||
the expression evaluates to ``None`` or an empty string (mirrors .NET
|
||||
``GetConversationId`` behaviour).
|
||||
"""
|
||||
if not conversation_id_expr:
|
||||
def _evaluate_conversation_id(state: DeclarativeWorkflowState, conversation_id_expr: Any) -> str | None:
|
||||
"""Return the evaluated ``conversationId`` string, or None when empty/unset."""
|
||||
if not isinstance(conversation_id_expr, str) or not conversation_id_expr:
|
||||
return None
|
||||
evaluated = state.eval_if_expression(conversation_id_expr)
|
||||
if evaluated is None or (isinstance(evaluated, str) and not evaluated):
|
||||
if evaluated is None:
|
||||
return None
|
||||
return f"System.conversations.{evaluated}.messages"
|
||||
text = str(evaluated)
|
||||
return text or None
|
||||
|
||||
|
||||
def _get_output_path(action_def: Mapping[str, Any], key: str) -> str | None:
|
||||
@@ -260,20 +222,7 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
|
||||
if require_approval:
|
||||
request_id = str(uuid.uuid4())
|
||||
approval_state = _MCPToolApprovalState(
|
||||
server_url=server_url,
|
||||
tool_name=tool_name,
|
||||
server_label=server_label,
|
||||
arguments=arguments,
|
||||
connection_name=connection_name,
|
||||
headers_def=self._action_def.get("headers"),
|
||||
auto_send=auto_send,
|
||||
conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None,
|
||||
output_messages_path=output_messages_path,
|
||||
output_result_path=output_result_path,
|
||||
)
|
||||
ctx.state.set(self._approval_key(), approval_state)
|
||||
|
||||
conversation_id = _evaluate_conversation_id(state, conversation_id_expr)
|
||||
request = MCPToolApprovalRequest(
|
||||
request_id=request_id,
|
||||
tool_name=tool_name,
|
||||
@@ -281,6 +230,8 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
server_label=server_label,
|
||||
arguments=arguments,
|
||||
header_names=sorted(headers.keys()),
|
||||
connection_name=connection_name,
|
||||
metadata={"conversation_id": conversation_id},
|
||||
)
|
||||
logger.info(
|
||||
"%s: requesting approval for MCP tool '%s' on '%s'",
|
||||
@@ -289,7 +240,6 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
server_url,
|
||||
)
|
||||
await ctx.request_info(request, ToolApprovalResponse, request_id=request_id)
|
||||
# Workflow yields here — resume in handle_approval_response.
|
||||
return
|
||||
|
||||
# No approval required - invoke directly.
|
||||
@@ -307,7 +257,7 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
state=state,
|
||||
result=result,
|
||||
auto_send=auto_send,
|
||||
conversation_id_expr=conversation_id_expr if isinstance(conversation_id_expr, str) else None,
|
||||
conversation_id=_evaluate_conversation_id(state, conversation_id_expr),
|
||||
output_messages_path=output_messages_path,
|
||||
output_result_path=output_result_path,
|
||||
)
|
||||
@@ -322,54 +272,46 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
response: ToolApprovalResponse,
|
||||
ctx: WorkflowContext[ActionComplete, str],
|
||||
) -> None:
|
||||
"""Resume after the workflow yielded for an approval request."""
|
||||
"""Resume the invocation using the values pinned on ``original_request``."""
|
||||
state = self._get_state(ctx.state)
|
||||
approval_key = self._approval_key()
|
||||
|
||||
try:
|
||||
approval_state: _MCPToolApprovalState = ctx.state.get(approval_key)
|
||||
except KeyError:
|
||||
logger.error("%s: approval state missing for executor '%s'", self.__class__.__name__, self.id)
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
try:
|
||||
ctx.state.delete(approval_key)
|
||||
except KeyError:
|
||||
logger.warning("%s: approval state already deleted for '%s'", self.__class__.__name__, self.id)
|
||||
tool_name = original_request.tool_name
|
||||
metadata: dict[str, Any] = getattr(original_request, "metadata", None) or {}
|
||||
raw_conversation_id = metadata.get("conversation_id")
|
||||
conversation_id = raw_conversation_id if isinstance(raw_conversation_id, str) and raw_conversation_id else None
|
||||
|
||||
auto_send = self._get_auto_send(state)
|
||||
output_messages_path = _get_output_path(self._action_def, "messages")
|
||||
output_result_path = _get_output_path(self._action_def, "result")
|
||||
|
||||
if not response.approved:
|
||||
logger.info(
|
||||
"%s: MCP tool '%s' rejected: %s",
|
||||
self.__class__.__name__,
|
||||
approval_state.tool_name,
|
||||
tool_name,
|
||||
response.reason,
|
||||
)
|
||||
self._assign_error(
|
||||
state, approval_state.output_result_path, "MCP tool invocation was not approved by user."
|
||||
)
|
||||
self._assign_error(state, output_result_path, "MCP tool invocation was not approved by user.")
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
# Approved — re-evaluate headers (not stored at approval time for security).
|
||||
headers = self._evaluate_headers(state, approval_state.headers_def)
|
||||
|
||||
invocation = MCPToolInvocation(
|
||||
server_url=approval_state.server_url,
|
||||
tool_name=approval_state.tool_name,
|
||||
server_label=approval_state.server_label,
|
||||
arguments=approval_state.arguments,
|
||||
headers=headers,
|
||||
connection_name=approval_state.connection_name,
|
||||
server_url=original_request.server_url,
|
||||
tool_name=tool_name,
|
||||
server_label=original_request.server_label,
|
||||
arguments=original_request.arguments,
|
||||
headers=self._evaluate_headers(state, self._action_def.get("headers")),
|
||||
connection_name=getattr(original_request, "connection_name", None),
|
||||
)
|
||||
result = await self._invoke_with_narrow_catch(invocation)
|
||||
await self._process_result(
|
||||
ctx=ctx,
|
||||
state=state,
|
||||
result=result,
|
||||
auto_send=approval_state.auto_send,
|
||||
conversation_id_expr=approval_state.conversation_id_expr,
|
||||
output_messages_path=approval_state.output_messages_path,
|
||||
output_result_path=approval_state.output_result_path,
|
||||
auto_send=auto_send,
|
||||
conversation_id=conversation_id,
|
||||
output_messages_path=output_messages_path,
|
||||
output_result_path=output_result_path,
|
||||
)
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
@@ -528,7 +470,7 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
state: DeclarativeWorkflowState,
|
||||
result: MCPToolResult,
|
||||
auto_send: bool,
|
||||
conversation_id_expr: str | None,
|
||||
conversation_id: str | None,
|
||||
output_messages_path: str | None,
|
||||
output_result_path: str | None,
|
||||
) -> None:
|
||||
@@ -557,14 +499,10 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
if auto_send and parsed_results:
|
||||
await ctx.yield_output(_format_outputs_for_send(parsed_results))
|
||||
|
||||
if conversation_id_expr:
|
||||
messages_path = _get_messages_path(state, conversation_id_expr)
|
||||
if messages_path is not None:
|
||||
# Mirrors .NET: conversation gets ASSISTANT-role message with
|
||||
# the same outputs (so chat history reads it as the agent's
|
||||
# contribution).
|
||||
assistant_message = Message(role="assistant", contents=list(result.outputs))
|
||||
state.append(messages_path, assistant_message)
|
||||
if conversation_id:
|
||||
messages_path = f"System.conversations.{conversation_id}.messages"
|
||||
assistant_message = Message(role="assistant", contents=list(result.outputs))
|
||||
state.append(messages_path, assistant_message)
|
||||
|
||||
@staticmethod
|
||||
def _assign_error(
|
||||
@@ -577,9 +515,6 @@ class InvokeMcpToolActionExecutor(DeclarativeActionExecutor):
|
||||
return
|
||||
state.set(output_result_path, f"Error: {error_message}")
|
||||
|
||||
def _approval_key(self) -> str:
|
||||
return f"{_MCP_APPROVAL_STATE_KEY}_{self.id}"
|
||||
|
||||
|
||||
def _parse_outputs(outputs: list[Content]) -> list[Any]:
|
||||
"""Parse :class:`Content` outputs into Python values for ``output.result``.
|
||||
|
||||
+12
-65
@@ -41,10 +41,6 @@ logger = logging.getLogger(__name__)
|
||||
# at runtime are discoverable by both agent-based and function-based tool executors.
|
||||
FUNCTION_TOOL_REGISTRY_KEY = TOOL_REGISTRY_KEY
|
||||
|
||||
# State key prefix for storing approval state during yield/resume.
|
||||
# The executor's ID is appended to create a per-executor key.
|
||||
TOOL_APPROVAL_STATE_KEY = "_tool_approval_state"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Types for Approval Flow
|
||||
@@ -87,26 +83,6 @@ class ToolApprovalResponse:
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# State Types for Approval Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolApprovalState:
|
||||
"""State saved during approval yield for resumption.
|
||||
|
||||
Stored in State under a per-executor key when requireApproval=true.
|
||||
Retrieved by handle_approval_response() to continue execution.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
arguments: dict[str, Any]
|
||||
output_messages_var: str | None
|
||||
output_result_var: str | None
|
||||
auto_send: bool
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Result Types
|
||||
# ============================================================================
|
||||
@@ -501,25 +477,16 @@ class BaseToolExecutor(DeclarativeActionExecutor):
|
||||
require_approval = self._action_def.get("requireApproval", False)
|
||||
|
||||
if require_approval:
|
||||
# Save state for resumption (keyed by executor ID to avoid collisions)
|
||||
approval_state = ToolApprovalState(
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
output_messages_var=messages_var,
|
||||
output_result_var=result_var,
|
||||
auto_send=auto_send,
|
||||
)
|
||||
approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}"
|
||||
ctx.state.set(approval_key, approval_state)
|
||||
|
||||
# Emit approval request - workflow yields here
|
||||
# Emit approval request - the request payload is the source of
|
||||
# truth for resumed invocation; no side-channel state is written.
|
||||
request_id = str(uuid.uuid4())
|
||||
request = ToolApprovalRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
request_id=request_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
logger.info(f"{self.__class__.__name__}: requesting approval for '{function_name}'")
|
||||
await ctx.request_info(request, ToolApprovalResponse)
|
||||
await ctx.request_info(request, ToolApprovalResponse, request_id=request_id)
|
||||
# Workflow yields - will resume in handle_approval_response
|
||||
return
|
||||
|
||||
@@ -545,36 +512,16 @@ class BaseToolExecutor(DeclarativeActionExecutor):
|
||||
) -> None:
|
||||
"""Handle response to a ToolApprovalRequest.
|
||||
|
||||
Called when the workflow resumes after yielding for approval.
|
||||
Either executes the tool (if approved) or stores rejection status.
|
||||
Resumes after the workflow yielded for approval. The invocation
|
||||
``function_name`` and ``arguments`` are sourced from
|
||||
``original_request`` (the payload the reviewer approved); output
|
||||
configuration is re-derived from the executor's action definition.
|
||||
"""
|
||||
state = self._get_state(ctx.state)
|
||||
approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}"
|
||||
|
||||
# Retrieve saved invocation state
|
||||
try:
|
||||
approval_state: ToolApprovalState = ctx.state.get(approval_key)
|
||||
except KeyError:
|
||||
error_msg = "Approval state not found, cannot resume tool invocation"
|
||||
logger.error(f"{self.__class__.__name__}: {error_msg}")
|
||||
# Try to store error - get output config from action def as fallback
|
||||
_, result_var, _ = self._get_output_config()
|
||||
if result_var and state:
|
||||
state.set(_normalize_variable_path(result_var), {"error": error_msg})
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
# Clean up approval state
|
||||
try:
|
||||
ctx.state.delete(approval_key)
|
||||
except KeyError:
|
||||
logger.warning(f"{self.__class__.__name__}: approval state already deleted")
|
||||
|
||||
function_name = approval_state.function_name
|
||||
arguments = approval_state.arguments
|
||||
messages_var = approval_state.output_messages_var
|
||||
result_var = approval_state.output_result_var
|
||||
auto_send = approval_state.auto_send
|
||||
function_name = original_request.function_name
|
||||
arguments = original_request.arguments
|
||||
messages_var, result_var, auto_send = self._get_output_config()
|
||||
|
||||
# Check if approved
|
||||
if not response.approved:
|
||||
|
||||
@@ -0,0 +1,528 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false
|
||||
# pyright: reportMissingParameterType=false, reportUnknownMemberType=false
|
||||
# pyright: reportPrivateUsage=false, reportUnknownVariableType=false
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
|
||||
"""Regression tests pinning the approval-flow binding contract.
|
||||
|
||||
The resumed invocation MUST come from the framework-delivered
|
||||
``original_request`` payload (the data the reviewer approved) for both
|
||||
``InvokeFunctionTool`` and ``InvokeMcpTool``. These tests verify that:
|
||||
|
||||
* Invocation parameters come from ``original_request``, not from any prior
|
||||
side-channel state.
|
||||
* Concurrent pending approvals on the same executor do not swap.
|
||||
* Pre-existing state at old approval keys is ignored entirely.
|
||||
* Resume works on a freshly constructed executor (checkpoint-restore
|
||||
simulation), without any prior ``ctx.state`` write.
|
||||
* For MCP, ``connection_name`` is sourced from the approval payload and
|
||||
``headers`` are re-evaluated from the action definition on resume.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import powerfx # noqa: F401
|
||||
|
||||
_powerfx_available = True
|
||||
except (ImportError, RuntimeError):
|
||||
_powerfx_available = False
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _powerfx_available or sys.version_info >= (3, 14),
|
||||
reason="PowerFx engine not available (requires dotnet runtime)",
|
||||
)
|
||||
|
||||
from agent_framework import Content # noqa: E402
|
||||
|
||||
from agent_framework_declarative._workflows import ( # noqa: E402
|
||||
DECLARATIVE_STATE_KEY,
|
||||
ActionComplete,
|
||||
InvokeFunctionToolExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
MCPToolHandler,
|
||||
MCPToolInvocation,
|
||||
MCPToolResult,
|
||||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
)
|
||||
from agent_framework_declarative._workflows._declarative_base import DeclarativeWorkflowState # noqa: E402
|
||||
from agent_framework_declarative._workflows._executors_mcp import ( # noqa: E402
|
||||
InvokeMcpToolActionExecutor,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state() -> MagicMock:
|
||||
"""In-memory mock of the underlying State."""
|
||||
state = MagicMock()
|
||||
state._data = {}
|
||||
|
||||
def _get(key: str, default: Any = None) -> Any:
|
||||
return state._data.get(key, default)
|
||||
|
||||
def _set(key: str, value: Any) -> None:
|
||||
state._data[key] = value
|
||||
|
||||
def _has(key: str) -> bool:
|
||||
return key in state._data
|
||||
|
||||
def _delete(key: str) -> None:
|
||||
state._data.pop(key, None)
|
||||
|
||||
state.get = MagicMock(side_effect=_get)
|
||||
state.set = MagicMock(side_effect=_set)
|
||||
state.has = MagicMock(side_effect=_has)
|
||||
state.delete = MagicMock(side_effect=_delete)
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(mock_state: MagicMock) -> MagicMock:
|
||||
ctx = MagicMock()
|
||||
ctx.state = mock_state
|
||||
ctx.send_message = AsyncMock()
|
||||
ctx.yield_output = AsyncMock()
|
||||
ctx.request_info = AsyncMock()
|
||||
return ctx
|
||||
|
||||
|
||||
def _seed_state(mock_state: MagicMock) -> None:
|
||||
mock_state._data[DECLARATIVE_STATE_KEY] = {
|
||||
"Inputs": {},
|
||||
"Outputs": {},
|
||||
"Local": {},
|
||||
"Custom": {},
|
||||
"System": {
|
||||
"ConversationId": "00000000-0000-0000-0000-000000000000",
|
||||
"LastMessage": {"Text": "", "Id": ""},
|
||||
"LastMessageText": "",
|
||||
"LastMessageId": "",
|
||||
},
|
||||
"Agent": {},
|
||||
"Conversation": {"messages": [], "history": []},
|
||||
}
|
||||
|
||||
|
||||
class _RecordingMcpHandler(MCPToolHandler):
|
||||
def __init__(self, result: MCPToolResult | None = None) -> None:
|
||||
self.result = result or MCPToolResult(outputs=[Content.from_text("ok")])
|
||||
self.invocations: list[MCPToolInvocation] = []
|
||||
|
||||
@property
|
||||
def call_count(self) -> int:
|
||||
return len(self.invocations)
|
||||
|
||||
@property
|
||||
def last(self) -> MCPToolInvocation | None:
|
||||
return self.invocations[-1] if self.invocations else None
|
||||
|
||||
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
|
||||
self.invocations.append(invocation)
|
||||
return self.result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InvokeFunctionTool: approval-binding regression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFunctionToolApprovalBinding:
|
||||
def _action(self, *, fn_name: str = "my_tool") -> dict[str, Any]:
|
||||
return {
|
||||
"kind": "InvokeFunctionTool",
|
||||
"id": "fn_action",
|
||||
"functionName": fn_name,
|
||||
"requireApproval": True,
|
||||
"output": {"result": "Local.result"},
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None:
|
||||
"""The id on the emitted ToolApprovalRequest must match the framework's pending-request key."""
|
||||
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
|
||||
|
||||
_seed_state(mock_state)
|
||||
|
||||
def my_tool(x: int) -> int:
|
||||
return x
|
||||
|
||||
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
mock_context.request_info.assert_called_once()
|
||||
emitted_request = mock_context.request_info.call_args[0][0]
|
||||
framework_request_id = mock_context.request_info.call_args.kwargs["request_id"]
|
||||
assert isinstance(emitted_request, ToolApprovalRequest)
|
||||
assert emitted_request.request_id == framework_request_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_uses_request_payload_arguments(self, mock_state, mock_context) -> None:
|
||||
_seed_state(mock_state)
|
||||
call_log: list[int] = []
|
||||
|
||||
def my_tool(x: int) -> int:
|
||||
call_log.append(x)
|
||||
return x
|
||||
|
||||
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
|
||||
|
||||
request = ToolApprovalRequest(request_id="r-1", function_name="my_tool", arguments={"x": 1})
|
||||
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert call_log == [1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_pending_approvals_do_not_swap(self, mock_state, mock_context) -> None:
|
||||
"""Two pending approvals, responses delivered out of order — each invocation uses its own payload."""
|
||||
_seed_state(mock_state)
|
||||
call_log: list[int] = []
|
||||
|
||||
def my_tool(x: int) -> int:
|
||||
call_log.append(x)
|
||||
return x
|
||||
|
||||
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
|
||||
|
||||
request_a = ToolApprovalRequest(request_id="r-A", function_name="my_tool", arguments={"x": 1})
|
||||
request_b = ToolApprovalRequest(request_id="r-B", function_name="my_tool", arguments={"x": 999})
|
||||
|
||||
# Deliver response for B first, then for A. Each invocation must use its own payload.
|
||||
await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context)
|
||||
await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert call_log == [999, 1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None:
|
||||
"""Pre-existing state at the OLD approval key is ignored — payload wins."""
|
||||
_seed_state(mock_state)
|
||||
call_log: list[int] = []
|
||||
|
||||
def my_tool(x: int) -> int:
|
||||
call_log.append(x)
|
||||
return x
|
||||
|
||||
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
|
||||
|
||||
# Poison the old key shape (no longer read by the executor).
|
||||
mock_state._data["_tool_approval_state_fn_action"] = {"function_name": "other", "arguments": {"x": 999}}
|
||||
|
||||
request = ToolApprovalRequest(request_id="r-3", function_name="my_tool", arguments={"x": 7})
|
||||
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert call_log == [7]
|
||||
# The poison was never read or deleted by the executor.
|
||||
assert "_tool_approval_state_fn_action" in mock_state._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_executor_resume_works(self, mock_state, mock_context) -> None:
|
||||
"""Simulates checkpoint restore: a brand-new executor instance handles the approval response."""
|
||||
_seed_state(mock_state)
|
||||
call_log: list[int] = []
|
||||
|
||||
def my_tool(x: int) -> int:
|
||||
call_log.append(x)
|
||||
return x
|
||||
|
||||
# Pretend the executor that emitted the request is gone; a fresh one handles the response.
|
||||
fresh = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
|
||||
|
||||
request = ToolApprovalRequest(request_id="r-4", function_name="my_tool", arguments={"x": 42})
|
||||
await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert call_log == [42]
|
||||
mock_context.send_message.assert_called_once()
|
||||
sent = mock_context.send_message.call_args[0][0]
|
||||
assert isinstance(sent, ActionComplete)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejection_uses_request_payload_function_name(self, mock_state, mock_context) -> None:
|
||||
_seed_state(mock_state)
|
||||
|
||||
def my_tool(x: int) -> int:
|
||||
raise AssertionError("should not be called when rejected")
|
||||
|
||||
executor = InvokeFunctionToolExecutor(self._action(), tools={"my_tool": my_tool})
|
||||
|
||||
request = ToolApprovalRequest(request_id="r-5", function_name="my_tool", arguments={"x": 3})
|
||||
await executor.handle_approval_response(
|
||||
request, ToolApprovalResponse(approved=False, reason="not authorized"), mock_context
|
||||
)
|
||||
|
||||
# The rejection message references the function name from the request payload.
|
||||
local = mock_state._data[DECLARATIVE_STATE_KEY]["Local"]
|
||||
assert local["result"]["rejected"] is True
|
||||
assert local["result"]["reason"] == "not authorized"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InvokeMcpTool: approval-binding regression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMcpToolApprovalBinding:
|
||||
def _action(self, *, headers: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
action: dict[str, Any] = {
|
||||
"kind": "InvokeMcpTool",
|
||||
"id": "mcp_action",
|
||||
"serverUrl": "https://mcp.example/api",
|
||||
"toolName": "search",
|
||||
"requireApproval": True,
|
||||
"output": {"result": "Local.Result"},
|
||||
}
|
||||
if headers is not None:
|
||||
action["headers"] = headers
|
||||
return action
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_id_matches_framework_pending_key(self, mock_state, mock_context) -> None:
|
||||
"""The id on the emitted MCPToolApprovalRequest must match the framework's pending-request key."""
|
||||
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
|
||||
|
||||
_seed_state(mock_state)
|
||||
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=_RecordingMcpHandler())
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
mock_context.request_info.assert_called_once()
|
||||
emitted_request = mock_context.request_info.call_args[0][0]
|
||||
framework_request_id = mock_context.request_info.call_args.kwargs["request_id"]
|
||||
assert isinstance(emitted_request, MCPToolApprovalRequest)
|
||||
assert emitted_request.request_id == framework_request_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_uses_request_payload_fields(self, mock_state, mock_context) -> None:
|
||||
_seed_state(mock_state)
|
||||
handler = _RecordingMcpHandler()
|
||||
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
|
||||
|
||||
request = MCPToolApprovalRequest(
|
||||
request_id="r-1",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label="prod",
|
||||
arguments={"q": "x"},
|
||||
connection_name="conn-A",
|
||||
)
|
||||
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert handler.call_count == 1
|
||||
inv = handler.last
|
||||
assert inv is not None
|
||||
assert inv.tool_name == "search"
|
||||
assert inv.server_url == "https://mcp.example/api"
|
||||
assert inv.server_label == "prod"
|
||||
assert inv.arguments == {"q": "x"}
|
||||
assert inv.connection_name == "conn-A"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_pending_mcp_approvals_do_not_swap(self, mock_state, mock_context) -> None:
|
||||
_seed_state(mock_state)
|
||||
handler = _RecordingMcpHandler()
|
||||
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
|
||||
|
||||
request_a = MCPToolApprovalRequest(
|
||||
request_id="r-A",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "alpha"},
|
||||
connection_name="conn-A",
|
||||
)
|
||||
request_b = MCPToolApprovalRequest(
|
||||
request_id="r-B",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "beta"},
|
||||
connection_name="conn-B",
|
||||
)
|
||||
|
||||
await executor.handle_approval_response(request_b, ToolApprovalResponse(approved=True), mock_context)
|
||||
await executor.handle_approval_response(request_a, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert handler.call_count == 2
|
||||
assert handler.invocations[0].arguments == {"q": "beta"}
|
||||
assert handler.invocations[0].connection_name == "conn-B"
|
||||
assert handler.invocations[1].arguments == {"q": "alpha"}
|
||||
assert handler.invocations[1].connection_name == "conn-A"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_headers_reevaluated_from_action_def_on_resume(self, mock_state, mock_context) -> None:
|
||||
"""Headers come from the action definition (re-evaluated) so secrets are not in the payload."""
|
||||
_seed_state(mock_state)
|
||||
handler = _RecordingMcpHandler()
|
||||
executor = InvokeMcpToolActionExecutor(
|
||||
self._action(headers={"Authorization": "Bearer tk"}),
|
||||
mcp_tool_handler=handler,
|
||||
)
|
||||
|
||||
request = MCPToolApprovalRequest(
|
||||
request_id="r-1",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "x"},
|
||||
connection_name=None,
|
||||
)
|
||||
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert handler.last is not None
|
||||
assert handler.last.headers == {"Authorization": "Bearer tk"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_resume_ignores_stale_state_at_old_approval_key(self, mock_state, mock_context) -> None:
|
||||
_seed_state(mock_state)
|
||||
handler = _RecordingMcpHandler()
|
||||
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
|
||||
|
||||
mock_state._data["_mcp_tool_approval_state_mcp_action"] = {"poison": True}
|
||||
|
||||
request = MCPToolApprovalRequest(
|
||||
request_id="r-1",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "real"},
|
||||
connection_name=None,
|
||||
)
|
||||
await executor.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert handler.call_count == 1
|
||||
assert handler.last is not None
|
||||
assert handler.last.arguments == {"q": "real"}
|
||||
# The poison was never read or deleted by the executor.
|
||||
assert "_mcp_tool_approval_state_mcp_action" in mock_state._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_mcp_executor_resume_works(self, mock_state, mock_context) -> None:
|
||||
"""Checkpoint-restore simulation: fresh executor handles the response."""
|
||||
_seed_state(mock_state)
|
||||
handler = _RecordingMcpHandler()
|
||||
fresh = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
|
||||
|
||||
request = MCPToolApprovalRequest(
|
||||
request_id="r-1",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "fresh"},
|
||||
connection_name=None,
|
||||
)
|
||||
await fresh.handle_approval_response(request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert handler.call_count == 1
|
||||
assert handler.last is not None
|
||||
assert handler.last.arguments == {"q": "fresh"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_payload_carries_connection_name(self, mock_state, mock_context) -> None:
|
||||
"""When emitting the approval request, connection_name flows into MCPToolApprovalRequest."""
|
||||
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
|
||||
|
||||
_seed_state(mock_state)
|
||||
action = self._action()
|
||||
action["connection"] = {"name": "conn-from-action"}
|
||||
executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler())
|
||||
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
mock_context.request_info.assert_called_once()
|
||||
request = mock_context.request_info.call_args[0][0]
|
||||
assert isinstance(request, MCPToolApprovalRequest)
|
||||
assert request.connection_name == "conn-from-action"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_payload_pins_conversation_id(self, mock_state, mock_context) -> None:
|
||||
"""Evaluated ``conversationId`` is pinned in ``metadata`` at request-emit time."""
|
||||
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
|
||||
|
||||
_seed_state(mock_state)
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.set("Local.targetConversation", "conv-original")
|
||||
action = self._action()
|
||||
action["conversationId"] = "=Local.targetConversation"
|
||||
executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=_RecordingMcpHandler())
|
||||
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
mock_context.request_info.assert_called_once()
|
||||
request = mock_context.request_info.call_args[0][0]
|
||||
assert isinstance(request, MCPToolApprovalRequest)
|
||||
assert request.metadata.get("conversation_id") == "conv-original"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_routes_output_to_pinned_conversation_not_mutated_state(
|
||||
self, mock_state, mock_context
|
||||
) -> None:
|
||||
"""Output appends to the conversation pinned on ``original_request``, not the
|
||||
current state evaluation."""
|
||||
_seed_state(mock_state)
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.set("System.conversations.conv-original.messages", [])
|
||||
state.set("System.conversations.conv-mutated.messages", [])
|
||||
state.set("Local.targetConversation", "conv-mutated")
|
||||
|
||||
handler = _RecordingMcpHandler(MCPToolResult(outputs=[Content.from_text("approved-output")]))
|
||||
action = self._action()
|
||||
action["conversationId"] = "=Local.targetConversation"
|
||||
executor = InvokeMcpToolActionExecutor(action, mcp_tool_handler=handler)
|
||||
|
||||
original_request = MCPToolApprovalRequest(
|
||||
request_id="r-1",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "x"},
|
||||
connection_name=None,
|
||||
metadata={"conversation_id": "conv-original"},
|
||||
)
|
||||
await executor.handle_approval_response(original_request, ToolApprovalResponse(approved=True), mock_context)
|
||||
|
||||
assert len(state.get("System.conversations.conv-original.messages") or []) == 1
|
||||
assert state.get("System.conversations.conv-mutated.messages") == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_handles_legacy_request_without_new_fields(self, mock_state, mock_context) -> None:
|
||||
"""Resume tolerates payloads lacking ``connection_name`` / ``metadata`` (legacy pickle shape)."""
|
||||
|
||||
@dataclass
|
||||
class _LegacyMCPApprovalRequest:
|
||||
request_id: str
|
||||
tool_name: str
|
||||
server_url: str
|
||||
server_label: str | None
|
||||
arguments: dict[str, Any]
|
||||
header_names: list[str]
|
||||
|
||||
_seed_state(mock_state)
|
||||
handler = _RecordingMcpHandler()
|
||||
executor = InvokeMcpToolActionExecutor(self._action(), mcp_tool_handler=handler)
|
||||
|
||||
legacy_request = _LegacyMCPApprovalRequest(
|
||||
request_id="r-1",
|
||||
tool_name="search",
|
||||
server_url="https://mcp.example/api",
|
||||
server_label=None,
|
||||
arguments={"q": "x"},
|
||||
header_names=[],
|
||||
)
|
||||
await executor.handle_approval_response(
|
||||
legacy_request, # type: ignore[arg-type]
|
||||
ToolApprovalResponse(approved=True),
|
||||
mock_context,
|
||||
)
|
||||
|
||||
assert handler.call_count == 1
|
||||
assert handler.last is not None
|
||||
assert handler.last.connection_name is None
|
||||
@@ -35,14 +35,12 @@ pytestmark = pytest.mark.skipif(
|
||||
from agent_framework_declarative._workflows import ( # noqa: E402
|
||||
DECLARATIVE_STATE_KEY,
|
||||
FUNCTION_TOOL_REGISTRY_KEY,
|
||||
TOOL_APPROVAL_STATE_KEY,
|
||||
ActionComplete,
|
||||
ActionTrigger,
|
||||
DeclarativeWorkflowBuilder,
|
||||
InvokeFunctionToolExecutor,
|
||||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
ToolApprovalState,
|
||||
ToolInvocationResult,
|
||||
WorkflowFactory,
|
||||
)
|
||||
@@ -393,21 +391,6 @@ class TestToolApprovalTypes:
|
||||
assert response.approved is False
|
||||
assert response.reason == "Not authorized"
|
||||
|
||||
def test_approval_state(self):
|
||||
"""Test creating approval state for yield/resume."""
|
||||
state = ToolApprovalState(
|
||||
function_name="delete_user",
|
||||
arguments={"user_id": "123"},
|
||||
output_messages_var="Local.messages",
|
||||
output_result_var="Local.result",
|
||||
auto_send=True,
|
||||
)
|
||||
assert state.function_name == "delete_user"
|
||||
assert state.arguments == {"user_id": "123"}
|
||||
assert state.output_messages_var == "Local.messages"
|
||||
assert state.output_result_var == "Local.result"
|
||||
assert state.auto_send is True
|
||||
|
||||
|
||||
class TestInvokeFunctionToolEdgeCases:
|
||||
"""Tests for edge cases and error handling."""
|
||||
@@ -1075,13 +1058,6 @@ class TestApprovalFlow:
|
||||
# Should NOT have sent ActionComplete (workflow yields)
|
||||
mock_context.send_message.assert_not_called()
|
||||
|
||||
# Approval state should be saved in state
|
||||
approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_test"
|
||||
saved_state = mock_state._data[approval_key]
|
||||
assert isinstance(saved_state, ToolApprovalState)
|
||||
assert saved_state.function_name == "my_tool"
|
||||
assert saved_state.arguments == {"x": 5}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_response_approved(self, mock_state, mock_context):
|
||||
"""When approval response is approved, the tool should be invoked."""
|
||||
@@ -1104,17 +1080,7 @@ class TestApprovalFlow:
|
||||
|
||||
executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool})
|
||||
|
||||
# Pre-populate approval state (simulating what handle_action stores)
|
||||
approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_approved"
|
||||
mock_state._data[approval_key] = ToolApprovalState(
|
||||
function_name="my_tool",
|
||||
arguments={"x": 7},
|
||||
output_messages_var=None,
|
||||
output_result_var="Local.result",
|
||||
auto_send=True,
|
||||
)
|
||||
|
||||
# Simulate the response
|
||||
# Simulate the response — invocation params come from original_request
|
||||
original_request = ToolApprovalRequest(
|
||||
request_id="req-123",
|
||||
function_name="my_tool",
|
||||
@@ -1124,7 +1090,7 @@ class TestApprovalFlow:
|
||||
|
||||
await executor.handle_approval_response(original_request, response, mock_context)
|
||||
|
||||
# Tool should have been called
|
||||
# Tool should have been called with the approved arguments
|
||||
assert call_log == [7]
|
||||
|
||||
# ActionComplete should have been sent
|
||||
@@ -1132,9 +1098,6 @@ class TestApprovalFlow:
|
||||
sent = mock_context.send_message.call_args[0][0]
|
||||
assert isinstance(sent, ActionComplete)
|
||||
|
||||
# Approval state should be cleaned up
|
||||
assert approval_key not in mock_state._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_response_rejected(self, mock_state, mock_context):
|
||||
"""When approval response is rejected, rejection status should be stored."""
|
||||
@@ -1154,16 +1117,6 @@ class TestApprovalFlow:
|
||||
|
||||
executor = InvokeFunctionToolExecutor(action_def, tools={"my_tool": my_tool})
|
||||
|
||||
# Pre-populate approval state
|
||||
approval_key = f"{TOOL_APPROVAL_STATE_KEY}_approval_rejected"
|
||||
mock_state._data[approval_key] = ToolApprovalState(
|
||||
function_name="my_tool",
|
||||
arguments={"x": 5},
|
||||
output_messages_var=None,
|
||||
output_result_var="Local.result",
|
||||
auto_send=True,
|
||||
)
|
||||
|
||||
original_request = ToolApprovalRequest(
|
||||
request_id="req-456",
|
||||
function_name="my_tool",
|
||||
@@ -1185,36 +1138,6 @@ class TestApprovalFlow:
|
||||
assert result["reason"] == "Not authorized"
|
||||
assert result["approved"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_response_missing_state(self, mock_state, mock_context):
|
||||
"""When approval state is missing on resume, should log error and complete."""
|
||||
self._init_state(mock_state)
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeFunctionTool",
|
||||
"id": "missing_state_test",
|
||||
"functionName": "my_tool",
|
||||
"requireApproval": True,
|
||||
"output": {"result": "Local.result"},
|
||||
}
|
||||
|
||||
executor = InvokeFunctionToolExecutor(action_def, tools={})
|
||||
|
||||
# Don't populate approval state - simulate missing state
|
||||
original_request = ToolApprovalRequest(
|
||||
request_id="req-789",
|
||||
function_name="my_tool",
|
||||
arguments={},
|
||||
)
|
||||
response = ToolApprovalResponse(approved=True)
|
||||
|
||||
await executor.handle_approval_response(original_request, response, mock_context)
|
||||
|
||||
# Should still send ActionComplete
|
||||
mock_context.send_message.assert_called_once()
|
||||
sent = mock_context.send_message.call_args[0][0]
|
||||
assert isinstance(sent, ActionComplete)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# State registry tool lookup (lines 255-257)
|
||||
|
||||
@@ -403,7 +403,6 @@ class TestApprovalFlow:
|
||||
async def test_approval_required_emits_request_and_yields(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
|
||||
from agent_framework_declarative._workflows._declarative_base import ActionTrigger
|
||||
from agent_framework_declarative._workflows._executors_mcp import (
|
||||
_MCP_APPROVAL_STATE_KEY,
|
||||
InvokeMcpToolActionExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
)
|
||||
@@ -439,18 +438,12 @@ class TestApprovalFlow:
|
||||
# Handler not invoked yet.
|
||||
assert handler.call_count == 0
|
||||
|
||||
# Approval state stored.
|
||||
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
|
||||
assert approval_key in mock_state._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_response_approved_invokes_handler(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
|
||||
from agent_framework_declarative._workflows import ActionComplete, ToolApprovalResponse
|
||||
from agent_framework_declarative._workflows._executors_mcp import (
|
||||
_MCP_APPROVAL_STATE_KEY,
|
||||
InvokeMcpToolActionExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
_MCPToolApprovalState,
|
||||
)
|
||||
|
||||
_seed_state(mock_state)
|
||||
@@ -458,24 +451,11 @@ class TestApprovalFlow:
|
||||
executor = InvokeMcpToolActionExecutor(
|
||||
_action(
|
||||
require_approval=True,
|
||||
headers={"Authorization": "Bearer tk"},
|
||||
output={"result": "Local.Result"},
|
||||
),
|
||||
mcp_tool_handler=handler,
|
||||
)
|
||||
# Pre-populate approval state.
|
||||
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
|
||||
mock_state._data[approval_key] = _MCPToolApprovalState(
|
||||
server_url="https://mcp.example/api",
|
||||
tool_name="search",
|
||||
server_label=None,
|
||||
arguments={"q": "x"},
|
||||
connection_name=None,
|
||||
headers_def={"Authorization": "Bearer tk"},
|
||||
auto_send=False,
|
||||
conversation_id_expr=None,
|
||||
output_messages_path=None,
|
||||
output_result_path="Local.Result",
|
||||
)
|
||||
await executor.handle_approval_response(
|
||||
MCPToolApprovalRequest(
|
||||
request_id="req-1",
|
||||
@@ -491,10 +471,12 @@ class TestApprovalFlow:
|
||||
assert handler.call_count == 1
|
||||
inv = handler.last_invocation
|
||||
assert inv is not None
|
||||
# Headers are re-evaluated from headers_def.
|
||||
# Invocation fields source from the approval request payload.
|
||||
assert inv.tool_name == "search"
|
||||
assert inv.server_url == "https://mcp.example/api"
|
||||
assert inv.arguments == {"q": "x"}
|
||||
# Headers are re-evaluated from the action definition on resume.
|
||||
assert inv.headers == {"Authorization": "Bearer tk"}
|
||||
# Approval state was cleaned up.
|
||||
assert approval_key not in mock_state._data
|
||||
# ActionComplete was sent.
|
||||
mock_context.send_message.assert_called_once()
|
||||
sent = mock_context.send_message.call_args[0][0]
|
||||
@@ -504,10 +486,8 @@ class TestApprovalFlow:
|
||||
async def test_approval_response_rejected_assigns_error(self, mock_state, mock_context) -> None: # type: ignore[no-untyped-def]
|
||||
from agent_framework_declarative._workflows import ToolApprovalResponse
|
||||
from agent_framework_declarative._workflows._executors_mcp import (
|
||||
_MCP_APPROVAL_STATE_KEY,
|
||||
InvokeMcpToolActionExecutor,
|
||||
MCPToolApprovalRequest,
|
||||
_MCPToolApprovalState,
|
||||
)
|
||||
|
||||
_seed_state(mock_state)
|
||||
@@ -519,19 +499,6 @@ class TestApprovalFlow:
|
||||
),
|
||||
mcp_tool_handler=handler,
|
||||
)
|
||||
approval_key = f"{_MCP_APPROVAL_STATE_KEY}_mcp_action"
|
||||
mock_state._data[approval_key] = _MCPToolApprovalState(
|
||||
server_url="https://mcp.example/api",
|
||||
tool_name="search",
|
||||
server_label=None,
|
||||
arguments={},
|
||||
connection_name=None,
|
||||
headers_def=None,
|
||||
auto_send=True,
|
||||
conversation_id_expr=None,
|
||||
output_messages_path=None,
|
||||
output_result_path="Local.Result",
|
||||
)
|
||||
await executor.handle_approval_response(
|
||||
MCPToolApprovalRequest(
|
||||
request_id="req-2",
|
||||
|
||||
@@ -87,6 +87,8 @@ def _prompt_for_approval(request: MCPToolApprovalRequest) -> ToolApprovalRespons
|
||||
print(f" outbound header names: {', '.join(request.header_names)}")
|
||||
else:
|
||||
print(" outbound header names: (none)")
|
||||
if request.connection_name:
|
||||
print(f" connection: {request.connection_name}")
|
||||
print("-" * 60)
|
||||
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user