mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add CreateConversationExecutor, fix input routing, remove unused handler layer (#4159)
* Fixed declarative deep research sample * Small fix * Resolved comment * Add CreateConversationExecutor, fix input routing, remove unused handler layer * Address Copilot feedback * Fix System.ConversationId --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
bb4fe48c9a
commit
de612c47f5
@@ -40,6 +40,7 @@ from ._executors_basic import (
|
||||
BASIC_ACTION_EXECUTORS,
|
||||
AppendValueExecutor,
|
||||
ClearAllVariablesExecutor,
|
||||
CreateConversationExecutor,
|
||||
EmitEventExecutor,
|
||||
ResetVariableExecutor,
|
||||
SendActivityExecutor,
|
||||
@@ -68,13 +69,6 @@ from ._executors_external_input import (
|
||||
WaitForInputExecutor,
|
||||
)
|
||||
from ._factory import DeclarativeWorkflowError, WorkflowFactory
|
||||
from ._handlers import ActionHandler, action_handler, get_action_handler
|
||||
from ._human_input import (
|
||||
ExternalLoopEvent,
|
||||
QuestionRequest,
|
||||
process_external_loop,
|
||||
validate_input_response,
|
||||
)
|
||||
from ._state import WorkflowState
|
||||
|
||||
__all__ = [
|
||||
@@ -87,7 +81,6 @@ __all__ = [
|
||||
"EXTERNAL_INPUT_EXECUTORS",
|
||||
"TOOL_REGISTRY_KEY",
|
||||
"ActionComplete",
|
||||
"ActionHandler",
|
||||
"ActionTrigger",
|
||||
"AgentExternalInputRequest",
|
||||
"AgentExternalInputResponse",
|
||||
@@ -98,6 +91,7 @@ __all__ = [
|
||||
"ConfirmationExecutor",
|
||||
"ContinueLoopExecutor",
|
||||
"ConversationData",
|
||||
"CreateConversationExecutor",
|
||||
"DeclarativeActionExecutor",
|
||||
"DeclarativeMessage",
|
||||
"DeclarativeStateData",
|
||||
@@ -109,7 +103,6 @@ __all__ = [
|
||||
"EndWorkflowExecutor",
|
||||
"ExternalInputRequest",
|
||||
"ExternalInputResponse",
|
||||
"ExternalLoopEvent",
|
||||
"ExternalLoopState",
|
||||
"ForeachInitExecutor",
|
||||
"ForeachNextExecutor",
|
||||
@@ -119,7 +112,6 @@ __all__ = [
|
||||
"LoopControl",
|
||||
"LoopIterationResult",
|
||||
"QuestionExecutor",
|
||||
"QuestionRequest",
|
||||
"RequestExternalInputExecutor",
|
||||
"ResetVariableExecutor",
|
||||
"SendActivityExecutor",
|
||||
@@ -130,8 +122,4 @@ __all__ = [
|
||||
"WaitForInputExecutor",
|
||||
"WorkflowFactory",
|
||||
"WorkflowState",
|
||||
"action_handler",
|
||||
"get_action_handler",
|
||||
"process_external_loop",
|
||||
"validate_input_response",
|
||||
]
|
||||
|
||||
-652
@@ -1,652 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Agent invocation action handlers for declarative workflows.
|
||||
|
||||
This module implements handlers for:
|
||||
- InvokeAzureAgent: Invoke a hosted Azure AI agent
|
||||
- InvokePromptAgent: Invoke a local prompt-based agent
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework._types import AgentResponse, Message
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
AgentResponseEvent,
|
||||
AgentStreamingChunkEvent,
|
||||
WorkflowEvent,
|
||||
action_handler,
|
||||
)
|
||||
from ._human_input import ExternalLoopEvent, QuestionRequest
|
||||
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
def _extract_json_from_response(text: str) -> Any:
|
||||
r"""Extract and parse JSON from an agent response.
|
||||
|
||||
Agents often return JSON wrapped in markdown code blocks or with
|
||||
explanatory text. This function attempts to extract and parse the
|
||||
JSON content from various formats:
|
||||
|
||||
1. Pure JSON: {"key": "value"}
|
||||
2. Markdown code block: ```json\n{"key": "value"}\n```
|
||||
3. Markdown code block (no language): ```\n{"key": "value"}\n```
|
||||
4. JSON with leading/trailing text: Here's the result: {"key": "value"}
|
||||
5. Multiple JSON objects: Returns the LAST valid JSON object
|
||||
|
||||
When multiple JSON objects are present (e.g., streaming agent responses
|
||||
that emit partial then final results), this returns the last complete
|
||||
JSON object, which is typically the final/complete result.
|
||||
|
||||
Args:
|
||||
text: The raw text response from an agent
|
||||
|
||||
Returns:
|
||||
Parsed JSON as a Python dict/list, or None if parsing fails
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: If no valid JSON can be extracted
|
||||
"""
|
||||
import re
|
||||
|
||||
if not text:
|
||||
return None
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Try parsing as pure JSON first
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting from markdown code blocks: ```json ... ``` or ``` ... ```
|
||||
# Use the last code block if there are multiple
|
||||
code_block_patterns = [
|
||||
r"```json\s*\n?(.*?)\n?```", # ```json ... ```
|
||||
r"```\s*\n?(.*?)\n?```", # ``` ... ```
|
||||
]
|
||||
for pattern in code_block_patterns:
|
||||
matches = list(re.finditer(pattern, text, re.DOTALL))
|
||||
if matches:
|
||||
# Try the last match first (most likely to be the final result)
|
||||
for match in reversed(matches):
|
||||
try:
|
||||
return json.loads(match.group(1).strip())
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Find ALL JSON objects {...} or arrays [...] in the text and return the last valid one
|
||||
# This handles cases where agents stream multiple JSON objects (partial, then final)
|
||||
all_json_objects: list[Any] = []
|
||||
|
||||
pos = 0
|
||||
while pos < len(text):
|
||||
# Find next { or [
|
||||
json_start = -1
|
||||
bracket_char = None
|
||||
for i in range(pos, len(text)):
|
||||
if text[i] == "{":
|
||||
json_start = i
|
||||
bracket_char = "{"
|
||||
break
|
||||
if text[i] == "[":
|
||||
json_start = i
|
||||
bracket_char = "["
|
||||
break
|
||||
|
||||
if json_start < 0:
|
||||
break # No more JSON objects
|
||||
|
||||
# Find matching closing bracket
|
||||
open_bracket = bracket_char
|
||||
close_bracket = "}" if open_bracket == "{" else "]"
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
found_end = False
|
||||
|
||||
for i in range(json_start, len(text)):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
continue
|
||||
|
||||
if char == open_bracket:
|
||||
depth += 1
|
||||
elif char == close_bracket:
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
# Found the end
|
||||
potential_json = text[json_start : i + 1]
|
||||
try:
|
||||
parsed = json.loads(potential_json)
|
||||
all_json_objects.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
pos = i + 1
|
||||
found_end = True
|
||||
break
|
||||
|
||||
if not found_end:
|
||||
# Malformed JSON, move past the start character
|
||||
pos = json_start + 1
|
||||
|
||||
# Return the last valid JSON object (most likely to be the final/complete result)
|
||||
if all_json_objects:
|
||||
return all_json_objects[-1]
|
||||
|
||||
# Unable to extract JSON
|
||||
raise json.JSONDecodeError("No valid JSON found in response", text, 0)
|
||||
|
||||
|
||||
def _build_messages_from_state(ctx: ActionContext) -> list[Message]:
|
||||
"""Build the message list to send to an agent.
|
||||
|
||||
This collects messages from:
|
||||
1. Conversation history
|
||||
2. Current input (if first agent call)
|
||||
3. Additional context from instructions
|
||||
|
||||
Args:
|
||||
ctx: The action context
|
||||
|
||||
Returns:
|
||||
List of Message objects to send to the agent
|
||||
"""
|
||||
messages: list[Message] = []
|
||||
|
||||
# Get conversation history
|
||||
history = ctx.state.get("conversation.messages", [])
|
||||
if history:
|
||||
messages.extend(history)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@action_handler("InvokeAzureAgent")
|
||||
async def handle_invoke_azure_agent(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]:
|
||||
"""Invoke a hosted Azure AI agent.
|
||||
|
||||
Supports both Python-style and .NET-style YAML schemas:
|
||||
|
||||
Python-style schema:
|
||||
kind: InvokeAzureAgent
|
||||
agent: agentName
|
||||
input: =expression or literal input
|
||||
outputPath: Local.response
|
||||
|
||||
.NET-style schema:
|
||||
kind: InvokeAzureAgent
|
||||
agent:
|
||||
name: AgentName
|
||||
conversationId: =System.ConversationId
|
||||
input:
|
||||
arguments:
|
||||
param1: value1
|
||||
messages: =expression
|
||||
output:
|
||||
messages: Local.Response
|
||||
responseObject: Local.StructuredResponse
|
||||
"""
|
||||
# Get agent name - support both formats
|
||||
agent_config: dict[str, Any] | str | None = ctx.action.get("agent")
|
||||
agent_name: str | None = None
|
||||
if isinstance(agent_config, dict):
|
||||
agent_name = str(agent_config.get("name")) if agent_config.get("name") else None
|
||||
# Support dynamic agent name from expression
|
||||
if agent_name and isinstance(agent_name, str) and agent_name.startswith("="):
|
||||
evaluated = ctx.state.eval_if_expression(agent_name)
|
||||
agent_name = str(evaluated) if evaluated is not None else None
|
||||
elif isinstance(agent_config, str):
|
||||
agent_name = agent_config
|
||||
|
||||
if not agent_name:
|
||||
logger.warning("InvokeAzureAgent action missing 'agent' or 'agent.name' property")
|
||||
return
|
||||
|
||||
# Get input configuration
|
||||
input_config: dict[str, Any] | Any = ctx.action.get("input", {})
|
||||
input_arguments: dict[str, Any] = {}
|
||||
input_messages: Any = None
|
||||
external_loop_when: str | None = None
|
||||
if isinstance(input_config, dict):
|
||||
input_config_typed = cast(dict[str, Any], input_config)
|
||||
input_arguments = cast(dict[str, Any], input_config_typed.get("arguments") or {})
|
||||
input_messages = input_config_typed.get("messages")
|
||||
# Extract external loop configuration
|
||||
external_loop = input_config_typed.get("externalLoop")
|
||||
if isinstance(external_loop, dict):
|
||||
external_loop_typed = cast(dict[str, Any], external_loop)
|
||||
external_loop_when = str(external_loop_typed.get("when")) if external_loop_typed.get("when") else None
|
||||
else:
|
||||
input_messages = input_config # Treat as message directly
|
||||
|
||||
# Get output configuration (.NET style)
|
||||
output_config: dict[str, Any] | Any = ctx.action.get("output", {})
|
||||
output_messages_var: str | None = None
|
||||
output_response_obj_var: str | None = None
|
||||
if isinstance(output_config, dict):
|
||||
output_config_typed = cast(dict[str, Any], output_config)
|
||||
output_messages_var = str(output_config_typed.get("messages")) if output_config_typed.get("messages") else None
|
||||
output_response_obj_var = (
|
||||
str(output_config_typed.get("responseObject")) if output_config_typed.get("responseObject") else None
|
||||
)
|
||||
# auto_send is defined but not used currently
|
||||
_auto_send: bool = bool(output_config_typed.get("autoSend", True))
|
||||
|
||||
# Legacy Python style output path
|
||||
output_path = ctx.action.get("outputPath")
|
||||
|
||||
# Other properties
|
||||
conversation_id = ctx.action.get("conversationId")
|
||||
instructions = ctx.action.get("instructions")
|
||||
tools_config: list[dict[str, Any]] = ctx.action.get("tools", [])
|
||||
|
||||
# Get the agent from registry
|
||||
agent = ctx.agents.get(agent_name)
|
||||
if agent is None:
|
||||
logger.error(f"InvokeAzureAgent: agent '{agent_name}' not found in registry")
|
||||
return
|
||||
|
||||
# Evaluate conversation ID
|
||||
if conversation_id:
|
||||
evaluated_conv_id = ctx.state.eval_if_expression(conversation_id)
|
||||
ctx.state.set("System.ConversationId", evaluated_conv_id)
|
||||
|
||||
# Evaluate instructions (unused currently but may be used for prompting)
|
||||
_ = ctx.state.eval_if_expression(instructions) if instructions else None
|
||||
|
||||
# Build messages
|
||||
messages = _build_messages_from_state(ctx)
|
||||
|
||||
# Handle input messages from .NET style
|
||||
if input_messages:
|
||||
evaluated_input = ctx.state.eval_if_expression(input_messages)
|
||||
if evaluated_input:
|
||||
if isinstance(evaluated_input, str):
|
||||
messages.append(Message(role="user", text=evaluated_input))
|
||||
elif isinstance(evaluated_input, list):
|
||||
for msg_item in evaluated_input: # type: ignore
|
||||
if isinstance(msg_item, str):
|
||||
messages.append(Message(role="user", text=msg_item))
|
||||
elif isinstance(msg_item, Message):
|
||||
messages.append(msg_item)
|
||||
elif isinstance(msg_item, dict) and "content" in msg_item:
|
||||
item_dict = cast(dict[str, Any], msg_item)
|
||||
role: str = str(item_dict.get("role", "user"))
|
||||
content: str = str(item_dict.get("content", ""))
|
||||
if role == "user":
|
||||
messages.append(Message(role="user", text=content))
|
||||
elif role == "assistant":
|
||||
messages.append(Message(role="assistant", text=content))
|
||||
elif role == "system":
|
||||
messages.append(Message(role="system", text=content))
|
||||
|
||||
# Evaluate and include input arguments
|
||||
evaluated_args: dict[str, Any] = {}
|
||||
for arg_key, arg_value in input_arguments.items():
|
||||
evaluated_args[arg_key] = ctx.state.eval_if_expression(arg_value)
|
||||
|
||||
# Prepare tool bindings
|
||||
tool_bindings: dict[str, dict[str, Any]] = {}
|
||||
for tool_config in tools_config:
|
||||
tool_name: str | None = str(tool_config.get("name")) if tool_config.get("name") else None
|
||||
bindings: list[dict[str, Any]] = list(tool_config.get("bindings", [])) # type: ignore[arg-type]
|
||||
if tool_name and bindings:
|
||||
tool_bindings[tool_name] = {
|
||||
str(b.get("name")): ctx.state.eval_if_expression(b.get("input")) for b in bindings if b.get("name")
|
||||
}
|
||||
|
||||
logger.debug(f"InvokeAzureAgent: calling '{agent_name}' with {len(messages)} messages")
|
||||
|
||||
# External loop iteration counter
|
||||
iteration = 0
|
||||
max_iterations = 100 # Safety limit
|
||||
|
||||
# Start external loop if configured
|
||||
# Build options for kwargs propagation to agent tools
|
||||
run_kwargs = ctx.run_kwargs
|
||||
options: dict[str, Any] | None = None
|
||||
if run_kwargs:
|
||||
# Merge caller-provided options to avoid duplicate keyword argument
|
||||
options = dict(run_kwargs.get("options") or {})
|
||||
options["additional_function_arguments"] = run_kwargs
|
||||
# Exclude 'options' from splat to avoid TypeError on duplicate keyword
|
||||
run_kwargs = {k: v for k, v in run_kwargs.items() if k != "options"}
|
||||
|
||||
while True:
|
||||
# Invoke the agent
|
||||
try:
|
||||
# Agents use run() with stream parameter
|
||||
if hasattr(agent, "run"):
|
||||
# Try streaming first
|
||||
try:
|
||||
updates: list[Any] = []
|
||||
tool_calls: list[Any] = []
|
||||
|
||||
async for chunk in agent.run(messages, stream=True, options=options, **run_kwargs):
|
||||
updates.append(chunk)
|
||||
|
||||
# Yield streaming events for text chunks
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
yield AgentStreamingChunkEvent(
|
||||
agent_name=str(agent_name),
|
||||
chunk=chunk.text,
|
||||
)
|
||||
|
||||
# Collect tool calls
|
||||
if hasattr(chunk, "tool_calls"):
|
||||
tool_calls.extend(chunk.tool_calls)
|
||||
|
||||
# Build consolidated response from updates
|
||||
response = AgentResponse.from_updates(updates)
|
||||
text = response.text
|
||||
response_messages = response.messages
|
||||
|
||||
# Update state with result
|
||||
ctx.state.set_agent_result(
|
||||
text=text,
|
||||
messages=response_messages,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
# Add to conversation history
|
||||
if text:
|
||||
ctx.state.add_conversation_message(Message(role="assistant", text=text))
|
||||
|
||||
# Store in output variables (.NET style)
|
||||
if output_messages_var:
|
||||
output_path_mapped = _normalize_variable_path(output_messages_var)
|
||||
ctx.state.set(output_path_mapped, response_messages if response_messages else text)
|
||||
|
||||
if output_response_obj_var:
|
||||
output_path_mapped = _normalize_variable_path(output_response_obj_var)
|
||||
# Try to extract and parse JSON from the response
|
||||
try:
|
||||
parsed = _extract_json_from_response(text) if text else None
|
||||
logger.debug(
|
||||
f"InvokeAzureAgent (streaming): parsed responseObject for "
|
||||
f"'{output_path_mapped}': type={type(parsed).__name__}, "
|
||||
f"value_preview={str(parsed)[:100] if parsed else None}"
|
||||
)
|
||||
ctx.state.set(output_path_mapped, parsed)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"InvokeAzureAgent (streaming): failed to parse JSON for "
|
||||
f"'{output_path_mapped}': {e}, text_preview={text[:100] if text else None}"
|
||||
)
|
||||
ctx.state.set(output_path_mapped, text)
|
||||
|
||||
# Store in output path (Python style)
|
||||
if output_path:
|
||||
ctx.state.set(output_path, text)
|
||||
|
||||
yield AgentResponseEvent(
|
||||
agent_name=str(agent_name),
|
||||
text=text,
|
||||
messages=response_messages,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
except TypeError:
|
||||
# Agent doesn't support streaming, fall back to non-streaming
|
||||
response = await agent.run(messages, options=options, **run_kwargs)
|
||||
|
||||
text = response.text
|
||||
response_messages = response.messages
|
||||
response_tool_calls: list[Any] | None = getattr(response, "tool_calls", None)
|
||||
|
||||
# Update state with result
|
||||
ctx.state.set_agent_result(
|
||||
text=text,
|
||||
messages=response_messages,
|
||||
tool_calls=response_tool_calls,
|
||||
)
|
||||
|
||||
# Add to conversation history
|
||||
if text:
|
||||
ctx.state.add_conversation_message(Message(role="assistant", text=text))
|
||||
|
||||
# Store in output variables (.NET style)
|
||||
if output_messages_var:
|
||||
output_path_mapped = _normalize_variable_path(output_messages_var)
|
||||
ctx.state.set(output_path_mapped, response_messages if response_messages else text)
|
||||
|
||||
if output_response_obj_var:
|
||||
output_path_mapped = _normalize_variable_path(output_response_obj_var)
|
||||
try:
|
||||
parsed = _extract_json_from_response(text) if text else None
|
||||
logger.debug(
|
||||
f"InvokeAzureAgent (non-streaming): parsed responseObject for "
|
||||
f"'{output_path_mapped}': type={type(parsed).__name__}, "
|
||||
f"value_preview={str(parsed)[:100] if parsed else None}"
|
||||
)
|
||||
ctx.state.set(output_path_mapped, parsed)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"InvokeAzureAgent (non-streaming): failed to parse JSON for "
|
||||
f"'{output_path_mapped}': {e}, text_preview={text[:100] if text else None}"
|
||||
)
|
||||
ctx.state.set(output_path_mapped, text)
|
||||
|
||||
# Store in output path (Python style)
|
||||
if output_path:
|
||||
ctx.state.set(output_path, text)
|
||||
|
||||
yield AgentResponseEvent(
|
||||
agent_name=str(agent_name),
|
||||
text=text,
|
||||
messages=response_messages,
|
||||
tool_calls=response_tool_calls,
|
||||
)
|
||||
else:
|
||||
logger.error(f"InvokeAzureAgent: agent '{agent_name}' has no run method")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}': {e}")
|
||||
raise
|
||||
|
||||
# Check external loop condition
|
||||
if external_loop_when:
|
||||
# Evaluate the loop condition
|
||||
should_continue = ctx.state.eval(external_loop_when)
|
||||
should_continue = bool(should_continue) if should_continue is not None else False
|
||||
|
||||
logger.debug(
|
||||
f"InvokeAzureAgent: external loop condition '{str(external_loop_when)[:50]}' = "
|
||||
f"{should_continue} (iteration {iteration})"
|
||||
)
|
||||
|
||||
if should_continue and iteration < max_iterations:
|
||||
# Emit event to signal waiting for external input
|
||||
action_id: str = str(ctx.action.get("id", f"agent_{agent_name}"))
|
||||
yield ExternalLoopEvent(
|
||||
action_id=action_id,
|
||||
iteration=iteration,
|
||||
condition_expression=str(external_loop_when),
|
||||
)
|
||||
|
||||
# The workflow executor should:
|
||||
# 1. Pause execution
|
||||
# 2. Wait for external input
|
||||
# 3. Update state with input
|
||||
# 4. Resume this generator
|
||||
|
||||
# For now, we request input via QuestionRequest
|
||||
yield QuestionRequest(
|
||||
request_id=f"{action_id}_input_{iteration}",
|
||||
prompt="Waiting for user input...",
|
||||
variable="Local.userInput",
|
||||
)
|
||||
|
||||
iteration += 1
|
||||
|
||||
# Clear messages for next iteration (start fresh with conversation)
|
||||
messages = _build_messages_from_state(ctx)
|
||||
continue
|
||||
elif iteration >= max_iterations:
|
||||
logger.warning(f"InvokeAzureAgent: external loop exceeded max iterations ({max_iterations})")
|
||||
|
||||
# No external loop or condition is false - exit
|
||||
break
|
||||
|
||||
|
||||
def _normalize_variable_path(variable: str) -> str:
|
||||
"""Normalize variable names to ensure they have a scope prefix.
|
||||
|
||||
Args:
|
||||
variable: Variable name like 'Local.X' or 'System.ConversationId'
|
||||
|
||||
Returns:
|
||||
The variable path with a scope prefix (defaults to Local if none provided)
|
||||
"""
|
||||
if variable.startswith(("Local.", "System.", "Workflow.", "Agent.", "Conversation.")):
|
||||
# Already has a proper namespace
|
||||
return variable
|
||||
if "." in variable:
|
||||
# Has some namespace, use as-is
|
||||
return variable
|
||||
# Default to Local scope
|
||||
return "Local." + variable
|
||||
|
||||
|
||||
@action_handler("InvokePromptAgent")
|
||||
async def handle_invoke_prompt_agent(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]:
|
||||
"""Invoke a local prompt-based agent (similar to InvokeAzureAgent but for local agents).
|
||||
|
||||
Action schema:
|
||||
kind: InvokePromptAgent
|
||||
agent: agentName # name of the agent in the agents registry
|
||||
input: =expression or literal input
|
||||
instructions: =expression or literal prompt/instructions
|
||||
outputPath: Local.response # optional path to store result
|
||||
"""
|
||||
# Implementation is similar to InvokeAzureAgent
|
||||
# The difference is primarily in how the agent is configured
|
||||
agent_name_raw = ctx.action.get("agent")
|
||||
if not isinstance(agent_name_raw, str):
|
||||
logger.warning("InvokePromptAgent action missing 'agent' property")
|
||||
return
|
||||
agent_name: str = agent_name_raw
|
||||
input_expr = ctx.action.get("input")
|
||||
instructions = ctx.action.get("instructions")
|
||||
output_path = ctx.action.get("outputPath")
|
||||
|
||||
# Get the agent from registry
|
||||
agent = ctx.agents.get(agent_name)
|
||||
if agent is None:
|
||||
logger.error(f"InvokePromptAgent: agent '{agent_name}' not found in registry")
|
||||
return
|
||||
|
||||
# Evaluate input
|
||||
input_value = ctx.state.eval_if_expression(input_expr) if input_expr else None
|
||||
|
||||
# Evaluate instructions (unused currently but may be used for prompting)
|
||||
_ = ctx.state.eval_if_expression(instructions) if instructions else None
|
||||
|
||||
# Build messages
|
||||
messages = _build_messages_from_state(ctx)
|
||||
|
||||
# Add input as user message if provided
|
||||
if input_value:
|
||||
if isinstance(input_value, str):
|
||||
messages.append(Message(role="user", text=input_value))
|
||||
elif isinstance(input_value, Message):
|
||||
messages.append(input_value)
|
||||
|
||||
logger.debug(f"InvokePromptAgent: calling '{agent_name}' with {len(messages)} messages")
|
||||
|
||||
# Build options for kwargs propagation to agent tools
|
||||
prompt_run_kwargs = ctx.run_kwargs
|
||||
prompt_options: dict[str, Any] | None = None
|
||||
if prompt_run_kwargs:
|
||||
# Merge caller-provided options to avoid duplicate keyword argument
|
||||
prompt_options = dict(prompt_run_kwargs.get("options") or {})
|
||||
prompt_options["additional_function_arguments"] = prompt_run_kwargs
|
||||
# Exclude 'options' from splat to avoid TypeError on duplicate keyword
|
||||
prompt_run_kwargs = {k: v for k, v in prompt_run_kwargs.items() if k != "options"}
|
||||
|
||||
# Invoke the agent
|
||||
try:
|
||||
if hasattr(agent, "run"):
|
||||
# Try streaming first
|
||||
try:
|
||||
updates: list[Any] = []
|
||||
|
||||
async for chunk in agent.run(messages, stream=True, options=prompt_options, **prompt_run_kwargs):
|
||||
updates.append(chunk)
|
||||
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
yield AgentStreamingChunkEvent(
|
||||
agent_name=agent_name,
|
||||
chunk=chunk.text,
|
||||
)
|
||||
|
||||
# Build consolidated response from updates
|
||||
response = AgentResponse.from_updates(updates)
|
||||
text = response.text
|
||||
response_messages = response.messages
|
||||
|
||||
ctx.state.set_agent_result(text=text, messages=response_messages)
|
||||
|
||||
if text:
|
||||
ctx.state.add_conversation_message(Message(role="assistant", text=text))
|
||||
|
||||
if output_path:
|
||||
ctx.state.set(output_path, text)
|
||||
|
||||
yield AgentResponseEvent(
|
||||
agent_name=agent_name,
|
||||
text=text,
|
||||
messages=response_messages,
|
||||
)
|
||||
|
||||
except TypeError:
|
||||
# Agent doesn't support streaming, fall back to non-streaming
|
||||
response = await agent.run(messages, options=prompt_options, **prompt_run_kwargs)
|
||||
text = response.text
|
||||
response_messages = response.messages
|
||||
|
||||
ctx.state.set_agent_result(text=text, messages=response_messages)
|
||||
|
||||
if text:
|
||||
ctx.state.add_conversation_message(Message(role="assistant", text=text))
|
||||
|
||||
if output_path:
|
||||
ctx.state.set(output_path, text)
|
||||
|
||||
yield AgentResponseEvent(
|
||||
agent_name=agent_name,
|
||||
text=text,
|
||||
messages=response_messages,
|
||||
)
|
||||
else:
|
||||
logger.error(f"InvokePromptAgent: agent '{agent_name}' has no run method")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"InvokePromptAgent: error invoking agent '{agent_name}': {e}")
|
||||
raise
|
||||
@@ -1,572 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Basic action handlers for variable manipulation and output.
|
||||
|
||||
This module implements handlers for:
|
||||
- SetValue: Set a variable in the workflow state
|
||||
- AppendValue: Append a value to a list variable
|
||||
- SendActivity: Send text or attachments to the user
|
||||
- EmitEvent: Emit a custom workflow event
|
||||
|
||||
Note: All handlers are defined as async generators (AsyncGenerator[WorkflowEvent, None])
|
||||
for consistency with the ActionHandler protocol, even when they don't perform async
|
||||
operations. This uniform interface allows the workflow executor to consume all handlers
|
||||
the same way, and some handlers (like InvokeAzureAgent) genuinely require async for
|
||||
network calls. The `return; yield` pattern makes a function an async generator without
|
||||
actually yielding any events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
AttachmentOutputEvent,
|
||||
CustomEvent,
|
||||
TextOutputEvent,
|
||||
WorkflowEvent,
|
||||
action_handler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._state import WorkflowState
|
||||
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@action_handler("SetValue")
|
||||
async def handle_set_value(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Set a value in the workflow state.
|
||||
|
||||
Action schema:
|
||||
kind: SetValue
|
||||
path: Local.variableName # or Workflow.Outputs.result
|
||||
value: =expression or literal value
|
||||
"""
|
||||
path = ctx.action.get("path")
|
||||
value = ctx.action.get("value")
|
||||
|
||||
if not path:
|
||||
logger.warning("SetValue action missing 'path' property")
|
||||
return
|
||||
|
||||
# Evaluate the value if it's an expression
|
||||
evaluated_value = ctx.state.eval_if_expression(value)
|
||||
|
||||
logger.debug(f"SetValue: {path} = {evaluated_value}")
|
||||
ctx.state.set(path, evaluated_value)
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("SetVariable")
|
||||
async def handle_set_variable(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Set a variable in the workflow state (.NET workflow format).
|
||||
|
||||
This is an alias for SetValue with 'variable' instead of 'path'.
|
||||
|
||||
Action schema:
|
||||
kind: SetVariable
|
||||
variable: Local.variableName
|
||||
value: =expression or literal value
|
||||
"""
|
||||
variable = ctx.action.get("variable")
|
||||
value = ctx.action.get("value")
|
||||
|
||||
if not variable:
|
||||
logger.warning("SetVariable action missing 'variable' property")
|
||||
return
|
||||
|
||||
# Evaluate the value if it's an expression
|
||||
evaluated_value = ctx.state.eval_if_expression(value)
|
||||
|
||||
# Use .NET-style variable names directly (Local.X, System.X, Workflow.X)
|
||||
path = _normalize_variable_path(variable)
|
||||
|
||||
logger.debug(f"SetVariable: {variable} ({path}) = {evaluated_value}")
|
||||
ctx.state.set(path, evaluated_value)
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
def _normalize_variable_path(variable: str) -> str:
|
||||
"""Normalize variable names to ensure they have a scope prefix.
|
||||
|
||||
Args:
|
||||
variable: Variable name like 'Local.X' or 'System.ConversationId'
|
||||
|
||||
Returns:
|
||||
The variable path with a scope prefix (defaults to Local if none provided)
|
||||
"""
|
||||
if variable.startswith(("Local.", "System.", "Workflow.", "Agent.", "Conversation.")):
|
||||
# Already has a proper namespace
|
||||
return variable
|
||||
if "." in variable:
|
||||
# Has some namespace, use as-is
|
||||
return variable
|
||||
# Default to Local scope
|
||||
return "Local." + variable
|
||||
|
||||
|
||||
@action_handler("AppendValue")
|
||||
async def handle_append_value(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Append a value to a list in the workflow state.
|
||||
|
||||
Action schema:
|
||||
kind: AppendValue
|
||||
path: Local.results
|
||||
value: =expression or literal value
|
||||
"""
|
||||
path = ctx.action.get("path")
|
||||
value = ctx.action.get("value")
|
||||
|
||||
if not path:
|
||||
logger.warning("AppendValue action missing 'path' property")
|
||||
return
|
||||
|
||||
# Evaluate the value if it's an expression
|
||||
evaluated_value = ctx.state.eval_if_expression(value)
|
||||
|
||||
logger.debug(f"AppendValue: {path} += {evaluated_value}")
|
||||
ctx.state.append(path, evaluated_value)
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("SendActivity")
|
||||
async def handle_send_activity(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Send text or attachments to the user.
|
||||
|
||||
Action schema (object form):
|
||||
kind: SendActivity
|
||||
activity:
|
||||
text: =expression or literal text
|
||||
attachments:
|
||||
- content: ...
|
||||
contentType: text/plain
|
||||
|
||||
Action schema (simple form):
|
||||
kind: SendActivity
|
||||
activity: =expression or literal text
|
||||
"""
|
||||
activity = ctx.action.get("activity", {})
|
||||
|
||||
# Handle simple string form
|
||||
if isinstance(activity, str):
|
||||
evaluated_text = ctx.state.eval_if_expression(activity)
|
||||
if evaluated_text:
|
||||
logger.debug(
|
||||
"SendActivity: text = %s", evaluated_text[:100] if len(str(evaluated_text)) > 100 else evaluated_text
|
||||
)
|
||||
yield TextOutputEvent(text=str(evaluated_text))
|
||||
return
|
||||
|
||||
# Handle object form - text output
|
||||
text = activity.get("text")
|
||||
if text:
|
||||
evaluated_text = ctx.state.eval_if_expression(text)
|
||||
if evaluated_text:
|
||||
logger.debug(
|
||||
"SendActivity: text = %s", evaluated_text[:100] if len(str(evaluated_text)) > 100 else evaluated_text
|
||||
)
|
||||
yield TextOutputEvent(text=str(evaluated_text))
|
||||
|
||||
# Handle attachments
|
||||
attachments = activity.get("attachments", [])
|
||||
for attachment in attachments:
|
||||
content = attachment.get("content")
|
||||
content_type = attachment.get("contentType", "application/octet-stream")
|
||||
|
||||
if content:
|
||||
evaluated_content = ctx.state.eval_if_expression(content)
|
||||
logger.debug(f"SendActivity: attachment type={content_type}")
|
||||
yield AttachmentOutputEvent(content=evaluated_content, content_type=content_type)
|
||||
|
||||
|
||||
@action_handler("EmitEvent")
|
||||
async def handle_emit_event(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Emit a custom workflow event.
|
||||
|
||||
Action schema:
|
||||
kind: EmitEvent
|
||||
event:
|
||||
name: eventName
|
||||
data: =expression or literal data
|
||||
"""
|
||||
event_def = ctx.action.get("event", {})
|
||||
name = event_def.get("name")
|
||||
data = event_def.get("data")
|
||||
|
||||
if not name:
|
||||
logger.warning("EmitEvent action missing 'event.name' property")
|
||||
return
|
||||
|
||||
# Evaluate data if it's an expression
|
||||
evaluated_data = ctx.state.eval_if_expression(data)
|
||||
|
||||
logger.debug(f"EmitEvent: {name} = {evaluated_data}")
|
||||
yield CustomEvent(name=name, data=evaluated_data)
|
||||
|
||||
|
||||
def _evaluate_dict_values(d: dict[str, Any], state: WorkflowState) -> dict[str, Any]:
|
||||
"""Recursively evaluate PowerFx expressions in a dictionary.
|
||||
|
||||
Args:
|
||||
d: Dictionary that may contain expression values
|
||||
state: The workflow state for expression evaluation
|
||||
|
||||
Returns:
|
||||
Dictionary with all expressions evaluated
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(value, str):
|
||||
result[key] = state.eval_if_expression(value)
|
||||
elif isinstance(value, dict):
|
||||
result[key] = _evaluate_dict_values(cast(dict[str, Any], value), state)
|
||||
elif isinstance(value, list):
|
||||
evaluated_list: list[Any] = []
|
||||
for list_item in value:
|
||||
if isinstance(list_item, dict):
|
||||
evaluated_list.append(_evaluate_dict_values(cast(dict[str, Any], list_item), state))
|
||||
elif isinstance(list_item, str):
|
||||
evaluated_list.append(state.eval_if_expression(list_item))
|
||||
else:
|
||||
evaluated_list.append(list_item)
|
||||
result[key] = evaluated_list
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
@action_handler("SetTextVariable")
|
||||
async def handle_set_text_variable(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Set a text variable with string interpolation support.
|
||||
|
||||
This is similar to SetVariable but supports multi-line text with
|
||||
{Local.Variable} style interpolation.
|
||||
|
||||
Action schema:
|
||||
kind: SetTextVariable
|
||||
variable: Local.myText
|
||||
value: |-
|
||||
Multi-line text with {Local.Variable} interpolation
|
||||
and more content here.
|
||||
"""
|
||||
variable = ctx.action.get("variable")
|
||||
value = ctx.action.get("value")
|
||||
|
||||
if not variable:
|
||||
logger.warning("SetTextVariable action missing 'variable' property")
|
||||
return
|
||||
|
||||
# Evaluate the value - handle string interpolation
|
||||
if isinstance(value, str):
|
||||
evaluated_value = _interpolate_string(value, ctx.state)
|
||||
else:
|
||||
evaluated_value = ctx.state.eval_if_expression(value)
|
||||
|
||||
path = _normalize_variable_path(variable)
|
||||
|
||||
logger.debug(f"SetTextVariable: {variable} ({path}) = {str(evaluated_value)[:100]}")
|
||||
ctx.state.set(path, evaluated_value)
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("SetMultipleVariables")
|
||||
async def handle_set_multiple_variables(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Set multiple variables at once.
|
||||
|
||||
Action schema:
|
||||
kind: SetMultipleVariables
|
||||
variables:
|
||||
- variable: Local.var1
|
||||
value: value1
|
||||
- variable: Local.var2
|
||||
value: =expression
|
||||
"""
|
||||
variables = ctx.action.get("variables", [])
|
||||
|
||||
for var_def in variables:
|
||||
variable = var_def.get("variable")
|
||||
value = var_def.get("value")
|
||||
|
||||
if not variable:
|
||||
logger.warning("SetMultipleVariables: variable entry missing 'variable' property")
|
||||
continue
|
||||
|
||||
evaluated_value = ctx.state.eval_if_expression(value)
|
||||
path = _normalize_variable_path(variable)
|
||||
|
||||
logger.debug(f"SetMultipleVariables: {variable} ({path}) = {evaluated_value}")
|
||||
ctx.state.set(path, evaluated_value)
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("ResetVariable")
|
||||
async def handle_reset_variable(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Reset a variable to its default/blank state.
|
||||
|
||||
Action schema:
|
||||
kind: ResetVariable
|
||||
variable: Local.variableName
|
||||
"""
|
||||
variable = ctx.action.get("variable")
|
||||
|
||||
if not variable:
|
||||
logger.warning("ResetVariable action missing 'variable' property")
|
||||
return
|
||||
|
||||
path = _normalize_variable_path(variable)
|
||||
|
||||
logger.debug(f"ResetVariable: {variable} ({path}) = None")
|
||||
ctx.state.set(path, None)
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("ClearAllVariables")
|
||||
async def handle_clear_all_variables(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Clear all turn-scoped variables.
|
||||
|
||||
Action schema:
|
||||
kind: ClearAllVariables
|
||||
"""
|
||||
logger.debug("ClearAllVariables: clearing turn scope")
|
||||
ctx.state.reset_local()
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("CreateConversation")
|
||||
async def handle_create_conversation(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Create a new conversation context.
|
||||
|
||||
Action schema (.NET style):
|
||||
kind: CreateConversation
|
||||
conversationId: Local.myConversationId # Variable to store the generated ID
|
||||
|
||||
The conversationId parameter is the OUTPUT variable where the generated
|
||||
conversation ID will be stored. This matches .NET behavior where:
|
||||
- A unique conversation ID is always auto-generated
|
||||
- The conversationId parameter specifies where to store it
|
||||
"""
|
||||
import uuid
|
||||
|
||||
conversation_id_var = ctx.action.get("conversationId")
|
||||
|
||||
# Always generate a unique ID (.NET behavior)
|
||||
generated_id = str(uuid.uuid4())
|
||||
|
||||
# Store conversation in state
|
||||
conversations: dict[str, Any] = ctx.state.get("System.conversations") or {}
|
||||
conversations[generated_id] = {
|
||||
"id": generated_id,
|
||||
"messages": [],
|
||||
"created_at": None, # Could add timestamp
|
||||
}
|
||||
ctx.state.set("System.conversations", conversations)
|
||||
|
||||
logger.debug(f"CreateConversation: created {generated_id}")
|
||||
|
||||
# Store the generated ID in the specified variable (.NET style output binding)
|
||||
if conversation_id_var:
|
||||
output_path = _normalize_variable_path(conversation_id_var)
|
||||
ctx.state.set(output_path, generated_id)
|
||||
logger.debug(f"CreateConversation: bound to {output_path} = {generated_id}")
|
||||
|
||||
# Also handle legacy output binding for backwards compatibility
|
||||
output = ctx.action.get("output", {})
|
||||
output_var = output.get("conversationId")
|
||||
if output_var:
|
||||
output_path = _normalize_variable_path(output_var)
|
||||
ctx.state.set(output_path, generated_id)
|
||||
logger.debug(f"CreateConversation: legacy output bound to {output_path}")
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("AddConversationMessage")
|
||||
async def handle_add_conversation_message(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Add a message to a conversation.
|
||||
|
||||
Action schema:
|
||||
kind: AddConversationMessage
|
||||
conversationId: =expression or variable reference
|
||||
message:
|
||||
role: user | assistant | system
|
||||
content: =expression or literal text
|
||||
"""
|
||||
conversation_id = ctx.action.get("conversationId")
|
||||
message_def = ctx.action.get("message", {})
|
||||
|
||||
if not conversation_id:
|
||||
logger.warning("AddConversationMessage missing 'conversationId' property")
|
||||
return
|
||||
|
||||
# Evaluate conversation ID
|
||||
evaluated_id = ctx.state.eval_if_expression(conversation_id)
|
||||
|
||||
# Evaluate message content
|
||||
role = message_def.get("role", "user")
|
||||
content = message_def.get("content", "")
|
||||
|
||||
evaluated_content = ctx.state.eval_if_expression(content)
|
||||
if isinstance(evaluated_content, str):
|
||||
evaluated_content = _interpolate_string(evaluated_content, ctx.state)
|
||||
|
||||
# Get or create conversation
|
||||
conversations: dict[str, Any] = ctx.state.get("System.conversations") or {}
|
||||
if evaluated_id not in conversations:
|
||||
conversations[evaluated_id] = {"id": evaluated_id, "messages": []}
|
||||
|
||||
# Add message
|
||||
message: dict[str, Any] = {"role": role, "content": evaluated_content}
|
||||
conv_entry: dict[str, Any] = dict(conversations[evaluated_id])
|
||||
messages_list: list[Any] = list(conv_entry.get("messages", []))
|
||||
messages_list.append(message)
|
||||
conv_entry["messages"] = messages_list
|
||||
conversations[evaluated_id] = conv_entry
|
||||
ctx.state.set("System.conversations", conversations)
|
||||
|
||||
# Also add to global conversation state
|
||||
ctx.state.add_conversation_message(message)
|
||||
|
||||
logger.debug(f"AddConversationMessage: added {role} message to {evaluated_id}")
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("CopyConversationMessages")
|
||||
async def handle_copy_conversation_messages(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Copy messages from one conversation to another.
|
||||
|
||||
Action schema:
|
||||
kind: CopyConversationMessages
|
||||
sourceConversationId: =expression
|
||||
targetConversationId: =expression
|
||||
count: 10 # optional, number of messages to copy
|
||||
"""
|
||||
source_id = ctx.action.get("sourceConversationId")
|
||||
target_id = ctx.action.get("targetConversationId")
|
||||
count = ctx.action.get("count")
|
||||
|
||||
if not source_id or not target_id:
|
||||
logger.warning("CopyConversationMessages missing source or target conversation ID")
|
||||
return
|
||||
|
||||
# Evaluate IDs
|
||||
evaluated_source = ctx.state.eval_if_expression(source_id)
|
||||
evaluated_target = ctx.state.eval_if_expression(target_id)
|
||||
|
||||
# Get conversations
|
||||
conversations: dict[str, Any] = ctx.state.get("System.conversations") or {}
|
||||
|
||||
source_conv: dict[str, Any] = conversations.get(evaluated_source, {})
|
||||
source_messages: list[Any] = source_conv.get("messages", [])
|
||||
|
||||
# Limit messages if count specified
|
||||
if count is not None:
|
||||
source_messages = source_messages[-count:]
|
||||
|
||||
# Get or create target conversation
|
||||
if evaluated_target not in conversations:
|
||||
conversations[evaluated_target] = {"id": evaluated_target, "messages": []}
|
||||
|
||||
# Copy messages
|
||||
target_entry: dict[str, Any] = dict(conversations[evaluated_target])
|
||||
target_messages: list[Any] = list(target_entry.get("messages", []))
|
||||
target_messages.extend(source_messages)
|
||||
target_entry["messages"] = target_messages
|
||||
conversations[evaluated_target] = target_entry
|
||||
ctx.state.set("System.conversations", conversations)
|
||||
|
||||
logger.debug(
|
||||
"CopyConversationMessages: copied %d messages from %s to %s",
|
||||
len(source_messages),
|
||||
evaluated_source,
|
||||
evaluated_target,
|
||||
)
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
@action_handler("RetrieveConversationMessages")
|
||||
async def handle_retrieve_conversation_messages(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Retrieve messages from a conversation and store in a variable.
|
||||
|
||||
Action schema:
|
||||
kind: RetrieveConversationMessages
|
||||
conversationId: =expression
|
||||
output:
|
||||
messages: Local.myMessages
|
||||
count: 10 # optional
|
||||
"""
|
||||
conversation_id = ctx.action.get("conversationId")
|
||||
output = ctx.action.get("output", {})
|
||||
count = ctx.action.get("count")
|
||||
|
||||
if not conversation_id:
|
||||
logger.warning("RetrieveConversationMessages missing 'conversationId' property")
|
||||
return
|
||||
|
||||
# Evaluate conversation ID
|
||||
evaluated_id = ctx.state.eval_if_expression(conversation_id)
|
||||
|
||||
# Get messages
|
||||
conversations: dict[str, Any] = ctx.state.get("System.conversations") or {}
|
||||
conv: dict[str, Any] = conversations.get(evaluated_id, {})
|
||||
messages: list[Any] = conv.get("messages", [])
|
||||
|
||||
# Limit messages if count specified
|
||||
if count is not None:
|
||||
messages = messages[-count:]
|
||||
|
||||
# Handle output binding
|
||||
output_var = output.get("messages")
|
||||
if output_var:
|
||||
output_path = _normalize_variable_path(output_var)
|
||||
ctx.state.set(output_path, messages)
|
||||
logger.debug(f"RetrieveConversationMessages: bound {len(messages)} messages to {output_path}")
|
||||
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
|
||||
def _interpolate_string(text: str, state: WorkflowState) -> str:
|
||||
"""Interpolate {Variable.Path} references in a string.
|
||||
|
||||
Args:
|
||||
text: Text that may contain {Variable.Path} references
|
||||
state: The workflow state for variable lookup
|
||||
|
||||
Returns:
|
||||
Text with variables interpolated
|
||||
"""
|
||||
import re
|
||||
|
||||
def replace_var(match: re.Match[str]) -> str:
|
||||
var_path: str = match.group(1)
|
||||
# Map .NET style to Python style
|
||||
path = _normalize_variable_path(var_path)
|
||||
value = state.get(path)
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
# Match {Variable.Path} patterns
|
||||
pattern = r"\{([A-Za-z][A-Za-z0-9_.]*)\}"
|
||||
return re.sub(pattern, replace_var, text)
|
||||
-396
@@ -1,396 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Control flow action handlers for declarative workflows.
|
||||
|
||||
This module implements handlers for:
|
||||
- Foreach: Iterate over a collection and execute nested actions
|
||||
- If: Conditional branching
|
||||
- Switch: Multi-way branching based on value matching
|
||||
- RepeatUntil: Loop until a condition is met
|
||||
- BreakLoop: Exit the current loop
|
||||
- ContinueLoop: Skip to the next iteration
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
LoopControlSignal,
|
||||
WorkflowEvent,
|
||||
action_handler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@action_handler("Foreach")
|
||||
async def handle_foreach(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Iterate over a collection and execute nested actions for each item.
|
||||
|
||||
Action schema:
|
||||
kind: Foreach
|
||||
source: =expression returning a collection
|
||||
itemName: itemVariable # optional, defaults to 'item'
|
||||
indexName: indexVariable # optional, defaults to 'index'
|
||||
actions:
|
||||
- kind: ...
|
||||
"""
|
||||
source_expr = ctx.action.get("source")
|
||||
item_name = ctx.action.get("itemName", "item")
|
||||
index_name = ctx.action.get("indexName", "index")
|
||||
actions = ctx.action.get("actions", [])
|
||||
|
||||
if not source_expr:
|
||||
logger.warning("Foreach action missing 'source' property")
|
||||
return
|
||||
|
||||
# Evaluate the source collection
|
||||
collection = ctx.state.eval_if_expression(source_expr)
|
||||
|
||||
if collection is None:
|
||||
logger.debug("Foreach: source evaluated to None, skipping")
|
||||
return
|
||||
|
||||
if not hasattr(collection, "__iter__"):
|
||||
logger.warning(f"Foreach: source is not iterable: {type(collection).__name__}")
|
||||
return
|
||||
|
||||
collection_len = len(list(collection)) if hasattr(collection, "__len__") else "?"
|
||||
logger.debug(f"Foreach: iterating over {collection_len} items")
|
||||
|
||||
# Iterate over the collection
|
||||
for index, item in enumerate(collection):
|
||||
# Set loop variables in the Local scope
|
||||
ctx.state.set(f"Local.{item_name}", item)
|
||||
ctx.state.set(f"Local.{index_name}", index)
|
||||
|
||||
# Execute nested actions
|
||||
try:
|
||||
async for event in ctx.execute_actions(actions, ctx.state):
|
||||
# Check for loop control signals
|
||||
if isinstance(event, LoopControlSignal):
|
||||
if event.signal_type == "break":
|
||||
logger.debug(f"Foreach: break signal received at index {index}")
|
||||
return
|
||||
elif event.signal_type == "continue":
|
||||
logger.debug(f"Foreach: continue signal received at index {index}")
|
||||
break # Break inner loop to continue outer
|
||||
else:
|
||||
yield event
|
||||
except StopIteration:
|
||||
# Continue signal was raised
|
||||
continue
|
||||
|
||||
|
||||
@action_handler("If")
|
||||
async def handle_if(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Conditional branching based on a condition expression.
|
||||
|
||||
Action schema:
|
||||
kind: If
|
||||
condition: =boolean expression
|
||||
then:
|
||||
- kind: ... # actions if condition is true
|
||||
else:
|
||||
- kind: ... # actions if condition is false (optional)
|
||||
"""
|
||||
condition_expr = ctx.action.get("condition")
|
||||
then_actions = ctx.action.get("then", [])
|
||||
else_actions = ctx.action.get("else", [])
|
||||
|
||||
if condition_expr is None:
|
||||
logger.warning("If action missing 'condition' property")
|
||||
return
|
||||
|
||||
# Evaluate the condition
|
||||
condition_result = ctx.state.eval_if_expression(condition_expr)
|
||||
|
||||
# Coerce to boolean
|
||||
is_truthy = bool(condition_result)
|
||||
|
||||
logger.debug(
|
||||
"If: condition '%s' evaluated to %s",
|
||||
condition_expr[:50] if len(str(condition_expr)) > 50 else condition_expr,
|
||||
is_truthy,
|
||||
)
|
||||
|
||||
# Execute the appropriate branch
|
||||
actions_to_execute = then_actions if is_truthy else else_actions
|
||||
|
||||
async for event in ctx.execute_actions(actions_to_execute, ctx.state):
|
||||
yield event
|
||||
|
||||
|
||||
@action_handler("Switch")
|
||||
async def handle_switch(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Multi-way branching based on value matching.
|
||||
|
||||
Action schema:
|
||||
kind: Switch
|
||||
value: =expression to match
|
||||
cases:
|
||||
- match: value1
|
||||
actions:
|
||||
- kind: ...
|
||||
- match: value2
|
||||
actions:
|
||||
- kind: ...
|
||||
default:
|
||||
- kind: ... # optional default actions
|
||||
"""
|
||||
value_expr = ctx.action.get("value")
|
||||
cases = ctx.action.get("cases", [])
|
||||
default_actions = ctx.action.get("default", [])
|
||||
|
||||
if not value_expr:
|
||||
logger.warning("Switch action missing 'value' property")
|
||||
return
|
||||
|
||||
# Evaluate the switch value
|
||||
switch_value = ctx.state.eval_if_expression(value_expr)
|
||||
|
||||
logger.debug(f"Switch: value = {switch_value}")
|
||||
|
||||
# Find matching case
|
||||
matched_actions = None
|
||||
for case in cases:
|
||||
match_value = ctx.state.eval_if_expression(case.get("match"))
|
||||
if switch_value == match_value:
|
||||
matched_actions = case.get("actions", [])
|
||||
logger.debug(f"Switch: matched case '{match_value}'")
|
||||
break
|
||||
|
||||
# Use default if no match found
|
||||
if matched_actions is None:
|
||||
matched_actions = default_actions
|
||||
logger.debug("Switch: using default case")
|
||||
|
||||
# Execute matched actions
|
||||
async for event in ctx.execute_actions(matched_actions, ctx.state):
|
||||
yield event
|
||||
|
||||
|
||||
@action_handler("RepeatUntil")
|
||||
async def handle_repeat_until(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Loop until a condition becomes true.
|
||||
|
||||
Action schema:
|
||||
kind: RepeatUntil
|
||||
condition: =boolean expression (loop exits when true)
|
||||
maxIterations: 100 # optional safety limit
|
||||
actions:
|
||||
- kind: ...
|
||||
"""
|
||||
condition_expr = ctx.action.get("condition")
|
||||
max_iterations = ctx.action.get("maxIterations", 100)
|
||||
actions = ctx.action.get("actions", [])
|
||||
|
||||
if condition_expr is None:
|
||||
logger.warning("RepeatUntil action missing 'condition' property")
|
||||
return
|
||||
|
||||
iteration = 0
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
ctx.state.set("Local.iteration", iteration)
|
||||
|
||||
logger.debug(f"RepeatUntil: iteration {iteration}")
|
||||
|
||||
# Execute loop body
|
||||
should_break = False
|
||||
async for event in ctx.execute_actions(actions, ctx.state):
|
||||
if isinstance(event, LoopControlSignal):
|
||||
if event.signal_type == "break":
|
||||
logger.debug(f"RepeatUntil: break signal received at iteration {iteration}")
|
||||
should_break = True
|
||||
break
|
||||
elif event.signal_type == "continue":
|
||||
logger.debug(f"RepeatUntil: continue signal received at iteration {iteration}")
|
||||
break
|
||||
else:
|
||||
yield event
|
||||
|
||||
if should_break:
|
||||
break
|
||||
|
||||
# Check exit condition
|
||||
condition_result = ctx.state.eval_if_expression(condition_expr)
|
||||
if bool(condition_result):
|
||||
logger.debug(f"RepeatUntil: condition met after {iteration} iterations")
|
||||
break
|
||||
|
||||
if iteration >= max_iterations:
|
||||
logger.warning(f"RepeatUntil: reached max iterations ({max_iterations})")
|
||||
|
||||
|
||||
@action_handler("BreakLoop")
|
||||
async def handle_break_loop(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029
|
||||
"""Signal to break out of the current loop.
|
||||
|
||||
Action schema:
|
||||
kind: BreakLoop
|
||||
"""
|
||||
logger.debug("BreakLoop: signaling break")
|
||||
yield LoopControlSignal(signal_type="break")
|
||||
|
||||
|
||||
@action_handler("ContinueLoop")
|
||||
async def handle_continue_loop(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029
|
||||
"""Signal to continue to the next iteration of the current loop.
|
||||
|
||||
Action schema:
|
||||
kind: ContinueLoop
|
||||
"""
|
||||
logger.debug("ContinueLoop: signaling continue")
|
||||
yield LoopControlSignal(signal_type="continue")
|
||||
|
||||
|
||||
@action_handler("ConditionGroup")
|
||||
async def handle_condition_group(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Multi-condition branching (like else-if chains).
|
||||
|
||||
Evaluates conditions in order and executes the first matching condition's actions.
|
||||
If no conditions match and elseActions is provided, executes those.
|
||||
|
||||
Action schema:
|
||||
kind: ConditionGroup
|
||||
conditions:
|
||||
- condition: =boolean expression
|
||||
actions:
|
||||
- kind: ...
|
||||
- condition: =another expression
|
||||
actions:
|
||||
- kind: ...
|
||||
elseActions:
|
||||
- kind: ... # optional, executed if no conditions match
|
||||
"""
|
||||
conditions = ctx.action.get("conditions", [])
|
||||
else_actions = ctx.action.get("elseActions", [])
|
||||
|
||||
matched = False
|
||||
for condition_def in conditions:
|
||||
condition_expr = condition_def.get("condition")
|
||||
actions = condition_def.get("actions", [])
|
||||
|
||||
if condition_expr is None:
|
||||
logger.warning("ConditionGroup condition missing 'condition' property")
|
||||
continue
|
||||
|
||||
# Evaluate the condition
|
||||
condition_result = ctx.state.eval_if_expression(condition_expr)
|
||||
is_truthy = bool(condition_result)
|
||||
|
||||
logger.debug(
|
||||
"ConditionGroup: condition '%s' evaluated to %s",
|
||||
str(condition_expr)[:50] if len(str(condition_expr)) > 50 else condition_expr,
|
||||
is_truthy,
|
||||
)
|
||||
|
||||
if is_truthy:
|
||||
matched = True
|
||||
# Execute this condition's actions
|
||||
async for event in ctx.execute_actions(actions, ctx.state):
|
||||
yield event
|
||||
# Only execute the first matching condition
|
||||
break
|
||||
|
||||
# Execute elseActions if no condition matched
|
||||
if not matched and else_actions:
|
||||
logger.debug("ConditionGroup: no conditions matched, executing elseActions")
|
||||
async for event in ctx.execute_actions(else_actions, ctx.state):
|
||||
yield event
|
||||
|
||||
|
||||
@action_handler("GotoAction")
|
||||
async def handle_goto_action(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029
|
||||
"""Jump to another action by ID (triggers re-execution from that action).
|
||||
|
||||
Note: GotoAction in the .NET implementation creates a loop by restarting
|
||||
execution from a specific action. In Python, we emit a GotoSignal that
|
||||
the top-level executor should handle.
|
||||
|
||||
Action schema:
|
||||
kind: GotoAction
|
||||
actionId: target_action_id
|
||||
"""
|
||||
action_id = ctx.action.get("actionId")
|
||||
|
||||
if not action_id:
|
||||
logger.warning("GotoAction missing 'actionId' property")
|
||||
return
|
||||
|
||||
logger.debug(f"GotoAction: jumping to action '{action_id}'")
|
||||
|
||||
# Emit a goto signal that the executor should handle
|
||||
yield GotoSignal(target_action_id=action_id)
|
||||
|
||||
|
||||
class GotoSignal(WorkflowEvent):
|
||||
"""Signal to jump to a specific action by ID.
|
||||
|
||||
This signal is used by GotoAction to implement control flow jumps.
|
||||
The top-level executor should handle this signal appropriately.
|
||||
"""
|
||||
|
||||
def __init__(self, target_action_id: str) -> None:
|
||||
self.target_action_id = target_action_id
|
||||
|
||||
|
||||
class EndWorkflowSignal(WorkflowEvent):
|
||||
"""Signal to end the workflow execution.
|
||||
|
||||
This signal causes the workflow to terminate gracefully.
|
||||
"""
|
||||
|
||||
def __init__(self, reason: str | None = None) -> None:
|
||||
self.reason = reason
|
||||
|
||||
|
||||
class EndConversationSignal(WorkflowEvent):
|
||||
"""Signal to end the current conversation.
|
||||
|
||||
This signal causes the conversation to terminate while the workflow may continue.
|
||||
"""
|
||||
|
||||
def __init__(self, conversation_id: str | None = None, reason: str | None = None) -> None:
|
||||
self.conversation_id = conversation_id
|
||||
self.reason = reason
|
||||
|
||||
|
||||
@action_handler("EndWorkflow")
|
||||
async def handle_end_workflow(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029
|
||||
"""End the workflow execution.
|
||||
|
||||
Action schema:
|
||||
kind: EndWorkflow
|
||||
reason: Optional reason for ending (for logging)
|
||||
"""
|
||||
reason = ctx.action.get("reason")
|
||||
|
||||
logger.debug(f"EndWorkflow: ending workflow{f' (reason: {reason})' if reason else ''}")
|
||||
|
||||
yield EndWorkflowSignal(reason=reason)
|
||||
|
||||
|
||||
@action_handler("EndConversation")
|
||||
async def handle_end_conversation(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]: # noqa: RUF029
|
||||
"""End the current conversation.
|
||||
|
||||
Action schema:
|
||||
kind: EndConversation
|
||||
conversationId: Optional specific conversation to end
|
||||
reason: Optional reason for ending
|
||||
"""
|
||||
conversation_id = ctx.action.get("conversationId")
|
||||
reason = ctx.action.get("reason")
|
||||
|
||||
# Evaluate conversation ID if provided
|
||||
if conversation_id:
|
||||
evaluated_id = ctx.state.eval_if_expression(conversation_id)
|
||||
else:
|
||||
evaluated_id = ctx.state.get("System.ConversationId")
|
||||
|
||||
logger.debug(f"EndConversation: ending conversation {evaluated_id}{f' (reason: {reason})' if reason else ''}")
|
||||
|
||||
yield EndConversationSignal(conversation_id=evaluated_id, reason=reason)
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Error handling action handlers for declarative workflows.
|
||||
|
||||
This module implements handlers for:
|
||||
- ThrowException: Raise an error that can be caught by TryCatch
|
||||
- TryCatch: Try-catch-finally error handling
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework.exceptions import WorkflowException
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
WorkflowEvent,
|
||||
action_handler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
class WorkflowActionError(WorkflowException):
|
||||
"""Exception raised by ThrowException action."""
|
||||
|
||||
def __init__(self, message: str, code: str | None = None):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorEvent(WorkflowEvent):
|
||||
"""Event emitted when an error occurs."""
|
||||
|
||||
message: str
|
||||
"""The error message."""
|
||||
|
||||
code: str | None = None
|
||||
"""Optional error code."""
|
||||
|
||||
source_action: str | None = None
|
||||
"""The action that caused the error."""
|
||||
|
||||
|
||||
@action_handler("ThrowException")
|
||||
async def handle_throw_exception(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Raise an exception that can be caught by TryCatch.
|
||||
|
||||
Action schema:
|
||||
kind: ThrowException
|
||||
message: =expression or literal error message
|
||||
code: ERROR_CODE # optional error code
|
||||
"""
|
||||
message_expr = ctx.action.get("message", "An error occurred")
|
||||
code = ctx.action.get("code")
|
||||
|
||||
# Evaluate the message if it's an expression
|
||||
message = ctx.state.eval_if_expression(message_expr)
|
||||
|
||||
logger.debug(f"ThrowException: {message} (code={code})")
|
||||
|
||||
raise WorkflowActionError(str(message), code)
|
||||
|
||||
# This yield is never reached but makes it a generator
|
||||
yield ErrorEvent(message=str(message), code=code) # type: ignore[unreachable]
|
||||
|
||||
|
||||
@action_handler("TryCatch")
|
||||
async def handle_try_catch(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]:
|
||||
"""Try-catch-finally error handling.
|
||||
|
||||
Action schema:
|
||||
kind: TryCatch
|
||||
try:
|
||||
- kind: ... # actions to try
|
||||
catch:
|
||||
- kind: ... # actions to execute on error (optional)
|
||||
finally:
|
||||
- kind: ... # actions to always execute (optional)
|
||||
|
||||
In the catch block, the following variables are available:
|
||||
Local.error.message: The error message
|
||||
Local.error.code: The error code (if provided)
|
||||
Local.error.type: The error type name
|
||||
"""
|
||||
try_actions = ctx.action.get("try", [])
|
||||
catch_actions = ctx.action.get("catch", [])
|
||||
finally_actions = ctx.action.get("finally", [])
|
||||
|
||||
error_occurred = False
|
||||
error_info = None
|
||||
|
||||
# Execute try block
|
||||
try:
|
||||
async for event in ctx.execute_actions(try_actions, ctx.state):
|
||||
yield event
|
||||
except WorkflowActionError as e:
|
||||
error_occurred = True
|
||||
error_info = {
|
||||
"message": str(e),
|
||||
"code": e.code,
|
||||
"type": "WorkflowActionError",
|
||||
}
|
||||
logger.debug(f"TryCatch: caught WorkflowActionError: {e}")
|
||||
except Exception as e:
|
||||
error_occurred = True
|
||||
error_info = {
|
||||
"message": str(e),
|
||||
"code": None,
|
||||
"type": type(e).__name__,
|
||||
}
|
||||
logger.debug(f"TryCatch: caught {type(e).__name__}: {e}")
|
||||
|
||||
# Execute catch block if error occurred
|
||||
if error_occurred and catch_actions:
|
||||
# Set error info in Local scope
|
||||
ctx.state.set("Local.error", error_info)
|
||||
|
||||
try:
|
||||
async for event in ctx.execute_actions(catch_actions, ctx.state):
|
||||
yield event
|
||||
finally:
|
||||
# Clean up error info (but don't interfere with finally block)
|
||||
pass
|
||||
|
||||
# Execute finally block
|
||||
if finally_actions:
|
||||
async for event in ctx.execute_actions(finally_actions, ctx.state):
|
||||
yield event
|
||||
+14
-3
@@ -27,6 +27,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal as _Decimal
|
||||
@@ -162,15 +163,19 @@ class DeclarativeWorkflowState:
|
||||
Args:
|
||||
inputs: Initial workflow inputs (become Workflow.Inputs.*)
|
||||
"""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
state_data: DeclarativeStateData = {
|
||||
"Inputs": dict(inputs) if inputs else {},
|
||||
"Outputs": {},
|
||||
"Local": {},
|
||||
"System": {
|
||||
"ConversationId": "default",
|
||||
"ConversationId": conversation_id,
|
||||
"LastMessage": {"Text": "", "Id": ""},
|
||||
"LastMessageText": "",
|
||||
"LastMessageId": "",
|
||||
"conversations": {
|
||||
conversation_id: {"id": conversation_id, "messages": []},
|
||||
},
|
||||
},
|
||||
"Agent": {},
|
||||
"Conversation": {"messages": [], "history": []},
|
||||
@@ -854,12 +859,18 @@ class DeclarativeActionExecutor(Executor):
|
||||
# Structured inputs - use directly
|
||||
state.initialize(trigger) # type: ignore
|
||||
elif isinstance(trigger, str):
|
||||
# String input - wrap in dict
|
||||
# String input - wrap in dict and populate System.LastMessage.Text
|
||||
# so YAML expressions like =System.LastMessage.Text see the user input
|
||||
state.initialize({"input": trigger})
|
||||
state.set("System.LastMessage", {"Text": trigger, "Id": ""})
|
||||
state.set("System.LastMessageText", trigger)
|
||||
elif not isinstance(
|
||||
trigger, (ActionTrigger, ActionComplete, ConditionResult, LoopIterationResult, LoopControl)
|
||||
):
|
||||
# Any other type - convert to string like .NET's DefaultTransform
|
||||
state.initialize({"input": str(trigger)})
|
||||
input_str = str(trigger)
|
||||
state.initialize({"input": input_str})
|
||||
state.set("System.LastMessage", {"Text": input_str, "Id": ""})
|
||||
state.set("System.LastMessageText", input_str)
|
||||
|
||||
return state
|
||||
|
||||
+32
-6
@@ -49,7 +49,17 @@ ALL_ACTION_EXECUTORS = {
|
||||
|
||||
# Action kinds that terminate control flow (no fall-through to successor)
|
||||
# These actions transfer control elsewhere and should not have sequential edges to the next action
|
||||
TERMINATOR_ACTIONS = frozenset({"Goto", "GotoAction", "BreakLoop", "ContinueLoop", "EndWorkflow", "EndDialog"})
|
||||
TERMINATOR_ACTIONS = frozenset({
|
||||
"Goto",
|
||||
"GotoAction",
|
||||
"BreakLoop",
|
||||
"ContinueLoop",
|
||||
"EndWorkflow",
|
||||
"EndDialog",
|
||||
"EndConversation",
|
||||
"CancelDialog",
|
||||
"CancelAllDialogs",
|
||||
})
|
||||
|
||||
# Required fields for specific action kinds (schema validation)
|
||||
# Each action needs at least one of the listed fields (checked with alternates)
|
||||
@@ -110,6 +120,7 @@ class DeclarativeWorkflowBuilder:
|
||||
agents: dict[str, Any] | None = None,
|
||||
checkpoint_storage: Any | None = None,
|
||||
validate: bool = True,
|
||||
max_iterations: int | None = None,
|
||||
):
|
||||
"""Initialize the builder.
|
||||
|
||||
@@ -119,6 +130,8 @@ class DeclarativeWorkflowBuilder:
|
||||
agents: Registry of agent instances by name (for InvokeAzureAgent actions)
|
||||
checkpoint_storage: Optional checkpoint storage for pause/resume support
|
||||
validate: Whether to validate the workflow definition before building (default: True)
|
||||
max_iterations: Maximum runner supersteps. Falls back to the YAML ``maxTurns``
|
||||
field, then to the core default (100).
|
||||
"""
|
||||
self._yaml_def = yaml_definition
|
||||
self._workflow_id = workflow_id or yaml_definition.get("name", "declarative_workflow")
|
||||
@@ -129,6 +142,11 @@ class DeclarativeWorkflowBuilder:
|
||||
self._pending_gotos: list[tuple[Any, str]] = [] # (goto_executor, target_id)
|
||||
self._validate = validate
|
||||
self._seen_explicit_ids: set[str] = set() # Track explicit IDs for duplicate detection
|
||||
# Resolve max_iterations: explicit arg > YAML maxTurns > core default
|
||||
resolved = max_iterations if max_iterations is not None else yaml_definition.get("maxTurns")
|
||||
if resolved is not None and (not isinstance(resolved, int) or resolved <= 0):
|
||||
raise ValueError(f"Invalid max_iterations/maxTurns value: {resolved!r}. Expected a positive integer.")
|
||||
self._max_iterations: int | None = resolved
|
||||
|
||||
def build(self) -> Workflow:
|
||||
"""Build the workflow graph.
|
||||
@@ -153,11 +171,14 @@ class DeclarativeWorkflowBuilder:
|
||||
# _create_executors_for_actions runs (which itself needs the builder to add edges).
|
||||
entry_node = JoinExecutor({"kind": "Entry"}, id="_workflow_entry")
|
||||
self._executors[entry_node.id] = entry_node
|
||||
builder = WorkflowBuilder(
|
||||
start_executor=entry_node,
|
||||
name=self._workflow_id,
|
||||
checkpoint_storage=self._checkpoint_storage,
|
||||
)
|
||||
builder_kwargs: dict[str, Any] = {
|
||||
"start_executor": entry_node,
|
||||
"name": self._workflow_id,
|
||||
"checkpoint_storage": self._checkpoint_storage,
|
||||
}
|
||||
if self._max_iterations is not None:
|
||||
builder_kwargs["max_iterations"] = self._max_iterations
|
||||
builder = WorkflowBuilder(**builder_kwargs)
|
||||
|
||||
# Create all executors and wire sequential edges
|
||||
first_executor = self._create_executors_for_actions(actions, builder)
|
||||
@@ -944,6 +965,11 @@ class DeclarativeWorkflowBuilder:
|
||||
|
||||
last_executor = chain[-1]
|
||||
|
||||
# Skip terminators — they handle their own control flow
|
||||
action_def = getattr(last_executor, "_action_def", {})
|
||||
if isinstance(action_def, dict) and action_def.get("kind", "") in TERMINATOR_ACTIONS:
|
||||
return None
|
||||
|
||||
# Check if last executor is a structure with branch_exits
|
||||
# In that case, we return the structure so its exits can be collected
|
||||
if hasattr(last_executor, "branch_exits"):
|
||||
|
||||
+12
-2
@@ -430,17 +430,27 @@ class InvokeAzureAgentExecutor(DeclarativeActionExecutor):
|
||||
agent_config = self._action_def.get("agent")
|
||||
|
||||
if isinstance(agent_config, str):
|
||||
if agent_config.startswith("="):
|
||||
evaluated = state.eval_if_expression(agent_config)
|
||||
return str(evaluated) if evaluated is not None else None
|
||||
return agent_config
|
||||
|
||||
if isinstance(agent_config, dict):
|
||||
agent_dict = cast(dict[str, Any], agent_config)
|
||||
name = agent_dict.get("name")
|
||||
if name is not None and isinstance(name, str):
|
||||
# Support dynamic agent name from expression (would need async eval)
|
||||
if name.startswith("="):
|
||||
evaluated = state.eval_if_expression(name)
|
||||
return str(evaluated) if evaluated is not None else None
|
||||
return str(name)
|
||||
|
||||
agent_name = self._action_def.get("agentName")
|
||||
return str(agent_name) if isinstance(agent_name, str) else None
|
||||
if isinstance(agent_name, str):
|
||||
if agent_name.startswith("="):
|
||||
evaluated = state.eval_if_expression(agent_name)
|
||||
return str(evaluated) if evaluated is not None else None
|
||||
return agent_name
|
||||
return None
|
||||
|
||||
def _get_input_config(self) -> tuple[dict[str, Any], Any, str | None, int]:
|
||||
"""Parse input configuration.
|
||||
|
||||
+37
@@ -6,6 +6,7 @@ These executors handle simple actions like SetValue, SendActivity, etc.
|
||||
Each action becomes a node in the workflow graph.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
@@ -80,6 +81,41 @@ class SetVariableExecutor(DeclarativeActionExecutor):
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
|
||||
class CreateConversationExecutor(DeclarativeActionExecutor):
|
||||
"""Executor for the CreateConversation action.
|
||||
|
||||
Generates a unique conversation ID and initialises a conversation entry
|
||||
in ``System.conversations``. The generated ID is stored at the state
|
||||
path specified by the ``conversationId`` parameter (if provided).
|
||||
"""
|
||||
|
||||
@handler
|
||||
async def handle_action(
|
||||
self,
|
||||
trigger: Any,
|
||||
ctx: WorkflowContext[ActionComplete],
|
||||
) -> None:
|
||||
"""Handle the CreateConversation action."""
|
||||
state = await self._ensure_state_initialized(ctx, trigger)
|
||||
|
||||
generated_id = str(uuid.uuid4())
|
||||
|
||||
# Store the generated ID at the requested path (e.g. "Local.myConvId")
|
||||
conversation_id_path = _get_variable_path(self._action_def, "conversationId")
|
||||
if conversation_id_path:
|
||||
state.set(conversation_id_path, generated_id)
|
||||
|
||||
# Initialise the conversation entry in System.conversations
|
||||
conversations: dict[str, Any] = state.get("System.conversations") or {}
|
||||
conversations[generated_id] = {
|
||||
"id": generated_id,
|
||||
"messages": [],
|
||||
}
|
||||
state.set("System.conversations", conversations)
|
||||
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
|
||||
class SetTextVariableExecutor(DeclarativeActionExecutor):
|
||||
"""Executor for the SetTextVariable action."""
|
||||
|
||||
@@ -560,6 +596,7 @@ class ParseValueExecutor(DeclarativeActionExecutor):
|
||||
|
||||
# Mapping of action kinds to executor classes
|
||||
BASIC_ACTION_EXECUTORS: dict[str, type[DeclarativeActionExecutor]] = {
|
||||
"CreateConversation": CreateConversationExecutor,
|
||||
"SetValue": SetValueExecutor,
|
||||
"SetVariable": SetVariableExecutor,
|
||||
"SetTextVariable": SetTextVariableExecutor,
|
||||
|
||||
+1
@@ -496,6 +496,7 @@ class JoinExecutor(DeclarativeActionExecutor):
|
||||
ctx: WorkflowContext[ActionComplete],
|
||||
) -> None:
|
||||
"""Simply pass through to continue the workflow."""
|
||||
await self._ensure_state_initialized(ctx, trigger)
|
||||
await ctx.send_message(ActionComplete())
|
||||
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class WorkflowFactory:
|
||||
bindings: Mapping[str, Any] | None = None,
|
||||
env_file: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
max_iterations: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the workflow factory.
|
||||
|
||||
@@ -101,6 +102,9 @@ class WorkflowFactory:
|
||||
bindings: Optional function bindings for tool calls within workflow actions.
|
||||
env_file: Optional path to .env file for environment variables used in agent creation.
|
||||
checkpoint_storage: Optional checkpoint storage enabling pause/resume functionality.
|
||||
max_iterations: Optional maximum runner supersteps. Overrides the YAML ``maxTurns``
|
||||
field and the core default (100). Workflows with ``GotoAction`` loops (e.g.
|
||||
DeepResearch) typically need a higher value.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -138,6 +142,7 @@ class WorkflowFactory:
|
||||
self._agents: dict[str, SupportsAgentRun | AgentExecutor] = dict(agents) if agents else {}
|
||||
self._bindings: dict[str, Any] = dict(bindings) if bindings else {}
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
self._max_iterations = max_iterations
|
||||
|
||||
def create_workflow_from_yaml_path(
|
||||
self,
|
||||
@@ -379,6 +384,7 @@ class WorkflowFactory:
|
||||
workflow_id=name,
|
||||
agents=agents,
|
||||
checkpoint_storage=self._checkpoint_storage,
|
||||
max_iterations=self._max_iterations,
|
||||
)
|
||||
workflow = graph_builder.build()
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Action handlers for declarative workflow execution.
|
||||
|
||||
This module provides the ActionHandler protocol and registry for executing
|
||||
workflow actions defined in YAML. Each action type (InvokeAzureAgent, Foreach, etc.)
|
||||
has a corresponding handler registered via the @action_handler decorator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._state import WorkflowState
|
||||
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionContext:
|
||||
"""Context passed to action handlers during execution.
|
||||
|
||||
Provides access to workflow state, the action definition, and methods
|
||||
for executing nested actions (for control flow constructs like Foreach).
|
||||
"""
|
||||
|
||||
state: WorkflowState
|
||||
"""The current workflow state with variables and agent results."""
|
||||
|
||||
action: dict[str, Any]
|
||||
"""The action definition from the YAML."""
|
||||
|
||||
execute_actions: ExecuteActionsFn
|
||||
"""Function to execute a list of nested actions (for Foreach, If, etc.)."""
|
||||
|
||||
agents: dict[str, Any]
|
||||
"""Registry of agent instances by name."""
|
||||
|
||||
bindings: dict[str, Any]
|
||||
"""Function bindings for tool calls."""
|
||||
|
||||
run_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Kwargs from workflow.run() to forward to agent invocations."""
|
||||
|
||||
@property
|
||||
def action_id(self) -> str | None:
|
||||
"""Get the action's unique identifier."""
|
||||
return self.action.get("id")
|
||||
|
||||
@property
|
||||
def display_name(self) -> str | None:
|
||||
"""Get the action's human-readable display name for debugging/logging."""
|
||||
return self.action.get("displayName")
|
||||
|
||||
@property
|
||||
def action_kind(self) -> str | None:
|
||||
"""Get the action's type/kind."""
|
||||
return self.action.get("kind")
|
||||
|
||||
|
||||
# Type alias for the nested action executor function
|
||||
ExecuteActionsFn = Callable[
|
||||
[list[dict[str, Any]], "WorkflowState"],
|
||||
AsyncGenerator["WorkflowEvent", None],
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowEvent:
|
||||
"""Base class for events emitted during workflow execution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextOutputEvent(WorkflowEvent):
|
||||
"""Event emitted when text should be sent to the user."""
|
||||
|
||||
text: str
|
||||
"""The text content to output."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttachmentOutputEvent(WorkflowEvent):
|
||||
"""Event emitted when an attachment should be sent to the user."""
|
||||
|
||||
content: Any
|
||||
"""The attachment content."""
|
||||
|
||||
content_type: str = "application/octet-stream"
|
||||
"""The MIME type of the attachment."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponseEvent(WorkflowEvent):
|
||||
"""Event emitted when an agent produces a response."""
|
||||
|
||||
agent_name: str
|
||||
"""The name of the agent that produced the response."""
|
||||
|
||||
text: str | None
|
||||
"""The text content of the response, if any."""
|
||||
|
||||
messages: list[Any]
|
||||
"""The messages from the agent response."""
|
||||
|
||||
tool_calls: list[Any] | None = None
|
||||
"""Any tool calls made by the agent."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStreamingChunkEvent(WorkflowEvent):
|
||||
"""Event emitted for streaming chunks from an agent."""
|
||||
|
||||
agent_name: str
|
||||
"""The name of the agent producing the chunk."""
|
||||
|
||||
chunk: str
|
||||
"""The streaming chunk content."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomEvent(WorkflowEvent):
|
||||
"""Custom event emitted via EmitEvent action."""
|
||||
|
||||
name: str
|
||||
"""The event name."""
|
||||
|
||||
data: Any
|
||||
"""The event data."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopControlSignal(WorkflowEvent):
|
||||
"""Signal for loop control (break/continue)."""
|
||||
|
||||
signal_type: str
|
||||
"""Either 'break' or 'continue'."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ActionHandler(Protocol):
|
||||
"""Protocol for action handlers.
|
||||
|
||||
Action handlers are async generators that execute a single action type
|
||||
and yield events as they process. They receive an ActionContext with
|
||||
the current state, action definition, and utilities for nested execution.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: ActionContext,
|
||||
) -> AsyncGenerator[WorkflowEvent]:
|
||||
"""Execute the action and yield events.
|
||||
|
||||
Args:
|
||||
ctx: The action context containing state, action definition, and utilities
|
||||
|
||||
Yields:
|
||||
WorkflowEvent instances as the action executes
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Global registry of action handlers
|
||||
_ACTION_HANDLERS: dict[str, ActionHandler] = {}
|
||||
|
||||
|
||||
def action_handler(action_kind: str) -> Callable[[ActionHandler], ActionHandler]:
|
||||
"""Decorator to register an action handler for a specific action type.
|
||||
|
||||
Args:
|
||||
action_kind: The action type this handler processes (e.g., 'InvokeAzureAgent')
|
||||
|
||||
Example:
|
||||
@action_handler("SetValue")
|
||||
async def handle_set_value(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
path = ctx.action.get("path")
|
||||
value = ctx.state.eval_if_expression(ctx.action.get("value"))
|
||||
ctx.state.set(path, value)
|
||||
return
|
||||
yield # Make it a generator
|
||||
"""
|
||||
|
||||
def decorator(func: ActionHandler) -> ActionHandler:
|
||||
_ACTION_HANDLERS[action_kind] = func
|
||||
logger.debug(f"Registered action handler for '{action_kind}'")
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_action_handler(action_kind: str) -> ActionHandler | None:
|
||||
"""Get the registered handler for an action type.
|
||||
|
||||
Args:
|
||||
action_kind: The action type to look up
|
||||
|
||||
Returns:
|
||||
The registered ActionHandler, or None if not found
|
||||
"""
|
||||
return _ACTION_HANDLERS.get(action_kind)
|
||||
|
||||
|
||||
def list_action_handlers() -> list[str]:
|
||||
"""List all registered action handler types.
|
||||
|
||||
Returns:
|
||||
A list of registered action type names
|
||||
"""
|
||||
return list(_ACTION_HANDLERS.keys())
|
||||
@@ -1,321 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Human-in-the-loop action handlers for declarative workflows.
|
||||
|
||||
This module implements handlers for human input patterns:
|
||||
- Question: Request human input with validation
|
||||
- RequestExternalInput: Request input from external system
|
||||
- ExternalLoop processing: Loop while waiting for external input
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
WorkflowEvent,
|
||||
action_handler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._state import WorkflowState
|
||||
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestionRequest(WorkflowEvent):
|
||||
"""Event emitted when the workflow needs user input via Question action.
|
||||
|
||||
When this event is yielded, the workflow execution should pause
|
||||
and wait for user input to be provided via workflow.send_response().
|
||||
|
||||
This is used by the Question, RequestExternalInput, and WaitForInput
|
||||
action handlers in the non-graph workflow path.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
"""Unique identifier for this request."""
|
||||
|
||||
prompt: str | None
|
||||
"""The prompt/question to display to the user."""
|
||||
|
||||
variable: str
|
||||
"""The variable where the response should be stored."""
|
||||
|
||||
validation: dict[str, Any] | None = None
|
||||
"""Optional validation rules for the input."""
|
||||
|
||||
choices: list[str] | None = None
|
||||
"""Optional list of valid choices."""
|
||||
|
||||
default_value: Any = None
|
||||
"""Default value if no input is provided."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternalLoopEvent(WorkflowEvent):
|
||||
"""Event emitted when entering an external input loop.
|
||||
|
||||
This event signals that the action is waiting for external input
|
||||
in a loop pattern (e.g., input.externalLoop.when condition).
|
||||
"""
|
||||
|
||||
action_id: str
|
||||
"""The ID of the action that requires external input."""
|
||||
|
||||
iteration: int
|
||||
"""The current iteration number (0-based)."""
|
||||
|
||||
condition_expression: str
|
||||
"""The PowerFx condition that must become false to exit the loop."""
|
||||
|
||||
|
||||
@action_handler("Question")
|
||||
async def handle_question(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Handle Question action - request human input with optional validation.
|
||||
|
||||
Action schema:
|
||||
kind: Question
|
||||
id: ask_name
|
||||
variable: Local.userName
|
||||
prompt: What is your name?
|
||||
validation:
|
||||
required: true
|
||||
minLength: 1
|
||||
maxLength: 100
|
||||
choices: # optional - present as multiple choice
|
||||
- Option A
|
||||
- Option B
|
||||
default: Option A # optional default value
|
||||
|
||||
The handler emits a QuestionRequest and expects the workflow runner
|
||||
to capture and provide the response before continuing.
|
||||
"""
|
||||
question_id = ctx.action.get("id", "question")
|
||||
variable = ctx.action.get("variable")
|
||||
prompt = ctx.action.get("prompt")
|
||||
question: dict[str, Any] | Any = ctx.action.get("question", {})
|
||||
validation = ctx.action.get("validation", {})
|
||||
choices = ctx.action.get("choices")
|
||||
default_value = ctx.action.get("default")
|
||||
|
||||
if not variable:
|
||||
logger.warning("Question action missing 'variable' property")
|
||||
return
|
||||
|
||||
# Evaluate prompt if it's an expression (support both 'prompt' and 'question.text')
|
||||
prompt_text: Any | None = None
|
||||
if isinstance(question, dict):
|
||||
question_dict: dict[str, Any] = cast(dict[str, Any], question)
|
||||
prompt_text = prompt or question_dict.get("text")
|
||||
else:
|
||||
prompt_text = prompt
|
||||
evaluated_prompt = ctx.state.eval_if_expression(prompt_text) if prompt_text else None
|
||||
|
||||
# Evaluate choices if they're expressions
|
||||
evaluated_choices = None
|
||||
if choices:
|
||||
evaluated_choices = [ctx.state.eval_if_expression(c) if isinstance(c, str) else c for c in choices]
|
||||
|
||||
logger.debug(f"Question: requesting input for {variable}")
|
||||
|
||||
# Emit the request event
|
||||
yield QuestionRequest(
|
||||
request_id=question_id,
|
||||
prompt=str(evaluated_prompt) if evaluated_prompt else None,
|
||||
variable=variable,
|
||||
validation=validation,
|
||||
choices=evaluated_choices,
|
||||
default_value=default_value,
|
||||
)
|
||||
|
||||
# Apply default value if specified (for non-interactive scenarios)
|
||||
if default_value is not None:
|
||||
evaluated_default = ctx.state.eval_if_expression(default_value)
|
||||
ctx.state.set(variable, evaluated_default)
|
||||
|
||||
|
||||
@action_handler("RequestExternalInput")
|
||||
async def handle_request_external_input(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Handle RequestExternalInput action - request input from external system.
|
||||
|
||||
Action schema:
|
||||
kind: RequestExternalInput
|
||||
id: get_approval
|
||||
variable: Local.approval
|
||||
prompt: Please approve or reject the request
|
||||
timeout: 300 # seconds
|
||||
default: "No feedback provided" # optional default value
|
||||
output:
|
||||
response: Local.approvalResponse
|
||||
timestamp: Local.approvalTime
|
||||
|
||||
Similar to Question but designed for external system integration
|
||||
rather than direct human input.
|
||||
"""
|
||||
request_id = ctx.action.get("id", "external_input")
|
||||
variable = ctx.action.get("variable")
|
||||
prompt = ctx.action.get("prompt")
|
||||
timeout = ctx.action.get("timeout") # seconds
|
||||
default_value = ctx.action.get("default")
|
||||
_output = ctx.action.get("output", {}) # Reserved for future use
|
||||
|
||||
if not variable:
|
||||
logger.warning("RequestExternalInput action missing 'variable' property")
|
||||
return
|
||||
|
||||
# Extract prompt text (support both 'prompt' string and 'prompt.text' object)
|
||||
prompt_text: Any | None = None
|
||||
if isinstance(prompt, dict):
|
||||
prompt_dict: dict[str, Any] = cast(dict[str, Any], prompt)
|
||||
prompt_text = prompt_dict.get("text")
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
# Evaluate prompt if it's an expression
|
||||
evaluated_prompt = ctx.state.eval_if_expression(prompt_text) if prompt_text else None
|
||||
|
||||
logger.debug(f"RequestExternalInput: requesting input for {variable}")
|
||||
|
||||
# Emit the request event
|
||||
yield QuestionRequest(
|
||||
request_id=request_id,
|
||||
prompt=str(evaluated_prompt) if evaluated_prompt else None,
|
||||
variable=variable,
|
||||
validation={"timeout": timeout} if timeout else None,
|
||||
default_value=default_value,
|
||||
)
|
||||
|
||||
# Apply default value if specified (for non-interactive scenarios)
|
||||
if default_value is not None:
|
||||
evaluated_default = ctx.state.eval_if_expression(default_value)
|
||||
ctx.state.set(variable, evaluated_default)
|
||||
|
||||
|
||||
@action_handler("WaitForInput")
|
||||
async def handle_wait_for_input(ctx: ActionContext) -> AsyncGenerator[WorkflowEvent]: # noqa: RUF029
|
||||
"""Handle WaitForInput action - pause and wait for external input.
|
||||
|
||||
Action schema:
|
||||
kind: WaitForInput
|
||||
id: wait_for_response
|
||||
variable: Local.response
|
||||
message: Waiting for user response...
|
||||
|
||||
This is a simpler form of RequestExternalInput that just pauses
|
||||
execution until input is provided.
|
||||
"""
|
||||
wait_id = ctx.action.get("id", "wait")
|
||||
variable = ctx.action.get("variable")
|
||||
message = ctx.action.get("message")
|
||||
|
||||
if not variable:
|
||||
logger.warning("WaitForInput action missing 'variable' property")
|
||||
return
|
||||
|
||||
# Evaluate message if it's an expression
|
||||
evaluated_message = ctx.state.eval_if_expression(message) if message else None
|
||||
|
||||
logger.debug(f"WaitForInput: waiting for {variable}")
|
||||
|
||||
yield QuestionRequest(
|
||||
request_id=wait_id,
|
||||
prompt=str(evaluated_message) if evaluated_message else None,
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
|
||||
def process_external_loop(
|
||||
input_config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Process the externalLoop.when pattern from action input.
|
||||
|
||||
This function evaluates the externalLoop.when condition to determine
|
||||
if the action should continue looping for external input.
|
||||
|
||||
Args:
|
||||
input_config: The input configuration containing externalLoop
|
||||
state: The workflow state for expression evaluation
|
||||
|
||||
Returns:
|
||||
Tuple of (should_continue_loop, condition_expression)
|
||||
- should_continue_loop: True if the loop should continue
|
||||
- condition_expression: The original condition expression for diagnostics
|
||||
"""
|
||||
external_loop = input_config.get("externalLoop", {})
|
||||
when_condition = external_loop.get("when")
|
||||
|
||||
if not when_condition:
|
||||
return (False, None)
|
||||
|
||||
# Evaluate the condition
|
||||
result = state.eval(when_condition)
|
||||
|
||||
# The loop continues while the condition is True
|
||||
should_continue = bool(result) if result is not None else False
|
||||
|
||||
logger.debug(f"ExternalLoop condition '{when_condition[:50]}' evaluated to {should_continue}")
|
||||
|
||||
return (should_continue, when_condition)
|
||||
|
||||
|
||||
def validate_input_response(
|
||||
value: Any,
|
||||
validation: dict[str, Any] | None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Validate input response against validation rules.
|
||||
|
||||
Args:
|
||||
value: The input value to validate
|
||||
validation: Validation rules from the Question action
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not validation:
|
||||
return (True, None)
|
||||
|
||||
# Check required
|
||||
if validation.get("required") and (value is None or value == ""):
|
||||
return (False, "This field is required")
|
||||
|
||||
if value is None:
|
||||
return (True, None)
|
||||
|
||||
# Check string length
|
||||
if isinstance(value, str):
|
||||
min_length = validation.get("minLength")
|
||||
max_length = validation.get("maxLength")
|
||||
|
||||
if min_length is not None and len(value) < min_length:
|
||||
return (False, f"Minimum length is {min_length}")
|
||||
|
||||
if max_length is not None and len(value) > max_length:
|
||||
return (False, f"Maximum length is {max_length}")
|
||||
|
||||
# Check numeric range
|
||||
if isinstance(value, (int, float)):
|
||||
min_value = validation.get("min")
|
||||
max_value = validation.get("max")
|
||||
|
||||
if min_value is not None and value < min_value:
|
||||
return (False, f"Minimum value is {min_value}")
|
||||
|
||||
if max_value is not None and value > max_value:
|
||||
return (False, f"Maximum value is {max_value}")
|
||||
|
||||
# Check pattern (regex)
|
||||
pattern = validation.get("pattern")
|
||||
if pattern and isinstance(value, str):
|
||||
import re
|
||||
|
||||
if not re.match(pattern, value):
|
||||
return (False, f"Value does not match pattern: {pattern}")
|
||||
|
||||
return (True, None)
|
||||
@@ -12,6 +12,7 @@ This module provides state management for declarative workflows, handling:
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -107,11 +108,15 @@ class WorkflowState:
|
||||
self._inputs: dict[str, Any] = dict(inputs) if inputs else {}
|
||||
self._local: dict[str, Any] = {}
|
||||
self._outputs: dict[str, Any] = {}
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self._system: dict[str, Any] = {
|
||||
"ConversationId": "default",
|
||||
"ConversationId": conversation_id,
|
||||
"LastMessage": {"Text": "", "Id": ""},
|
||||
"LastMessageText": "",
|
||||
"LastMessageId": "",
|
||||
"conversations": {
|
||||
conversation_id: {"id": conversation_id, "messages": []},
|
||||
},
|
||||
}
|
||||
self._agent: dict[str, Any] = {}
|
||||
self._conversation: dict[str, Any] = {
|
||||
|
||||
@@ -1,348 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for additional action handlers (conversation, variables, etc.)."""
|
||||
|
||||
import pytest
|
||||
|
||||
import agent_framework_declarative._workflows._actions_basic # noqa: F401
|
||||
import agent_framework_declarative._workflows._actions_control_flow # noqa: F401
|
||||
from agent_framework_declarative._workflows._handlers import get_action_handler
|
||||
from agent_framework_declarative._workflows._state import WorkflowState
|
||||
|
||||
|
||||
def create_action_context(action: dict, state: WorkflowState | None = None):
|
||||
"""Create a minimal action context for testing."""
|
||||
from agent_framework_declarative._workflows._handlers import ActionContext
|
||||
|
||||
if state is None:
|
||||
state = WorkflowState()
|
||||
|
||||
async def execute_actions(actions, state):
|
||||
for act in actions:
|
||||
handler = get_action_handler(act.get("kind"))
|
||||
if handler:
|
||||
async for event in handler(
|
||||
ActionContext(
|
||||
state=state,
|
||||
action=act,
|
||||
execute_actions=execute_actions,
|
||||
agents={},
|
||||
bindings={},
|
||||
)
|
||||
):
|
||||
yield event
|
||||
|
||||
return ActionContext(
|
||||
state=state,
|
||||
action=action,
|
||||
execute_actions=execute_actions,
|
||||
agents={},
|
||||
bindings={},
|
||||
)
|
||||
|
||||
|
||||
class TestSetTextVariableHandler:
|
||||
"""Tests for SetTextVariable action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_text_variable_simple(self):
|
||||
"""Test setting a simple text variable."""
|
||||
ctx = create_action_context({
|
||||
"kind": "SetTextVariable",
|
||||
"variable": "Local.greeting",
|
||||
"value": "Hello, World!",
|
||||
})
|
||||
|
||||
handler = get_action_handler("SetTextVariable")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.greeting") == "Hello, World!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_text_variable_with_interpolation(self):
|
||||
"""Test setting text with variable interpolation."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.name", "Alice")
|
||||
|
||||
ctx = create_action_context(
|
||||
{
|
||||
"kind": "SetTextVariable",
|
||||
"variable": "Local.message",
|
||||
"value": "Hello, {Local.name}!",
|
||||
},
|
||||
state=state,
|
||||
)
|
||||
|
||||
handler = get_action_handler("SetTextVariable")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.message") == "Hello, Alice!"
|
||||
|
||||
|
||||
class TestResetVariableHandler:
|
||||
"""Tests for ResetVariable action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_variable(self):
|
||||
"""Test resetting a variable to None."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.counter", 5)
|
||||
|
||||
ctx = create_action_context(
|
||||
{
|
||||
"kind": "ResetVariable",
|
||||
"variable": "Local.counter",
|
||||
},
|
||||
state=state,
|
||||
)
|
||||
|
||||
handler = get_action_handler("ResetVariable")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.counter") is None
|
||||
|
||||
|
||||
class TestSetMultipleVariablesHandler:
|
||||
"""Tests for SetMultipleVariables action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_multiple_variables(self):
|
||||
"""Test setting multiple variables at once."""
|
||||
ctx = create_action_context({
|
||||
"kind": "SetMultipleVariables",
|
||||
"variables": [
|
||||
{"variable": "Local.a", "value": 1},
|
||||
{"variable": "Local.b", "value": 2},
|
||||
{"variable": "Local.c", "value": "three"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("SetMultipleVariables")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.a") == 1
|
||||
assert ctx.state.get("Local.b") == 2
|
||||
assert ctx.state.get("Local.c") == "three"
|
||||
|
||||
|
||||
class TestClearAllVariablesHandler:
|
||||
"""Tests for ClearAllVariables action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all_variables(self):
|
||||
"""Test clearing all turn-scoped variables."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.a", 1)
|
||||
state.set("Local.b", 2)
|
||||
state.set("Workflow.Outputs.result", "kept")
|
||||
|
||||
ctx = create_action_context(
|
||||
{
|
||||
"kind": "ClearAllVariables",
|
||||
},
|
||||
state=state,
|
||||
)
|
||||
|
||||
handler = get_action_handler("ClearAllVariables")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.a") is None
|
||||
assert ctx.state.get("Local.b") is None
|
||||
# Workflow outputs should be preserved
|
||||
assert ctx.state.get("Workflow.Outputs.result") == "kept"
|
||||
|
||||
|
||||
class TestCreateConversationHandler:
|
||||
"""Tests for CreateConversation action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_with_output_binding(self):
|
||||
"""Test creating a new conversation with output variable binding.
|
||||
|
||||
The conversationId field specifies the OUTPUT variable where the
|
||||
auto-generated conversation ID is stored.
|
||||
"""
|
||||
ctx = create_action_context({
|
||||
"kind": "CreateConversation",
|
||||
"conversationId": "Local.myConvId", # Output variable
|
||||
})
|
||||
|
||||
handler = get_action_handler("CreateConversation")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
# Check conversation was created with auto-generated ID
|
||||
conversations = ctx.state.get("System.conversations")
|
||||
assert conversations is not None
|
||||
assert len(conversations) == 1
|
||||
|
||||
# Get the generated ID
|
||||
generated_id = list(conversations.keys())[0]
|
||||
assert conversations[generated_id]["messages"] == []
|
||||
|
||||
# Check output binding - the ID should be stored in the specified variable
|
||||
assert ctx.state.get("Local.myConvId") == generated_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_legacy_output(self):
|
||||
"""Test creating a conversation with legacy output binding."""
|
||||
ctx = create_action_context({
|
||||
"kind": "CreateConversation",
|
||||
"output": {
|
||||
"conversationId": "Local.myConvId",
|
||||
},
|
||||
})
|
||||
|
||||
handler = get_action_handler("CreateConversation")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
# Check conversation was created
|
||||
conversations = ctx.state.get("System.conversations")
|
||||
assert conversations is not None
|
||||
assert len(conversations) == 1
|
||||
|
||||
# Get the generated ID
|
||||
generated_id = list(conversations.keys())[0]
|
||||
|
||||
# Check legacy output binding
|
||||
assert ctx.state.get("Local.myConvId") == generated_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation_auto_id(self):
|
||||
"""Test creating a conversation with auto-generated ID."""
|
||||
ctx = create_action_context({
|
||||
"kind": "CreateConversation",
|
||||
})
|
||||
|
||||
handler = get_action_handler("CreateConversation")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
# Check conversation was created with some ID
|
||||
conversations = ctx.state.get("System.conversations")
|
||||
assert conversations is not None
|
||||
assert len(conversations) == 1
|
||||
|
||||
|
||||
class TestAddConversationMessageHandler:
|
||||
"""Tests for AddConversationMessage action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_conversation_message(self):
|
||||
"""Test adding a message to a conversation."""
|
||||
state = WorkflowState()
|
||||
state.set(
|
||||
"System.conversations",
|
||||
{
|
||||
"conv-123": {"id": "conv-123", "messages": []},
|
||||
},
|
||||
)
|
||||
|
||||
ctx = create_action_context(
|
||||
{
|
||||
"kind": "AddConversationMessage",
|
||||
"conversationId": "conv-123",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": "Hello!",
|
||||
},
|
||||
},
|
||||
state=state,
|
||||
)
|
||||
|
||||
handler = get_action_handler("AddConversationMessage")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
conversations = ctx.state.get("System.conversations")
|
||||
assert len(conversations["conv-123"]["messages"]) == 1
|
||||
assert conversations["conv-123"]["messages"][0]["content"] == "Hello!"
|
||||
|
||||
|
||||
class TestEndWorkflowHandler:
|
||||
"""Tests for EndWorkflow action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_workflow_signal(self):
|
||||
"""Test that EndWorkflow emits correct signal."""
|
||||
from agent_framework_declarative._workflows._actions_control_flow import EndWorkflowSignal
|
||||
|
||||
ctx = create_action_context({
|
||||
"kind": "EndWorkflow",
|
||||
"reason": "Completed successfully",
|
||||
})
|
||||
|
||||
handler = get_action_handler("EndWorkflow")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], EndWorkflowSignal)
|
||||
assert events[0].reason == "Completed successfully"
|
||||
|
||||
|
||||
class TestEndConversationHandler:
|
||||
"""Tests for EndConversation action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_conversation_signal(self):
|
||||
"""Test that EndConversation emits correct signal."""
|
||||
from agent_framework_declarative._workflows._actions_control_flow import EndConversationSignal
|
||||
|
||||
ctx = create_action_context({
|
||||
"kind": "EndConversation",
|
||||
"conversationId": "conv-123",
|
||||
})
|
||||
|
||||
handler = get_action_handler("EndConversation")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], EndConversationSignal)
|
||||
assert events[0].conversation_id == "conv-123"
|
||||
|
||||
|
||||
class TestConditionGroupWithElseActions:
|
||||
"""Tests for ConditionGroup with elseActions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_condition_group_else_actions(self):
|
||||
"""Test that elseActions execute when no condition matches."""
|
||||
ctx = create_action_context({
|
||||
"kind": "ConditionGroup",
|
||||
"conditions": [
|
||||
{
|
||||
"condition": False,
|
||||
"actions": [
|
||||
{"kind": "SetValue", "path": "Local.result", "value": "matched"},
|
||||
],
|
||||
},
|
||||
],
|
||||
"elseActions": [
|
||||
{"kind": "SetValue", "path": "Local.result", "value": "else"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("ConditionGroup")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.result") == "else"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_condition_group_match_skips_else(self):
|
||||
"""Test that elseActions don't execute when a condition matches."""
|
||||
ctx = create_action_context({
|
||||
"kind": "ConditionGroup",
|
||||
"conditions": [
|
||||
{
|
||||
"condition": True,
|
||||
"actions": [
|
||||
{"kind": "SetValue", "path": "Local.result", "value": "matched"},
|
||||
],
|
||||
},
|
||||
],
|
||||
"elseActions": [
|
||||
{"kind": "SetValue", "path": "Local.result", "value": "else"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("ConditionGroup")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.result") == "matched"
|
||||
@@ -1,286 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for human-in-the-loop action handlers."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_declarative._workflows._handlers import ActionContext, get_action_handler
|
||||
from agent_framework_declarative._workflows._human_input import (
|
||||
QuestionRequest,
|
||||
process_external_loop,
|
||||
validate_input_response,
|
||||
)
|
||||
from agent_framework_declarative._workflows._state import WorkflowState
|
||||
|
||||
|
||||
def create_action_context(action: dict, state: WorkflowState | None = None):
|
||||
"""Create a minimal action context for testing."""
|
||||
if state is None:
|
||||
state = WorkflowState()
|
||||
|
||||
async def execute_actions(actions, state):
|
||||
for act in actions:
|
||||
handler = get_action_handler(act.get("kind"))
|
||||
if handler:
|
||||
async for event in handler(
|
||||
ActionContext(
|
||||
state=state,
|
||||
action=act,
|
||||
execute_actions=execute_actions,
|
||||
agents={},
|
||||
bindings={},
|
||||
)
|
||||
):
|
||||
yield event
|
||||
|
||||
return ActionContext(
|
||||
state=state,
|
||||
action=action,
|
||||
execute_actions=execute_actions,
|
||||
agents={},
|
||||
bindings={},
|
||||
)
|
||||
|
||||
|
||||
class TestQuestionHandler:
|
||||
"""Tests for Question action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_emits_request_info_event(self):
|
||||
"""Test that Question handler emits QuestionRequest."""
|
||||
ctx = create_action_context({
|
||||
"kind": "Question",
|
||||
"id": "ask_name",
|
||||
"variable": "Local.userName",
|
||||
"prompt": "What is your name?",
|
||||
})
|
||||
|
||||
handler = get_action_handler("Question")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], QuestionRequest)
|
||||
assert events[0].request_id == "ask_name"
|
||||
assert events[0].prompt == "What is your name?"
|
||||
assert events[0].variable == "Local.userName"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_with_choices(self):
|
||||
"""Test Question with multiple choice options."""
|
||||
ctx = create_action_context({
|
||||
"kind": "Question",
|
||||
"id": "ask_choice",
|
||||
"variable": "Local.selection",
|
||||
"prompt": "Select an option:",
|
||||
"choices": ["Option A", "Option B", "Option C"],
|
||||
"default": "Option A",
|
||||
})
|
||||
|
||||
handler = get_action_handler("Question")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
event = events[0]
|
||||
assert isinstance(event, QuestionRequest)
|
||||
assert event.choices == ["Option A", "Option B", "Option C"]
|
||||
assert event.default_value == "Option A"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_question_with_validation(self):
|
||||
"""Test Question with validation rules."""
|
||||
ctx = create_action_context({
|
||||
"kind": "Question",
|
||||
"id": "ask_email",
|
||||
"variable": "Local.email",
|
||||
"prompt": "Enter your email:",
|
||||
"validation": {
|
||||
"required": True,
|
||||
"pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$",
|
||||
},
|
||||
})
|
||||
|
||||
handler = get_action_handler("Question")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
event = events[0]
|
||||
assert event.validation == {
|
||||
"required": True,
|
||||
"pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$",
|
||||
}
|
||||
|
||||
|
||||
class TestRequestExternalInputHandler:
|
||||
"""Tests for RequestExternalInput action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_external_input(self):
|
||||
"""Test RequestExternalInput handler emits event."""
|
||||
ctx = create_action_context({
|
||||
"kind": "RequestExternalInput",
|
||||
"id": "get_approval",
|
||||
"variable": "Local.approval",
|
||||
"prompt": "Please approve or reject",
|
||||
"timeout": 300,
|
||||
})
|
||||
|
||||
handler = get_action_handler("RequestExternalInput")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
event = events[0]
|
||||
assert isinstance(event, QuestionRequest)
|
||||
assert event.request_id == "get_approval"
|
||||
assert event.variable == "Local.approval"
|
||||
assert event.validation == {"timeout": 300}
|
||||
|
||||
|
||||
class TestWaitForInputHandler:
|
||||
"""Tests for WaitForInput action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_input(self):
|
||||
"""Test WaitForInput handler."""
|
||||
ctx = create_action_context({
|
||||
"kind": "WaitForInput",
|
||||
"id": "wait",
|
||||
"variable": "Local.response",
|
||||
"message": "Waiting...",
|
||||
})
|
||||
|
||||
handler = get_action_handler("WaitForInput")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
event = events[0]
|
||||
assert isinstance(event, QuestionRequest)
|
||||
assert event.request_id == "wait"
|
||||
assert event.prompt == "Waiting..."
|
||||
|
||||
|
||||
class TestProcessExternalLoop:
|
||||
"""Tests for process_external_loop helper function."""
|
||||
|
||||
def test_no_external_loop(self):
|
||||
"""Test when no external loop is configured."""
|
||||
state = WorkflowState()
|
||||
result, expr = process_external_loop({}, state)
|
||||
|
||||
assert result is False
|
||||
assert expr is None
|
||||
|
||||
def test_external_loop_true_condition(self):
|
||||
"""Test when external loop condition evaluates to true."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.isComplete", False)
|
||||
|
||||
input_config = {
|
||||
"externalLoop": {
|
||||
"when": "=!Local.isComplete",
|
||||
},
|
||||
}
|
||||
|
||||
result, expr = process_external_loop(input_config, state)
|
||||
|
||||
# !False = True, so loop should continue
|
||||
assert result is True
|
||||
assert expr == "=!Local.isComplete"
|
||||
|
||||
def test_external_loop_false_condition(self):
|
||||
"""Test when external loop condition evaluates to false."""
|
||||
state = WorkflowState()
|
||||
state.set("Local.isComplete", True)
|
||||
|
||||
input_config = {
|
||||
"externalLoop": {
|
||||
"when": "=!Local.isComplete",
|
||||
},
|
||||
}
|
||||
|
||||
result, expr = process_external_loop(input_config, state)
|
||||
|
||||
# !True = False, so loop should stop
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestValidateInputResponse:
|
||||
"""Tests for validate_input_response helper function."""
|
||||
|
||||
def test_no_validation(self):
|
||||
"""Test with no validation rules."""
|
||||
is_valid, error = validate_input_response("any value", None)
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
def test_required_valid(self):
|
||||
"""Test required validation with valid value."""
|
||||
is_valid, error = validate_input_response("value", {"required": True})
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
def test_required_empty_string(self):
|
||||
"""Test required validation with empty string."""
|
||||
is_valid, error = validate_input_response("", {"required": True})
|
||||
assert is_valid is False
|
||||
assert "required" in error.lower()
|
||||
|
||||
def test_required_none(self):
|
||||
"""Test required validation with None."""
|
||||
is_valid, error = validate_input_response(None, {"required": True})
|
||||
assert is_valid is False
|
||||
assert "required" in error.lower()
|
||||
|
||||
def test_min_length_valid(self):
|
||||
"""Test minLength validation with valid value."""
|
||||
is_valid, error = validate_input_response("hello", {"minLength": 3})
|
||||
assert is_valid is True
|
||||
|
||||
def test_min_length_invalid(self):
|
||||
"""Test minLength validation with too short value."""
|
||||
is_valid, error = validate_input_response("hi", {"minLength": 3})
|
||||
assert is_valid is False
|
||||
assert "minimum length" in error.lower()
|
||||
|
||||
def test_max_length_valid(self):
|
||||
"""Test maxLength validation with valid value."""
|
||||
is_valid, error = validate_input_response("hello", {"maxLength": 10})
|
||||
assert is_valid is True
|
||||
|
||||
def test_max_length_invalid(self):
|
||||
"""Test maxLength validation with too long value."""
|
||||
is_valid, error = validate_input_response("hello world", {"maxLength": 5})
|
||||
assert is_valid is False
|
||||
assert "maximum length" in error.lower()
|
||||
|
||||
def test_min_value_valid(self):
|
||||
"""Test min validation for numbers."""
|
||||
is_valid, error = validate_input_response(10, {"min": 5})
|
||||
assert is_valid is True
|
||||
|
||||
def test_min_value_invalid(self):
|
||||
"""Test min validation with too small number."""
|
||||
is_valid, error = validate_input_response(3, {"min": 5})
|
||||
assert is_valid is False
|
||||
assert "minimum value" in error.lower()
|
||||
|
||||
def test_max_value_valid(self):
|
||||
"""Test max validation for numbers."""
|
||||
is_valid, error = validate_input_response(5, {"max": 10})
|
||||
assert is_valid is True
|
||||
|
||||
def test_max_value_invalid(self):
|
||||
"""Test max validation with too large number."""
|
||||
is_valid, error = validate_input_response(15, {"max": 10})
|
||||
assert is_valid is False
|
||||
assert "maximum value" in error.lower()
|
||||
|
||||
def test_pattern_valid(self):
|
||||
"""Test pattern validation with matching value."""
|
||||
is_valid, error = validate_input_response("test@example.com", {"pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$"})
|
||||
assert is_valid is True
|
||||
|
||||
def test_pattern_invalid(self):
|
||||
"""Test pattern validation with non-matching value."""
|
||||
is_valid, error = validate_input_response("not-an-email", {"pattern": r"^[\w\.-]+@[\w\.-]+\.\w+$"})
|
||||
assert is_valid is False
|
||||
assert "pattern" in error.lower()
|
||||
@@ -740,6 +740,90 @@ class TestAgentExecutorsCoverage:
|
||||
name = executor._get_agent_name(state)
|
||||
assert name == "LegacyAgent"
|
||||
|
||||
async def test_agent_executor_get_agent_name_string_expression(self, mock_context, mock_state):
|
||||
"""Test agent name extraction from simple string expression."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": "=Local.SelectedAgent",
|
||||
}
|
||||
executor = InvokeAzureAgentExecutor(action_def)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
with patch.object(state, "eval_if_expression", return_value="DynamicAgent"):
|
||||
name = executor._get_agent_name(state)
|
||||
assert name == "DynamicAgent"
|
||||
|
||||
async def test_agent_executor_get_agent_name_dict_expression(self, mock_context, mock_state):
|
||||
"""Test agent name extraction from nested dict with expression."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": {"name": "=Local.ManagerResult.next_speaker.answer"},
|
||||
}
|
||||
executor = InvokeAzureAgentExecutor(action_def)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
with patch.object(state, "eval_if_expression", return_value="WeatherAgent"):
|
||||
name = executor._get_agent_name(state)
|
||||
assert name == "WeatherAgent"
|
||||
|
||||
async def test_agent_executor_get_agent_name_legacy_expression(self, mock_context, mock_state):
|
||||
"""Test agent name extraction from legacy agentName with expression."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agentName": "=Local.NextAgent",
|
||||
}
|
||||
executor = InvokeAzureAgentExecutor(action_def)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
with patch.object(state, "eval_if_expression", return_value="ResolvedAgent"):
|
||||
name = executor._get_agent_name(state)
|
||||
assert name == "ResolvedAgent"
|
||||
|
||||
async def test_agent_executor_get_agent_name_expression_returns_none(self, mock_context, mock_state):
|
||||
"""Test agent name returns None when expression evaluates to None."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
action_def = {
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": {"name": "=Local.UndefinedVar"},
|
||||
}
|
||||
executor = InvokeAzureAgentExecutor(action_def)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
with patch.object(state, "eval_if_expression", return_value=None):
|
||||
name = executor._get_agent_name(state)
|
||||
assert name is None
|
||||
|
||||
async def test_agent_executor_get_input_config_simple(self, mock_context, mock_state):
|
||||
"""Test input config parsing with simple non-dict input."""
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
@@ -2337,6 +2421,89 @@ class TestBuilderEdgeWiring:
|
||||
exit_exec = graph_builder._get_branch_exit(None)
|
||||
assert exit_exec is None
|
||||
|
||||
def test_get_branch_exit_returns_none_for_goto_terminator(self):
|
||||
"""Test that _get_branch_exit returns None when branch ends with GotoAction.
|
||||
|
||||
GotoAction is a terminator that handles its own control flow (jumping to
|
||||
the target action). It should NOT be returned as a branch exit, because
|
||||
that would cause the parent ConditionGroup to wire it to the next
|
||||
sequential action, creating a dual-edge where both the goto target and
|
||||
the next action receive messages.
|
||||
"""
|
||||
from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
|
||||
# GotoAction executor is a JoinExecutor with a GotoAction action_def
|
||||
goto_executor = JoinExecutor(
|
||||
{"kind": "GotoAction", "id": "goto_summary", "actionId": "invoke_summary"},
|
||||
id="goto_summary",
|
||||
)
|
||||
|
||||
# Simulate a single-action branch chain
|
||||
goto_executor._chain_executors = [goto_executor] # type: ignore[attr-defined]
|
||||
|
||||
exit_exec = graph_builder._get_branch_exit(goto_executor)
|
||||
assert exit_exec is None
|
||||
|
||||
def test_get_branch_exit_returns_none_for_end_workflow_terminator(self):
|
||||
"""Test that _get_branch_exit returns None when branch ends with EndWorkflow."""
|
||||
from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
|
||||
end_executor = JoinExecutor(
|
||||
{"kind": "EndWorkflow", "id": "end"},
|
||||
id="end",
|
||||
)
|
||||
end_executor._chain_executors = [end_executor] # type: ignore[attr-defined]
|
||||
|
||||
exit_exec = graph_builder._get_branch_exit(end_executor)
|
||||
assert exit_exec is None
|
||||
|
||||
def test_get_branch_exit_returns_none_for_goto_in_chain(self):
|
||||
"""Test that _get_branch_exit returns None when chain ends with GotoAction.
|
||||
|
||||
Even when a branch has multiple actions before the GotoAction,
|
||||
the branch exit should be None because the last action is a terminator.
|
||||
"""
|
||||
from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder
|
||||
from agent_framework_declarative._workflows._executors_basic import SendActivityExecutor
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
|
||||
# A branch with: SendActivity -> GotoAction
|
||||
activity = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "msg"}}, id="msg")
|
||||
goto = JoinExecutor(
|
||||
{"kind": "GotoAction", "id": "goto_target", "actionId": "some_target"},
|
||||
id="goto_target",
|
||||
)
|
||||
activity._chain_executors = [activity, goto] # type: ignore[attr-defined]
|
||||
|
||||
exit_exec = graph_builder._get_branch_exit(activity)
|
||||
assert exit_exec is None
|
||||
|
||||
def test_get_branch_exit_returns_executor_for_non_terminator(self):
|
||||
"""Test that _get_branch_exit still returns the exit for non-terminator branches."""
|
||||
from agent_framework_declarative._workflows._declarative_builder import DeclarativeWorkflowBuilder
|
||||
from agent_framework_declarative._workflows._executors_basic import SendActivityExecutor
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
|
||||
exec1 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "1"}}, id="e1")
|
||||
exec2 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "2"}}, id="e2")
|
||||
exec1._chain_executors = [exec1, exec2] # type: ignore[attr-defined]
|
||||
|
||||
exit_exec = graph_builder._get_branch_exit(exec1)
|
||||
assert exit_exec == exec2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent executor external loop response handler tests
|
||||
@@ -2702,3 +2869,133 @@ class TestLongMessageTextHandling:
|
||||
|
||||
result = state.eval('=!IsBlank(Find("CONGRATULATIONS", Upper(MessageText(Local.Messages))))')
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCreateConversationExecutor:
|
||||
"""Tests for CreateConversationExecutor."""
|
||||
|
||||
async def test_basic_creation(self, mock_context, mock_state):
|
||||
"""Test that a UUID is generated, stored at conversationId path, and conversation entry created."""
|
||||
from agent_framework_declarative._workflows._executors_basic import (
|
||||
CreateConversationExecutor,
|
||||
)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
action_def = {
|
||||
"kind": "CreateConversation",
|
||||
"conversationId": "Local.myConvId",
|
||||
}
|
||||
executor = CreateConversationExecutor(action_def)
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
# A UUID should be stored at the requested path
|
||||
conv_id = state.get("Local.myConvId")
|
||||
assert conv_id is not None
|
||||
assert isinstance(conv_id, str)
|
||||
assert len(conv_id) == 36 # UUID format
|
||||
|
||||
# Conversation entry should exist in System.conversations
|
||||
conversations = state.get("System.conversations")
|
||||
assert conversations is not None
|
||||
assert conv_id in conversations
|
||||
assert conversations[conv_id]["id"] == conv_id
|
||||
assert conversations[conv_id]["messages"] == []
|
||||
|
||||
async def test_no_conversation_id_param(self, mock_context, mock_state):
|
||||
"""Test that conversation is still created even without a conversationId param."""
|
||||
from agent_framework_declarative._workflows._executors_basic import (
|
||||
CreateConversationExecutor,
|
||||
)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
action_def = {
|
||||
"kind": "CreateConversation",
|
||||
}
|
||||
executor = CreateConversationExecutor(action_def)
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
# Conversation entry should still exist in System.conversations
|
||||
# (initialize() seeds one default conversation, plus the one just created)
|
||||
conversations = state.get("System.conversations")
|
||||
assert conversations is not None
|
||||
assert len(conversations) == 2
|
||||
|
||||
async def test_multiple_conversations(self, mock_context, mock_state):
|
||||
"""Test creating multiple conversations produces distinct IDs."""
|
||||
from agent_framework_declarative._workflows._executors_basic import (
|
||||
CreateConversationExecutor,
|
||||
)
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
action_def1 = {
|
||||
"kind": "CreateConversation",
|
||||
"conversationId": "Local.conv1",
|
||||
}
|
||||
action_def2 = {
|
||||
"kind": "CreateConversation",
|
||||
"conversationId": "Local.conv2",
|
||||
}
|
||||
|
||||
executor1 = CreateConversationExecutor(action_def1)
|
||||
await executor1.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
executor2 = CreateConversationExecutor(action_def2)
|
||||
await executor2.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
conv1 = state.get("Local.conv1")
|
||||
conv2 = state.get("Local.conv2")
|
||||
|
||||
assert conv1 != conv2
|
||||
|
||||
# initialize() seeds one default conversation, plus the two just created
|
||||
conversations = state.get("System.conversations")
|
||||
assert len(conversations) == 3
|
||||
assert conv1 in conversations
|
||||
assert conv2 in conversations
|
||||
|
||||
|
||||
class TestDeclarativeWorkflowStateConversationIdInit:
|
||||
"""Tests that DeclarativeWorkflowState.initialize() generates a real UUID for ConversationId."""
|
||||
|
||||
async def test_conversation_id_is_not_default(self, mock_state):
|
||||
"""System.ConversationId should be a UUID, not 'default'."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
conv_id = state.get("System.ConversationId")
|
||||
assert conv_id is not None
|
||||
assert conv_id != "default"
|
||||
# Validate it looks like a UUID
|
||||
import uuid
|
||||
|
||||
uuid.UUID(conv_id) # Raises ValueError if not a valid UUID
|
||||
|
||||
async def test_conversations_dict_initialized(self, mock_state):
|
||||
"""System.conversations should contain an entry matching ConversationId."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
conv_id = state.get("System.ConversationId")
|
||||
conversations = state.get("System.conversations")
|
||||
assert conversations is not None
|
||||
assert conv_id in conversations
|
||||
assert conversations[conv_id]["id"] == conv_id
|
||||
assert conversations[conv_id]["messages"] == []
|
||||
|
||||
async def test_each_initialize_generates_unique_id(self, mock_state):
|
||||
"""Each call to initialize() should produce a different ConversationId."""
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
|
||||
state.initialize()
|
||||
id1 = state.get("System.ConversationId")
|
||||
|
||||
state.initialize()
|
||||
id2 = state.get("System.ConversationId")
|
||||
|
||||
assert id1 != id2
|
||||
|
||||
@@ -231,49 +231,3 @@ actions:
|
||||
|
||||
# Should execute successfully with displayName metadata
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_action_context_display_name_property(self):
|
||||
"""Test that ActionContext provides displayName property."""
|
||||
from agent_framework_declarative._workflows._handlers import ActionContext
|
||||
from agent_framework_declarative._workflows._state import WorkflowState
|
||||
|
||||
state = WorkflowState()
|
||||
ctx = ActionContext(
|
||||
state=state,
|
||||
action={
|
||||
"kind": "SetValue",
|
||||
"id": "test_action",
|
||||
"displayName": "Test Action Display Name",
|
||||
"path": "Local.value",
|
||||
"value": "test",
|
||||
},
|
||||
execute_actions=lambda a, s: None,
|
||||
agents={},
|
||||
bindings={},
|
||||
)
|
||||
|
||||
assert ctx.action_id == "test_action"
|
||||
assert ctx.display_name == "Test Action Display Name"
|
||||
assert ctx.action_kind == "SetValue"
|
||||
|
||||
def test_action_context_without_display_name(self):
|
||||
"""Test ActionContext when displayName is not provided."""
|
||||
from agent_framework_declarative._workflows._handlers import ActionContext
|
||||
from agent_framework_declarative._workflows._state import WorkflowState
|
||||
|
||||
state = WorkflowState()
|
||||
ctx = ActionContext(
|
||||
state=state,
|
||||
action={
|
||||
"kind": "SetValue",
|
||||
"path": "Local.value",
|
||||
"value": "test",
|
||||
},
|
||||
execute_actions=lambda a, s: None,
|
||||
agents={},
|
||||
bindings={},
|
||||
)
|
||||
|
||||
assert ctx.action_id is None
|
||||
assert ctx.display_name is None
|
||||
assert ctx.action_kind == "SetValue"
|
||||
|
||||
@@ -1,553 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for action handlers."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Import handlers to register them
|
||||
from agent_framework_declarative._workflows import (
|
||||
_actions_basic, # noqa: F401
|
||||
_actions_control_flow, # noqa: F401
|
||||
_actions_error, # noqa: F401
|
||||
)
|
||||
from agent_framework_declarative._workflows._handlers import (
|
||||
ActionContext,
|
||||
CustomEvent,
|
||||
TextOutputEvent,
|
||||
WorkflowEvent,
|
||||
get_action_handler,
|
||||
list_action_handlers,
|
||||
)
|
||||
from agent_framework_declarative._workflows._state import WorkflowState
|
||||
|
||||
|
||||
def create_action_context(
|
||||
action: dict[str, Any],
|
||||
inputs: dict[str, Any] | None = None,
|
||||
agents: dict[str, Any] | None = None,
|
||||
bindings: dict[str, Any] | None = None,
|
||||
run_kwargs: dict[str, Any] | None = None,
|
||||
) -> ActionContext:
|
||||
"""Helper to create an ActionContext for testing."""
|
||||
state = WorkflowState(inputs=inputs or {})
|
||||
|
||||
async def execute_actions(
|
||||
actions: list[dict[str, Any]], state: WorkflowState
|
||||
) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Mock execute_actions that runs handlers for nested actions."""
|
||||
for nested_action in actions:
|
||||
action_kind = nested_action.get("kind")
|
||||
handler = get_action_handler(action_kind)
|
||||
if handler:
|
||||
ctx = ActionContext(
|
||||
state=state,
|
||||
action=nested_action,
|
||||
execute_actions=execute_actions,
|
||||
agents=agents or {},
|
||||
bindings=bindings or {},
|
||||
run_kwargs=run_kwargs or {},
|
||||
)
|
||||
async for event in handler(ctx):
|
||||
yield event
|
||||
|
||||
return ActionContext(
|
||||
state=state,
|
||||
action=action,
|
||||
execute_actions=execute_actions,
|
||||
agents=agents or {},
|
||||
bindings=bindings or {},
|
||||
run_kwargs=run_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
class TestActionHandlerRegistry:
|
||||
"""Tests for action handler registration."""
|
||||
|
||||
def test_basic_handlers_registered(self):
|
||||
"""Test that basic handlers are registered."""
|
||||
handlers = list_action_handlers()
|
||||
assert "SetValue" in handlers
|
||||
assert "AppendValue" in handlers
|
||||
assert "SendActivity" in handlers
|
||||
assert "EmitEvent" in handlers
|
||||
|
||||
def test_control_flow_handlers_registered(self):
|
||||
"""Test that control flow handlers are registered."""
|
||||
handlers = list_action_handlers()
|
||||
assert "Foreach" in handlers
|
||||
assert "If" in handlers
|
||||
assert "Switch" in handlers
|
||||
assert "RepeatUntil" in handlers
|
||||
assert "BreakLoop" in handlers
|
||||
assert "ContinueLoop" in handlers
|
||||
|
||||
def test_error_handlers_registered(self):
|
||||
"""Test that error handlers are registered."""
|
||||
handlers = list_action_handlers()
|
||||
assert "ThrowException" in handlers
|
||||
assert "TryCatch" in handlers
|
||||
|
||||
def test_get_unknown_handler_returns_none(self):
|
||||
"""Test that getting an unknown handler returns None."""
|
||||
assert get_action_handler("UnknownAction") is None
|
||||
|
||||
|
||||
class TestSetValueHandler:
|
||||
"""Tests for SetValue action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_simple_value(self):
|
||||
"""Test setting a simple value."""
|
||||
ctx = create_action_context({
|
||||
"kind": "SetValue",
|
||||
"path": "Local.result",
|
||||
"value": "test value",
|
||||
})
|
||||
|
||||
handler = get_action_handler("SetValue")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 0 # SetValue doesn't emit events
|
||||
assert ctx.state.get("Local.result") == "test value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_value_from_input(self):
|
||||
"""Test setting a value from workflow inputs."""
|
||||
ctx = create_action_context(
|
||||
{
|
||||
"kind": "SetValue",
|
||||
"path": "Local.copy",
|
||||
"value": "literal",
|
||||
},
|
||||
inputs={"original": "from input"},
|
||||
)
|
||||
|
||||
handler = get_action_handler("SetValue")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.copy") == "literal"
|
||||
|
||||
|
||||
class TestAppendValueHandler:
|
||||
"""Tests for AppendValue action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_to_new_list(self):
|
||||
"""Test appending to a non-existent list creates it."""
|
||||
ctx = create_action_context({
|
||||
"kind": "AppendValue",
|
||||
"path": "Local.results",
|
||||
"value": "item1",
|
||||
})
|
||||
|
||||
handler = get_action_handler("AppendValue")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.results") == ["item1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_to_existing_list(self):
|
||||
"""Test appending to an existing list."""
|
||||
ctx = create_action_context({
|
||||
"kind": "AppendValue",
|
||||
"path": "Local.results",
|
||||
"value": "item2",
|
||||
})
|
||||
ctx.state.set("Local.results", ["item1"])
|
||||
|
||||
handler = get_action_handler("AppendValue")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.results") == ["item1", "item2"]
|
||||
|
||||
|
||||
class TestSendActivityHandler:
|
||||
"""Tests for SendActivity action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_activity(self):
|
||||
"""Test sending a text activity."""
|
||||
ctx = create_action_context({
|
||||
"kind": "SendActivity",
|
||||
"activity": {
|
||||
"text": "Hello, world!",
|
||||
},
|
||||
})
|
||||
|
||||
handler = get_action_handler("SendActivity")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], TextOutputEvent)
|
||||
assert events[0].text == "Hello, world!"
|
||||
|
||||
|
||||
class TestEmitEventHandler:
|
||||
"""Tests for EmitEvent action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_custom_event(self):
|
||||
"""Test emitting a custom event."""
|
||||
ctx = create_action_context({
|
||||
"kind": "EmitEvent",
|
||||
"event": {
|
||||
"name": "myEvent",
|
||||
"data": {"key": "value"},
|
||||
},
|
||||
})
|
||||
|
||||
handler = get_action_handler("EmitEvent")
|
||||
events = [e async for e in handler(ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], CustomEvent)
|
||||
assert events[0].name == "myEvent"
|
||||
assert events[0].data == {"key": "value"}
|
||||
|
||||
|
||||
class TestForeachHandler:
|
||||
"""Tests for Foreach action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_foreach_basic_iteration(self):
|
||||
"""Test basic foreach iteration."""
|
||||
ctx = create_action_context({
|
||||
"kind": "Foreach",
|
||||
"source": ["a", "b", "c"],
|
||||
"itemName": "letter",
|
||||
"actions": [
|
||||
{
|
||||
"kind": "AppendValue",
|
||||
"path": "Local.results",
|
||||
"value": "processed",
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("Foreach")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.results") == ["processed", "processed", "processed"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_foreach_sets_item_and_index(self):
|
||||
"""Test that foreach sets item and index variables."""
|
||||
ctx = create_action_context({
|
||||
"kind": "Foreach",
|
||||
"source": ["x", "y"],
|
||||
"itemName": "item",
|
||||
"indexName": "idx",
|
||||
"actions": [],
|
||||
})
|
||||
|
||||
# We'll check the last values after iteration
|
||||
handler = get_action_handler("Foreach")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
# After iteration, the last item/index should be set
|
||||
assert ctx.state.get("Local.item") == "y"
|
||||
assert ctx.state.get("Local.idx") == 1
|
||||
|
||||
|
||||
class TestIfHandler:
|
||||
"""Tests for If action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_if_true_branch(self):
|
||||
"""Test that the 'then' branch executes when condition is true."""
|
||||
ctx = create_action_context({
|
||||
"kind": "If",
|
||||
"condition": True,
|
||||
"then": [
|
||||
{"kind": "SetValue", "path": "Local.branch", "value": "then"},
|
||||
],
|
||||
"else": [
|
||||
{"kind": "SetValue", "path": "Local.branch", "value": "else"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("If")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.branch") == "then"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_if_false_branch(self):
|
||||
"""Test that the 'else' branch executes when condition is false."""
|
||||
ctx = create_action_context({
|
||||
"kind": "If",
|
||||
"condition": False,
|
||||
"then": [
|
||||
{"kind": "SetValue", "path": "Local.branch", "value": "then"},
|
||||
],
|
||||
"else": [
|
||||
{"kind": "SetValue", "path": "Local.branch", "value": "else"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("If")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.branch") == "else"
|
||||
|
||||
|
||||
class TestSwitchHandler:
|
||||
"""Tests for Switch action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_matching_case(self):
|
||||
"""Test switch with a matching case."""
|
||||
ctx = create_action_context({
|
||||
"kind": "Switch",
|
||||
"value": "option2",
|
||||
"cases": [
|
||||
{
|
||||
"match": "option1",
|
||||
"actions": [{"kind": "SetValue", "path": "Local.result", "value": "one"}],
|
||||
},
|
||||
{
|
||||
"match": "option2",
|
||||
"actions": [{"kind": "SetValue", "path": "Local.result", "value": "two"}],
|
||||
},
|
||||
],
|
||||
"default": [{"kind": "SetValue", "path": "Local.result", "value": "default"}],
|
||||
})
|
||||
|
||||
handler = get_action_handler("Switch")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.result") == "two"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_default_case(self):
|
||||
"""Test switch falls through to default."""
|
||||
ctx = create_action_context({
|
||||
"kind": "Switch",
|
||||
"value": "unknown",
|
||||
"cases": [
|
||||
{
|
||||
"match": "option1",
|
||||
"actions": [{"kind": "SetValue", "path": "Local.result", "value": "one"}],
|
||||
},
|
||||
],
|
||||
"default": [{"kind": "SetValue", "path": "Local.result", "value": "default"}],
|
||||
})
|
||||
|
||||
handler = get_action_handler("Switch")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.result") == "default"
|
||||
|
||||
|
||||
class TestRepeatUntilHandler:
|
||||
"""Tests for RepeatUntil action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repeat_until_condition_met(self):
|
||||
"""Test repeat until condition becomes true."""
|
||||
ctx = create_action_context({
|
||||
"kind": "RepeatUntil",
|
||||
"condition": False, # Will be evaluated each iteration
|
||||
"maxIterations": 3,
|
||||
"actions": [
|
||||
{"kind": "SetValue", "path": "Local.count", "value": 1},
|
||||
],
|
||||
})
|
||||
# Set up a counter that will cause the loop to exit
|
||||
ctx.state.set("Local.count", 0)
|
||||
|
||||
handler = get_action_handler("RepeatUntil")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
# With condition=False (literal), it will run maxIterations times
|
||||
assert ctx.state.get("Local.iteration") == 3
|
||||
|
||||
|
||||
class TestTryCatchHandler:
|
||||
"""Tests for TryCatch action handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_without_error(self):
|
||||
"""Test try block without errors."""
|
||||
ctx = create_action_context({
|
||||
"kind": "TryCatch",
|
||||
"try": [
|
||||
{"kind": "SetValue", "path": "Local.result", "value": "success"},
|
||||
],
|
||||
"catch": [
|
||||
{"kind": "SetValue", "path": "Local.result", "value": "caught"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("TryCatch")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.result") == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_with_throw_exception(self):
|
||||
"""Test catching a thrown exception."""
|
||||
ctx = create_action_context({
|
||||
"kind": "TryCatch",
|
||||
"try": [
|
||||
{"kind": "ThrowException", "message": "Test error", "code": "ERR001"},
|
||||
],
|
||||
"catch": [
|
||||
{"kind": "SetValue", "path": "Local.result", "value": "caught"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("TryCatch")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.result") == "caught"
|
||||
assert ctx.state.get("Local.error.message") == "Test error"
|
||||
assert ctx.state.get("Local.error.code") == "ERR001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finally_always_executes(self):
|
||||
"""Test that finally block always executes."""
|
||||
ctx = create_action_context({
|
||||
"kind": "TryCatch",
|
||||
"try": [
|
||||
{"kind": "SetValue", "path": "Local.try", "value": "ran"},
|
||||
],
|
||||
"finally": [
|
||||
{"kind": "SetValue", "path": "Local.finally", "value": "ran"},
|
||||
],
|
||||
})
|
||||
|
||||
handler = get_action_handler("TryCatch")
|
||||
_events = [e async for e in handler(ctx)] # noqa: F841
|
||||
|
||||
assert ctx.state.get("Local.try") == "ran"
|
||||
assert ctx.state.get("Local.finally") == "ran"
|
||||
|
||||
|
||||
class TestActionContextKwargs:
|
||||
"""ActionContext should carry and forward run_kwargs to agent invocations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_context_carries_run_kwargs(self):
|
||||
"""ActionContext should store and expose run_kwargs."""
|
||||
ctx = create_action_context(
|
||||
{"kind": "SetValue", "path": "Local.x", "value": "1"},
|
||||
run_kwargs={"user_token": "test123"},
|
||||
)
|
||||
assert ctx.run_kwargs == {"user_token": "test123"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_context_defaults_to_empty_kwargs(self):
|
||||
"""ActionContext.run_kwargs should default to empty dict."""
|
||||
ctx = create_action_context(
|
||||
{"kind": "SetValue", "path": "Local.x", "value": "1"},
|
||||
)
|
||||
assert ctx.run_kwargs == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_handler_forwards_kwargs(self):
|
||||
"""handle_invoke_azure_agent should forward ctx.run_kwargs to agent.run()."""
|
||||
import agent_framework_declarative._workflows._actions_agents # noqa: F401
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
|
||||
async def non_streaming_run(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise TypeError("no streaming")
|
||||
return mock_response
|
||||
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(side_effect=non_streaming_run)
|
||||
|
||||
test_kwargs = {"user_token": "secret", "api_key": "key123"}
|
||||
|
||||
state = WorkflowState()
|
||||
state.add_conversation_message(MagicMock(role="user", text="hello"))
|
||||
|
||||
ctx = create_action_context(
|
||||
action={
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": "my_agent",
|
||||
},
|
||||
agents={"my_agent": mock_agent},
|
||||
run_kwargs=test_kwargs,
|
||||
)
|
||||
|
||||
handler = get_action_handler("InvokeAzureAgent")
|
||||
_ = [e async for e in handler(ctx)]
|
||||
|
||||
assert mock_agent.run.call_count >= 1
|
||||
|
||||
# Find the non-streaming fallback call
|
||||
for call in mock_agent.run.call_args_list:
|
||||
call_kw = call.kwargs
|
||||
if not call_kw.get("stream"):
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
assert call_kw.get("api_key") == "key123"
|
||||
assert call_kw.get("options") == {"additional_function_arguments": test_kwargs}
|
||||
break
|
||||
else:
|
||||
# All calls were streaming — check the streaming call
|
||||
call_kw = mock_agent.run.call_args_list[0].kwargs
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
assert call_kw.get("api_key") == "key123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_handler_merges_caller_options(self):
|
||||
"""Caller-provided options in run_kwargs should be merged, not cause TypeError."""
|
||||
import agent_framework_declarative._workflows._actions_agents # noqa: F401
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
|
||||
async def non_streaming_run(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise TypeError("no streaming")
|
||||
return mock_response
|
||||
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(side_effect=non_streaming_run)
|
||||
|
||||
# Include 'options' in run_kwargs to test merge behavior
|
||||
test_kwargs = {"user_token": "secret", "options": {"temperature": 0.7}}
|
||||
|
||||
state = WorkflowState()
|
||||
state.add_conversation_message(MagicMock(role="user", text="hello"))
|
||||
|
||||
ctx = create_action_context(
|
||||
action={
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": "my_agent",
|
||||
},
|
||||
agents={"my_agent": mock_agent},
|
||||
run_kwargs=test_kwargs,
|
||||
)
|
||||
|
||||
handler = get_action_handler("InvokeAzureAgent")
|
||||
_ = [e async for e in handler(ctx)]
|
||||
|
||||
assert mock_agent.run.call_count >= 1
|
||||
|
||||
# Find the non-streaming fallback call
|
||||
for call in mock_agent.run.call_args_list:
|
||||
call_kw = call.kwargs
|
||||
if not call_kw.get("stream"):
|
||||
# Caller options should be merged with additional_function_arguments
|
||||
assert call_kw["options"]["temperature"] == 0.7
|
||||
assert "additional_function_arguments" in call_kw["options"]
|
||||
# Direct kwargs should not include 'options' (no duplicate keyword)
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
break
|
||||
else:
|
||||
call_kw = mock_agent.run.call_args_list[0].kwargs
|
||||
assert call_kw["options"]["temperature"] == 0.7
|
||||
assert "additional_function_arguments" in call_kw["options"]
|
||||
@@ -216,53 +216,22 @@ class TestHandlerCoverage:
|
||||
|
||||
return action_kinds
|
||||
|
||||
def test_handlers_exist_for_sample_actions(self, all_action_kinds):
|
||||
"""Test that handlers exist for all action kinds in samples."""
|
||||
from agent_framework_declarative._workflows._handlers import list_action_handlers
|
||||
def test_executors_exist_for_sample_actions(self, all_action_kinds):
|
||||
"""Test that executors exist for all action kinds used in samples."""
|
||||
from agent_framework_declarative._workflows._declarative_builder import ALL_ACTION_EXECUTORS
|
||||
|
||||
registered_handlers = set(list_action_handlers())
|
||||
registered_executors = set(ALL_ACTION_EXECUTORS.keys())
|
||||
|
||||
# Handlers we expect but may not be in samples
|
||||
expected_handlers = {
|
||||
"SetValue",
|
||||
"SetVariable",
|
||||
"SetTextVariable",
|
||||
"SetMultipleVariables",
|
||||
"ResetVariable",
|
||||
"ClearAllVariables",
|
||||
"AppendValue",
|
||||
"SendActivity",
|
||||
"EmitEvent",
|
||||
"Foreach",
|
||||
"If",
|
||||
"Switch",
|
||||
"ConditionGroup",
|
||||
"GotoAction",
|
||||
"BreakLoop",
|
||||
"ContinueLoop",
|
||||
"RepeatUntil",
|
||||
"TryCatch",
|
||||
"ThrowException",
|
||||
"EndWorkflow",
|
||||
"EndConversation",
|
||||
"InvokeAzureAgent",
|
||||
"InvokePromptAgent",
|
||||
"CreateConversation",
|
||||
"AddConversationMessage",
|
||||
"CopyConversationMessages",
|
||||
"RetrieveConversationMessages",
|
||||
"Question",
|
||||
"RequestExternalInput",
|
||||
"WaitForInput",
|
||||
# Kinds handled structurally by the builder (not registered as executors)
|
||||
structural_kinds = {
|
||||
"OnConversationStart", # Trigger kind, not an action
|
||||
"ConditionGroup", # Decomposed into evaluator/join nodes
|
||||
"GotoAction", # Resolved as graph edges, not executor nodes
|
||||
"Goto", # Alias for GotoAction
|
||||
}
|
||||
|
||||
# Check that sample action kinds have handlers
|
||||
missing_handlers = all_action_kinds - registered_handlers - {"OnConversationStart"} # Trigger kind, not action
|
||||
missing_executors = all_action_kinds - registered_executors - structural_kinds
|
||||
|
||||
if missing_handlers:
|
||||
# Informational, not a failure, as some actions may be future work
|
||||
pass
|
||||
|
||||
# Check that we have handlers for the expected core set
|
||||
core_handlers = registered_handlers & expected_handlers
|
||||
assert len(core_handlers) > 10, "Expected more core handlers to be registered"
|
||||
assert not missing_executors, (
|
||||
f"Missing executors for action kinds used in workflow samples: {sorted(missing_executors)}"
|
||||
)
|
||||
|
||||
@@ -223,3 +223,33 @@ class TestWorkflowStateResetTurn:
|
||||
|
||||
assert state.get("Workflow.Inputs.query") == "test"
|
||||
assert state.get("Workflow.Outputs.result") == "done"
|
||||
|
||||
|
||||
class TestWorkflowStateConversationIdInit:
|
||||
"""Tests that WorkflowState generates a real UUID for System.ConversationId."""
|
||||
|
||||
def test_conversation_id_is_not_default(self):
|
||||
"""System.ConversationId should be a UUID, not 'default'."""
|
||||
import uuid
|
||||
|
||||
state = WorkflowState()
|
||||
conv_id = state.get("System.ConversationId")
|
||||
assert conv_id is not None
|
||||
assert conv_id != "default"
|
||||
uuid.UUID(conv_id) # Raises ValueError if not a valid UUID
|
||||
|
||||
def test_conversations_dict_initialized(self):
|
||||
"""System.conversations should contain an entry matching ConversationId."""
|
||||
state = WorkflowState()
|
||||
conv_id = state.get("System.ConversationId")
|
||||
conversations = state.get("System.conversations")
|
||||
assert conversations is not None
|
||||
assert conv_id in conversations
|
||||
assert conversations[conv_id]["id"] == conv_id
|
||||
assert conversations[conv_id]["messages"] == []
|
||||
|
||||
def test_each_instance_generates_unique_id(self):
|
||||
"""Each WorkflowState instance should have a different ConversationId."""
|
||||
state1 = WorkflowState()
|
||||
state2 = WorkflowState()
|
||||
assert state1.get("System.ConversationId") != state2.get("System.ConversationId")
|
||||
|
||||
Reference in New Issue
Block a user