mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
[BREAKING] Python: Add InvokeFunctionTool action for declarative workflows (#3716)
* add(declarative): Declarative workflow InvokeFunctionTool feature * Cleanup * Address PR feedback * Remove InvokeTool kind, consolidate to InvokeFunctionTool * Fix sample locations * pin azure-ai-projects to 2.0.0b3 due to breaking changes
This commit is contained in:
committed by
GitHub
Unverified
parent
f77f40b987
commit
40d2fac29c
@@ -34,7 +34,6 @@ from ._executors_agents import (
|
||||
AgentResult,
|
||||
ExternalLoopState,
|
||||
InvokeAzureAgentExecutor,
|
||||
InvokeToolExecutor,
|
||||
)
|
||||
from ._executors_basic import (
|
||||
BASIC_ACTION_EXECUTORS,
|
||||
@@ -68,6 +67,17 @@ from ._executors_external_input import (
|
||||
RequestExternalInputExecutor,
|
||||
WaitForInputExecutor,
|
||||
)
|
||||
from ._executors_tools import (
|
||||
FUNCTION_TOOL_REGISTRY_KEY,
|
||||
TOOL_ACTION_EXECUTORS,
|
||||
TOOL_APPROVAL_STATE_KEY,
|
||||
BaseToolExecutor,
|
||||
InvokeFunctionToolExecutor,
|
||||
ToolApprovalRequest,
|
||||
ToolApprovalResponse,
|
||||
ToolApprovalState,
|
||||
ToolInvocationResult,
|
||||
)
|
||||
from ._factory import DeclarativeWorkflowError, WorkflowFactory
|
||||
from ._state import WorkflowState
|
||||
|
||||
@@ -79,6 +89,9 @@ __all__ = [
|
||||
"CONTROL_FLOW_EXECUTORS",
|
||||
"DECLARATIVE_STATE_KEY",
|
||||
"EXTERNAL_INPUT_EXECUTORS",
|
||||
"FUNCTION_TOOL_REGISTRY_KEY",
|
||||
"TOOL_ACTION_EXECUTORS",
|
||||
"TOOL_APPROVAL_STATE_KEY",
|
||||
"TOOL_REGISTRY_KEY",
|
||||
"ActionComplete",
|
||||
"ActionTrigger",
|
||||
@@ -86,6 +99,7 @@ __all__ = [
|
||||
"AgentExternalInputResponse",
|
||||
"AgentResult",
|
||||
"AppendValueExecutor",
|
||||
"BaseToolExecutor",
|
||||
"BreakLoopExecutor",
|
||||
"ClearAllVariablesExecutor",
|
||||
"ConfirmationExecutor",
|
||||
@@ -107,7 +121,7 @@ __all__ = [
|
||||
"ForeachInitExecutor",
|
||||
"ForeachNextExecutor",
|
||||
"InvokeAzureAgentExecutor",
|
||||
"InvokeToolExecutor",
|
||||
"InvokeFunctionToolExecutor",
|
||||
"JoinExecutor",
|
||||
"LoopControl",
|
||||
"LoopIterationResult",
|
||||
@@ -119,6 +133,10 @@ __all__ = [
|
||||
"SetTextVariableExecutor",
|
||||
"SetValueExecutor",
|
||||
"SetVariableExecutor",
|
||||
"ToolApprovalRequest",
|
||||
"ToolApprovalResponse",
|
||||
"ToolApprovalState",
|
||||
"ToolInvocationResult",
|
||||
"WaitForInputExecutor",
|
||||
"WorkflowFactory",
|
||||
"WorkflowState",
|
||||
|
||||
+20
-3
@@ -13,6 +13,7 @@ action definitions and creates a proper workflow graph with:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
@@ -38,6 +39,10 @@ from ._executors_control_flow import (
|
||||
SwitchEvaluatorExecutor,
|
||||
)
|
||||
from ._executors_external_input import EXTERNAL_INPUT_EXECUTORS
|
||||
from ._executors_tools import TOOL_ACTION_EXECUTORS, InvokeFunctionToolExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Combined mapping of all action kinds to executor classes
|
||||
ALL_ACTION_EXECUTORS = {
|
||||
@@ -45,6 +50,7 @@ ALL_ACTION_EXECUTORS = {
|
||||
**CONTROL_FLOW_EXECUTORS,
|
||||
**AGENT_ACTION_EXECUTORS,
|
||||
**EXTERNAL_INPUT_EXECUTORS,
|
||||
**TOOL_ACTION_EXECUTORS,
|
||||
}
|
||||
|
||||
# Action kinds that terminate control flow (no fall-through to successor)
|
||||
@@ -78,6 +84,7 @@ ACTION_REQUIRED_FIELDS: dict[str, list[str]] = {
|
||||
"RequestHumanInput": ["variable"],
|
||||
"WaitForHumanInput": ["variable"],
|
||||
"EmitEvent": ["event"],
|
||||
"InvokeFunctionTool": ["functionName"],
|
||||
}
|
||||
|
||||
# Alternate field names that satisfy required field requirements
|
||||
@@ -118,6 +125,7 @@ class DeclarativeWorkflowBuilder:
|
||||
yaml_definition: dict[str, Any],
|
||||
workflow_id: str | None = None,
|
||||
agents: dict[str, Any] | None = None,
|
||||
tools: dict[str, Any] | None = None,
|
||||
checkpoint_storage: Any | None = None,
|
||||
validate: bool = True,
|
||||
max_iterations: int | None = None,
|
||||
@@ -128,6 +136,7 @@ class DeclarativeWorkflowBuilder:
|
||||
yaml_definition: The parsed YAML workflow definition
|
||||
workflow_id: Optional ID for the workflow (defaults to name from YAML)
|
||||
agents: Registry of agent instances by name (for InvokeAzureAgent actions)
|
||||
tools: Registry of tool/function instances by name (for InvokeFunctionTool actions)
|
||||
checkpoint_storage: Optional checkpoint storage for pause/resume support
|
||||
validate: Whether to validate the workflow definition before building (default: True)
|
||||
max_iterations: Maximum runner supersteps. Falls back to the YAML ``maxTurns``
|
||||
@@ -138,6 +147,7 @@ class DeclarativeWorkflowBuilder:
|
||||
self._executors: dict[str, Any] = {} # id -> executor
|
||||
self._action_index = 0 # Counter for generating unique IDs
|
||||
self._agents = agents or {} # Agent registry for agent executors
|
||||
self._tools = tools or {} # Tool registry for tool executors
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
self._pending_gotos: list[tuple[Any, str]] = [] # (goto_executor, target_id)
|
||||
self._validate = validate
|
||||
@@ -423,8 +433,13 @@ class DeclarativeWorkflowBuilder:
|
||||
executor_class = ALL_ACTION_EXECUTORS.get(kind)
|
||||
|
||||
if executor_class is None:
|
||||
# Unknown action type - skip with warning
|
||||
# In production, might want to log this
|
||||
# Unknown action type - log warning and skip
|
||||
logger.warning(
|
||||
"Unknown action kind '%s' encountered at index %d - action will be skipped. Available action kinds: %s",
|
||||
kind,
|
||||
self._action_index,
|
||||
list(ALL_ACTION_EXECUTORS.keys()),
|
||||
)
|
||||
return None
|
||||
|
||||
# Create the executor with ID
|
||||
@@ -437,10 +452,12 @@ class DeclarativeWorkflowBuilder:
|
||||
action_id = f"{parent_id}_{kind}_{self._action_index}" if parent_id else f"{kind}_{self._action_index}"
|
||||
self._action_index += 1
|
||||
|
||||
# Pass agents to agent-related executors
|
||||
# Pass agents/tools to specialized executors
|
||||
executor: Any
|
||||
if kind in ("InvokeAzureAgent",):
|
||||
executor = InvokeAzureAgentExecutor(action_def, id=action_id, agents=self._agents)
|
||||
elif kind == "InvokeFunctionTool":
|
||||
executor = InvokeFunctionToolExecutor(action_def, id=action_id, tools=self._tools)
|
||||
else:
|
||||
executor = executor_class(action_def, id=action_id)
|
||||
self._executors[action_id] = executor
|
||||
|
||||
-68
@@ -1019,75 +1019,7 @@ class InvokeAzureAgentExecutor(DeclarativeActionExecutor):
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
|
||||
class InvokeToolExecutor(DeclarativeActionExecutor):
|
||||
"""Executor that invokes a registered tool/function.
|
||||
|
||||
Tools are simpler than agents - they take input, perform an action,
|
||||
and return a result synchronously (or with a simple async call).
|
||||
"""
|
||||
|
||||
@handler
|
||||
async def handle_action(
|
||||
self,
|
||||
trigger: Any,
|
||||
ctx: WorkflowContext[ActionComplete],
|
||||
) -> None:
|
||||
"""Handle the tool invocation."""
|
||||
state = await self._ensure_state_initialized(ctx, trigger)
|
||||
|
||||
tool_name = self._action_def.get("tool") or self._action_def.get("toolName", "")
|
||||
input_expr = self._action_def.get("input")
|
||||
output_property = self._action_def.get("output", {}).get("property") or self._action_def.get("resultProperty")
|
||||
parameters = self._action_def.get("parameters", {})
|
||||
|
||||
# Get tools registry
|
||||
try:
|
||||
tool_registry: dict[str, Any] | None = ctx.state.get(TOOL_REGISTRY_KEY)
|
||||
except KeyError:
|
||||
tool_registry = {}
|
||||
|
||||
tool: Any = tool_registry.get(tool_name) if tool_registry else None
|
||||
|
||||
if tool is None:
|
||||
error_msg = f"Tool '{tool_name}' not found in registry"
|
||||
if output_property:
|
||||
state.set(output_property, {"error": error_msg})
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
# Build parameters
|
||||
params: dict[str, Any] = {}
|
||||
for param_name, param_expression in parameters.items():
|
||||
params[param_name] = state.eval_if_expression(param_expression)
|
||||
|
||||
# Add main input if specified
|
||||
if input_expr:
|
||||
params["input"] = state.eval_if_expression(input_expr)
|
||||
|
||||
try:
|
||||
# Invoke the tool
|
||||
if callable(tool):
|
||||
from inspect import isawaitable
|
||||
|
||||
result = tool(**params)
|
||||
if isawaitable(result):
|
||||
result = await result
|
||||
|
||||
# Store result
|
||||
if output_property:
|
||||
state.set(output_property, result)
|
||||
|
||||
except Exception as e:
|
||||
if output_property:
|
||||
state.set(output_property, {"error": str(e)})
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
|
||||
# Mapping of agent action kinds to executor classes
|
||||
AGENT_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = {
|
||||
"InvokeAzureAgent": InvokeAzureAgentExecutor,
|
||||
"InvokeTool": InvokeToolExecutor,
|
||||
}
|
||||
|
||||
+711
@@ -0,0 +1,711 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tool invocation executors for declarative workflows.
|
||||
|
||||
Provides base abstractions and concrete executors for invoking various tool types
|
||||
(functions, APIs, MCP servers, etc.) with support for approval flows and structured output.
|
||||
|
||||
This module is designed for extensibility:
|
||||
- BaseToolExecutor provides common patterns (registry lookup, approval flow, output formatting)
|
||||
- Concrete executors (InvokeFunctionToolExecutor) implement tool-specific invocation logic
|
||||
- New tool types can be added by subclassing BaseToolExecutor
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import isawaitable
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
response_handler,
|
||||
)
|
||||
|
||||
from ._declarative_base import (
|
||||
ActionComplete,
|
||||
DeclarativeActionExecutor,
|
||||
DeclarativeWorkflowState,
|
||||
)
|
||||
from ._executors_agents import TOOL_REGISTRY_KEY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Registry key for function tools in State - reuse existing key so functions registered
|
||||
# at runtime are discoverable by both agent-based and function-based tool executors.
|
||||
FUNCTION_TOOL_REGISTRY_KEY = TOOL_REGISTRY_KEY
|
||||
|
||||
# State key prefix for storing approval state during yield/resume.
|
||||
# The executor's ID is appended to create a per-executor key.
|
||||
TOOL_APPROVAL_STATE_KEY = "_tool_approval_state"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Types for Approval Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolApprovalRequest:
|
||||
"""Request for approval before invoking a tool.
|
||||
|
||||
Emitted when requireApproval=true, signaling that the workflow should yield
|
||||
and wait for user approval before invoking the tool.
|
||||
|
||||
This follows the same pattern as AgentExternalInputRequest from _executors_agents.py,
|
||||
allowing consistent handling of human-in-loop scenarios across agents and tools.
|
||||
|
||||
Attributes:
|
||||
request_id: Unique identifier for this approval request.
|
||||
function_name: Evaluated function name to be invoked.
|
||||
arguments: Evaluated arguments to be passed to the function.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
function_name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolApprovalResponse:
|
||||
"""Response to a ToolApprovalRequest.
|
||||
|
||||
Provided by the caller to approve or reject tool invocation.
|
||||
|
||||
Attributes:
|
||||
approved: Whether the tool invocation was approved.
|
||||
reason: Optional reason for rejection.
|
||||
"""
|
||||
|
||||
approved: bool
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# State Types for Approval Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolApprovalState:
|
||||
"""State saved during approval yield for resumption.
|
||||
|
||||
Stored in State under a per-executor key when requireApproval=true.
|
||||
Retrieved by handle_approval_response() to continue execution.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
arguments: dict[str, Any]
|
||||
output_messages_var: str | None
|
||||
output_result_var: str | None
|
||||
auto_send: bool
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Result Types
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInvocationResult:
|
||||
"""Result from a tool invocation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the invocation succeeded.
|
||||
result: The return value from the tool (if successful).
|
||||
error: Error message (if failed).
|
||||
messages: Message list format for conversation history.
|
||||
rejected: Whether the invocation was rejected during approval.
|
||||
rejection_reason: Reason for rejection.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
messages: list[Message] = field(default_factory=list)
|
||||
rejected: bool = False
|
||||
rejection_reason: str | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _normalize_variable_path(variable: str) -> str:
|
||||
"""Normalize variable names to ensure they have a scope prefix.
|
||||
|
||||
Args:
|
||||
variable: Variable name like 'Local.X' or 'weatherResult'
|
||||
|
||||
Returns:
|
||||
The variable path with a scope prefix (defaults to Local if none provided)
|
||||
"""
|
||||
if variable.startswith(("Local.", "System.", "Workflow.", "Agent.", "Conversation.")):
|
||||
return variable
|
||||
if "." in variable:
|
||||
return variable
|
||||
return "Local." + variable
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Base Tool Executor (Abstract)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class BaseToolExecutor(DeclarativeActionExecutor):
|
||||
"""Base class for tool invocation executors.
|
||||
|
||||
Provides common functionality for all tool-like executors:
|
||||
- Tool registry lookup (State + WorkflowFactory registration)
|
||||
- Approval flow (request_info pattern with yield/resume)
|
||||
- Output formatting (messages as Message list + result variable)
|
||||
- Error handling (stores error in output, doesn't raise)
|
||||
|
||||
Subclasses must implement:
|
||||
- _invoke_tool(): Perform the actual tool invocation
|
||||
|
||||
YAML Schema (common fields):
|
||||
kind: <ToolKind>
|
||||
id: unique_id
|
||||
functionName: function_to_call # required, supports =expression syntax
|
||||
requireApproval: true # optional, default=false
|
||||
arguments: # optional dictionary
|
||||
param1: value1
|
||||
param2: =Local.dynamicValue
|
||||
output:
|
||||
messages: Local.toolCallMessages # Message list
|
||||
result: Local.toolResult
|
||||
autoSend: true # optional, default=true
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_def: dict[str, Any],
|
||||
*,
|
||||
id: str | None = None,
|
||||
tools: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Initialize the tool executor.
|
||||
|
||||
Args:
|
||||
action_def: The action definition from YAML
|
||||
id: Optional executor ID
|
||||
tools: Registry of tool instances by name (from WorkflowFactory)
|
||||
"""
|
||||
super().__init__(action_def, id=id)
|
||||
self._tools = tools or {}
|
||||
|
||||
@abstractmethod
|
||||
async def _invoke_tool(
|
||||
self,
|
||||
tool: Any,
|
||||
function_name: str,
|
||||
arguments: dict[str, Any],
|
||||
state: DeclarativeWorkflowState,
|
||||
) -> Any:
|
||||
"""Invoke the tool with the given arguments.
|
||||
|
||||
Args:
|
||||
tool: The tool instance to invoke
|
||||
function_name: Function/method name to call
|
||||
arguments: Arguments to pass
|
||||
state: Workflow state
|
||||
|
||||
Returns:
|
||||
The result from the tool invocation
|
||||
|
||||
Raises:
|
||||
Any exception from the tool invocation
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_tool(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: WorkflowContext[Any, Any],
|
||||
) -> Any | None:
|
||||
"""Get tool from registry.
|
||||
|
||||
Checks both WorkflowFactory registry (self._tools) and State registry.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function
|
||||
ctx: Workflow context
|
||||
|
||||
Returns:
|
||||
The tool/function, or None if not found
|
||||
"""
|
||||
# Check WorkflowFactory registry first (passed in constructor)
|
||||
tool = self._tools.get(function_name)
|
||||
if tool is not None:
|
||||
return tool
|
||||
|
||||
# Check State registry (for runtime registration)
|
||||
try:
|
||||
tool_registry: dict[str, Any] | None = ctx.state.get(FUNCTION_TOOL_REGISTRY_KEY)
|
||||
if tool_registry:
|
||||
return tool_registry.get(function_name)
|
||||
except KeyError:
|
||||
logger.debug(
|
||||
"%s: tool registry key '%s' not found in state "
|
||||
"(this is normal if tools are only registered via WorkflowFactory)",
|
||||
self.__class__.__name__,
|
||||
FUNCTION_TOOL_REGISTRY_KEY,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_output_config(self) -> tuple[str | None, str | None, bool]:
|
||||
"""Parse output configuration from action definition.
|
||||
|
||||
Returns:
|
||||
Tuple of (messages_var, result_var, auto_send)
|
||||
"""
|
||||
output_config = self._action_def.get("output", {})
|
||||
|
||||
if not isinstance(output_config, dict):
|
||||
return None, None, True
|
||||
|
||||
messages_var = output_config.get("messages")
|
||||
result_var = output_config.get("result")
|
||||
auto_send = bool(output_config.get("autoSend", True))
|
||||
|
||||
return (
|
||||
str(messages_var) if messages_var else None,
|
||||
str(result_var) if result_var else None,
|
||||
auto_send,
|
||||
)
|
||||
|
||||
def _store_result(
|
||||
self,
|
||||
result: ToolInvocationResult,
|
||||
state: DeclarativeWorkflowState,
|
||||
messages_var: str | None,
|
||||
result_var: str | None,
|
||||
) -> None:
|
||||
"""Store tool invocation result in workflow state.
|
||||
|
||||
Args:
|
||||
result: The tool invocation result
|
||||
state: Workflow state
|
||||
messages_var: Variable path for messages output
|
||||
result_var: Variable path for result output
|
||||
"""
|
||||
# Store messages if variable specified
|
||||
if messages_var:
|
||||
path = _normalize_variable_path(messages_var)
|
||||
state.set(path, result.messages)
|
||||
|
||||
# Store result if variable specified
|
||||
if result_var:
|
||||
path = _normalize_variable_path(result_var)
|
||||
if result.rejected:
|
||||
state.set(
|
||||
path,
|
||||
{
|
||||
"approved": False,
|
||||
"rejected": True,
|
||||
"reason": result.rejection_reason,
|
||||
},
|
||||
)
|
||||
elif result.success:
|
||||
state.set(path, result.result)
|
||||
else:
|
||||
state.set(
|
||||
path,
|
||||
{
|
||||
"error": result.error,
|
||||
},
|
||||
)
|
||||
|
||||
async def _format_messages(
|
||||
self,
|
||||
function_name: str,
|
||||
arguments: dict[str, Any],
|
||||
result: Any,
|
||||
) -> list[Message]:
|
||||
"""Format tool invocation as Message list.
|
||||
|
||||
Creates tool call + tool result message pair for conversation history,
|
||||
following the same format as agent tool calls.
|
||||
|
||||
Args:
|
||||
function_name: Function name invoked
|
||||
arguments: Arguments passed
|
||||
result: Result from invocation
|
||||
|
||||
Returns:
|
||||
List of Message objects [tool_call_message, tool_result_message]
|
||||
"""
|
||||
call_id = str(uuid.uuid4())
|
||||
|
||||
# Safely serialize arguments to JSON
|
||||
try:
|
||||
arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to serialize arguments to JSON: {e}")
|
||||
arguments_str = str(arguments)
|
||||
|
||||
# Tool call message (from assistant)
|
||||
tool_call_content = Content.from_function_call(
|
||||
call_id=call_id,
|
||||
name=function_name,
|
||||
arguments=arguments_str,
|
||||
)
|
||||
tool_call_message = Message(
|
||||
role="assistant",
|
||||
contents=[tool_call_content],
|
||||
)
|
||||
|
||||
# Safely serialize result to JSON
|
||||
try:
|
||||
result_str = json.dumps(result) if not isinstance(result, str) else result
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to serialize result to JSON: {e}")
|
||||
result_str = str(result)
|
||||
|
||||
tool_result_content = Content.from_function_result(
|
||||
call_id=call_id,
|
||||
result=result_str,
|
||||
)
|
||||
tool_result_message = Message(
|
||||
role="tool",
|
||||
contents=[tool_result_content],
|
||||
)
|
||||
|
||||
return [tool_call_message, tool_result_message]
|
||||
|
||||
async def _execute_tool_invocation(
|
||||
self,
|
||||
function_name: str,
|
||||
arguments: dict[str, Any],
|
||||
state: DeclarativeWorkflowState,
|
||||
ctx: WorkflowContext[Any, Any],
|
||||
) -> ToolInvocationResult:
|
||||
"""Execute the tool invocation.
|
||||
|
||||
Args:
|
||||
function_name: Function to invoke
|
||||
arguments: Arguments to pass
|
||||
state: Workflow state
|
||||
ctx: Workflow context
|
||||
|
||||
Returns:
|
||||
ToolInvocationResult with outcome
|
||||
"""
|
||||
# Get tool from registry
|
||||
tool = self._get_tool(function_name, ctx)
|
||||
if tool is None:
|
||||
error_msg = f"Function '{function_name}' not found in registry"
|
||||
logger.error(f"{self.__class__.__name__}: {error_msg}")
|
||||
return ToolInvocationResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
try:
|
||||
# Invoke the tool (subclass implements this)
|
||||
result_value = await self._invoke_tool(
|
||||
tool=tool,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Format as messages for conversation history
|
||||
messages = await self._format_messages(
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
result=result_value,
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
success=True,
|
||||
result=result_value,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s: error invoking function '%s': %s: %s",
|
||||
self.__class__.__name__,
|
||||
function_name,
|
||||
type(e).__name__,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return ToolInvocationResult(
|
||||
success=False,
|
||||
error=f"{type(e).__name__}: {e}",
|
||||
)
|
||||
|
||||
@handler
|
||||
async def handle_action(
|
||||
self,
|
||||
trigger: Any,
|
||||
ctx: WorkflowContext[ActionComplete, str],
|
||||
) -> None:
|
||||
"""Handle the tool invocation with optional approval flow.
|
||||
|
||||
When requireApproval=true:
|
||||
1. Saves invocation state to State (keyed by executor ID)
|
||||
2. Emits ToolApprovalRequest via ctx.request_info()
|
||||
3. Workflow yields (returns without ActionComplete)
|
||||
4. Resumes in handle_approval_response() when user responds
|
||||
"""
|
||||
state = await self._ensure_state_initialized(ctx, trigger)
|
||||
|
||||
# Parse output configuration early so we can store errors
|
||||
messages_var, result_var, auto_send = self._get_output_config()
|
||||
|
||||
# Get and evaluate function name (required)
|
||||
function_name_expr = self._action_def.get("functionName")
|
||||
if not function_name_expr:
|
||||
error_msg = f"Action '{self.id}' is missing required 'functionName' field"
|
||||
logger.error(f"{self.__class__.__name__}: {error_msg}")
|
||||
if result_var:
|
||||
state.set(_normalize_variable_path(result_var), {"error": error_msg})
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
function_name = state.eval_if_expression(function_name_expr)
|
||||
if not function_name:
|
||||
error_msg = f"Action '{self.id}': functionName expression evaluated to empty"
|
||||
logger.error(f"{self.__class__.__name__}: {error_msg}")
|
||||
if result_var:
|
||||
state.set(_normalize_variable_path(result_var), {"error": error_msg})
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
function_name = str(function_name)
|
||||
|
||||
# Evaluate arguments
|
||||
arguments_def = self._action_def.get("arguments", {})
|
||||
arguments: dict[str, Any] = {}
|
||||
if arguments_def is not None and not isinstance(arguments_def, dict):
|
||||
logger.warning(
|
||||
"%s: 'arguments' must be a dictionary, got %s - ignoring",
|
||||
self.__class__.__name__,
|
||||
type(arguments_def).__name__,
|
||||
)
|
||||
elif isinstance(arguments_def, dict):
|
||||
for key, value in arguments_def.items():
|
||||
arguments[key] = state.eval_if_expression(value)
|
||||
|
||||
# Check if approval is required
|
||||
require_approval = self._action_def.get("requireApproval", False)
|
||||
|
||||
if require_approval:
|
||||
# Save state for resumption (keyed by executor ID to avoid collisions)
|
||||
approval_state = ToolApprovalState(
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
output_messages_var=messages_var,
|
||||
output_result_var=result_var,
|
||||
auto_send=auto_send,
|
||||
)
|
||||
approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}"
|
||||
ctx.state.set(approval_key, approval_state)
|
||||
|
||||
# Emit approval request - workflow yields here
|
||||
request = ToolApprovalRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
logger.info(f"{self.__class__.__name__}: requesting approval for '{function_name}'")
|
||||
await ctx.request_info(request, ToolApprovalResponse)
|
||||
# Workflow yields - will resume in handle_approval_response
|
||||
return
|
||||
|
||||
# No approval required - invoke directly
|
||||
result = await self._execute_tool_invocation(
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
state=state,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
self._store_result(result, state, messages_var, result_var)
|
||||
if auto_send and result.success and result.result is not None:
|
||||
await ctx.yield_output(str(result.result))
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
@response_handler
|
||||
async def handle_approval_response(
|
||||
self,
|
||||
original_request: ToolApprovalRequest,
|
||||
response: ToolApprovalResponse,
|
||||
ctx: WorkflowContext[ActionComplete, str],
|
||||
) -> None:
|
||||
"""Handle response to a ToolApprovalRequest.
|
||||
|
||||
Called when the workflow resumes after yielding for approval.
|
||||
Either executes the tool (if approved) or stores rejection status.
|
||||
"""
|
||||
state = self._get_state(ctx.state)
|
||||
approval_key = f"{TOOL_APPROVAL_STATE_KEY}_{self.id}"
|
||||
|
||||
# Retrieve saved invocation state
|
||||
try:
|
||||
approval_state: ToolApprovalState = ctx.state.get(approval_key)
|
||||
except KeyError:
|
||||
error_msg = "Approval state not found, cannot resume tool invocation"
|
||||
logger.error(f"{self.__class__.__name__}: {error_msg}")
|
||||
# Try to store error - get output config from action def as fallback
|
||||
_, result_var, _ = self._get_output_config()
|
||||
if result_var and state:
|
||||
state.set(_normalize_variable_path(result_var), {"error": error_msg})
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
# Clean up approval state
|
||||
try:
|
||||
ctx.state.delete(approval_key)
|
||||
except KeyError:
|
||||
logger.warning(f"{self.__class__.__name__}: approval state already deleted")
|
||||
|
||||
function_name = approval_state.function_name
|
||||
arguments = approval_state.arguments
|
||||
messages_var = approval_state.output_messages_var
|
||||
result_var = approval_state.output_result_var
|
||||
auto_send = approval_state.auto_send
|
||||
|
||||
# Check if approved
|
||||
if not response.approved:
|
||||
logger.info(f"{self.__class__.__name__}: tool invocation rejected: {response.reason}")
|
||||
|
||||
# Store rejection status (don't raise error)
|
||||
result = ToolInvocationResult(
|
||||
success=False,
|
||||
rejected=True,
|
||||
rejection_reason=response.reason,
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
text=f"Function '{function_name}' was rejected: {response.reason or 'No reason provided'}",
|
||||
)
|
||||
],
|
||||
)
|
||||
self._store_result(result, state, messages_var, result_var)
|
||||
await ctx.send_message(ActionComplete())
|
||||
return
|
||||
|
||||
# Approved - execute the invocation
|
||||
result = await self._execute_tool_invocation(
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
state=state,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
self._store_result(result, state, messages_var, result_var)
|
||||
if auto_send and result.success and result.result is not None:
|
||||
await ctx.yield_output(str(result.result))
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Function Tool Executor (Concrete)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class InvokeFunctionToolExecutor(BaseToolExecutor):
|
||||
"""Executor that invokes a Python function as a tool.
|
||||
|
||||
This executor supports invoking registered Python functions with:
|
||||
- Expression evaluation for functionName and arguments
|
||||
- Optional approval flow (yield/resume pattern)
|
||||
- Async function support
|
||||
- Message list output for conversation history
|
||||
|
||||
YAML Schema:
|
||||
kind: InvokeFunctionTool
|
||||
id: invoke_function_example
|
||||
functionName: get_weather # required, supports =expression syntax
|
||||
requireApproval: true # optional, default=false
|
||||
arguments: # optional dictionary
|
||||
location: =Local.location
|
||||
unit: F
|
||||
output:
|
||||
messages: Local.weatherToolCallItems # Message list
|
||||
result: Local.WeatherInfo
|
||||
autoSend: true # optional, default=true
|
||||
|
||||
Tool Registration:
|
||||
Tools can be registered via:
|
||||
1. WorkflowFactory.register_tool("name", func) - preferred
|
||||
2. Setting FUNCTION_TOOL_REGISTRY_KEY in State at runtime
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_declarative import WorkflowFactory
|
||||
|
||||
|
||||
def get_weather(location: str, unit: str = "F") -> dict:
|
||||
return {"temp": 72, "unit": unit, "location": location}
|
||||
|
||||
|
||||
async def fetch_data(url: str) -> dict:
|
||||
# async function example
|
||||
return {"data": "..."}
|
||||
|
||||
|
||||
factory = (
|
||||
WorkflowFactory().register_tool("get_weather", get_weather).register_tool("fetch_data", fetch_data)
|
||||
)
|
||||
|
||||
workflow = factory.create_workflow_from_yaml_path("workflow.yaml")
|
||||
"""
|
||||
|
||||
async def _invoke_tool(
|
||||
self,
|
||||
tool: Any,
|
||||
function_name: str,
|
||||
arguments: dict[str, Any],
|
||||
state: DeclarativeWorkflowState,
|
||||
) -> Any:
|
||||
"""Invoke the function tool.
|
||||
|
||||
Supports:
|
||||
- Direct callable functions
|
||||
- Async functions (via inspect.isawaitable)
|
||||
|
||||
Args:
|
||||
tool: The tool/function to invoke
|
||||
function_name: Name of the function (for error messages)
|
||||
arguments: Arguments to pass to the function
|
||||
state: Workflow state (not used for function tools)
|
||||
|
||||
Returns:
|
||||
The result from the function invocation
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool is not callable
|
||||
"""
|
||||
if not callable(tool):
|
||||
raise ValueError(f"Function '{function_name}' is not callable")
|
||||
|
||||
# Invoke the function
|
||||
result = tool(**arguments)
|
||||
|
||||
# Handle async functions
|
||||
if isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Executor Registry Export
|
||||
# ============================================================================
|
||||
|
||||
TOOL_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = {
|
||||
"InvokeFunctionTool": InvokeFunctionToolExecutor,
|
||||
}
|
||||
@@ -141,6 +141,7 @@ class WorkflowFactory:
|
||||
self._agent_factory = agent_factory or AgentFactory(env_file_path=env_file)
|
||||
self._agents: dict[str, SupportsAgentRun | AgentExecutor] = dict(agents) if agents else {}
|
||||
self._bindings: dict[str, Any] = dict(bindings) if bindings else {}
|
||||
self._tools: dict[str, Any] = {} # Tool registry for InvokeFunctionTool actions
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
self._max_iterations = max_iterations
|
||||
|
||||
@@ -377,12 +378,13 @@ class WorkflowFactory:
|
||||
if description:
|
||||
normalized_def["description"] = description
|
||||
|
||||
# Build the graph-based workflow, passing agents for InvokeAzureAgent executors
|
||||
# Build the graph-based workflow, passing agents and tools for specialized executors
|
||||
try:
|
||||
graph_builder = DeclarativeWorkflowBuilder(
|
||||
normalized_def,
|
||||
workflow_id=name,
|
||||
agents=agents,
|
||||
tools=self._tools,
|
||||
checkpoint_storage=self._checkpoint_storage,
|
||||
max_iterations=self._max_iterations,
|
||||
)
|
||||
@@ -390,9 +392,10 @@ class WorkflowFactory:
|
||||
except ValueError as e:
|
||||
raise DeclarativeWorkflowError(f"Failed to build graph-based workflow: {e}") from e
|
||||
|
||||
# Store agents and bindings for reference (executors already have them)
|
||||
# Store agents, bindings, and tools for reference (executors already have them)
|
||||
workflow._declarative_agents = agents # type: ignore[attr-defined]
|
||||
workflow._declarative_bindings = self._bindings # type: ignore[attr-defined]
|
||||
workflow._declarative_tools = self._tools # type: ignore[attr-defined]
|
||||
|
||||
# Store input schema if defined in workflow definition
|
||||
# This allows DevUI to generate proper input forms
|
||||
@@ -598,9 +601,66 @@ class WorkflowFactory:
|
||||
|
||||
workflow = factory.create_workflow_from_yaml_path("workflow.yaml")
|
||||
"""
|
||||
if not callable(func):
|
||||
raise TypeError(f"Expected a callable for binding '{name}', got {type(func).__name__}")
|
||||
self._bindings[name] = func
|
||||
return self
|
||||
|
||||
def register_tool(self, name: str, func: Any) -> WorkflowFactory:
|
||||
"""Register a function with the factory for use in InvokeFunctionTool actions.
|
||||
|
||||
Registered functions are available to InvokeFunctionTool actions by name via the functionName field.
|
||||
This method supports fluent chaining.
|
||||
|
||||
Args:
|
||||
name: The name to register the function under. Must match the functionName
|
||||
referenced in InvokeFunctionTool actions.
|
||||
func: The function to register (can be sync or async).
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_declarative import WorkflowFactory
|
||||
|
||||
|
||||
def get_weather(location: str, unit: str = "F") -> dict:
|
||||
return {"temp": 72, "unit": unit, "location": location}
|
||||
|
||||
|
||||
async def fetch_data(url: str) -> dict:
|
||||
# Async function example
|
||||
return {"data": "..."}
|
||||
|
||||
|
||||
# Register functions for use in InvokeFunctionTool workflow actions
|
||||
factory = (
|
||||
WorkflowFactory().register_tool("get_weather", get_weather).register_tool("fetch_data", fetch_data)
|
||||
)
|
||||
|
||||
workflow = factory.create_workflow_from_yaml_path("workflow.yaml")
|
||||
|
||||
The workflow YAML can then reference these tools:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
actions:
|
||||
- kind: InvokeFunctionTool
|
||||
id: call_weather
|
||||
functionName: get_weather
|
||||
arguments:
|
||||
location: =Local.city
|
||||
unit: F
|
||||
output:
|
||||
result: Local.weatherData
|
||||
"""
|
||||
if not callable(func):
|
||||
raise TypeError(f"Expected a callable for tool '{name}', got {type(func).__name__}")
|
||||
self._tools[name] = func
|
||||
return self
|
||||
|
||||
def _convert_inputs_to_json_schema(self, inputs_def: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert a declarative inputs definition to JSON Schema.
|
||||
|
||||
|
||||
@@ -961,6 +961,423 @@ tools:
|
||||
assert mcp_tool.get("project_connection_id") == "my-oauth-connection"
|
||||
|
||||
|
||||
class TestAgentFactoryFilePath:
|
||||
"""Tests for AgentFactory file path operations."""
|
||||
|
||||
def test_create_agent_from_yaml_path_file_not_found(self, tmp_path):
|
||||
"""Test that nonexistent file raises DeclarativeLoaderError."""
|
||||
from agent_framework_declarative import AgentFactory
|
||||
from agent_framework_declarative._loader import DeclarativeLoaderError
|
||||
|
||||
factory = AgentFactory()
|
||||
with pytest.raises(DeclarativeLoaderError, match="YAML file not found"):
|
||||
factory.create_agent_from_yaml_path(tmp_path / "nonexistent.yaml")
|
||||
|
||||
def test_create_agent_from_yaml_path_with_string_path(self, tmp_path):
|
||||
"""Test create_agent_from_yaml_path accepts string path."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_file = tmp_path / "agent.yaml"
|
||||
yaml_file.write_text("""
|
||||
kind: Prompt
|
||||
name: FileAgent
|
||||
instructions: Test agent from file
|
||||
""")
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client)
|
||||
agent = factory.create_agent_from_yaml_path(str(yaml_file))
|
||||
|
||||
assert agent.name == "FileAgent"
|
||||
|
||||
def test_create_agent_from_yaml_path_with_path_object(self, tmp_path):
|
||||
"""Test create_agent_from_yaml_path accepts Path object."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_file = tmp_path / "agent.yaml"
|
||||
yaml_file.write_text("""
|
||||
kind: Prompt
|
||||
name: PathAgent
|
||||
instructions: Test agent from Path
|
||||
""")
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client)
|
||||
agent = factory.create_agent_from_yaml_path(yaml_file)
|
||||
|
||||
assert agent.name == "PathAgent"
|
||||
|
||||
|
||||
class TestAgentFactoryAsyncMethods:
|
||||
"""Tests for AgentFactory async methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_from_yaml_path_async_file_not_found(self, tmp_path):
|
||||
"""Test async version raises DeclarativeLoaderError for nonexistent file."""
|
||||
from agent_framework_declarative import AgentFactory
|
||||
from agent_framework_declarative._loader import DeclarativeLoaderError
|
||||
|
||||
factory = AgentFactory()
|
||||
with pytest.raises(DeclarativeLoaderError, match="YAML file not found"):
|
||||
await factory.create_agent_from_yaml_path_async(tmp_path / "nonexistent.yaml")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_from_yaml_async_with_client(self):
|
||||
"""Test async creation with pre-configured client."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: AsyncAgent
|
||||
instructions: Test async agent
|
||||
"""
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client)
|
||||
agent = await factory.create_agent_from_yaml_async(yaml_content)
|
||||
|
||||
assert agent.name == "AsyncAgent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_from_dict_async_with_client(self):
|
||||
"""Test async dict creation with pre-configured client."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
agent_def = {
|
||||
"kind": "Prompt",
|
||||
"name": "AsyncDictAgent",
|
||||
"instructions": "Test async dict agent",
|
||||
}
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client)
|
||||
agent = await factory.create_agent_from_dict_async(agent_def)
|
||||
|
||||
assert agent.name == "AsyncDictAgent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_from_dict_async_invalid_kind_raises(self):
|
||||
"""Test that async version also raises for non-PromptAgent."""
|
||||
from agent_framework_declarative import AgentFactory
|
||||
from agent_framework_declarative._loader import DeclarativeLoaderError
|
||||
|
||||
agent_def = {
|
||||
"kind": "Resource",
|
||||
"name": "NotAnAgent",
|
||||
}
|
||||
|
||||
factory = AgentFactory()
|
||||
with pytest.raises(DeclarativeLoaderError, match="Only definitions for a PromptAgent are supported"):
|
||||
await factory.create_agent_from_dict_async(agent_def)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_from_yaml_path_async_with_string_path(self, tmp_path):
|
||||
"""Test async version accepts string path."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_file = tmp_path / "async_agent.yaml"
|
||||
yaml_file.write_text("""
|
||||
kind: Prompt
|
||||
name: AsyncPathAgent
|
||||
instructions: Test async path agent
|
||||
""")
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client)
|
||||
agent = await factory.create_agent_from_yaml_path_async(str(yaml_file))
|
||||
|
||||
assert agent.name == "AsyncPathAgent"
|
||||
|
||||
|
||||
class TestAgentFactoryProviderLookup:
|
||||
"""Tests for provider configuration lookup."""
|
||||
|
||||
def test_provider_lookup_error_for_unknown_provider(self):
|
||||
"""Test that unknown provider raises ProviderLookupError."""
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
from agent_framework_declarative._loader import ProviderLookupError
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
instructions: Test agent
|
||||
model:
|
||||
id: test-model
|
||||
provider: UnknownProvider
|
||||
apiType: UnknownApiType
|
||||
"""
|
||||
|
||||
factory = AgentFactory()
|
||||
with pytest.raises(ProviderLookupError, match="Unsupported provider type"):
|
||||
factory.create_agent_from_yaml(yaml_content)
|
||||
|
||||
def test_additional_mappings_override_default(self):
|
||||
"""Test that additional_mappings can extend provider configurations."""
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
# Define a custom provider mapping
|
||||
custom_mappings = {
|
||||
"CustomProvider.Chat": {
|
||||
"package": "agent_framework.openai",
|
||||
"name": "OpenAIChatClient",
|
||||
"model_id_field": "model_id",
|
||||
},
|
||||
}
|
||||
|
||||
factory = AgentFactory(additional_mappings=custom_mappings)
|
||||
|
||||
# The custom mapping should be available
|
||||
assert "CustomProvider.Chat" in factory.additional_mappings
|
||||
|
||||
|
||||
class TestAgentFactoryConnectionHandling:
|
||||
"""Tests for connection handling in AgentFactory."""
|
||||
|
||||
def test_reference_connection_requires_connections_dict(self):
|
||||
"""Test that ReferenceConnection without connections dict raises."""
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
instructions: Test agent
|
||||
model:
|
||||
id: gpt-4
|
||||
provider: OpenAI
|
||||
apiType: Chat
|
||||
connection:
|
||||
kind: reference
|
||||
name: my-connection
|
||||
"""
|
||||
|
||||
factory = AgentFactory() # No connections provided
|
||||
with pytest.raises(ValueError, match="Connections must be provided to resolve ReferenceConnection"):
|
||||
factory.create_agent_from_yaml(yaml_content)
|
||||
|
||||
def test_reference_connection_not_found_raises(self):
|
||||
"""Test that missing ReferenceConnection raises."""
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
instructions: Test agent
|
||||
model:
|
||||
id: gpt-4
|
||||
provider: OpenAI
|
||||
apiType: Chat
|
||||
connection:
|
||||
kind: reference
|
||||
name: missing-connection
|
||||
"""
|
||||
|
||||
factory = AgentFactory(connections={"other-connection": "value"})
|
||||
with pytest.raises(ValueError, match="not found in provided connections"):
|
||||
factory.create_agent_from_yaml(yaml_content)
|
||||
|
||||
def test_model_without_id_uses_provided_client(self):
|
||||
"""Test that model without id uses the provided chat_client."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
instructions: Test agent
|
||||
model:
|
||||
provider: OpenAI
|
||||
"""
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client)
|
||||
agent = factory.create_agent_from_yaml(yaml_content)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
def test_model_without_id_and_no_client_raises(self):
|
||||
"""Test that model without id and no client raises."""
|
||||
from agent_framework_declarative import AgentFactory
|
||||
from agent_framework_declarative._loader import DeclarativeLoaderError
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
instructions: Test agent
|
||||
model:
|
||||
provider: OpenAI
|
||||
"""
|
||||
|
||||
factory = AgentFactory() # No chat_client
|
||||
with pytest.raises(DeclarativeLoaderError, match="ChatClient must be provided"):
|
||||
factory.create_agent_from_yaml(yaml_content)
|
||||
|
||||
|
||||
class TestAgentFactoryChatOptions:
|
||||
"""Tests for chat options parsing."""
|
||||
|
||||
def test_parse_chat_options_with_all_fields(self):
|
||||
"""Test parsing all ModelOptions fields into chat options dict."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
from agent_framework_declarative._models import Model, ModelOptions
|
||||
|
||||
factory = AgentFactory()
|
||||
|
||||
# Create a Model with all options set
|
||||
options = ModelOptions(
|
||||
temperature=0.7,
|
||||
maxOutputTokens=1000,
|
||||
topP=0.9,
|
||||
frequencyPenalty=0.5,
|
||||
presencePenalty=0.3,
|
||||
seed=42,
|
||||
stopSequences=["STOP", "END"],
|
||||
allowMultipleToolCalls=True,
|
||||
)
|
||||
options.additionalProperties["chatToolMode"] = "auto"
|
||||
|
||||
model = Model(id="gpt-4", options=options)
|
||||
|
||||
# Parse the options
|
||||
chat_options = factory._parse_chat_options(model)
|
||||
|
||||
# Verify all options are parsed correctly
|
||||
assert chat_options.get("temperature") == 0.7
|
||||
assert chat_options.get("max_tokens") == 1000
|
||||
assert chat_options.get("top_p") == 0.9
|
||||
assert chat_options.get("frequency_penalty") == 0.5
|
||||
assert chat_options.get("presence_penalty") == 0.3
|
||||
assert chat_options.get("seed") == 42
|
||||
assert chat_options.get("stop") == ["STOP", "END"]
|
||||
assert chat_options.get("allow_multiple_tool_calls") is True
|
||||
assert chat_options.get("tool_choice") == "auto"
|
||||
|
||||
def test_parse_chat_options_empty_model(self):
|
||||
"""Test that missing model options returns empty dict."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
|
||||
factory = AgentFactory()
|
||||
result = factory._parse_chat_options(None)
|
||||
assert result == {}
|
||||
|
||||
def test_parse_chat_options_with_additional_properties(self):
|
||||
"""Test that additional properties are passed through."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
from agent_framework_declarative._models import Model, ModelOptions
|
||||
|
||||
factory = AgentFactory()
|
||||
|
||||
# Create a Model with additional properties
|
||||
options = ModelOptions(temperature=0.5)
|
||||
options.additionalProperties["customOption"] = "customValue"
|
||||
|
||||
model = Model(id="gpt-4", options=options)
|
||||
|
||||
# Parse the options
|
||||
chat_options = factory._parse_chat_options(model)
|
||||
|
||||
# Verify additional properties are preserved
|
||||
assert "additional_chat_options" in chat_options
|
||||
assert chat_options["additional_chat_options"].get("customOption") == "customValue"
|
||||
|
||||
|
||||
class TestAgentFactoryToolParsing:
|
||||
"""Tests for tool parsing edge cases."""
|
||||
|
||||
def test_parse_tools_returns_none_for_empty_list(self):
|
||||
"""Test that empty tools list returns None."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
|
||||
factory = AgentFactory()
|
||||
result = factory._parse_tools(None)
|
||||
assert result is None
|
||||
|
||||
result = factory._parse_tools([])
|
||||
assert result is None
|
||||
|
||||
def test_parse_function_tool_with_bindings(self):
|
||||
"""Test parsing FunctionTool with bindings."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
instructions: Test agent
|
||||
tools:
|
||||
- kind: function
|
||||
name: my_function
|
||||
description: A test function
|
||||
bindings:
|
||||
- name: my_binding
|
||||
"""
|
||||
|
||||
def my_function():
|
||||
return "result"
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client, bindings={"my_binding": my_function})
|
||||
agent = factory.create_agent_from_yaml(yaml_content)
|
||||
|
||||
# Should have parsed the tool with binding
|
||||
tools = agent.default_options.get("tools", [])
|
||||
assert len(tools) == 1
|
||||
|
||||
def test_parse_file_search_tool_with_all_options(self):
|
||||
"""Test parsing FileSearchTool with ranker and filters."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
yaml_content = """
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
instructions: Test agent
|
||||
tools:
|
||||
- kind: file_search
|
||||
name: search
|
||||
description: Search files
|
||||
vectorStoreIds:
|
||||
- vs_123
|
||||
ranker: semantic
|
||||
scoreThreshold: 0.8
|
||||
maximumResultCount: 10
|
||||
filters:
|
||||
type: document
|
||||
"""
|
||||
|
||||
mock_client = MagicMock()
|
||||
factory = AgentFactory(client=mock_client)
|
||||
agent = factory.create_agent_from_yaml(yaml_content)
|
||||
|
||||
# Verify a file search tool was parsed
|
||||
tools = agent.default_options.get("tools", [])
|
||||
assert len(tools) == 1
|
||||
|
||||
def test_parse_unsupported_tool_kind_raises(self):
|
||||
"""Test that unsupported tool kind raises ValueError."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
from agent_framework_declarative._models import CustomTool
|
||||
|
||||
factory = AgentFactory()
|
||||
custom_tool = CustomTool(kind="custom", name="test")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported tool kind"):
|
||||
factory._parse_tool(custom_tool)
|
||||
|
||||
|
||||
class TestProviderResponseFormat:
|
||||
"""response_format from outputSchema must be passed inside default_options."""
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,6 +21,15 @@ from agent_framework_declarative._workflows._declarative_base import (
|
||||
LoopIterationResult,
|
||||
)
|
||||
|
||||
try:
|
||||
import powerfx # noqa: F401
|
||||
|
||||
_powerfx_available = True
|
||||
except (ImportError, RuntimeError):
|
||||
_powerfx_available = False
|
||||
|
||||
_requires_powerfx = pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -250,6 +259,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval([1, 2, 3]) # type: ignore[arg-type]
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_simple_and_operator(self, mock_state):
|
||||
"""Test simple And operator evaluation."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -264,6 +274,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval("=Local.a And Local.b")
|
||||
assert result is True
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_simple_or_operator(self, mock_state):
|
||||
"""Test simple Or operator evaluation."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -278,6 +289,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval("=Local.a Or Local.b")
|
||||
assert result is False
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_negation(self, mock_state):
|
||||
"""Test negation (!) evaluation."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -287,6 +299,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval("=!Local.flag")
|
||||
assert result is False
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_not_function(self, mock_state):
|
||||
"""Test Not() function evaluation."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -296,6 +309,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval("=Not(Local.flag)")
|
||||
assert result is False
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_comparison_operators(self, mock_state):
|
||||
"""Test comparison operators."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -310,6 +324,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
assert state.eval("=Local.x <> Local.y") is True
|
||||
assert state.eval("=Local.x = 5") is True
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_arithmetic_operators(self, mock_state):
|
||||
"""Test arithmetic operators."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -322,6 +337,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
assert state.eval("=Local.x * Local.y") == 30
|
||||
assert state.eval("=Local.x / Local.y") == pytest.approx(3.333, rel=0.01)
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_string_literal(self, mock_state):
|
||||
"""Test string literal evaluation."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -330,6 +346,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval('="hello world"')
|
||||
assert result == "hello world"
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_float_literal(self, mock_state):
|
||||
"""Test float literal evaluation."""
|
||||
from decimal import Decimal
|
||||
@@ -341,6 +358,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
# Accepts both float (Python fallback) and Decimal (pythonnet/PowerFx)
|
||||
assert result == 3.14 or result == Decimal("3.14")
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_variable_reference_with_namespace_mappings(self, mock_state):
|
||||
"""Test variable reference with PowerFx symbols."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -355,6 +373,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval("=Workflow.Inputs.query")
|
||||
assert result == "test"
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_if_expression_with_dict(self, mock_state):
|
||||
"""Test eval_if_expression recursively evaluates dicts."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -364,6 +383,7 @@ class TestDeclarativeWorkflowStateExtended:
|
||||
result = state.eval_if_expression({"greeting": "=Local.name", "static": "hello"})
|
||||
assert result == {"greeting": "Alice", "static": "hello"}
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_eval_if_expression_with_list(self, mock_state):
|
||||
"""Test eval_if_expression recursively evaluates lists."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -449,6 +469,7 @@ class TestBasicExecutorsCoverage:
|
||||
result = state.get("Local.nested")
|
||||
assert result == 42
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_set_text_variable_executor(self, mock_context, mock_state):
|
||||
"""Test SetTextVariableExecutor."""
|
||||
from agent_framework_declarative._workflows._executors_basic import (
|
||||
@@ -591,6 +612,7 @@ class TestBasicExecutorsCoverage:
|
||||
|
||||
mock_context.yield_output.assert_called_once_with("Plain text message")
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_send_activity_with_expression(self, mock_context, mock_state):
|
||||
"""Test SendActivityExecutor evaluates expressions."""
|
||||
from agent_framework_declarative._workflows._executors_basic import (
|
||||
@@ -909,6 +931,7 @@ class TestAgentExecutorsCoverage:
|
||||
assert result_prop == "Local.result"
|
||||
assert auto_send is False
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_agent_executor_build_input_text_from_string_messages(self, mock_context, mock_state):
|
||||
"""Test _build_input_text with string messages expression."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
@@ -925,6 +948,7 @@ class TestAgentExecutorsCoverage:
|
||||
input_text = await executor._build_input_text(state, {}, "=Local.userInput")
|
||||
assert input_text == "Hello agent!"
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_agent_executor_build_input_text_from_message_list(self, mock_context, mock_state):
|
||||
"""Test _build_input_text extracts text from message list."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
@@ -948,6 +972,7 @@ class TestAgentExecutorsCoverage:
|
||||
input_text = await executor._build_input_text(state, {}, "=Conversation.messages")
|
||||
assert input_text == "Last message"
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_agent_executor_build_input_text_from_message_with_text_attr(self, mock_context, mock_state):
|
||||
"""Test _build_input_text extracts text from message with text attribute."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
@@ -1120,83 +1145,6 @@ class TestAgentExecutorsCoverage:
|
||||
parsed = state.get("Local.Parsed")
|
||||
assert parsed == {"status": "ok", "count": 42}
|
||||
|
||||
async def test_invoke_tool_executor_not_found(self, mock_context, mock_state):
|
||||
"""Test InvokeToolExecutor when tool not found."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeToolExecutor,
|
||||
)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeTool",
|
||||
"tool": "MissingTool",
|
||||
"resultProperty": "Local.result",
|
||||
}
|
||||
executor = InvokeToolExecutor(action_def)
|
||||
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
result = state.get("Local.result")
|
||||
assert result == {"error": "Tool 'MissingTool' not found in registry"}
|
||||
|
||||
async def test_invoke_tool_executor_sync_tool(self, mock_context, mock_state):
|
||||
"""Test InvokeToolExecutor with synchronous tool."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
TOOL_REGISTRY_KEY,
|
||||
InvokeToolExecutor,
|
||||
)
|
||||
|
||||
def my_tool(x: int, y: int) -> int:
|
||||
return x + y
|
||||
|
||||
mock_state._data[TOOL_REGISTRY_KEY] = {"add": my_tool}
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeTool",
|
||||
"tool": "add",
|
||||
"parameters": {"x": 5, "y": 3},
|
||||
"resultProperty": "Local.result",
|
||||
}
|
||||
executor = InvokeToolExecutor(action_def)
|
||||
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
result = state.get("Local.result")
|
||||
assert result == 8
|
||||
|
||||
async def test_invoke_tool_executor_async_tool(self, mock_context, mock_state):
|
||||
"""Test InvokeToolExecutor with asynchronous tool."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
TOOL_REGISTRY_KEY,
|
||||
InvokeToolExecutor,
|
||||
)
|
||||
|
||||
async def my_async_tool(input: str) -> str:
|
||||
return f"Processed: {input}"
|
||||
|
||||
mock_state._data[TOOL_REGISTRY_KEY] = {"process": my_async_tool}
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeTool",
|
||||
"tool": "process",
|
||||
"input": "test data",
|
||||
"resultProperty": "Local.result",
|
||||
}
|
||||
executor = InvokeToolExecutor(action_def)
|
||||
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
result = state.get("Local.result")
|
||||
assert result == "Processed: test data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Control Flow Executors Tests - Additional coverage
|
||||
@@ -1206,6 +1154,7 @@ class TestAgentExecutorsCoverage:
|
||||
class TestControlFlowCoverage:
|
||||
"""Tests for control flow executors covering uncovered code paths."""
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_foreach_with_source_alias(self, mock_context, mock_state):
|
||||
"""Test ForeachInitExecutor with 'source' alias (interpreter mode)."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1268,6 +1217,7 @@ class TestControlFlowCoverage:
|
||||
assert msg.current_index == 1
|
||||
assert msg.current_item == "b"
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_switch_evaluator_with_value_cases(self, mock_context, mock_state):
|
||||
"""Test SwitchEvaluatorExecutor with value/cases schema."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1295,6 +1245,7 @@ class TestControlFlowCoverage:
|
||||
assert msg.matched is True
|
||||
assert msg.branch_index == 1 # Second case matched
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_switch_evaluator_default_case(self, mock_context, mock_state):
|
||||
"""Test SwitchEvaluatorExecutor falls through to default."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1554,6 +1505,7 @@ class TestControlFlowCoverage:
|
||||
# Should NOT send any message
|
||||
mock_context.send_message.assert_not_called()
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_condition_group_evaluator_first_match(self, mock_context, mock_state):
|
||||
"""Test ConditionGroupEvaluatorExecutor returns first match."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1579,6 +1531,7 @@ class TestControlFlowCoverage:
|
||||
assert msg.matched is True
|
||||
assert msg.branch_index == 1 # Second condition (x > 5) is first match
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_condition_group_evaluator_no_match(self, mock_context, mock_state):
|
||||
"""Test ConditionGroupEvaluatorExecutor with no matches."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1603,6 +1556,7 @@ class TestControlFlowCoverage:
|
||||
assert msg.matched is False
|
||||
assert msg.branch_index == -1
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_condition_group_evaluator_boolean_true_condition(self, mock_context, mock_state):
|
||||
"""Test ConditionGroupEvaluatorExecutor with boolean True condition."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1626,6 +1580,7 @@ class TestControlFlowCoverage:
|
||||
assert msg.matched is True
|
||||
assert msg.branch_index == 1
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_if_condition_evaluator_true(self, mock_context, mock_state):
|
||||
"""Test IfConditionEvaluatorExecutor with true condition."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1646,6 +1601,7 @@ class TestControlFlowCoverage:
|
||||
assert msg.matched is True
|
||||
assert msg.branch_index == 0 # Then branch
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_if_condition_evaluator_false(self, mock_context, mock_state):
|
||||
"""Test IfConditionEvaluatorExecutor with false condition."""
|
||||
from agent_framework_declarative._workflows._executors_control_flow import (
|
||||
@@ -1894,6 +1850,7 @@ class TestHumanInputExecutorsCoverage:
|
||||
class TestAgentExternalLoopCoverage:
|
||||
"""Tests for agent executor external loop handling."""
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_agent_executor_with_external_loop(self, mock_context, mock_state):
|
||||
"""Test agent executor with external loop that triggers."""
|
||||
from unittest.mock import patch
|
||||
@@ -1996,39 +1953,13 @@ class TestAgentExternalLoopCoverage:
|
||||
result = state.get("Local.result")
|
||||
assert result == "Direct string response"
|
||||
|
||||
async def test_invoke_tool_with_error(self, mock_context, mock_state):
|
||||
"""Test InvokeToolExecutor handles tool errors."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
TOOL_REGISTRY_KEY,
|
||||
InvokeToolExecutor,
|
||||
)
|
||||
|
||||
def failing_tool(**kwargs):
|
||||
raise ValueError("Tool error")
|
||||
|
||||
mock_state._data[TOOL_REGISTRY_KEY] = {"bad_tool": failing_tool}
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeTool",
|
||||
"tool": "bad_tool",
|
||||
"resultProperty": "Local.result",
|
||||
}
|
||||
executor = InvokeToolExecutor(action_def)
|
||||
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
result = state.get("Local.result")
|
||||
assert result == {"error": "Tool error"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PowerFx Functions Coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestPowerFxFunctionsCoverage:
|
||||
"""Tests for PowerFx function evaluation coverage."""
|
||||
|
||||
@@ -2784,6 +2715,7 @@ class TestBuilderValidation:
|
||||
assert "activity" in str(exc_info.value)
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestExpressionEdgeCases:
|
||||
"""Tests for expression evaluation edge cases."""
|
||||
|
||||
@@ -2808,6 +2740,7 @@ class TestExpressionEdgeCases:
|
||||
assert result == 42
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestLongMessageTextHandling:
|
||||
"""Tests for handling long MessageText results that exceed PowerFx limits."""
|
||||
|
||||
|
||||
@@ -7,7 +7,16 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_declarative._workflows import (
|
||||
try:
|
||||
import powerfx # noqa: F401
|
||||
|
||||
_powerfx_available = True
|
||||
except (ImportError, RuntimeError):
|
||||
_powerfx_available = False
|
||||
|
||||
_requires_powerfx = pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available")
|
||||
|
||||
from agent_framework_declarative._workflows import ( # noqa: E402
|
||||
ALL_ACTION_EXECUTORS,
|
||||
DECLARATIVE_STATE_KEY,
|
||||
ActionComplete,
|
||||
@@ -99,6 +108,7 @@ class TestDeclarativeWorkflowState:
|
||||
result = state.get("Local.items")
|
||||
assert result == ["first", "second"]
|
||||
|
||||
@_requires_powerfx
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_expression(self, mock_state):
|
||||
"""Test evaluating expressions."""
|
||||
@@ -196,6 +206,7 @@ class TestDeclarativeActionExecutor:
|
||||
|
||||
# Note: ConditionEvaluatorExecutor tests removed - conditions are now evaluated on edges
|
||||
|
||||
@_requires_powerfx
|
||||
async def test_foreach_init_with_items(self, mock_context, mock_state):
|
||||
"""Test ForeachInitExecutor with items."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
@@ -529,6 +540,7 @@ class TestHumanInputExecutors:
|
||||
assert "continue" in request.message.lower()
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestParseValueExecutor:
|
||||
"""Tests for the ParseValue action executor."""
|
||||
|
||||
|
||||
@@ -9,13 +9,27 @@ These tests verify:
|
||||
- Pause/resume capabilities
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_declarative._workflows import (
|
||||
try:
|
||||
import powerfx # noqa: F401
|
||||
|
||||
_powerfx_available = True
|
||||
except (ImportError, RuntimeError):
|
||||
_powerfx_available = False
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _powerfx_available or sys.version_info >= (3, 14),
|
||||
reason="PowerFx engine not available (requires dotnet runtime)",
|
||||
)
|
||||
|
||||
from agent_framework_declarative._workflows import ( # noqa: E402
|
||||
ActionTrigger,
|
||||
DeclarativeWorkflowBuilder,
|
||||
)
|
||||
from agent_framework_declarative._workflows._factory import WorkflowFactory
|
||||
from agent_framework_declarative._workflows._factory import WorkflowFactory # noqa: E402
|
||||
|
||||
|
||||
class TestGraphBasedWorkflowExecution:
|
||||
|
||||
@@ -237,6 +237,444 @@ class TestCustomFunctionsRegistry:
|
||||
"Lower",
|
||||
"Concat",
|
||||
"Search",
|
||||
"If",
|
||||
"Or",
|
||||
"And",
|
||||
"Not",
|
||||
"AgentMessage",
|
||||
"ForAll",
|
||||
]
|
||||
for name in expected:
|
||||
assert name in CUSTOM_FUNCTIONS
|
||||
|
||||
|
||||
class TestMessageTextEdgeCases:
|
||||
"""Additional tests for message_text edge cases."""
|
||||
|
||||
def test_message_text_dict_with_text_attr_content(self):
|
||||
"""Test message with content that has text attribute."""
|
||||
|
||||
class ContentWithText: # noqa: B903
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
msg = {"role": "assistant", "content": ContentWithText("Hello from text attr")}
|
||||
assert message_text(msg) == "Hello from text attr"
|
||||
|
||||
def test_message_text_dict_content_non_string(self):
|
||||
"""Test message with non-string content."""
|
||||
msg = {"role": "assistant", "content": 42}
|
||||
assert message_text(msg) == "42"
|
||||
|
||||
def test_message_text_list_with_string_items(self):
|
||||
"""Test message_text with list of strings."""
|
||||
result = message_text(["Hello", "World"])
|
||||
assert result == "Hello World"
|
||||
|
||||
def test_message_text_list_with_content_objects(self):
|
||||
"""Test message_text with list items having content attribute."""
|
||||
|
||||
class MessageObj: # noqa: B903
|
||||
def __init__(self, content: str):
|
||||
self.content = content
|
||||
|
||||
msgs = [MessageObj("Hello"), MessageObj("World")]
|
||||
result = message_text(msgs)
|
||||
assert result == "Hello World"
|
||||
|
||||
def test_message_text_list_with_content_text_attr(self):
|
||||
"""Test message_text with content having text attribute."""
|
||||
|
||||
class ContentWithText: # noqa: B903
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
class MessageObj:
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
msgs = [MessageObj(ContentWithText("Part1")), MessageObj(ContentWithText("Part2"))]
|
||||
result = message_text(msgs)
|
||||
assert result == "Part1 Part2"
|
||||
|
||||
def test_message_text_list_with_non_string_content(self):
|
||||
"""Test message_text with non-string content in dicts."""
|
||||
msgs = [{"content": 123}, {"content": 456}]
|
||||
result = message_text(msgs)
|
||||
assert result == "123 456"
|
||||
|
||||
def test_message_text_object_with_text_attr(self):
|
||||
"""Test message_text with object having text attribute."""
|
||||
|
||||
class ObjWithText:
|
||||
text = "Direct text"
|
||||
|
||||
result = message_text(ObjWithText())
|
||||
assert result == "Direct text"
|
||||
|
||||
def test_message_text_object_with_content_attr(self):
|
||||
"""Test message_text with object having content attribute."""
|
||||
|
||||
class ObjWithContent:
|
||||
content = "Direct content"
|
||||
|
||||
result = message_text(ObjWithContent())
|
||||
assert result == "Direct content"
|
||||
|
||||
def test_message_text_object_with_non_string_content(self):
|
||||
"""Test message_text with object having non-string content."""
|
||||
|
||||
class ObjWithContent:
|
||||
content = None
|
||||
|
||||
result = message_text(ObjWithContent())
|
||||
assert result == ""
|
||||
|
||||
def test_message_text_list_with_empty_content_object(self):
|
||||
"""Test message with content object that evaluates to empty."""
|
||||
|
||||
class MessageObj:
|
||||
content = None
|
||||
|
||||
result = message_text([MessageObj()])
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestAgentMessage:
|
||||
"""Tests for agent_message function."""
|
||||
|
||||
def test_agent_message_creates_dict(self):
|
||||
"""Test that AgentMessage creates correct dict."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import agent_message
|
||||
|
||||
msg = agent_message("Hello")
|
||||
assert msg == {"role": "assistant", "content": "Hello"}
|
||||
|
||||
def test_agent_message_with_none(self):
|
||||
"""Test AgentMessage with None."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import agent_message
|
||||
|
||||
msg = agent_message(None)
|
||||
assert msg == {"role": "assistant", "content": ""}
|
||||
|
||||
|
||||
class TestIfFunc:
|
||||
"""Tests for if_func conditional function."""
|
||||
|
||||
def test_if_true_condition(self):
|
||||
"""Test If with true condition."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import if_func
|
||||
|
||||
assert if_func(True, "yes", "no") == "yes"
|
||||
|
||||
def test_if_false_condition(self):
|
||||
"""Test If with false condition."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import if_func
|
||||
|
||||
assert if_func(False, "yes", "no") == "no"
|
||||
|
||||
def test_if_truthy_value(self):
|
||||
"""Test If with truthy value."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import if_func
|
||||
|
||||
assert if_func(1, "yes", "no") == "yes"
|
||||
assert if_func("non-empty", "yes", "no") == "yes"
|
||||
|
||||
def test_if_falsy_value(self):
|
||||
"""Test If with falsy value."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import if_func
|
||||
|
||||
assert if_func(0, "yes", "no") == "no"
|
||||
assert if_func("", "yes", "no") == "no"
|
||||
assert if_func(None, "yes", "no") == "no"
|
||||
|
||||
def test_if_no_false_value(self):
|
||||
"""Test If with no false value defaults to None."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import if_func
|
||||
|
||||
assert if_func(False, "yes") is None
|
||||
|
||||
|
||||
class TestOrFunc:
|
||||
"""Tests for or_func function."""
|
||||
|
||||
def test_or_all_false(self):
|
||||
"""Test Or with all false values."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import or_func
|
||||
|
||||
assert or_func(False, False, False) is False
|
||||
|
||||
def test_or_one_true(self):
|
||||
"""Test Or with one true value."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import or_func
|
||||
|
||||
assert or_func(False, True, False) is True
|
||||
|
||||
def test_or_all_true(self):
|
||||
"""Test Or with all true values."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import or_func
|
||||
|
||||
assert or_func(True, True, True) is True
|
||||
|
||||
def test_or_empty(self):
|
||||
"""Test Or with no arguments."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import or_func
|
||||
|
||||
assert or_func() is False
|
||||
|
||||
|
||||
class TestAndFunc:
|
||||
"""Tests for and_func function."""
|
||||
|
||||
def test_and_all_true(self):
|
||||
"""Test And with all true values."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import and_func
|
||||
|
||||
assert and_func(True, True, True) is True
|
||||
|
||||
def test_and_one_false(self):
|
||||
"""Test And with one false value."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import and_func
|
||||
|
||||
assert and_func(True, False, True) is False
|
||||
|
||||
def test_and_all_false(self):
|
||||
"""Test And with all false values."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import and_func
|
||||
|
||||
assert and_func(False, False, False) is False
|
||||
|
||||
def test_and_empty(self):
|
||||
"""Test And with no arguments."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import and_func
|
||||
|
||||
assert and_func() is True
|
||||
|
||||
|
||||
class TestNotFunc:
|
||||
"""Tests for not_func function."""
|
||||
|
||||
def test_not_true(self):
|
||||
"""Test Not with true."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import not_func
|
||||
|
||||
assert not_func(True) is False
|
||||
|
||||
def test_not_false(self):
|
||||
"""Test Not with false."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import not_func
|
||||
|
||||
assert not_func(False) is True
|
||||
|
||||
def test_not_truthy(self):
|
||||
"""Test Not with truthy values."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import not_func
|
||||
|
||||
assert not_func(1) is False
|
||||
assert not_func("text") is False
|
||||
|
||||
def test_not_falsy(self):
|
||||
"""Test Not with falsy values."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import not_func
|
||||
|
||||
assert not_func(0) is True
|
||||
assert not_func("") is True
|
||||
assert not_func(None) is True
|
||||
|
||||
|
||||
class TestIsBlankEdgeCases:
|
||||
"""Additional tests for is_blank edge cases."""
|
||||
|
||||
def test_is_blank_empty_dict(self):
|
||||
"""Test that empty dict is blank."""
|
||||
assert is_blank({}) is True
|
||||
|
||||
def test_is_blank_non_empty_dict(self):
|
||||
"""Test that non-empty dict is not blank."""
|
||||
assert is_blank({"key": "value"}) is False
|
||||
|
||||
|
||||
class TestCountRowsEdgeCases:
|
||||
"""Additional tests for count_rows edge cases."""
|
||||
|
||||
def test_count_rows_dict(self):
|
||||
"""Test counting dict items."""
|
||||
assert count_rows({"a": 1, "b": 2, "c": 3}) == 3
|
||||
|
||||
def test_count_rows_tuple(self):
|
||||
"""Test counting tuple items."""
|
||||
assert count_rows((1, 2, 3, 4)) == 4
|
||||
|
||||
def test_count_rows_non_iterable(self):
|
||||
"""Test counting non-iterable returns 0."""
|
||||
assert count_rows(42) == 0
|
||||
assert count_rows("string") == 0
|
||||
|
||||
|
||||
class TestFirstLastEdgeCases:
|
||||
"""Additional tests for first/last edge cases."""
|
||||
|
||||
def test_first_none(self):
|
||||
"""Test first with None."""
|
||||
assert first(None) is None
|
||||
|
||||
def test_last_none(self):
|
||||
"""Test last with None."""
|
||||
assert last(None) is None
|
||||
|
||||
def test_first_tuple(self):
|
||||
"""Test first with tuple."""
|
||||
assert first((1, 2, 3)) == 1
|
||||
|
||||
def test_last_tuple(self):
|
||||
"""Test last with tuple."""
|
||||
assert last((1, 2, 3)) == 3
|
||||
|
||||
|
||||
class TestFindEdgeCases:
|
||||
"""Additional tests for find edge cases."""
|
||||
|
||||
def test_find_none_substring(self):
|
||||
"""Test find with None substring."""
|
||||
assert find(None, "text") is None
|
||||
|
||||
def test_find_none_text(self):
|
||||
"""Test find with None text."""
|
||||
assert find("sub", None) is None
|
||||
|
||||
def test_find_both_none(self):
|
||||
"""Test find with both None."""
|
||||
assert find(None, None) is None
|
||||
|
||||
|
||||
class TestLowerEdgeCases:
|
||||
"""Additional tests for lower edge cases."""
|
||||
|
||||
def test_lower_none(self):
|
||||
"""Test lower with None."""
|
||||
assert lower(None) == ""
|
||||
|
||||
|
||||
class TestConcatStrings:
|
||||
"""Tests for concat_strings function."""
|
||||
|
||||
def test_concat_strings_basic(self):
|
||||
"""Test basic string concatenation."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import concat_strings
|
||||
|
||||
assert concat_strings("Hello", " ", "World") == "Hello World"
|
||||
|
||||
def test_concat_strings_with_none(self):
|
||||
"""Test concat with None values."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import concat_strings
|
||||
|
||||
assert concat_strings("Hello", None, "World") == "HelloWorld"
|
||||
|
||||
def test_concat_strings_empty(self):
|
||||
"""Test concat with no arguments."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import concat_strings
|
||||
|
||||
assert concat_strings() == ""
|
||||
|
||||
|
||||
class TestConcatTextEdgeCases:
|
||||
"""Additional tests for concat_text edge cases."""
|
||||
|
||||
def test_concat_text_none(self):
|
||||
"""Test concat_text with None."""
|
||||
assert concat_text(None) == ""
|
||||
|
||||
def test_concat_text_non_list(self):
|
||||
"""Test concat_text with non-list."""
|
||||
assert concat_text("single value") == "single value"
|
||||
|
||||
def test_concat_text_with_field_attr(self):
|
||||
"""Test concat_text with field as object attribute."""
|
||||
|
||||
class Item: # noqa: B903
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
items = [Item("Alice"), Item("Bob")]
|
||||
assert concat_text(items, field="name", separator=", ") == "Alice, Bob"
|
||||
|
||||
def test_concat_text_with_none_values(self):
|
||||
"""Test concat_text with None values in list."""
|
||||
items = [{"name": "Alice"}, {"name": None}, {"name": "Bob"}]
|
||||
result = concat_text(items, field="name", separator=", ")
|
||||
assert result == "Alice, , Bob"
|
||||
|
||||
|
||||
class TestForAll:
|
||||
"""Tests for for_all function."""
|
||||
|
||||
def test_for_all_with_list_of_dicts(self):
|
||||
"""Test ForAll with list of dictionaries."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import for_all
|
||||
|
||||
items = [{"name": "Alice"}, {"name": "Bob"}]
|
||||
result = for_all(items, "expression")
|
||||
assert result == items
|
||||
|
||||
def test_for_all_with_non_dict_items(self):
|
||||
"""Test ForAll with non-dict items."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import for_all
|
||||
|
||||
items = [1, 2, 3]
|
||||
result = for_all(items, "expression")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_for_all_with_none(self):
|
||||
"""Test ForAll with None."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import for_all
|
||||
|
||||
assert for_all(None, "expression") == []
|
||||
|
||||
def test_for_all_with_non_list(self):
|
||||
"""Test ForAll with non-list."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import for_all
|
||||
|
||||
assert for_all("not a list", "expression") == []
|
||||
|
||||
def test_for_all_empty_list(self):
|
||||
"""Test ForAll with empty list."""
|
||||
from agent_framework_declarative._workflows._powerfx_functions import for_all
|
||||
|
||||
assert for_all([], "expression") == []
|
||||
|
||||
|
||||
class TestSearchTableEdgeCases:
|
||||
"""Additional tests for search_table edge cases."""
|
||||
|
||||
def test_search_table_none(self):
|
||||
"""Test search_table with None."""
|
||||
assert search_table(None, "value", "column") == []
|
||||
|
||||
def test_search_table_non_list(self):
|
||||
"""Test search_table with non-list."""
|
||||
assert search_table("not a list", "value", "column") == []
|
||||
|
||||
def test_search_table_with_object_attr(self):
|
||||
"""Test search_table with object attributes."""
|
||||
|
||||
class Item: # noqa: B903
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
items = [Item("Alice"), Item("Bob"), Item("Charlie")]
|
||||
result = search_table(items, "Bob", "name")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "Bob"
|
||||
|
||||
def test_search_table_no_matching_column(self):
|
||||
"""Test search_table when items don't have the column."""
|
||||
items = [{"other": "value"}]
|
||||
result = search_table(items, "value", "name")
|
||||
assert result == []
|
||||
|
||||
def test_search_table_empty_value(self):
|
||||
"""Test search_table with empty search value."""
|
||||
items = [{"name": "Alice"}, {"name": "Bob"}]
|
||||
result = search_table(items, "", "name")
|
||||
# Empty string matches everything
|
||||
assert len(result) == 2
|
||||
|
||||
@@ -20,7 +20,16 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_declarative._workflows._declarative_base import (
|
||||
try:
|
||||
import powerfx # noqa: F401
|
||||
|
||||
_powerfx_available = True
|
||||
except (ImportError, RuntimeError):
|
||||
_powerfx_available = False
|
||||
|
||||
pytestmark = pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available")
|
||||
|
||||
from agent_framework_declarative._workflows._declarative_base import ( # noqa: E402
|
||||
DeclarativeWorkflowState,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,15 @@ from agent_framework_declarative._workflows._factory import (
|
||||
WorkflowFactory,
|
||||
)
|
||||
|
||||
try:
|
||||
import powerfx # noqa: F401
|
||||
|
||||
_powerfx_available = True
|
||||
except (ImportError, RuntimeError):
|
||||
_powerfx_available = False
|
||||
|
||||
_requires_powerfx = pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available")
|
||||
|
||||
|
||||
class TestWorkflowFactoryValidation:
|
||||
"""Tests for workflow definition validation."""
|
||||
@@ -58,6 +67,7 @@ actions:
|
||||
assert workflow.name == "minimal-workflow"
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestWorkflowFactoryExecution:
|
||||
"""Tests for workflow execution."""
|
||||
|
||||
@@ -204,6 +214,7 @@ actions:
|
||||
assert workflow.name == "file-workflow"
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestDisplayNameMetadata:
|
||||
"""Tests for displayName metadata support."""
|
||||
|
||||
@@ -231,3 +242,662 @@ actions:
|
||||
|
||||
# Should execute successfully with displayName metadata
|
||||
assert len(outputs) >= 1
|
||||
|
||||
|
||||
class TestWorkflowFactoryToolRegistration:
|
||||
"""Tests for tool registration."""
|
||||
|
||||
def test_register_tool_basic(self):
|
||||
"""Test registering a tool."""
|
||||
|
||||
def my_tool(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
factory = WorkflowFactory()
|
||||
result = factory.register_tool("my_tool", my_tool)
|
||||
|
||||
# Should return self for fluent chaining
|
||||
assert result is factory
|
||||
assert "my_tool" in factory._tools
|
||||
assert factory._tools["my_tool"](5) == 10
|
||||
|
||||
def test_register_multiple_tools(self):
|
||||
"""Test registering multiple tools with fluent chaining."""
|
||||
|
||||
def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
def multiply(a: int, b: int) -> int:
|
||||
return a * b
|
||||
|
||||
factory = WorkflowFactory().register_tool("add", add).register_tool("multiply", multiply)
|
||||
|
||||
assert "add" in factory._tools
|
||||
assert "multiply" in factory._tools
|
||||
assert factory._tools["add"](2, 3) == 5
|
||||
assert factory._tools["multiply"](2, 3) == 6
|
||||
|
||||
def test_register_tool_non_callable_raises(self):
|
||||
"""Test that register_tool raises TypeError for non-callable."""
|
||||
factory = WorkflowFactory()
|
||||
|
||||
with pytest.raises(TypeError, match="Expected a callable for tool"):
|
||||
factory.register_tool("bad_tool", "not_a_function")
|
||||
|
||||
def test_register_binding_non_callable_raises(self):
|
||||
"""Test that register_binding raises TypeError for non-callable."""
|
||||
factory = WorkflowFactory()
|
||||
|
||||
with pytest.raises(TypeError, match="Expected a callable for binding"):
|
||||
factory.register_binding("bad_binding", 42)
|
||||
|
||||
|
||||
class TestWorkflowFactoryEdgeCases:
|
||||
"""Tests for edge cases in workflow factory."""
|
||||
|
||||
def test_empty_actions_list(self):
|
||||
"""Test workflow with empty actions list."""
|
||||
factory = WorkflowFactory()
|
||||
with pytest.raises(DeclarativeWorkflowError, match="actions"):
|
||||
factory.create_workflow_from_yaml("""
|
||||
name: empty-actions
|
||||
actions: []
|
||||
""")
|
||||
|
||||
def test_unknown_action_kind(self):
|
||||
"""Test workflow with unknown action kind."""
|
||||
factory = WorkflowFactory()
|
||||
with pytest.raises((DeclarativeWorkflowError, ValueError)):
|
||||
factory.create_workflow_from_yaml("""
|
||||
name: unknown-action
|
||||
actions:
|
||||
- kind: UnknownActionType
|
||||
value: test
|
||||
""")
|
||||
|
||||
def test_workflow_with_description(self):
|
||||
"""Test workflow with description field."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: described-workflow
|
||||
description: This is a test workflow
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
assert workflow is not None
|
||||
assert workflow.name == "described-workflow"
|
||||
|
||||
@_requires_powerfx
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_with_expression_value(self):
|
||||
"""Test workflow with expression-based value."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: expression-test
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 5
|
||||
- kind: SetValue
|
||||
path: Local.y
|
||||
value: =Local.x
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: =Local.y
|
||||
""")
|
||||
|
||||
result = await workflow.run({})
|
||||
outputs = result.get_outputs()
|
||||
|
||||
assert any("5" in str(o) for o in outputs)
|
||||
|
||||
@_requires_powerfx
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_with_nested_if(self):
|
||||
"""Test workflow with nested If statements."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: nested-if-test
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.level
|
||||
value: 2
|
||||
- kind: If
|
||||
condition: true
|
||||
then:
|
||||
- kind: If
|
||||
condition: true
|
||||
then:
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Nested condition passed
|
||||
""")
|
||||
|
||||
result = await workflow.run({})
|
||||
outputs = result.get_outputs()
|
||||
|
||||
assert any("Nested condition passed" in str(o) for o in outputs)
|
||||
|
||||
def test_load_from_string_path(self, tmp_path):
|
||||
"""Test loading a workflow from a string file path."""
|
||||
workflow_file = tmp_path / "workflow.yaml"
|
||||
workflow_file.write_text("""
|
||||
name: string-path-workflow
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.loaded
|
||||
value: true
|
||||
""")
|
||||
|
||||
factory = WorkflowFactory()
|
||||
# Pass as string instead of Path object
|
||||
workflow = factory.create_workflow_from_yaml_path(str(workflow_file))
|
||||
|
||||
assert workflow is not None
|
||||
assert workflow.name == "string-path-workflow"
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestWorkflowFactorySwitch:
|
||||
"""Tests for Switch/Case action."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_with_matching_case(self):
|
||||
"""Test Switch with a matching case."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: switch-test
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.color
|
||||
value: red
|
||||
- kind: Switch
|
||||
value: =Local.color
|
||||
cases:
|
||||
- match: red
|
||||
actions:
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Color is red
|
||||
- match: blue
|
||||
actions:
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Color is blue
|
||||
""")
|
||||
|
||||
result = await workflow.run({})
|
||||
outputs = result.get_outputs()
|
||||
|
||||
assert any("Color is red" in str(o) for o in outputs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_with_default(self):
|
||||
"""Test Switch falling through to default."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: switch-default-test
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.color
|
||||
value: green
|
||||
- kind: Switch
|
||||
value: =Local.color
|
||||
cases:
|
||||
- match: red
|
||||
actions:
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Red
|
||||
- match: blue
|
||||
actions:
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Blue
|
||||
default:
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Unknown color
|
||||
""")
|
||||
|
||||
result = await workflow.run({})
|
||||
outputs = result.get_outputs()
|
||||
|
||||
assert any("Unknown color" in str(o) for o in outputs)
|
||||
|
||||
|
||||
@_requires_powerfx
|
||||
class TestWorkflowFactoryMultipleActionTypes:
|
||||
"""Tests for workflows with multiple action types."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_multiple_variables(self):
|
||||
"""Test SetMultipleVariables action."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: multi-set-test
|
||||
actions:
|
||||
- kind: SetMultipleVariables
|
||||
variables:
|
||||
- path: Local.a
|
||||
value: 1
|
||||
- path: Local.b
|
||||
value: 2
|
||||
- path: Local.c
|
||||
value: 3
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Done
|
||||
""")
|
||||
|
||||
result = await workflow.run({})
|
||||
outputs = result.get_outputs()
|
||||
|
||||
assert any("Done" in str(o) for o in outputs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_value(self):
|
||||
"""Test AppendValue action."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: append-test
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.list
|
||||
value: []
|
||||
- kind: AppendValue
|
||||
path: Local.list
|
||||
value: first
|
||||
- kind: AppendValue
|
||||
path: Local.list
|
||||
value: second
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Done
|
||||
""")
|
||||
|
||||
result = await workflow.run({})
|
||||
outputs = result.get_outputs()
|
||||
|
||||
assert any("Done" in str(o) for o in outputs)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_event(self):
|
||||
"""Test EmitEvent action."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: emit-event-test
|
||||
actions:
|
||||
- kind: EmitEvent
|
||||
event:
|
||||
name: test_event
|
||||
data:
|
||||
message: Hello
|
||||
- kind: SendActivity
|
||||
activity:
|
||||
text: Event emitted
|
||||
""")
|
||||
|
||||
result = await workflow.run({})
|
||||
outputs = result.get_outputs()
|
||||
|
||||
# Workflow should complete
|
||||
assert any("Event emitted" in str(o) for o in outputs)
|
||||
|
||||
|
||||
class TestWorkflowFactoryYamlErrors:
|
||||
"""Tests for YAML parsing error handling."""
|
||||
|
||||
def test_invalid_yaml_raises(self):
|
||||
"""Test that invalid YAML raises DeclarativeWorkflowError."""
|
||||
factory = WorkflowFactory()
|
||||
with pytest.raises(DeclarativeWorkflowError, match="Invalid YAML"):
|
||||
factory.create_workflow_from_yaml("""
|
||||
name: broken-yaml
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: [unclosed bracket
|
||||
""")
|
||||
|
||||
def test_non_dict_workflow_raises(self):
|
||||
"""Test that non-dict workflow definition raises error."""
|
||||
factory = WorkflowFactory()
|
||||
with pytest.raises(DeclarativeWorkflowError, match="must be a dictionary"):
|
||||
factory.create_workflow_from_yaml("- just a list item")
|
||||
|
||||
|
||||
class TestWorkflowFactoryTriggerFormat:
|
||||
"""Tests for trigger-based workflow format."""
|
||||
|
||||
def test_trigger_based_workflow(self):
|
||||
"""Test workflow with trigger-based format."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
kind: Workflow
|
||||
trigger:
|
||||
kind: OnConversationStart
|
||||
id: my_trigger
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
assert workflow is not None
|
||||
assert workflow.name == "my_trigger"
|
||||
|
||||
def test_trigger_workflow_without_id(self):
|
||||
"""Test trigger workflow without id uses default name."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
kind: Workflow
|
||||
trigger:
|
||||
kind: OnConversationStart
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
assert workflow is not None
|
||||
assert workflow.name == "declarative_workflow"
|
||||
|
||||
|
||||
class TestWorkflowFactoryAgentCreation:
|
||||
"""Tests for agent creation from definitions."""
|
||||
|
||||
def test_agent_creation_with_file_reference(self, tmp_path):
|
||||
"""Test creating agent from file reference."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework_declarative import AgentFactory
|
||||
|
||||
# Create a minimal agent YAML file (using Prompt kind)
|
||||
agent_file = tmp_path / "test_agent.yaml"
|
||||
agent_file.write_text("""
|
||||
kind: Prompt
|
||||
name: TestAgent
|
||||
description: A test agent
|
||||
instructions: You are a test agent.
|
||||
""")
|
||||
|
||||
# Create a mock client and agent factory
|
||||
mock_client = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.name = "TestAgent"
|
||||
mock_client.create_agent.return_value = mock_agent
|
||||
|
||||
agent_factory = AgentFactory(client=mock_client)
|
||||
|
||||
# Create workflow that references the agent
|
||||
workflow_file = tmp_path / "workflow.yaml"
|
||||
workflow_file.write_text(f"""
|
||||
kind: Workflow
|
||||
agents:
|
||||
TestAgent:
|
||||
file: {agent_file.name}
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
factory = WorkflowFactory(agent_factory=agent_factory)
|
||||
workflow = factory.create_workflow_from_yaml_path(workflow_file)
|
||||
|
||||
assert workflow is not None
|
||||
assert "TestAgent" in workflow._declarative_agents
|
||||
|
||||
def test_agent_connection_definition_raises(self):
|
||||
"""Test that connection-based agent definition raises error."""
|
||||
factory = WorkflowFactory()
|
||||
with pytest.raises(DeclarativeWorkflowError, match="Connection-based agents"):
|
||||
factory.create_workflow_from_yaml("""
|
||||
kind: Workflow
|
||||
agents:
|
||||
MyAgent:
|
||||
connection: azure-connection
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
def test_invalid_agent_definition_raises(self):
|
||||
"""Test that invalid agent definition raises error."""
|
||||
factory = WorkflowFactory()
|
||||
with pytest.raises(DeclarativeWorkflowError, match="Invalid agent definition"):
|
||||
factory.create_workflow_from_yaml("""
|
||||
kind: Workflow
|
||||
agents:
|
||||
MyAgent:
|
||||
unknown_field: value
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
def test_preregistered_agent_not_overwritten(self):
|
||||
"""Test that pre-registered agents are not overwritten by definitions."""
|
||||
|
||||
class MockAgent:
|
||||
name = "PreregisteredAgent"
|
||||
|
||||
factory = WorkflowFactory(agents={"TestAgent": MockAgent()})
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
kind: Workflow
|
||||
agents:
|
||||
TestAgent:
|
||||
kind: Agent
|
||||
name: OverrideAttempt
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
assert workflow._declarative_agents["TestAgent"].name == "PreregisteredAgent"
|
||||
|
||||
|
||||
class TestWorkflowFactoryInputSchema:
|
||||
"""Tests for input schema conversion."""
|
||||
|
||||
def test_inputs_to_json_schema_basic(self):
|
||||
"""Test basic input schema conversion."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: input-schema-test
|
||||
inputs:
|
||||
name:
|
||||
type: string
|
||||
description: The user's name
|
||||
age:
|
||||
type: integer
|
||||
description: The user's age
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
schema = workflow.input_schema
|
||||
assert schema["type"] == "object"
|
||||
assert "name" in schema["properties"]
|
||||
assert "age" in schema["properties"]
|
||||
assert schema["properties"]["name"]["type"] == "string"
|
||||
assert schema["properties"]["age"]["type"] == "integer"
|
||||
assert "name" in schema["required"]
|
||||
assert "age" in schema["required"]
|
||||
|
||||
def test_inputs_schema_with_optional_field(self):
|
||||
"""Test input schema with optional field."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: optional-input-test
|
||||
inputs:
|
||||
required_field:
|
||||
type: string
|
||||
required: true
|
||||
optional_field:
|
||||
type: string
|
||||
required: false
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
schema = workflow.input_schema
|
||||
assert "required_field" in schema["required"]
|
||||
assert "optional_field" not in schema["required"]
|
||||
|
||||
def test_inputs_schema_with_default_value(self):
|
||||
"""Test input schema with default value."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: default-input-test
|
||||
inputs:
|
||||
greeting:
|
||||
type: string
|
||||
default: Hello
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
schema = workflow.input_schema
|
||||
assert schema["properties"]["greeting"]["default"] == "Hello"
|
||||
|
||||
def test_inputs_schema_with_enum(self):
|
||||
"""Test input schema with enum values."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: enum-input-test
|
||||
inputs:
|
||||
color:
|
||||
type: string
|
||||
enum:
|
||||
- red
|
||||
- green
|
||||
- blue
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
schema = workflow.input_schema
|
||||
assert schema["properties"]["color"]["enum"] == ["red", "green", "blue"]
|
||||
|
||||
def test_inputs_schema_type_mappings(self):
|
||||
"""Test various type mappings in input schema."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: type-mapping-test
|
||||
inputs:
|
||||
str_field:
|
||||
type: str
|
||||
int_field:
|
||||
type: int
|
||||
float_field:
|
||||
type: float
|
||||
bool_field:
|
||||
type: bool
|
||||
list_field:
|
||||
type: list
|
||||
dict_field:
|
||||
type: dict
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
schema = workflow.input_schema
|
||||
assert schema["properties"]["str_field"]["type"] == "string"
|
||||
assert schema["properties"]["int_field"]["type"] == "integer"
|
||||
assert schema["properties"]["float_field"]["type"] == "number"
|
||||
assert schema["properties"]["bool_field"]["type"] == "boolean"
|
||||
assert schema["properties"]["list_field"]["type"] == "array"
|
||||
assert schema["properties"]["dict_field"]["type"] == "object"
|
||||
|
||||
def test_inputs_schema_simple_format(self):
|
||||
"""Test simple input format (field: type)."""
|
||||
factory = WorkflowFactory()
|
||||
workflow = factory.create_workflow_from_yaml("""
|
||||
name: simple-input-test
|
||||
inputs:
|
||||
name: string
|
||||
count: integer
|
||||
actions:
|
||||
- kind: SetValue
|
||||
path: Local.x
|
||||
value: 1
|
||||
""")
|
||||
|
||||
schema = workflow.input_schema
|
||||
assert schema["properties"]["name"]["type"] == "string"
|
||||
assert schema["properties"]["count"]["type"] == "integer"
|
||||
assert "name" in schema["required"]
|
||||
assert "count" in schema["required"]
|
||||
|
||||
|
||||
class TestWorkflowFactoryChaining:
|
||||
"""Tests for fluent method chaining."""
|
||||
|
||||
def test_fluent_agent_registration(self):
|
||||
"""Test fluent agent registration."""
|
||||
|
||||
class MockAgent1:
|
||||
name = "Agent1"
|
||||
|
||||
class MockAgent2:
|
||||
name = "Agent2"
|
||||
|
||||
factory = WorkflowFactory().register_agent("agent1", MockAgent1()).register_agent("agent2", MockAgent2())
|
||||
|
||||
assert "agent1" in factory._agents
|
||||
assert "agent2" in factory._agents
|
||||
|
||||
def test_fluent_binding_registration(self):
|
||||
"""Test fluent binding registration."""
|
||||
|
||||
def func1():
|
||||
return 1
|
||||
|
||||
def func2():
|
||||
return 2
|
||||
|
||||
factory = WorkflowFactory().register_binding("func1", func1).register_binding("func2", func2)
|
||||
|
||||
assert "func1" in factory._bindings
|
||||
assert "func2" in factory._bindings
|
||||
|
||||
def test_fluent_mixed_registration(self):
|
||||
"""Test mixed fluent registration."""
|
||||
|
||||
class MockAgent:
|
||||
name = "Agent"
|
||||
|
||||
def my_tool():
|
||||
return "tool"
|
||||
|
||||
def my_binding():
|
||||
return "binding"
|
||||
|
||||
factory = (
|
||||
WorkflowFactory()
|
||||
.register_agent("agent", MockAgent())
|
||||
.register_tool("tool", my_tool)
|
||||
.register_binding("binding", my_binding)
|
||||
)
|
||||
|
||||
assert "agent" in factory._agents
|
||||
assert "tool" in factory._tools
|
||||
assert "binding" in factory._bindings
|
||||
|
||||
@@ -225,6 +225,335 @@ class TestWorkflowStateResetTurn:
|
||||
assert state.get("Workflow.Outputs.result") == "done"
|
||||
|
||||
|
||||
class TestWorkflowStateEvalSimple:
|
||||
"""Tests for _eval_simple fallback PowerFx evaluation."""
|
||||
|
||||
def test_negation_prefix(self):
|
||||
"""Test negation with ! prefix."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.value", True)
|
||||
assert state._eval_simple("!Local.value") is False
|
||||
state.set("Local.value", False)
|
||||
assert state._eval_simple("!Local.value") is True
|
||||
|
||||
def test_not_function(self):
|
||||
"""Test Not() function."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.flag", True)
|
||||
assert state._eval_simple("Not(Local.flag)") is False
|
||||
state.set("Local.flag", False)
|
||||
assert state._eval_simple("Not(Local.flag)") is True
|
||||
|
||||
def test_and_operator(self):
|
||||
"""Test And operator."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", True)
|
||||
state.set("Local.b", True)
|
||||
assert state._eval_simple("Local.a And Local.b") is True
|
||||
state.set("Local.b", False)
|
||||
assert state._eval_simple("Local.a And Local.b") is False
|
||||
|
||||
def test_or_operator(self):
|
||||
"""Test Or operator."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", False)
|
||||
state.set("Local.b", False)
|
||||
assert state._eval_simple("Local.a Or Local.b") is False
|
||||
state.set("Local.b", True)
|
||||
assert state._eval_simple("Local.a Or Local.b") is True
|
||||
|
||||
def test_or_operator_double_pipe(self):
|
||||
"""Test || operator."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.x", False)
|
||||
state.set("Local.y", True)
|
||||
assert state._eval_simple("Local.x || Local.y") is True
|
||||
|
||||
def test_less_than(self):
|
||||
"""Test < comparison."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.num", 5)
|
||||
assert state._eval_simple("Local.num < 10") is True
|
||||
assert state._eval_simple("Local.num < 3") is False
|
||||
|
||||
def test_greater_than(self):
|
||||
"""Test > comparison."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.num", 5)
|
||||
assert state._eval_simple("Local.num > 3") is True
|
||||
assert state._eval_simple("Local.num > 10") is False
|
||||
|
||||
def test_less_than_or_equal(self):
|
||||
"""Test <= comparison."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.num", 5)
|
||||
assert state._eval_simple("Local.num <= 5") is True
|
||||
assert state._eval_simple("Local.num <= 4") is False
|
||||
|
||||
def test_greater_than_or_equal(self):
|
||||
"""Test >= comparison."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.num", 5)
|
||||
assert state._eval_simple("Local.num >= 5") is True
|
||||
assert state._eval_simple("Local.num >= 6") is False
|
||||
|
||||
def test_not_equal(self):
|
||||
"""Test <> comparison."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.val", "hello")
|
||||
assert state._eval_simple('Local.val <> "world"') is True
|
||||
assert state._eval_simple('Local.val <> "hello"') is False
|
||||
|
||||
def test_equal(self):
|
||||
"""Test = comparison."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.val", "test")
|
||||
assert state._eval_simple('Local.val = "test"') is True
|
||||
assert state._eval_simple('Local.val = "other"') is False
|
||||
|
||||
def test_addition_numeric(self):
|
||||
"""Test + operator with numbers."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 3)
|
||||
state.set("Local.b", 4)
|
||||
assert state._eval_simple("Local.a + Local.b") == 7.0
|
||||
|
||||
def test_addition_string_concat(self):
|
||||
"""Test + operator falls back to string concat."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", "hello")
|
||||
state.set("Local.b", "world")
|
||||
assert state._eval_simple("Local.a + Local.b") == "helloworld"
|
||||
|
||||
def test_addition_with_none(self):
|
||||
"""Test + treats None as 0."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 5)
|
||||
# Local.b doesn't exist, so it's None
|
||||
assert state._eval_simple("Local.a + Local.b") == 5.0
|
||||
|
||||
def test_subtraction(self):
|
||||
"""Test - operator."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 10)
|
||||
state.set("Local.b", 3)
|
||||
assert state._eval_simple("Local.a - Local.b") == 7.0
|
||||
|
||||
def test_subtraction_with_none(self):
|
||||
"""Test - treats None as 0."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 5)
|
||||
assert state._eval_simple("Local.a - Local.missing") == 5.0
|
||||
|
||||
def test_multiplication(self):
|
||||
"""Test * operator."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 4)
|
||||
state.set("Local.b", 5)
|
||||
assert state._eval_simple("Local.a * Local.b") == 20.0
|
||||
|
||||
def test_multiplication_with_none(self):
|
||||
"""Test * treats None as 0."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 5)
|
||||
assert state._eval_simple("Local.a * Local.missing") == 0.0
|
||||
|
||||
def test_division(self):
|
||||
"""Test / operator."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 20)
|
||||
state.set("Local.b", 4)
|
||||
assert state._eval_simple("Local.a / Local.b") == 5.0
|
||||
|
||||
def test_division_by_zero(self):
|
||||
"""Test / by zero returns None."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 10)
|
||||
state.set("Local.b", 0)
|
||||
assert state._eval_simple("Local.a / Local.b") is None
|
||||
|
||||
def test_string_literal_double_quotes(self):
|
||||
"""Test string literal with double quotes."""
|
||||
state = WorkflowState()
|
||||
assert state._eval_simple('"hello world"') == "hello world"
|
||||
|
||||
def test_string_literal_single_quotes(self):
|
||||
"""Test string literal with single quotes."""
|
||||
state = WorkflowState()
|
||||
assert state._eval_simple("'hello world'") == "hello world"
|
||||
|
||||
def test_integer_literal(self):
|
||||
"""Test integer literal."""
|
||||
state = WorkflowState()
|
||||
assert state._eval_simple("42") == 42
|
||||
|
||||
def test_float_literal(self):
|
||||
"""Test float literal."""
|
||||
state = WorkflowState()
|
||||
assert state._eval_simple("3.14") == 3.14
|
||||
|
||||
def test_boolean_true_literal(self):
|
||||
"""Test true literal (case insensitive)."""
|
||||
state = WorkflowState()
|
||||
assert state._eval_simple("true") is True
|
||||
assert state._eval_simple("True") is True
|
||||
assert state._eval_simple("TRUE") is True
|
||||
|
||||
def test_boolean_false_literal(self):
|
||||
"""Test false literal (case insensitive)."""
|
||||
state = WorkflowState()
|
||||
assert state._eval_simple("false") is False
|
||||
assert state._eval_simple("False") is False
|
||||
assert state._eval_simple("FALSE") is False
|
||||
|
||||
def test_variable_reference(self):
|
||||
"""Test simple variable reference."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.myvar", "myvalue")
|
||||
assert state._eval_simple("Local.myvar") == "myvalue"
|
||||
|
||||
def test_unknown_expression_returned_as_is(self):
|
||||
"""Test that unknown expressions are returned as-is."""
|
||||
state = WorkflowState()
|
||||
result = state._eval_simple("unknown_identifier")
|
||||
assert result == "unknown_identifier"
|
||||
|
||||
def test_agent_namespace_reference(self):
|
||||
"""Test Agent namespace variable reference."""
|
||||
state = WorkflowState()
|
||||
state.set_agent_result(text="agent response")
|
||||
assert state._eval_simple("Agent.text") == "agent response"
|
||||
|
||||
def test_conversation_namespace_reference(self):
|
||||
"""Test Conversation namespace variable reference."""
|
||||
state = WorkflowState()
|
||||
state.add_conversation_message({"role": "user", "content": "hello"})
|
||||
result = state._eval_simple("Conversation.messages")
|
||||
assert len(result) == 1
|
||||
|
||||
def test_workflow_inputs_reference(self):
|
||||
"""Test Workflow.Inputs reference."""
|
||||
state = WorkflowState(inputs={"name": "test"})
|
||||
assert state._eval_simple("Workflow.Inputs.name") == "test"
|
||||
|
||||
|
||||
class TestWorkflowStateParseFunctionArgs:
|
||||
"""Tests for _parse_function_args helper."""
|
||||
|
||||
def test_simple_args(self):
|
||||
"""Test parsing simple comma-separated args."""
|
||||
state = WorkflowState()
|
||||
args = state._parse_function_args("1, 2, 3")
|
||||
assert args == ["1", "2", "3"]
|
||||
|
||||
def test_string_args_with_commas(self):
|
||||
"""Test parsing string args containing commas."""
|
||||
state = WorkflowState()
|
||||
args = state._parse_function_args('"hello, world", "another"')
|
||||
assert args == ['"hello, world"', '"another"']
|
||||
|
||||
def test_nested_function_args(self):
|
||||
"""Test parsing nested function calls."""
|
||||
state = WorkflowState()
|
||||
args = state._parse_function_args("Concat(a, b), c")
|
||||
assert args == ["Concat(a, b)", "c"]
|
||||
|
||||
def test_empty_args(self):
|
||||
"""Test parsing empty args string."""
|
||||
state = WorkflowState()
|
||||
args = state._parse_function_args("")
|
||||
assert args == []
|
||||
|
||||
def test_single_arg(self):
|
||||
"""Test parsing single argument."""
|
||||
state = WorkflowState()
|
||||
args = state._parse_function_args("single")
|
||||
assert args == ["single"]
|
||||
|
||||
def test_deeply_nested_parens(self):
|
||||
"""Test parsing deeply nested parentheses."""
|
||||
state = WorkflowState()
|
||||
args = state._parse_function_args("Func1(Func2(a, b)), c")
|
||||
assert args == ["Func1(Func2(a, b))", "c"]
|
||||
|
||||
|
||||
class TestWorkflowStateEvalIfExpression:
|
||||
"""Tests for eval_if_expression method."""
|
||||
|
||||
def test_dict_values_evaluated(self):
|
||||
"""Test that dict values are recursively evaluated."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.name", "World")
|
||||
result = state.eval_if_expression({"greeting": "=Local.name", "static": "value"})
|
||||
assert result == {"greeting": "World", "static": "value"}
|
||||
|
||||
def test_list_values_evaluated(self):
|
||||
"""Test that list values are recursively evaluated."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.val", 42)
|
||||
result = state.eval_if_expression(["=Local.val", "static"])
|
||||
assert result == [42, "static"]
|
||||
|
||||
def test_nested_dict_in_list(self):
|
||||
"""Test nested dict in list is evaluated."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.x", 10)
|
||||
result = state.eval_if_expression([{"key": "=Local.x"}])
|
||||
assert result == [{"key": 10}]
|
||||
|
||||
|
||||
class TestWorkflowStateSetErrors:
|
||||
"""Tests for set() error handling."""
|
||||
|
||||
def test_set_workflow_directly_raises(self):
|
||||
"""Test that setting Workflow directly raises error."""
|
||||
state = WorkflowState()
|
||||
with pytest.raises(ValueError, match="Cannot set 'Workflow' directly"):
|
||||
state.set("Workflow", "value")
|
||||
|
||||
def test_set_unknown_workflow_namespace_raises(self):
|
||||
"""Test that setting unknown Workflow sub-namespace raises."""
|
||||
state = WorkflowState()
|
||||
with pytest.raises(ValueError, match="Unknown Workflow namespace"):
|
||||
state.set("Workflow.Unknown.path", "value")
|
||||
|
||||
def test_set_namespace_root_raises(self):
|
||||
"""Test that setting namespace root raises error."""
|
||||
state = WorkflowState()
|
||||
with pytest.raises(ValueError, match="Cannot replace entire namespace"):
|
||||
state.set("Local", "value")
|
||||
|
||||
|
||||
class TestWorkflowStateGetEdgeCases:
|
||||
"""Tests for get() edge cases."""
|
||||
|
||||
def test_get_empty_path(self):
|
||||
"""Test get with empty path returns default."""
|
||||
state = WorkflowState()
|
||||
assert state.get("", "default") == "default"
|
||||
|
||||
def test_get_unknown_namespace(self):
|
||||
"""Test get from unknown namespace returns default."""
|
||||
state = WorkflowState()
|
||||
assert state.get("Unknown.path") is None
|
||||
assert state.get("Unknown.path", "fallback") == "fallback"
|
||||
|
||||
def test_get_with_object_attribute(self):
|
||||
"""Test get navigates object attributes."""
|
||||
state = WorkflowState()
|
||||
|
||||
class MockObj:
|
||||
attr = "attribute_value"
|
||||
|
||||
state.set("Local.obj", MockObj())
|
||||
assert state.get("Local.obj.attr") == "attribute_value"
|
||||
|
||||
def test_get_unknown_workflow_subspace(self):
|
||||
"""Test get from unknown Workflow sub-namespace."""
|
||||
state = WorkflowState()
|
||||
assert state.get("Workflow.Unknown.path") is None
|
||||
|
||||
|
||||
class TestWorkflowStateConversationIdInit:
|
||||
"""Tests that WorkflowState generates a real UUID for System.ConversationId."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user