mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Workflow Edge Groups (#393)
* Introducing edge groups * Add conditional and partitioning edge groups; next add samples and tests * Add unit tests * Add samples * Address comments 1 * Address comments 2 * Update conditional edge group to take in cases and default * Minor updates to sample * Collapsing Paritioning Edge group and Conditional Edge group to source edge group * Improve sample clarity * Name consolidation --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
19d91bb950
commit
ed86baa6cb
@@ -32,6 +32,8 @@ _IMPORTS = [
|
||||
"InMemoryCheckpointStorage",
|
||||
"CheckpointStorage",
|
||||
"WorkflowCheckpoint",
|
||||
"Case",
|
||||
"Default",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from agent_framework_workflow import (
|
||||
AgentExecutorResponse,
|
||||
AgentRunEvent,
|
||||
AgentRunStreamingEvent,
|
||||
Case,
|
||||
CheckpointStorage,
|
||||
Default,
|
||||
Executor,
|
||||
ExecutorCompletedEvent,
|
||||
ExecutorEvent,
|
||||
@@ -34,7 +36,9 @@ __all__ = [
|
||||
"AgentExecutorResponse",
|
||||
"AgentRunEvent",
|
||||
"AgentRunStreamingEvent",
|
||||
"Case",
|
||||
"CheckpointStorage",
|
||||
"Default",
|
||||
"Executor",
|
||||
"ExecutorCompletedEvent",
|
||||
"ExecutorEvent",
|
||||
|
||||
@@ -11,6 +11,7 @@ from ._checkpoint import (
|
||||
from ._const import (
|
||||
DEFAULT_MAX_ITERATIONS,
|
||||
)
|
||||
from ._edge import Case, Default
|
||||
from ._events import (
|
||||
AgentRunEvent,
|
||||
AgentRunStreamingEvent,
|
||||
@@ -60,7 +61,9 @@ __all__ = [
|
||||
"AgentExecutorResponse",
|
||||
"AgentRunEvent",
|
||||
"AgentRunStreamingEvent",
|
||||
"Case",
|
||||
"CheckpointStorage",
|
||||
"Default",
|
||||
"EdgeDuplicationError",
|
||||
"Executor",
|
||||
"ExecutorCompletedEvent",
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ._executor import Executor
|
||||
@@ -9,6 +13,8 @@ from ._runner_context import Message, RunnerContext
|
||||
from ._shared_state import SharedState
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Edge:
|
||||
"""Represents a directed edge in a graph."""
|
||||
@@ -34,10 +40,6 @@ class Edge:
|
||||
self.target = target
|
||||
self._condition = condition
|
||||
|
||||
# Edge group is used to group edges that share the same target executor.
|
||||
# It allows for sending messages to the target executor only when all edges in the group have data.
|
||||
self._edge_group_ids: list[str] = []
|
||||
|
||||
@property
|
||||
def source_id(self) -> str:
|
||||
"""Get the source executor ID."""
|
||||
@@ -53,27 +55,6 @@ class Edge:
|
||||
"""Get the unique ID of the edge."""
|
||||
return f"{self.source_id}{self.ID_SEPARATOR}{self.target_id}"
|
||||
|
||||
def has_edge_group(self) -> bool:
|
||||
"""Check if the edge is part of an edge group."""
|
||||
return bool(self._edge_group_ids)
|
||||
|
||||
@classmethod
|
||||
def source_and_target_from_id(cls, edge_id: str) -> tuple[str, str]:
|
||||
"""Extract the source and target IDs from the edge ID.
|
||||
|
||||
Args:
|
||||
edge_id (str): The edge ID in the format "source_id->target_id".
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing the source ID and target ID.
|
||||
"""
|
||||
if cls.ID_SEPARATOR not in edge_id:
|
||||
raise ValueError(f"Invalid edge ID format: {edge_id}")
|
||||
ids = edge_id.split(cls.ID_SEPARATOR)
|
||||
if len(ids) != 2:
|
||||
raise ValueError(f"Invalid edge ID format: {edge_id}")
|
||||
return ids[0], ids[1]
|
||||
|
||||
def can_handle(self, message_data: Any) -> bool:
|
||||
"""Check if the edge can handle the given data.
|
||||
|
||||
@@ -83,11 +64,14 @@ class Edge:
|
||||
Returns:
|
||||
bool: True if the edge can handle the data, False otherwise.
|
||||
"""
|
||||
if not self._edge_group_ids:
|
||||
return self.target.can_handle(message_data)
|
||||
return self.target.can_handle(message_data)
|
||||
|
||||
# If the edge is part of an edge group, the target should expect a list of the data type.
|
||||
return self.target.can_handle([message_data])
|
||||
def should_route(self, data: Any) -> bool:
|
||||
"""Determine if message should be routed through this edge based on the condition."""
|
||||
if self._condition is None:
|
||||
return True
|
||||
|
||||
return self._condition(data)
|
||||
|
||||
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> None:
|
||||
"""Send a message along this edge.
|
||||
@@ -98,57 +82,338 @@ class Edge:
|
||||
ctx (RunnerContext): The context for the runner.
|
||||
"""
|
||||
if not self.can_handle(message.data):
|
||||
# Caller of this method should ensure that the edge can handle the data.
|
||||
raise RuntimeError(f"Edge {self.id} cannot handle data of type {type(message.data)}.")
|
||||
|
||||
if not self._edge_group_ids and self._should_route(message.data):
|
||||
if self.should_route(message.data):
|
||||
await self.target.execute(
|
||||
message.data, WorkflowContext(self.target.id, [self.source.id], shared_state, ctx)
|
||||
)
|
||||
elif self._edge_group_ids:
|
||||
# Logic:
|
||||
# 1. If not all edges in the edge group have data in the shared state,
|
||||
# add the data to the shared state.
|
||||
# 2. If all edges in the edge group have data in the shared state,
|
||||
# copy the data to a list and send it to the target executor.
|
||||
message_list: list[Message] = []
|
||||
async with shared_state.hold() as held_shared_state:
|
||||
has_data = await asyncio.gather(
|
||||
*(held_shared_state.has_within_hold(edge_id) for edge_id in self._edge_group_ids)
|
||||
)
|
||||
if not all(has_data):
|
||||
await held_shared_state.set_within_hold(self.id, message)
|
||||
else:
|
||||
message_list = [
|
||||
await held_shared_state.get_within_hold(edge_id) for edge_id in self._edge_group_ids
|
||||
] + [message]
|
||||
# Remove the data from the shared state after retrieving it
|
||||
await asyncio.gather(
|
||||
*(held_shared_state.delete_within_hold(edge_id) for edge_id in self._edge_group_ids)
|
||||
)
|
||||
|
||||
if message_list:
|
||||
data_list = [msg.data for msg in message_list]
|
||||
source_ids = [msg.source_id for msg in message_list]
|
||||
await self.target.execute(data_list, WorkflowContext(self.target.id, source_ids, shared_state, ctx))
|
||||
|
||||
def _should_route(self, data: Any) -> bool:
|
||||
"""Determine if message should be routed through this edge."""
|
||||
if self._condition is None:
|
||||
return True
|
||||
class EdgeGroup:
|
||||
"""Represents a group of edges that share some common properties and can be triggered together."""
|
||||
|
||||
return self._condition(data)
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the edge group."""
|
||||
self._id = f"{self.__class__.__name__}/{uuid.uuid4()}"
|
||||
|
||||
def set_edge_group(self, edge_group_ids: list[str]) -> None:
|
||||
"""Set the edge group IDs for this edge.
|
||||
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
|
||||
"""Send a message through the edge group.
|
||||
|
||||
Args:
|
||||
edge_group_ids (list[str]): A list of edge IDs that belong to the same edge group.
|
||||
message (Message): The message to send.
|
||||
shared_state (SharedState): The shared state to use for holding data.
|
||||
ctx (RunnerContext): The context for the runner.
|
||||
|
||||
Returns:
|
||||
bool: True if the message was sent successfully, False if the target executor cannot handle the message.
|
||||
If a message can be delivered but rejected due to a condition, it will still return True.
|
||||
|
||||
Note:
|
||||
Exception will not be raised if the target executor cannot handle the message. This is because
|
||||
a source executor can be connected to multiple target executors, and not every target executor may
|
||||
be able to handle all the messages sent by the source executor.
|
||||
"""
|
||||
# Validate that the edges in the edge group contain the same target executor as this edge
|
||||
# TODO(@taochen): An edge cannot be part of multiple edge groups.
|
||||
# TODO(@taochen): Can an edge have both a condition and an edge group?
|
||||
if edge_group_ids:
|
||||
for edge_id in edge_group_ids:
|
||||
if Edge.source_and_target_from_id(edge_id)[1] != self.target.id:
|
||||
raise ValueError("All edges in the group must have the same target executor.")
|
||||
self._edge_group_ids = edge_group_ids
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Get the unique ID of the edge group."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def source_executors(self) -> list[Executor]:
|
||||
"""Get the source executor IDs of the edges in the group."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def target_executors(self) -> list[Executor]:
|
||||
"""Get the target executor IDs of the edges in the group."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def edges(self) -> list[Edge]:
|
||||
"""Get the edges in the group."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SingleEdgeGroup(EdgeGroup):
|
||||
"""Represents a single edge group that contains only one edge.
|
||||
|
||||
A concrete implementation of EdgeGroup that represent a group containing exactly one edge.
|
||||
"""
|
||||
|
||||
def __init__(self, source: Executor, target: Executor, condition: Callable[[Any], bool] | None = None) -> None:
|
||||
"""Initialize the single edge group with an edge.
|
||||
|
||||
Args:
|
||||
source (Executor): The source executor.
|
||||
target (Executor): The target executor that the source executor can send messages to.
|
||||
condition (Callable[[Any], bool], optional): A condition function that determines
|
||||
if the edge will pass the data to the target executor. If None, the edge can
|
||||
will always pass the data to the target executor.
|
||||
"""
|
||||
self._edge = Edge(source=source, target=target, condition=condition)
|
||||
|
||||
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
|
||||
"""Send a message through the single edge."""
|
||||
if message.target_id and message.target_id != self._edge.target_id:
|
||||
return False
|
||||
|
||||
if self._edge.can_handle(message.data):
|
||||
await self._edge.send_message(message, shared_state, ctx)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def source_executors(self) -> list[Executor]:
|
||||
"""Get the source executor of the edge."""
|
||||
return [self._edge.source]
|
||||
|
||||
@property
|
||||
def target_executors(self) -> list[Executor]:
|
||||
"""Get the target executor of the edge."""
|
||||
return [self._edge.target]
|
||||
|
||||
@property
|
||||
def edges(self) -> list[Edge]:
|
||||
"""Get the edges in the group."""
|
||||
return [self._edge]
|
||||
|
||||
|
||||
class FanOutEdgeGroup(EdgeGroup):
|
||||
"""Represents a group of edges that share the same source executor.
|
||||
|
||||
Assembles a Fan-out pattern where multiple edges share the same source executor
|
||||
and send messages to their respective target executors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: Executor,
|
||||
targets: Sequence[Executor],
|
||||
selection_func: Callable[[Any, list[str]], list[str]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the fan-out edge group with a list of edges.
|
||||
|
||||
Args:
|
||||
source (Executor): The source executor.
|
||||
targets (Sequence[Executor]): A list of target executors that the source executor can send messages to.
|
||||
selection_func (Callable[[Any, list[str]], list[str]], optional): A function that selects which target
|
||||
executors to send messages to. The function takes in the message data and a list of target executor
|
||||
IDs, and returns a list of selected target executor IDs.
|
||||
"""
|
||||
if len(targets) <= 1:
|
||||
raise ValueError("FanOutEdgeGroup must contain at least two targets.")
|
||||
self._edges = [Edge(source=source, target=target) for target in targets]
|
||||
self._target_ids = [edge.target_id for edge in self._edges]
|
||||
self._target_map = {edge.target_id: edge for edge in self._edges}
|
||||
self._selection_func = selection_func
|
||||
|
||||
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
|
||||
"""Send a message through all edges in the fan-out edge group."""
|
||||
selection_results = (
|
||||
self._selection_func(message.data, self._target_ids) if self._selection_func else self._target_ids
|
||||
)
|
||||
if not self._validate_selection_result(selection_results):
|
||||
raise RuntimeError(
|
||||
f"Invalid selection result: {selection_results}. "
|
||||
f"Expected selections to be a subset of valid target executor IDs: {self._target_ids}."
|
||||
)
|
||||
|
||||
if message.target_id:
|
||||
# If the target ID is specified and the selection result contains it, send the message to that edge
|
||||
if message.target_id in selection_results:
|
||||
edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None)
|
||||
if edge and edge.can_handle(message.data):
|
||||
await edge.send_message(message, shared_state, ctx)
|
||||
return True
|
||||
return False
|
||||
|
||||
# If no target ID, send the message to the selected targets
|
||||
async def send_to_edge(edge: Edge) -> bool:
|
||||
"""Send the message to the edge at the specified index."""
|
||||
if edge.can_handle(message.data):
|
||||
await edge.send_message(message, shared_state, ctx)
|
||||
return True
|
||||
return False
|
||||
|
||||
tasks = [send_to_edge(self._target_map[target_id]) for target_id in selection_results]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return any(results)
|
||||
|
||||
@property
|
||||
def source_executors(self) -> list[Executor]:
|
||||
"""Get the source executor of the edges in the group."""
|
||||
return [self._edges[0].source]
|
||||
|
||||
@property
|
||||
def target_executors(self) -> list[Executor]:
|
||||
"""Get the target executors of the edges in the group."""
|
||||
return [edge.target for edge in self._edges]
|
||||
|
||||
@property
|
||||
def edges(self) -> list[Edge]:
|
||||
"""Get the edges in the group."""
|
||||
return self._edges
|
||||
|
||||
def _validate_selection_result(self, selection_results: list[str]) -> bool:
|
||||
"""Validate the selection results to ensure all IDs are valid target executor IDs."""
|
||||
return all(result in self._target_ids for result in selection_results)
|
||||
|
||||
|
||||
class FanInEdgeGroup(EdgeGroup):
|
||||
"""Represents a group of edges that share the same target executor.
|
||||
|
||||
Assembles a Fan-in pattern where multiple edges send messages to a single target executor.
|
||||
Messages are buffered until all edges in the group have data to send.
|
||||
"""
|
||||
|
||||
def __init__(self, sources: Sequence[Executor], target: Executor) -> None:
|
||||
"""Initialize the fan-in edge group with a list of edges.
|
||||
|
||||
Args:
|
||||
sources (Sequence[Executor]): A list of source executors that can send messages to the target executor.
|
||||
target (Executor): The target executor that receives a list of messages aggregated from all sources.
|
||||
"""
|
||||
if len(sources) <= 1:
|
||||
raise ValueError("FanInEdgeGroup must contain at least two sources.")
|
||||
self._edges = [Edge(source=source, target=target) for source in sources]
|
||||
# Buffer to hold messages before sending them to the target executor
|
||||
# Key is the source executor ID, value is a list of messages
|
||||
self._buffer: dict[str, list[Message]] = defaultdict(list)
|
||||
|
||||
async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool:
|
||||
"""Send a message through all edges in the fan-in edge group."""
|
||||
if message.target_id and message.target_id != self._edges[0].target_id:
|
||||
return False
|
||||
|
||||
if self._edges[0].can_handle([message.data]):
|
||||
# If the edge can handle the data, buffer the message
|
||||
self._buffer[message.source_id].append(message)
|
||||
else:
|
||||
# If the edge cannot handle the data, return False
|
||||
return False
|
||||
|
||||
if self._is_ready_to_send():
|
||||
# If all edges in the group have data, send the buffered messages to the target executor
|
||||
messages_to_send = [msg for edge in self._edges for msg in self._buffer[edge.source_id]]
|
||||
self._buffer.clear()
|
||||
# Only trigger one edge to send the messages to avoid duplicate sends
|
||||
await self._edges[0].send_message(
|
||||
Message([msg.data for msg in messages_to_send], self.__class__.__name__),
|
||||
shared_state,
|
||||
ctx,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _is_ready_to_send(self) -> bool:
|
||||
"""Check if all edges in the group have data to send."""
|
||||
return all(self._buffer[edge.source_id] for edge in self._edges)
|
||||
|
||||
@property
|
||||
def source_executors(self) -> list[Executor]:
|
||||
"""Get the source executors of the edges in the group."""
|
||||
return [edge.source for edge in self._edges]
|
||||
|
||||
@property
|
||||
def target_executors(self) -> list[Executor]:
|
||||
"""Get the target executor of the edges in the group."""
|
||||
return [self._edges[0].target]
|
||||
|
||||
@property
|
||||
def edges(self) -> list[Edge]:
|
||||
"""Get the edges in the group."""
|
||||
return self._edges
|
||||
|
||||
|
||||
@dataclass
|
||||
class Case:
|
||||
"""Represents a single case in the conditional edge group.
|
||||
|
||||
Args:
|
||||
condition (Callable[[Any], bool]): The condition function for the case.
|
||||
target (Executor): The target executor for the case.
|
||||
"""
|
||||
|
||||
condition: Callable[[Any], bool]
|
||||
target: Executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Default:
|
||||
"""Represents the default case in the conditional edge group.
|
||||
|
||||
Args:
|
||||
target (Executor): The target executor for the default case.
|
||||
"""
|
||||
|
||||
target: Executor
|
||||
|
||||
|
||||
class SwitchCaseEdgeGroup(FanOutEdgeGroup):
|
||||
"""Represents a group of edges that assemble a conditional routing pattern.
|
||||
|
||||
This is similar to a switch-case construct:
|
||||
switch(data):
|
||||
case condition_1:
|
||||
edge_1
|
||||
break
|
||||
case condition_2:
|
||||
edge_2
|
||||
break
|
||||
default:
|
||||
edge_3
|
||||
break
|
||||
Or equivalently an if-elif-else construct:
|
||||
if condition_1:
|
||||
edge_1
|
||||
elif condition_2:
|
||||
edge_2
|
||||
else:
|
||||
edge_4
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: Executor,
|
||||
cases: Sequence[Case | Default],
|
||||
) -> None:
|
||||
"""Initialize the conditional edge group with a list of edges.
|
||||
|
||||
Args:
|
||||
source (Executor): The source executor.
|
||||
cases (Sequence[Case | Default]): A list of cases for the conditional edge group.
|
||||
There should be exactly one default case.
|
||||
"""
|
||||
if len(cases) < 2:
|
||||
raise ValueError("SwitchCaseEdgeGroup must contain at least two cases (including the default case).")
|
||||
|
||||
default_case = [isinstance(case, Default) for case in cases]
|
||||
if sum(default_case) != 1:
|
||||
raise ValueError("SwitchCaseEdgeGroup must contain exactly one default case.")
|
||||
|
||||
if isinstance(cases[-1], Default):
|
||||
logger.warning(
|
||||
"Default case in the conditional edge group is not the last case. "
|
||||
"This will result in unexpected behavior."
|
||||
)
|
||||
|
||||
def selection_func(data: Any, targets: list[str]) -> list[str]:
|
||||
"""Select the target executor based on the conditions."""
|
||||
for index, case in enumerate(cases):
|
||||
if isinstance(case, Default):
|
||||
return [case.target.id]
|
||||
if isinstance(case, Case):
|
||||
try:
|
||||
if case.condition(data):
|
||||
return [case.target.id]
|
||||
except Exception as e:
|
||||
logger.warning(f"Error occurred while evaluating condition for case {index}: {e}")
|
||||
|
||||
raise RuntimeError("No matching case found in SwitchCaseEdgeGroup.")
|
||||
|
||||
super().__init__(source, [case.target for case in cases], selection_func=selection_func)
|
||||
|
||||
@@ -31,7 +31,7 @@ class Executor:
|
||||
Args:
|
||||
id: A unique identifier for the executor. If None, a new UUID will be generated.
|
||||
"""
|
||||
self._id = id or str(uuid.uuid4())
|
||||
self._id = id or f"{self.__class__.__name__}/{uuid.uuid4()}"
|
||||
|
||||
self._handlers: dict[type, Callable[[Any, WorkflowContext], Any]] = {}
|
||||
self._discover_handlers()
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from ._edge import Edge
|
||||
from ._edge import EdgeGroup
|
||||
from ._events import WorkflowEvent
|
||||
from ._executor import Executor
|
||||
from ._runner_context import Message, RunnerContext
|
||||
@@ -20,7 +20,7 @@ class Runner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
edges: list[Edge],
|
||||
edge_groups: Sequence[EdgeGroup],
|
||||
shared_state: SharedState,
|
||||
ctx: RunnerContext,
|
||||
max_iterations: int = 100,
|
||||
@@ -29,13 +29,13 @@ class Runner:
|
||||
"""Initialize the runner with edges, shared state, and context.
|
||||
|
||||
Args:
|
||||
edges: The edges of the workflow.
|
||||
edge_groups: The edge groups of the workflow.
|
||||
shared_state: The shared state for the workflow.
|
||||
ctx: The runner context for the workflow.
|
||||
max_iterations: The maximum number of iterations to run.
|
||||
workflow_id: The workflow ID for checkpointing.
|
||||
"""
|
||||
self._edge_map = self._parse_edges(edges)
|
||||
self._edge_group_map = self._parse_edge_groups(edge_groups)
|
||||
self._ctx = ctx
|
||||
self._iteration = 0
|
||||
self._max_iterations = max_iterations
|
||||
@@ -125,26 +125,25 @@ class Runner:
|
||||
|
||||
async def _run_iteration(self):
|
||||
async def _deliver_messages(source_executor_id: str, messages: list[Message]) -> None:
|
||||
async def _deliver_messages_inner(
|
||||
edge: Edge,
|
||||
messages: list[Message],
|
||||
) -> None:
|
||||
for message in messages:
|
||||
if message.target_id is not None and message.target_id != edge.target_id:
|
||||
continue
|
||||
if not edge.can_handle(message.data):
|
||||
continue
|
||||
await edge.send_message(message, self._shared_state, self._ctx)
|
||||
"""Outer loop to concurrently deliver messages from all sources to their targets."""
|
||||
|
||||
associated_edges = self._edge_map.get(source_executor_id, [])
|
||||
tasks = [asyncio.create_task(_deliver_messages_inner(edge, messages)) for edge in associated_edges]
|
||||
await asyncio.gather(*tasks)
|
||||
async def _deliver_message_inner(edge_group: EdgeGroup, message: Message) -> bool:
|
||||
"""Inner loop to deliver a single message through an edge group."""
|
||||
return await edge_group.send_message(message, self._shared_state, self._ctx)
|
||||
|
||||
associated_edge_groups = self._edge_group_map.get(source_executor_id, [])
|
||||
for message in messages:
|
||||
# Deliver a message through all edge groups associated with the source executor concurrently.
|
||||
tasks = [_deliver_message_inner(edge_group, message) for edge_group in associated_edge_groups]
|
||||
results = await asyncio.gather(*tasks)
|
||||
if not any(results):
|
||||
logger.warning(
|
||||
f"Message {message} could not be delivered. "
|
||||
"This may be due to type incompatibility or no matching targets."
|
||||
)
|
||||
|
||||
messages = await self._ctx.drain_messages()
|
||||
tasks = [
|
||||
asyncio.create_task(_deliver_messages(source_executor_id, messages))
|
||||
for source_executor_id, messages in messages.items()
|
||||
]
|
||||
tasks = [_deliver_messages(source_executor_id, messages) for source_executor_id, messages in messages.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | None:
|
||||
@@ -177,10 +176,11 @@ class Runner:
|
||||
Only JSON-serializable dicts should be provided by executors.
|
||||
"""
|
||||
executors: dict[str, Executor] = {}
|
||||
for edge_list in self._edge_map.values():
|
||||
for edge in edge_list:
|
||||
executors[edge.source.id] = edge.source
|
||||
executors[edge.target.id] = edge.target
|
||||
for edge_groups in self._edge_group_map.values():
|
||||
for edge_group in edge_groups:
|
||||
for edge in edge_group.edges:
|
||||
executors[edge.source.id] = edge.source
|
||||
executors[edge.target.id] = edge.target
|
||||
for exec_id, executor in executors.items():
|
||||
state_dict: dict[str, Any] | None = None
|
||||
snapshot = getattr(executor, "snapshot_state", None)
|
||||
@@ -268,16 +268,18 @@ class Runner:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to restore shared state from context: {e}")
|
||||
|
||||
def _parse_edges(self, edges: list[Edge]) -> dict[str, list[Edge]]:
|
||||
"""Parse the edges of the workflow into a more convenient format.
|
||||
def _parse_edge_groups(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, list[EdgeGroup]]:
|
||||
"""Parse the edge groups of the workflow into a mapping where each source executor ID maps to its edge groups.
|
||||
|
||||
Args:
|
||||
edges: A list of edges in the workflow.
|
||||
edge_groups: A list of edge groups in the workflow.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each source executor ID to a list of target executor IDs.
|
||||
A dictionary mapping each source executor ID to a list of edge groups.
|
||||
"""
|
||||
parsed: defaultdict[str, list[Edge]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
parsed[edge.source_id].append(edge)
|
||||
parsed: defaultdict[str, list[EdgeGroup]] = defaultdict(list)
|
||||
for group in edge_groups:
|
||||
for source_executor in group.source_executors:
|
||||
parsed[source_executor.id].append(group)
|
||||
|
||||
return parsed
|
||||
|
||||
@@ -13,6 +13,10 @@ def is_instance_of(data: Any, target_type: type) -> bool:
|
||||
Returns:
|
||||
bool: True if data is an instance of target_type, False otherwise.
|
||||
"""
|
||||
# Case 0: target_type is Any - always return True
|
||||
if target_type is Any:
|
||||
return True
|
||||
|
||||
origin = get_origin(target_type)
|
||||
args = get_args(target_type)
|
||||
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
from ._edge import Edge
|
||||
from ._edge import Edge, EdgeGroup, FanInEdgeGroup
|
||||
from ._executor import Executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -92,18 +93,19 @@ class WorkflowGraphValidator:
|
||||
self._executors: dict[str, Executor] = {}
|
||||
|
||||
# region Core Validation Methods
|
||||
def validate_workflow(self, edges: list[Edge], start_executor: Executor | str) -> None:
|
||||
def validate_workflow(self, edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None:
|
||||
"""Validate the entire workflow graph.
|
||||
|
||||
Args:
|
||||
edges: list of edges in the workflow
|
||||
edge_groups: list of edge groups in the workflow
|
||||
start_executor: The starting executor (can be instance or ID)
|
||||
|
||||
Raises:
|
||||
WorkflowValidationError: If any validation fails
|
||||
"""
|
||||
self._edges = edges
|
||||
self._executors = self._build_executor_map(edges)
|
||||
self._executors = self._build_executor_map(edge_groups)
|
||||
self._edges = [edge for group in edge_groups for edge in group.edges]
|
||||
self._edge_groups = edge_groups
|
||||
|
||||
# Validate that start_executor exists in the graph
|
||||
# It should because we check for it in the WorkflowBuilder
|
||||
@@ -121,12 +123,13 @@ class WorkflowGraphValidator:
|
||||
self._validate_dead_ends()
|
||||
self._validate_cycles()
|
||||
|
||||
def _build_executor_map(self, edges: list[Edge]) -> dict[str, Executor]:
|
||||
def _build_executor_map(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, Executor]:
|
||||
"""Build a map of executor IDs to executor instances."""
|
||||
executors: dict[str, Executor] = {}
|
||||
for edge in edges:
|
||||
executors[edge.source_id] = edge.source
|
||||
executors[edge.target_id] = edge.target
|
||||
for group in edge_groups:
|
||||
for executor in group.source_executors + group.target_executors:
|
||||
executors[executor.id] = executor
|
||||
|
||||
return executors
|
||||
|
||||
# endregion
|
||||
@@ -155,64 +158,80 @@ class WorkflowGraphValidator:
|
||||
Raises:
|
||||
TypeCompatibilityError: If type incompatibility is detected
|
||||
"""
|
||||
for edge in self._edges:
|
||||
source_executor = edge.source
|
||||
target_executor = edge.target
|
||||
for edge_group in self._edge_groups:
|
||||
for edge in edge_group.edges:
|
||||
self._validate_edge_type_compatibility(edge, edge_group)
|
||||
|
||||
# Get output types from source executor
|
||||
source_output_types = self._get_executor_output_types(source_executor)
|
||||
def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) -> None:
|
||||
"""Validate type compatibility for a specific edge.
|
||||
|
||||
# Get input types from target executor
|
||||
target_input_types = self._get_executor_input_types(target_executor)
|
||||
This checks that the output types of the source executor are compatible
|
||||
with the input types expected by the target executor.
|
||||
|
||||
# If either executor has no type information, log warning and skip validation
|
||||
# This allows for dynamic typing scenarios but warns about reduced validation coverage
|
||||
if not source_output_types or not target_input_types:
|
||||
if not source_output_types:
|
||||
logger.warning(
|
||||
f"Executor '{source_executor.id}' has no output type annotations. "
|
||||
f"Type compatibility validation will be skipped for edges from this executor. "
|
||||
f"Consider adding output_types to @handler decorators for better validation."
|
||||
)
|
||||
if not target_input_types:
|
||||
logger.warning(
|
||||
f"Executor '{target_executor.id}' has no input type annotations. "
|
||||
f"Type compatibility validation will be skipped for edges to this executor. "
|
||||
f"Consider adding type annotations to message handler parameters for better validation."
|
||||
)
|
||||
continue
|
||||
Args:
|
||||
edge: The edge to validate
|
||||
edge_group: The edge group containing this edge
|
||||
|
||||
# Check if any source output type is compatible with any target input type
|
||||
compatible = False
|
||||
compatible_pairs: list[tuple[type[Any], type[Any]]] = []
|
||||
Raises:
|
||||
TypeCompatibilityError: If type incompatibility is detected
|
||||
"""
|
||||
source_executor = edge.source
|
||||
target_executor = edge.target
|
||||
|
||||
for source_type in source_output_types:
|
||||
for target_type in target_input_types:
|
||||
if edge.has_edge_group():
|
||||
# If the edge is part of an edge group, the target expects a list of data types
|
||||
if self._is_type_compatible(list[source_type], target_type):
|
||||
compatible = True
|
||||
compatible_pairs.append((list[source_type], target_type))
|
||||
else:
|
||||
if self._is_type_compatible(source_type, target_type):
|
||||
compatible = True
|
||||
compatible_pairs.append((source_type, target_type))
|
||||
# Get output types from source executor
|
||||
source_output_types = self._get_executor_output_types(source_executor)
|
||||
|
||||
# Log successful type compatibility for debugging
|
||||
if compatible:
|
||||
logger.debug(
|
||||
f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. "
|
||||
f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}"
|
||||
# Get input types from target executor
|
||||
target_input_types = self._get_executor_input_types(target_executor)
|
||||
|
||||
# If either executor has no type information, log warning and skip validation
|
||||
# This allows for dynamic typing scenarios but warns about reduced validation coverage
|
||||
if not source_output_types or not target_input_types:
|
||||
if not source_output_types:
|
||||
logger.warning(
|
||||
f"Executor '{source_executor.id}' has no output type annotations. "
|
||||
f"Type compatibility validation will be skipped for edges from this executor. "
|
||||
f"Consider adding output_types to @handler decorators for better validation."
|
||||
)
|
||||
|
||||
if not compatible:
|
||||
# Enhanced error with more detailed information
|
||||
raise TypeCompatibilityError(
|
||||
source_executor.id,
|
||||
target_executor.id,
|
||||
source_output_types,
|
||||
target_input_types,
|
||||
if not target_input_types:
|
||||
logger.warning(
|
||||
f"Executor '{target_executor.id}' has no input type annotations. "
|
||||
f"Type compatibility validation will be skipped for edges to this executor. "
|
||||
f"Consider adding type annotations to message handler parameters for better validation."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if any source output type is compatible with any target input type
|
||||
compatible = False
|
||||
compatible_pairs: list[tuple[type[Any], type[Any]]] = []
|
||||
|
||||
for source_type in source_output_types:
|
||||
for target_type in target_input_types:
|
||||
if isinstance(edge_group, FanInEdgeGroup):
|
||||
# If the edge is part of an edge group, the target expects a list of data types
|
||||
if self._is_type_compatible(list[source_type], target_type):
|
||||
compatible = True
|
||||
compatible_pairs.append((list[source_type], target_type))
|
||||
else:
|
||||
if self._is_type_compatible(source_type, target_type):
|
||||
compatible = True
|
||||
compatible_pairs.append((source_type, target_type))
|
||||
|
||||
# Log successful type compatibility for debugging
|
||||
if compatible:
|
||||
logger.debug(
|
||||
f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. "
|
||||
f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}"
|
||||
)
|
||||
|
||||
if not compatible:
|
||||
# Enhanced error with more detailed information
|
||||
raise TypeCompatibilityError(
|
||||
source_executor.id,
|
||||
target_executor.id,
|
||||
source_output_types,
|
||||
target_input_types,
|
||||
)
|
||||
|
||||
def _get_executor_output_types(self, executor: Executor) -> list[type[Any]]:
|
||||
"""Extract output types from an executor's message handlers.
|
||||
@@ -479,15 +498,15 @@ class WorkflowGraphValidator:
|
||||
# endregion
|
||||
|
||||
|
||||
def validate_workflow_graph(edges: list[Edge], start_executor: Executor | str) -> None:
|
||||
def validate_workflow_graph(edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None:
|
||||
"""Convenience function to validate a workflow graph.
|
||||
|
||||
Args:
|
||||
edges: list of edges in the workflow
|
||||
edge_groups: list of edge groups in the workflow
|
||||
start_executor: The starting executor (can be instance or ID)
|
||||
|
||||
Raises:
|
||||
WorkflowValidationError: If any validation fails
|
||||
"""
|
||||
validator = WorkflowGraphValidator()
|
||||
validator.validate_workflow(edges, start_executor)
|
||||
validator.validate_workflow(edge_groups, start_executor)
|
||||
|
||||
@@ -9,7 +9,15 @@ from typing import Any
|
||||
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._const import DEFAULT_MAX_ITERATIONS
|
||||
from ._edge import Edge
|
||||
from ._edge import (
|
||||
Case,
|
||||
Default,
|
||||
EdgeGroup,
|
||||
FanInEdgeGroup,
|
||||
FanOutEdgeGroup,
|
||||
SingleEdgeGroup,
|
||||
SwitchCaseEdgeGroup,
|
||||
)
|
||||
from ._events import RequestInfoEvent, WorkflowCompletedEvent, WorkflowEvent
|
||||
from ._executor import Executor, RequestInfoExecutor
|
||||
from ._runner import Runner
|
||||
@@ -67,7 +75,7 @@ class Workflow:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
edges: list[Edge],
|
||||
edge_groups: list[EdgeGroup],
|
||||
start_executor: Executor | str,
|
||||
runner_context: RunnerContext,
|
||||
max_iterations: int,
|
||||
@@ -75,28 +83,29 @@ class Workflow:
|
||||
"""Initialize the workflow with a list of edges.
|
||||
|
||||
Args:
|
||||
edges: A list of directed edges representing the connections between nodes in the workflow.
|
||||
edge_groups: A list of EdgeGroup instances that define the workflow edges.
|
||||
start_executor: The starting executor for the workflow, which can be an Executor instance or its ID.
|
||||
runner_context: The RunnerContext instance to be used during workflow execution.
|
||||
max_iterations: The maximum number of iterations the workflow will run for convergence.
|
||||
"""
|
||||
self._edges = edges
|
||||
self._edge_groups = edge_groups
|
||||
self._executors = self._build_executor_map(edge_groups)
|
||||
self._start_executor = start_executor
|
||||
self._executors = {edge.source_id: edge.source for edge in edges} | {
|
||||
edge.target_id: edge.target for edge in edges
|
||||
}
|
||||
|
||||
self._shared_state = SharedState()
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
self._runner = Runner(
|
||||
self._edges, self._shared_state, runner_context, max_iterations=max_iterations, workflow_id=workflow_id
|
||||
self._edge_groups,
|
||||
self._shared_state,
|
||||
runner_context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def edges(self) -> list[Edge]:
|
||||
"""Get the list of edges in the workflow."""
|
||||
return self._edges
|
||||
def edge_groups(self) -> list[EdgeGroup]:
|
||||
"""Get the list of edge groups in the workflow."""
|
||||
return self._edge_groups
|
||||
|
||||
@property
|
||||
def start_executor(self) -> Executor:
|
||||
@@ -298,6 +307,22 @@ class Workflow:
|
||||
raise ValueError(f"Executor with ID {executor_id} not found.")
|
||||
return self._executors[executor_id]
|
||||
|
||||
def _build_executor_map(self, edge_groups: list[EdgeGroup]) -> dict[str, Executor]:
|
||||
"""Build the executor map from edge groups.
|
||||
|
||||
Args:
|
||||
edge_groups: A list of EdgeGroup instances.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping executor IDs to Executor instances.
|
||||
"""
|
||||
executors: dict[str, Executor] = {}
|
||||
for group in edge_groups:
|
||||
for executor in group.source_executors + group.target_executors:
|
||||
executors[executor.id] = executor
|
||||
|
||||
return executors
|
||||
|
||||
async def _restore_from_external_checkpoint(
|
||||
self, checkpoint_id: str, checkpoint_storage: CheckpointStorage
|
||||
) -> bool:
|
||||
@@ -405,7 +430,7 @@ class WorkflowBuilder:
|
||||
|
||||
def __init__(self, max_iterations: int = DEFAULT_MAX_ITERATIONS):
|
||||
"""Initialize the WorkflowBuilder with an empty list of edges and no starting executor."""
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_groups: list[EdgeGroup] = []
|
||||
self._start_executor: Executor | str | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = None
|
||||
self._max_iterations: int = max_iterations
|
||||
@@ -427,21 +452,66 @@ class WorkflowBuilder:
|
||||
should be traversed based on the message type.
|
||||
"""
|
||||
# TODO(@taochen): Support executor factories for lazy initialization
|
||||
self._edges.append(Edge(source, target, condition))
|
||||
self._edge_groups.append(SingleEdgeGroup(source, target, condition))
|
||||
return self
|
||||
|
||||
def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "Self":
|
||||
"""Add multiple edges to the workflow.
|
||||
"""Add multiple edges to the workflow where messages from the source will be sent to all target.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
Messages from the source executor will be sent to all target executors.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
targets: A list of target executors for the edges.
|
||||
"""
|
||||
for target in targets:
|
||||
self._edges.append(Edge(source, target))
|
||||
self._edge_groups.append(FanOutEdgeGroup(source, targets))
|
||||
|
||||
return self
|
||||
|
||||
def add_switch_case_edge_group(self, source: Executor, cases: Sequence[Case | Default]) -> "Self":
|
||||
"""Add an edge group that represents a switch-case statement.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
Messages from the source executor will be sent to one of the target executors based on
|
||||
the provided conditions.
|
||||
|
||||
Think of this as a switch statement where each target executor corresponds to a case.
|
||||
Each condition function will be evaluated in order, and the first one that returns True
|
||||
will determine which target executor receives the message.
|
||||
|
||||
The last case (the default case) will receive messages that fall through all conditions
|
||||
(i.e., no condition matched).
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
cases: A list of case objects that determine the target executor for each message.
|
||||
"""
|
||||
self._edge_groups.append(SwitchCaseEdgeGroup(source, cases))
|
||||
|
||||
return self
|
||||
|
||||
def add_multi_selection_edge_group(
|
||||
self,
|
||||
source: Executor,
|
||||
targets: Sequence[Executor],
|
||||
selection_func: Callable[[Any, list[str]], list[str]],
|
||||
) -> "Self":
|
||||
"""Add an edge group that represents a multi-selection execution model.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
Messages from the source executor will be sent to multiple target executors based on
|
||||
the provided selection function.
|
||||
|
||||
The selection function should take a message and the name of the target executors,
|
||||
and return a list of indices indicating which target executors should receive the message.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
targets: A list of target executors for the edges.
|
||||
selection_func: A function that selects target executors for messages.
|
||||
"""
|
||||
self._edge_groups.append(FanOutEdgeGroup(source, targets, selection_func))
|
||||
|
||||
return self
|
||||
|
||||
def add_fan_in_edges(self, sources: Sequence[Executor], target: Executor) -> "Self":
|
||||
@@ -478,16 +548,7 @@ class WorkflowBuilder:
|
||||
sources: A list of source executors for the edges.
|
||||
target: The target executor for the edges.
|
||||
"""
|
||||
edges = [Edge(source, target) for source in sources]
|
||||
|
||||
# Set the edge groups for the edges to ensure they are processed together.
|
||||
for i, edge in enumerate(edges):
|
||||
group_ids: list[str] = []
|
||||
group_ids.extend([e.id for e in edges[0:i]])
|
||||
group_ids.extend([e.id for e in edges[i + 1 :]])
|
||||
edge.set_edge_group(group_ids)
|
||||
|
||||
self._edges.extend(edges)
|
||||
self._edge_groups.append(FanInEdgeGroup(sources, target))
|
||||
|
||||
return self
|
||||
|
||||
@@ -549,11 +610,11 @@ class WorkflowBuilder:
|
||||
if not self._start_executor:
|
||||
raise ValueError("Starting executor must be set before building the workflow.")
|
||||
|
||||
validate_workflow_graph(self._edges, self._start_executor)
|
||||
validate_workflow_graph(self._edge_groups, self._start_executor)
|
||||
|
||||
context = InProcRunnerContext(self._checkpoint_storage)
|
||||
|
||||
return Workflow(self._edges, self._start_executor, context, self._max_iterations)
|
||||
return Workflow(self._edge_groups, self._start_executor, context, self._max_iterations)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -2,10 +2,20 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from agent_framework.workflow import Executor, WorkflowContext, handler
|
||||
|
||||
from agent_framework_workflow._edge import Edge
|
||||
from agent_framework_workflow._edge import (
|
||||
Case,
|
||||
Default,
|
||||
Edge,
|
||||
FanInEdgeGroup,
|
||||
FanOutEdgeGroup,
|
||||
SingleEdgeGroup,
|
||||
SwitchCaseEdgeGroup,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,6 +25,13 @@ class MockMessage:
|
||||
data: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMessageSecondary:
|
||||
"""A secondary mock message for testing purposes."""
|
||||
|
||||
data: Any
|
||||
|
||||
|
||||
class MockExecutor(Executor):
|
||||
"""A mock executor for testing purposes."""
|
||||
|
||||
@@ -24,6 +41,27 @@ class MockExecutor(Executor):
|
||||
pass
|
||||
|
||||
|
||||
class MockExecutorSecondary(Executor):
|
||||
"""A secondary mock executor for testing purposes."""
|
||||
|
||||
@handler
|
||||
async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None:
|
||||
"""A secondary mock handler that does nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class MockAggregator(Executor):
|
||||
"""A mock aggregator for testing purposes."""
|
||||
|
||||
@handler
|
||||
async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None:
|
||||
"""A mock aggregator handler that does nothing."""
|
||||
pass
|
||||
|
||||
|
||||
# region Edge
|
||||
|
||||
|
||||
def test_create_edge():
|
||||
"""Test creating an edge with a source and target executor."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
@@ -34,7 +72,6 @@ def test_create_edge():
|
||||
assert edge.source_id == "source_executor"
|
||||
assert edge.target_id == "target_executor"
|
||||
assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}"
|
||||
assert (edge.source_id, edge.target_id) == Edge.source_and_target_from_id(edge.id)
|
||||
|
||||
|
||||
def test_edge_can_handle():
|
||||
@@ -45,3 +82,750 @@ def test_edge_can_handle():
|
||||
edge = Edge(source=source, target=target)
|
||||
|
||||
assert edge.can_handle(MockMessage(data="test"))
|
||||
|
||||
|
||||
# endregion Edge
|
||||
|
||||
# region SingleEdgeGroup
|
||||
|
||||
|
||||
def test_single_edge_group():
|
||||
"""Test creating a single edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
edge_group = SingleEdgeGroup(source=source, target=target)
|
||||
|
||||
assert edge_group.source_executors == [source]
|
||||
assert edge_group.target_executors == [target]
|
||||
assert edge_group.edges[0].source_id == "source_executor"
|
||||
assert edge_group.edges[0].target_id == "target_executor"
|
||||
|
||||
|
||||
def test_single_edge_group_with_condition():
|
||||
"""Test creating a single edge group with a condition."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
edge_group = SingleEdgeGroup(source=source, target=target, condition=lambda x: x.data == "test")
|
||||
|
||||
assert edge_group.source_executors == [source]
|
||||
assert edge_group.target_executors == [target]
|
||||
assert edge_group.edges[0].source_id == "source_executor"
|
||||
assert edge_group.edges[0].target_id == "target_executor"
|
||||
assert edge_group.edges[0]._condition is not None # type: ignore
|
||||
|
||||
|
||||
async def test_single_edge_group_send_message():
|
||||
"""Test sending a message through a single edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
edge_group = SingleEdgeGroup(source=source, target=target)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is True
|
||||
|
||||
|
||||
async def test_single_edge_group_send_message_with_target():
|
||||
"""Test sending a message through a single edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
edge_group = SingleEdgeGroup(source=source, target=target)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id, target_id=target.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is True
|
||||
|
||||
|
||||
async def test_single_edge_group_send_message_with_invalid_target():
|
||||
"""Test sending a message through a single edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
edge_group = SingleEdgeGroup(source=source, target=target)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id, target_id="invalid_target")
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_single_edge_group_send_message_with_invalid_data():
|
||||
"""Test sending a message through a single edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
edge_group = SingleEdgeGroup(source=source, target=target)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = "invalid_data"
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
# endregion SingleEdgeGroup
|
||||
|
||||
|
||||
# region FanOutEdgeGroup
|
||||
|
||||
|
||||
def test_source_edge_group():
|
||||
"""Test creating a fan-out group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
||||
|
||||
assert edge_group.source_executors == [source]
|
||||
assert edge_group.target_executors == [target1, target2]
|
||||
assert len(edge_group.edges) == 2
|
||||
assert edge_group.edges[0].source_id == "source_executor"
|
||||
assert edge_group.edges[0].target_id == "target_executor_1"
|
||||
assert edge_group.edges[1].source_id == "source_executor"
|
||||
assert edge_group.edges[1].target_id == "target_executor_2"
|
||||
|
||||
|
||||
def test_source_edge_group_invalid_number_of_targets():
|
||||
"""Test creating a fan-out group with an invalid number of targets."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"):
|
||||
FanOutEdgeGroup(source=source, targets=[target])
|
||||
|
||||
|
||||
async def test_source_edge_group_send_message():
|
||||
"""Test sending a message through a fan-out group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 2
|
||||
|
||||
|
||||
async def test_source_edge_group_send_message_with_target():
|
||||
"""Test sending a message through a fan-out group with a target."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
||||
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 1
|
||||
assert mock_send.call_args[0][0].target_id == target1.id
|
||||
|
||||
|
||||
async def test_source_edge_group_send_message_with_invalid_target():
|
||||
"""Test sending a message through a fan-out group with an invalid target."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id, target_id="invalid_target")
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_source_edge_group_send_message_with_invalid_data():
|
||||
"""Test sending a message through a fan-out group with invalid data."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = "invalid_data"
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_source_edge_group_send_message_only_one_successful_send():
|
||||
"""Test sending a message through a fan-out group where only one edge can handle the message."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutorSecondary(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 1
|
||||
|
||||
|
||||
def test_source_edge_group_with_selection_func():
|
||||
"""Test creating a partitioning edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(
|
||||
source=source,
|
||||
targets=[target1, target2],
|
||||
selection_func=lambda data, target_ids: [target1.id],
|
||||
)
|
||||
|
||||
assert edge_group.source_executors == [source]
|
||||
assert edge_group.target_executors == [target1, target2]
|
||||
assert len(edge_group.edges) == 2
|
||||
assert edge_group.edges[0].source_id == "source_executor"
|
||||
assert edge_group.edges[0].target_id == "target_executor_1"
|
||||
assert edge_group.edges[1].source_id == "source_executor"
|
||||
assert edge_group.edges[1].target_id == "target_executor_2"
|
||||
|
||||
|
||||
async def test_source_edge_group_with_selection_func_send_message():
|
||||
"""Test sending a message through a fan-out group with a selection function."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(
|
||||
source=source,
|
||||
targets=[target1, target2],
|
||||
selection_func=lambda data, target_ids: [target1.id, target2.id],
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 2
|
||||
|
||||
|
||||
async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result():
|
||||
"""Test sending a message through a fan-out group with a selection func with an invalid selection result."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(
|
||||
source=source,
|
||||
targets=[target1, target2],
|
||||
selection_func=lambda data, target_ids: [target1.id, "invalid_target"],
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
|
||||
async def test_source_edge_group_with_selection_func_send_message_with_target():
|
||||
"""Test sending a message through a fan-out group with a selection func with a target."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(
|
||||
source=source,
|
||||
targets=[target1, target2],
|
||||
selection_func=lambda data, target_ids: [target1.id, target2.id],
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
||||
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 1
|
||||
assert mock_send.call_args[0][0].target_id == target1.id
|
||||
|
||||
|
||||
async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection():
|
||||
"""Test sending a message through a fan-out group with a selection func with a target not in the selection."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(
|
||||
source=source,
|
||||
targets=[target1, target2],
|
||||
selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source.id, target_id=target2.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_source_edge_group_with_selection_func_send_message_with_invalid_data():
|
||||
"""Test sending a message through a fan-out group with a selection func with invalid data."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(
|
||||
source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id]
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = "invalid_data"
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data():
|
||||
"""Test sending a message through a fan-out group with a selection func with a target and invalid data."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = FanOutEdgeGroup(
|
||||
source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id]
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = "invalid_data"
|
||||
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
# endregion FanOutEdgeGroup
|
||||
|
||||
# region FanInEdgeGroup
|
||||
|
||||
|
||||
def test_target_edge_group():
|
||||
"""Test creating a fan-in edge group."""
|
||||
source1 = MockExecutor(id="source_executor_1")
|
||||
source2 = MockExecutor(id="source_executor_2")
|
||||
target = MockAggregator(id="target_executor")
|
||||
|
||||
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
||||
|
||||
assert edge_group.source_executors == [source1, source2]
|
||||
assert edge_group.target_executors == [target]
|
||||
assert len(edge_group.edges) == 2
|
||||
assert edge_group.edges[0].source_id == "source_executor_1"
|
||||
assert edge_group.edges[0].target_id == "target_executor"
|
||||
assert edge_group.edges[1].source_id == "source_executor_2"
|
||||
assert edge_group.edges[1].target_id == "target_executor"
|
||||
|
||||
|
||||
def test_target_edge_group_invalid_number_of_sources():
|
||||
"""Test creating a fan-in edge group with an invalid number of sources."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockAggregator(id="target_executor")
|
||||
|
||||
with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"):
|
||||
FanInEdgeGroup(sources=[source], target=target)
|
||||
|
||||
|
||||
async def test_target_edge_group_send_message_buffer():
|
||||
"""Test sending a message through a fan-in edge group with buffering."""
|
||||
source1 = MockExecutor(id="source_executor_1")
|
||||
source2 = MockExecutor(id="source_executor_2")
|
||||
target = MockAggregator(id="target_executor")
|
||||
|
||||
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(
|
||||
Message(data=data, source_id=source1.id),
|
||||
shared_state,
|
||||
ctx,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 0 # The message should be buffered and wait for the second source
|
||||
assert len(edge_group._buffer[source1.id]) == 1 # type: ignore
|
||||
|
||||
success = await edge_group.send_message(
|
||||
Message(data=data, source_id=source2.id),
|
||||
shared_state,
|
||||
ctx,
|
||||
)
|
||||
assert success is True
|
||||
assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages
|
||||
|
||||
# Buffer should be cleared after sending
|
||||
assert not edge_group._buffer # type: ignore
|
||||
|
||||
|
||||
async def test_target_edge_group_send_message_with_invalid_target():
|
||||
"""Test sending a message through a fan-in edge group with an invalid target."""
|
||||
source1 = MockExecutor(id="source_executor_1")
|
||||
source2 = MockExecutor(id="source_executor_2")
|
||||
target = MockAggregator(id="target_executor")
|
||||
|
||||
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data="test")
|
||||
message = Message(data=data, source_id=source1.id, target_id="invalid_target")
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_target_edge_group_send_message_with_invalid_data():
|
||||
"""Test sending a message through a fan-in edge group with invalid data."""
|
||||
source1 = MockExecutor(id="source_executor_1")
|
||||
source2 = MockExecutor(id="source_executor_2")
|
||||
target = MockAggregator(id="target_executor")
|
||||
|
||||
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = "invalid_data"
|
||||
message = Message(data=data, source_id=source1.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
# endregion FanInEdgeGroup
|
||||
|
||||
# region SwitchCaseEdgeGroup
|
||||
|
||||
|
||||
def test_switch_case_edge_group():
|
||||
"""Test creating a switch case edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target1),
|
||||
Default(target=target2),
|
||||
],
|
||||
)
|
||||
|
||||
assert edge_group.source_executors == [source]
|
||||
assert edge_group.target_executors == [target1, target2]
|
||||
assert len(edge_group.edges) == 2
|
||||
assert edge_group.edges[0].source_id == "source_executor"
|
||||
assert edge_group.edges[0].target_id == "target_executor_1"
|
||||
assert edge_group.edges[1].source_id == "source_executor"
|
||||
assert edge_group.edges[1].target_id == "target_executor_2"
|
||||
|
||||
assert edge_group._selection_func is not None # type: ignore
|
||||
assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore
|
||||
assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore
|
||||
|
||||
|
||||
def test_switch_case_edge_group_invalid_number_of_cases():
|
||||
"""Test creating a switch case edge group with an invalid number of cases."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)."
|
||||
):
|
||||
SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target),
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."):
|
||||
SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target),
|
||||
Case(condition=lambda x: x.data >= 0, target=target),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_switch_case_edge_group_invalid_number_of_default_cases():
|
||||
"""Test creating a switch case edge group with an invalid number of conditions."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."):
|
||||
SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target1),
|
||||
Default(target=target2),
|
||||
Default(target=target2),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def test_switch_case_edge_group_send_message():
|
||||
"""Test sending a message through a switch case edge group."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target1),
|
||||
Default(target=target2),
|
||||
],
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data=-1)
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 1
|
||||
|
||||
# Default condition should
|
||||
data = MockMessage(data=1)
|
||||
message = Message(data=data, source_id=source.id)
|
||||
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
|
||||
assert success is True
|
||||
assert mock_send.call_count == 1
|
||||
|
||||
|
||||
async def test_switch_case_edge_group_send_message_with_invalid_target():
|
||||
"""Test sending a message through a switch case edge group with an invalid target."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target1),
|
||||
Default(target=target2),
|
||||
],
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data=-1)
|
||||
message = Message(data=data, source_id=source.id, target_id="invalid_target")
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_switch_case_edge_group_send_message_with_valid_target():
|
||||
"""Test sending a message through a switch case edge group with a target."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target1),
|
||||
Default(target=target2),
|
||||
],
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = MockMessage(data=1) # Condition will fail
|
||||
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
data = MockMessage(data=-1) # Condition will pass
|
||||
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is True
|
||||
|
||||
|
||||
async def test_switch_case_edge_group_send_message_with_invalid_data():
|
||||
"""Test sending a message through a switch case edge group with invalid data."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
target2 = MockExecutor(id="target_executor_2")
|
||||
|
||||
edge_group = SwitchCaseEdgeGroup(
|
||||
source=source,
|
||||
cases=[
|
||||
Case(condition=lambda x: x.data < 0, target=target1),
|
||||
Default(target=target2),
|
||||
],
|
||||
)
|
||||
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
|
||||
shared_state = SharedState()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
data = "invalid_data"
|
||||
message = Message(data=data, source_id=source.id)
|
||||
|
||||
success = await edge_group.send_message(message, shared_state, ctx)
|
||||
assert success is False
|
||||
|
||||
|
||||
# endregion SwitchCaseEdgeGroup
|
||||
|
||||
@@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||
import pytest
|
||||
from agent_framework.workflow import Executor, WorkflowCompletedEvent, WorkflowContext, WorkflowEvent, handler
|
||||
|
||||
from agent_framework_workflow._edge import Edge
|
||||
from agent_framework_workflow._edge import SingleEdgeGroup
|
||||
from agent_framework_workflow._runner import Runner
|
||||
from agent_framework_workflow._runner_context import InProcRunnerContext, RunnerContext
|
||||
from agent_framework_workflow._shared_state import SharedState
|
||||
@@ -36,12 +36,12 @@ def test_create_runner():
|
||||
executor_b = MockExecutor(id="executor_b")
|
||||
|
||||
# Create a loop
|
||||
edges = [
|
||||
Edge(source=executor_a, target=executor_b),
|
||||
Edge(source=executor_b, target=executor_a),
|
||||
edge_groups = [
|
||||
SingleEdgeGroup(executor_a, executor_b),
|
||||
SingleEdgeGroup(executor_b, executor_a),
|
||||
]
|
||||
|
||||
runner = Runner(edges, shared_state=SharedState(), ctx=InProcRunnerContext())
|
||||
runner = Runner(edge_groups, shared_state=SharedState(), ctx=InProcRunnerContext())
|
||||
|
||||
assert runner.context is not None and isinstance(runner.context, RunnerContext)
|
||||
|
||||
@@ -53,8 +53,8 @@ async def test_runner_run_until_convergence():
|
||||
|
||||
# Create a loop
|
||||
edges = [
|
||||
Edge(source=executor_a, target=executor_b),
|
||||
Edge(source=executor_b, target=executor_a),
|
||||
SingleEdgeGroup(executor_a, executor_b),
|
||||
SingleEdgeGroup(executor_b, executor_a),
|
||||
]
|
||||
|
||||
shared_state = SharedState()
|
||||
@@ -87,8 +87,8 @@ async def test_runner_run_until_convergence_not_completed():
|
||||
|
||||
# Create a loop
|
||||
edges = [
|
||||
Edge(source=executor_a, target=executor_b),
|
||||
Edge(source=executor_b, target=executor_a),
|
||||
SingleEdgeGroup(executor_a, executor_b),
|
||||
SingleEdgeGroup(executor_b, executor_a),
|
||||
]
|
||||
|
||||
shared_state = SharedState()
|
||||
@@ -117,8 +117,8 @@ async def test_runner_already_running():
|
||||
|
||||
# Create a loop
|
||||
edges = [
|
||||
Edge(source=executor_a, target=executor_b),
|
||||
Edge(source=executor_b, target=executor_a),
|
||||
SingleEdgeGroup(executor_a, executor_b),
|
||||
SingleEdgeGroup(executor_b, executor_a),
|
||||
]
|
||||
|
||||
shared_state = SharedState()
|
||||
|
||||
@@ -17,7 +17,7 @@ from agent_framework_workflow import (
|
||||
handler,
|
||||
validate_workflow_graph,
|
||||
)
|
||||
from agent_framework_workflow._edge import Edge
|
||||
from agent_framework_workflow._edge import SingleEdgeGroup
|
||||
|
||||
|
||||
class StringExecutor(Executor):
|
||||
@@ -159,10 +159,13 @@ def test_graph_connectivity_isolated_executors():
|
||||
executor3 = StringExecutor(id="executor3") # This will be isolated
|
||||
|
||||
# Create edges that include an isolated executor (self-loop that's not connected to main graph)
|
||||
edges = [Edge(executor1, executor2), Edge(executor3, executor3)] # Self-loop to include in graph
|
||||
edge_groups = [
|
||||
SingleEdgeGroup(executor1, executor2),
|
||||
SingleEdgeGroup(executor3, executor3),
|
||||
] # Self-loop to include in graph
|
||||
|
||||
with pytest.raises(GraphConnectivityError) as exc_info:
|
||||
validate_workflow_graph(edges, executor1)
|
||||
validate_workflow_graph(edge_groups, executor1)
|
||||
|
||||
assert "unreachable" in str(exc_info.value).lower()
|
||||
assert "executor3" in str(exc_info.value)
|
||||
@@ -239,15 +242,15 @@ def test_type_compatibility_inheritance():
|
||||
def test_direct_validation_function():
|
||||
executor1 = StringExecutor(id="executor1")
|
||||
executor2 = StringExecutor(id="executor2")
|
||||
edges = [Edge(executor1, executor2)]
|
||||
edge_groups = [SingleEdgeGroup(executor1, executor2)]
|
||||
|
||||
# This should not raise any exceptions
|
||||
validate_workflow_graph(edges, executor1)
|
||||
validate_workflow_graph(edge_groups, executor1)
|
||||
|
||||
# Test with invalid start executor
|
||||
executor3 = StringExecutor(id="executor3")
|
||||
with pytest.raises(GraphConnectivityError):
|
||||
validate_workflow_graph(edges, executor3)
|
||||
validate_workflow_graph(edge_groups, executor3)
|
||||
|
||||
|
||||
def test_fan_out_validation():
|
||||
|
||||
@@ -60,6 +60,6 @@ def test_workflow_builder_fluent_api():
|
||||
.build()
|
||||
)
|
||||
|
||||
assert len(workflow.edges) == 6
|
||||
assert len(workflow.edge_groups) == 4
|
||||
assert workflow.start_executor.id == executor_a.id
|
||||
assert len(workflow.executors) == 6
|
||||
|
||||
@@ -0,0 +1,291 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.workflow import Case, Default, Executor, WorkflowBuilder, WorkflowContext, handler
|
||||
|
||||
"""
|
||||
The following sample demonstrates the foundation patterns that the workflow framework supports.
|
||||
These patterns include:
|
||||
- Single connection
|
||||
- Single connection with condition
|
||||
- Fan-out and fan-in connections
|
||||
- Conditional fan-out connections
|
||||
- Partitioning fan-out connections
|
||||
|
||||
The samples here use numbers and simple arithmetic operations to demonstrate the patterns.
|
||||
"""
|
||||
|
||||
|
||||
class AddOneExecutor(Executor):
|
||||
"""An executor that processes a number by adding one."""
|
||||
|
||||
@handler(output_types=[int])
|
||||
async def add_one(self, number: int, ctx: WorkflowContext) -> None:
|
||||
"""Execute the task by adding one to the input number."""
|
||||
result = number + 1
|
||||
|
||||
# Send the result to the next executor in the workflow.
|
||||
await ctx.send_message(result)
|
||||
|
||||
print("Adding one to the number:", number, "Result:", result)
|
||||
|
||||
|
||||
class MultiplyByTwoExecutor(Executor):
|
||||
"""An executor that processes a number by multiplying it by two."""
|
||||
|
||||
@handler(output_types=[int])
|
||||
async def multiply_by_two(self, number: int, ctx: WorkflowContext) -> None:
|
||||
"""Execute the task by multiplying the input number by two."""
|
||||
result = number * 2
|
||||
|
||||
# Send the result to the next executor in the workflow.
|
||||
await ctx.send_message(result)
|
||||
|
||||
print("Multiplying the number by two:", number, "Result:", result)
|
||||
|
||||
|
||||
class DivideByTwoExecutor(Executor):
|
||||
"""An executor that processes a number by dividing it by two."""
|
||||
|
||||
@handler(output_types=[float])
|
||||
async def divide_by_two(self, number: int, ctx: WorkflowContext) -> None:
|
||||
"""Execute the task by dividing the input number by two."""
|
||||
result = number / 2
|
||||
|
||||
# Send the result with a workflow completion event.
|
||||
await ctx.send_message(result)
|
||||
|
||||
print("Dividing the number by two:", number, "Result:", result)
|
||||
|
||||
|
||||
class AggregateResultExecutor(Executor):
|
||||
"""An executor that receives results and prints them."""
|
||||
|
||||
@handler
|
||||
async def aggregate_results(self, results: Any, ctx: WorkflowContext) -> None:
|
||||
"""Print whatever results are received."""
|
||||
print("Aggregating results:", results)
|
||||
|
||||
|
||||
async def single_edge():
|
||||
"""A sample to demonstrate a single directed connection between two executors.
|
||||
|
||||
Three executors are connected in a sequence: AddOneExecutor -> AddOneExecutor -> AggregateResultExecutor.
|
||||
|
||||
Expected output:
|
||||
Adding one to the number: 1 Result: 2
|
||||
Adding one to the number: 2 Result: 3
|
||||
Aggregating results: 3
|
||||
"""
|
||||
add_one_executor_a = AddOneExecutor()
|
||||
add_one_executor_b = AddOneExecutor()
|
||||
aggregate_result_executor = AggregateResultExecutor()
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.add_edge(add_one_executor_a, add_one_executor_b)
|
||||
.add_edge(add_one_executor_b, aggregate_result_executor)
|
||||
.set_start_executor(add_one_executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
await workflow.run(1)
|
||||
|
||||
|
||||
async def single_edge_with_condition():
|
||||
"""A sample to demonstrate a single directed connection with a condition.
|
||||
|
||||
Three executors are connected: AddOneExecutor -> AddOneExecutor, AggregateResultExecutor.
|
||||
The AddOneExecutor will loop back to itself until the number reaches 10, then it will start
|
||||
sending the result to AggregateResultExecutor when the number is greater than 8. The workflow
|
||||
stops when the number reaches 11.
|
||||
|
||||
Expected output:
|
||||
Adding one to the number: 1 Result: 2
|
||||
Adding one to the number: 2 Result: 3
|
||||
Adding one to the number: 3 Result: 4
|
||||
Adding one to the number: 4 Result: 5
|
||||
Adding one to the number: 5 Result: 6
|
||||
Adding one to the number: 6 Result: 7
|
||||
Adding one to the number: 7 Result: 8
|
||||
Adding one to the number: 8 Result: 9
|
||||
Adding one to the number: 9 Result: 10
|
||||
Aggregating results: 9
|
||||
Adding one to the number: 10 Result: 11
|
||||
Aggregating results: 10
|
||||
Aggregating results: 11
|
||||
"""
|
||||
add_one_executor_a = AddOneExecutor()
|
||||
aggregate_result_executor = AggregateResultExecutor()
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.add_edge(add_one_executor_a, add_one_executor_a, condition=lambda x: x < 11)
|
||||
.add_edge(add_one_executor_a, aggregate_result_executor, condition=lambda x: x > 8)
|
||||
.set_start_executor(add_one_executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
await workflow.run(1)
|
||||
|
||||
|
||||
async def fan_out_fan_in_edge_group():
|
||||
"""A sample to demonstrate a fan-out and fan-in connection between executors.
|
||||
|
||||
Four executors are connected in a fan-out and fan-in pattern:
|
||||
AddOneExecutor -> MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor.
|
||||
The AddOneExecutor sends its output to both MultiplyByTwoExecutor and DivideByTwoExecutor,
|
||||
and both of these executors send their results to AggregateResultExecutor.
|
||||
|
||||
The target of the fan-in connection will wait for all the results from the sources before proceeding.
|
||||
|
||||
Expected output:
|
||||
Adding one to the number: 1 Result: 2
|
||||
Multiplying the number by two: 2 Result: 4
|
||||
Dividing the number by two: 2 Result: 1.0
|
||||
Aggregating results: [4, 1.0]
|
||||
"""
|
||||
add_one_executor = AddOneExecutor()
|
||||
multiply_by_two_executor = MultiplyByTwoExecutor()
|
||||
divide_by_two_executor = DivideByTwoExecutor()
|
||||
aggregate_result_executor = AggregateResultExecutor()
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.add_fan_out_edges(add_one_executor, [multiply_by_two_executor, divide_by_two_executor])
|
||||
.add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor)
|
||||
.set_start_executor(add_one_executor)
|
||||
.build()
|
||||
)
|
||||
|
||||
await workflow.run(1)
|
||||
|
||||
|
||||
async def switch_case_edge_group():
|
||||
"""A sample to demonstrate a switch-case connection.
|
||||
|
||||
Four executors are connected in a switch-case pattern:
|
||||
AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor.
|
||||
|
||||
The message from AddOneExecutor will be evaluated against the conditions one by one, and the first condition
|
||||
that evaluates to True will determine the target executors. If no conditions match, the message will be sent
|
||||
to the last targets.
|
||||
|
||||
This pattern resembles a switch-case statement with a default case where the first matching case is executed.
|
||||
|
||||
Expected output:
|
||||
Adding one to the number: 1 Result: 2
|
||||
Adding one to the number: 2 Result: 3
|
||||
Adding one to the number: 3 Result: 4
|
||||
Adding one to the number: 4 Result: 5
|
||||
Adding one to the number: 5 Result: 6
|
||||
Adding one to the number: 6 Result: 7
|
||||
Adding one to the number: 7 Result: 8
|
||||
Adding one to the number: 8 Result: 9
|
||||
Adding one to the number: 9 Result: 10
|
||||
Adding one to the number: 10 Result: 11
|
||||
Multiplying the number by two: 11 Result: 22
|
||||
"""
|
||||
add_one_executor = AddOneExecutor()
|
||||
multiply_by_two_executor = MultiplyByTwoExecutor()
|
||||
divide_by_two_executor = DivideByTwoExecutor()
|
||||
aggregate_result_executor = AggregateResultExecutor()
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(add_one_executor)
|
||||
.add_switch_case_edge_group(
|
||||
source=add_one_executor,
|
||||
cases=[
|
||||
# Loop back to the add_one_executor if the number is less than 11
|
||||
Case(condition=lambda x: x < 11, target=add_one_executor),
|
||||
# multiply_by_two_executor when the number is larger than or equal to 11 and even.
|
||||
Case(condition=lambda x: x % 2 == 0, target=multiply_by_two_executor),
|
||||
# Otherwise, send to the divide_by_two_executor.
|
||||
Default(target=divide_by_two_executor),
|
||||
],
|
||||
)
|
||||
.add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor)
|
||||
.build()
|
||||
)
|
||||
|
||||
await workflow.run(1)
|
||||
|
||||
|
||||
async def multi_selection_edge_group():
|
||||
"""A sample to demonstrate a multi-selection edge connection.
|
||||
|
||||
Four executors are connected in a multi-selection edge pattern:
|
||||
AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor.
|
||||
|
||||
The AddOneExecutor sends its output to one or more executors based on the partitioning function.
|
||||
|
||||
Expected output:
|
||||
Adding one to the number: 1 Result: 2
|
||||
Adding one to the number: 2 Result: 3
|
||||
Adding one to the number: 3 Result: 4
|
||||
Adding one to the number: 4 Result: 5
|
||||
Adding one to the number: 5 Result: 6
|
||||
Adding one to the number: 6 Result: 7
|
||||
Adding one to the number: 7 Result: 8
|
||||
Adding one to the number: 8 Result: 9
|
||||
Adding one to the number: 9 Result: 10
|
||||
Adding one to the number: 10 Result: 11
|
||||
Adding one to the number: 11 Result: 12
|
||||
Adding one to the number: 12 Result: 13
|
||||
Dividing the number by two: 12 Result: 6.0
|
||||
Multiplying the number by two: 13 Result: 26
|
||||
Aggregating results: [26, 6.0]
|
||||
"""
|
||||
add_one_executor = AddOneExecutor()
|
||||
multiply_by_two_executor = MultiplyByTwoExecutor()
|
||||
divide_by_two_executor = DivideByTwoExecutor()
|
||||
aggregate_result_executor = AggregateResultExecutor()
|
||||
|
||||
def selection_func(number: int, target_ids: list[str]) -> list[str]:
|
||||
"""Selection function to determine which executor to send the number to."""
|
||||
if number < 12:
|
||||
# Loop back to the add_one_executor if the number is less than 12
|
||||
return [add_one_executor.id]
|
||||
|
||||
if number % 2 == 0:
|
||||
# Send it to the add_one_executor to add one more time and the
|
||||
# divide_by_two_executor to divide the result by two.
|
||||
return [add_one_executor.id, divide_by_two_executor.id]
|
||||
|
||||
# Otherwise, send it to the multiply_by_two_executor to multiply the result by two.
|
||||
return [multiply_by_two_executor.id]
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(add_one_executor)
|
||||
.add_multi_selection_edge_group(
|
||||
add_one_executor,
|
||||
[add_one_executor, multiply_by_two_executor, divide_by_two_executor],
|
||||
selection_func=selection_func,
|
||||
)
|
||||
.add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor)
|
||||
.build()
|
||||
)
|
||||
|
||||
await workflow.run(1)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run the workflows."""
|
||||
print("**Running single connection workflow**")
|
||||
await single_edge()
|
||||
print("**Running single connection with condition workflow**")
|
||||
await single_edge_with_condition()
|
||||
print("**Running fan-out and fan-in connection workflow**")
|
||||
await fan_out_fan_in_edge_group()
|
||||
print("**Running conditional fan-out connection workflow**")
|
||||
await switch_case_edge_group()
|
||||
print("**Running multi-selection edge group workflow**")
|
||||
await multi_selection_edge_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -3,7 +3,15 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework.workflow import Executor, WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, handler
|
||||
from agent_framework.workflow import (
|
||||
Case,
|
||||
Default,
|
||||
Executor,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
|
||||
"""
|
||||
The following sample demonstrates a basic workflow with two executors
|
||||
@@ -91,15 +99,12 @@ async def main():
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(spam_detector)
|
||||
.add_edge(
|
||||
.add_switch_case_edge_group(
|
||||
spam_detector,
|
||||
send_response,
|
||||
condition=lambda x: x.is_spam is False,
|
||||
)
|
||||
.add_edge(
|
||||
spam_detector,
|
||||
remove_spam,
|
||||
condition=lambda x: x.is_spam is True,
|
||||
[
|
||||
Case(condition=lambda x: x.is_spam, target=remove_spam),
|
||||
Default(target=send_response),
|
||||
],
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -120,8 +120,7 @@ async def main():
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(group_chat_manager)
|
||||
.add_edge(group_chat_manager, writer)
|
||||
.add_edge(group_chat_manager, reviewer)
|
||||
.add_fan_out_edges(group_chat_manager, [writer, reviewer])
|
||||
.add_edge(writer, group_chat_manager)
|
||||
.add_edge(reviewer, group_chat_manager)
|
||||
.build()
|
||||
|
||||
@@ -167,8 +167,7 @@ async def main():
|
||||
.set_start_executor(group_chat_manager)
|
||||
.add_edge(group_chat_manager, request_info_executor)
|
||||
.add_edge(request_info_executor, group_chat_manager)
|
||||
.add_edge(group_chat_manager, writer)
|
||||
.add_edge(group_chat_manager, reviewer)
|
||||
.add_fan_out_edges(group_chat_manager, [writer, reviewer])
|
||||
.add_edge(writer, group_chat_manager)
|
||||
.add_edge(reviewer, group_chat_manager)
|
||||
.build()
|
||||
|
||||
@@ -113,7 +113,6 @@ class Map(Executor):
|
||||
ctx: The execution context containing the shared state and other information.
|
||||
"""
|
||||
# Retrieve the data to be processed from the shared state.
|
||||
# Define a key for the shared state to store the data to be processed
|
||||
data_to_be_processed: list[str] = await ctx.get_shared_state(SHARED_STATE_DATA_KEY)
|
||||
chunk_start, chunk_end = await ctx.get_shared_state(self.id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user