mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
[BREAKING] Python: Remove workflow register factory methods. Update tests and samples (#3781)
* Remove workflow register factory methods. Update tests and samples * Address Copilot feedback
This commit is contained in:
committed by
GitHub
Unverified
parent
f407f726a7
commit
a4c9e43afb
@@ -3,11 +3,9 @@
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .._agents import SupportsAgentRun
|
||||
from .._threads import AgentThread
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent_executor import AgentExecutor
|
||||
from ._agent_utils import resolve_agent_id
|
||||
@@ -40,76 +38,6 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _EdgeRegistration:
|
||||
"""A data class representing an edge registration in the workflow builder.
|
||||
|
||||
Args:
|
||||
source: The registered source name.
|
||||
target: The registered target name.
|
||||
condition: An optional condition function `(data) -> bool | Awaitable[bool]`.
|
||||
"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
condition: EdgeCondition | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FanOutEdgeRegistration:
|
||||
"""A data class representing a fan-out edge registration in the workflow builder.
|
||||
|
||||
Args:
|
||||
source: The registered source name.
|
||||
targets: A list of registered target names.
|
||||
"""
|
||||
|
||||
source: str
|
||||
targets: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FanInEdgeRegistration:
|
||||
"""A data class representing a fan-in edge registration in the workflow builder.
|
||||
|
||||
Args:
|
||||
sources: A list of registered source names.
|
||||
target: The registered target name.
|
||||
"""
|
||||
|
||||
sources: list[str]
|
||||
target: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SwitchCaseEdgeGroupRegistration:
|
||||
"""A data class representing a switch-case edge group registration in the workflow builder.
|
||||
|
||||
Args:
|
||||
source: The registered source name.
|
||||
cases: A list of case objects that determine the target executor for each message.
|
||||
"""
|
||||
|
||||
source: str
|
||||
cases: list[Case | Default]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MultiSelectionEdgeGroupRegistration:
|
||||
"""A data class representing a multi-selection edge group registration in the workflow builder.
|
||||
|
||||
Args:
|
||||
source: The registered source name.
|
||||
targets: A list of registered target names.
|
||||
selection_func: A function that selects target executors for messages.
|
||||
Takes (message, list[registered target names]) and returns list[registered target names].
|
||||
"""
|
||||
|
||||
source: str
|
||||
targets: list[str]
|
||||
selection_func: Callable[[Any, list[str]], list[str]]
|
||||
|
||||
|
||||
class WorkflowBuilder:
|
||||
"""A builder class for constructing workflows.
|
||||
|
||||
@@ -136,14 +64,10 @@ class WorkflowBuilder:
|
||||
await ctx.yield_output(text[::-1])
|
||||
|
||||
|
||||
# Build a workflow
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="UpperCase")
|
||||
.register_executor(lambda: UpperCaseExecutor(id="upper"), name="UpperCase")
|
||||
.register_executor(lambda: ReverseExecutor(id="reverse"), name="Reverse")
|
||||
.add_edge("UpperCase", "Reverse")
|
||||
.build()
|
||||
)
|
||||
upper = UpperCaseExecutor(id="upper")
|
||||
reverse = ReverseExecutor(id="reverse")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=upper).add_edge(upper, reverse).build()
|
||||
|
||||
# Run the workflow
|
||||
events = await workflow.run("hello")
|
||||
@@ -156,9 +80,9 @@ class WorkflowBuilder:
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
*,
|
||||
start_executor: Executor | SupportsAgentRun | str,
|
||||
start_executor: Executor | SupportsAgentRun,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
output_executors: list[Executor | SupportsAgentRun | str] | None = None,
|
||||
output_executors: list[Executor | SupportsAgentRun] | None = None,
|
||||
):
|
||||
"""Initialize the WorkflowBuilder.
|
||||
|
||||
@@ -166,15 +90,15 @@ class WorkflowBuilder:
|
||||
max_iterations: Maximum number of iterations for workflow convergence. Default is 100.
|
||||
name: Optional human-readable name for the workflow.
|
||||
description: Optional description of what the workflow does.
|
||||
start_executor: The starting executor for the workflow. Can be an Executor instance,
|
||||
SupportsAgentRun instance, or the name of a registered executor factory.
|
||||
start_executor: The starting executor for the workflow. Can be an Executor instance
|
||||
or SupportsAgentRun instance.
|
||||
checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence.
|
||||
output_executors: Optional list of executors whose outputs should be collected.
|
||||
If not provided, outputs from all executors are collected.
|
||||
"""
|
||||
self._edge_groups: list[EdgeGroup] = []
|
||||
self._executors: dict[str, Executor] = {}
|
||||
self._start_executor: Executor | str | None = None
|
||||
self._start_executor: Executor | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = checkpoint_storage
|
||||
self._max_iterations: int = max_iterations
|
||||
self._name: str | None = name
|
||||
@@ -184,18 +108,8 @@ class WorkflowBuilder:
|
||||
# being created for the same agent.
|
||||
self._agent_wrappers: dict[str, Executor] = {}
|
||||
|
||||
# Registrations for lazy initialization of executors
|
||||
self._edge_registry: list[
|
||||
_EdgeRegistration
|
||||
| _FanOutEdgeRegistration
|
||||
| _SwitchCaseEdgeGroupRegistration
|
||||
| _MultiSelectionEdgeGroupRegistration
|
||||
| _FanInEdgeRegistration
|
||||
] = []
|
||||
self._executor_registry: dict[str, Callable[[], Executor]] = {}
|
||||
|
||||
# Output executors filter; if set, only outputs from these executors are yielded
|
||||
self._output_executors: list[Executor | SupportsAgentRun | str] = output_executors if output_executors else []
|
||||
self._output_executors: list[Executor | SupportsAgentRun] = output_executors if output_executors else []
|
||||
|
||||
# Set the start executor
|
||||
self._set_start_executor(start_executor)
|
||||
@@ -258,133 +172,10 @@ class WorkflowBuilder:
|
||||
f"WorkflowBuilder expected an Executor or SupportsAgentRun instance; got {type(candidate).__name__}."
|
||||
)
|
||||
|
||||
def register_executor(self, factory_func: Callable[[], Executor], name: str | list[str]) -> Self:
|
||||
"""Register an executor factory function for lazy initialization.
|
||||
|
||||
This method allows you to register a factory function that creates an executor.
|
||||
The executor will be instantiated only when the workflow is built, enabling
|
||||
deferred initialization and potentially reducing startup time.
|
||||
|
||||
Args:
|
||||
factory_func: A callable that returns an Executor instance when called.
|
||||
name: The name(s) of the registered executor factory. This doesn't have to match
|
||||
the executor's ID, but it must be unique within the workflow.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from typing_extensions import Never
|
||||
from agent_framework import Executor, WorkflowBuilder, WorkflowContext, handler
|
||||
|
||||
|
||||
class UpperCaseExecutor(Executor):
|
||||
@handler
|
||||
async def process(self, text: str, ctx: WorkflowContext[str]) -> None:
|
||||
await ctx.send_message(text.upper())
|
||||
|
||||
|
||||
class ReverseExecutor(Executor):
|
||||
@handler
|
||||
async def process(self, text: str, ctx: WorkflowContext[Never, str]) -> None:
|
||||
await ctx.yield_output(text[::-1])
|
||||
|
||||
|
||||
# Build a workflow
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="UpperCase")
|
||||
.register_executor(lambda: UpperCaseExecutor(id="upper"), name="UpperCase")
|
||||
.register_executor(lambda: ReverseExecutor(id="reverse"), name="Reverse")
|
||||
.add_edge("UpperCase", "Reverse")
|
||||
.build()
|
||||
)
|
||||
|
||||
If multiple names are provided, the same factory function will be registered under each name.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework import WorkflowBuilder, Executor, WorkflowContext, handler
|
||||
|
||||
|
||||
class LoggerExecutor(Executor):
|
||||
@handler
|
||||
async def log(self, message: str, ctx: WorkflowContext) -> None:
|
||||
print(f"Log: {message}")
|
||||
|
||||
|
||||
# Register the same executor factory under multiple names
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="ExecutorA")
|
||||
.register_executor(lambda: LoggerExecutor(id="logger"), name=["ExecutorA", "ExecutorB"])
|
||||
.add_edge("ExecutorA", "ExecutorB")
|
||||
.build()
|
||||
"""
|
||||
names = [name] if isinstance(name, str) else name
|
||||
|
||||
for n in names:
|
||||
if n in self._executor_registry:
|
||||
raise ValueError(f"An executor factory with the name '{n}' is already registered.")
|
||||
|
||||
for n in names:
|
||||
self._executor_registry[n] = factory_func
|
||||
|
||||
return self
|
||||
|
||||
def register_agent(
|
||||
self,
|
||||
factory_func: Callable[[], SupportsAgentRun],
|
||||
name: str,
|
||||
agent_thread: AgentThread | None = None,
|
||||
) -> Self:
|
||||
"""Register an agent factory function for lazy initialization.
|
||||
|
||||
This method allows you to register a factory function that creates an agent.
|
||||
The agent will be instantiated and wrapped in an AgentExecutor only when the workflow is built,
|
||||
enabling deferred initialization and potentially reducing startup time.
|
||||
|
||||
Args:
|
||||
factory_func: A callable that returns an SupportsAgentRun instance when called.
|
||||
name: The name of the registered agent factory. This doesn't have to match
|
||||
the agent's internal name. But it must be unique within the workflow.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created when
|
||||
the agent is instantiated.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework import WorkflowBuilder
|
||||
from agent_framework_anthropic import AnthropicAgent
|
||||
|
||||
|
||||
# Build a workflow
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="SomeOtherExecutor")
|
||||
.register_executor(lambda: ..., name="SomeOtherExecutor")
|
||||
.register_agent(
|
||||
lambda: AnthropicAgent(name="writer", model="claude-3-5-sonnet-20241022"),
|
||||
name="WriterAgent",
|
||||
output_response=True,
|
||||
)
|
||||
.add_edge("SomeOtherExecutor", "WriterAgent")
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
if name in self._executor_registry:
|
||||
raise ValueError(f"An agent factory with the name '{name}' is already registered.")
|
||||
|
||||
def wrapped_factory() -> AgentExecutor:
|
||||
agent = factory_func()
|
||||
return AgentExecutor(
|
||||
agent,
|
||||
agent_thread=agent_thread,
|
||||
)
|
||||
|
||||
self._executor_registry[name] = wrapped_factory
|
||||
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
target: Executor | SupportsAgentRun | str,
|
||||
source: Executor | SupportsAgentRun,
|
||||
target: Executor | SupportsAgentRun,
|
||||
condition: EdgeCondition | None = None,
|
||||
) -> Self:
|
||||
"""Add a directed edge between two executors.
|
||||
@@ -393,17 +184,12 @@ class WorkflowBuilder:
|
||||
Messages sent by the source executor will be routed to the target executor.
|
||||
|
||||
Args:
|
||||
source: The source executor or registered name of the source factory for the edge.
|
||||
target: The target executor or registered name of the target factory for the edge.
|
||||
source: The source executor or agent for the edge.
|
||||
target: The target executor or agent for the edge.
|
||||
condition: An optional condition function `(data) -> bool | Awaitable[bool]`
|
||||
that determines whether the edge should be traversed.
|
||||
Example: `lambda data: data["ready"]`.
|
||||
|
||||
Note: If instances are provided for both source and target, they will be shared across
|
||||
all workflow instances created from the built Workflow. To avoid this, consider
|
||||
registering the executors and agents using `register_executor` and `register_agent`
|
||||
and referencing them by factory name for lazy initialization instead.
|
||||
|
||||
Returns:
|
||||
Self: The WorkflowBuilder instance for method chaining.
|
||||
|
||||
@@ -426,39 +212,13 @@ class WorkflowBuilder:
|
||||
await ctx.yield_output(f"Processed {count} characters")
|
||||
|
||||
|
||||
# Connect executors with an edge
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="ProcessorA")
|
||||
.register_executor(lambda: ProcessorA(id="a"), name="ProcessorA")
|
||||
.register_executor(lambda: ProcessorB(id="b"), name="ProcessorB")
|
||||
.add_edge("ProcessorA", "ProcessorB")
|
||||
.build()
|
||||
)
|
||||
a = ProcessorA(id="a")
|
||||
b = ProcessorB(id="b")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="ProcessorA")
|
||||
.register_executor(lambda: ProcessorA(id="a"), name="ProcessorA")
|
||||
.register_executor(lambda: ProcessorB(id="b"), name="ProcessorB")
|
||||
.add_edge("ProcessorA", "ProcessorB", condition=only_large_numbers)
|
||||
.build()
|
||||
)
|
||||
workflow = WorkflowBuilder(start_executor=a).add_edge(a, b).build()
|
||||
"""
|
||||
if (isinstance(source, str) and not isinstance(target, str)) or (
|
||||
not isinstance(source, str) and isinstance(target, str)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and target must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and isinstance(target, str):
|
||||
# Both are names; defer resolution to build time
|
||||
self._edge_registry.append(_EdgeRegistration(source=source, target=target, condition=condition))
|
||||
return self
|
||||
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
|
||||
target_exec = self._maybe_wrap_agent(target) # type: ignore[arg-type]
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_exec = self._maybe_wrap_agent(target)
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_id = self._add_executor(target_exec)
|
||||
self._edge_groups.append(SingleEdgeGroup(source_id, target_id, condition))
|
||||
@@ -466,8 +226,8 @@ class WorkflowBuilder:
|
||||
|
||||
def add_fan_out_edges(
|
||||
self,
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
targets: Sequence[Executor | SupportsAgentRun | str],
|
||||
source: Executor | SupportsAgentRun,
|
||||
targets: Sequence[Executor | SupportsAgentRun],
|
||||
) -> Self:
|
||||
"""Add multiple edges to the workflow where messages from the source will be sent to all targets.
|
||||
|
||||
@@ -475,17 +235,12 @@ class WorkflowBuilder:
|
||||
Messages from the source will be broadcast to all target executors concurrently.
|
||||
|
||||
Args:
|
||||
source: The source executor or registered name of the source factory for the edges.
|
||||
targets: A list of target executors or registered names of the target factories for the edges.
|
||||
source: The source executor or agent for the edges.
|
||||
targets: A list of target executors or agents for the edges.
|
||||
|
||||
Returns:
|
||||
Self: The WorkflowBuilder instance for method chaining.
|
||||
|
||||
Note: If instances are provided for source and targets, they will be shared across
|
||||
all workflow instances created from the built Workflow. To avoid this, consider
|
||||
registering the executors and agents using `register_executor` and `register_agent`
|
||||
and referencing them by factory name for lazy initialization instead.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -511,32 +266,14 @@ class WorkflowBuilder:
|
||||
print(f"ValidatorB: {data}")
|
||||
|
||||
|
||||
# Broadcast to multiple validators
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="DataSource")
|
||||
.register_executor(lambda: DataSource(id="source"), name="DataSource")
|
||||
.register_executor(lambda: ValidatorA(id="val_a"), name="ValidatorA")
|
||||
.register_executor(lambda: ValidatorB(id="val_b"), name="ValidatorB")
|
||||
.add_fan_out_edges("DataSource", ["ValidatorA", "ValidatorB"])
|
||||
.build()
|
||||
)
|
||||
source = DataSource(id="source")
|
||||
val_a = ValidatorA(id="val_a")
|
||||
val_b = ValidatorB(id="val_b")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=source).add_fan_out_edges(source, [val_a, val_b]).build()
|
||||
"""
|
||||
if (isinstance(source, str) and not all(isinstance(t, str) for t in targets)) or (
|
||||
not isinstance(source, str) and any(isinstance(t, str) for t in targets)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and targets must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and all(isinstance(t, str) for t in targets):
|
||||
# Both are names; defer resolution to build time
|
||||
self._edge_registry.append(_FanOutEdgeRegistration(source=source, targets=list(targets))) # type: ignore
|
||||
return self
|
||||
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets] # type: ignore[arg-type]
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets]
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_ids = [self._add_executor(t) for t in target_execs]
|
||||
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids)) # type: ignore[call-arg]
|
||||
@@ -545,7 +282,7 @@ class WorkflowBuilder:
|
||||
|
||||
def add_switch_case_edge_group(
|
||||
self,
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
source: Executor | SupportsAgentRun,
|
||||
cases: Sequence[Case | Default],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a switch-case statement.
|
||||
@@ -562,17 +299,12 @@ class WorkflowBuilder:
|
||||
(i.e., no condition matched).
|
||||
|
||||
Args:
|
||||
source: The source executor or registered name of the source factory for the edge group.
|
||||
source: The source executor or agent for the edge group.
|
||||
cases: A list of case objects that determine the target executor for each message.
|
||||
|
||||
Returns:
|
||||
Self: The WorkflowBuilder instance for method chaining.
|
||||
|
||||
Note: If instances are provided for source and case targets, they will be shared across
|
||||
all workflow instances created from the built Workflow. To avoid this, consider
|
||||
registering the executors and agents using `register_executor` and `register_agent`
|
||||
and referencing them by factory name for lazy initialization instead.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -603,37 +335,23 @@ class WorkflowBuilder:
|
||||
print(f"Low score: {result.score}")
|
||||
|
||||
|
||||
# Route based on score value
|
||||
evaluator = Evaluator(id="eval")
|
||||
high = HighScoreHandler(id="high")
|
||||
low = LowScoreHandler(id="low")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="Evaluator")
|
||||
.register_executor(lambda: Evaluator(id="eval"), name="Evaluator")
|
||||
.register_executor(lambda: HighScoreHandler(id="high"), name="HighScoreHandler")
|
||||
.register_executor(lambda: LowScoreHandler(id="low"), name="LowScoreHandler")
|
||||
WorkflowBuilder(start_executor=evaluator)
|
||||
.add_switch_case_edge_group(
|
||||
"Evaluator",
|
||||
evaluator,
|
||||
[
|
||||
Case(condition=lambda r: r.score > 10, target="HighScoreHandler"),
|
||||
Default(target="LowScoreHandler"),
|
||||
Case(condition=lambda r: r.score > 10, target=high),
|
||||
Default(target=low),
|
||||
],
|
||||
)
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
if (isinstance(source, str) and not all(isinstance(case.target, str) for case in cases)) or (
|
||||
not isinstance(source, str) and any(isinstance(case.target, str) for case in cases)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and case targets must be either registered factory names (str) "
|
||||
"or Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and all(isinstance(case.target, str) for case in cases):
|
||||
# Source is a name; defer resolution to build time
|
||||
self._edge_registry.append(_SwitchCaseEdgeGroupRegistration(source=source, cases=list(cases))) # type: ignore
|
||||
return self
|
||||
|
||||
# Source is an Executor/SupportsAgentRun instance; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore[arg-type]
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
source_id = self._add_executor(source_exec)
|
||||
# Convert case data types to internal types that only uses target_id.
|
||||
internal_cases: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = []
|
||||
@@ -651,8 +369,8 @@ class WorkflowBuilder:
|
||||
|
||||
def add_multi_selection_edge_group(
|
||||
self,
|
||||
source: Executor | SupportsAgentRun | str,
|
||||
targets: Sequence[Executor | SupportsAgentRun | str],
|
||||
source: Executor | SupportsAgentRun,
|
||||
targets: Sequence[Executor | SupportsAgentRun],
|
||||
selection_func: Callable[[Any, list[str]], list[str]],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a multi-selection execution model.
|
||||
@@ -665,19 +383,14 @@ class WorkflowBuilder:
|
||||
and return a list of executor IDs indicating which target executors should receive the message.
|
||||
|
||||
Args:
|
||||
source: The source executor or registered name of the source factory for the edge group.
|
||||
targets: A list of target executors or registered names of the target factories for the edges.
|
||||
source: The source executor or agent for the edge group.
|
||||
targets: A list of target executors or agents for the edges.
|
||||
selection_func: A function that selects target executors for messages.
|
||||
Takes (message, list[executor_id]) and returns list[executor_id].
|
||||
|
||||
Returns:
|
||||
Self: The WorkflowBuilder instance for method chaining.
|
||||
|
||||
Note: If instances are provided for source and targets, they will be shared across
|
||||
all workflow instances created from the built Workflow. To avoid this, consider
|
||||
registering the executors and agents using `register_executor` and `register_agent`
|
||||
and referencing them by factory name for lazy initialization instead.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -710,6 +423,11 @@ class WorkflowBuilder:
|
||||
print(f"WorkerB processing: {task.data}")
|
||||
|
||||
|
||||
dispatcher = TaskDispatcher(id="dispatcher")
|
||||
worker_a = WorkerA(id="worker_a")
|
||||
worker_b = WorkerB(id="worker_b")
|
||||
|
||||
|
||||
# Select workers based on task priority
|
||||
def select_workers(task: Task, available: list[str]) -> list[str]:
|
||||
if task.priority == "high":
|
||||
@@ -718,40 +436,17 @@ class WorkflowBuilder:
|
||||
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="TaskDispatcher")
|
||||
.register_executor(lambda: TaskDispatcher(id="dispatcher"), name="TaskDispatcher")
|
||||
.register_executor(lambda: WorkerA(id="worker_a"), name="WorkerA")
|
||||
.register_executor(lambda: WorkerB(id="worker_b"), name="WorkerB")
|
||||
WorkflowBuilder(start_executor=dispatcher)
|
||||
.add_multi_selection_edge_group(
|
||||
"TaskDispatcher",
|
||||
["WorkerA", "WorkerB"],
|
||||
dispatcher,
|
||||
[worker_a, worker_b],
|
||||
selection_func=select_workers,
|
||||
)
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
if (isinstance(source, str) and not all(isinstance(t, str) for t in targets)) or (
|
||||
not isinstance(source, str) and any(isinstance(t, str) for t in targets)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both source and targets must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if isinstance(source, str) and all(isinstance(t, str) for t in targets):
|
||||
# Both are names; defer resolution to build time
|
||||
self._edge_registry.append(
|
||||
_MultiSelectionEdgeGroupRegistration(
|
||||
source=source,
|
||||
targets=list(targets), # type: ignore
|
||||
selection_func=selection_func,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_exec = self._maybe_wrap_agent(source) # type: ignore
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets] # type: ignore
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets]
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_ids = [self._add_executor(t) for t in target_execs]
|
||||
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids, selection_func)) # type: ignore[call-arg]
|
||||
@@ -760,8 +455,8 @@ class WorkflowBuilder:
|
||||
|
||||
def add_fan_in_edges(
|
||||
self,
|
||||
sources: Sequence[Executor | SupportsAgentRun | str],
|
||||
target: Executor | SupportsAgentRun | str,
|
||||
sources: Sequence[Executor | SupportsAgentRun],
|
||||
target: Executor | SupportsAgentRun,
|
||||
) -> Self:
|
||||
"""Add multiple edges from sources to a single target executor.
|
||||
|
||||
@@ -773,17 +468,12 @@ class WorkflowBuilder:
|
||||
types of the source executors.
|
||||
|
||||
Args:
|
||||
sources: A list of source executors or registered names of the source factories for the edges.
|
||||
target: The target executor or registered name of the target factory for the edges.
|
||||
sources: A list of source executors or agents for the edges.
|
||||
target: The target executor or agent for the edges.
|
||||
|
||||
Returns:
|
||||
Self: The WorkflowBuilder instance for method chaining.
|
||||
|
||||
Note: If instances are provided for sources and target, they will be shared across
|
||||
all workflow instances created from the built Workflow. To avoid this, consider
|
||||
registering the executors and agents using `register_executor` and `register_agent`
|
||||
and referencing them by factory name for lazy initialization instead.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -804,39 +494,21 @@ class WorkflowBuilder:
|
||||
await ctx.yield_output(f"Combined: {combined}")
|
||||
|
||||
|
||||
# Collect results from multiple producers
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="Producer1")
|
||||
.register_executor(lambda: Producer(id="prod_1"), name="Producer1")
|
||||
.register_executor(lambda: Producer(id="prod_2"), name="Producer2")
|
||||
.register_executor(lambda: Aggregator(id="agg"), name="Aggregator")
|
||||
.add_fan_in_edges(["Producer1", "Producer2"], "Aggregator")
|
||||
.build()
|
||||
)
|
||||
prod_1 = Producer(id="prod_1")
|
||||
prod_2 = Producer(id="prod_2")
|
||||
agg = Aggregator(id="agg")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=prod_1).add_fan_in_edges([prod_1, prod_2], agg).build()
|
||||
"""
|
||||
if (all(isinstance(s, str) for s in sources) and not isinstance(target, str)) or (
|
||||
not all(isinstance(s, str) for s in sources) and isinstance(target, str)
|
||||
):
|
||||
raise ValueError(
|
||||
"Both sources and target must be either registered factory names (str) or "
|
||||
"Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if all(isinstance(s, str) for s in sources) and isinstance(target, str):
|
||||
# Both are names; defer resolution to build time
|
||||
self._edge_registry.append(_FanInEdgeRegistration(sources=list(sources), target=target)) # type: ignore
|
||||
return self
|
||||
|
||||
# Both are Executor/SupportsAgentRun instances; wrap and add now
|
||||
source_execs = [self._maybe_wrap_agent(s) for s in sources] # type: ignore
|
||||
target_exec = self._maybe_wrap_agent(target) # type: ignore
|
||||
source_execs = [self._maybe_wrap_agent(s) for s in sources]
|
||||
target_exec = self._maybe_wrap_agent(target)
|
||||
source_ids = [self._add_executor(s) for s in source_execs]
|
||||
target_id = self._add_executor(target_exec)
|
||||
self._edge_groups.append(FanInEdgeGroup(source_ids, target_id)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_chain(self, executors: Sequence[Executor | SupportsAgentRun | str]) -> Self:
|
||||
def add_chain(self, executors: Sequence[Executor | SupportsAgentRun]) -> Self:
|
||||
"""Add a chain of executors to the workflow.
|
||||
|
||||
The output of each executor in the chain will be sent to the next executor in the chain.
|
||||
@@ -845,16 +517,11 @@ class WorkflowBuilder:
|
||||
Cycles in the chain are not allowed, meaning an executor cannot appear more than once in the chain.
|
||||
|
||||
Args:
|
||||
executors: A list of executors or registered names of the executor factories to chain together.
|
||||
executors: A list of executors or agents to chain together.
|
||||
|
||||
Returns:
|
||||
Self: The WorkflowBuilder instance for method chaining.
|
||||
|
||||
Note: If executor instances are provided, they will be shared across all workflow instances created
|
||||
from the built Workflow. To avoid this, consider registering the executors and agents using
|
||||
`register_executor` and `register_agent` and referencing them by factory name for lazy
|
||||
initialization instead.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -880,148 +547,37 @@ class WorkflowBuilder:
|
||||
await ctx.yield_output(f"Final: {text}")
|
||||
|
||||
|
||||
# Chain executors in sequence
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="step1")
|
||||
.register_executor(lambda: Step1(id="step1"), name="step1")
|
||||
.register_executor(lambda: Step2(id="step2"), name="step2")
|
||||
.register_executor(lambda: Step3(id="step3"), name="step3")
|
||||
.add_chain(["step1", "step2", "step3"])
|
||||
.build()
|
||||
)
|
||||
step1 = Step1(id="step1")
|
||||
step2 = Step2(id="step2")
|
||||
step3 = Step3(id="step3")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=step1).add_chain([step1, step2, step3]).build()
|
||||
"""
|
||||
if len(executors) < 2:
|
||||
raise ValueError("At least two executors are required to form a chain.")
|
||||
|
||||
if not all(isinstance(e, str) for e in executors) and any(isinstance(e, str) for e in executors):
|
||||
raise ValueError(
|
||||
"All executors in the chain must be either registered factory names (str) "
|
||||
"or Executor/SupportsAgentRun instances."
|
||||
)
|
||||
|
||||
if all(isinstance(e, str) for e in executors):
|
||||
# All are names; defer resolution to build time
|
||||
for i in range(len(executors) - 1):
|
||||
self.add_edge(executors[i], executors[i + 1])
|
||||
return self
|
||||
|
||||
# All are Executor/SupportsAgentRun instances; wrap and add now
|
||||
# Wrap each candidate first to ensure stable IDs before adding edges
|
||||
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors] # type: ignore[arg-type]
|
||||
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors]
|
||||
for i in range(len(wrapped) - 1):
|
||||
self.add_edge(wrapped[i], wrapped[i + 1])
|
||||
return self
|
||||
|
||||
def _set_start_executor(self, executor: Executor | SupportsAgentRun | str) -> None:
|
||||
def _set_start_executor(self, executor: Executor | SupportsAgentRun) -> None:
|
||||
"""Set the starting executor for the workflow (internal method).
|
||||
|
||||
Args:
|
||||
executor: The starting executor, which can be an Executor instance, SupportsAgentRun instance,
|
||||
or the name of a registered executor factory.
|
||||
executor: The starting executor, which can be an Executor instance or SupportsAgentRun instance.
|
||||
"""
|
||||
if self._start_executor is not None:
|
||||
start_id = self._start_executor if isinstance(self._start_executor, str) else self._start_executor.id
|
||||
logger.warning(f"Overwriting existing start executor: {start_id} for the workflow.")
|
||||
logger.warning(f"Overwriting existing start executor: {self._start_executor.id} for the workflow.")
|
||||
|
||||
if isinstance(executor, str):
|
||||
self._start_executor = executor
|
||||
else:
|
||||
wrapped = self._maybe_wrap_agent(executor) # type: ignore[arg-type]
|
||||
self._start_executor = wrapped
|
||||
# Ensure the start executor is present in the executor map so validation succeeds
|
||||
# even if no edges are added yet, or before edges wrap the same agent again.
|
||||
existing = self._executors.get(wrapped.id)
|
||||
if existing is not wrapped:
|
||||
self._add_executor(wrapped)
|
||||
|
||||
# Removed explicit set_agent_streaming() API; agents always stream updates.
|
||||
|
||||
def _resolve_edge_registry(self) -> tuple[Executor, dict[str, Executor], list[EdgeGroup]]:
|
||||
"""Resolve deferred edge registrations into executors and edge groups.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- The starting Executor instance.
|
||||
- A dictionary mapping registered factory names to resolved Executor instances.
|
||||
- A list of EdgeGroup instances representing the workflow edges composed of resolved executors.
|
||||
|
||||
Notes:
|
||||
Non-factory executors (i.e., those added directly) are not included in the returned list,
|
||||
as they are already part of the workflow builder's internal state.
|
||||
"""
|
||||
if not self._start_executor:
|
||||
raise ValueError(
|
||||
"Starting executor must be set via the start_executor constructor parameter before building."
|
||||
)
|
||||
|
||||
start_executor: Executor | None = None
|
||||
if isinstance(self._start_executor, Executor):
|
||||
start_executor = self._start_executor
|
||||
|
||||
# Maps registered factory names to created executor instances for edge resolution
|
||||
factory_name_to_instance: dict[str, Executor] = {}
|
||||
# Maps executor IDs to created executor instances to prevent duplicates
|
||||
executor_id_to_instance: dict[str, Executor] = {}
|
||||
deferred_edge_groups: list[EdgeGroup] = []
|
||||
for name, exec_factory in self._executor_registry.items():
|
||||
instance = exec_factory()
|
||||
if instance.id in executor_id_to_instance:
|
||||
raise ValueError(f"Executor with ID '{instance.id}' has already been registered.")
|
||||
if instance.id in self._executors:
|
||||
raise ValueError(f"Executor ID collision: An executor with ID '{instance.id}' already exists.")
|
||||
executor_id_to_instance[instance.id] = instance
|
||||
|
||||
if isinstance(self._start_executor, str) and name == self._start_executor:
|
||||
start_executor = instance
|
||||
|
||||
# All executors will get their own internal edge group for receiving system messages
|
||||
deferred_edge_groups.append(InternalEdgeGroup(instance.id)) # type: ignore[call-arg]
|
||||
factory_name_to_instance[name] = instance
|
||||
|
||||
def _get_executor(name: str) -> Executor:
|
||||
"""Helper to get executor by the registered name. Raises if not found."""
|
||||
if name not in factory_name_to_instance:
|
||||
raise ValueError(f"Factory '{name}' has not been registered.")
|
||||
return factory_name_to_instance[name]
|
||||
|
||||
for registration in self._edge_registry:
|
||||
match registration:
|
||||
case _EdgeRegistration(source, target, condition):
|
||||
source_exec: Executor = _get_executor(source)
|
||||
target_exec: Executor = _get_executor(target)
|
||||
deferred_edge_groups.append(SingleEdgeGroup(source_exec.id, target_exec.id, condition)) # type: ignore[call-arg]
|
||||
case _FanOutEdgeRegistration(source, targets):
|
||||
source_exec = _get_executor(source)
|
||||
target_execs = [_get_executor(t) for t in targets]
|
||||
deferred_edge_groups.append(FanOutEdgeGroup(source_exec.id, [t.id for t in target_execs])) # type: ignore[call-arg]
|
||||
case _SwitchCaseEdgeGroupRegistration(source, cases):
|
||||
source_exec = _get_executor(source)
|
||||
cases_converted: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = []
|
||||
for case in cases:
|
||||
if not isinstance(case.target, str):
|
||||
raise ValueError("Switch case target must be a registered factory name (str) if deferred.")
|
||||
target_exec = _get_executor(case.target)
|
||||
if isinstance(case, Default):
|
||||
cases_converted.append(SwitchCaseEdgeGroupDefault(target_id=target_exec.id))
|
||||
else:
|
||||
cases_converted.append(
|
||||
SwitchCaseEdgeGroupCase(condition=case.condition, target_id=target_exec.id)
|
||||
)
|
||||
deferred_edge_groups.append(SwitchCaseEdgeGroup(source_exec.id, cases_converted)) # type: ignore[call-arg]
|
||||
case _MultiSelectionEdgeGroupRegistration(source, targets, selection_func):
|
||||
source_exec = _get_executor(source)
|
||||
target_execs = [_get_executor(t) for t in targets]
|
||||
deferred_edge_groups.append(
|
||||
FanOutEdgeGroup(source_exec.id, [t.id for t in target_execs], selection_func) # type: ignore[call-arg]
|
||||
)
|
||||
case _FanInEdgeRegistration(sources, target):
|
||||
source_execs = [_get_executor(s) for s in sources]
|
||||
target_exec = _get_executor(target)
|
||||
deferred_edge_groups.append(FanInEdgeGroup([s.id for s in source_execs], target_exec.id)) # type: ignore[call-arg]
|
||||
if start_executor is None:
|
||||
raise ValueError("Failed to resolve starting executor from registered factories.")
|
||||
|
||||
return (start_executor, factory_name_to_instance, deferred_edge_groups)
|
||||
wrapped = self._maybe_wrap_agent(executor)
|
||||
self._start_executor = wrapped
|
||||
# Ensure the start executor is present in the executor map so validation succeeds
|
||||
# even if no edges are added yet, or before edges wrap the same agent again.
|
||||
existing = self._executors.get(wrapped.id)
|
||||
if existing is not wrapped:
|
||||
self._add_executor(wrapped)
|
||||
|
||||
def build(self) -> Workflow:
|
||||
"""Build and return the constructed workflow.
|
||||
@@ -1053,12 +609,9 @@ class WorkflowBuilder:
|
||||
await ctx.yield_output(text.upper())
|
||||
|
||||
|
||||
# Build and execute a workflow
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="MyExecutor")
|
||||
.register_executor(lambda: MyExecutor(id="executor"), name="MyExecutor")
|
||||
.build()
|
||||
)
|
||||
executor = MyExecutor(id="executor")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
# The workflow is now immutable and ready to run
|
||||
events = await workflow.run("hello")
|
||||
@@ -1074,23 +627,17 @@ class WorkflowBuilder:
|
||||
# Add workflow build started event
|
||||
span.add_event(OtelAttr.BUILD_STARTED)
|
||||
|
||||
# Resolve lazy edge registrations
|
||||
start_executor, deferred_executors, deferred_edge_groups = self._resolve_edge_registry()
|
||||
executors = self._executors | {exe.id: exe for exe in deferred_executors.values()}
|
||||
edge_groups = self._edge_groups + deferred_edge_groups
|
||||
output_executors = (
|
||||
[
|
||||
deferred_executors[factory_name].id
|
||||
for factory_name in self._output_executors
|
||||
if isinstance(factory_name, str)
|
||||
]
|
||||
+ [ex.id for ex in self._output_executors if isinstance(ex, Executor)]
|
||||
+ [
|
||||
resolve_agent_id(agent)
|
||||
for agent in self._output_executors
|
||||
if isinstance(agent, SupportsAgentRun)
|
||||
]
|
||||
)
|
||||
if not self._start_executor:
|
||||
raise ValueError(
|
||||
"Starting executor must be set via the start_executor constructor parameter before building."
|
||||
)
|
||||
|
||||
start_executor = self._start_executor
|
||||
executors = self._executors
|
||||
edge_groups = self._edge_groups
|
||||
output_executors = [ex.id for ex in self._output_executors if isinstance(ex, Executor)] + [
|
||||
resolve_agent_id(agent) for agent in self._output_executors if isinstance(agent, SupportsAgentRun)
|
||||
]
|
||||
|
||||
# Perform validation before creating the workflow
|
||||
validate_workflow_graph(
|
||||
|
||||
@@ -670,16 +670,20 @@ class TestWorkflowAgent:
|
||||
return ResponseStream(_iter(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
@executor
|
||||
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext[AgentExecutorRequest, str]) -> None:
|
||||
async def start_exec(messages: list[ChatMessage], ctx: WorkflowContext[AgentExecutorRequest, str]) -> None:
|
||||
await ctx.yield_output("Start output")
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
# Build workflow: start -> agent1 (no output) -> agent2 (output_response=True)
|
||||
builder = WorkflowBuilder(start_executor="start", output_executors=["start", "agent2"])
|
||||
builder.register_executor(lambda: start_executor, "start")
|
||||
builder.register_agent(lambda: MockAgent("agent1", "Agent1 output - should NOT appear"), "agent1")
|
||||
builder.register_agent(lambda: MockAgent("agent2", "Agent2 output - SHOULD appear"), "agent2")
|
||||
workflow = builder.add_edge("start", "agent1").add_edge("agent1", "agent2").build()
|
||||
agent1 = MockAgent("agent1", "Agent1 output - should NOT appear")
|
||||
agent2 = MockAgent("agent2", "Agent2 output - SHOULD appear")
|
||||
|
||||
# Build workflow: start -> agent1 (no output) -> agent2 (output visible)
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=start_exec, output_executors=[start_exec, agent2])
|
||||
.add_edge(start_exec, agent1)
|
||||
.add_edge(agent1, agent2)
|
||||
.build()
|
||||
)
|
||||
|
||||
agent = WorkflowAgent(workflow=workflow, name="Test Agent")
|
||||
result = await agent.run("Test input")
|
||||
@@ -754,17 +758,13 @@ class TestWorkflowAgent:
|
||||
return ResponseStream(_iter(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
@executor
|
||||
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
async def start_exec(messages: list[ChatMessage], ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
mock_agent = MockAgent("agent", "Unique response text")
|
||||
|
||||
# Build workflow with single agent
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="start")
|
||||
.register_executor(lambda: start_executor, "start")
|
||||
.register_agent(lambda: MockAgent("agent", "Unique response text"), "agent")
|
||||
.add_edge("start", "agent")
|
||||
.build()
|
||||
)
|
||||
workflow = WorkflowBuilder(start_executor=start_exec).add_edge(start_exec, mock_agent).build()
|
||||
|
||||
agent = WorkflowAgent(workflow=workflow, name="Test Agent")
|
||||
result = await agent.run("Test input")
|
||||
|
||||
@@ -134,322 +134,65 @@ def test_add_agent_duplicate_id_raises_error():
|
||||
builder.add_edge(agent1, agent2).build()
|
||||
|
||||
|
||||
# Tests for new executor registration patterns
|
||||
def test_fan_out_edges_with_direct_instances():
|
||||
"""Test fan-out edges with direct executor instances."""
|
||||
source = MockExecutor(id="Source")
|
||||
target1 = MockExecutor(id="Target1")
|
||||
target2 = MockExecutor(id="Target2")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=source).add_fan_out_edges(source, [target1, target2]).build()
|
||||
|
||||
def test_register_executor_basic():
|
||||
"""Test basic executor registration with lazy initialization."""
|
||||
builder = WorkflowBuilder(start_executor="TestExecutor")
|
||||
|
||||
# Register an executor factory - ID must match the registered name
|
||||
result = builder.register_executor(lambda: MockExecutor(id="TestExecutor"), name="TestExecutor")
|
||||
|
||||
# Verify that register returns the builder for chaining
|
||||
assert result is builder
|
||||
|
||||
# Build workflow and verify executor is instantiated
|
||||
workflow = builder.build()
|
||||
assert "TestExecutor" in workflow.executors
|
||||
assert isinstance(workflow.executors["TestExecutor"], MockExecutor)
|
||||
|
||||
|
||||
def test_register_multiple_executors():
|
||||
"""Test registering multiple executors and connecting them with edges."""
|
||||
builder = WorkflowBuilder(start_executor="ExecutorA")
|
||||
|
||||
# Register multiple executors - IDs must match registered names
|
||||
builder.register_executor(lambda: MockExecutor(id="ExecutorA"), name="ExecutorA")
|
||||
builder.register_executor(lambda: MockExecutor(id="ExecutorB"), name="ExecutorB")
|
||||
builder.register_executor(lambda: MockExecutor(id="ExecutorC"), name="ExecutorC")
|
||||
|
||||
# Build workflow with edges using registered names
|
||||
workflow = builder.add_edge("ExecutorA", "ExecutorB").add_edge("ExecutorB", "ExecutorC").build()
|
||||
|
||||
# Verify all executors are present
|
||||
assert "ExecutorA" in workflow.executors
|
||||
assert "ExecutorB" in workflow.executors
|
||||
assert "ExecutorC" in workflow.executors
|
||||
assert workflow.start_executor_id == "ExecutorA"
|
||||
|
||||
|
||||
def test_register_with_multiple_names():
|
||||
"""Test registering the same factory function under multiple names."""
|
||||
builder = WorkflowBuilder(start_executor="ExecutorA")
|
||||
|
||||
# Register same executor factory under multiple names
|
||||
# Note: Each call creates a new instance, so IDs won't conflict
|
||||
counter = {"val": 0}
|
||||
|
||||
def make_executor():
|
||||
counter["val"] += 1
|
||||
return MockExecutor(id="ExecutorA" if counter["val"] == 1 else "ExecutorB")
|
||||
|
||||
builder.register_executor(make_executor, name=["ExecutorA", "ExecutorB"])
|
||||
|
||||
# Set up workflow
|
||||
workflow = builder.add_edge("ExecutorA", "ExecutorB").build()
|
||||
|
||||
# Verify both executors are present
|
||||
assert "ExecutorA" in workflow.executors
|
||||
assert "ExecutorB" in workflow.executors
|
||||
assert workflow.start_executor_id == "ExecutorA"
|
||||
|
||||
|
||||
def test_register_duplicate_name_raises_error():
|
||||
"""Test that registering duplicate names raises an error."""
|
||||
builder = WorkflowBuilder(start_executor="MyExecutor")
|
||||
|
||||
# Register first executor
|
||||
builder.register_executor(lambda: MockExecutor(id="executor_1"), name="MyExecutor")
|
||||
|
||||
# Registering second executor with same name should raise ValueError
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
builder.register_executor(lambda: MockExecutor(id="executor_2"), name="MyExecutor")
|
||||
|
||||
|
||||
def test_register_duplicate_id_raises_error():
|
||||
"""Test that registering duplicate id raises an error."""
|
||||
builder = WorkflowBuilder(start_executor="MyExecutor1")
|
||||
|
||||
# Register first executor
|
||||
builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor1")
|
||||
builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor2")
|
||||
|
||||
# Registering second executor with same ID should raise ValueError
|
||||
with pytest.raises(ValueError, match="Executor with ID 'executor' has already been registered."):
|
||||
builder.build()
|
||||
|
||||
|
||||
def test_register_agent_basic():
|
||||
"""Test basic agent registration with lazy initialization."""
|
||||
builder = WorkflowBuilder(start_executor="TestAgent")
|
||||
|
||||
# Register an agent factory
|
||||
result = builder.register_agent(lambda: DummyAgent(id="agent_test", name="test_agent"), name="TestAgent")
|
||||
|
||||
# Verify that register_agent returns the builder for chaining
|
||||
assert result is builder
|
||||
|
||||
# Build workflow and verify agent is wrapped in AgentExecutor
|
||||
workflow = builder.build()
|
||||
assert "test_agent" in workflow.executors
|
||||
assert isinstance(workflow.executors["test_agent"], AgentExecutor)
|
||||
|
||||
|
||||
def test_register_agent_with_thread():
|
||||
"""Test registering an agent with a custom thread."""
|
||||
builder = WorkflowBuilder(start_executor="ThreadedAgent")
|
||||
custom_thread = AgentThread()
|
||||
|
||||
# Register agent with custom thread
|
||||
builder.register_agent(
|
||||
lambda: DummyAgent(id="agent_with_thread", name="threaded_agent"),
|
||||
name="ThreadedAgent",
|
||||
agent_thread=custom_thread,
|
||||
)
|
||||
|
||||
# Build workflow and verify agent executor configuration
|
||||
workflow = builder.build()
|
||||
executor = workflow.executors["threaded_agent"]
|
||||
|
||||
assert isinstance(executor, AgentExecutor)
|
||||
assert executor.id == "threaded_agent"
|
||||
assert executor._agent_thread is custom_thread # type: ignore
|
||||
|
||||
|
||||
def test_register_agent_duplicate_name_raises_error():
|
||||
"""Test that registering agents with duplicate names raises an error."""
|
||||
builder = WorkflowBuilder(start_executor="MyAgent")
|
||||
|
||||
# Register first agent
|
||||
builder.register_agent(lambda: DummyAgent(id="agent1", name="first"), name="MyAgent")
|
||||
|
||||
# Registering second agent with same name should raise ValueError
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
builder.register_agent(lambda: DummyAgent(id="agent2", name="second"), name="MyAgent")
|
||||
|
||||
|
||||
def test_register_and_add_edge_with_strings():
|
||||
"""Test that registered executors can be connected using string names."""
|
||||
builder = WorkflowBuilder(start_executor="Source")
|
||||
|
||||
# Register executors
|
||||
builder.register_executor(lambda: MockExecutor(id="source"), name="Source")
|
||||
builder.register_executor(lambda: MockExecutor(id="target"), name="Target")
|
||||
|
||||
# Add edge using string names
|
||||
workflow = builder.add_edge("Source", "Target").build()
|
||||
|
||||
# Verify edge is created correctly
|
||||
assert workflow.start_executor_id == "source"
|
||||
assert "source" in workflow.executors
|
||||
assert "target" in workflow.executors
|
||||
|
||||
|
||||
def test_register_agent_and_add_edge_with_strings():
|
||||
"""Test that registered agents can be connected using string names."""
|
||||
builder = WorkflowBuilder(start_executor="Writer")
|
||||
|
||||
# Register agents
|
||||
builder.register_agent(lambda: DummyAgent(id="writer_id", name="writer"), name="Writer")
|
||||
builder.register_agent(lambda: DummyAgent(id="reviewer_id", name="reviewer"), name="Reviewer")
|
||||
|
||||
# Add edge using string names
|
||||
workflow = builder.add_edge("Writer", "Reviewer").build()
|
||||
|
||||
# Verify edge is created correctly
|
||||
assert workflow.start_executor_id == "writer"
|
||||
assert "writer" in workflow.executors
|
||||
assert "reviewer" in workflow.executors
|
||||
assert all(isinstance(e, AgentExecutor) for e in workflow.executors.values())
|
||||
|
||||
|
||||
def test_register_with_fan_out_edges():
|
||||
"""Test using registered names with fan-out edge groups."""
|
||||
builder = WorkflowBuilder(start_executor="Source")
|
||||
|
||||
# Register executors - IDs must match registered names
|
||||
builder.register_executor(lambda: MockExecutor(id="Source"), name="Source")
|
||||
builder.register_executor(lambda: MockExecutor(id="Target1"), name="Target1")
|
||||
builder.register_executor(lambda: MockExecutor(id="Target2"), name="Target2")
|
||||
|
||||
# Add fan-out edges using registered names
|
||||
workflow = builder.add_fan_out_edges("Source", ["Target1", "Target2"]).build()
|
||||
|
||||
# Verify all executors are present
|
||||
assert "Source" in workflow.executors
|
||||
assert "Target1" in workflow.executors
|
||||
assert "Target2" in workflow.executors
|
||||
|
||||
|
||||
def test_register_with_fan_in_edges():
|
||||
"""Test using registered names with fan-in edge groups."""
|
||||
builder = WorkflowBuilder(start_executor="Source1")
|
||||
def test_fan_in_edges_with_direct_instances():
|
||||
"""Test fan-in edges with direct executor instances."""
|
||||
source1 = MockExecutor(id="Source1")
|
||||
source2 = MockExecutor(id="Source2")
|
||||
aggregator = MockAggregator(id="Aggregator")
|
||||
|
||||
# Register executors - IDs must match registered names
|
||||
builder.register_executor(lambda: MockExecutor(id="Source1"), name="Source1")
|
||||
builder.register_executor(lambda: MockExecutor(id="Source2"), name="Source2")
|
||||
builder.register_executor(lambda: MockAggregator(id="Aggregator"), name="Aggregator")
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=source1)
|
||||
.add_edge(source1, source2)
|
||||
.add_fan_in_edges([source1, source2], aggregator)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Add fan-in edges using registered names
|
||||
# Both Source1 and Source2 need to be reachable, so connect Source1 to Source2
|
||||
workflow = builder.add_edge("Source1", "Source2").add_fan_in_edges(["Source1", "Source2"], "Aggregator").build()
|
||||
|
||||
# Verify all executors are present
|
||||
assert "Source1" in workflow.executors
|
||||
assert "Source2" in workflow.executors
|
||||
assert "Aggregator" in workflow.executors
|
||||
|
||||
|
||||
def test_register_with_chain():
|
||||
"""Test using registered names with add_chain."""
|
||||
builder = WorkflowBuilder(start_executor="Step1")
|
||||
def test_chain_with_direct_instances():
|
||||
"""Test add_chain with direct executor instances."""
|
||||
step1 = MockExecutor(id="Step1")
|
||||
step2 = MockExecutor(id="Step2")
|
||||
step3 = MockExecutor(id="Step3")
|
||||
|
||||
# Register executors - IDs must match registered names
|
||||
builder.register_executor(lambda: MockExecutor(id="Step1"), name="Step1")
|
||||
builder.register_executor(lambda: MockExecutor(id="Step2"), name="Step2")
|
||||
builder.register_executor(lambda: MockExecutor(id="Step3"), name="Step3")
|
||||
workflow = WorkflowBuilder(start_executor=step1).add_chain([step1, step2, step3]).build()
|
||||
|
||||
# Add chain using registered names
|
||||
workflow = builder.add_chain(["Step1", "Step2", "Step3"]).build()
|
||||
|
||||
# Verify all executors are present
|
||||
assert "Step1" in workflow.executors
|
||||
assert "Step2" in workflow.executors
|
||||
assert "Step3" in workflow.executors
|
||||
assert workflow.start_executor_id == "Step1"
|
||||
|
||||
|
||||
def test_register_factory_called_only_once():
|
||||
"""Test that registered factory functions are called only during build."""
|
||||
call_count = 0
|
||||
|
||||
def factory():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return MockExecutor(id="Test")
|
||||
|
||||
builder = WorkflowBuilder(start_executor="Test")
|
||||
builder.register_executor(factory, name="Test")
|
||||
|
||||
# Factory should not be called yet
|
||||
assert call_count == 0
|
||||
|
||||
# Factory should still not be called
|
||||
assert call_count == 0
|
||||
|
||||
# Build workflow
|
||||
workflow = builder.build()
|
||||
|
||||
# Factory should now be called exactly once
|
||||
assert call_count == 1
|
||||
assert "Test" in workflow.executors
|
||||
|
||||
|
||||
def test_mixing_eager_and_lazy_initialization_error():
|
||||
"""Test that mixing eager executor instances with lazy string names raises appropriate error."""
|
||||
builder = WorkflowBuilder(start_executor="Lazy")
|
||||
|
||||
# Create an eager executor instance
|
||||
eager_executor = MockExecutor(id="eager")
|
||||
|
||||
# Register a lazy executor
|
||||
builder.register_executor(lambda: MockExecutor(id="Lazy"), name="Lazy")
|
||||
|
||||
# Mixing eager and lazy should raise an error during add_edge
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
r"Both source and target must be either registered factory names \(str\) "
|
||||
r"or Executor/SupportsAgentRun instances\."
|
||||
),
|
||||
):
|
||||
builder.add_edge(eager_executor, "Lazy")
|
||||
|
||||
|
||||
def test_register_with_condition():
|
||||
"""Test adding edges with conditions using registered names."""
|
||||
builder = WorkflowBuilder(start_executor="Source")
|
||||
def test_add_edge_with_condition():
|
||||
"""Test adding edges with conditions using direct executor instances."""
|
||||
source = MockExecutor(id="Source")
|
||||
target = MockExecutor(id="Target")
|
||||
|
||||
def condition_func(msg: MockMessage) -> bool:
|
||||
return msg.data > 0
|
||||
|
||||
# Register executors - IDs must match registered names
|
||||
builder.register_executor(lambda: MockExecutor(id="Source"), name="Source")
|
||||
builder.register_executor(lambda: MockExecutor(id="Target"), name="Target")
|
||||
workflow = WorkflowBuilder(start_executor=source).add_edge(source, target, condition=condition_func).build()
|
||||
|
||||
# Add edge with condition
|
||||
workflow = builder.add_edge("Source", "Target", condition=condition_func).build()
|
||||
|
||||
# Verify workflow is built correctly
|
||||
assert "Source" in workflow.executors
|
||||
assert "Target" in workflow.executors
|
||||
|
||||
|
||||
def test_register_agent_creates_unique_instances():
|
||||
"""Test that registered agent factories create new instances on each build."""
|
||||
instance_ids: list[int] = []
|
||||
|
||||
def agent_factory() -> DummyAgent:
|
||||
agent = DummyAgent(id=f"agent_{len(instance_ids)}", name="test")
|
||||
instance_ids.append(id(agent))
|
||||
return agent
|
||||
|
||||
# Build first workflow
|
||||
builder1 = WorkflowBuilder(start_executor="Agent")
|
||||
builder1.register_agent(agent_factory, name="Agent")
|
||||
_ = builder1.build()
|
||||
|
||||
# Build second workflow
|
||||
builder2 = WorkflowBuilder(start_executor="Agent")
|
||||
builder2.register_agent(agent_factory, name="Agent")
|
||||
_ = builder2.build()
|
||||
|
||||
# Verify that two different agent instances were created
|
||||
assert len(instance_ids) == 2
|
||||
assert instance_ids[0] != instance_ids[1]
|
||||
|
||||
|
||||
# region with_output_from tests
|
||||
|
||||
|
||||
@@ -488,14 +231,17 @@ def test_with_output_from_with_agent_instances():
|
||||
assert workflow._output_executors == ["reviewer"] # type: ignore
|
||||
|
||||
|
||||
def test_with_output_from_with_registered_names():
|
||||
"""Test with_output_from with registered factory names (strings)."""
|
||||
builder = WorkflowBuilder(start_executor="ExecutorAFactory", output_executors=["ExecutorBFactory"])
|
||||
builder.register_executor(lambda: MockExecutor(id="ExecutorA"), name="ExecutorAFactory")
|
||||
builder.register_executor(lambda: MockExecutor(id="ExecutorB"), name="ExecutorBFactory")
|
||||
workflow = builder.add_edge("ExecutorAFactory", "ExecutorBFactory").build()
|
||||
def test_with_output_from_with_executor_instances_by_id():
|
||||
"""Test with_output_from with direct executor instances resolves to executor IDs."""
|
||||
executor_a = MockExecutor(id="ExecutorA")
|
||||
executor_b = MockExecutor(id="ExecutorB")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=executor_a, output_executors=[executor_b])
|
||||
.add_edge(executor_a, executor_b)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Verify that the workflow was built with the correct output executors
|
||||
assert workflow._output_executors == ["ExecutorB"] # type: ignore
|
||||
|
||||
|
||||
@@ -531,14 +277,17 @@ def test_with_output_from_can_be_set_to_different_value():
|
||||
assert workflow._output_executors == ["executor_b"] # type: ignore
|
||||
|
||||
|
||||
def test_with_output_from_with_registered_agents():
|
||||
"""Test with_output_from with registered agent factory names."""
|
||||
builder = WorkflowBuilder(start_executor="WriterAgent", output_executors=["ReviewerAgent"])
|
||||
builder.register_agent(lambda: DummyAgent(id="agent1", name="writer"), name="WriterAgent")
|
||||
builder.register_agent(lambda: DummyAgent(id="agent2", name="reviewer"), name="ReviewerAgent")
|
||||
workflow = builder.add_edge("WriterAgent", "ReviewerAgent").build()
|
||||
def test_with_output_from_with_agent_instances_resolves_name():
|
||||
"""Test with_output_from with agent instances resolves to agent names."""
|
||||
agent_writer = DummyAgent(id="agent1", name="writer")
|
||||
agent_reviewer = DummyAgent(id="agent2", name="reviewer")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=agent_writer, output_executors=[agent_reviewer])
|
||||
.add_edge(agent_writer, agent_reviewer)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Verify that the workflow was built with the agent's resolved name
|
||||
assert workflow._output_executors == ["reviewer"] # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -474,8 +474,9 @@ async def test_message_trace_context_serialization(span_exporter: InMemorySpanEx
|
||||
async def test_workflow_build_error_tracing(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that build errors are properly recorded in build spans."""
|
||||
|
||||
# Test validation error by referencing a non-existent start executor
|
||||
builder = WorkflowBuilder(start_executor="NonExistent")
|
||||
# Create a valid builder, then clear the start executor to trigger a build-time ValueError
|
||||
builder = WorkflowBuilder(start_executor=MockExecutor(id="mock"))
|
||||
builder._start_executor = None # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
builder.build()
|
||||
|
||||
+15
-20
@@ -148,33 +148,26 @@ class DeclarativeWorkflowBuilder:
|
||||
if self._validate:
|
||||
self._validate_workflow(actions)
|
||||
|
||||
# Use a placeholder for start_executor; it will be overwritten below via _set_start_executor
|
||||
# Create a stable entry node as the start executor, then wire it to the first action.
|
||||
# This avoids needing a placeholder since the entry executor isn't known until after
|
||||
# _create_executors_for_actions runs (which itself needs the builder to add edges).
|
||||
entry_node = JoinExecutor({"kind": "Entry"}, id="_workflow_entry")
|
||||
self._executors[entry_node.id] = entry_node
|
||||
builder = WorkflowBuilder(
|
||||
start_executor="_declarative_placeholder",
|
||||
start_executor=entry_node,
|
||||
name=self._workflow_id,
|
||||
checkpoint_storage=self._checkpoint_storage,
|
||||
)
|
||||
|
||||
# First pass: create all executors
|
||||
entry_executor = self._create_executors_for_actions(actions, builder)
|
||||
# Create all executors and wire sequential edges
|
||||
first_executor = self._create_executors_for_actions(actions, builder)
|
||||
|
||||
# Set the entry point
|
||||
if entry_executor:
|
||||
# Check if entry is a control flow structure (If/Switch)
|
||||
if getattr(entry_executor, "_is_if_structure", False) or getattr(
|
||||
entry_executor, "_is_switch_structure", False
|
||||
):
|
||||
# Create an entry passthrough node and wire to the structure's branches
|
||||
entry_node = JoinExecutor({"kind": "Entry"}, id="_workflow_entry")
|
||||
self._executors[entry_node.id] = entry_node
|
||||
builder._set_start_executor(entry_node)
|
||||
# Use _add_sequential_edge which knows how to wire to structures
|
||||
self._add_sequential_edge(builder, entry_node, entry_executor)
|
||||
else:
|
||||
builder._set_start_executor(entry_executor)
|
||||
else:
|
||||
if not first_executor:
|
||||
raise ValueError("Failed to create any executors from actions.")
|
||||
|
||||
# Wire entry node to the first action (handles both regular and control flow targets)
|
||||
self._add_sequential_edge(builder, entry_node, first_executor)
|
||||
|
||||
# Resolve pending gotos (back-edges for loops, forward-edges for jumps)
|
||||
self._resolve_pending_gotos(builder)
|
||||
|
||||
@@ -223,9 +216,11 @@ class DeclarativeWorkflowBuilder:
|
||||
for action_def in actions:
|
||||
kind = action_def.get("kind", "")
|
||||
|
||||
# Check for duplicate explicit IDs
|
||||
# Check for duplicate or reserved explicit IDs
|
||||
explicit_id = action_def.get("id")
|
||||
if explicit_id:
|
||||
if explicit_id == "_workflow_entry":
|
||||
raise ValueError(f"Action ID '{explicit_id}' is reserved for internal use. Choose a different ID.")
|
||||
if explicit_id in seen_ids:
|
||||
raise ValueError(f"Duplicate action ID '{explicit_id}'. Action IDs must be unique.")
|
||||
seen_ids.add(explicit_id)
|
||||
|
||||
@@ -2012,7 +2012,9 @@ class TestBuilderControlFlowCreation:
|
||||
# Create builder with minimal yaml definition
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
action_def = {
|
||||
"kind": "GotoAction",
|
||||
@@ -2036,7 +2038,9 @@ class TestBuilderControlFlowCreation:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
action_def = {
|
||||
"kind": "GotoAction",
|
||||
@@ -2056,7 +2060,9 @@ class TestBuilderControlFlowCreation:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
action_def = {
|
||||
"kind": "GotoAction",
|
||||
@@ -2094,7 +2100,9 @@ class TestBuilderControlFlowCreation:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
# Create a mock loop_next executor
|
||||
loop_next = ForeachNextExecutor(
|
||||
@@ -2124,7 +2132,9 @@ class TestBuilderControlFlowCreation:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
action_def = {
|
||||
"kind": "BreakLoop",
|
||||
@@ -2149,7 +2159,9 @@ class TestBuilderControlFlowCreation:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
# Create a mock loop_next executor
|
||||
loop_next = ForeachNextExecutor(
|
||||
@@ -2179,7 +2191,9 @@ class TestBuilderControlFlowCreation:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
action_def = {
|
||||
"kind": "ContinueLoop",
|
||||
@@ -2203,7 +2217,9 @@ class TestBuilderEdgeWiring:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
# Create a mock source executor
|
||||
source = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "test"}}, id="source")
|
||||
@@ -2236,7 +2252,9 @@ class TestBuilderEdgeWiring:
|
||||
|
||||
yaml_def = {"name": "test_workflow", "actions": []}
|
||||
graph_builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
wb = WorkflowBuilder(start_executor="dummy")
|
||||
from agent_framework_declarative._workflows._executors_control_flow import JoinExecutor
|
||||
|
||||
wb = WorkflowBuilder(start_executor=JoinExecutor({"kind": "Dummy"}, id="dummy"))
|
||||
|
||||
source = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "source"}}, id="source")
|
||||
target = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "target"}}, id="target")
|
||||
|
||||
@@ -223,11 +223,11 @@ class TestGraphWorkflowCheckpointing:
|
||||
builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
_workflow = builder.build() # noqa: F841
|
||||
|
||||
# Verify multiple executors were created
|
||||
# Verify multiple executors were created (+ _workflow_entry node)
|
||||
assert "step1" in builder._executors
|
||||
assert "step2" in builder._executors
|
||||
assert "step3" in builder._executors
|
||||
assert len(builder._executors) == 3
|
||||
assert len(builder._executors) == 4
|
||||
|
||||
def test_workflow_executor_connectivity(self):
|
||||
"""Test that executors are properly connected in sequence."""
|
||||
@@ -243,8 +243,8 @@ class TestGraphWorkflowCheckpointing:
|
||||
builder = DeclarativeWorkflowBuilder(yaml_def)
|
||||
workflow = builder.build()
|
||||
|
||||
# Verify all executors exist
|
||||
assert len(builder._executors) == 3
|
||||
# Verify all executors exist (+ _workflow_entry node)
|
||||
assert len(builder._executors) == 4
|
||||
|
||||
# Verify the workflow can be inspected
|
||||
assert workflow is not None
|
||||
|
||||
@@ -29,8 +29,7 @@ parallel workflow with:
|
||||
- a default aggregator that combines all agent conversations and completes the workflow
|
||||
|
||||
Notes:
|
||||
- Participants can be provided as SupportsAgentRun or Executor instances via `participants=[...]`,
|
||||
or as factories returning SupportsAgentRun or Executor via `participant_factories=[...]`.
|
||||
- Participants can be provided as SupportsAgentRun or Executor instances via `participants=[...]`.
|
||||
- A custom aggregator can be provided as:
|
||||
- an Executor instance (it should handle list[AgentExecutorResponse],
|
||||
yield output), or
|
||||
@@ -187,11 +186,8 @@ class ConcurrentBuilder:
|
||||
r"""High-level builder for concurrent agent workflows.
|
||||
|
||||
- `participants=[...]` accepts a list of SupportsAgentRun (recommended) or Executor.
|
||||
- `participant_factories=[...]` accepts a list of factories for SupportsAgentRun (recommended)
|
||||
or Executor factories
|
||||
- `build()` wires: dispatcher -> fan-out -> participants -> fan-in -> aggregator.
|
||||
- `with_aggregator(...)` overrides the default aggregator with an Executor or callback.
|
||||
- `register_aggregator(...)` accepts a factory for an Executor as custom aggregator.
|
||||
|
||||
Usage:
|
||||
|
||||
@@ -202,9 +198,6 @@ class ConcurrentBuilder:
|
||||
# Minimal: use default aggregator (returns list[ChatMessage])
|
||||
workflow = ConcurrentBuilder(participants=[agent1, agent2, agent3]).build()
|
||||
|
||||
# With agent factories
|
||||
workflow = ConcurrentBuilder(participant_factories=[create_agent1, create_agent2, create_agent3]).build()
|
||||
|
||||
|
||||
# Custom aggregator via callback (sync or async). The callback receives
|
||||
# list[AgentExecutorResponse] and its return value becomes the workflow's output.
|
||||
@@ -215,20 +208,6 @@ class ConcurrentBuilder:
|
||||
workflow = ConcurrentBuilder(participants=[agent1, agent2, agent3]).with_aggregator(summarize).build()
|
||||
|
||||
|
||||
# Custom aggregator via a factory
|
||||
class MyAggregator(Executor):
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
await ctx.yield_output(" | ".join(r.agent_response.messages[-1].text for r in results))
|
||||
|
||||
|
||||
workflow = (
|
||||
ConcurrentBuilder(participant_factories=[create_agent1, create_agent2, create_agent3])
|
||||
.register_aggregator(lambda: MyAggregator(id="my_aggregator"))
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
# Enable checkpoint persistence so runs can resume
|
||||
workflow = ConcurrentBuilder(participants=[agent1, agent2, agent3], checkpoint_storage=storage).build()
|
||||
|
||||
@@ -239,58 +218,29 @@ class ConcurrentBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
participants: Sequence[SupportsAgentRun | Executor] | None = None,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]] | None = None,
|
||||
participants: Sequence[SupportsAgentRun | Executor],
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
intermediate_outputs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the ConcurrentBuilder.
|
||||
|
||||
Args:
|
||||
participants: Optional sequence of agent or executor instances to run in parallel.
|
||||
participant_factories: Optional sequence of callables returning agent or executor instances.
|
||||
participants: Sequence of agent or executor instances to run in parallel.
|
||||
checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence.
|
||||
intermediate_outputs: If True, enables intermediate outputs from agent participants
|
||||
before aggregation.
|
||||
"""
|
||||
self._participants: list[SupportsAgentRun | Executor] = []
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
self._aggregator: Executor | None = None
|
||||
self._aggregator_factory: Callable[[], Executor] | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = checkpoint_storage
|
||||
self._request_info_enabled: bool = False
|
||||
self._request_info_filter: set[str] | None = None
|
||||
self._intermediate_outputs: bool = intermediate_outputs
|
||||
|
||||
if participants is None and participant_factories is None:
|
||||
raise ValueError("Either participants or participant_factories must be provided.")
|
||||
|
||||
if participant_factories is not None:
|
||||
self._set_participant_factories(participant_factories)
|
||||
if participants is not None:
|
||||
self._set_participants(participants)
|
||||
|
||||
def _set_participant_factories(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> None:
|
||||
"""Set participant factories (internal)."""
|
||||
if self._participants:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participant_factories:
|
||||
raise ValueError("participant_factories already set.")
|
||||
|
||||
if not participant_factories:
|
||||
raise ValueError("participant_factories cannot be empty")
|
||||
|
||||
self._participant_factories = list(participant_factories)
|
||||
self._set_participants(participants)
|
||||
|
||||
def _set_participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> None:
|
||||
"""Set participants (internal)."""
|
||||
if self._participant_factories:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participants:
|
||||
raise ValueError("participants already set.")
|
||||
|
||||
@@ -315,39 +265,6 @@ class ConcurrentBuilder:
|
||||
|
||||
self._participants = list(participants)
|
||||
|
||||
def register_aggregator(self, aggregator_factory: Callable[[], Executor]) -> "ConcurrentBuilder":
|
||||
r"""Define a custom aggregator for this concurrent workflow.
|
||||
|
||||
Accepts a factory (callable) that returns an Executor instance. The executor
|
||||
should handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)`.
|
||||
|
||||
Args:
|
||||
aggregator_factory: Callable that returns an Executor instance
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class MyCustomExecutor(Executor): ...
|
||||
|
||||
|
||||
wf = (
|
||||
ConcurrentBuilder()
|
||||
.register_participants([create_researcher, create_marketer, create_legal])
|
||||
.register_aggregator(lambda: MyCustomExecutor(id="my_aggregator"))
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
if self._aggregator is not None:
|
||||
raise ValueError(
|
||||
"Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance."
|
||||
)
|
||||
|
||||
if self._aggregator_factory is not None:
|
||||
raise ValueError("register_aggregator() has already been called on this builder instance.")
|
||||
|
||||
self._aggregator_factory = aggregator_factory
|
||||
return self
|
||||
|
||||
def with_aggregator(
|
||||
self,
|
||||
aggregator: Executor
|
||||
@@ -393,11 +310,6 @@ class ConcurrentBuilder:
|
||||
|
||||
wf = ConcurrentBuilder(participants=[a1, a2, a3]).with_aggregator(summarize).build()
|
||||
"""
|
||||
if self._aggregator_factory is not None:
|
||||
raise ValueError(
|
||||
"Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance."
|
||||
)
|
||||
|
||||
if self._aggregator is not None:
|
||||
raise ValueError("with_aggregator() has already been called on this builder instance.")
|
||||
|
||||
@@ -445,19 +357,10 @@ class ConcurrentBuilder:
|
||||
|
||||
def _resolve_participants(self) -> list[Executor]:
|
||||
"""Resolve participant instances into Executor objects."""
|
||||
if not self._participants and not self._participant_factories:
|
||||
raise ValueError("No participants provided. Pass participants or participant_factories to the constructor.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
if not self._participants:
|
||||
raise ValueError("No participants provided. Pass participants to the constructor.")
|
||||
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
# Resolve the participant factories now. This doesn't break the factory pattern
|
||||
# since the Sequential builder still creates new instances per workflow build.
|
||||
for factory in self._participant_factories:
|
||||
p = factory()
|
||||
participants.append(p)
|
||||
else:
|
||||
participants = self._participants
|
||||
participants: list[Executor | SupportsAgentRun] = self._participants
|
||||
|
||||
executors: list[Executor] = []
|
||||
for p in participants:
|
||||
@@ -502,15 +405,7 @@ class ConcurrentBuilder:
|
||||
"""
|
||||
# Internal nodes
|
||||
dispatcher = _DispatchToAllParticipants(id="dispatcher")
|
||||
aggregator = (
|
||||
self._aggregator
|
||||
if self._aggregator is not None
|
||||
else (
|
||||
self._aggregator_factory()
|
||||
if self._aggregator_factory is not None
|
||||
else _AggregateAgentConversations(id="aggregator")
|
||||
)
|
||||
)
|
||||
aggregator = self._aggregator if self._aggregator is not None else _AggregateAgentConversations(id="aggregator")
|
||||
|
||||
# Resolve participants and participant factories to executors
|
||||
participants: list[Executor] = self._resolve_participants()
|
||||
|
||||
@@ -526,8 +526,7 @@ class GroupChatBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
participants: Sequence[SupportsAgentRun | Executor] | None = None,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]] | None = None,
|
||||
participants: Sequence[SupportsAgentRun | Executor],
|
||||
# Orchestrator config (exactly one required)
|
||||
orchestrator_agent: ChatAgent | Callable[[], ChatAgent] | None = None,
|
||||
orchestrator: BaseGroupChatOrchestrator | Callable[[], BaseGroupChatOrchestrator] | None = None,
|
||||
@@ -542,8 +541,7 @@ class GroupChatBuilder:
|
||||
"""Initialize the GroupChatBuilder.
|
||||
|
||||
Args:
|
||||
participants: Optional sequence of agent or executor instances for the group chat.
|
||||
participant_factories: Optional sequence of callables returning agent or executor instances.
|
||||
participants: Sequence of agent or executor instances for the group chat.
|
||||
orchestrator_agent: An instance of ChatAgent or a callable that produces one to manage the group chat.
|
||||
orchestrator: An instance of BaseGroupChatOrchestrator or a callable that produces one to manage the
|
||||
group chat.
|
||||
@@ -557,7 +555,6 @@ class GroupChatBuilder:
|
||||
intermediate_outputs: If True, enables intermediate outputs from agent participants.
|
||||
"""
|
||||
self._participants: dict[str, SupportsAgentRun | Executor] = {}
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
|
||||
# Orchestrator related members
|
||||
self._orchestrator: BaseGroupChatOrchestrator | None = None
|
||||
@@ -578,13 +575,7 @@ class GroupChatBuilder:
|
||||
# Intermediate outputs
|
||||
self._intermediate_outputs = intermediate_outputs
|
||||
|
||||
if participants is None and participant_factories is None:
|
||||
raise ValueError("Either participants or participant_factories must be provided.")
|
||||
|
||||
if participant_factories is not None:
|
||||
self._set_participant_factories(participant_factories)
|
||||
if participants is not None:
|
||||
self._set_participants(participants)
|
||||
self._set_participants(participants)
|
||||
|
||||
# Set orchestrator if provided
|
||||
if any(x is not None for x in [orchestrator_agent, orchestrator, selection_func]):
|
||||
@@ -645,27 +636,8 @@ class GroupChatBuilder:
|
||||
else:
|
||||
self._orchestrator_factory = orchestrator_agent or orchestrator
|
||||
|
||||
def _set_participant_factories(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> None:
|
||||
"""Set participant factories (internal)."""
|
||||
if self._participants:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participant_factories:
|
||||
raise ValueError("participant_factories already set.")
|
||||
|
||||
if not participant_factories:
|
||||
raise ValueError("participant_factories cannot be empty")
|
||||
|
||||
self._participant_factories = list(participant_factories)
|
||||
|
||||
def _set_participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> None:
|
||||
"""Set participants (internal)."""
|
||||
if self._participant_factories:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participants:
|
||||
raise ValueError("participants already set.")
|
||||
|
||||
@@ -874,17 +846,10 @@ class GroupChatBuilder:
|
||||
|
||||
def _resolve_participants(self) -> list[Executor]:
|
||||
"""Resolve participant instances into Executor objects."""
|
||||
if not self._participants and not self._participant_factories:
|
||||
raise ValueError("No participants provided. Pass participants or participant_factories to the constructor.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
if not self._participants:
|
||||
raise ValueError("No participants provided. Pass participants to the constructor.")
|
||||
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
for factory in self._participant_factories:
|
||||
participant = factory()
|
||||
participants.append(participant)
|
||||
else:
|
||||
participants = list(self._participants.values())
|
||||
participants: list[Executor | SupportsAgentRun] = list(self._participants.values())
|
||||
|
||||
executors: list[Executor] = []
|
||||
for participant in participants:
|
||||
|
||||
@@ -32,7 +32,7 @@ Key properties:
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -575,7 +575,6 @@ class HandoffBuilder:
|
||||
*,
|
||||
name: str | None = None,
|
||||
participants: Sequence[SupportsAgentRun] | None = None,
|
||||
participant_factories: Mapping[str, Callable[[], SupportsAgentRun]] | None = None,
|
||||
description: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
@@ -584,8 +583,7 @@ class HandoffBuilder:
|
||||
|
||||
The builder starts in an unconfigured state and requires you to call:
|
||||
1. `.participants([...])` - Register agents
|
||||
2. or `.participant_factories({...})` - Register agent factories
|
||||
3. `.build()` - Construct the final Workflow
|
||||
2. `.build()` - Construct the final Workflow
|
||||
|
||||
Optional configuration methods allow you to customize context management,
|
||||
termination logic, and persistence.
|
||||
@@ -596,9 +594,6 @@ class HandoffBuilder:
|
||||
participants: Optional list of agents that will participate in the handoff workflow.
|
||||
You can also call `.participants([...])` later. Each participant must have a
|
||||
unique identifier (`.name` is preferred if set, otherwise `.id` is used).
|
||||
participant_factories: Optional mapping of factory names to callables that produce agents when invoked.
|
||||
This allows for lazy instantiation and state isolation per workflow instance
|
||||
created by this builder.
|
||||
description: Optional human-readable description explaining the workflow's
|
||||
purpose. Useful for documentation and observability.
|
||||
checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence.
|
||||
@@ -610,10 +605,7 @@ class HandoffBuilder:
|
||||
|
||||
# Participant related members
|
||||
self._participants: dict[str, SupportsAgentRun] = {}
|
||||
self._participant_factories: dict[str, Callable[[], SupportsAgentRun]] = {}
|
||||
self._start_id: str | None = None
|
||||
if participant_factories:
|
||||
self.register_participants(participant_factories)
|
||||
|
||||
if participants:
|
||||
self.participants(participants)
|
||||
@@ -635,68 +627,6 @@ class HandoffBuilder:
|
||||
termination_condition
|
||||
)
|
||||
|
||||
def register_participants(
|
||||
self, participant_factories: Mapping[str, Callable[[], SupportsAgentRun]]
|
||||
) -> "HandoffBuilder":
|
||||
"""Register factories that produce agents for the handoff workflow.
|
||||
|
||||
Each factory is a callable that returns an SupportsAgentRun instance.
|
||||
Factories are invoked when building the workflow, allowing for lazy instantiation
|
||||
and state isolation per workflow instance.
|
||||
|
||||
Args:
|
||||
participant_factories: Mapping of factory names to callables that return SupportsAgentRun
|
||||
instances. Each produced participant must have a unique identifier
|
||||
(`.name` is preferred if set, otherwise `.id` is used).
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If participant_factories is empty or `.participants(...)` or `.register_participants(...)`
|
||||
has already been called.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework_orchestrations import HandoffBuilder
|
||||
|
||||
|
||||
def create_triage() -> ChatAgent:
|
||||
return ...
|
||||
|
||||
|
||||
def create_refund_agent() -> ChatAgent:
|
||||
return ...
|
||||
|
||||
|
||||
def create_billing_agent() -> ChatAgent:
|
||||
return ...
|
||||
|
||||
|
||||
factories = {
|
||||
"triage": create_triage,
|
||||
"refund": create_refund_agent,
|
||||
"billing": create_billing_agent,
|
||||
}
|
||||
|
||||
# Handoff will be created automatically unless specified otherwise
|
||||
# The default creates a mesh topology where all agents can handoff to all others
|
||||
builder = HandoffBuilder().register_participants(factories)
|
||||
builder.with_start_agent("triage")
|
||||
"""
|
||||
if self._participants:
|
||||
raise ValueError("Cannot mix .participants() and .register_participants() in the same builder instance.")
|
||||
|
||||
if self._participant_factories:
|
||||
raise ValueError("register_participants() has already been called on this builder instance.")
|
||||
if not participant_factories:
|
||||
raise ValueError("participant_factories cannot be empty")
|
||||
|
||||
self._participant_factories = dict(participant_factories)
|
||||
return self
|
||||
|
||||
def participants(self, participants: Sequence[SupportsAgentRun]) -> "HandoffBuilder":
|
||||
"""Register the agents that will participate in the handoff workflow.
|
||||
|
||||
@@ -708,8 +638,8 @@ class HandoffBuilder:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: If participants is empty, contains duplicates, or `.participants()` or
|
||||
`.register_participants()` has already been called.
|
||||
ValueError: If participants is empty, contains duplicates, or `.participants()`
|
||||
has already been called.
|
||||
TypeError: If participants are not SupportsAgentRun instances.
|
||||
|
||||
Example:
|
||||
@@ -727,9 +657,6 @@ class HandoffBuilder:
|
||||
builder = HandoffBuilder().participants([triage, refund, billing])
|
||||
builder.with_start_agent(triage)
|
||||
"""
|
||||
if self._participant_factories:
|
||||
raise ValueError("Cannot mix .participants() and .register_participants() in the same builder instance.")
|
||||
|
||||
if self._participants:
|
||||
raise ValueError("participants have already been assigned")
|
||||
|
||||
@@ -755,8 +682,8 @@ class HandoffBuilder:
|
||||
|
||||
def add_handoff(
|
||||
self,
|
||||
source: str | SupportsAgentRun,
|
||||
targets: Sequence[str] | Sequence[SupportsAgentRun],
|
||||
source: SupportsAgentRun,
|
||||
targets: Sequence[SupportsAgentRun],
|
||||
*,
|
||||
description: str | None = None,
|
||||
) -> "HandoffBuilder":
|
||||
@@ -768,16 +695,8 @@ class HandoffBuilder:
|
||||
to all others by default (mesh topology).
|
||||
|
||||
Args:
|
||||
source: The agent that can initiate the handoff. Can be:
|
||||
- Factory name (str): If using participant factories
|
||||
- SupportsAgentRun instance: The actual agent object
|
||||
- Cannot mix factory names and instances across source and targets
|
||||
targets: One or more target agents that the source can hand off to. Can be:
|
||||
- Factory name (str): If using participant factories
|
||||
- SupportsAgentRun instance: The actual agent object
|
||||
- Single target: ["billing_agent"] or [agent_instance]
|
||||
- Multiple targets: ["billing_agent", "support_agent"] or [agent1, agent2]
|
||||
- Cannot mix factory names and instances across source and targets
|
||||
source: The agent that can initiate the handoff.
|
||||
targets: One or more target agents that the source can hand off to.
|
||||
description: Optional custom description for the handoff. If not provided, the description
|
||||
of the target agent(s) will be used. If the target agent has no description,
|
||||
no description will be set for the handoff tool, which is not recommended.
|
||||
@@ -789,25 +708,10 @@ class HandoffBuilder:
|
||||
Self for method chaining.
|
||||
|
||||
Raises:
|
||||
ValueError: 1) If source or targets are not in the participants list, or if
|
||||
participants(...) hasn't been called yet.
|
||||
2) If source or targets are factory names (str) but participant_factories(...)
|
||||
hasn't been called yet, or if they are not in the participant_factories list.
|
||||
TypeError: If mixing factory names (str) and SupportsAgentRun/Executor instances
|
||||
ValueError: If source or targets are not in the participants list, or if
|
||||
participants(...) hasn't been called yet.
|
||||
|
||||
Examples:
|
||||
Single target (using factory name):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
builder.add_handoff("triage_agent", "billing_agent")
|
||||
|
||||
Multiple targets (using factory names):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
builder.add_handoff("triage_agent", ["billing_agent", "support_agent", "escalation_agent"])
|
||||
|
||||
Multiple targets (using agent instances):
|
||||
|
||||
.. code-block:: python
|
||||
@@ -830,96 +734,54 @@ class HandoffBuilder:
|
||||
- Handoff tools are automatically registered for each source agent
|
||||
- If a source agent is configured multiple times via add_handoff, targets are merged
|
||||
"""
|
||||
if isinstance(source, str) and all(isinstance(t, str) for t in targets):
|
||||
# Both source and targets are factory names
|
||||
if not self._participant_factories:
|
||||
raise ValueError("Call participant_factories(...) before add_handoff(...)")
|
||||
if not self._participants:
|
||||
raise ValueError("Call participants(...) before add_handoff(...)")
|
||||
|
||||
if source not in self._participant_factories:
|
||||
raise ValueError(f"Source factory name '{source}' is not in the participant_factories list")
|
||||
# Resolve source agent ID
|
||||
source_id = self._resolve_to_id(source)
|
||||
if source_id not in self._participants:
|
||||
raise ValueError(f"Source agent '{source}' is not in the participants list")
|
||||
|
||||
for target in targets:
|
||||
if target not in self._participant_factories:
|
||||
raise ValueError(f"Target factory name '{target}' is not in the participant_factories list")
|
||||
# Resolve all target IDs
|
||||
target_ids: list[str] = []
|
||||
for target in targets:
|
||||
target_id = self._resolve_to_id(target)
|
||||
if target_id not in self._participants:
|
||||
raise ValueError(f"Target agent '{target}' is not in the participants list")
|
||||
target_ids.append(target_id)
|
||||
|
||||
# Merge with existing handoff configuration for this source
|
||||
if source in self._handoff_config:
|
||||
# Add new targets to existing list, avoiding duplicates
|
||||
for t in targets:
|
||||
if t in self._handoff_config[source]:
|
||||
logger.warning(f"Handoff from '{source}' to '{t}' is already configured; overwriting.")
|
||||
self._handoff_config[source].add(HandoffConfiguration(target=t, description=description))
|
||||
else:
|
||||
self._handoff_config[source] = set()
|
||||
for t in targets:
|
||||
self._handoff_config[source].add(HandoffConfiguration(target=t, description=description))
|
||||
return self
|
||||
# Merge with existing handoff configuration for this source
|
||||
if source_id not in self._handoff_config:
|
||||
self._handoff_config[source_id] = set()
|
||||
|
||||
if isinstance(source, (SupportsAgentRun)) and all(isinstance(t, SupportsAgentRun) for t in targets):
|
||||
# Both source and targets are instances
|
||||
if not self._participants:
|
||||
raise ValueError("Call participants(...) before add_handoff(...)")
|
||||
for t in target_ids:
|
||||
config = HandoffConfiguration(target=t, description=description)
|
||||
if config in self._handoff_config[source_id]:
|
||||
logger.warning(f"Handoff from '{source_id}' to '{t}' is already configured; overwriting.")
|
||||
# Remove old config so the new one (with updated description) takes effect
|
||||
self._handoff_config[source_id].discard(config)
|
||||
self._handoff_config[source_id].add(config)
|
||||
|
||||
# Resolve source agent ID
|
||||
source_id = self._resolve_to_id(source)
|
||||
if source_id not in self._participants:
|
||||
raise ValueError(f"Source agent '{source}' is not in the participants list")
|
||||
return self
|
||||
|
||||
# Resolve all target IDs
|
||||
target_ids: list[str] = []
|
||||
for target in targets:
|
||||
target_id = self._resolve_to_id(target)
|
||||
if target_id not in self._participants:
|
||||
raise ValueError(f"Target agent '{target}' is not in the participants list")
|
||||
target_ids.append(target_id)
|
||||
|
||||
# Merge with existing handoff configuration for this source
|
||||
if source_id in self._handoff_config:
|
||||
# Add new targets to existing list, avoiding duplicates
|
||||
for t in target_ids:
|
||||
if t in self._handoff_config[source_id]:
|
||||
logger.warning(f"Handoff from '{source_id}' to '{t}' is already configured; overwriting.")
|
||||
self._handoff_config[source_id].add(HandoffConfiguration(target=t, description=description))
|
||||
else:
|
||||
self._handoff_config[source_id] = set()
|
||||
for t in target_ids:
|
||||
self._handoff_config[source_id].add(HandoffConfiguration(target=t, description=description))
|
||||
|
||||
return self
|
||||
|
||||
raise TypeError(
|
||||
"Cannot mix factory names (str) and SupportsAgentRun instances across source and targets in add_handoff()"
|
||||
)
|
||||
|
||||
def with_start_agent(self, agent: str | SupportsAgentRun) -> "HandoffBuilder":
|
||||
def with_start_agent(self, agent: SupportsAgentRun) -> "HandoffBuilder":
|
||||
"""Set the agent that will initiate the handoff workflow.
|
||||
|
||||
If not specified, the first registered participant will be used as the starting agent.
|
||||
|
||||
Args:
|
||||
agent: The agent that will start the workflow. Can be:
|
||||
- Factory name (str): If using participant factories
|
||||
- SupportsAgentRun instance: The actual agent object
|
||||
agent: The agent that will start the workflow.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
"""
|
||||
if isinstance(agent, str):
|
||||
if self._participant_factories:
|
||||
if agent not in self._participant_factories:
|
||||
raise ValueError(f"Start agent factory name '{agent}' is not in the participant_factories list")
|
||||
else:
|
||||
raise ValueError("Call register_participants(...) before with_start_agent(...)")
|
||||
self._start_id = agent
|
||||
elif isinstance(agent, SupportsAgentRun):
|
||||
resolved_id = self._resolve_to_id(agent)
|
||||
if self._participants:
|
||||
if resolved_id not in self._participants:
|
||||
raise ValueError(f"Start agent '{resolved_id}' is not in the participants list")
|
||||
else:
|
||||
raise ValueError("Call participants(...) before with_start_agent(...)")
|
||||
self._start_id = resolved_id
|
||||
resolved_id = self._resolve_to_id(agent)
|
||||
if self._participants:
|
||||
if resolved_id not in self._participants:
|
||||
raise ValueError(f"Start agent '{resolved_id}' is not in the participants list")
|
||||
else:
|
||||
raise TypeError("Start agent must be a factory name (str) or an SupportsAgentRun instance")
|
||||
raise ValueError("Call participants(...) before with_start_agent(...)")
|
||||
self._start_id = resolved_id
|
||||
|
||||
return self
|
||||
|
||||
@@ -1090,48 +952,21 @@ class HandoffBuilder:
|
||||
# region Internal Helper Methods
|
||||
|
||||
def _resolve_agents(self) -> dict[str, SupportsAgentRun]:
|
||||
"""Resolve participant factories into agent instances.
|
||||
|
||||
If agent instances were provided directly via participants(...), those are
|
||||
returned as-is. If participant factories were provided via participant_factories(...),
|
||||
those are invoked to create the agent instances.
|
||||
"""Resolve participant instances into agent instances.
|
||||
|
||||
Returns:
|
||||
Map of executor IDs or factory names to `SupportsAgentRun` instances
|
||||
Map of executor IDs to `SupportsAgentRun` instances
|
||||
"""
|
||||
if not self._participants and not self._participant_factories:
|
||||
raise ValueError("No participants provided. Call .participants() or .register_participants() first.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
if not self._participants:
|
||||
raise ValueError("No participants provided. Call .participants() first.")
|
||||
|
||||
if self._participants:
|
||||
return self._participants
|
||||
return self._participants
|
||||
|
||||
if self._participant_factories:
|
||||
# Invoke each factory to create participant instances
|
||||
factory_names_to_agents: dict[str, SupportsAgentRun] = {}
|
||||
for factory_name, factory in self._participant_factories.items():
|
||||
instance = factory()
|
||||
if isinstance(instance, SupportsAgentRun):
|
||||
resolved_id = self._resolve_to_id(instance)
|
||||
else:
|
||||
raise TypeError(f"Participants must be SupportsAgentRun instances. Got {type(instance).__name__}.")
|
||||
|
||||
if resolved_id in factory_names_to_agents:
|
||||
raise ValueError(f"Duplicate participant name '{resolved_id}' detected")
|
||||
|
||||
# Map executors by factory name (not executor.id) because handoff configs reference factory names
|
||||
# This allows users to configure handoffs using the factory names they provided
|
||||
factory_names_to_agents[factory_name] = instance
|
||||
|
||||
return factory_names_to_agents
|
||||
|
||||
raise ValueError("No executors or participant_factories have been configured")
|
||||
|
||||
def _resolve_handoffs(self, agents: Mapping[str, SupportsAgentRun]) -> dict[str, list[HandoffConfiguration]]:
|
||||
"""Handoffs may be specified using factory names or instances; resolve to executor IDs.
|
||||
def _resolve_handoffs(self, agents: dict[str, SupportsAgentRun]) -> dict[str, list[HandoffConfiguration]]:
|
||||
"""Resolve handoff configurations to executor IDs.
|
||||
|
||||
Args:
|
||||
agents: Map of agent IDs or factory names to `SupportsAgentRun` instances
|
||||
agents: Map of agent IDs to `SupportsAgentRun` instances
|
||||
|
||||
Returns:
|
||||
Map of executor IDs to list of HandoffConfiguration instances
|
||||
@@ -1145,14 +980,14 @@ class HandoffBuilder:
|
||||
if not source_agent:
|
||||
raise ValueError(
|
||||
f"Handoff source agent '{source_id}' not found. "
|
||||
"Please make sure source has been added as either a participant or participant_factory."
|
||||
"Please make sure source has been added as a participant."
|
||||
)
|
||||
for handoff_config in handoff_configurations:
|
||||
target_agent = agents.get(handoff_config.target_id)
|
||||
if not target_agent:
|
||||
raise ValueError(
|
||||
f"Handoff target agent '{handoff_config.target_id}' not found for source '{source_id}'. "
|
||||
"Please make sure target has been added as either a participant or participant_factory."
|
||||
"Please make sure target has been added as a participant."
|
||||
)
|
||||
|
||||
updated_handoff_configurations.setdefault(self._resolve_to_id(source_agent), []).append(
|
||||
@@ -1184,7 +1019,7 @@ class HandoffBuilder:
|
||||
"""Resolve agents into HandoffAgentExecutors.
|
||||
|
||||
Args:
|
||||
agents: Map of agent IDs or factory names to `SupportsAgentRun` instances
|
||||
agents: Map of agent IDs to `SupportsAgentRun` instances
|
||||
handoffs: Map of executor IDs to list of HandoffConfiguration instances
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1374,8 +1374,7 @@ class MagenticBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
participants: Sequence[SupportsAgentRun | Executor] | None = None,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]] | None = None,
|
||||
participants: Sequence[SupportsAgentRun | Executor],
|
||||
# Manager config (exactly one required)
|
||||
manager: MagenticManagerBase | None = None,
|
||||
manager_factory: Callable[[], MagenticManagerBase] | None = None,
|
||||
@@ -1401,8 +1400,7 @@ class MagenticBuilder:
|
||||
"""Initialize the Magentic workflow builder.
|
||||
|
||||
Args:
|
||||
participants: Optional sequence of agent or executor instances for the workflow.
|
||||
participant_factories: Optional sequence of callables returning agent or executor instances.
|
||||
participants: Sequence of agent or executor instances for the workflow.
|
||||
manager: Pre-configured manager instance (subclass of MagenticManagerBase).
|
||||
manager_factory: Callable that returns a new MagenticManagerBase instance.
|
||||
manager_agent: Agent instance for creating a StandardMagenticManager.
|
||||
@@ -1423,7 +1421,6 @@ class MagenticBuilder:
|
||||
intermediate_outputs: If True, enables intermediate outputs from agent participants.
|
||||
"""
|
||||
self._participants: dict[str, SupportsAgentRun | Executor] = {}
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
|
||||
# Manager related members
|
||||
self._manager: MagenticManagerBase | None = None
|
||||
@@ -1437,13 +1434,7 @@ class MagenticBuilder:
|
||||
# Intermediate outputs
|
||||
self._intermediate_outputs = intermediate_outputs
|
||||
|
||||
if participants is None and participant_factories is None:
|
||||
raise ValueError("Either participants or participant_factories must be provided.")
|
||||
|
||||
if participant_factories is not None:
|
||||
self._set_participant_factories(participant_factories)
|
||||
if participants is not None:
|
||||
self._set_participants(participants)
|
||||
self._set_participants(participants)
|
||||
|
||||
# Set manager if provided
|
||||
if any(x is not None for x in [manager, manager_factory, manager_agent, manager_agent_factory]):
|
||||
@@ -1465,27 +1456,8 @@ class MagenticBuilder:
|
||||
max_round_count=max_round_count,
|
||||
)
|
||||
|
||||
def _set_participant_factories(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> None:
|
||||
"""Set participant factories (internal)."""
|
||||
if self._participants:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participant_factories:
|
||||
raise ValueError("participant_factories already set.")
|
||||
|
||||
if not participant_factories:
|
||||
raise ValueError("participant_factories cannot be empty")
|
||||
|
||||
self._participant_factories = list(participant_factories)
|
||||
|
||||
def _set_participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> None:
|
||||
"""Set participants (internal)."""
|
||||
if self._participant_factories:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participants:
|
||||
raise ValueError("participants already set.")
|
||||
|
||||
@@ -1750,17 +1722,10 @@ class MagenticBuilder:
|
||||
|
||||
def _resolve_participants(self) -> list[Executor]:
|
||||
"""Resolve participant instances into Executor objects."""
|
||||
if not self._participants and not self._participant_factories:
|
||||
raise ValueError("No participants provided. Pass participants or participant_factories to the constructor.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
if not self._participants:
|
||||
raise ValueError("No participants provided. Pass participants to the constructor.")
|
||||
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
for factory in self._participant_factories:
|
||||
participant = factory()
|
||||
participants.append(participant)
|
||||
else:
|
||||
participants = list(self._participants.values())
|
||||
participants: list[Executor | SupportsAgentRun] = list(self._participants.values())
|
||||
|
||||
executors: list[Executor] = []
|
||||
for participant in participants:
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
|
||||
This module provides a high-level, agent-focused API to assemble a sequential
|
||||
workflow where:
|
||||
- Participants can be provided as SupportsAgentRun or Executor instances via `participants=[...]`,
|
||||
or as factories returning SupportsAgentRun or Executor via `participant_factories=[...]`
|
||||
- Participants are provided as SupportsAgentRun or Executor instances via `participants=[...]`
|
||||
- A shared conversation context (list[ChatMessage]) is passed along the chain
|
||||
- Agents append their assistant messages to the context
|
||||
- Custom executors can transform or summarize and return a refined context
|
||||
@@ -38,7 +37,7 @@ confusion and to mirror how the concurrent builder uses explicit dispatcher/aggr
|
||||
""" # noqa: E501
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatMessage, SupportsAgentRun
|
||||
@@ -110,8 +109,6 @@ class SequentialBuilder:
|
||||
r"""High-level builder for sequential agent/executor workflows with shared context.
|
||||
|
||||
- `participants=[...]` accepts a list of SupportsAgentRun (recommended) or Executor instances
|
||||
- `participant_factories=[...]` accepts a list of factories for SupportsAgentRun (recommended)
|
||||
or Executor factories
|
||||
- Executors must define a handler that consumes list[ChatMessage] and sends out a list[ChatMessage]
|
||||
- The workflow wires participants in order, passing a list[ChatMessage] down the chain
|
||||
- Agents append their assistant messages to the conversation
|
||||
@@ -127,11 +124,6 @@ class SequentialBuilder:
|
||||
# With agent instances
|
||||
workflow = SequentialBuilder(participants=[agent1, agent2, summarizer_exec]).build()
|
||||
|
||||
# With agent factories
|
||||
workflow = SequentialBuilder(
|
||||
participant_factories=[create_agent1, create_agent2, create_summarizer_exec]
|
||||
).build()
|
||||
|
||||
# Enable checkpoint persistence
|
||||
workflow = SequentialBuilder(participants=[agent1, agent2], checkpoint_storage=storage).build()
|
||||
|
||||
@@ -149,55 +141,27 @@ class SequentialBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
participants: Sequence[SupportsAgentRun | Executor] | None = None,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]] | None = None,
|
||||
participants: Sequence[SupportsAgentRun | Executor],
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
intermediate_outputs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the SequentialBuilder.
|
||||
|
||||
Args:
|
||||
participants: Optional sequence of agent or executor instances to run sequentially.
|
||||
participant_factories: Optional sequence of callables returning agent or executor instances.
|
||||
participants: Sequence of agent or executor instances to run sequentially.
|
||||
checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence.
|
||||
intermediate_outputs: If True, enables intermediate outputs from agent participants.
|
||||
"""
|
||||
self._participants: list[SupportsAgentRun | Executor] = []
|
||||
self._participant_factories: list[Callable[[], SupportsAgentRun | Executor]] = []
|
||||
self._checkpoint_storage: CheckpointStorage | None = checkpoint_storage
|
||||
self._request_info_enabled: bool = False
|
||||
self._request_info_filter: set[str] | None = None
|
||||
self._intermediate_outputs: bool = intermediate_outputs
|
||||
|
||||
if participants is None and participant_factories is None:
|
||||
raise ValueError("Either participants or participant_factories must be provided.")
|
||||
|
||||
if participant_factories is not None:
|
||||
self._set_participant_factories(participant_factories)
|
||||
if participants is not None:
|
||||
self._set_participants(participants)
|
||||
|
||||
def _set_participant_factories(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], SupportsAgentRun | Executor]],
|
||||
) -> None:
|
||||
"""Set participant factories (internal)."""
|
||||
if self._participants:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participant_factories:
|
||||
raise ValueError("participant_factories already set.")
|
||||
|
||||
if not participant_factories:
|
||||
raise ValueError("participant_factories cannot be empty")
|
||||
|
||||
self._participant_factories = list(participant_factories)
|
||||
self._set_participants(participants)
|
||||
|
||||
def _set_participants(self, participants: Sequence[SupportsAgentRun | Executor]) -> None:
|
||||
"""Set participants (internal)."""
|
||||
if self._participant_factories:
|
||||
raise ValueError("Cannot provide both participants and participant_factories.")
|
||||
|
||||
if self._participants:
|
||||
raise ValueError("participants already set.")
|
||||
|
||||
@@ -256,19 +220,10 @@ class SequentialBuilder:
|
||||
|
||||
def _resolve_participants(self) -> list[Executor]:
|
||||
"""Resolve participant instances into Executor objects."""
|
||||
if not self._participants and not self._participant_factories:
|
||||
raise ValueError("No participants provided. Pass participants or participant_factories to the constructor.")
|
||||
# We don't need to check if both are set since that is handled in the respective methods
|
||||
if not self._participants:
|
||||
raise ValueError("No participants provided. Pass participants to the constructor.")
|
||||
|
||||
participants: list[Executor | SupportsAgentRun] = []
|
||||
if self._participant_factories:
|
||||
# Resolve the participant factories now. This doesn't break the factory pattern
|
||||
# since the Sequential builder still creates new instances per workflow build.
|
||||
for factory in self._participant_factories:
|
||||
p = factory()
|
||||
participants.append(p)
|
||||
else:
|
||||
participants = self._participants
|
||||
participants: list[Executor | SupportsAgentRun] = self._participants
|
||||
|
||||
executors: list[Executor] = []
|
||||
for p in participants:
|
||||
|
||||
@@ -49,47 +49,6 @@ def test_concurrent_builder_rejects_duplicate_executors() -> None:
|
||||
ConcurrentBuilder(participants=[a, b])
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_duplicate_executors_from_factories() -> None:
|
||||
"""Test that duplicate executor IDs from factories are detected at build time."""
|
||||
|
||||
def create_dup1() -> Executor:
|
||||
return _FakeAgentExec("dup", "A")
|
||||
|
||||
def create_dup2() -> Executor:
|
||||
return _FakeAgentExec("dup", "B") # same executor id
|
||||
|
||||
builder = ConcurrentBuilder(participant_factories=[create_dup1, create_dup2])
|
||||
with pytest.raises(ValueError, match="Duplicate executor ID 'dup' detected in workflow."):
|
||||
builder.build()
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_mixed_participants_and_factories() -> None:
|
||||
"""Test that passing both participants and participant_factories to the constructor raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
ConcurrentBuilder(
|
||||
participants=[_FakeAgentExec("a", "A")],
|
||||
participant_factories=[lambda: _FakeAgentExec("b", "B")],
|
||||
)
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_both_participants_and_factories() -> None:
|
||||
"""Test that passing both participants and participant_factories raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
ConcurrentBuilder(
|
||||
participants=[_FakeAgentExec("a", "A")],
|
||||
participant_factories=[lambda: _FakeAgentExec("b", "B")],
|
||||
)
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_both_factories_and_participants() -> None:
|
||||
"""Test that passing both participant_factories and participants raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
ConcurrentBuilder(
|
||||
participant_factories=[lambda: _FakeAgentExec("a", "A")],
|
||||
participants=[_FakeAgentExec("b", "B")],
|
||||
)
|
||||
|
||||
|
||||
async def test_concurrent_default_aggregator_emits_single_user_and_assistants() -> None:
|
||||
# Three synthetic agent executors
|
||||
e1 = _FakeAgentExec("agentA", "Alpha")
|
||||
@@ -231,79 +190,6 @@ async def test_concurrent_with_aggregator_executor_instance() -> None:
|
||||
assert output == "One & Two"
|
||||
|
||||
|
||||
async def test_concurrent_with_aggregator_executor_factory() -> None:
|
||||
"""Test with_aggregator using an Executor factory."""
|
||||
|
||||
class CustomAggregator(Executor):
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
texts: list[str] = []
|
||||
for r in results:
|
||||
msgs: list[ChatMessage] = r.agent_response.messages
|
||||
texts.append(msgs[-1].text if msgs else "")
|
||||
await ctx.yield_output(" | ".join(sorted(texts)))
|
||||
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
e2 = _FakeAgentExec("agentB", "Two")
|
||||
|
||||
wf = (
|
||||
ConcurrentBuilder(participants=[e1, e2])
|
||||
.register_aggregator(lambda: CustomAggregator(id="custom_aggregator"))
|
||||
.build()
|
||||
)
|
||||
|
||||
completed = False
|
||||
output: str | None = None
|
||||
async for ev in wf.run("prompt: factory test", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif ev.type == "output":
|
||||
output = cast(str, ev.data)
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
assert isinstance(output, str)
|
||||
assert output == "One | Two"
|
||||
|
||||
|
||||
async def test_concurrent_with_aggregator_executor_factory_with_default_id() -> None:
|
||||
"""Test with_aggregator using an Executor class directly as factory (with default __init__ parameters)."""
|
||||
|
||||
class CustomAggregator(Executor):
|
||||
def __init__(self, id: str = "default_aggregator") -> None:
|
||||
super().__init__(id)
|
||||
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
texts: list[str] = []
|
||||
for r in results:
|
||||
msgs: list[ChatMessage] = r.agent_response.messages
|
||||
texts.append(msgs[-1].text if msgs else "")
|
||||
await ctx.yield_output(" | ".join(sorted(texts)))
|
||||
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
e2 = _FakeAgentExec("agentB", "Two")
|
||||
|
||||
wf = ConcurrentBuilder(participants=[e1, e2]).register_aggregator(CustomAggregator).build()
|
||||
|
||||
completed = False
|
||||
output: str | None = None
|
||||
async for ev in wf.run("prompt: factory test", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif ev.type == "output":
|
||||
output = cast(str, ev.data)
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
assert isinstance(output, str)
|
||||
assert output == "One | Two"
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_multiple_calls_to_with_aggregator() -> None:
|
||||
"""Test that multiple calls to .with_aggregator() raises an error."""
|
||||
|
||||
@@ -318,20 +204,6 @@ def test_concurrent_builder_rejects_multiple_calls_to_with_aggregator() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_multiple_calls_to_register_aggregator() -> None:
|
||||
"""Test that multiple calls to .register_aggregator() raises an error."""
|
||||
|
||||
class CustomAggregator(Executor):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match=r"register_aggregator\(\) has already been called"):
|
||||
(
|
||||
ConcurrentBuilder(participants=[_FakeAgentExec("a", "A")])
|
||||
.register_aggregator(lambda: CustomAggregator(id="agg1"))
|
||||
.register_aggregator(lambda: CustomAggregator(id="agg2"))
|
||||
)
|
||||
|
||||
|
||||
async def test_concurrent_checkpoint_resume_round_trip() -> None:
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -455,11 +327,6 @@ async def test_concurrent_checkpoint_runtime_overrides_buildtime() -> None:
|
||||
assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden"
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_empty_participant_factories() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ConcurrentBuilder(participant_factories=[])
|
||||
|
||||
|
||||
async def test_concurrent_builder_reusable_after_build_with_participants() -> None:
|
||||
"""Test that the builder can be reused to build multiple identical workflows with participants()."""
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
@@ -471,74 +338,3 @@ async def test_concurrent_builder_reusable_after_build_with_participants() -> No
|
||||
|
||||
assert builder._participants[0] is e1 # type: ignore
|
||||
assert builder._participants[1] is e2 # type: ignore
|
||||
assert builder._participant_factories == [] # type: ignore
|
||||
|
||||
|
||||
async def test_concurrent_builder_reusable_after_build_with_factories() -> None:
|
||||
"""Test that the builder can be reused to build multiple workflows with register_participants()."""
|
||||
call_count = 0
|
||||
|
||||
def create_agent_executor_a() -> Executor:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _FakeAgentExec("agentA", "One")
|
||||
|
||||
def create_agent_executor_b() -> Executor:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _FakeAgentExec("agentB", "Two")
|
||||
|
||||
builder = ConcurrentBuilder(participant_factories=[create_agent_executor_a, create_agent_executor_b])
|
||||
|
||||
# Build the first workflow
|
||||
wf1 = builder.build()
|
||||
|
||||
assert builder._participants == [] # type: ignore
|
||||
assert len(builder._participant_factories) == 2 # type: ignore
|
||||
assert call_count == 2
|
||||
|
||||
# Build the second workflow
|
||||
wf2 = builder.build()
|
||||
assert call_count == 4
|
||||
|
||||
# Verify that the two workflows have different executor instances
|
||||
assert wf1.executors["agentA"] is not wf2.executors["agentA"]
|
||||
assert wf1.executors["agentB"] is not wf2.executors["agentB"]
|
||||
|
||||
|
||||
async def test_concurrent_with_register_participants() -> None:
|
||||
"""Test workflow creation using register_participants with factories."""
|
||||
|
||||
def create_agent1() -> Executor:
|
||||
return _FakeAgentExec("agentA", "Alpha")
|
||||
|
||||
def create_agent2() -> Executor:
|
||||
return _FakeAgentExec("agentB", "Beta")
|
||||
|
||||
def create_agent3() -> Executor:
|
||||
return _FakeAgentExec("agentC", "Gamma")
|
||||
|
||||
wf = ConcurrentBuilder(participant_factories=[create_agent1, create_agent2, create_agent3]).build()
|
||||
|
||||
completed = False
|
||||
output: list[ChatMessage] | None = None
|
||||
async for ev in wf.run("test prompt", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif ev.type == "output":
|
||||
output = cast(list[ChatMessage], ev.data)
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
messages: list[ChatMessage] = output
|
||||
|
||||
# Expect one user message + one assistant message per participant
|
||||
assert len(messages) == 1 + 3
|
||||
assert messages[0].role == "user"
|
||||
assert "test prompt" in messages[0].text
|
||||
|
||||
assistant_texts = {m.text for m in messages[1:]}
|
||||
assert assistant_texts == {"Alpha", "Beta", "Gamma"}
|
||||
assert all(m.role == "assistant" for m in messages[1:])
|
||||
|
||||
@@ -240,12 +240,9 @@ class TestGroupChatBuilder:
|
||||
builder.build()
|
||||
|
||||
def test_build_without_participants_raises_error(self) -> None:
|
||||
"""Test that constructing without participants raises ValueError."""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Either participants or participant_factories must be provided\.",
|
||||
):
|
||||
GroupChatBuilder()
|
||||
"""Test that constructing with empty participants raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
GroupChatBuilder(participants=[])
|
||||
|
||||
def test_duplicate_manager_configuration_raises_error(self) -> None:
|
||||
"""Test that configuring multiple orchestrator options raises ValueError."""
|
||||
@@ -775,150 +772,6 @@ def test_group_chat_builder_with_request_info_returns_self():
|
||||
assert result2 is builder2
|
||||
|
||||
|
||||
# region Participant Factory Tests
|
||||
|
||||
|
||||
def test_group_chat_builder_rejects_empty_participant_factories():
|
||||
"""Test that GroupChatBuilder rejects empty participant_factories list."""
|
||||
|
||||
def selector(state: GroupChatState) -> str:
|
||||
return list(state.participants.keys())[0]
|
||||
|
||||
with pytest.raises(ValueError, match=r"participant_factories cannot be empty"):
|
||||
GroupChatBuilder(participant_factories=[])
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Either participants or participant_factories must be provided\.",
|
||||
):
|
||||
GroupChatBuilder()
|
||||
|
||||
|
||||
def test_group_chat_builder_rejects_mixing_participants_and_factories():
|
||||
"""Test that passing both participants and participant_factories to the constructor raises an error."""
|
||||
alpha = StubAgent("alpha", "reply from alpha")
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
GroupChatBuilder(
|
||||
participants=[alpha],
|
||||
participant_factories=[lambda: StubAgent("beta", "reply from beta")],
|
||||
)
|
||||
|
||||
|
||||
def test_group_chat_builder_rejects_both_factories_and_participants():
|
||||
"""Test that passing both participant_factories and participants raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
GroupChatBuilder(
|
||||
participant_factories=[lambda: StubAgent("alpha", "reply from alpha")],
|
||||
participants=[StubAgent("beta", "reply from beta")],
|
||||
)
|
||||
|
||||
|
||||
def test_group_chat_builder_rejects_both_participants_and_factories():
|
||||
"""Test that passing both participants and participant_factories raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
GroupChatBuilder(
|
||||
participants=[StubAgent("alpha", "reply from alpha")],
|
||||
participant_factories=[lambda: StubAgent("beta", "reply from beta")],
|
||||
)
|
||||
|
||||
|
||||
async def test_group_chat_with_participant_factories():
|
||||
"""Test workflow creation using participant_factories."""
|
||||
call_count = 0
|
||||
|
||||
def create_alpha() -> StubAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return StubAgent("alpha", "reply from alpha")
|
||||
|
||||
def create_beta() -> StubAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return StubAgent("beta", "reply from beta")
|
||||
|
||||
selector = make_sequence_selector()
|
||||
|
||||
workflow = GroupChatBuilder(
|
||||
participant_factories=[create_alpha, create_beta],
|
||||
max_rounds=2,
|
||||
selection_func=selector,
|
||||
).build()
|
||||
|
||||
# Factories should be called during build
|
||||
assert call_count == 2
|
||||
|
||||
outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run("coordinate task", stream=True):
|
||||
if event.type == "output":
|
||||
outputs.append(event)
|
||||
|
||||
assert len(outputs) == 1
|
||||
|
||||
|
||||
async def test_group_chat_participant_factories_reusable_builder():
|
||||
"""Test that the builder can be reused to build multiple workflows with factories."""
|
||||
call_count = 0
|
||||
|
||||
def create_alpha() -> StubAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return StubAgent("alpha", "reply from alpha")
|
||||
|
||||
def create_beta() -> StubAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return StubAgent("beta", "reply from beta")
|
||||
|
||||
selector = make_sequence_selector()
|
||||
|
||||
builder = GroupChatBuilder(participant_factories=[create_alpha, create_beta], max_rounds=2, selection_func=selector)
|
||||
|
||||
# Build first workflow
|
||||
wf1 = builder.build()
|
||||
assert call_count == 2
|
||||
|
||||
# Build second workflow
|
||||
wf2 = builder.build()
|
||||
assert call_count == 4
|
||||
|
||||
# Verify that the two workflows have different agent instances
|
||||
assert wf1.executors["alpha"] is not wf2.executors["alpha"]
|
||||
assert wf1.executors["beta"] is not wf2.executors["beta"]
|
||||
|
||||
|
||||
async def test_group_chat_participant_factories_with_checkpointing():
|
||||
"""Test checkpointing with participant_factories."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
def create_alpha() -> StubAgent:
|
||||
return StubAgent("alpha", "reply from alpha")
|
||||
|
||||
def create_beta() -> StubAgent:
|
||||
return StubAgent("beta", "reply from beta")
|
||||
|
||||
selector = make_sequence_selector()
|
||||
|
||||
workflow = GroupChatBuilder(
|
||||
participant_factories=[create_alpha, create_beta],
|
||||
checkpoint_storage=storage,
|
||||
max_rounds=2,
|
||||
selection_func=selector,
|
||||
).build()
|
||||
|
||||
outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run("checkpoint test", stream=True):
|
||||
if event.type == "output":
|
||||
outputs.append(event)
|
||||
|
||||
assert outputs, "Should have workflow output"
|
||||
|
||||
checkpoints = await storage.list_checkpoints()
|
||||
assert checkpoints, "Checkpoints should be created during workflow execution"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Orchestrator Factory Tests
|
||||
|
||||
|
||||
@@ -1129,77 +982,4 @@ def test_group_chat_orchestrator_factory_invalid_return_type():
|
||||
GroupChatBuilder(participants=[alpha], orchestrator_agent=invalid_factory).build()
|
||||
|
||||
|
||||
def test_group_chat_with_both_participant_and_orchestrator_factories():
|
||||
"""Test workflow creation using both participant_factories and orchestrator_factory."""
|
||||
participant_factory_call_count = 0
|
||||
agent_factory_call_count = 0
|
||||
|
||||
def create_alpha() -> StubAgent:
|
||||
nonlocal participant_factory_call_count
|
||||
participant_factory_call_count += 1
|
||||
return StubAgent("alpha", "reply from alpha")
|
||||
|
||||
def create_beta() -> StubAgent:
|
||||
nonlocal participant_factory_call_count
|
||||
participant_factory_call_count += 1
|
||||
return StubAgent("beta", "reply from beta")
|
||||
|
||||
def agent_factory() -> ChatAgent:
|
||||
nonlocal agent_factory_call_count
|
||||
agent_factory_call_count += 1
|
||||
return cast(ChatAgent, StubManagerAgent())
|
||||
|
||||
workflow = GroupChatBuilder(
|
||||
participant_factories=[create_alpha, create_beta],
|
||||
orchestrator_agent=agent_factory,
|
||||
).build()
|
||||
|
||||
# All factories should be called during build
|
||||
assert participant_factory_call_count == 2
|
||||
assert agent_factory_call_count == 1
|
||||
|
||||
# Verify all executors are present in the workflow
|
||||
assert "alpha" in workflow.executors
|
||||
assert "beta" in workflow.executors
|
||||
assert "manager_agent" in workflow.executors
|
||||
|
||||
|
||||
async def test_group_chat_factories_reusable_for_multiple_workflows():
|
||||
"""Test that both factories are reused correctly for multiple workflow builds."""
|
||||
participant_factory_call_count = 0
|
||||
agent_factory_call_count = 0
|
||||
|
||||
def create_alpha() -> StubAgent:
|
||||
nonlocal participant_factory_call_count
|
||||
participant_factory_call_count += 1
|
||||
return StubAgent("alpha", "reply from alpha")
|
||||
|
||||
def create_beta() -> StubAgent:
|
||||
nonlocal participant_factory_call_count
|
||||
participant_factory_call_count += 1
|
||||
return StubAgent("beta", "reply from beta")
|
||||
|
||||
def agent_factory() -> ChatAgent:
|
||||
nonlocal agent_factory_call_count
|
||||
agent_factory_call_count += 1
|
||||
return cast(ChatAgent, StubManagerAgent())
|
||||
|
||||
builder = GroupChatBuilder(participant_factories=[create_alpha, create_beta], orchestrator_agent=agent_factory)
|
||||
|
||||
# Build first workflow
|
||||
wf1 = builder.build()
|
||||
assert participant_factory_call_count == 2
|
||||
assert agent_factory_call_count == 1
|
||||
|
||||
# Build second workflow
|
||||
wf2 = builder.build()
|
||||
assert participant_factory_call_count == 4
|
||||
assert agent_factory_call_count == 2
|
||||
|
||||
# Verify that the workflows have different agent and orchestrator instances
|
||||
assert wf1.executors["alpha"] is not wf2.executors["alpha"]
|
||||
assert wf1.executors["beta"] is not wf2.executors["beta"]
|
||||
assert wf1.executors["manager_agent"] is not wf2.executors["manager_agent"]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -229,10 +229,8 @@ def test_build_fails_without_start_agent():
|
||||
|
||||
def test_build_fails_without_participants():
|
||||
"""Verify that build() raises ValueError when no participants are provided."""
|
||||
with pytest.raises(
|
||||
ValueError, match=r"No participants provided\. Call \.participants\(\) or \.register_participants\(\) first."
|
||||
):
|
||||
HandoffBuilder().build()
|
||||
with pytest.raises(ValueError):
|
||||
HandoffBuilder(participants=[]).build()
|
||||
|
||||
|
||||
async def test_handoff_async_termination_condition() -> None:
|
||||
@@ -349,162 +347,6 @@ async def test_context_provider_preserved_during_handoff():
|
||||
)
|
||||
|
||||
|
||||
# region Participant Factory Tests
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_empty_participant_factories():
|
||||
"""Test that HandoffBuilder rejects empty participant_factories dictionary."""
|
||||
# Empty factories are rejected immediately when calling participant_factories()
|
||||
with pytest.raises(ValueError, match=r"participant_factories cannot be empty"):
|
||||
HandoffBuilder().register_participants({})
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"No participants provided\. Call \.participants\(\) or \.register_participants\(\) first\."
|
||||
):
|
||||
HandoffBuilder(participant_factories={}).build()
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_mixing_participants_and_factories():
|
||||
"""Test that mixing participants and participant_factories in __init__ raises an error."""
|
||||
triage = MockHandoffAgent(name="triage")
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
HandoffBuilder(participants=[triage], participant_factories={"triage": lambda: triage})
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_mixing_participants_and_participant_factories_methods():
|
||||
"""Test that mixing .participants() and .participant_factories() raises an error."""
|
||||
triage = MockHandoffAgent(name="triage")
|
||||
|
||||
# Case 1: participants first, then participant_factories
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
HandoffBuilder(participants=[triage]).register_participants({
|
||||
"specialist": lambda: MockHandoffAgent(name="specialist")
|
||||
})
|
||||
|
||||
# Case 2: participant_factories first, then participants
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
HandoffBuilder(participant_factories={"triage": lambda: triage}).participants([
|
||||
MockHandoffAgent(name="specialist")
|
||||
])
|
||||
|
||||
# Case 3: participants(), then participant_factories()
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
HandoffBuilder().participants([triage]).register_participants({
|
||||
"specialist": lambda: MockHandoffAgent(name="specialist")
|
||||
})
|
||||
|
||||
# Case 4: participant_factories(), then participants()
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
HandoffBuilder().register_participants({"triage": lambda: triage}).participants([
|
||||
MockHandoffAgent(name="specialist")
|
||||
])
|
||||
|
||||
# Case 5: mix during initialization
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
HandoffBuilder(
|
||||
participants=[triage], participant_factories={"specialist": lambda: MockHandoffAgent(name="specialist")}
|
||||
)
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_multiple_calls_to_participant_factories():
|
||||
"""Test that multiple calls to .participant_factories() raises an error."""
|
||||
with pytest.raises(
|
||||
ValueError, match=r"register_participants\(\) has already been called on this builder instance."
|
||||
):
|
||||
(
|
||||
HandoffBuilder()
|
||||
.register_participants({"agent1": lambda: MockHandoffAgent(name="agent1")})
|
||||
.register_participants({"agent2": lambda: MockHandoffAgent(name="agent2")})
|
||||
)
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_multiple_calls_to_participants():
|
||||
"""Test that multiple calls to .participants() raises an error."""
|
||||
with pytest.raises(ValueError, match="participants have already been assigned"):
|
||||
(
|
||||
HandoffBuilder()
|
||||
.participants([MockHandoffAgent(name="agent1")])
|
||||
.participants([MockHandoffAgent(name="agent2")])
|
||||
)
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_instance_coordinator_with_factories():
|
||||
"""Test that using an agent instance for set_coordinator when using factories raises an error."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage")
|
||||
|
||||
def create_specialist() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist")
|
||||
|
||||
# Create an agent instance
|
||||
coordinator_instance = MockHandoffAgent(name="coordinator")
|
||||
|
||||
with pytest.raises(ValueError, match=r"Call participants\(\.\.\.\) before with_start_agent\(\.\.\.\)"):
|
||||
(
|
||||
HandoffBuilder(
|
||||
participant_factories={"triage": create_triage, "specialist": create_specialist}
|
||||
).with_start_agent(coordinator_instance) # Instance, not factory name
|
||||
)
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_factory_name_coordinator_with_instances():
|
||||
"""Test that using a factory name for set_coordinator when using instances raises an error."""
|
||||
triage = MockHandoffAgent(name="triage")
|
||||
specialist = MockHandoffAgent(name="specialist")
|
||||
|
||||
with pytest.raises(ValueError, match=r"Call register_participants\(...\) before with_start_agent\(...\)"):
|
||||
(
|
||||
HandoffBuilder(participants=[triage, specialist]).with_start_agent(
|
||||
"triage"
|
||||
) # String factory name, not instance
|
||||
)
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_mixed_types_in_add_handoff_source():
|
||||
"""Test that add_handoff rejects factory name source with instance-based participants."""
|
||||
triage = MockHandoffAgent(name="triage")
|
||||
specialist = MockHandoffAgent(name="specialist")
|
||||
|
||||
with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and SupportsAgentRun.*instances"):
|
||||
(
|
||||
HandoffBuilder(participants=[triage, specialist])
|
||||
.with_start_agent(triage)
|
||||
.add_handoff("triage", [specialist]) # String source with instance participants
|
||||
)
|
||||
|
||||
|
||||
def test_handoff_builder_accepts_all_factory_names_in_add_handoff():
|
||||
"""Test that add_handoff accepts all factory names when using participant_factories."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage")
|
||||
|
||||
def create_specialist_a() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist_a")
|
||||
|
||||
def create_specialist_b() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist_b")
|
||||
|
||||
# This should work - all strings with participant_factories
|
||||
builder = (
|
||||
HandoffBuilder(
|
||||
participant_factories={
|
||||
"triage": create_triage,
|
||||
"specialist_a": create_specialist_a,
|
||||
"specialist_b": create_specialist_b,
|
||||
}
|
||||
)
|
||||
.with_start_agent("triage")
|
||||
.add_handoff("triage", ["specialist_a", "specialist_b"])
|
||||
)
|
||||
|
||||
workflow = builder.build()
|
||||
assert "triage" in workflow.executors
|
||||
assert "specialist_a" in workflow.executors
|
||||
assert "specialist_b" in workflow.executors
|
||||
|
||||
|
||||
def test_handoff_builder_accepts_all_instances_in_add_handoff():
|
||||
"""Test that add_handoff accepts all instances when using participants."""
|
||||
triage = MockHandoffAgent(name="triage", handoff_to="specialist_a")
|
||||
@@ -522,260 +364,3 @@ def test_handoff_builder_accepts_all_instances_in_add_handoff():
|
||||
assert "triage" in workflow.executors
|
||||
assert "specialist_a" in workflow.executors
|
||||
assert "specialist_b" in workflow.executors
|
||||
|
||||
|
||||
async def test_handoff_with_participant_factories():
|
||||
"""Test workflow creation using participant_factories."""
|
||||
call_count = 0
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return MockHandoffAgent(name="triage", handoff_to="specialist")
|
||||
|
||||
def create_specialist() -> MockHandoffAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return MockHandoffAgent(name="specialist")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
participant_factories={"triage": create_triage, "specialist": create_specialist},
|
||||
termination_condition=lambda conv: sum(1 for m in conv if m.role == "user") >= 2,
|
||||
)
|
||||
.with_start_agent("triage")
|
||||
.build()
|
||||
)
|
||||
|
||||
# Factories should be called during build
|
||||
assert call_count == 2
|
||||
|
||||
events = await _drain(workflow.run("Need help", stream=True))
|
||||
requests = [ev for ev in events if ev.type == "request_info"]
|
||||
assert requests
|
||||
|
||||
# Follow-up message
|
||||
events = await _drain(
|
||||
workflow.run(stream=True, responses={requests[-1].request_id: [ChatMessage(role="user", text="More details")]})
|
||||
)
|
||||
outputs = [ev for ev in events if ev.type == "output"]
|
||||
assert outputs
|
||||
|
||||
|
||||
async def test_handoff_participant_factories_reusable_builder():
|
||||
"""Test that the builder can be reused to build multiple workflows with factories."""
|
||||
call_count = 0
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return MockHandoffAgent(name="triage", handoff_to="specialist")
|
||||
|
||||
def create_specialist() -> MockHandoffAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return MockHandoffAgent(name="specialist")
|
||||
|
||||
builder = HandoffBuilder(
|
||||
participant_factories={"triage": create_triage, "specialist": create_specialist}
|
||||
).with_start_agent("triage")
|
||||
|
||||
# Build first workflow
|
||||
wf1 = builder.build()
|
||||
assert call_count == 2
|
||||
|
||||
# Build second workflow
|
||||
wf2 = builder.build()
|
||||
assert call_count == 4
|
||||
|
||||
# Verify that the two workflows have different agent instances
|
||||
assert wf1.executors["triage"] is not wf2.executors["triage"]
|
||||
assert wf1.executors["specialist"] is not wf2.executors["specialist"]
|
||||
|
||||
|
||||
async def test_handoff_with_participant_factories_and_add_handoff():
|
||||
"""Test that .add_handoff() works correctly with participant_factories."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage", handoff_to="specialist_a")
|
||||
|
||||
def create_specialist_a() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist_a", handoff_to="specialist_b")
|
||||
|
||||
def create_specialist_b() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist_b")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
participant_factories={
|
||||
"triage": create_triage,
|
||||
"specialist_a": create_specialist_a,
|
||||
"specialist_b": create_specialist_b,
|
||||
},
|
||||
termination_condition=lambda conv: sum(1 for m in conv if m.role == "user") >= 3,
|
||||
)
|
||||
.with_start_agent("triage")
|
||||
.add_handoff("triage", ["specialist_a", "specialist_b"])
|
||||
.add_handoff("specialist_a", ["specialist_b"])
|
||||
.build()
|
||||
)
|
||||
|
||||
# Start conversation - triage hands off to specialist_a
|
||||
events = await _drain(workflow.run("Initial request", stream=True))
|
||||
requests = [ev for ev in events if ev.type == "request_info"]
|
||||
assert requests
|
||||
|
||||
# Verify specialist_a executor exists and was called
|
||||
assert "specialist_a" in workflow.executors
|
||||
|
||||
# Second user message - specialist_a hands off to specialist_b
|
||||
events = await _drain(
|
||||
workflow.run(
|
||||
stream=True, responses={requests[-1].request_id: [ChatMessage(role="user", text="Need escalation")]}
|
||||
)
|
||||
)
|
||||
requests = [ev for ev in events if ev.type == "request_info"]
|
||||
assert requests
|
||||
|
||||
# Verify specialist_b executor exists
|
||||
assert "specialist_b" in workflow.executors
|
||||
|
||||
|
||||
async def test_handoff_participant_factories_with_checkpointing():
|
||||
"""Test checkpointing with participant_factories."""
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
||||
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage", handoff_to="specialist")
|
||||
|
||||
def create_specialist() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
participant_factories={"triage": create_triage, "specialist": create_specialist},
|
||||
checkpoint_storage=storage,
|
||||
termination_condition=lambda conv: sum(1 for m in conv if m.role == "user") >= 2,
|
||||
)
|
||||
.with_start_agent("triage")
|
||||
.build()
|
||||
)
|
||||
|
||||
# Run workflow and capture output
|
||||
events = await _drain(workflow.run("checkpoint test", stream=True))
|
||||
requests = [ev for ev in events if ev.type == "request_info"]
|
||||
assert requests
|
||||
|
||||
events = await _drain(
|
||||
workflow.run(stream=True, responses={requests[-1].request_id: [ChatMessage(role="user", text="follow up")]})
|
||||
)
|
||||
outputs = [ev for ev in events if ev.type == "output"]
|
||||
assert outputs, "Should have workflow output after termination condition is met"
|
||||
|
||||
# List checkpoints - just verify they were created
|
||||
checkpoints = await storage.list_checkpoints()
|
||||
assert checkpoints, "Checkpoints should be created during workflow execution"
|
||||
|
||||
|
||||
def test_handoff_set_coordinator_with_factory_name():
|
||||
"""Test that set_coordinator accepts factory name as string."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage")
|
||||
|
||||
def create_specialist() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist")
|
||||
|
||||
builder = HandoffBuilder(
|
||||
participant_factories={"triage": create_triage, "specialist": create_specialist}
|
||||
).with_start_agent("triage")
|
||||
|
||||
workflow = builder.build()
|
||||
assert "triage" in workflow.executors
|
||||
|
||||
|
||||
def test_handoff_add_handoff_with_factory_names():
|
||||
"""Test that add_handoff accepts factory names as strings."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage", handoff_to="specialist_a")
|
||||
|
||||
def create_specialist_a() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist_a")
|
||||
|
||||
def create_specialist_b() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist_b")
|
||||
|
||||
builder = (
|
||||
HandoffBuilder(
|
||||
participant_factories={
|
||||
"triage": create_triage,
|
||||
"specialist_a": create_specialist_a,
|
||||
"specialist_b": create_specialist_b,
|
||||
}
|
||||
)
|
||||
.with_start_agent("triage")
|
||||
.add_handoff("triage", ["specialist_a", "specialist_b"])
|
||||
)
|
||||
|
||||
workflow = builder.build()
|
||||
assert "triage" in workflow.executors
|
||||
assert "specialist_a" in workflow.executors
|
||||
assert "specialist_b" in workflow.executors
|
||||
|
||||
|
||||
async def test_handoff_participant_factories_autonomous_mode():
|
||||
"""Test autonomous mode with participant_factories."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage", handoff_to="specialist")
|
||||
|
||||
def create_specialist() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist})
|
||||
.with_start_agent("triage")
|
||||
.with_autonomous_mode(agents=["specialist"], turn_limits={"specialist": 1})
|
||||
.build()
|
||||
)
|
||||
|
||||
events = await _drain(workflow.run("Issue", stream=True))
|
||||
requests = [ev for ev in events if ev.type == "request_info"]
|
||||
assert requests and len(requests) == 1
|
||||
assert requests[0].source_executor_id == "specialist"
|
||||
|
||||
|
||||
def test_handoff_participant_factories_invalid_coordinator_name():
|
||||
"""Test that set_coordinator raises error for non-existent factory name."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Start agent factory name 'nonexistent' is not in the participant_factories list"
|
||||
):
|
||||
(HandoffBuilder(participant_factories={"triage": create_triage}).with_start_agent("nonexistent").build())
|
||||
|
||||
|
||||
def test_handoff_participant_factories_invalid_handoff_target():
|
||||
"""Test that add_handoff raises error for non-existent target factory name."""
|
||||
|
||||
def create_triage() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="triage")
|
||||
|
||||
def create_specialist() -> MockHandoffAgent:
|
||||
return MockHandoffAgent(name="specialist")
|
||||
|
||||
with pytest.raises(ValueError, match="Target factory name 'nonexistent' is not in the participant_factories list"):
|
||||
(
|
||||
HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist})
|
||||
.with_start_agent("triage")
|
||||
.add_handoff("triage", ["nonexistent"])
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
# endregion Participant Factory Tests
|
||||
|
||||
@@ -890,121 +890,6 @@ async def test_magentic_checkpoint_restore_no_duplicate_history():
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Participant Factory Tests
|
||||
|
||||
|
||||
def test_magentic_builder_rejects_empty_participant_factories():
|
||||
"""Test that MagenticBuilder rejects empty participant_factories list."""
|
||||
with pytest.raises(ValueError, match=r"participant_factories cannot be empty"):
|
||||
MagenticBuilder(participant_factories=[])
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Either participants or participant_factories must be provided\.",
|
||||
):
|
||||
MagenticBuilder()
|
||||
|
||||
|
||||
def test_magentic_builder_rejects_mixing_participants_and_factories():
|
||||
"""Test that passing both participants and participant_factories to the constructor raises an error."""
|
||||
agent = StubAgent("agentA", "reply from agentA")
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
MagenticBuilder(
|
||||
participants=[agent],
|
||||
participant_factories=[lambda: StubAgent("agentB", "reply")],
|
||||
)
|
||||
|
||||
|
||||
def test_magentic_builder_rejects_both_factories_and_participants():
|
||||
"""Test that passing both participant_factories and participants raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
MagenticBuilder(
|
||||
participant_factories=[lambda: StubAgent("agentA", "reply from agentA")],
|
||||
participants=[StubAgent("agentB", "reply from agentB")],
|
||||
)
|
||||
|
||||
|
||||
def test_magentic_builder_rejects_both_participants_and_factories():
|
||||
"""Test that passing both participants and participant_factories raises an error."""
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
MagenticBuilder(
|
||||
participants=[StubAgent("agentA", "reply from agentA")],
|
||||
participant_factories=[lambda: StubAgent("agentB", "reply from agentB")],
|
||||
)
|
||||
|
||||
|
||||
async def test_magentic_with_participant_factories():
|
||||
"""Test workflow creation using participant_factories."""
|
||||
call_count = 0
|
||||
|
||||
def create_agent() -> StubAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return StubAgent("agentA", "reply from agentA")
|
||||
|
||||
manager = FakeManager()
|
||||
workflow = MagenticBuilder(participant_factories=[create_agent], manager=manager).build()
|
||||
|
||||
# Factory should be called during build
|
||||
assert call_count == 1
|
||||
|
||||
outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run("test task", stream=True):
|
||||
if event.type == "output":
|
||||
outputs.append(event)
|
||||
|
||||
assert len(outputs) == 1
|
||||
|
||||
|
||||
async def test_magentic_participant_factories_reusable_builder():
|
||||
"""Test that the builder can be reused to build multiple workflows with factories."""
|
||||
call_count = 0
|
||||
|
||||
def create_agent() -> StubAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return StubAgent("agentA", "reply from agentA")
|
||||
|
||||
builder = MagenticBuilder(participant_factories=[create_agent], manager=FakeManager())
|
||||
|
||||
# Build first workflow
|
||||
wf1 = builder.build()
|
||||
assert call_count == 1
|
||||
|
||||
# Build second workflow
|
||||
wf2 = builder.build()
|
||||
assert call_count == 2
|
||||
|
||||
# Verify that the two workflows have different agent instances
|
||||
assert wf1.executors["agentA"] is not wf2.executors["agentA"]
|
||||
|
||||
|
||||
async def test_magentic_participant_factories_with_checkpointing():
|
||||
"""Test checkpointing with participant_factories."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
def create_agent() -> StubAgent:
|
||||
return StubAgent("agentA", "reply from agentA")
|
||||
|
||||
manager = FakeManager()
|
||||
workflow = MagenticBuilder(
|
||||
participant_factories=[create_agent], checkpoint_storage=storage, manager=manager
|
||||
).build()
|
||||
|
||||
outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run("checkpoint test", stream=True):
|
||||
if event.type == "output":
|
||||
outputs.append(event)
|
||||
|
||||
assert outputs, "Should have workflow output"
|
||||
|
||||
checkpoints = await storage.list_checkpoints()
|
||||
assert checkpoints, "Checkpoints should be created during workflow execution"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Manager Factory Tests
|
||||
@@ -1112,66 +997,6 @@ async def test_magentic_manager_factory_reusable_builder():
|
||||
assert orchestrator1 is not orchestrator2
|
||||
|
||||
|
||||
def test_magentic_with_both_participant_and_manager_factories():
|
||||
"""Test workflow creation using both participant_factories and manager_factory."""
|
||||
participant_factory_call_count = 0
|
||||
manager_factory_call_count = 0
|
||||
|
||||
def create_agent() -> StubAgent:
|
||||
nonlocal participant_factory_call_count
|
||||
participant_factory_call_count += 1
|
||||
return StubAgent("agentA", "reply from agentA")
|
||||
|
||||
def manager_factory() -> MagenticManagerBase:
|
||||
nonlocal manager_factory_call_count
|
||||
manager_factory_call_count += 1
|
||||
return FakeManager()
|
||||
|
||||
workflow = MagenticBuilder(participant_factories=[create_agent], manager_factory=manager_factory).build()
|
||||
|
||||
# All factories should be called during build
|
||||
assert participant_factory_call_count == 1
|
||||
assert manager_factory_call_count == 1
|
||||
|
||||
# Verify executor is present in the workflow
|
||||
assert "agentA" in workflow.executors
|
||||
|
||||
|
||||
async def test_magentic_factories_reusable_for_multiple_workflows():
|
||||
"""Test that both factories are reused correctly for multiple workflow builds."""
|
||||
participant_factory_call_count = 0
|
||||
manager_factory_call_count = 0
|
||||
|
||||
def create_agent() -> StubAgent:
|
||||
nonlocal participant_factory_call_count
|
||||
participant_factory_call_count += 1
|
||||
return StubAgent("agentA", "reply from agentA")
|
||||
|
||||
def manager_factory() -> MagenticManagerBase:
|
||||
nonlocal manager_factory_call_count
|
||||
manager_factory_call_count += 1
|
||||
return FakeManager()
|
||||
|
||||
builder = MagenticBuilder(participant_factories=[create_agent], manager_factory=manager_factory)
|
||||
|
||||
# Build first workflow
|
||||
wf1 = builder.build()
|
||||
assert participant_factory_call_count == 1
|
||||
assert manager_factory_call_count == 1
|
||||
|
||||
# Build second workflow
|
||||
wf2 = builder.build()
|
||||
assert participant_factory_call_count == 2
|
||||
assert manager_factory_call_count == 2
|
||||
|
||||
# Verify that the workflows have different agent and orchestrator instances
|
||||
assert wf1.executors["agentA"] is not wf2.executors["agentA"]
|
||||
|
||||
orchestrator1 = next(e for e in wf1.executors.values() if isinstance(e, MagenticOrchestrator))
|
||||
orchestrator2 = next(e for e in wf2.executors.values() if isinstance(e, MagenticOrchestrator))
|
||||
assert orchestrator1 is not orchestrator2
|
||||
|
||||
|
||||
def test_magentic_agent_factory_with_standard_manager_options():
|
||||
"""Test that agent_factory properly passes through standard manager options."""
|
||||
factory_call_count = 0
|
||||
|
||||
@@ -71,22 +71,6 @@ def test_sequential_builder_rejects_empty_participants() -> None:
|
||||
SequentialBuilder(participants=[])
|
||||
|
||||
|
||||
def test_sequential_builder_rejects_empty_participant_factories() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
SequentialBuilder(participant_factories=[])
|
||||
|
||||
|
||||
def test_sequential_builder_rejects_mixing_participants_and_factories() -> None:
|
||||
"""Test that passing both participants and participant_factories to the constructor raises an error."""
|
||||
a1 = _EchoAgent(id="agent1", name="A1")
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot provide both participants and participant_factories"):
|
||||
SequentialBuilder(
|
||||
participants=[a1],
|
||||
participant_factories=[lambda: _EchoAgent(id="agent2", name="A2")],
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_builder_validation_rejects_invalid_executor() -> None:
|
||||
"""Test that adding an invalid executor to the builder raises an error."""
|
||||
with pytest.raises(TypeCompatibilityError):
|
||||
@@ -121,37 +105,6 @@ async def test_sequential_agents_append_to_context() -> None:
|
||||
assert "A2 reply" in msgs[2].text
|
||||
|
||||
|
||||
async def test_sequential_register_participants_with_agent_factories() -> None:
|
||||
"""Test that register_participants works with agent factories."""
|
||||
|
||||
def create_agent1() -> _EchoAgent:
|
||||
return _EchoAgent(id="agent1", name="A1")
|
||||
|
||||
def create_agent2() -> _EchoAgent:
|
||||
return _EchoAgent(id="agent2", name="A2")
|
||||
|
||||
wf = SequentialBuilder(participant_factories=[create_agent1, create_agent2]).build()
|
||||
|
||||
completed = False
|
||||
output: list[ChatMessage] | None = None
|
||||
async for ev in wf.run("hello factories", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif ev.type == "output":
|
||||
output = ev.data
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
assert isinstance(output, list)
|
||||
msgs: list[ChatMessage] = output
|
||||
assert len(msgs) == 3
|
||||
assert msgs[0].role == "user" and "hello factories" in msgs[0].text
|
||||
assert msgs[1].role == "assistant" and "A1 reply" in msgs[1].text
|
||||
assert msgs[2].role == "assistant" and "A2 reply" in msgs[2].text
|
||||
|
||||
|
||||
async def test_sequential_with_custom_executor_summary() -> None:
|
||||
a1 = _EchoAgent(id="agent1", name="A1")
|
||||
summarizer = _SummarizerExec(id="summarizer")
|
||||
@@ -178,37 +131,6 @@ async def test_sequential_with_custom_executor_summary() -> None:
|
||||
assert msgs[2].role == "assistant" and msgs[2].text.startswith("Summary of users:")
|
||||
|
||||
|
||||
async def test_sequential_register_participants_mixed_agents_and_executors() -> None:
|
||||
"""Test register_participants with both agent and executor factories."""
|
||||
|
||||
def create_agent() -> _EchoAgent:
|
||||
return _EchoAgent(id="agent1", name="A1")
|
||||
|
||||
def create_summarizer() -> _SummarizerExec:
|
||||
return _SummarizerExec(id="summarizer")
|
||||
|
||||
wf = SequentialBuilder(participant_factories=[create_agent, create_summarizer]).build()
|
||||
|
||||
completed = False
|
||||
output: list[ChatMessage] | None = None
|
||||
async for ev in wf.run("topic Y", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif ev.type == "output":
|
||||
output = ev.data
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
msgs: list[ChatMessage] = output
|
||||
# Expect: [user, A1 reply, summary]
|
||||
assert len(msgs) == 3
|
||||
assert msgs[0].role == "user" and "topic Y" in msgs[0].text
|
||||
assert msgs[1].role == "assistant" and "A1 reply" in msgs[1].text
|
||||
assert msgs[2].role == "assistant" and msgs[2].text.startswith("Summary of users:")
|
||||
|
||||
|
||||
async def test_sequential_checkpoint_resume_round_trip() -> None:
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -325,92 +247,6 @@ async def test_sequential_checkpoint_runtime_overrides_buildtime() -> None:
|
||||
assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden"
|
||||
|
||||
|
||||
async def test_sequential_register_participants_with_checkpointing() -> None:
|
||||
"""Test that checkpointing works with register_participants."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
def create_agent1() -> _EchoAgent:
|
||||
return _EchoAgent(id="agent1", name="A1")
|
||||
|
||||
def create_agent2() -> _EchoAgent:
|
||||
return _EchoAgent(id="agent2", name="A2")
|
||||
|
||||
wf = SequentialBuilder(participant_factories=[create_agent1, create_agent2], checkpoint_storage=storage).build()
|
||||
|
||||
baseline_output: list[ChatMessage] | None = None
|
||||
async for ev in wf.run("checkpoint with factories", stream=True):
|
||||
if ev.type == "output":
|
||||
baseline_output = ev.data
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
assert baseline_output is not None
|
||||
|
||||
checkpoints = await storage.list_checkpoints()
|
||||
assert checkpoints
|
||||
checkpoints.sort(key=lambda cp: cp.timestamp)
|
||||
|
||||
resume_checkpoint = next(
|
||||
(cp for cp in checkpoints if (cp.metadata or {}).get("checkpoint_type") == "superstep"),
|
||||
checkpoints[-1],
|
||||
)
|
||||
|
||||
wf_resume = SequentialBuilder(
|
||||
participant_factories=[create_agent1, create_agent2], checkpoint_storage=storage
|
||||
).build()
|
||||
|
||||
resumed_output: list[ChatMessage] | None = None
|
||||
async for ev in wf_resume.run(checkpoint_id=resume_checkpoint.checkpoint_id, stream=True):
|
||||
if ev.type == "output":
|
||||
resumed_output = ev.data
|
||||
if ev.type == "status" and ev.state in (
|
||||
WorkflowRunState.IDLE,
|
||||
WorkflowRunState.IDLE_WITH_PENDING_REQUESTS,
|
||||
):
|
||||
break
|
||||
|
||||
assert resumed_output is not None
|
||||
assert [m.role for m in resumed_output] == [m.role for m in baseline_output]
|
||||
assert [m.text for m in resumed_output] == [m.text for m in baseline_output]
|
||||
|
||||
|
||||
async def test_sequential_register_participants_factories_called_on_build() -> None:
|
||||
"""Test that factories are called during build(), not during register_participants()."""
|
||||
call_count = 0
|
||||
|
||||
def create_agent() -> _EchoAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _EchoAgent(id=f"agent{call_count}", name=f"A{call_count}")
|
||||
|
||||
builder = SequentialBuilder(participant_factories=[create_agent, create_agent])
|
||||
|
||||
# Factories should not be called yet
|
||||
assert call_count == 0
|
||||
|
||||
wf = builder.build()
|
||||
|
||||
# Now factories should have been called
|
||||
assert call_count == 2
|
||||
|
||||
# Run the workflow to ensure it works
|
||||
completed = False
|
||||
output: list[ChatMessage] | None = None
|
||||
async for ev in wf.run("test factories timing", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif ev.type == "output":
|
||||
output = ev.data # type: ignore[assignment]
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
msgs: list[ChatMessage] = output
|
||||
# Should have user message + 2 agent replies
|
||||
assert len(msgs) == 3
|
||||
|
||||
|
||||
async def test_sequential_builder_reusable_after_build_with_participants() -> None:
|
||||
"""Test that the builder can be reused to build multiple identical workflows with participants()."""
|
||||
a1 = _EchoAgent(id="agent1", name="A1")
|
||||
@@ -423,30 +259,3 @@ async def test_sequential_builder_reusable_after_build_with_participants() -> No
|
||||
|
||||
assert builder._participants[0] is a1 # type: ignore
|
||||
assert builder._participants[1] is a2 # type: ignore
|
||||
assert builder._participant_factories == [] # type: ignore
|
||||
|
||||
|
||||
async def test_sequential_builder_reusable_after_build_with_factories() -> None:
|
||||
"""Test that the builder can be reused to build multiple workflows with register_participants()."""
|
||||
call_count = 0
|
||||
|
||||
def create_agent1() -> _EchoAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _EchoAgent(id="agent1", name="A1")
|
||||
|
||||
def create_agent2() -> _EchoAgent:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _EchoAgent(id="agent2", name="A2")
|
||||
|
||||
builder = SequentialBuilder(participant_factories=[create_agent1, create_agent2])
|
||||
|
||||
# Build first workflow - factories should be called
|
||||
builder.build()
|
||||
|
||||
assert call_count == 2
|
||||
assert builder._participants == [] # type: ignore
|
||||
assert len(builder._participant_factories) == 2 # type: ignore
|
||||
assert builder._participant_factories[0] is create_agent1 # type: ignore
|
||||
assert builder._participant_factories[1] is create_agent2 # type: ignore
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
Workflow,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.orchestrations import ConcurrentBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
from typing_extensions import Never
|
||||
|
||||
"""
|
||||
Sample: Concurrent Orchestration with participant factories and Custom Aggregator
|
||||
|
||||
Build a concurrent workflow with ConcurrentBuilder that fans out one prompt to
|
||||
multiple domain agents and fans in their responses.
|
||||
|
||||
Override the default aggregator with a custom Executor class that uses
|
||||
AzureOpenAIChatClient.get_response() to synthesize a concise, consolidated summary
|
||||
from the experts' outputs.
|
||||
|
||||
All participants and the aggregator are created via factory functions that return
|
||||
their respective ChatAgent or Executor instances.
|
||||
|
||||
Using participant factories allows you to set up proper state isolation between workflow
|
||||
instances created by the same builder. This is particularly useful when you need to handle
|
||||
requests or tasks in parallel with stateful participants.
|
||||
|
||||
Demonstrates:
|
||||
- ConcurrentBuilder(participant_factories=[...]).with_aggregator(callback)
|
||||
- Fan-out to agents and fan-in at an aggregator
|
||||
- Aggregation implemented via an LLM call (chat_client.get_response)
|
||||
- Workflow output yielded with the synthesized summary string
|
||||
|
||||
Prerequisites:
|
||||
- Azure OpenAI configured for AzureOpenAIChatClient (az login + required env vars)
|
||||
"""
|
||||
|
||||
|
||||
def create_researcher() -> ChatAgent:
|
||||
"""Factory function to create a researcher agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're an expert market and product researcher. Given a prompt, provide concise, factual insights,"
|
||||
" opportunities, and risks."
|
||||
),
|
||||
name="researcher",
|
||||
)
|
||||
|
||||
|
||||
def create_marketer() -> ChatAgent:
|
||||
"""Factory function to create a marketer agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a creative marketing strategist. Craft compelling value propositions and target messaging"
|
||||
" aligned to the prompt."
|
||||
),
|
||||
name="marketer",
|
||||
)
|
||||
|
||||
|
||||
def create_legal() -> ChatAgent:
|
||||
"""Factory function to create a legal/compliance agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a cautious legal/compliance reviewer. Highlight constraints, disclaimers, and policy concerns"
|
||||
" based on the prompt."
|
||||
),
|
||||
name="legal",
|
||||
)
|
||||
|
||||
|
||||
class SummarizationExecutor(Executor):
|
||||
"""Custom aggregator executor that synthesizes expert outputs into a concise summary."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(id="summarization_executor")
|
||||
self.chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
|
||||
|
||||
@handler
|
||||
async def summarize_results(self, results: list[Any], ctx: WorkflowContext[Never, str]) -> None:
|
||||
expert_sections: list[str] = []
|
||||
for r in results:
|
||||
try:
|
||||
messages = getattr(r.agent_response, "messages", [])
|
||||
final_text = messages[-1].text if messages and hasattr(messages[-1], "text") else "(no content)"
|
||||
expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}:\n{final_text}")
|
||||
except Exception as e:
|
||||
expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}: (error: {type(e).__name__}: {e})")
|
||||
|
||||
# Ask the model to synthesize a concise summary of the experts' outputs
|
||||
system_msg = ChatMessage(
|
||||
"system",
|
||||
text=(
|
||||
"You are a helpful assistant that consolidates multiple domain expert outputs "
|
||||
"into one cohesive, concise summary with clear takeaways. Keep it under 200 words."
|
||||
),
|
||||
)
|
||||
user_msg = ChatMessage("user", text="\n\n".join(expert_sections))
|
||||
|
||||
response = await self.chat_client.get_response([system_msg, user_msg])
|
||||
|
||||
await ctx.yield_output(response.messages[-1].text if response.messages else "")
|
||||
|
||||
|
||||
async def run_workflow(workflow: Workflow, query: str) -> None:
|
||||
events = await workflow.run(query)
|
||||
outputs = events.get_outputs()
|
||||
|
||||
if outputs:
|
||||
print(outputs[0]) # Get the first (and typically only) output
|
||||
else:
|
||||
raise RuntimeError("No outputs received from the workflow.")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a concurrent builder with participant factories and a custom aggregator
|
||||
# - register_participants([...]) accepts factory functions that return
|
||||
# SupportsAgentRun (agents) or Executor instances.
|
||||
# - register_aggregator(...) takes a factory function that returns an Executor instance.
|
||||
concurrent_builder = (
|
||||
ConcurrentBuilder(participant_factories=[create_researcher, create_marketer, create_legal])
|
||||
.register_aggregator(SummarizationExecutor)
|
||||
)
|
||||
|
||||
# Build workflow_a
|
||||
workflow_a = concurrent_builder.build()
|
||||
|
||||
# Run workflow_a
|
||||
# Context is maintained across runs
|
||||
print("=== First Run on workflow_a ===")
|
||||
await run_workflow(workflow_a, "We are launching a new budget-friendly electric bike for urban commuters.")
|
||||
print("\n=== Second Run on workflow_a ===")
|
||||
await run_workflow(workflow_a, "Refine your response to focus on the California market.")
|
||||
|
||||
# Build workflow_b
|
||||
# This will create new instances of all participants and the aggregator
|
||||
# The agents will also get new threads
|
||||
workflow_b = concurrent_builder.build()
|
||||
# Run workflow_b
|
||||
# Context is not maintained across instances
|
||||
# Should not expect mentions of electric bikes in the results
|
||||
print("\n=== First Run on workflow_b ===")
|
||||
await run_workflow(workflow_b, "Refine your response to focus on the California market.")
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
=== First Run on workflow_a ===
|
||||
The budget-friendly electric bike market is poised for significant growth, driven by urbanization, ...
|
||||
|
||||
=== Second Run on workflow_a ===
|
||||
Launching a budget-friendly electric bike in California presents significant opportunities, driven ...
|
||||
|
||||
=== First Run on workflow_b ===
|
||||
To successfully penetrate the California market, consider these tailored strategies focused on ...
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,271 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated, cast
|
||||
|
||||
from agent_framework import (
|
||||
AgentResponse,
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
Workflow,
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
"""Sample: Handoff workflow with participant factories for state isolation.
|
||||
|
||||
This sample demonstrates how to use participant factories in HandoffBuilder to create
|
||||
agents dynamically.
|
||||
|
||||
Using participant factories allows you to set up proper state isolation between workflow
|
||||
instances created by the same builder. This is particularly useful when you need to handle
|
||||
requests or tasks in parallel with stateful participants.
|
||||
|
||||
Routing Pattern:
|
||||
User -> Triage Agent -> Specialist (Refund/Order Status/Return) -> User
|
||||
|
||||
Prerequisites:
|
||||
- `az login` (Azure CLI authentication)
|
||||
- Environment variables for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.)
|
||||
|
||||
Key Concepts:
|
||||
- Participant factories: create agents via factory functions for isolation
|
||||
- State isolation: each workflow instance gets its own agent instances
|
||||
"""
|
||||
|
||||
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
|
||||
# See:
|
||||
# samples/getting_started/tools/function_tool_with_approval.py
|
||||
# samples/getting_started/tools/function_tool_with_approval_and_threads.py.
|
||||
@tool(approval_mode="never_require")
|
||||
def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str:
|
||||
"""Simulated function to process a refund for a given order number."""
|
||||
return f"Refund processed successfully for order {order_number}."
|
||||
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def check_order_status(order_number: Annotated[str, "Order number to check status for"]) -> str:
|
||||
"""Simulated function to check the status of a given order number."""
|
||||
return f"Order {order_number} is currently being processed and will ship in 2 business days."
|
||||
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def process_return(order_number: Annotated[str, "Order number to process return for"]) -> str:
|
||||
"""Simulated function to process a return for a given order number."""
|
||||
return f"Return initiated successfully for order {order_number}. You will receive return instructions via email."
|
||||
|
||||
|
||||
def create_triage_agent() -> ChatAgent:
|
||||
"""Factory function to create a triage agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
|
||||
"based on the problem described."
|
||||
),
|
||||
name="triage_agent",
|
||||
)
|
||||
|
||||
|
||||
def create_refund_agent() -> ChatAgent:
|
||||
"""Factory function to create a refund agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions="You process refund requests.",
|
||||
name="refund_agent",
|
||||
# In a real application, an agent can have multiple tools; here we keep it simple
|
||||
tools=[process_refund],
|
||||
)
|
||||
|
||||
|
||||
def create_order_status_agent() -> ChatAgent:
|
||||
"""Factory function to create an order status agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions="You handle order and shipping inquiries.",
|
||||
name="order_agent",
|
||||
# In a real application, an agent can have multiple tools; here we keep it simple
|
||||
tools=[check_order_status],
|
||||
)
|
||||
|
||||
|
||||
def create_return_agent() -> ChatAgent:
|
||||
"""Factory function to create a return agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions="You manage product return requests.",
|
||||
name="return_agent",
|
||||
# In a real application, an agent can have multiple tools; here we keep it simple
|
||||
tools=[process_return],
|
||||
)
|
||||
|
||||
|
||||
def _handle_events(events: list[WorkflowEvent]) -> list[WorkflowEvent[HandoffAgentUserRequest]]:
|
||||
"""Process workflow events and extract any pending user input requests.
|
||||
|
||||
This function inspects each event type and:
|
||||
- Prints workflow status changes (IDLE, IDLE_WITH_PENDING_REQUESTS, etc.)
|
||||
- Displays final conversation snapshots when workflow completes
|
||||
- Prints user input request prompts
|
||||
- Collects all request_info events for response handling
|
||||
|
||||
Args:
|
||||
events: List of WorkflowEvent to process
|
||||
|
||||
Returns:
|
||||
List of WorkflowEvent[HandoffAgentUserRequest] representing pending user input requests
|
||||
"""
|
||||
requests: list[WorkflowEvent[HandoffAgentUserRequest]] = []
|
||||
|
||||
for event in events:
|
||||
if event.type == "handoff_sent":
|
||||
# handoff_sent event: Indicates a handoff has been initiated
|
||||
print(f"\n[Handoff from {event.data.source} to {event.data.target} initiated.]")
|
||||
elif event.type == "status" and event.state in {
|
||||
WorkflowRunState.IDLE,
|
||||
WorkflowRunState.IDLE_WITH_PENDING_REQUESTS,
|
||||
}:
|
||||
# Status event: Indicates workflow state changes
|
||||
print(f"\n[Workflow Status] {event.state.name}")
|
||||
elif event.type == "output":
|
||||
# Output event: Contains contents generated by the workflow
|
||||
data = event.data
|
||||
if isinstance(data, AgentResponse):
|
||||
for message in data.messages:
|
||||
if not message.text:
|
||||
# Skip messages without text (e.g., tool calls)
|
||||
continue
|
||||
speaker = message.author_name or message.role
|
||||
print(f"- {speaker}: {message.text}")
|
||||
elif event.type == "output":
|
||||
# The output of the handoff workflow is a collection of chat messages from all participants
|
||||
conversation = cast(list[ChatMessage], event.data)
|
||||
if isinstance(conversation, list):
|
||||
print("\n=== Final Conversation Snapshot ===")
|
||||
for message in conversation:
|
||||
speaker = message.author_name or message.role
|
||||
print(f"- {speaker}: {message.text or [content.type for content in message.contents]}")
|
||||
print("===================================")
|
||||
elif event.type == "request_info" and isinstance(event.data, HandoffAgentUserRequest):
|
||||
# Request info event: Workflow is requesting user input
|
||||
_print_handoff_agent_user_request(event.data.agent_response)
|
||||
requests.append(cast(WorkflowEvent[HandoffAgentUserRequest], event))
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
def _print_handoff_agent_user_request(response: AgentResponse) -> None:
|
||||
"""Display the agent's response messages when requesting user input.
|
||||
|
||||
This will happen when an agent generates a response that doesn't trigger
|
||||
a handoff, i.e., the agent is asking the user for more information.
|
||||
|
||||
Args:
|
||||
response: The AgentResponse from the agent requesting user input
|
||||
"""
|
||||
if not response.messages:
|
||||
raise RuntimeError("Cannot print agent responses: response has no messages.")
|
||||
|
||||
print("\n[Agent is requesting your input...]")
|
||||
|
||||
# Print agent responses
|
||||
for message in response.messages:
|
||||
if not message.text:
|
||||
# Skip messages without text (e.g., tool calls)
|
||||
continue
|
||||
speaker = message.author_name or message.role
|
||||
print(f"- {speaker}: {message.text}")
|
||||
|
||||
|
||||
async def _run_workflow(workflow: Workflow, user_inputs: list[str]) -> None:
|
||||
"""Run the workflow with the given user input and display events."""
|
||||
print(f"- User: {user_inputs[0]}")
|
||||
workflow_result = await workflow.run(user_inputs[0])
|
||||
pending_requests = _handle_events(workflow_result)
|
||||
|
||||
# Process the request/response cycle
|
||||
# The workflow will continue requesting input until:
|
||||
# 1. The termination condition is met (4 user messages in this case), OR
|
||||
# 2. We run out of scripted responses
|
||||
while pending_requests:
|
||||
if user_inputs[1:]:
|
||||
# Get the next scripted response
|
||||
user_response = user_inputs.pop(1)
|
||||
print(f"\n- User: {user_response}")
|
||||
|
||||
# Send response(s) to all pending requests
|
||||
# In this demo, there's typically one request per cycle, but the API supports multiple
|
||||
responses = {
|
||||
req.request_id: HandoffAgentUserRequest.create_response(user_response) for req in pending_requests
|
||||
}
|
||||
else:
|
||||
# No more scripted responses; terminate the workflow
|
||||
responses = {req.request_id: HandoffAgentUserRequest.terminate() for req in pending_requests}
|
||||
|
||||
# Send responses and get new events
|
||||
# We use run(responses=...) to get events, allowing us to
|
||||
# display agent responses and handle new requests as they arrive
|
||||
workflow_result = await workflow.run(responses=responses)
|
||||
pending_requests = _handle_events(workflow_result)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the autonomous handoff workflow with participant factories."""
|
||||
# Build the handoff workflow using participant factories
|
||||
# termination_condition: Custom termination that checks if the triage agent has provided a closing message.
|
||||
# This looks for the last message being from triage_agent and containing "welcome",
|
||||
# which indicates the conversation has concluded naturally.
|
||||
workflow_builder = (
|
||||
HandoffBuilder(
|
||||
name="Autonomous Handoff with Participant Factories",
|
||||
participant_factories={
|
||||
"triage": create_triage_agent,
|
||||
"refund": create_refund_agent,
|
||||
"order_status": create_order_status_agent,
|
||||
"return": create_return_agent,
|
||||
},
|
||||
termination_condition=lambda conversation: (
|
||||
len(conversation) > 0
|
||||
and conversation[-1].author_name == "triage_agent"
|
||||
and "welcome" in conversation[-1].text.lower()
|
||||
),
|
||||
)
|
||||
.with_start_agent("triage")
|
||||
)
|
||||
|
||||
# Scripted user responses for reproducible demo
|
||||
# In a console application, replace this with:
|
||||
# user_input = input("Your response: ")
|
||||
# or integrate with a UI/chat interface
|
||||
user_inputs = [
|
||||
"Hello, I need assistance with my recent purchase.",
|
||||
"My order 1234 arrived damaged and the packaging was destroyed. I'd like to return it.",
|
||||
"Is my return being processed?",
|
||||
"Thanks for resolving this.",
|
||||
]
|
||||
|
||||
workflow_a = workflow_builder.build()
|
||||
print("=== Running workflow_a ===")
|
||||
await _run_workflow(workflow_a, list(user_inputs))
|
||||
|
||||
workflow_b = workflow_builder.build()
|
||||
print("=== Running workflow_b ===")
|
||||
# Only provide the last two inputs to workflow_b to demonstrate state isolation
|
||||
# The agents in this workflow have no prior context thus should not have knowledge of
|
||||
# order 1234 or previous interactions.
|
||||
await _run_workflow(workflow_b, user_inputs[2:])
|
||||
"""
|
||||
Expected behavior:
|
||||
- workflow_a and workflow_b maintain separate states for their participants.
|
||||
- Each workflow processes its requests independently without interference.
|
||||
- workflow_a will answer the follow-up request based on its own conversation history,
|
||||
while workflow_b will provide a general answer without prior context.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
Workflow,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
"""
|
||||
Sample: Sequential workflow with participant factories
|
||||
|
||||
This sample demonstrates how to create a sequential workflow with participant factories.
|
||||
|
||||
Using participant factories allows you to set up proper state isolation between workflow
|
||||
instances created by the same builder. This is particularly useful when you need to handle
|
||||
requests or tasks in parallel with stateful participants.
|
||||
|
||||
In this example, we create a sequential workflow with two participants: an accumulator
|
||||
and a content producer. The accumulator is stateful and maintains a list of all messages it has
|
||||
received. Context is maintained across runs of the same workflow instance but not across different
|
||||
workflow instances.
|
||||
"""
|
||||
|
||||
|
||||
class Accumulate(Executor):
|
||||
"""Simple accumulator.
|
||||
|
||||
Accumulates all messages from the conversation and prints them out.
|
||||
"""
|
||||
|
||||
def __init__(self, id: str):
|
||||
super().__init__(id)
|
||||
# Some internal state to accumulate messages
|
||||
self._accumulated: list[str] = []
|
||||
|
||||
@handler
|
||||
async def accumulate(self, conversation: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None:
|
||||
self._accumulated.extend([msg.text for msg in conversation])
|
||||
print(f"Number of queries received so far: {len(self._accumulated)}")
|
||||
await ctx.send_message(conversation)
|
||||
|
||||
|
||||
def create_agent() -> ChatAgent:
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions="Produce a concise paragraph answering the user's request.",
|
||||
name="ContentProducer",
|
||||
)
|
||||
|
||||
|
||||
async def run_workflow(workflow: Workflow, query: str) -> None:
|
||||
events = await workflow.run(query)
|
||||
outputs = events.get_outputs()
|
||||
|
||||
if outputs:
|
||||
messages: list[ChatMessage] = outputs[0]
|
||||
for message in messages:
|
||||
name = message.author_name or ("assistant" if message.role == "assistant" else "user")
|
||||
print(f"{name}: {message.text}")
|
||||
else:
|
||||
raise RuntimeError("No outputs received from the workflow.")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# 1) Create a builder with participant factories
|
||||
builder = SequentialBuilder(participant_factories=[
|
||||
lambda: Accumulate("accumulator"),
|
||||
create_agent,
|
||||
])
|
||||
# 2) Build workflow_a
|
||||
workflow_a = builder.build()
|
||||
|
||||
# 3) Run workflow_a
|
||||
# Context is maintained across runs
|
||||
print("=== First Run on workflow_a ===")
|
||||
await run_workflow(workflow_a, "Why is the sky blue?")
|
||||
print("\n=== Second Run on workflow_a ===")
|
||||
await run_workflow(workflow_a, "Repeat my previous question.")
|
||||
|
||||
# 4) Build workflow_b
|
||||
# This will create a new instance of the accumulator and content producer
|
||||
# using the same workflow builder
|
||||
workflow_b = builder.build()
|
||||
|
||||
# 5) Run workflow_b
|
||||
# Context is not maintained across instances
|
||||
print("\n=== First Run on workflow_b ===")
|
||||
await run_workflow(workflow_b, "Repeat my previous question.")
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
=== First Run on workflow_a ===
|
||||
Number of queries received so far: 1
|
||||
user: Why is the sky blue?
|
||||
ContentProducer: The sky appears blue due to a phenomenon called Rayleigh scattering.
|
||||
When sunlight enters the Earth's atmosphere, it collides with gases
|
||||
and particles, scattering shorter wavelengths of light (blue and violet)
|
||||
more than the longer wavelengths (red and yellow). Although violet light
|
||||
is scattered even more than blue, our eyes are more sensitive to blue
|
||||
light, and some violet light is absorbed by the ozone layer. As a result,
|
||||
we perceive the sky as predominantly blue during the day.
|
||||
|
||||
=== Second Run on workflow_a ===
|
||||
Number of queries received so far: 2
|
||||
user: Repeat my previous question.
|
||||
ContentProducer: Why is the sky blue?
|
||||
|
||||
=== First Run on workflow_b ===
|
||||
Number of queries received so far: 1
|
||||
user: Repeat my previous question.
|
||||
ContentProducer: I'm sorry, but I can't repeat your previous question as I don't have
|
||||
access to your past queries. However, feel free to ask anything again,
|
||||
and I'll be happy to help!
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
Workflow,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
executor,
|
||||
@@ -48,6 +49,11 @@ What this example shows
|
||||
- Fluent WorkflowBuilder API:
|
||||
add_edge(A, B) to connect nodes, set_start_executor(A), then build() -> Workflow.
|
||||
|
||||
- State isolation via helper functions:
|
||||
Wrapping executor instantiation and workflow building inside a function
|
||||
(e.g., create_workflow()) ensures each call produces fresh, independent
|
||||
instances. This is the recommended pattern for reuse.
|
||||
|
||||
- Running and results:
|
||||
workflow.run(initial_input) executes the graph. Terminal nodes yield
|
||||
outputs using ctx.yield_output(). The workflow runs until idle.
|
||||
@@ -152,18 +158,28 @@ class ExclamationAdder(Executor):
|
||||
await ctx.send_message(result) # type: ignore
|
||||
|
||||
|
||||
def create_workflow() -> Workflow:
|
||||
"""Create a fresh workflow with isolated state.
|
||||
|
||||
Wrapping workflow construction in a helper function ensures each call
|
||||
produces independent executor instances. This is the recommended pattern
|
||||
for reuse — call create_workflow() each time you need a new workflow so
|
||||
that no state leaks between runs.
|
||||
"""
|
||||
upper_case = UpperCase(id="upper_case_executor")
|
||||
|
||||
return WorkflowBuilder(start_executor=upper_case).add_edge(upper_case, reverse_text).build()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Build and run workflows using the fluent builder API."""
|
||||
|
||||
# Workflow 1: Using introspection-based type detection
|
||||
# -----------------------------------------------------
|
||||
upper_case = UpperCase(id="upper_case_executor")
|
||||
|
||||
# Build the workflow using a fluent pattern:
|
||||
# 1) start_executor=... in constructor declares the entry point
|
||||
# 2) add_edge(from_node, to_node) defines a directed edge upper_case -> reverse_text
|
||||
# 3) build() finalizes and returns an immutable Workflow object
|
||||
workflow1 = WorkflowBuilder(start_executor=upper_case).add_edge(upper_case, reverse_text).build()
|
||||
# Workflow 1: Using the helper function pattern for state isolation
|
||||
# ------------------------------------------------------------------
|
||||
# Each call to create_workflow() returns a workflow with fresh executor
|
||||
# instances. This is the recommended pattern when you need to run the
|
||||
# same workflow topology multiple times with clean state.
|
||||
workflow1 = create_workflow()
|
||||
|
||||
# Run the workflow by sending the initial message to the start node.
|
||||
# The run(...) call returns an event collection; its get_outputs() method
|
||||
@@ -175,6 +191,7 @@ async def main():
|
||||
|
||||
# Workflow 2: Using explicit type parameters on @handler
|
||||
# -------------------------------------------------------
|
||||
upper_case = UpperCase(id="upper_case_executor")
|
||||
exclamation_adder = ExclamationAdder(id="exclamation_adder")
|
||||
|
||||
# This workflow demonstrates the explicit input/output feature:
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
ChatAgent,
|
||||
Executor,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
executor,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
"""
|
||||
Step 4: Using Factories to Define Executors and Agents
|
||||
|
||||
What this example shows
|
||||
- Defining custom executors using both class-based and function-based approaches.
|
||||
- Registering executor and agent factories with WorkflowBuilder for lazy instantiation.
|
||||
- Building a simple workflow that transforms input text through multiple steps.
|
||||
|
||||
Benefits of using factories
|
||||
- Decouples executor and agent creation from workflow definition.
|
||||
- Isolated instances are created for workflow builder build, allowing for cleaner state management
|
||||
and handling parallel workflow runs.
|
||||
|
||||
It is recommended to use factories when defining executors and agents for production workflows.
|
||||
|
||||
Prerequisites
|
||||
- No external services required.
|
||||
"""
|
||||
|
||||
|
||||
class UpperCase(Executor):
|
||||
def __init__(self, id: str):
|
||||
super().__init__(id=id)
|
||||
|
||||
@handler
|
||||
async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None:
|
||||
"""Convert the input to uppercase and forward it to the next node."""
|
||||
result = text.upper()
|
||||
|
||||
# Send the result to the next executor in the workflow.
|
||||
await ctx.send_message(result)
|
||||
|
||||
|
||||
@executor(id="reverse_text_executor")
|
||||
async def reverse_text(text: str, ctx: WorkflowContext[str]) -> None:
|
||||
"""Reverse the input string and send it downstream."""
|
||||
result = text[::-1]
|
||||
|
||||
# Send the result to the next executor in the workflow.
|
||||
await ctx.send_message(result)
|
||||
|
||||
|
||||
def create_agent() -> ChatAgent:
|
||||
"""Factory function to create a Writer agent."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=("You decode messages. Try to reconstruct the original message."),
|
||||
name="decoder",
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Build and run a simple 2-step workflow using the fluent builder API."""
|
||||
# Build the workflow using a fluent pattern:
|
||||
# 1) register_executor(factory, name) registers an executor factory
|
||||
# 2) register_agent(factory, name) registers an agent factory
|
||||
# 3) add_chain([node_names]) adds a sequence of nodes to the workflow
|
||||
# 4) set_start_executor(node) declares the entry point
|
||||
# 5) build() finalizes and returns an immutable Workflow object
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="UpperCase")
|
||||
.register_executor(lambda: UpperCase(id="upper_case_executor"), name="UpperCase")
|
||||
.register_executor(lambda: reverse_text, name="ReverseText")
|
||||
.register_agent(create_agent, name="DecoderAgent")
|
||||
.add_chain(["UpperCase", "ReverseText", "DecoderAgent"])
|
||||
.build()
|
||||
)
|
||||
|
||||
first_update = True
|
||||
async for event in workflow.run("hello world", stream=True):
|
||||
# The outputs of the workflow are whatever the agents produce. So the events are expected to
|
||||
# contain `AgentResponseUpdate` from the agents in the workflow.
|
||||
if event.type == "output" and isinstance(event.data, AgentResponseUpdate):
|
||||
update = event.data
|
||||
if first_update:
|
||||
print(f"{update.author_name}: {update.text}", end="", flush=True)
|
||||
first_update = False
|
||||
else:
|
||||
print(update.text, end="", flush=True)
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
decoder: HELLO WORLD
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
+6
-8
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatMessageStore,
|
||||
@@ -70,15 +71,12 @@ async def main() -> None:
|
||||
# Set the message store to store messages in memory.
|
||||
shared_thread.message_store = ChatMessageStore()
|
||||
|
||||
writer_executor = AgentExecutor(writer, agent_thread=shared_thread)
|
||||
reviewer_executor = AgentExecutor(reviewer, agent_thread=shared_thread)
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="writer")
|
||||
.register_agent(factory_func=lambda: writer, name="writer", agent_thread=shared_thread)
|
||||
.register_agent(factory_func=lambda: reviewer, name="reviewer", agent_thread=shared_thread)
|
||||
.register_executor(
|
||||
factory_func=lambda: intercept_agent_response,
|
||||
name="intercept_agent_response",
|
||||
)
|
||||
.add_chain(["writer", "intercept_agent_response", "reviewer"])
|
||||
WorkflowBuilder(start_executor=writer_executor)
|
||||
.add_chain([writer_executor, intercept_agent_response, reviewer_executor])
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+14
-15
@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
AgentResponse,
|
||||
@@ -239,22 +240,20 @@ async def main() -> None:
|
||||
"""Run the workflow and bridge human feedback between two agents."""
|
||||
|
||||
# Build the workflow.
|
||||
writer_agent = AgentExecutor(create_writer_agent())
|
||||
final_editor_agent = AgentExecutor(create_final_editor_agent())
|
||||
coordinator = Coordinator(
|
||||
id="coordinator",
|
||||
writer_id="writer_agent",
|
||||
final_editor_id="final_editor_agent",
|
||||
)
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="writer_agent")
|
||||
.register_agent(create_writer_agent, name="writer_agent")
|
||||
.register_agent(create_final_editor_agent, name="final_editor_agent")
|
||||
.register_executor(
|
||||
lambda: Coordinator(
|
||||
id="coordinator",
|
||||
writer_id="writer_agent",
|
||||
final_editor_id="final_editor_agent",
|
||||
),
|
||||
name="coordinator",
|
||||
)
|
||||
.add_edge("writer_agent", "coordinator")
|
||||
.add_edge("coordinator", "writer_agent")
|
||||
.add_edge("final_editor_agent", "coordinator")
|
||||
.add_edge("coordinator", "final_editor_agent")
|
||||
WorkflowBuilder(start_executor=writer_agent)
|
||||
.add_edge(writer_agent, coordinator)
|
||||
.add_edge(coordinator, writer_agent)
|
||||
.add_edge(final_editor_agent, coordinator)
|
||||
.add_edge(coordinator, final_editor_agent)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+9
-14
@@ -98,21 +98,16 @@ async def main() -> None:
|
||||
print("Building workflow with Worker-Reviewer cycle...")
|
||||
# Build a workflow with bidirectional communication between Worker and Reviewer,
|
||||
# and escalation paths for human review.
|
||||
worker = Worker(
|
||||
id="worker",
|
||||
chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()),
|
||||
)
|
||||
reviewer = ReviewerWithHumanInTheLoop(worker_id="worker")
|
||||
|
||||
agent = (
|
||||
WorkflowBuilder(start_executor="worker")
|
||||
.register_executor(
|
||||
lambda: Worker(
|
||||
id="sub-worker",
|
||||
chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()),
|
||||
),
|
||||
name="worker",
|
||||
)
|
||||
.register_executor(
|
||||
lambda: ReviewerWithHumanInTheLoop(worker_id="sub-worker"),
|
||||
name="reviewer",
|
||||
)
|
||||
.add_edge("worker", "reviewer") # Worker sends requests to Reviewer
|
||||
.add_edge("reviewer", "worker") # Reviewer sends feedback to Worker
|
||||
WorkflowBuilder(start_executor=worker)
|
||||
.add_edge(worker, reviewer) # Worker sends requests to Reviewer
|
||||
.add_edge(reviewer, worker) # Reviewer sends feedback to Worker
|
||||
.build()
|
||||
.as_agent() # Convert workflow into an agent interface
|
||||
)
|
||||
|
||||
+6
-11
@@ -186,18 +186,13 @@ async def main() -> None:
|
||||
print("=" * 50)
|
||||
|
||||
print("Building workflow with Worker ↔ Reviewer cycle...")
|
||||
worker = Worker(id="worker", chat_client=OpenAIChatClient(model_id="gpt-4.1-nano"))
|
||||
reviewer = Reviewer(id="reviewer", chat_client=OpenAIChatClient(model_id="gpt-4.1"))
|
||||
|
||||
agent = (
|
||||
WorkflowBuilder(start_executor="worker")
|
||||
.register_executor(
|
||||
lambda: Worker(id="worker", chat_client=OpenAIChatClient(model_id="gpt-4.1-nano")),
|
||||
name="worker",
|
||||
)
|
||||
.register_executor(
|
||||
lambda: Reviewer(id="reviewer", chat_client=OpenAIChatClient(model_id="gpt-4.1")),
|
||||
name="reviewer",
|
||||
)
|
||||
.add_edge("worker", "reviewer") # Worker sends responses to Reviewer
|
||||
.add_edge("reviewer", "worker") # Reviewer provides feedback to Worker
|
||||
WorkflowBuilder(start_executor=worker)
|
||||
.add_edge(worker, reviewer) # Worker sends responses to Reviewer
|
||||
.add_edge(reviewer, worker) # Reviewer provides feedback to Worker
|
||||
.build()
|
||||
.as_agent() # Wrap workflow as an agent
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent_framework import AgentThread, ChatAgent, ChatMessageStore
|
||||
from agent_framework import AgentThread, ChatMessageStore
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
|
||||
@@ -39,27 +39,24 @@ async def main() -> None:
|
||||
# Create a chat client
|
||||
chat_client = OpenAIChatClient()
|
||||
|
||||
# Define factory functions for workflow participants
|
||||
def create_assistant() -> ChatAgent:
|
||||
return chat_client.as_agent(
|
||||
name="assistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Answer questions based on the conversation "
|
||||
"history. If the user asks about something mentioned earlier, reference it."
|
||||
),
|
||||
)
|
||||
assistant = chat_client.as_agent(
|
||||
name="assistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Answer questions based on the conversation "
|
||||
"history. If the user asks about something mentioned earlier, reference it."
|
||||
),
|
||||
)
|
||||
|
||||
def create_summarizer() -> ChatAgent:
|
||||
return chat_client.as_agent(
|
||||
name="summarizer",
|
||||
instructions=(
|
||||
"You are a summarizer. After the assistant responds, provide a brief "
|
||||
"one-sentence summary of the key point from the conversation so far."
|
||||
),
|
||||
)
|
||||
summarizer = chat_client.as_agent(
|
||||
name="summarizer",
|
||||
instructions=(
|
||||
"You are a summarizer. After the assistant responds, provide a brief "
|
||||
"one-sentence summary of the key point from the conversation so far."
|
||||
),
|
||||
)
|
||||
|
||||
# Build a sequential workflow: assistant -> summarizer
|
||||
workflow = SequentialBuilder(participant_factories=[create_assistant, create_summarizer]).build()
|
||||
workflow = SequentialBuilder(participants=[assistant, summarizer]).build()
|
||||
|
||||
# Wrap the workflow as an agent
|
||||
agent = workflow.as_agent(name="ConversationalWorkflowAgent")
|
||||
@@ -124,13 +121,12 @@ async def demonstrate_thread_serialization() -> None:
|
||||
"""
|
||||
chat_client = OpenAIChatClient()
|
||||
|
||||
def create_assistant() -> ChatAgent:
|
||||
return chat_client.as_agent(
|
||||
name="memory_assistant",
|
||||
instructions="You are a helpful assistant with good memory. Remember details from our conversation.",
|
||||
)
|
||||
memory_assistant = chat_client.as_agent(
|
||||
name="memory_assistant",
|
||||
instructions="You are a helpful assistant with good memory. Remember details from our conversation.",
|
||||
)
|
||||
|
||||
workflow = SequentialBuilder(participant_factories=[create_assistant]).build()
|
||||
workflow = SequentialBuilder(participants=[memory_assistant]).build()
|
||||
agent = workflow.as_agent(name="MemoryWorkflowAgent")
|
||||
|
||||
# Create initial thread and have a conversation
|
||||
|
||||
+13
-14
@@ -17,6 +17,7 @@ else:
|
||||
# `agent_framework.builtin` chat client or mock the writer executor. We keep the
|
||||
# concrete import here so readers can see an end-to-end configuration.
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatMessage,
|
||||
@@ -178,23 +179,21 @@ def create_workflow(checkpoint_storage: FileCheckpointStorage) -> Workflow:
|
||||
# Wire the workflow DAG. Edges mirror the numbered steps described in the
|
||||
# module docstring. Because `WorkflowBuilder` is declarative, reading these
|
||||
# edges is often the quickest way to understand execution order.
|
||||
writer_agent = AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions="Write concise, warm release notes that sound human and helpful.",
|
||||
name="writer",
|
||||
)
|
||||
writer = AgentExecutor(writer_agent)
|
||||
review_gateway = ReviewGateway(id="review_gateway", writer_id="writer")
|
||||
prepare_brief = BriefPreparer(id="prepare_brief", agent_id="writer")
|
||||
|
||||
workflow_builder = (
|
||||
WorkflowBuilder(
|
||||
max_iterations=6, start_executor="prepare_brief", checkpoint_storage=checkpoint_storage
|
||||
max_iterations=6, start_executor=prepare_brief, checkpoint_storage=checkpoint_storage
|
||||
)
|
||||
.register_agent(
|
||||
lambda: AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions="Write concise, warm release notes that sound human and helpful.",
|
||||
# The agent name is stable across runs which keeps checkpoints deterministic.
|
||||
name="writer",
|
||||
),
|
||||
name="writer",
|
||||
)
|
||||
.register_executor(lambda: ReviewGateway(id="review_gateway", writer_id="writer"), name="review_gateway")
|
||||
.register_executor(lambda: BriefPreparer(id="prepare_brief", agent_id="writer"), name="prepare_brief")
|
||||
.add_edge("prepare_brief", "writer")
|
||||
.add_edge("writer", "review_gateway")
|
||||
.add_edge("review_gateway", "writer") # revisions loop
|
||||
.add_edge(prepare_brief, writer)
|
||||
.add_edge(writer, review_gateway)
|
||||
.add_edge(review_gateway, writer) # revisions loop
|
||||
)
|
||||
|
||||
return workflow_builder.build()
|
||||
|
||||
@@ -105,12 +105,12 @@ class WorkerExecutor(Executor):
|
||||
async def main():
|
||||
# Build workflow with checkpointing enabled
|
||||
checkpoint_storage = InMemoryCheckpointStorage()
|
||||
start = StartExecutor(id="start")
|
||||
worker = WorkerExecutor(id="worker")
|
||||
workflow_builder = (
|
||||
WorkflowBuilder(start_executor="start", checkpoint_storage=checkpoint_storage)
|
||||
.register_executor(lambda: StartExecutor(id="start"), name="start")
|
||||
.register_executor(lambda: WorkerExecutor(id="worker"), name="worker")
|
||||
.add_edge("start", "worker")
|
||||
.add_edge("worker", "worker") # Self-loop for iterative processing
|
||||
WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage)
|
||||
.add_edge(start, worker)
|
||||
.add_edge(worker, worker) # Self-loop for iterative processing
|
||||
)
|
||||
|
||||
# Run workflow with automatic checkpoint recovery
|
||||
|
||||
@@ -297,14 +297,14 @@ class LaunchCoordinator(Executor):
|
||||
|
||||
def build_sub_workflow() -> WorkflowExecutor:
|
||||
"""Assemble the sub-workflow used by the parent workflow executor."""
|
||||
writer = DraftWriter()
|
||||
router = DraftReviewRouter()
|
||||
finaliser = DraftFinaliser()
|
||||
sub_workflow = (
|
||||
WorkflowBuilder(start_executor="writer")
|
||||
.register_executor(DraftWriter, name="writer")
|
||||
.register_executor(DraftReviewRouter, name="router")
|
||||
.register_executor(DraftFinaliser, name="finaliser")
|
||||
.add_edge("writer", "router")
|
||||
.add_edge("router", "finaliser")
|
||||
.add_edge("finaliser", "writer") # permits revision loops
|
||||
WorkflowBuilder(start_executor=writer)
|
||||
.add_edge(writer, router)
|
||||
.add_edge(router, finaliser)
|
||||
.add_edge(finaliser, writer) # permits revision loops
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -313,12 +313,12 @@ def build_sub_workflow() -> WorkflowExecutor:
|
||||
|
||||
def build_parent_workflow(storage: FileCheckpointStorage) -> Workflow:
|
||||
"""Assemble the parent workflow that embeds the sub-workflow."""
|
||||
coordinator = LaunchCoordinator()
|
||||
sub_executor = build_sub_workflow()
|
||||
return (
|
||||
WorkflowBuilder(start_executor="coordinator", checkpoint_storage=storage)
|
||||
.register_executor(LaunchCoordinator, name="coordinator")
|
||||
.register_executor(build_sub_workflow, name="sub_executor")
|
||||
.add_edge("coordinator", "sub_executor")
|
||||
.add_edge("sub_executor", "coordinator")
|
||||
WorkflowBuilder(start_executor=coordinator, checkpoint_storage=storage)
|
||||
.add_edge(coordinator, sub_executor)
|
||||
.add_edge(sub_executor, coordinator)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+19
-25
@@ -27,7 +27,6 @@ import asyncio
|
||||
|
||||
from agent_framework import (
|
||||
AgentThread,
|
||||
ChatAgent,
|
||||
ChatMessageStore,
|
||||
InMemoryCheckpointStorage,
|
||||
)
|
||||
@@ -43,20 +42,17 @@ async def basic_checkpointing() -> None:
|
||||
|
||||
chat_client = OpenAIChatClient()
|
||||
|
||||
def create_assistant() -> ChatAgent:
|
||||
return chat_client.as_agent(
|
||||
name="assistant",
|
||||
instructions="You are a helpful assistant. Keep responses brief.",
|
||||
)
|
||||
assistant = chat_client.as_agent(
|
||||
name="assistant",
|
||||
instructions="You are a helpful assistant. Keep responses brief.",
|
||||
)
|
||||
|
||||
def create_reviewer() -> ChatAgent:
|
||||
return chat_client.as_agent(
|
||||
name="reviewer",
|
||||
instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.",
|
||||
)
|
||||
reviewer = chat_client.as_agent(
|
||||
name="reviewer",
|
||||
instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.",
|
||||
)
|
||||
|
||||
# Build sequential workflow with participant factories
|
||||
workflow = SequentialBuilder(participant_factories=[create_assistant, create_reviewer]).build()
|
||||
workflow = SequentialBuilder(participants=[assistant, reviewer]).build()
|
||||
agent = workflow.as_agent(name="CheckpointedAgent")
|
||||
|
||||
# Create checkpoint storage
|
||||
@@ -87,13 +83,12 @@ async def checkpointing_with_thread() -> None:
|
||||
|
||||
chat_client = OpenAIChatClient()
|
||||
|
||||
def create_assistant() -> ChatAgent:
|
||||
return chat_client.as_agent(
|
||||
name="memory_assistant",
|
||||
instructions="You are a helpful assistant with good memory. Reference previous conversation when relevant.",
|
||||
)
|
||||
assistant = chat_client.as_agent(
|
||||
name="memory_assistant",
|
||||
instructions="You are a helpful assistant with good memory. Reference previous conversation when relevant.",
|
||||
)
|
||||
|
||||
workflow = SequentialBuilder(participant_factories=[create_assistant]).build()
|
||||
workflow = SequentialBuilder(participants=[assistant]).build()
|
||||
agent = workflow.as_agent(name="MemoryAgent")
|
||||
|
||||
# Create both thread (for conversation) and checkpoint storage (for workflow state)
|
||||
@@ -131,13 +126,12 @@ async def streaming_with_checkpoints() -> None:
|
||||
|
||||
chat_client = OpenAIChatClient()
|
||||
|
||||
def create_assistant() -> ChatAgent:
|
||||
return chat_client.as_agent(
|
||||
name="streaming_assistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
assistant = chat_client.as_agent(
|
||||
name="streaming_assistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
workflow = SequentialBuilder(participant_factories=[create_assistant]).build()
|
||||
workflow = SequentialBuilder(participants=[assistant]).build()
|
||||
agent = workflow.as_agent(name="StreamingCheckpointAgent")
|
||||
|
||||
checkpoint_storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -140,9 +140,9 @@ def create_sub_workflow() -> WorkflowExecutor:
|
||||
"""Create the text processing sub-workflow."""
|
||||
print("🚀 Setting up sub-workflow...")
|
||||
|
||||
text_processor = TextProcessor()
|
||||
processing_workflow = (
|
||||
WorkflowBuilder(start_executor="text_processor")
|
||||
.register_executor(TextProcessor, name="text_processor")
|
||||
WorkflowBuilder(start_executor=text_processor)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -153,12 +153,12 @@ async def main():
|
||||
"""Main function to run the basic sub-workflow example."""
|
||||
print("🔧 Setting up parent workflow...")
|
||||
# Step 1: Create the parent workflow
|
||||
orchestrator = TextProcessingOrchestrator()
|
||||
sub_workflow_executor = create_sub_workflow()
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor="text_orchestrator")
|
||||
.register_executor(TextProcessingOrchestrator, name="text_orchestrator")
|
||||
.register_executor(create_sub_workflow, name="text_processor_workflow")
|
||||
.add_edge("text_orchestrator", "text_processor_workflow")
|
||||
.add_edge("text_processor_workflow", "text_orchestrator")
|
||||
WorkflowBuilder(start_executor=orchestrator)
|
||||
.add_edge(orchestrator, sub_workflow_executor)
|
||||
.add_edge(sub_workflow_executor, orchestrator)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+26
-28
@@ -169,17 +169,18 @@ def build_resource_request_distribution_workflow() -> Workflow:
|
||||
elif len(self._responses) > self._request_count:
|
||||
raise ValueError("Received more responses than expected")
|
||||
|
||||
orchestrator = RequestDistribution("orchestrator")
|
||||
resource_requester = ResourceRequester("resource_requester")
|
||||
policy_checker = PolicyChecker("policy_checker")
|
||||
result_collector = ResultCollector("result_collector")
|
||||
|
||||
return (
|
||||
WorkflowBuilder(start_executor="orchestrator")
|
||||
.register_executor(lambda: RequestDistribution("orchestrator"), name="orchestrator")
|
||||
.register_executor(lambda: ResourceRequester("resource_requester"), name="resource_requester")
|
||||
.register_executor(lambda: PolicyChecker("policy_checker"), name="policy_checker")
|
||||
.register_executor(lambda: ResultCollector("result_collector"), name="result_collector")
|
||||
.add_edge("orchestrator", "resource_requester")
|
||||
.add_edge("orchestrator", "policy_checker")
|
||||
.add_edge("resource_requester", "result_collector")
|
||||
.add_edge("policy_checker", "result_collector")
|
||||
.add_edge("orchestrator", "result_collector") # For request count
|
||||
WorkflowBuilder(start_executor=orchestrator)
|
||||
.add_edge(orchestrator, resource_requester)
|
||||
.add_edge(orchestrator, policy_checker)
|
||||
.add_edge(resource_requester, result_collector)
|
||||
.add_edge(policy_checker, result_collector)
|
||||
.add_edge(orchestrator, result_collector) # For request count
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -287,25 +288,22 @@ class PolicyEngine(Executor):
|
||||
|
||||
async def main() -> None:
|
||||
# Build the main workflow
|
||||
resource_allocator = ResourceAllocator("resource_allocator")
|
||||
policy_engine = PolicyEngine("policy_engine")
|
||||
sub_workflow_executor = WorkflowExecutor(
|
||||
build_resource_request_distribution_workflow(),
|
||||
"sub_workflow_executor",
|
||||
# Setting allow_direct_output=True to let the sub-workflow output directly.
|
||||
# This is because the sub-workflow is the both the entry point and the exit
|
||||
# point of the main workflow.
|
||||
allow_direct_output=True,
|
||||
)
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor="sub_workflow_executor")
|
||||
.register_executor(lambda: ResourceAllocator("resource_allocator"), name="resource_allocator")
|
||||
.register_executor(lambda: PolicyEngine("policy_engine"), name="policy_engine")
|
||||
.register_executor(
|
||||
lambda: WorkflowExecutor(
|
||||
build_resource_request_distribution_workflow(),
|
||||
"sub_workflow_executor",
|
||||
# Setting allow_direct_output=True to let the sub-workflow output directly.
|
||||
# This is because the sub-workflow is the both the entry point and the exit
|
||||
# point of the main workflow.
|
||||
allow_direct_output=True,
|
||||
),
|
||||
name="sub_workflow_executor",
|
||||
)
|
||||
.add_edge("sub_workflow_executor", "resource_allocator")
|
||||
.add_edge("resource_allocator", "sub_workflow_executor")
|
||||
.add_edge("sub_workflow_executor", "policy_engine")
|
||||
.add_edge("policy_engine", "sub_workflow_executor")
|
||||
WorkflowBuilder(start_executor=sub_workflow_executor)
|
||||
.add_edge(sub_workflow_executor, resource_allocator)
|
||||
.add_edge(resource_allocator, sub_workflow_executor)
|
||||
.add_edge(sub_workflow_executor, policy_engine)
|
||||
.add_edge(policy_engine, sub_workflow_executor)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+15
-19
@@ -153,13 +153,14 @@ def build_email_address_validation_workflow() -> Workflow:
|
||||
)
|
||||
|
||||
# Build the workflow
|
||||
email_sanitizer = EmailSanitizer(id="email_sanitizer")
|
||||
email_format_validator = EmailFormatValidator(id="email_format_validator")
|
||||
domain_validator = DomainValidator(id="domain_validator")
|
||||
|
||||
return (
|
||||
WorkflowBuilder(start_executor="email_sanitizer")
|
||||
.register_executor(lambda: EmailSanitizer(id="email_sanitizer"), name="email_sanitizer")
|
||||
.register_executor(lambda: EmailFormatValidator(id="email_format_validator"), name="email_format_validator")
|
||||
.register_executor(lambda: DomainValidator(id="domain_validator"), name="domain_validator")
|
||||
.add_edge("email_sanitizer", "email_format_validator")
|
||||
.add_edge("email_format_validator", "domain_validator")
|
||||
WorkflowBuilder(start_executor=email_sanitizer)
|
||||
.add_edge(email_sanitizer, email_format_validator)
|
||||
.add_edge(email_format_validator, domain_validator)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -268,20 +269,15 @@ async def main() -> None:
|
||||
approved_domains = {"example.com", "company.com"}
|
||||
|
||||
# Build the main workflow
|
||||
smart_email_orchestrator = SmartEmailOrchestrator(id="smart_email_orchestrator", approved_domains=approved_domains)
|
||||
email_delivery = EmailDelivery(id="email_delivery")
|
||||
email_validation_workflow = WorkflowExecutor(build_email_address_validation_workflow(), id="email_validation_workflow")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="smart_email_orchestrator")
|
||||
.register_executor(
|
||||
lambda: SmartEmailOrchestrator(id="smart_email_orchestrator", approved_domains=approved_domains),
|
||||
name="smart_email_orchestrator",
|
||||
)
|
||||
.register_executor(lambda: EmailDelivery(id="email_delivery"), name="email_delivery")
|
||||
.register_executor(
|
||||
lambda: WorkflowExecutor(build_email_address_validation_workflow(), id="email_validation_workflow"),
|
||||
name="email_validation_workflow",
|
||||
)
|
||||
.add_edge("smart_email_orchestrator", "email_validation_workflow")
|
||||
.add_edge("email_validation_workflow", "smart_email_orchestrator")
|
||||
.add_edge("smart_email_orchestrator", "email_delivery")
|
||||
WorkflowBuilder(start_executor=smart_email_orchestrator)
|
||||
.add_edge(smart_email_orchestrator, email_validation_workflow)
|
||||
.add_edge(email_validation_workflow, smart_email_orchestrator)
|
||||
.add_edge(smart_email_orchestrator, email_delivery)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ( # Core chat primitives used to build requests
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest, # Input message bundle for an AgentExecutor
|
||||
AgentExecutorResponse,
|
||||
ChatAgent, # Output from an AgentExecutor
|
||||
@@ -161,19 +162,17 @@ async def main() -> None:
|
||||
# If not spam, hop to a transformer that creates a new AgentExecutorRequest,
|
||||
# then call the email assistant, then finalize.
|
||||
# If spam, go directly to the spam handler and finalize.
|
||||
spam_detection_agent = AgentExecutor(create_spam_detector_agent())
|
||||
email_assistant_agent = AgentExecutor(create_email_assistant_agent())
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="spam_detection_agent")
|
||||
.register_agent(create_spam_detector_agent, name="spam_detection_agent")
|
||||
.register_agent(create_email_assistant_agent, name="email_assistant_agent")
|
||||
.register_executor(lambda: to_email_assistant_request, name="to_email_assistant_request")
|
||||
.register_executor(lambda: handle_email_response, name="send_email")
|
||||
.register_executor(lambda: handle_spam_classifier_response, name="handle_spam")
|
||||
WorkflowBuilder(start_executor=spam_detection_agent)
|
||||
# Not spam path: transform response -> request for assistant -> assistant -> send email
|
||||
.add_edge("spam_detection_agent", "to_email_assistant_request", condition=get_condition(False))
|
||||
.add_edge("to_email_assistant_request", "email_assistant_agent")
|
||||
.add_edge("email_assistant_agent", "send_email")
|
||||
.add_edge(spam_detection_agent, to_email_assistant_request, condition=get_condition(False))
|
||||
.add_edge(to_email_assistant_request, email_assistant_agent)
|
||||
.add_edge(email_assistant_agent, handle_email_response)
|
||||
# Spam path: send to spam handler
|
||||
.add_edge("spam_detection_agent", "handle_spam", condition=get_condition(True))
|
||||
.add_edge(spam_detection_agent, handle_spam_classifier_response, condition=get_condition(True))
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+16
-27
@@ -9,6 +9,7 @@ from typing import Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatAgent,
|
||||
@@ -212,6 +213,10 @@ def create_email_summary_agent() -> ChatAgent:
|
||||
|
||||
async def main() -> None:
|
||||
# Build the workflow
|
||||
email_analysis_agent = AgentExecutor(create_email_analysis_agent())
|
||||
email_assistant_agent = AgentExecutor(create_email_assistant_agent())
|
||||
email_summary_agent = AgentExecutor(create_email_summary_agent())
|
||||
|
||||
def select_targets(analysis: AnalysisResult, target_ids: list[str]) -> list[str]:
|
||||
# Order: [handle_spam, submit_to_email_assistant, summarize_email, handle_uncertain]
|
||||
handle_spam_id, submit_to_email_assistant_id, summarize_email_id, handle_uncertain_id = target_ids
|
||||
@@ -224,39 +229,23 @@ async def main() -> None:
|
||||
return targets
|
||||
return [handle_uncertain_id]
|
||||
|
||||
workflow_builder = (
|
||||
WorkflowBuilder(start_executor="store_email")
|
||||
.register_agent(create_email_analysis_agent, name="email_analysis_agent")
|
||||
.register_agent(create_email_assistant_agent, name="email_assistant_agent")
|
||||
.register_agent(create_email_summary_agent, name="email_summary_agent")
|
||||
.register_executor(lambda: store_email, name="store_email")
|
||||
.register_executor(lambda: to_analysis_result, name="to_analysis_result")
|
||||
.register_executor(lambda: submit_to_email_assistant, name="submit_to_email_assistant")
|
||||
.register_executor(lambda: finalize_and_send, name="finalize_and_send")
|
||||
.register_executor(lambda: summarize_email, name="summarize_email")
|
||||
.register_executor(lambda: merge_summary, name="merge_summary")
|
||||
.register_executor(lambda: handle_spam, name="handle_spam")
|
||||
.register_executor(lambda: handle_uncertain, name="handle_uncertain")
|
||||
.register_executor(lambda: database_access, name="database_access")
|
||||
)
|
||||
|
||||
workflow = (
|
||||
workflow_builder
|
||||
.add_edge("store_email", "email_analysis_agent")
|
||||
.add_edge("email_analysis_agent", "to_analysis_result")
|
||||
WorkflowBuilder(start_executor=store_email)
|
||||
.add_edge(store_email, email_analysis_agent)
|
||||
.add_edge(email_analysis_agent, to_analysis_result)
|
||||
.add_multi_selection_edge_group(
|
||||
"to_analysis_result",
|
||||
["handle_spam", "submit_to_email_assistant", "summarize_email", "handle_uncertain"],
|
||||
to_analysis_result,
|
||||
[handle_spam, submit_to_email_assistant, summarize_email, handle_uncertain],
|
||||
selection_func=select_targets,
|
||||
)
|
||||
.add_edge("submit_to_email_assistant", "email_assistant_agent")
|
||||
.add_edge("email_assistant_agent", "finalize_and_send")
|
||||
.add_edge("summarize_email", "email_summary_agent")
|
||||
.add_edge("email_summary_agent", "merge_summary")
|
||||
.add_edge(submit_to_email_assistant, email_assistant_agent)
|
||||
.add_edge(email_assistant_agent, finalize_and_send)
|
||||
.add_edge(summarize_email, email_summary_agent)
|
||||
.add_edge(email_summary_agent, merge_summary)
|
||||
# Save to DB if short (no summary path)
|
||||
.add_edge("to_analysis_result", "database_access", condition=lambda r: r.email_length <= LONG_EMAIL_THRESHOLD)
|
||||
.add_edge(to_analysis_result, database_access, condition=lambda r: r.email_length <= LONG_EMAIL_THRESHOLD)
|
||||
# Save to DB with summary when long
|
||||
.add_edge("merge_summary", "database_access")
|
||||
.add_edge(merge_summary, database_access)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -62,11 +62,12 @@ async def main() -> None:
|
||||
"""Build a two step sequential workflow and run it with streaming to observe events."""
|
||||
# Step 1: Build the workflow graph.
|
||||
# Order matters. We connect upper_case_executor -> reverse_text_executor and set the start.
|
||||
upper_case_executor = UpperCaseExecutor(id="upper_case_executor")
|
||||
reverse_text_executor = ReverseTextExecutor(id="reverse_text_executor")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="upper_case_executor")
|
||||
.register_executor(lambda: UpperCaseExecutor(id="upper_case_executor"), name="upper_case_executor")
|
||||
.register_executor(lambda: ReverseTextExecutor(id="reverse_text_executor"), name="reverse_text_executor")
|
||||
.add_edge("upper_case_executor", "reverse_text_executor")
|
||||
WorkflowBuilder(start_executor=upper_case_executor)
|
||||
.add_edge(upper_case_executor, reverse_text_executor)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -56,10 +56,8 @@ async def main():
|
||||
# Step 1: Build the workflow with the defined edges.
|
||||
# Order matters. upper_case_executor runs first, then reverse_text_executor.
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="upper_case_executor")
|
||||
.register_executor(lambda: to_upper_case, name="upper_case_executor")
|
||||
.register_executor(lambda: reverse_text, name="reverse_text_executor")
|
||||
.add_edge("upper_case_executor", "reverse_text_executor")
|
||||
WorkflowBuilder(start_executor=to_upper_case)
|
||||
.add_edge(to_upper_case, reverse_text)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
from enum import Enum
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatAgent,
|
||||
@@ -125,16 +126,17 @@ async def main():
|
||||
"""Main function to run the workflow."""
|
||||
# Step 1: Build the workflow with the defined edges.
|
||||
# This time we are creating a loop in the workflow.
|
||||
guess_number = GuessNumberExecutor((1, 100), "guess_number")
|
||||
judge_agent = AgentExecutor(create_judge_agent())
|
||||
submit_judge = SubmitToJudgeAgent(judge_agent_id="judge_agent", target=30)
|
||||
parse_judge = ParseJudgeResponse(id="parse_judge")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="guess_number")
|
||||
.register_executor(lambda: GuessNumberExecutor((1, 100), "guess_number"), name="guess_number")
|
||||
.register_agent(create_judge_agent, name="judge_agent")
|
||||
.register_executor(lambda: SubmitToJudgeAgent(judge_agent_id="judge_agent", target=30), name="submit_judge")
|
||||
.register_executor(lambda: ParseJudgeResponse(id="parse_judge"), name="parse_judge")
|
||||
.add_edge("guess_number", "submit_judge")
|
||||
.add_edge("submit_judge", "judge_agent")
|
||||
.add_edge("judge_agent", "parse_judge")
|
||||
.add_edge("parse_judge", "guess_number")
|
||||
WorkflowBuilder(start_executor=guess_number)
|
||||
.add_edge(guess_number, submit_judge)
|
||||
.add_edge(submit_judge, judge_agent)
|
||||
.add_edge(judge_agent, parse_judge)
|
||||
.add_edge(parse_judge, guess_number)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework import ( # Core chat primitives used to form LLM requests
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest, # Message bundle sent to an AgentExecutor
|
||||
AgentExecutorResponse, # Result returned by an AgentExecutor
|
||||
Case,
|
||||
@@ -178,28 +179,23 @@ async def main():
|
||||
"""Main function to run the workflow."""
|
||||
# Build workflow: store -> detection agent -> to_detection_result -> switch (NotSpam or Spam or Default).
|
||||
# The switch-case group evaluates cases in order, then falls back to Default when none match.
|
||||
spam_detection_agent = AgentExecutor(create_spam_detection_agent())
|
||||
email_assistant_agent = AgentExecutor(create_email_assistant_agent())
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="store_email")
|
||||
.register_agent(create_spam_detection_agent, name="spam_detection_agent")
|
||||
.register_agent(create_email_assistant_agent, name="email_assistant_agent")
|
||||
.register_executor(lambda: store_email, name="store_email")
|
||||
.register_executor(lambda: to_detection_result, name="to_detection_result")
|
||||
.register_executor(lambda: submit_to_email_assistant, name="submit_to_email_assistant")
|
||||
.register_executor(lambda: finalize_and_send, name="finalize_and_send")
|
||||
.register_executor(lambda: handle_spam, name="handle_spam")
|
||||
.register_executor(lambda: handle_uncertain, name="handle_uncertain")
|
||||
.add_edge("store_email", "spam_detection_agent")
|
||||
.add_edge("spam_detection_agent", "to_detection_result")
|
||||
WorkflowBuilder(start_executor=store_email)
|
||||
.add_edge(store_email, spam_detection_agent)
|
||||
.add_edge(spam_detection_agent, to_detection_result)
|
||||
.add_switch_case_edge_group(
|
||||
"to_detection_result",
|
||||
to_detection_result,
|
||||
[
|
||||
Case(condition=get_case("NotSpam"), target="submit_to_email_assistant"),
|
||||
Case(condition=get_case("Spam"), target="handle_spam"),
|
||||
Default(target="handle_uncertain"),
|
||||
Case(condition=get_case("NotSpam"), target=submit_to_email_assistant),
|
||||
Case(condition=get_case("Spam"), target=handle_spam),
|
||||
Default(target=handle_uncertain),
|
||||
],
|
||||
)
|
||||
.add_edge("submit_to_email_assistant", "email_assistant_agent")
|
||||
.add_edge("email_assistant_agent", "finalize_and_send")
|
||||
.add_edge(submit_to_email_assistant, email_assistant_agent)
|
||||
.add_edge(email_assistant_agent, finalize_and_send)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -51,12 +51,9 @@ async def step3(text: str, ctx: WorkflowContext[Never, str]) -> None:
|
||||
def build_workflow():
|
||||
"""Build a simple 3-step sequential workflow (~6 seconds total)."""
|
||||
return (
|
||||
WorkflowBuilder(start_executor="step1")
|
||||
.register_executor(lambda: step1, name="step1")
|
||||
.register_executor(lambda: step2, name="step2")
|
||||
.register_executor(lambda: step3, name="step3")
|
||||
.add_edge("step1", "step2")
|
||||
.add_edge("step2", "step3")
|
||||
WorkflowBuilder(start_executor=step1)
|
||||
.add_edge(step1, step2)
|
||||
.add_edge(step2, step3)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+8
-7
@@ -72,14 +72,15 @@ class Aggregator(Executor):
|
||||
|
||||
async def main() -> None:
|
||||
# 1) Build a simple fan out and fan in workflow
|
||||
dispatcher = Dispatcher(id="dispatcher")
|
||||
average = Average(id="average")
|
||||
summation = Sum(id="summation")
|
||||
aggregator = Aggregator(id="aggregator")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="dispatcher")
|
||||
.register_executor(lambda: Dispatcher(id="dispatcher"), name="dispatcher")
|
||||
.register_executor(lambda: Average(id="average"), name="average")
|
||||
.register_executor(lambda: Sum(id="summation"), name="summation")
|
||||
.register_executor(lambda: Aggregator(id="aggregator"), name="aggregator")
|
||||
.add_fan_out_edges("dispatcher", ["average", "summation"])
|
||||
.add_fan_in_edges(["average", "summation"], "aggregator")
|
||||
WorkflowBuilder(start_executor=dispatcher)
|
||||
.add_fan_out_edges(dispatcher, [average, summation])
|
||||
.add_fan_in_edges([average, summation], aggregator)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor, # Wraps a ChatAgent as an Executor for use in workflows
|
||||
AgentExecutorRequest, # The message bundle sent to an AgentExecutor
|
||||
AgentExecutorResponse, # The structured result returned by an AgentExecutor
|
||||
ChatAgent, # Tracing event for agent execution steps
|
||||
ChatMessage, # Chat message structure
|
||||
Executor, # Base class for custom Python executors
|
||||
WorkflowBuilder, # Fluent builder for wiring the workflow graph
|
||||
@@ -87,50 +87,44 @@ class AggregateInsights(Executor):
|
||||
await ctx.yield_output(consolidated)
|
||||
|
||||
|
||||
def create_researcher_agent() -> ChatAgent:
|
||||
"""Creates a research domain expert agent."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're an expert market and product researcher. Given a prompt, provide concise, factual insights,"
|
||||
" opportunities, and risks."
|
||||
),
|
||||
name="researcher",
|
||||
)
|
||||
|
||||
|
||||
def create_marketer_agent() -> ChatAgent:
|
||||
"""Creates a marketing domain expert agent."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a creative marketing strategist. Craft compelling value propositions and target messaging"
|
||||
" aligned to the prompt."
|
||||
),
|
||||
name="marketer",
|
||||
)
|
||||
|
||||
|
||||
def create_legal_agent() -> ChatAgent:
|
||||
"""Creates a legal/compliance domain expert agent."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a cautious legal/compliance reviewer. Highlight constraints, disclaimers, and policy concerns"
|
||||
" based on the prompt."
|
||||
),
|
||||
name="legal",
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# 1) Build a simple fan out and fan in workflow
|
||||
# 1) Create executor and agent instances
|
||||
dispatcher = DispatchToExperts(id="dispatcher")
|
||||
aggregator = AggregateInsights(id="aggregator")
|
||||
|
||||
researcher = AgentExecutor(
|
||||
AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're an expert market and product researcher. Given a prompt, provide concise, factual insights,"
|
||||
" opportunities, and risks."
|
||||
),
|
||||
name="researcher",
|
||||
)
|
||||
)
|
||||
marketer = AgentExecutor(
|
||||
AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a creative marketing strategist. Craft compelling value propositions and target messaging"
|
||||
" aligned to the prompt."
|
||||
),
|
||||
name="marketer",
|
||||
)
|
||||
)
|
||||
legal = AgentExecutor(
|
||||
AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a cautious legal/compliance reviewer. Highlight constraints, disclaimers, and policy concerns"
|
||||
" based on the prompt."
|
||||
),
|
||||
name="legal",
|
||||
)
|
||||
)
|
||||
|
||||
# 2) Build a simple fan out and fan in workflow
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="dispatcher")
|
||||
.register_agent(create_researcher_agent, name="researcher")
|
||||
.register_agent(create_marketer_agent, name="marketer")
|
||||
.register_agent(create_legal_agent, name="legal")
|
||||
.register_executor(lambda: DispatchToExperts(id="dispatcher"), name="dispatcher")
|
||||
.register_executor(lambda: AggregateInsights(id="aggregator"), name="aggregator")
|
||||
.add_fan_out_edges("dispatcher", ["researcher", "marketer", "legal"]) # Parallel branches
|
||||
.add_fan_in_edges(["researcher", "marketer", "legal"], "aggregator") # Join at the aggregator
|
||||
WorkflowBuilder(start_executor=dispatcher)
|
||||
.add_fan_out_edges(dispatcher, [researcher, marketer, legal]) # Parallel branches
|
||||
.add_fan_in_edges([researcher, marketer, legal], aggregator) # Join at the aggregator
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+21
-39
@@ -257,49 +257,31 @@ class CompletionExecutor(Executor):
|
||||
async def main():
|
||||
"""Construct the map reduce workflow, visualize it, then run it over a sample file."""
|
||||
|
||||
# Step 1: Create the workflow builder and register executors.
|
||||
workflow_builder = (
|
||||
WorkflowBuilder(start_executor="split_data_executor")
|
||||
.register_executor(lambda: Map(id="map_executor_0"), name="map_executor_0")
|
||||
.register_executor(lambda: Map(id="map_executor_1"), name="map_executor_1")
|
||||
.register_executor(lambda: Map(id="map_executor_2"), name="map_executor_2")
|
||||
.register_executor(
|
||||
lambda: Split(["map_executor_0", "map_executor_1", "map_executor_2"], id="split_data_executor"),
|
||||
name="split_data_executor",
|
||||
)
|
||||
.register_executor(lambda: Reduce(id="reduce_executor_0"), name="reduce_executor_0")
|
||||
.register_executor(lambda: Reduce(id="reduce_executor_1"), name="reduce_executor_1")
|
||||
.register_executor(lambda: Reduce(id="reduce_executor_2"), name="reduce_executor_2")
|
||||
.register_executor(lambda: Reduce(id="reduce_executor_3"), name="reduce_executor_3")
|
||||
.register_executor(
|
||||
lambda: Shuffle(
|
||||
["reduce_executor_0", "reduce_executor_1", "reduce_executor_2", "reduce_executor_3"],
|
||||
id="shuffle_executor",
|
||||
),
|
||||
name="shuffle_executor",
|
||||
)
|
||||
.register_executor(lambda: CompletionExecutor(id="completion_executor"), name="completion_executor")
|
||||
# Step 1: Create executor instances.
|
||||
map_executor_0 = Map(id="map_executor_0")
|
||||
map_executor_1 = Map(id="map_executor_1")
|
||||
map_executor_2 = Map(id="map_executor_2")
|
||||
split_data_executor = Split(["map_executor_0", "map_executor_1", "map_executor_2"], id="split_data_executor")
|
||||
reduce_executor_0 = Reduce(id="reduce_executor_0")
|
||||
reduce_executor_1 = Reduce(id="reduce_executor_1")
|
||||
reduce_executor_2 = Reduce(id="reduce_executor_2")
|
||||
reduce_executor_3 = Reduce(id="reduce_executor_3")
|
||||
shuffle_executor = Shuffle(
|
||||
["reduce_executor_0", "reduce_executor_1", "reduce_executor_2", "reduce_executor_3"],
|
||||
id="shuffle_executor",
|
||||
)
|
||||
completion_executor = CompletionExecutor(id="completion_executor")
|
||||
|
||||
mappers = [map_executor_0, map_executor_1, map_executor_2]
|
||||
reducers = [reduce_executor_0, reduce_executor_1, reduce_executor_2, reduce_executor_3]
|
||||
|
||||
# Step 2: Build the workflow graph using fan out and fan in edges.
|
||||
workflow = (
|
||||
workflow_builder
|
||||
.add_fan_out_edges(
|
||||
"split_data_executor",
|
||||
["map_executor_0", "map_executor_1", "map_executor_2"],
|
||||
) # Split -> many mappers
|
||||
.add_fan_in_edges(
|
||||
["map_executor_0", "map_executor_1", "map_executor_2"],
|
||||
"shuffle_executor",
|
||||
) # All mappers -> shuffle
|
||||
.add_fan_out_edges(
|
||||
"shuffle_executor",
|
||||
["reduce_executor_0", "reduce_executor_1", "reduce_executor_2", "reduce_executor_3"],
|
||||
) # Shuffle -> many reducers
|
||||
.add_fan_in_edges(
|
||||
["reduce_executor_0", "reduce_executor_1", "reduce_executor_2", "reduce_executor_3"],
|
||||
"completion_executor",
|
||||
) # All reducers -> completion
|
||||
WorkflowBuilder(start_executor=split_data_executor)
|
||||
.add_fan_out_edges(split_data_executor, mappers) # Split -> many mappers
|
||||
.add_fan_in_edges(mappers, shuffle_executor) # All mappers -> shuffle
|
||||
.add_fan_out_edges(shuffle_executor, reducers) # Shuffle -> many reducers
|
||||
.add_fan_in_edges(reducers, completion_executor) # All reducers -> completion
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@@ -188,21 +188,17 @@ async def main() -> None:
|
||||
# store_email -> spam_detection_agent -> to_detection_result -> branch:
|
||||
# False -> submit_to_email_assistant -> email_assistant_agent -> finalize_and_send
|
||||
# True -> handle_spam
|
||||
spam_detection_agent = create_spam_detection_agent()
|
||||
email_assistant_agent = create_email_assistant_agent()
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="store_email")
|
||||
.register_agent(create_spam_detection_agent, name="spam_detection_agent")
|
||||
.register_agent(create_email_assistant_agent, name="email_assistant_agent")
|
||||
.register_executor(lambda: store_email, name="store_email")
|
||||
.register_executor(lambda: to_detection_result, name="to_detection_result")
|
||||
.register_executor(lambda: submit_to_email_assistant, name="submit_to_email_assistant")
|
||||
.register_executor(lambda: finalize_and_send, name="finalize_and_send")
|
||||
.register_executor(lambda: handle_spam, name="handle_spam")
|
||||
.add_edge("store_email", "spam_detection_agent")
|
||||
.add_edge("spam_detection_agent", "to_detection_result")
|
||||
.add_edge("to_detection_result", "submit_to_email_assistant", condition=get_condition(False))
|
||||
.add_edge("to_detection_result", "handle_spam", condition=get_condition(True))
|
||||
.add_edge("submit_to_email_assistant", "email_assistant_agent")
|
||||
.add_edge("email_assistant_agent", "finalize_and_send")
|
||||
WorkflowBuilder(start_executor=store_email)
|
||||
.add_edge(store_email, spam_detection_agent)
|
||||
.add_edge(spam_detection_agent, to_detection_result)
|
||||
.add_edge(to_detection_result, submit_to_email_assistant, condition=get_condition(False))
|
||||
.add_edge(to_detection_result, handle_spam, condition=get_condition(True))
|
||||
.add_edge(submit_to_email_assistant, email_assistant_agent)
|
||||
.add_edge(email_assistant_agent, finalize_and_send)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
+39
-42
@@ -4,9 +4,9 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
WorkflowBuilder,
|
||||
@@ -85,52 +85,49 @@ class AggregateInsights(Executor):
|
||||
await ctx.yield_output(consolidated)
|
||||
|
||||
|
||||
def create_researcher_agent() -> ChatAgent:
|
||||
"""Creates a research domain expert agent."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're an expert market and product researcher. Given a prompt, provide concise, factual insights,"
|
||||
" opportunities, and risks."
|
||||
),
|
||||
name="researcher",
|
||||
)
|
||||
|
||||
|
||||
def create_marketer_agent() -> ChatAgent:
|
||||
"""Creates a marketing domain expert agent."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a creative marketing strategist. Craft compelling value propositions and target messaging"
|
||||
" aligned to the prompt."
|
||||
),
|
||||
name="marketer",
|
||||
)
|
||||
|
||||
|
||||
def create_legal_agent() -> ChatAgent:
|
||||
"""Creates a legal domain expert agent."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a cautious legal/compliance reviewer. Highlight constraints, disclaimers, and policy concerns"
|
||||
" based on the prompt."
|
||||
),
|
||||
name="legal",
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Build and run the concurrent workflow with visualization."""
|
||||
|
||||
# Create agent instances
|
||||
researcher = AgentExecutor(
|
||||
AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're an expert market and product researcher. Given a prompt, provide concise, factual insights,"
|
||||
" opportunities, and risks."
|
||||
),
|
||||
name="researcher",
|
||||
)
|
||||
)
|
||||
|
||||
marketer = AgentExecutor(
|
||||
AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a creative marketing strategist. Craft compelling value propositions and target messaging"
|
||||
" aligned to the prompt."
|
||||
),
|
||||
name="marketer",
|
||||
)
|
||||
)
|
||||
|
||||
legal = AgentExecutor(
|
||||
AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
|
||||
instructions=(
|
||||
"You're a cautious legal/compliance reviewer. Highlight constraints, disclaimers, and policy concerns"
|
||||
" based on the prompt."
|
||||
),
|
||||
name="legal",
|
||||
)
|
||||
)
|
||||
|
||||
# Create executor instances
|
||||
dispatcher = DispatchToExperts(id="dispatcher")
|
||||
aggregator = AggregateInsights(id="aggregator")
|
||||
|
||||
# Build a simple fan-out/fan-in workflow
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor="dispatcher")
|
||||
.register_agent(create_researcher_agent, name="researcher")
|
||||
.register_agent(create_marketer_agent, name="marketer")
|
||||
.register_agent(create_legal_agent, name="legal")
|
||||
.register_executor(lambda: DispatchToExperts(id="dispatcher"), name="dispatcher")
|
||||
.register_executor(lambda: AggregateInsights(id="aggregator"), name="aggregator")
|
||||
.add_fan_out_edges("dispatcher", ["researcher", "marketer", "legal"])
|
||||
.add_fan_in_edges(["researcher", "marketer", "legal"], "aggregator")
|
||||
WorkflowBuilder(start_executor=dispatcher)
|
||||
.add_fan_out_edges(dispatcher, [researcher, marketer, legal])
|
||||
.add_fan_in_edges([researcher, marketer, legal], aggregator)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user