mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add factory pattern to concurrent orchestration builder (#2738)
* Add factory pattern to concurrent orchestration builder * Update readme * Address AI comments * Fix unit tests * Fix import * Prevent multiple calls to set participants or factories * Add comments * Mitigate warnings * Fix mypy * Address comments * Address Copilot comments * Fix tests
This commit is contained in:
committed by
GitHub
Unverified
parent
638fbb5f03
commit
191779ce80
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
@@ -189,8 +190,11 @@ class ConcurrentBuilder:
|
||||
r"""High-level builder for concurrent agent workflows.
|
||||
|
||||
- `participants([...])` accepts a list of AgentProtocol (recommended) or Executor.
|
||||
- `register_participants([...])` accepts a list of factories for AgentProtocol (recommended)
|
||||
or Executor factories
|
||||
- `build()` wires: dispatcher -> fan-out -> participants -> fan-in -> aggregator.
|
||||
- `with_custom_aggregator(...)` overrides the default aggregator with an Executor or callback.
|
||||
- `with_aggregator(...)` overrides the default aggregator with an Executor or callback.
|
||||
- `register_aggregator(...)` accepts a factory for an Executor as custom aggregator.
|
||||
|
||||
Usage:
|
||||
|
||||
@@ -201,14 +205,33 @@ class ConcurrentBuilder:
|
||||
# Minimal: use default aggregator (returns list[ChatMessage])
|
||||
workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).build()
|
||||
|
||||
# With agent factories
|
||||
workflow = ConcurrentBuilder().register_participants([create_agent1, create_agent2, create_agent3]).build()
|
||||
|
||||
|
||||
# Custom aggregator via callback (sync or async). The callback receives
|
||||
# list[AgentExecutorResponse] and its return value becomes the workflow's output.
|
||||
def summarize(results):
|
||||
def summarize(results: list[AgentExecutorResponse]) -> str:
|
||||
return " | ".join(r.agent_run_response.messages[-1].text for r in results)
|
||||
|
||||
|
||||
workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_custom_aggregator(summarize).build()
|
||||
workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_aggregator(summarize).build()
|
||||
|
||||
|
||||
# Custom aggregator via a factory
|
||||
class MyAggregator(Executor):
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results))
|
||||
|
||||
|
||||
workflow = (
|
||||
ConcurrentBuilder()
|
||||
.register_participants([create_agent1, create_agent2, create_agent3])
|
||||
.register_aggregator(lambda: MyAggregator(id="my_aggregator"))
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
# Enable checkpoint persistence so runs can resume
|
||||
workflow = ConcurrentBuilder().participants([agent1, agent2, agent3]).with_checkpointing(storage).build()
|
||||
@@ -219,10 +242,67 @@ class ConcurrentBuilder:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._participants: list[AgentProtocol | Executor] = []
|
||||
self._participant_factories: list[Callable[[], AgentProtocol | Executor]] = []
|
||||
self._aggregator: Executor | None = None
|
||||
self._aggregator_factory: Callable[[], Executor] | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = None
|
||||
self._request_info_enabled: bool = False
|
||||
|
||||
def register_participants(
|
||||
self,
|
||||
participant_factories: Sequence[Callable[[], AgentProtocol | Executor]],
|
||||
) -> "ConcurrentBuilder":
|
||||
r"""Define the parallel participants for this concurrent workflow.
|
||||
|
||||
Accepts factories (callables) that return AgentProtocol instances (e.g., created
|
||||
by a chat client) or Executor instances. Each participant created by a factory
|
||||
is wired as a parallel branch using fan-out edges from an internal dispatcher.
|
||||
|
||||
Args:
|
||||
participant_factories: Sequence of callables returning AgentProtocol or Executor instances
|
||||
|
||||
Raises:
|
||||
ValueError: if `participant_factories` is empty or `.participants()`
|
||||
or `.register_participants()` were already called
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def create_researcher() -> ChatAgent:
|
||||
return ...
|
||||
|
||||
|
||||
def create_marketer() -> ChatAgent:
|
||||
return ...
|
||||
|
||||
|
||||
def create_legal() -> ChatAgent:
|
||||
return ...
|
||||
|
||||
|
||||
class MyCustomExecutor(Executor): ...
|
||||
|
||||
|
||||
wf = ConcurrentBuilder().register_participants([create_researcher, create_marketer, create_legal]).build()
|
||||
|
||||
# Mixing agent(s) and executor(s) is supported
|
||||
wf2 = ConcurrentBuilder().register_participants([create_researcher, MyCustomExecutor]).build()
|
||||
"""
|
||||
if self._participants:
|
||||
raise ValueError(
|
||||
"Cannot mix .participants([...]) and .register_participants() in the same builder instance."
|
||||
)
|
||||
|
||||
if self._participant_factories:
|
||||
raise ValueError("register_participants() has already been called on this builder instance.")
|
||||
|
||||
if not participant_factories:
|
||||
raise ValueError("participant_factories cannot be empty")
|
||||
|
||||
self._participant_factories = list(participant_factories)
|
||||
return self
|
||||
|
||||
def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "ConcurrentBuilder":
|
||||
r"""Define the parallel participants for this concurrent workflow.
|
||||
|
||||
@@ -230,8 +310,12 @@ class ConcurrentBuilder:
|
||||
instances. Each participant is wired as a parallel branch using fan-out edges
|
||||
from an internal dispatcher.
|
||||
|
||||
Args:
|
||||
participants: Sequence of AgentProtocol or Executor instances
|
||||
|
||||
Raises:
|
||||
ValueError: if `participants` is empty or contains duplicates
|
||||
ValueError: if `participants` is empty, contains duplicates, or `.register_participants()`
|
||||
or `.participants()` were already called
|
||||
TypeError: if any entry is not AgentProtocol or Executor
|
||||
|
||||
Example:
|
||||
@@ -243,6 +327,14 @@ class ConcurrentBuilder:
|
||||
# Mixing agent(s) and executor(s) is supported
|
||||
wf2 = ConcurrentBuilder().participants([researcher_agent, my_custom_executor]).build()
|
||||
"""
|
||||
if self._participant_factories:
|
||||
raise ValueError(
|
||||
"Cannot mix .participants([...]) and .register_participants() in the same builder instance."
|
||||
)
|
||||
|
||||
if self._participants:
|
||||
raise ValueError("participants() has already been called on this builder instance.")
|
||||
|
||||
if not participants:
|
||||
raise ValueError("participants cannot be empty")
|
||||
|
||||
@@ -265,38 +357,107 @@ class ConcurrentBuilder:
|
||||
self._participants = list(participants)
|
||||
return self
|
||||
|
||||
def with_aggregator(self, aggregator: Executor | Callable[..., Any]) -> "ConcurrentBuilder":
|
||||
r"""Override the default aggregator with an Executor or a callback.
|
||||
def register_aggregator(self, aggregator_factory: Callable[[], Executor]) -> "ConcurrentBuilder":
|
||||
r"""Define a custom aggregator for this concurrent workflow.
|
||||
|
||||
- Executor: must handle `list[AgentExecutorResponse]` and
|
||||
yield output using `ctx.yield_output(...)` and add a
|
||||
output and the workflow becomes idle.
|
||||
Accepts a factory (callable) that returns an Executor instance. The executor
|
||||
should handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)`.
|
||||
|
||||
Args:
|
||||
aggregator_factory: Callable that returns an Executor instance
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class MyCustomExecutor(Executor): ...
|
||||
|
||||
|
||||
wf = (
|
||||
ConcurrentBuilder()
|
||||
.register_participants([create_researcher, create_marketer, create_legal])
|
||||
.register_aggregator(lambda: MyCustomExecutor(id="my_aggregator"))
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
if self._aggregator is not None:
|
||||
raise ValueError(
|
||||
"Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance."
|
||||
)
|
||||
|
||||
if self._aggregator_factory is not None:
|
||||
raise ValueError("register_aggregator() has already been called on this builder instance.")
|
||||
|
||||
self._aggregator_factory = aggregator_factory
|
||||
return self
|
||||
|
||||
def with_aggregator(
|
||||
self,
|
||||
aggregator: Executor
|
||||
| Callable[[list[AgentExecutorResponse]], Any]
|
||||
| Callable[[list[AgentExecutorResponse], WorkflowContext[Never, Any]], Any],
|
||||
) -> "ConcurrentBuilder":
|
||||
r"""Override the default aggregator with an executor, an executor factory, or a callback.
|
||||
|
||||
- Executor: must handle `list[AgentExecutorResponse]` and yield output using `ctx.yield_output(...)`
|
||||
- Callback: sync or async callable with one of the signatures:
|
||||
`(results: list[AgentExecutorResponse]) -> Any | None` or
|
||||
`(results: list[AgentExecutorResponse], ctx: WorkflowContext) -> Any | None`.
|
||||
If the callback returns a non-None value, it becomes the workflow's output.
|
||||
|
||||
Args:
|
||||
aggregator: Executor instance, or callback function
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
# Executor-based aggregator
|
||||
class CustomAggregator(Executor):
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext) -> None:
|
||||
await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results))
|
||||
|
||||
|
||||
wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(CustomAggregator()).build()
|
||||
|
||||
|
||||
# Callback-based aggregator (string result)
|
||||
async def summarize(results):
|
||||
async def summarize(results: list[AgentExecutorResponse]) -> str:
|
||||
return " | ".join(r.agent_run_response.messages[-1].text for r in results)
|
||||
|
||||
|
||||
wf = ConcurrentBuilder().participants([a1, a2, a3]).with_custom_aggregator(summarize).build()
|
||||
wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build()
|
||||
|
||||
|
||||
# Callback-based aggregator (yield result)
|
||||
async def summarize(results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
await ctx.yield_output(" | ".join(r.agent_run_response.messages[-1].text for r in results))
|
||||
|
||||
|
||||
wf = ConcurrentBuilder().participants([a1, a2, a3]).with_aggregator(summarize).build()
|
||||
"""
|
||||
if self._aggregator_factory is not None:
|
||||
raise ValueError(
|
||||
"Cannot mix .with_aggregator(...) and .register_aggregator(...) in the same builder instance."
|
||||
)
|
||||
|
||||
if self._aggregator is not None:
|
||||
raise ValueError("with_aggregator() has already been called on this builder instance.")
|
||||
|
||||
if isinstance(aggregator, Executor):
|
||||
self._aggregator = aggregator
|
||||
elif callable(aggregator):
|
||||
self._aggregator = _CallbackAggregator(aggregator)
|
||||
else:
|
||||
raise TypeError("aggregator must be an Executor or a callable")
|
||||
|
||||
return self
|
||||
|
||||
def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "ConcurrentBuilder":
|
||||
"""Enable checkpoint persistence using the provided storage backend."""
|
||||
"""Enable checkpoint persistence using the provided storage backend.
|
||||
|
||||
Args:
|
||||
checkpoint_storage: CheckpointStorage instance for persisting workflow state
|
||||
"""
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
return self
|
||||
|
||||
@@ -329,7 +490,7 @@ class ConcurrentBuilder:
|
||||
before sending the outputs to the aggregator
|
||||
- Aggregator yields output and the workflow becomes idle. The output is either:
|
||||
- list[ChatMessage] (default aggregator: one user + one assistant per agent)
|
||||
- custom payload from the provided callback/executor
|
||||
- custom payload from the provided aggregator
|
||||
|
||||
Returns:
|
||||
Workflow: a ready-to-run workflow instance
|
||||
@@ -343,26 +504,69 @@ class ConcurrentBuilder:
|
||||
|
||||
workflow = ConcurrentBuilder().participants([agent1, agent2]).build()
|
||||
"""
|
||||
if not self._participants:
|
||||
raise ValueError("No participants provided. Call .participants([...]) first.")
|
||||
if not self._participants and not self._participant_factories:
|
||||
raise ValueError(
|
||||
"No participants provided. Call .participants([...]) or .register_participants([...]) first."
|
||||
)
|
||||
|
||||
# Internal nodes
|
||||
dispatcher = _DispatchToAllParticipants(id="dispatcher")
|
||||
aggregator = self._aggregator or _AggregateAgentConversations(id="aggregator")
|
||||
aggregator = (
|
||||
self._aggregator
|
||||
if self._aggregator is not None
|
||||
else (
|
||||
self._aggregator_factory()
|
||||
if self._aggregator_factory is not None
|
||||
else _AggregateAgentConversations(id="aggregator")
|
||||
)
|
||||
)
|
||||
|
||||
builder = WorkflowBuilder()
|
||||
builder.set_start_executor(dispatcher)
|
||||
builder.add_fan_out_edges(dispatcher, list(self._participants))
|
||||
if self._participant_factories:
|
||||
# Register executors/agents to avoid warnings from the workflow builder
|
||||
# if factories are provided instead of direct instances. This doesn't
|
||||
# break the factory pattern since the concurrent builder still creates
|
||||
# new instances per workflow build.
|
||||
factory_names: list[str] = []
|
||||
for factory in self._participant_factories:
|
||||
factory_name = uuid.uuid4().hex
|
||||
factory_names.append(factory_name)
|
||||
instance = factory()
|
||||
if isinstance(instance, Executor):
|
||||
builder.register_executor(lambda executor=instance: executor, name=factory_name) # type: ignore[misc]
|
||||
else:
|
||||
builder.register_agent(lambda agent=instance: agent, name=factory_name) # type: ignore[misc]
|
||||
# Register the dispatcher and the aggregator
|
||||
builder.register_executor(lambda: dispatcher, name="dispatcher")
|
||||
builder.register_executor(lambda: aggregator, name="aggregator")
|
||||
|
||||
if self._request_info_enabled:
|
||||
# Insert interceptor between fan-in and aggregator
|
||||
# participants -> fan-in -> interceptor -> aggregator
|
||||
request_info_interceptor = RequestInfoInterceptor(executor_id="request_info")
|
||||
builder.add_fan_in_edges(list(self._participants), request_info_interceptor)
|
||||
builder.add_edge(request_info_interceptor, aggregator)
|
||||
builder.set_start_executor("dispatcher")
|
||||
builder.add_fan_out_edges("dispatcher", factory_names)
|
||||
if self._request_info_enabled:
|
||||
# Insert interceptor between fan-in and aggregator
|
||||
# participants -> fan-in -> interceptor -> aggregator
|
||||
builder.register_executor(
|
||||
lambda: RequestInfoInterceptor(executor_id="request_info"),
|
||||
name="request_info_interceptor",
|
||||
)
|
||||
builder.add_fan_in_edges(factory_names, "request_info_interceptor")
|
||||
builder.add_edge("request_info_interceptor", "aggregator")
|
||||
else:
|
||||
# Direct fan-in to aggregator
|
||||
builder.add_fan_in_edges(factory_names, "aggregator")
|
||||
else:
|
||||
# Direct fan-in to aggregator
|
||||
builder.add_fan_in_edges(list(self._participants), aggregator)
|
||||
builder.set_start_executor(dispatcher)
|
||||
builder.add_fan_out_edges(dispatcher, self._participants)
|
||||
|
||||
if self._request_info_enabled:
|
||||
# Insert interceptor between fan-in and aggregator
|
||||
# participants -> fan-in -> interceptor -> aggregator
|
||||
request_info_interceptor = RequestInfoInterceptor(executor_id="request_info")
|
||||
builder.add_fan_in_edges(self._participants, request_info_interceptor)
|
||||
builder.add_edge(request_info_interceptor, aggregator)
|
||||
else:
|
||||
# Direct fan-in to aggregator
|
||||
builder.add_fan_in_edges(self._participants, aggregator)
|
||||
if self._checkpoint_storage is not None:
|
||||
builder = builder.with_checkpointing(self._checkpoint_storage)
|
||||
|
||||
|
||||
@@ -374,7 +374,7 @@ class WorkflowBuilder:
|
||||
)
|
||||
"""
|
||||
if name in self._executor_registry:
|
||||
raise ValueError(f"An executor factory with the name '{name}' is already registered.")
|
||||
raise ValueError(f"An agent factory with the name '{name}' is already registered.")
|
||||
|
||||
def wrapped_factory() -> AgentExecutor:
|
||||
agent = factory_func()
|
||||
@@ -1148,21 +1148,29 @@ class WorkflowBuilder:
|
||||
if isinstance(self._start_executor, Executor):
|
||||
start_executor = self._start_executor
|
||||
|
||||
executors: dict[str, Executor] = {}
|
||||
# Maps registered factory names to created executor instances for edge resolution
|
||||
factory_name_to_instance: dict[str, Executor] = {}
|
||||
# Maps executor IDs to created executor instances to prevent duplicates
|
||||
executor_id_to_instance: dict[str, Executor] = {}
|
||||
deferred_edge_groups: list[EdgeGroup] = []
|
||||
for name, exec_factory in self._executor_registry.items():
|
||||
instance = exec_factory()
|
||||
if instance.id in executor_id_to_instance:
|
||||
raise ValueError(f"Executor with ID '{instance.id}' has already been created.")
|
||||
executor_id_to_instance[instance.id] = instance
|
||||
|
||||
if isinstance(self._start_executor, str) and name == self._start_executor:
|
||||
start_executor = instance
|
||||
|
||||
# All executors will get their own internal edge group for receiving system messages
|
||||
deferred_edge_groups.append(InternalEdgeGroup(instance.id)) # type: ignore[call-arg]
|
||||
executors[name] = instance
|
||||
factory_name_to_instance[name] = instance
|
||||
|
||||
def _get_executor(name: str) -> Executor:
|
||||
"""Helper to get executor by the registered name. Raises if not found."""
|
||||
if name not in executors:
|
||||
raise ValueError(f"Executor with name '{name}' has not been registered.")
|
||||
return executors[name]
|
||||
if name not in factory_name_to_instance:
|
||||
raise ValueError(f"Factory '{name}' has not been registered.")
|
||||
return factory_name_to_instance[name]
|
||||
|
||||
for registration in self._edge_registry:
|
||||
match registration:
|
||||
@@ -1179,7 +1187,7 @@ class WorkflowBuilder:
|
||||
cases_converted: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = []
|
||||
for case in cases:
|
||||
if not isinstance(case.target, str):
|
||||
raise ValueError("Switch case target must be a registered executor name (str) if deferred.")
|
||||
raise ValueError("Switch case target must be a registered factory name (str) if deferred.")
|
||||
target_exec = _get_executor(case.target)
|
||||
if isinstance(case, Default):
|
||||
cases_converted.append(SwitchCaseEdgeGroupDefault(target_id=target_exec.id))
|
||||
@@ -1201,7 +1209,7 @@ class WorkflowBuilder:
|
||||
if start_executor is None:
|
||||
raise ValueError("Failed to resolve starting executor from registered factories.")
|
||||
|
||||
return start_executor, list(executors.values()), deferred_edge_groups
|
||||
return start_executor, list(executor_id_to_instance.values()), deferred_edge_groups
|
||||
|
||||
def build(self) -> Workflow:
|
||||
"""Build and return the constructed workflow.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Never
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutorRequest,
|
||||
@@ -52,6 +53,55 @@ def test_concurrent_builder_rejects_duplicate_executors() -> None:
|
||||
ConcurrentBuilder().participants([a, b])
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_duplicate_executors_from_factories() -> None:
|
||||
"""Test that duplicate executor IDs from factories are detected at build time."""
|
||||
|
||||
def create_dup1() -> Executor:
|
||||
return _FakeAgentExec("dup", "A")
|
||||
|
||||
def create_dup2() -> Executor:
|
||||
return _FakeAgentExec("dup", "B") # same executor id
|
||||
|
||||
builder = ConcurrentBuilder().register_participants([create_dup1, create_dup2])
|
||||
with pytest.raises(ValueError, match="Executor with ID 'dup' has already been created."):
|
||||
builder.build()
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_mixed_participants_and_factories() -> None:
|
||||
"""Test that mixing .participants() and .register_participants() raises an error."""
|
||||
# Case 1: participants first, then register_participants
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
(
|
||||
ConcurrentBuilder()
|
||||
.participants([_FakeAgentExec("a", "A")])
|
||||
.register_participants([lambda: _FakeAgentExec("b", "B")])
|
||||
)
|
||||
|
||||
# Case 2: register_participants first, then participants
|
||||
with pytest.raises(ValueError, match="Cannot mix .participants"):
|
||||
(
|
||||
ConcurrentBuilder()
|
||||
.register_participants([lambda: _FakeAgentExec("a", "A")])
|
||||
.participants([_FakeAgentExec("b", "B")])
|
||||
)
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_multiple_calls_to_participants() -> None:
|
||||
"""Test that multiple calls to .participants() raises an error."""
|
||||
with pytest.raises(ValueError, match=r"participants\(\) has already been called"):
|
||||
(ConcurrentBuilder().participants([_FakeAgentExec("a", "A")]).participants([_FakeAgentExec("b", "B")]))
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_multiple_calls_to_register_participants() -> None:
|
||||
"""Test that multiple calls to .register_participants() raises an error."""
|
||||
with pytest.raises(ValueError, match=r"register_participants\(\) has already been called"):
|
||||
(
|
||||
ConcurrentBuilder()
|
||||
.register_participants([lambda: _FakeAgentExec("a", "A")])
|
||||
.register_participants([lambda: _FakeAgentExec("b", "B")])
|
||||
)
|
||||
|
||||
|
||||
async def test_concurrent_default_aggregator_emits_single_user_and_assistants() -> None:
|
||||
# Three synthetic agent executors
|
||||
e1 = _FakeAgentExec("agentA", "Alpha")
|
||||
@@ -159,6 +209,138 @@ def test_concurrent_custom_aggregator_uses_callback_name_for_id() -> None:
|
||||
assert aggregator.id == "summarize"
|
||||
|
||||
|
||||
async def test_concurrent_with_aggregator_executor_instance() -> None:
|
||||
"""Test with_aggregator using an Executor instance (not factory)."""
|
||||
|
||||
class CustomAggregator(Executor):
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
texts: list[str] = []
|
||||
for r in results:
|
||||
msgs: list[ChatMessage] = r.agent_run_response.messages
|
||||
texts.append(msgs[-1].text if msgs else "")
|
||||
await ctx.yield_output(" & ".join(sorted(texts)))
|
||||
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
e2 = _FakeAgentExec("agentB", "Two")
|
||||
|
||||
aggregator_instance = CustomAggregator(id="instance_aggregator")
|
||||
wf = ConcurrentBuilder().participants([e1, e2]).with_aggregator(aggregator_instance).build()
|
||||
|
||||
completed = False
|
||||
output: str | None = None
|
||||
async for ev in wf.run_stream("prompt: instance test"):
|
||||
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif isinstance(ev, WorkflowOutputEvent):
|
||||
output = cast(str, ev.data)
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
assert isinstance(output, str)
|
||||
assert output == "One & Two"
|
||||
|
||||
|
||||
async def test_concurrent_with_aggregator_executor_factory() -> None:
|
||||
"""Test with_aggregator using an Executor factory."""
|
||||
|
||||
class CustomAggregator(Executor):
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
texts: list[str] = []
|
||||
for r in results:
|
||||
msgs: list[ChatMessage] = r.agent_run_response.messages
|
||||
texts.append(msgs[-1].text if msgs else "")
|
||||
await ctx.yield_output(" | ".join(sorted(texts)))
|
||||
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
e2 = _FakeAgentExec("agentB", "Two")
|
||||
|
||||
wf = (
|
||||
ConcurrentBuilder()
|
||||
.participants([e1, e2])
|
||||
.register_aggregator(lambda: CustomAggregator(id="custom_aggregator"))
|
||||
.build()
|
||||
)
|
||||
|
||||
completed = False
|
||||
output: str | None = None
|
||||
async for ev in wf.run_stream("prompt: factory test"):
|
||||
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif isinstance(ev, WorkflowOutputEvent):
|
||||
output = cast(str, ev.data)
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
assert isinstance(output, str)
|
||||
assert output == "One | Two"
|
||||
|
||||
|
||||
async def test_concurrent_with_aggregator_executor_factory_with_default_id() -> None:
|
||||
"""Test with_aggregator using an Executor class directly as factory (with default __init__ parameters)."""
|
||||
|
||||
class CustomAggregator(Executor):
|
||||
def __init__(self, id: str = "default_aggregator") -> None:
|
||||
super().__init__(id)
|
||||
|
||||
@handler
|
||||
async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None:
|
||||
texts: list[str] = []
|
||||
for r in results:
|
||||
msgs: list[ChatMessage] = r.agent_run_response.messages
|
||||
texts.append(msgs[-1].text if msgs else "")
|
||||
await ctx.yield_output(" | ".join(sorted(texts)))
|
||||
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
e2 = _FakeAgentExec("agentB", "Two")
|
||||
|
||||
wf = ConcurrentBuilder().participants([e1, e2]).register_aggregator(CustomAggregator).build()
|
||||
|
||||
completed = False
|
||||
output: str | None = None
|
||||
async for ev in wf.run_stream("prompt: factory test"):
|
||||
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif isinstance(ev, WorkflowOutputEvent):
|
||||
output = cast(str, ev.data)
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
assert isinstance(output, str)
|
||||
assert output == "One | Two"
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_multiple_calls_to_with_aggregator() -> None:
|
||||
"""Test that multiple calls to .with_aggregator() raises an error."""
|
||||
|
||||
def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override]
|
||||
return str(len(results))
|
||||
|
||||
with pytest.raises(ValueError, match=r"with_aggregator\(\) has already been called"):
|
||||
(ConcurrentBuilder().with_aggregator(summarize).with_aggregator(summarize))
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_multiple_calls_to_register_aggregator() -> None:
|
||||
"""Test that multiple calls to .register_aggregator() raises an error."""
|
||||
|
||||
class CustomAggregator(Executor):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match=r"register_aggregator\(\) has already been called"):
|
||||
(
|
||||
ConcurrentBuilder()
|
||||
.register_aggregator(lambda: CustomAggregator(id="agg1"))
|
||||
.register_aggregator(lambda: CustomAggregator(id="agg2"))
|
||||
)
|
||||
|
||||
|
||||
async def test_concurrent_checkpoint_resume_round_trip() -> None:
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -278,3 +460,92 @@ async def test_concurrent_checkpoint_runtime_overrides_buildtime() -> None:
|
||||
|
||||
assert len(runtime_checkpoints) > 0, "Runtime storage should have checkpoints"
|
||||
assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden"
|
||||
|
||||
|
||||
def test_concurrent_builder_rejects_empty_participant_factories() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ConcurrentBuilder().register_participants([])
|
||||
|
||||
|
||||
async def test_concurrent_builder_reusable_after_build_with_participants() -> None:
|
||||
"""Test that the builder can be reused to build multiple identical workflows with participants()."""
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
e2 = _FakeAgentExec("agentB", "Two")
|
||||
|
||||
builder = ConcurrentBuilder().participants([e1, e2])
|
||||
|
||||
builder.build()
|
||||
|
||||
assert builder._participants[0] is e1 # type: ignore
|
||||
assert builder._participants[1] is e2 # type: ignore
|
||||
assert builder._participant_factories == [] # type: ignore
|
||||
|
||||
|
||||
async def test_concurrent_builder_reusable_after_build_with_factories() -> None:
|
||||
"""Test that the builder can be reused to build multiple workflows with register_participants()."""
|
||||
call_count = 0
|
||||
|
||||
def create_agent_executor_a() -> Executor:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _FakeAgentExec("agentA", "One")
|
||||
|
||||
def create_agent_executor_b() -> Executor:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _FakeAgentExec("agentB", "Two")
|
||||
|
||||
builder = ConcurrentBuilder().register_participants([create_agent_executor_a, create_agent_executor_b])
|
||||
|
||||
# Build the first workflow
|
||||
wf1 = builder.build()
|
||||
|
||||
assert builder._participants == [] # type: ignore
|
||||
assert len(builder._participant_factories) == 2 # type: ignore
|
||||
assert call_count == 2
|
||||
|
||||
# Build the second workflow
|
||||
wf2 = builder.build()
|
||||
assert call_count == 4
|
||||
|
||||
# Verify that the two workflows have different executor instances
|
||||
assert wf1.executors["agentA"] is not wf2.executors["agentA"]
|
||||
assert wf1.executors["agentB"] is not wf2.executors["agentB"]
|
||||
|
||||
|
||||
async def test_concurrent_with_register_participants() -> None:
|
||||
"""Test workflow creation using register_participants with factories."""
|
||||
|
||||
def create_agent1() -> Executor:
|
||||
return _FakeAgentExec("agentA", "Alpha")
|
||||
|
||||
def create_agent2() -> Executor:
|
||||
return _FakeAgentExec("agentB", "Beta")
|
||||
|
||||
def create_agent3() -> Executor:
|
||||
return _FakeAgentExec("agentC", "Gamma")
|
||||
|
||||
wf = ConcurrentBuilder().register_participants([create_agent1, create_agent2, create_agent3]).build()
|
||||
|
||||
completed = False
|
||||
output: list[ChatMessage] | None = None
|
||||
async for ev in wf.run_stream("test prompt"):
|
||||
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif isinstance(ev, WorkflowOutputEvent):
|
||||
output = cast(list[ChatMessage], ev.data)
|
||||
if completed and output is not None:
|
||||
break
|
||||
|
||||
assert completed
|
||||
assert output is not None
|
||||
messages: list[ChatMessage] = output
|
||||
|
||||
# Expect one user message + one assistant message per participant
|
||||
assert len(messages) == 1 + 3
|
||||
assert messages[0].role == Role.USER
|
||||
assert "test prompt" in messages[0].text
|
||||
|
||||
assistant_texts = {m.text for m in messages[1:]}
|
||||
assert assistant_texts == {"Alpha", "Beta", "Gamma"}
|
||||
assert all(m.role == Role.ASSISTANT for m in messages[1:])
|
||||
|
||||
@@ -293,6 +293,20 @@ def test_register_duplicate_name_raises_error():
|
||||
builder.register_executor(lambda: MockExecutor(id="executor_2"), name="MyExecutor")
|
||||
|
||||
|
||||
def test_register_duplicate_id_raises_error():
|
||||
"""Test that registering duplicate id raises an error."""
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
# Register first executor
|
||||
builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor1")
|
||||
builder.register_executor(lambda: MockExecutor(id="executor"), name="MyExecutor2")
|
||||
builder.set_start_executor("MyExecutor1")
|
||||
|
||||
# Registering second executor with same ID should raise ValueError
|
||||
with pytest.raises(ValueError, match="Executor with ID 'executor' has already been created."):
|
||||
builder.build()
|
||||
|
||||
|
||||
def test_register_agent_basic():
|
||||
"""Test basic agent registration with lazy initialization."""
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
@@ -110,6 +110,7 @@ For additional observability samples in Agent Framework, see the [observability
|
||||
| Concurrent Orchestration (Default Aggregator) | [orchestration/concurrent_agents.py](./orchestration/concurrent_agents.py) | Fan-out to multiple agents; fan-in with default aggregator returning combined ChatMessages |
|
||||
| Concurrent Orchestration (Custom Aggregator) | [orchestration/concurrent_custom_aggregator.py](./orchestration/concurrent_custom_aggregator.py) | Override aggregator via callback; summarize results with an LLM |
|
||||
| Concurrent Orchestration (Custom Agent Executors) | [orchestration/concurrent_custom_agent_executors.py](./orchestration/concurrent_custom_agent_executors.py) | Child executors own ChatAgents; concurrent fan-out/fan-in via ConcurrentBuilder |
|
||||
| Concurrent Orchestration (Participant Factory) | [orchestration/concurrent_participant_factory.py](./orchestration/concurrent_participant_factory.py) | Use participant factories for state isolation between workflow instances |
|
||||
| Group Chat with Agent Manager | [orchestration/group_chat_agent_manager.py](./orchestration/group_chat_agent_manager.py) | Agent-based manager using `set_manager()` to select next speaker |
|
||||
| Group Chat Philosophical Debate | [orchestration/group_chat_philosophical_debate.py](./orchestration/group_chat_philosophical_debate.py) | Agent manager moderates long-form, multi-round debate across diverse participants |
|
||||
| Group Chat with Simple Function Selector | [orchestration/group_chat_simple_selector.py](./orchestration/group_chat_simple_selector.py) | Group chat with a simple function selector for next speaker |
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@ to synthesize a concise, consolidated summary from the experts' outputs.
|
||||
The workflow completes when all participants become idle.
|
||||
|
||||
Demonstrates:
|
||||
- ConcurrentBuilder().participants([...]).with_custom_aggregator(callback)
|
||||
- ConcurrentBuilder().participants([...]).with_aggregator(callback)
|
||||
- Fan-out to agents and fan-in at an aggregator
|
||||
- Aggregation implemented via an LLM call (chat_client.get_response)
|
||||
- Workflow output yielded with the synthesized summary string
|
||||
|
||||
+169
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Never
|
||||
|
||||
from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
ConcurrentBuilder,
|
||||
Executor,
|
||||
Role,
|
||||
Workflow,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
"""
|
||||
Sample: Concurrent Orchestration with participant factories and Custom Aggregator
|
||||
|
||||
Build a concurrent workflow with ConcurrentBuilder that fans out one prompt to
|
||||
multiple domain agents and fans in their responses.
|
||||
|
||||
Override the default aggregator with a custom Executor class that uses
|
||||
AzureOpenAIChatClient.get_response() to synthesize a concise, consolidated summary
|
||||
from the experts' outputs.
|
||||
|
||||
All participants and the aggregator are created via factory functions that return
|
||||
their respective ChatAgent or Executor instances.
|
||||
|
||||
Using participant factories allows you to set up proper state isolation between workflow
|
||||
instances created by the same builder. This is particularly useful when you need to handle
|
||||
requests or tasks in parallel with stateful participants.
|
||||
|
||||
Demonstrates:
|
||||
- ConcurrentBuilder().register_participants([...]).with_aggregator(callback)
|
||||
- Fan-out to agents and fan-in at an aggregator
|
||||
- Aggregation implemented via an LLM call (chat_client.get_response)
|
||||
- Workflow output yielded with the synthesized summary string
|
||||
|
||||
Prerequisites:
|
||||
- Azure OpenAI configured for AzureOpenAIChatClient (az login + required env vars)
|
||||
"""
|
||||
|
||||
|
||||
def create_researcher() -> ChatAgent:
|
||||
"""Factory function to create a researcher agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
|
||||
instructions=(
|
||||
"You're an expert market and product researcher. Given a prompt, provide concise, factual insights,"
|
||||
" opportunities, and risks."
|
||||
),
|
||||
name="researcher",
|
||||
)
|
||||
|
||||
|
||||
def create_marketer() -> ChatAgent:
|
||||
"""Factory function to create a marketer agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
|
||||
instructions=(
|
||||
"You're a creative marketing strategist. Craft compelling value propositions and target messaging"
|
||||
" aligned to the prompt."
|
||||
),
|
||||
name="marketer",
|
||||
)
|
||||
|
||||
|
||||
def create_legal() -> ChatAgent:
|
||||
"""Factory function to create a legal/compliance agent instance."""
|
||||
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
|
||||
instructions=(
|
||||
"You're a cautious legal/compliance reviewer. Highlight constraints, disclaimers, and policy concerns"
|
||||
" based on the prompt."
|
||||
),
|
||||
name="legal",
|
||||
)
|
||||
|
||||
|
||||
class SummarizationExecutor(Executor):
|
||||
"""Custom aggregator executor that synthesizes expert outputs into a concise summary."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(id="summarization_executor")
|
||||
self.chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
|
||||
|
||||
@handler
|
||||
async def summarize_results(self, results: list[Any], ctx: WorkflowContext[Never, str]) -> None:
|
||||
expert_sections: list[str] = []
|
||||
for r in results:
|
||||
try:
|
||||
messages = getattr(r.agent_run_response, "messages", [])
|
||||
final_text = messages[-1].text if messages and hasattr(messages[-1], "text") else "(no content)"
|
||||
expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}:\n{final_text}")
|
||||
except Exception as e:
|
||||
expert_sections.append(f"{getattr(r, 'executor_id', 'expert')}: (error: {type(e).__name__}: {e})")
|
||||
|
||||
# Ask the model to synthesize a concise summary of the experts' outputs
|
||||
system_msg = ChatMessage(
|
||||
Role.SYSTEM,
|
||||
text=(
|
||||
"You are a helpful assistant that consolidates multiple domain expert outputs "
|
||||
"into one cohesive, concise summary with clear takeaways. Keep it under 200 words."
|
||||
),
|
||||
)
|
||||
user_msg = ChatMessage(Role.USER, text="\n\n".join(expert_sections))
|
||||
|
||||
response = await self.chat_client.get_response([system_msg, user_msg])
|
||||
|
||||
await ctx.yield_output(response.messages[-1].text if response.messages else "")
|
||||
|
||||
|
||||
async def run_workflow(workflow: Workflow, query: str) -> None:
|
||||
events = await workflow.run(query)
|
||||
outputs = events.get_outputs()
|
||||
|
||||
if outputs:
|
||||
print(outputs[0]) # Get the first (and typically only) output
|
||||
else:
|
||||
raise RuntimeError("No outputs received from the workflow.")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a concurrent builder with participant factories and a custom aggregator
|
||||
# - register_participants([...]) accepts factory functions that return
|
||||
# AgentProtocol (agents) or Executor instances.
|
||||
# - register_aggregator(...) takes a factory function that returns an Executor instance.
|
||||
concurrent_builder = (
|
||||
ConcurrentBuilder()
|
||||
.register_participants([create_researcher, create_marketer, create_legal])
|
||||
.register_aggregator(SummarizationExecutor)
|
||||
)
|
||||
|
||||
# Build workflow_a
|
||||
workflow_a = concurrent_builder.build()
|
||||
|
||||
# Run workflow_a
|
||||
# Context is maintained across runs
|
||||
print("=== First Run on workflow_a ===")
|
||||
await run_workflow(workflow_a, "We are launching a new budget-friendly electric bike for urban commuters.")
|
||||
print("\n=== Second Run on workflow_a ===")
|
||||
await run_workflow(workflow_a, "Refine your response to focus on the California market.")
|
||||
|
||||
# Build workflow_b
|
||||
# This will create new instances of all participants and the aggregator
|
||||
# The agents will also get new threads
|
||||
workflow_b = concurrent_builder.build()
|
||||
# Run workflow_b
|
||||
# Context is not maintained across instances
|
||||
# Should not expect mentions of electric bikes in the results
|
||||
print("\n=== First Run on workflow_b ===")
|
||||
await run_workflow(workflow_b, "Refine your response to focus on the California market.")
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
=== First Run on workflow_a ===
|
||||
The budget-friendly electric bike market is poised for significant growth, driven by urbanization, ...
|
||||
|
||||
=== Second Run on workflow_a ===
|
||||
Launching a budget-friendly electric bike in California presents significant opportunities, driven ...
|
||||
|
||||
=== First Run on workflow_b ===
|
||||
To successfully penetrate the California market, consider these tailored strategies focused on ...
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user