From ed86baa6cb312e07db552a5a5a4abc2345fcf1d6 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 15 Aug 2025 11:11:35 -0700 Subject: [PATCH] Python: Workflow Edge Groups (#393) * Introducing edge groups * Add conditional and partitioning edge groups; next add samples and tests * Add unit tests * Add samples * Address comments 1 * Address comments 2 * Update conditional edge group to take in cases and default * Minor updates to sample * Collapsing Paritioning Edge group and Conditional Edge group to source edge group * Improve sample clarity * Name consolidation --------- Co-authored-by: Eric Zhu --- .../main/agent_framework/workflow/__init__.py | 2 + .../agent_framework/workflow/__init__.pyi | 4 + .../agent_framework_workflow/__init__.py | 3 + .../agent_framework_workflow/_edge.py | 409 +++++++-- .../agent_framework_workflow/_executor.py | 2 +- .../agent_framework_workflow/_runner.py | 68 +- .../agent_framework_workflow/_typing_utils.py | 4 + .../agent_framework_workflow/_validation.py | 143 ++-- .../agent_framework_workflow/_workflow.py | 121 ++- python/packages/workflow/tests/test_edge.py | 788 +++++++++++++++++- python/packages/workflow/tests/test_runner.py | 22 +- .../workflow/tests/test_validation.py | 15 +- .../workflow/tests/test_workflow_builder.py | 2 +- .../workflow/step_00_foundation_patterns.py | 291 +++++++ .../step_02_simple_workflow_condition.py | 23 +- .../workflow/step_04_simple_group_chat.py | 3 +- .../step_05_simple_group_chat_with_hil.py | 3 +- .../workflow/step_06_map_reduce.py | 1 - 18 files changed, 1672 insertions(+), 232 deletions(-) create mode 100644 python/samples/getting_started/workflow/step_00_foundation_patterns.py diff --git a/python/packages/main/agent_framework/workflow/__init__.py b/python/packages/main/agent_framework/workflow/__init__.py index bbfd05c2a9..971b7fb004 100644 --- a/python/packages/main/agent_framework/workflow/__init__.py +++ b/python/packages/main/agent_framework/workflow/__init__.py @@ -32,6 +32,8 @@ _IMPORTS = [ "InMemoryCheckpointStorage", "CheckpointStorage", "WorkflowCheckpoint", + "Case", + "Default", ] diff --git a/python/packages/main/agent_framework/workflow/__init__.pyi b/python/packages/main/agent_framework/workflow/__init__.pyi index 6506d4a936..e36c6ee461 100644 --- a/python/packages/main/agent_framework/workflow/__init__.pyi +++ b/python/packages/main/agent_framework/workflow/__init__.pyi @@ -6,7 +6,9 @@ from agent_framework_workflow import ( AgentExecutorResponse, AgentRunEvent, AgentRunStreamingEvent, + Case, CheckpointStorage, + Default, Executor, ExecutorCompletedEvent, ExecutorEvent, @@ -34,7 +36,9 @@ __all__ = [ "AgentExecutorResponse", "AgentRunEvent", "AgentRunStreamingEvent", + "Case", "CheckpointStorage", + "Default", "Executor", "ExecutorCompletedEvent", "ExecutorEvent", diff --git a/python/packages/workflow/agent_framework_workflow/__init__.py b/python/packages/workflow/agent_framework_workflow/__init__.py index 1d59918525..8196d5f767 100644 --- a/python/packages/workflow/agent_framework_workflow/__init__.py +++ b/python/packages/workflow/agent_framework_workflow/__init__.py @@ -11,6 +11,7 @@ from ._checkpoint import ( from ._const import ( DEFAULT_MAX_ITERATIONS, ) +from ._edge import Case, Default from ._events import ( AgentRunEvent, AgentRunStreamingEvent, @@ -60,7 +61,9 @@ __all__ = [ "AgentExecutorResponse", "AgentRunEvent", "AgentRunStreamingEvent", + "Case", "CheckpointStorage", + "Default", "EdgeDuplicationError", "Executor", "ExecutorCompletedEvent", diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 0d7ca8bb18..fef5e3376d 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -1,7 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import Callable +import logging +import uuid +from collections import defaultdict +from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import Any, ClassVar from ._executor import Executor @@ -9,6 +13,8 @@ from ._runner_context import Message, RunnerContext from ._shared_state import SharedState from ._workflow_context import WorkflowContext +logger = logging.getLogger(__name__) + class Edge: """Represents a directed edge in a graph.""" @@ -34,10 +40,6 @@ class Edge: self.target = target self._condition = condition - # Edge group is used to group edges that share the same target executor. - # It allows for sending messages to the target executor only when all edges in the group have data. - self._edge_group_ids: list[str] = [] - @property def source_id(self) -> str: """Get the source executor ID.""" @@ -53,27 +55,6 @@ class Edge: """Get the unique ID of the edge.""" return f"{self.source_id}{self.ID_SEPARATOR}{self.target_id}" - def has_edge_group(self) -> bool: - """Check if the edge is part of an edge group.""" - return bool(self._edge_group_ids) - - @classmethod - def source_and_target_from_id(cls, edge_id: str) -> tuple[str, str]: - """Extract the source and target IDs from the edge ID. - - Args: - edge_id (str): The edge ID in the format "source_id->target_id". - - Returns: - tuple[str, str]: A tuple containing the source ID and target ID. - """ - if cls.ID_SEPARATOR not in edge_id: - raise ValueError(f"Invalid edge ID format: {edge_id}") - ids = edge_id.split(cls.ID_SEPARATOR) - if len(ids) != 2: - raise ValueError(f"Invalid edge ID format: {edge_id}") - return ids[0], ids[1] - def can_handle(self, message_data: Any) -> bool: """Check if the edge can handle the given data. @@ -83,11 +64,14 @@ class Edge: Returns: bool: True if the edge can handle the data, False otherwise. """ - if not self._edge_group_ids: - return self.target.can_handle(message_data) + return self.target.can_handle(message_data) - # If the edge is part of an edge group, the target should expect a list of the data type. - return self.target.can_handle([message_data]) + def should_route(self, data: Any) -> bool: + """Determine if message should be routed through this edge based on the condition.""" + if self._condition is None: + return True + + return self._condition(data) async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> None: """Send a message along this edge. @@ -98,57 +82,338 @@ class Edge: ctx (RunnerContext): The context for the runner. """ if not self.can_handle(message.data): + # Caller of this method should ensure that the edge can handle the data. raise RuntimeError(f"Edge {self.id} cannot handle data of type {type(message.data)}.") - if not self._edge_group_ids and self._should_route(message.data): + if self.should_route(message.data): await self.target.execute( message.data, WorkflowContext(self.target.id, [self.source.id], shared_state, ctx) ) - elif self._edge_group_ids: - # Logic: - # 1. If not all edges in the edge group have data in the shared state, - # add the data to the shared state. - # 2. If all edges in the edge group have data in the shared state, - # copy the data to a list and send it to the target executor. - message_list: list[Message] = [] - async with shared_state.hold() as held_shared_state: - has_data = await asyncio.gather( - *(held_shared_state.has_within_hold(edge_id) for edge_id in self._edge_group_ids) - ) - if not all(has_data): - await held_shared_state.set_within_hold(self.id, message) - else: - message_list = [ - await held_shared_state.get_within_hold(edge_id) for edge_id in self._edge_group_ids - ] + [message] - # Remove the data from the shared state after retrieving it - await asyncio.gather( - *(held_shared_state.delete_within_hold(edge_id) for edge_id in self._edge_group_ids) - ) - if message_list: - data_list = [msg.data for msg in message_list] - source_ids = [msg.source_id for msg in message_list] - await self.target.execute(data_list, WorkflowContext(self.target.id, source_ids, shared_state, ctx)) - def _should_route(self, data: Any) -> bool: - """Determine if message should be routed through this edge.""" - if self._condition is None: - return True +class EdgeGroup: + """Represents a group of edges that share some common properties and can be triggered together.""" - return self._condition(data) + def __init__(self) -> None: + """Initialize the edge group.""" + self._id = f"{self.__class__.__name__}/{uuid.uuid4()}" - def set_edge_group(self, edge_group_ids: list[str]) -> None: - """Set the edge group IDs for this edge. + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the edge group. Args: - edge_group_ids (list[str]): A list of edge IDs that belong to the same edge group. + message (Message): The message to send. + shared_state (SharedState): The shared state to use for holding data. + ctx (RunnerContext): The context for the runner. + + Returns: + bool: True if the message was sent successfully, False if the target executor cannot handle the message. + If a message can be delivered but rejected due to a condition, it will still return True. + + Note: + Exception will not be raised if the target executor cannot handle the message. This is because + a source executor can be connected to multiple target executors, and not every target executor may + be able to handle all the messages sent by the source executor. """ - # Validate that the edges in the edge group contain the same target executor as this edge - # TODO(@taochen): An edge cannot be part of multiple edge groups. - # TODO(@taochen): Can an edge have both a condition and an edge group? - if edge_group_ids: - for edge_id in edge_group_ids: - if Edge.source_and_target_from_id(edge_id)[1] != self.target.id: - raise ValueError("All edges in the group must have the same target executor.") - self._edge_group_ids = edge_group_ids + raise NotImplementedError + + @property + def id(self) -> str: + """Get the unique ID of the edge group.""" + return self._id + + @property + def source_executors(self) -> list[Executor]: + """Get the source executor IDs of the edges in the group.""" + raise NotImplementedError + + @property + def target_executors(self) -> list[Executor]: + """Get the target executor IDs of the edges in the group.""" + raise NotImplementedError + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + raise NotImplementedError + + +class SingleEdgeGroup(EdgeGroup): + """Represents a single edge group that contains only one edge. + + A concrete implementation of EdgeGroup that represent a group containing exactly one edge. + """ + + def __init__(self, source: Executor, target: Executor, condition: Callable[[Any], bool] | None = None) -> None: + """Initialize the single edge group with an edge. + + Args: + source (Executor): The source executor. + target (Executor): The target executor that the source executor can send messages to. + condition (Callable[[Any], bool], optional): A condition function that determines + if the edge will pass the data to the target executor. If None, the edge can + will always pass the data to the target executor. + """ + self._edge = Edge(source=source, target=target, condition=condition) + + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the single edge.""" + if message.target_id and message.target_id != self._edge.target_id: + return False + + if self._edge.can_handle(message.data): + await self._edge.send_message(message, shared_state, ctx) + return True + + return False + + @property + def source_executors(self) -> list[Executor]: + """Get the source executor of the edge.""" + return [self._edge.source] + + @property + def target_executors(self) -> list[Executor]: + """Get the target executor of the edge.""" + return [self._edge.target] + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return [self._edge] + + +class FanOutEdgeGroup(EdgeGroup): + """Represents a group of edges that share the same source executor. + + Assembles a Fan-out pattern where multiple edges share the same source executor + and send messages to their respective target executors. + """ + + def __init__( + self, + source: Executor, + targets: Sequence[Executor], + selection_func: Callable[[Any, list[str]], list[str]] | None = None, + ) -> None: + """Initialize the fan-out edge group with a list of edges. + + Args: + source (Executor): The source executor. + targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. + selection_func (Callable[[Any, list[str]], list[str]], optional): A function that selects which target + executors to send messages to. The function takes in the message data and a list of target executor + IDs, and returns a list of selected target executor IDs. + """ + if len(targets) <= 1: + raise ValueError("FanOutEdgeGroup must contain at least two targets.") + self._edges = [Edge(source=source, target=target) for target in targets] + self._target_ids = [edge.target_id for edge in self._edges] + self._target_map = {edge.target_id: edge for edge in self._edges} + self._selection_func = selection_func + + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through all edges in the fan-out edge group.""" + selection_results = ( + self._selection_func(message.data, self._target_ids) if self._selection_func else self._target_ids + ) + if not self._validate_selection_result(selection_results): + raise RuntimeError( + f"Invalid selection result: {selection_results}. " + f"Expected selections to be a subset of valid target executor IDs: {self._target_ids}." + ) + + if message.target_id: + # If the target ID is specified and the selection result contains it, send the message to that edge + if message.target_id in selection_results: + edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None) + if edge and edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True + return False + + # If no target ID, send the message to the selected targets + async def send_to_edge(edge: Edge) -> bool: + """Send the message to the edge at the specified index.""" + if edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True + return False + + tasks = [send_to_edge(self._target_map[target_id]) for target_id in selection_results] + results = await asyncio.gather(*tasks) + return any(results) + + @property + def source_executors(self) -> list[Executor]: + """Get the source executor of the edges in the group.""" + return [self._edges[0].source] + + @property + def target_executors(self) -> list[Executor]: + """Get the target executors of the edges in the group.""" + return [edge.target for edge in self._edges] + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return self._edges + + def _validate_selection_result(self, selection_results: list[str]) -> bool: + """Validate the selection results to ensure all IDs are valid target executor IDs.""" + return all(result in self._target_ids for result in selection_results) + + +class FanInEdgeGroup(EdgeGroup): + """Represents a group of edges that share the same target executor. + + Assembles a Fan-in pattern where multiple edges send messages to a single target executor. + Messages are buffered until all edges in the group have data to send. + """ + + def __init__(self, sources: Sequence[Executor], target: Executor) -> None: + """Initialize the fan-in edge group with a list of edges. + + Args: + sources (Sequence[Executor]): A list of source executors that can send messages to the target executor. + target (Executor): The target executor that receives a list of messages aggregated from all sources. + """ + if len(sources) <= 1: + raise ValueError("FanInEdgeGroup must contain at least two sources.") + self._edges = [Edge(source=source, target=target) for source in sources] + # Buffer to hold messages before sending them to the target executor + # Key is the source executor ID, value is a list of messages + self._buffer: dict[str, list[Message]] = defaultdict(list) + + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through all edges in the fan-in edge group.""" + if message.target_id and message.target_id != self._edges[0].target_id: + return False + + if self._edges[0].can_handle([message.data]): + # If the edge can handle the data, buffer the message + self._buffer[message.source_id].append(message) + else: + # If the edge cannot handle the data, return False + return False + + if self._is_ready_to_send(): + # If all edges in the group have data, send the buffered messages to the target executor + messages_to_send = [msg for edge in self._edges for msg in self._buffer[edge.source_id]] + self._buffer.clear() + # Only trigger one edge to send the messages to avoid duplicate sends + await self._edges[0].send_message( + Message([msg.data for msg in messages_to_send], self.__class__.__name__), + shared_state, + ctx, + ) + + return True + + def _is_ready_to_send(self) -> bool: + """Check if all edges in the group have data to send.""" + return all(self._buffer[edge.source_id] for edge in self._edges) + + @property + def source_executors(self) -> list[Executor]: + """Get the source executors of the edges in the group.""" + return [edge.source for edge in self._edges] + + @property + def target_executors(self) -> list[Executor]: + """Get the target executor of the edges in the group.""" + return [self._edges[0].target] + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return self._edges + + +@dataclass +class Case: + """Represents a single case in the conditional edge group. + + Args: + condition (Callable[[Any], bool]): The condition function for the case. + target (Executor): The target executor for the case. + """ + + condition: Callable[[Any], bool] + target: Executor + + +@dataclass +class Default: + """Represents the default case in the conditional edge group. + + Args: + target (Executor): The target executor for the default case. + """ + + target: Executor + + +class SwitchCaseEdgeGroup(FanOutEdgeGroup): + """Represents a group of edges that assemble a conditional routing pattern. + + This is similar to a switch-case construct: + switch(data): + case condition_1: + edge_1 + break + case condition_2: + edge_2 + break + default: + edge_3 + break + Or equivalently an if-elif-else construct: + if condition_1: + edge_1 + elif condition_2: + edge_2 + else: + edge_4 + """ + + def __init__( + self, + source: Executor, + cases: Sequence[Case | Default], + ) -> None: + """Initialize the conditional edge group with a list of edges. + + Args: + source (Executor): The source executor. + cases (Sequence[Case | Default]): A list of cases for the conditional edge group. + There should be exactly one default case. + """ + if len(cases) < 2: + raise ValueError("SwitchCaseEdgeGroup must contain at least two cases (including the default case).") + + default_case = [isinstance(case, Default) for case in cases] + if sum(default_case) != 1: + raise ValueError("SwitchCaseEdgeGroup must contain exactly one default case.") + + if isinstance(cases[-1], Default): + logger.warning( + "Default case in the conditional edge group is not the last case. " + "This will result in unexpected behavior." + ) + + def selection_func(data: Any, targets: list[str]) -> list[str]: + """Select the target executor based on the conditions.""" + for index, case in enumerate(cases): + if isinstance(case, Default): + return [case.target.id] + if isinstance(case, Case): + try: + if case.condition(data): + return [case.target.id] + except Exception as e: + logger.warning(f"Error occurred while evaluating condition for case {index}: {e}") + + raise RuntimeError("No matching case found in SwitchCaseEdgeGroup.") + + super().__init__(source, [case.target for case in cases], selection_func=selection_func) diff --git a/python/packages/workflow/agent_framework_workflow/_executor.py b/python/packages/workflow/agent_framework_workflow/_executor.py index ea1630eab4..b1ddad61b6 100644 --- a/python/packages/workflow/agent_framework_workflow/_executor.py +++ b/python/packages/workflow/agent_framework_workflow/_executor.py @@ -31,7 +31,7 @@ class Executor: Args: id: A unique identifier for the executor. If None, a new UUID will be generated. """ - self._id = id or str(uuid.uuid4()) + self._id = id or f"{self.__class__.__name__}/{uuid.uuid4()}" self._handlers: dict[type, Callable[[Any, WorkflowContext], Any]] = {} self._discover_handlers() diff --git a/python/packages/workflow/agent_framework_workflow/_runner.py b/python/packages/workflow/agent_framework_workflow/_runner.py index 1688a2f5fa..ac6ca41fb1 100644 --- a/python/packages/workflow/agent_framework_workflow/_runner.py +++ b/python/packages/workflow/agent_framework_workflow/_runner.py @@ -3,10 +3,10 @@ import asyncio import logging from collections import defaultdict -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Sequence from typing import Any -from ._edge import Edge +from ._edge import EdgeGroup from ._events import WorkflowEvent from ._executor import Executor from ._runner_context import Message, RunnerContext @@ -20,7 +20,7 @@ class Runner: def __init__( self, - edges: list[Edge], + edge_groups: Sequence[EdgeGroup], shared_state: SharedState, ctx: RunnerContext, max_iterations: int = 100, @@ -29,13 +29,13 @@ class Runner: """Initialize the runner with edges, shared state, and context. Args: - edges: The edges of the workflow. + edge_groups: The edge groups of the workflow. shared_state: The shared state for the workflow. ctx: The runner context for the workflow. max_iterations: The maximum number of iterations to run. workflow_id: The workflow ID for checkpointing. """ - self._edge_map = self._parse_edges(edges) + self._edge_group_map = self._parse_edge_groups(edge_groups) self._ctx = ctx self._iteration = 0 self._max_iterations = max_iterations @@ -125,26 +125,25 @@ class Runner: async def _run_iteration(self): async def _deliver_messages(source_executor_id: str, messages: list[Message]) -> None: - async def _deliver_messages_inner( - edge: Edge, - messages: list[Message], - ) -> None: - for message in messages: - if message.target_id is not None and message.target_id != edge.target_id: - continue - if not edge.can_handle(message.data): - continue - await edge.send_message(message, self._shared_state, self._ctx) + """Outer loop to concurrently deliver messages from all sources to their targets.""" - associated_edges = self._edge_map.get(source_executor_id, []) - tasks = [asyncio.create_task(_deliver_messages_inner(edge, messages)) for edge in associated_edges] - await asyncio.gather(*tasks) + async def _deliver_message_inner(edge_group: EdgeGroup, message: Message) -> bool: + """Inner loop to deliver a single message through an edge group.""" + return await edge_group.send_message(message, self._shared_state, self._ctx) + + associated_edge_groups = self._edge_group_map.get(source_executor_id, []) + for message in messages: + # Deliver a message through all edge groups associated with the source executor concurrently. + tasks = [_deliver_message_inner(edge_group, message) for edge_group in associated_edge_groups] + results = await asyncio.gather(*tasks) + if not any(results): + logger.warning( + f"Message {message} could not be delivered. " + "This may be due to type incompatibility or no matching targets." + ) messages = await self._ctx.drain_messages() - tasks = [ - asyncio.create_task(_deliver_messages(source_executor_id, messages)) - for source_executor_id, messages in messages.items() - ] + tasks = [_deliver_messages(source_executor_id, messages) for source_executor_id, messages in messages.items()] await asyncio.gather(*tasks) async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | None: @@ -177,10 +176,11 @@ class Runner: Only JSON-serializable dicts should be provided by executors. """ executors: dict[str, Executor] = {} - for edge_list in self._edge_map.values(): - for edge in edge_list: - executors[edge.source.id] = edge.source - executors[edge.target.id] = edge.target + for edge_groups in self._edge_group_map.values(): + for edge_group in edge_groups: + for edge in edge_group.edges: + executors[edge.source.id] = edge.source + executors[edge.target.id] = edge.target for exec_id, executor in executors.items(): state_dict: dict[str, Any] | None = None snapshot = getattr(executor, "snapshot_state", None) @@ -268,16 +268,18 @@ class Runner: except Exception as e: logger.warning(f"Failed to restore shared state from context: {e}") - def _parse_edges(self, edges: list[Edge]) -> dict[str, list[Edge]]: - """Parse the edges of the workflow into a more convenient format. + def _parse_edge_groups(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, list[EdgeGroup]]: + """Parse the edge groups of the workflow into a mapping where each source executor ID maps to its edge groups. Args: - edges: A list of edges in the workflow. + edge_groups: A list of edge groups in the workflow. Returns: - A dictionary mapping each source executor ID to a list of target executor IDs. + A dictionary mapping each source executor ID to a list of edge groups. """ - parsed: defaultdict[str, list[Edge]] = defaultdict(list) - for edge in edges: - parsed[edge.source_id].append(edge) + parsed: defaultdict[str, list[EdgeGroup]] = defaultdict(list) + for group in edge_groups: + for source_executor in group.source_executors: + parsed[source_executor.id].append(group) + return parsed diff --git a/python/packages/workflow/agent_framework_workflow/_typing_utils.py b/python/packages/workflow/agent_framework_workflow/_typing_utils.py index f8547d886e..46500bdf52 100644 --- a/python/packages/workflow/agent_framework_workflow/_typing_utils.py +++ b/python/packages/workflow/agent_framework_workflow/_typing_utils.py @@ -13,6 +13,10 @@ def is_instance_of(data: Any, target_type: type) -> bool: Returns: bool: True if data is an instance of target_type, False otherwise. """ + # Case 0: target_type is Any - always return True + if target_type is Any: + return True + origin = get_origin(target_type) args = get_args(target_type) diff --git a/python/packages/workflow/agent_framework_workflow/_validation.py b/python/packages/workflow/agent_framework_workflow/_validation.py index 0ce9cd2e76..3c4d8c12fd 100644 --- a/python/packages/workflow/agent_framework_workflow/_validation.py +++ b/python/packages/workflow/agent_framework_workflow/_validation.py @@ -3,10 +3,11 @@ import inspect import logging from collections import defaultdict +from collections.abc import Sequence from enum import Enum from typing import Any, Union, get_args, get_origin -from ._edge import Edge +from ._edge import Edge, EdgeGroup, FanInEdgeGroup from ._executor import Executor logger = logging.getLogger(__name__) @@ -92,18 +93,19 @@ class WorkflowGraphValidator: self._executors: dict[str, Executor] = {} # region Core Validation Methods - def validate_workflow(self, edges: list[Edge], start_executor: Executor | str) -> None: + def validate_workflow(self, edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None: """Validate the entire workflow graph. Args: - edges: list of edges in the workflow + edge_groups: list of edge groups in the workflow start_executor: The starting executor (can be instance or ID) Raises: WorkflowValidationError: If any validation fails """ - self._edges = edges - self._executors = self._build_executor_map(edges) + self._executors = self._build_executor_map(edge_groups) + self._edges = [edge for group in edge_groups for edge in group.edges] + self._edge_groups = edge_groups # Validate that start_executor exists in the graph # It should because we check for it in the WorkflowBuilder @@ -121,12 +123,13 @@ class WorkflowGraphValidator: self._validate_dead_ends() self._validate_cycles() - def _build_executor_map(self, edges: list[Edge]) -> dict[str, Executor]: + def _build_executor_map(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, Executor]: """Build a map of executor IDs to executor instances.""" executors: dict[str, Executor] = {} - for edge in edges: - executors[edge.source_id] = edge.source - executors[edge.target_id] = edge.target + for group in edge_groups: + for executor in group.source_executors + group.target_executors: + executors[executor.id] = executor + return executors # endregion @@ -155,64 +158,80 @@ class WorkflowGraphValidator: Raises: TypeCompatibilityError: If type incompatibility is detected """ - for edge in self._edges: - source_executor = edge.source - target_executor = edge.target + for edge_group in self._edge_groups: + for edge in edge_group.edges: + self._validate_edge_type_compatibility(edge, edge_group) - # Get output types from source executor - source_output_types = self._get_executor_output_types(source_executor) + def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) -> None: + """Validate type compatibility for a specific edge. - # Get input types from target executor - target_input_types = self._get_executor_input_types(target_executor) + This checks that the output types of the source executor are compatible + with the input types expected by the target executor. - # If either executor has no type information, log warning and skip validation - # This allows for dynamic typing scenarios but warns about reduced validation coverage - if not source_output_types or not target_input_types: - if not source_output_types: - logger.warning( - f"Executor '{source_executor.id}' has no output type annotations. " - f"Type compatibility validation will be skipped for edges from this executor. " - f"Consider adding output_types to @handler decorators for better validation." - ) - if not target_input_types: - logger.warning( - f"Executor '{target_executor.id}' has no input type annotations. " - f"Type compatibility validation will be skipped for edges to this executor. " - f"Consider adding type annotations to message handler parameters for better validation." - ) - continue + Args: + edge: The edge to validate + edge_group: The edge group containing this edge - # Check if any source output type is compatible with any target input type - compatible = False - compatible_pairs: list[tuple[type[Any], type[Any]]] = [] + Raises: + TypeCompatibilityError: If type incompatibility is detected + """ + source_executor = edge.source + target_executor = edge.target - for source_type in source_output_types: - for target_type in target_input_types: - if edge.has_edge_group(): - # If the edge is part of an edge group, the target expects a list of data types - if self._is_type_compatible(list[source_type], target_type): - compatible = True - compatible_pairs.append((list[source_type], target_type)) - else: - if self._is_type_compatible(source_type, target_type): - compatible = True - compatible_pairs.append((source_type, target_type)) + # Get output types from source executor + source_output_types = self._get_executor_output_types(source_executor) - # Log successful type compatibility for debugging - if compatible: - logger.debug( - f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. " - f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}" + # Get input types from target executor + target_input_types = self._get_executor_input_types(target_executor) + + # If either executor has no type information, log warning and skip validation + # This allows for dynamic typing scenarios but warns about reduced validation coverage + if not source_output_types or not target_input_types: + if not source_output_types: + logger.warning( + f"Executor '{source_executor.id}' has no output type annotations. " + f"Type compatibility validation will be skipped for edges from this executor. " + f"Consider adding output_types to @handler decorators for better validation." ) - - if not compatible: - # Enhanced error with more detailed information - raise TypeCompatibilityError( - source_executor.id, - target_executor.id, - source_output_types, - target_input_types, + if not target_input_types: + logger.warning( + f"Executor '{target_executor.id}' has no input type annotations. " + f"Type compatibility validation will be skipped for edges to this executor. " + f"Consider adding type annotations to message handler parameters for better validation." ) + return + + # Check if any source output type is compatible with any target input type + compatible = False + compatible_pairs: list[tuple[type[Any], type[Any]]] = [] + + for source_type in source_output_types: + for target_type in target_input_types: + if isinstance(edge_group, FanInEdgeGroup): + # If the edge is part of an edge group, the target expects a list of data types + if self._is_type_compatible(list[source_type], target_type): + compatible = True + compatible_pairs.append((list[source_type], target_type)) + else: + if self._is_type_compatible(source_type, target_type): + compatible = True + compatible_pairs.append((source_type, target_type)) + + # Log successful type compatibility for debugging + if compatible: + logger.debug( + f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. " + f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}" + ) + + if not compatible: + # Enhanced error with more detailed information + raise TypeCompatibilityError( + source_executor.id, + target_executor.id, + source_output_types, + target_input_types, + ) def _get_executor_output_types(self, executor: Executor) -> list[type[Any]]: """Extract output types from an executor's message handlers. @@ -479,15 +498,15 @@ class WorkflowGraphValidator: # endregion -def validate_workflow_graph(edges: list[Edge], start_executor: Executor | str) -> None: +def validate_workflow_graph(edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None: """Convenience function to validate a workflow graph. Args: - edges: list of edges in the workflow + edge_groups: list of edge groups in the workflow start_executor: The starting executor (can be instance or ID) Raises: WorkflowValidationError: If any validation fails """ validator = WorkflowGraphValidator() - validator.validate_workflow(edges, start_executor) + validator.validate_workflow(edge_groups, start_executor) diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index bc3ae00f19..f2b80ce953 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -9,7 +9,15 @@ from typing import Any from ._checkpoint import CheckpointStorage from ._const import DEFAULT_MAX_ITERATIONS -from ._edge import Edge +from ._edge import ( + Case, + Default, + EdgeGroup, + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) from ._events import RequestInfoEvent, WorkflowCompletedEvent, WorkflowEvent from ._executor import Executor, RequestInfoExecutor from ._runner import Runner @@ -67,7 +75,7 @@ class Workflow: def __init__( self, - edges: list[Edge], + edge_groups: list[EdgeGroup], start_executor: Executor | str, runner_context: RunnerContext, max_iterations: int, @@ -75,28 +83,29 @@ class Workflow: """Initialize the workflow with a list of edges. Args: - edges: A list of directed edges representing the connections between nodes in the workflow. + edge_groups: A list of EdgeGroup instances that define the workflow edges. start_executor: The starting executor for the workflow, which can be an Executor instance or its ID. runner_context: The RunnerContext instance to be used during workflow execution. max_iterations: The maximum number of iterations the workflow will run for convergence. """ - self._edges = edges + self._edge_groups = edge_groups + self._executors = self._build_executor_map(edge_groups) self._start_executor = start_executor - self._executors = {edge.source_id: edge.source for edge in edges} | { - edge.target_id: edge.target for edge in edges - } self._shared_state = SharedState() - workflow_id = str(uuid.uuid4()) self._runner = Runner( - self._edges, self._shared_state, runner_context, max_iterations=max_iterations, workflow_id=workflow_id + self._edge_groups, + self._shared_state, + runner_context, + max_iterations=max_iterations, + workflow_id=workflow_id, ) @property - def edges(self) -> list[Edge]: - """Get the list of edges in the workflow.""" - return self._edges + def edge_groups(self) -> list[EdgeGroup]: + """Get the list of edge groups in the workflow.""" + return self._edge_groups @property def start_executor(self) -> Executor: @@ -298,6 +307,22 @@ class Workflow: raise ValueError(f"Executor with ID {executor_id} not found.") return self._executors[executor_id] + def _build_executor_map(self, edge_groups: list[EdgeGroup]) -> dict[str, Executor]: + """Build the executor map from edge groups. + + Args: + edge_groups: A list of EdgeGroup instances. + + Returns: + A dictionary mapping executor IDs to Executor instances. + """ + executors: dict[str, Executor] = {} + for group in edge_groups: + for executor in group.source_executors + group.target_executors: + executors[executor.id] = executor + + return executors + async def _restore_from_external_checkpoint( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage ) -> bool: @@ -405,7 +430,7 @@ class WorkflowBuilder: def __init__(self, max_iterations: int = DEFAULT_MAX_ITERATIONS): """Initialize the WorkflowBuilder with an empty list of edges and no starting executor.""" - self._edges: list[Edge] = [] + self._edge_groups: list[EdgeGroup] = [] self._start_executor: Executor | str | None = None self._checkpoint_storage: CheckpointStorage | None = None self._max_iterations: int = max_iterations @@ -427,21 +452,66 @@ class WorkflowBuilder: should be traversed based on the message type. """ # TODO(@taochen): Support executor factories for lazy initialization - self._edges.append(Edge(source, target, condition)) + self._edge_groups.append(SingleEdgeGroup(source, target, condition)) return self def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "Self": - """Add multiple edges to the workflow. + """Add multiple edges to the workflow where messages from the source will be sent to all target. The output types of the source and the input types of the targets must be compatible. - Messages from the source executor will be sent to all target executors. Args: source: The source executor of the edges. targets: A list of target executors for the edges. """ - for target in targets: - self._edges.append(Edge(source, target)) + self._edge_groups.append(FanOutEdgeGroup(source, targets)) + + return self + + def add_switch_case_edge_group(self, source: Executor, cases: Sequence[Case | Default]) -> "Self": + """Add an edge group that represents a switch-case statement. + + The output types of the source and the input types of the targets must be compatible. + Messages from the source executor will be sent to one of the target executors based on + the provided conditions. + + Think of this as a switch statement where each target executor corresponds to a case. + Each condition function will be evaluated in order, and the first one that returns True + will determine which target executor receives the message. + + The last case (the default case) will receive messages that fall through all conditions + (i.e., no condition matched). + + Args: + source: The source executor of the edges. + cases: A list of case objects that determine the target executor for each message. + """ + self._edge_groups.append(SwitchCaseEdgeGroup(source, cases)) + + return self + + def add_multi_selection_edge_group( + self, + source: Executor, + targets: Sequence[Executor], + selection_func: Callable[[Any, list[str]], list[str]], + ) -> "Self": + """Add an edge group that represents a multi-selection execution model. + + The output types of the source and the input types of the targets must be compatible. + Messages from the source executor will be sent to multiple target executors based on + the provided selection function. + + The selection function should take a message and the name of the target executors, + and return a list of indices indicating which target executors should receive the message. + + Args: + source: The source executor of the edges. + targets: A list of target executors for the edges. + selection_func: A function that selects target executors for messages. + """ + self._edge_groups.append(FanOutEdgeGroup(source, targets, selection_func)) + return self def add_fan_in_edges(self, sources: Sequence[Executor], target: Executor) -> "Self": @@ -478,16 +548,7 @@ class WorkflowBuilder: sources: A list of source executors for the edges. target: The target executor for the edges. """ - edges = [Edge(source, target) for source in sources] - - # Set the edge groups for the edges to ensure they are processed together. - for i, edge in enumerate(edges): - group_ids: list[str] = [] - group_ids.extend([e.id for e in edges[0:i]]) - group_ids.extend([e.id for e in edges[i + 1 :]]) - edge.set_edge_group(group_ids) - - self._edges.extend(edges) + self._edge_groups.append(FanInEdgeGroup(sources, target)) return self @@ -549,11 +610,11 @@ class WorkflowBuilder: if not self._start_executor: raise ValueError("Starting executor must be set before building the workflow.") - validate_workflow_graph(self._edges, self._start_executor) + validate_workflow_graph(self._edge_groups, self._start_executor) context = InProcRunnerContext(self._checkpoint_storage) - return Workflow(self._edges, self._start_executor, context, self._max_iterations) + return Workflow(self._edge_groups, self._start_executor, context, self._max_iterations) # endregion diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index b1c41c4470..ddef2cd649 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -2,10 +2,20 @@ from dataclasses import dataclass from typing import Any +from unittest.mock import patch +import pytest from agent_framework.workflow import Executor, WorkflowContext, handler -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import ( + Case, + Default, + Edge, + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) @dataclass @@ -15,6 +25,13 @@ class MockMessage: data: Any +@dataclass +class MockMessageSecondary: + """A secondary mock message for testing purposes.""" + + data: Any + + class MockExecutor(Executor): """A mock executor for testing purposes.""" @@ -24,6 +41,27 @@ class MockExecutor(Executor): pass +class MockExecutorSecondary(Executor): + """A secondary mock executor for testing purposes.""" + + @handler + async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None: + """A secondary mock handler that does nothing.""" + pass + + +class MockAggregator(Executor): + """A mock aggregator for testing purposes.""" + + @handler + async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None: + """A mock aggregator handler that does nothing.""" + pass + + +# region Edge + + def test_create_edge(): """Test creating an edge with a source and target executor.""" source = MockExecutor(id="source_executor") @@ -34,7 +72,6 @@ def test_create_edge(): assert edge.source_id == "source_executor" assert edge.target_id == "target_executor" assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}" - assert (edge.source_id, edge.target_id) == Edge.source_and_target_from_id(edge.id) def test_edge_can_handle(): @@ -45,3 +82,750 @@ def test_edge_can_handle(): edge = Edge(source=source, target=target) assert edge.can_handle(MockMessage(data="test")) + + +# endregion Edge + +# region SingleEdgeGroup + + +def test_single_edge_group(): + """Test creating a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target] + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor" + + +def test_single_edge_group_with_condition(): + """Test creating a single edge group with a condition.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target, condition=lambda x: x.data == "test") + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target] + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor" + assert edge_group.edges[0]._condition is not None # type: ignore + + +async def test_single_edge_group_send_message(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_single_edge_group_send_message_with_target(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_single_edge_group_send_message_with_invalid_target(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_single_edge_group_send_message_with_invalid_data(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion SingleEdgeGroup + + +# region FanOutEdgeGroup + + +def test_source_edge_group(): + """Test creating a fan-out group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + + +def test_source_edge_group_invalid_number_of_targets(): + """Test creating a fan-out group with an invalid number of targets.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"): + FanOutEdgeGroup(source=source, targets=[target]) + + +async def test_source_edge_group_send_message(): + """Test sending a message through a fan-out group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 2 + + +async def test_source_edge_group_send_message_with_target(): + """Test sending a message through a fan-out group with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target1.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + assert mock_send.call_args[0][0].target_id == target1.id + + +async def test_source_edge_group_send_message_with_invalid_target(): + """Test sending a message through a fan-out group with an invalid target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_send_message_with_invalid_data(): + """Test sending a message through a fan-out group with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_send_message_only_one_successful_send(): + """Test sending a message through a fan-out group where only one edge can handle the message.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutorSecondary(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + +def test_source_edge_group_with_selection_func(): + """Test creating a partitioning edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id], + ) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + + +async def test_source_edge_group_with_selection_func_send_message(): + """Test sending a message through a fan-out group with a selection function.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, target2.id], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 2 + + +async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result(): + """Test sending a message through a fan-out group with a selection func with an invalid selection result.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, "invalid_target"], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with pytest.raises(RuntimeError): + await edge_group.send_message(message, shared_state, ctx) + + +async def test_source_edge_group_with_selection_func_send_message_with_target(): + """Test sending a message through a fan-out group with a selection func with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, target2.id], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target1.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + assert mock_send.call_args[0][0].target_id == target1.id + + +async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection(): + """Test sending a message through a fan-out group with a selection func with a target not in the selection.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target2.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_with_selection_func_send_message_with_invalid_data(): + """Test sending a message through a fan-out group with a selection func with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data(): + """Test sending a message through a fan-out group with a selection func with a target and invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id, target_id=target1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion FanOutEdgeGroup + +# region FanInEdgeGroup + + +def test_target_edge_group(): + """Test creating a fan-in edge group.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + assert edge_group.source_executors == [source1, source2] + assert edge_group.target_executors == [target] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor_1" + assert edge_group.edges[0].target_id == "target_executor" + assert edge_group.edges[1].source_id == "source_executor_2" + assert edge_group.edges[1].target_id == "target_executor" + + +def test_target_edge_group_invalid_number_of_sources(): + """Test creating a fan-in edge group with an invalid number of sources.""" + source = MockExecutor(id="source_executor") + target = MockAggregator(id="target_executor") + + with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"): + FanInEdgeGroup(sources=[source], target=target) + + +async def test_target_edge_group_send_message_buffer(): + """Test sending a message through a fan-in edge group with buffering.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message( + Message(data=data, source_id=source1.id), + shared_state, + ctx, + ) + + assert success is True + assert mock_send.call_count == 0 # The message should be buffered and wait for the second source + assert len(edge_group._buffer[source1.id]) == 1 # type: ignore + + success = await edge_group.send_message( + Message(data=data, source_id=source2.id), + shared_state, + ctx, + ) + assert success is True + assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages + + # Buffer should be cleared after sending + assert not edge_group._buffer # type: ignore + + +async def test_target_edge_group_send_message_with_invalid_target(): + """Test sending a message through a fan-in edge group with an invalid target.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source1.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_target_edge_group_send_message_with_invalid_data(): + """Test sending a message through a fan-in edge group with invalid data.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion FanInEdgeGroup + +# region SwitchCaseEdgeGroup + + +def test_switch_case_edge_group(): + """Test creating a switch case edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + + assert edge_group._selection_func is not None # type: ignore + assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore + assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore + + +def test_switch_case_edge_group_invalid_number_of_cases(): + """Test creating a switch case edge group with an invalid number of cases.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + with pytest.raises( + ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)." + ): + SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target), + ], + ) + + with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): + SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target), + Case(condition=lambda x: x.data >= 0, target=target), + ], + ) + + +def test_switch_case_edge_group_invalid_number_of_default_cases(): + """Test creating a switch case edge group with an invalid number of conditions.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): + SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + Default(target=target2), + ], + ) + + +async def test_switch_case_edge_group_send_message(): + """Test sending a message through a switch case edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=-1) + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + # Default condition should + data = MockMessage(data=1) + message = Message(data=data, source_id=source.id) + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + +async def test_switch_case_edge_group_send_message_with_invalid_target(): + """Test sending a message through a switch case edge group with an invalid target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=-1) + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_switch_case_edge_group_send_message_with_valid_target(): + """Test sending a message through a switch case edge group with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=1) # Condition will fail + message = Message(data=data, source_id=source.id, target_id=target1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + data = MockMessage(data=-1) # Condition will pass + message = Message(data=data, source_id=source.id, target_id=target1.id) + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_switch_case_edge_group_send_message_with_invalid_data(): + """Test sending a message through a switch case edge group with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion SwitchCaseEdgeGroup diff --git a/python/packages/workflow/tests/test_runner.py b/python/packages/workflow/tests/test_runner.py index a4a1abb43c..d4e4ef79ad 100644 --- a/python/packages/workflow/tests/test_runner.py +++ b/python/packages/workflow/tests/test_runner.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import pytest from agent_framework.workflow import Executor, WorkflowCompletedEvent, WorkflowContext, WorkflowEvent, handler -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import SingleEdgeGroup from agent_framework_workflow._runner import Runner from agent_framework_workflow._runner_context import InProcRunnerContext, RunnerContext from agent_framework_workflow._shared_state import SharedState @@ -36,12 +36,12 @@ def test_create_runner(): executor_b = MockExecutor(id="executor_b") # Create a loop - edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + edge_groups = [ + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] - runner = Runner(edges, shared_state=SharedState(), ctx=InProcRunnerContext()) + runner = Runner(edge_groups, shared_state=SharedState(), ctx=InProcRunnerContext()) assert runner.context is not None and isinstance(runner.context, RunnerContext) @@ -53,8 +53,8 @@ async def test_runner_run_until_convergence(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() @@ -87,8 +87,8 @@ async def test_runner_run_until_convergence_not_completed(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() @@ -117,8 +117,8 @@ async def test_runner_already_running(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() diff --git a/python/packages/workflow/tests/test_validation.py b/python/packages/workflow/tests/test_validation.py index 23a14a4490..8cd6e78c2f 100644 --- a/python/packages/workflow/tests/test_validation.py +++ b/python/packages/workflow/tests/test_validation.py @@ -17,7 +17,7 @@ from agent_framework_workflow import ( handler, validate_workflow_graph, ) -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import SingleEdgeGroup class StringExecutor(Executor): @@ -159,10 +159,13 @@ def test_graph_connectivity_isolated_executors(): executor3 = StringExecutor(id="executor3") # This will be isolated # Create edges that include an isolated executor (self-loop that's not connected to main graph) - edges = [Edge(executor1, executor2), Edge(executor3, executor3)] # Self-loop to include in graph + edge_groups = [ + SingleEdgeGroup(executor1, executor2), + SingleEdgeGroup(executor3, executor3), + ] # Self-loop to include in graph with pytest.raises(GraphConnectivityError) as exc_info: - validate_workflow_graph(edges, executor1) + validate_workflow_graph(edge_groups, executor1) assert "unreachable" in str(exc_info.value).lower() assert "executor3" in str(exc_info.value) @@ -239,15 +242,15 @@ def test_type_compatibility_inheritance(): def test_direct_validation_function(): executor1 = StringExecutor(id="executor1") executor2 = StringExecutor(id="executor2") - edges = [Edge(executor1, executor2)] + edge_groups = [SingleEdgeGroup(executor1, executor2)] # This should not raise any exceptions - validate_workflow_graph(edges, executor1) + validate_workflow_graph(edge_groups, executor1) # Test with invalid start executor executor3 = StringExecutor(id="executor3") with pytest.raises(GraphConnectivityError): - validate_workflow_graph(edges, executor3) + validate_workflow_graph(edge_groups, executor3) def test_fan_out_validation(): diff --git a/python/packages/workflow/tests/test_workflow_builder.py b/python/packages/workflow/tests/test_workflow_builder.py index 5135314485..55b3d32ec1 100644 --- a/python/packages/workflow/tests/test_workflow_builder.py +++ b/python/packages/workflow/tests/test_workflow_builder.py @@ -60,6 +60,6 @@ def test_workflow_builder_fluent_api(): .build() ) - assert len(workflow.edges) == 6 + assert len(workflow.edge_groups) == 4 assert workflow.start_executor.id == executor_a.id assert len(workflow.executors) == 6 diff --git a/python/samples/getting_started/workflow/step_00_foundation_patterns.py b/python/samples/getting_started/workflow/step_00_foundation_patterns.py new file mode 100644 index 0000000000..2827d84fab --- /dev/null +++ b/python/samples/getting_started/workflow/step_00_foundation_patterns.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any + +from agent_framework.workflow import Case, Default, Executor, WorkflowBuilder, WorkflowContext, handler + +""" +The following sample demonstrates the foundation patterns that the workflow framework supports. +These patterns include: +- Single connection +- Single connection with condition +- Fan-out and fan-in connections +- Conditional fan-out connections +- Partitioning fan-out connections + +The samples here use numbers and simple arithmetic operations to demonstrate the patterns. +""" + + +class AddOneExecutor(Executor): + """An executor that processes a number by adding one.""" + + @handler(output_types=[int]) + async def add_one(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by adding one to the input number.""" + result = number + 1 + + # Send the result to the next executor in the workflow. + await ctx.send_message(result) + + print("Adding one to the number:", number, "Result:", result) + + +class MultiplyByTwoExecutor(Executor): + """An executor that processes a number by multiplying it by two.""" + + @handler(output_types=[int]) + async def multiply_by_two(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by multiplying the input number by two.""" + result = number * 2 + + # Send the result to the next executor in the workflow. + await ctx.send_message(result) + + print("Multiplying the number by two:", number, "Result:", result) + + +class DivideByTwoExecutor(Executor): + """An executor that processes a number by dividing it by two.""" + + @handler(output_types=[float]) + async def divide_by_two(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by dividing the input number by two.""" + result = number / 2 + + # Send the result with a workflow completion event. + await ctx.send_message(result) + + print("Dividing the number by two:", number, "Result:", result) + + +class AggregateResultExecutor(Executor): + """An executor that receives results and prints them.""" + + @handler + async def aggregate_results(self, results: Any, ctx: WorkflowContext) -> None: + """Print whatever results are received.""" + print("Aggregating results:", results) + + +async def single_edge(): + """A sample to demonstrate a single directed connection between two executors. + + Three executors are connected in a sequence: AddOneExecutor -> AddOneExecutor -> AggregateResultExecutor. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Aggregating results: 3 + """ + add_one_executor_a = AddOneExecutor() + add_one_executor_b = AddOneExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_edge(add_one_executor_a, add_one_executor_b) + .add_edge(add_one_executor_b, aggregate_result_executor) + .set_start_executor(add_one_executor_a) + .build() + ) + + await workflow.run(1) + + +async def single_edge_with_condition(): + """A sample to demonstrate a single directed connection with a condition. + + Three executors are connected: AddOneExecutor -> AddOneExecutor, AggregateResultExecutor. + The AddOneExecutor will loop back to itself until the number reaches 10, then it will start + sending the result to AggregateResultExecutor when the number is greater than 8. The workflow + stops when the number reaches 11. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Aggregating results: 9 + Adding one to the number: 10 Result: 11 + Aggregating results: 10 + Aggregating results: 11 + """ + add_one_executor_a = AddOneExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_edge(add_one_executor_a, add_one_executor_a, condition=lambda x: x < 11) + .add_edge(add_one_executor_a, aggregate_result_executor, condition=lambda x: x > 8) + .set_start_executor(add_one_executor_a) + .build() + ) + + await workflow.run(1) + + +async def fan_out_fan_in_edge_group(): + """A sample to demonstrate a fan-out and fan-in connection between executors. + + Four executors are connected in a fan-out and fan-in pattern: + AddOneExecutor -> MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + The AddOneExecutor sends its output to both MultiplyByTwoExecutor and DivideByTwoExecutor, + and both of these executors send their results to AggregateResultExecutor. + + The target of the fan-in connection will wait for all the results from the sources before proceeding. + + Expected output: + Adding one to the number: 1 Result: 2 + Multiplying the number by two: 2 Result: 4 + Dividing the number by two: 2 Result: 1.0 + Aggregating results: [4, 1.0] + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_fan_out_edges(add_one_executor, [multiply_by_two_executor, divide_by_two_executor]) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .set_start_executor(add_one_executor) + .build() + ) + + await workflow.run(1) + + +async def switch_case_edge_group(): + """A sample to demonstrate a switch-case connection. + + Four executors are connected in a switch-case pattern: + AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + + The message from AddOneExecutor will be evaluated against the conditions one by one, and the first condition + that evaluates to True will determine the target executors. If no conditions match, the message will be sent + to the last targets. + + This pattern resembles a switch-case statement with a default case where the first matching case is executed. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Adding one to the number: 10 Result: 11 + Multiplying the number by two: 11 Result: 22 + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .set_start_executor(add_one_executor) + .add_switch_case_edge_group( + source=add_one_executor, + cases=[ + # Loop back to the add_one_executor if the number is less than 11 + Case(condition=lambda x: x < 11, target=add_one_executor), + # multiply_by_two_executor when the number is larger than or equal to 11 and even. + Case(condition=lambda x: x % 2 == 0, target=multiply_by_two_executor), + # Otherwise, send to the divide_by_two_executor. + Default(target=divide_by_two_executor), + ], + ) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .build() + ) + + await workflow.run(1) + + +async def multi_selection_edge_group(): + """A sample to demonstrate a multi-selection edge connection. + + Four executors are connected in a multi-selection edge pattern: + AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + + The AddOneExecutor sends its output to one or more executors based on the partitioning function. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Adding one to the number: 10 Result: 11 + Adding one to the number: 11 Result: 12 + Adding one to the number: 12 Result: 13 + Dividing the number by two: 12 Result: 6.0 + Multiplying the number by two: 13 Result: 26 + Aggregating results: [26, 6.0] + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + def selection_func(number: int, target_ids: list[str]) -> list[str]: + """Selection function to determine which executor to send the number to.""" + if number < 12: + # Loop back to the add_one_executor if the number is less than 12 + return [add_one_executor.id] + + if number % 2 == 0: + # Send it to the add_one_executor to add one more time and the + # divide_by_two_executor to divide the result by two. + return [add_one_executor.id, divide_by_two_executor.id] + + # Otherwise, send it to the multiply_by_two_executor to multiply the result by two. + return [multiply_by_two_executor.id] + + workflow = ( + WorkflowBuilder() + .set_start_executor(add_one_executor) + .add_multi_selection_edge_group( + add_one_executor, + [add_one_executor, multiply_by_two_executor, divide_by_two_executor], + selection_func=selection_func, + ) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .build() + ) + + await workflow.run(1) + + +async def main(): + """Main function to run the workflows.""" + print("**Running single connection workflow**") + await single_edge() + print("**Running single connection with condition workflow**") + await single_edge_with_condition() + print("**Running fan-out and fan-in connection workflow**") + await fan_out_fan_in_edge_group() + print("**Running conditional fan-out connection workflow**") + await switch_case_edge_group() + print("**Running multi-selection edge group workflow**") + await multi_selection_edge_group() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py index c6f4f21f60..3a526ce7b7 100644 --- a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py +++ b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py @@ -3,7 +3,15 @@ import asyncio from dataclasses import dataclass -from agent_framework.workflow import Executor, WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, handler +from agent_framework.workflow import ( + Case, + Default, + Executor, + WorkflowBuilder, + WorkflowCompletedEvent, + WorkflowContext, + handler, +) """ The following sample demonstrates a basic workflow with two executors @@ -91,15 +99,12 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(spam_detector) - .add_edge( + .add_switch_case_edge_group( spam_detector, - send_response, - condition=lambda x: x.is_spam is False, - ) - .add_edge( - spam_detector, - remove_spam, - condition=lambda x: x.is_spam is True, + [ + Case(condition=lambda x: x.is_spam, target=remove_spam), + Default(target=send_response), + ], ) .build() ) diff --git a/python/samples/getting_started/workflow/step_04_simple_group_chat.py b/python/samples/getting_started/workflow/step_04_simple_group_chat.py index 0e0d6eda8d..496f506b6b 100644 --- a/python/samples/getting_started/workflow/step_04_simple_group_chat.py +++ b/python/samples/getting_started/workflow/step_04_simple_group_chat.py @@ -120,8 +120,7 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(group_chat_manager) - .add_edge(group_chat_manager, writer) - .add_edge(group_chat_manager, reviewer) + .add_fan_out_edges(group_chat_manager, [writer, reviewer]) .add_edge(writer, group_chat_manager) .add_edge(reviewer, group_chat_manager) .build() diff --git a/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py b/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py index 1c30404426..f10eb41d47 100644 --- a/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py +++ b/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py @@ -167,8 +167,7 @@ async def main(): .set_start_executor(group_chat_manager) .add_edge(group_chat_manager, request_info_executor) .add_edge(request_info_executor, group_chat_manager) - .add_edge(group_chat_manager, writer) - .add_edge(group_chat_manager, reviewer) + .add_fan_out_edges(group_chat_manager, [writer, reviewer]) .add_edge(writer, group_chat_manager) .add_edge(reviewer, group_chat_manager) .build() diff --git a/python/samples/getting_started/workflow/step_06_map_reduce.py b/python/samples/getting_started/workflow/step_06_map_reduce.py index e7665f67ed..929cb5f4c7 100644 --- a/python/samples/getting_started/workflow/step_06_map_reduce.py +++ b/python/samples/getting_started/workflow/step_06_map_reduce.py @@ -113,7 +113,6 @@ class Map(Executor): ctx: The execution context containing the shared state and other information. """ # Retrieve the data to be processed from the shared state. - # Define a key for the shared state to store the data to be processed data_to_be_processed: list[str] = await ctx.get_shared_state(SHARED_STATE_DATA_KEY) chunk_start, chunk_end = await ctx.get_shared_state(self.id)