mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Introduce add_agent functionality and added output_response to AgentExecutor; agent streaming behavior to follow workflow invocation (#1184)
* refactor AgentExecutor, add output_response flag for switching on or off workflow output for each agent. * introduce add_agent * make default agent's streaming to false * address comments * fix test * add is_streaming to RunnerContext and WorkflowContext * fix add_agent return * fix tests * address comments * resolve conflict * update to address comments * fix
This commit is contained in:
committed by
GitHub
Unverified
parent
959e5842c2
commit
d01d9afb92
@@ -1,6 +1,11 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from ._agent import WorkflowAgent
|
||||
from ._agent_executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
)
|
||||
from ._checkpoint import (
|
||||
CheckpointStorage,
|
||||
FileCheckpointStorage,
|
||||
@@ -42,9 +47,6 @@ from ._events import (
|
||||
WorkflowStatusEvent,
|
||||
)
|
||||
from ._executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from ._agent import WorkflowAgent
|
||||
from ._agent_executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
)
|
||||
from ._checkpoint import (
|
||||
CheckpointStorage,
|
||||
FileCheckpointStorage,
|
||||
@@ -40,9 +45,6 @@ from ._events import (
|
||||
WorkflowStatusEvent,
|
||||
)
|
||||
from ._executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .._agents import AgentProtocol, ChatAgent
|
||||
from .._threads import AgentThread
|
||||
from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
from ._events import (
|
||||
AgentRunEvent,
|
||||
AgentRunUpdateEvent, # type: ignore[reportPrivateUsage]
|
||||
)
|
||||
from ._executor import Executor, handler
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentExecutorRequest:
|
||||
"""A request to an agent executor.
|
||||
|
||||
Attributes:
|
||||
messages: A list of chat messages to be processed by the agent.
|
||||
should_respond: A flag indicating whether the agent should respond to the messages.
|
||||
If False, the messages will be saved to the executor's cache but not sent to the agent.
|
||||
"""
|
||||
|
||||
messages: list[ChatMessage]
|
||||
should_respond: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentExecutorResponse:
|
||||
"""A response from an agent executor.
|
||||
|
||||
Attributes:
|
||||
executor_id: The ID of the executor that generated the response.
|
||||
agent_run_response: The underlying agent run response (unaltered from client).
|
||||
full_conversation: The full conversation context (prior inputs + all assistant/tool outputs) that
|
||||
should be used when chaining to another AgentExecutor. This prevents downstream agents losing
|
||||
user prompts while keeping the emitted AgentRunEvent text faithful to the raw agent output.
|
||||
"""
|
||||
|
||||
executor_id: str
|
||||
agent_run_response: AgentRunResponse
|
||||
full_conversation: list[ChatMessage] | None = None
|
||||
|
||||
|
||||
class AgentExecutor(Executor):
|
||||
"""built-in executor that wraps an agent for handling messages.
|
||||
|
||||
AgentExecutor adapts its behavior based on the workflow execution mode:
|
||||
- run_stream(): Emits incremental AgentRunUpdateEvent events as the agent produces tokens
|
||||
- run(): Emits a single AgentRunEvent containing the complete response
|
||||
|
||||
The executor automatically detects the mode via WorkflowContext.is_streaming().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
*,
|
||||
agent_thread: AgentThread | None = None,
|
||||
output_response: bool = False,
|
||||
id: str | None = None,
|
||||
):
|
||||
"""Initialize the executor with a unique identifier.
|
||||
|
||||
Args:
|
||||
agent: The agent to be wrapped by this executor.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
|
||||
id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
"""
|
||||
# Prefer provided id; else use agent.name if present; else generate deterministic prefix
|
||||
exec_id = id or agent.name
|
||||
if not exec_id:
|
||||
raise ValueError("Agent must have a name or an explicit id must be provided.")
|
||||
super().__init__(exec_id)
|
||||
self._agent = agent
|
||||
self._agent_thread = agent_thread or self._agent.get_new_thread()
|
||||
self._output_response = output_response
|
||||
self._cache: list[ChatMessage] = []
|
||||
|
||||
@property
|
||||
def workflow_output_types(self) -> list[type[Any]]:
|
||||
# Override to declare AgentRunResponse as a possible output type only if enabled.
|
||||
if self._output_response:
|
||||
return [AgentRunResponse]
|
||||
return []
|
||||
|
||||
async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
|
||||
"""Execute the underlying agent, emit events, and enqueue response.
|
||||
|
||||
Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent
|
||||
events (streaming mode) or a single AgentRunEvent (non-streaming mode).
|
||||
"""
|
||||
if ctx.is_streaming():
|
||||
# Streaming mode: emit incremental updates
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in self._agent.run_stream(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
):
|
||||
if not update.text:
|
||||
# Skip empty updates (no textual or structural content)
|
||||
continue
|
||||
updates.append(update)
|
||||
await ctx.add_event(AgentRunUpdateEvent(self.id, update))
|
||||
|
||||
if isinstance(self._agent, ChatAgent):
|
||||
response_format = self._agent.chat_options.response_format
|
||||
response = AgentRunResponse.from_agent_run_response_updates(
|
||||
updates,
|
||||
output_format_type=response_format,
|
||||
)
|
||||
else:
|
||||
response = AgentRunResponse.from_agent_run_response_updates(updates)
|
||||
else:
|
||||
# Non-streaming mode: use run() and emit single event
|
||||
response = await self._agent.run(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
)
|
||||
await ctx.add_event(AgentRunEvent(self.id, response))
|
||||
|
||||
if self._output_response:
|
||||
await ctx.yield_output(response)
|
||||
|
||||
# Always construct a full conversation snapshot from inputs (cache)
|
||||
# plus agent outputs (agent_run_response.messages). Do not mutate
|
||||
# response.messages so AgentRunEvent remains faithful to the raw output.
|
||||
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)
|
||||
|
||||
agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
|
||||
await ctx.send_message(agent_response)
|
||||
self._cache.clear()
|
||||
|
||||
@handler
|
||||
async def run(
|
||||
self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]
|
||||
) -> None:
|
||||
"""Handle an AgentExecutorRequest (canonical input).
|
||||
|
||||
This is the standard path: extend cache with provided messages; if should_respond
|
||||
run the agent and emit an AgentExecutorResponse downstream.
|
||||
"""
|
||||
self._cache.extend(request.messages)
|
||||
if request.should_respond:
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_response(
|
||||
self, prior: AgentExecutorResponse, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]
|
||||
) -> None:
|
||||
"""Enable seamless chaining: accept a prior AgentExecutorResponse as input.
|
||||
|
||||
Strategy: treat the prior response's messages as the conversation state and
|
||||
immediately run the agent to produce a new response.
|
||||
"""
|
||||
# Replace cache with full conversation if available, else fall back to agent_run_response messages.
|
||||
if prior.full_conversation is not None:
|
||||
self._cache = list(prior.full_conversation)
|
||||
else:
|
||||
self._cache = list(prior.agent_run_response.messages)
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_str(self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
|
||||
"""Accept a raw user prompt string and run the agent (one-shot)."""
|
||||
self._cache = [ChatMessage(role="user", text=text)] # type: ignore[arg-type]
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_message(
|
||||
self,
|
||||
message: ChatMessage,
|
||||
ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse],
|
||||
) -> None:
|
||||
"""Accept a single ChatMessage as input."""
|
||||
self._cache = [message]
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_messages(
|
||||
self,
|
||||
messages: list[ChatMessage],
|
||||
ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse],
|
||||
) -> None:
|
||||
"""Accept a list of ChatMessage objects as conversation context."""
|
||||
self._cache = list(messages)
|
||||
await self._run_agent_and_emit(ctx)
|
||||
@@ -10,8 +10,9 @@ from typing_extensions import Never
|
||||
|
||||
from agent_framework import AgentProtocol, ChatMessage, Role
|
||||
|
||||
from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._executor import AgentExecutorRequest, AgentExecutorResponse, Executor, handler
|
||||
from ._executor import Executor, handler
|
||||
from ._workflow import Workflow, WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
|
||||
@@ -12,14 +12,9 @@ from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||
from textwrap import shorten
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
|
||||
from .._agents import AgentProtocol
|
||||
from .._threads import AgentThread
|
||||
from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
from ..observability import create_processing_span
|
||||
from ._checkpoint import WorkflowCheckpoint
|
||||
from ._events import (
|
||||
AgentRunEvent,
|
||||
AgentRunUpdateEvent,
|
||||
ExecutorCompletedEvent,
|
||||
ExecutorFailedEvent,
|
||||
ExecutorInvokedEvent,
|
||||
@@ -1346,165 +1341,3 @@ class RequestInfoExecutor(Executor):
|
||||
|
||||
|
||||
# endregion: Request Info Executor
|
||||
|
||||
# region Agent Executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentExecutorRequest:
|
||||
"""A request to an agent executor.
|
||||
|
||||
Attributes:
|
||||
messages: A list of chat messages to be processed by the agent.
|
||||
should_respond: A flag indicating whether the agent should respond to the messages.
|
||||
If False, the messages will be saved to the executor's cache but not sent to the agent.
|
||||
"""
|
||||
|
||||
messages: list[ChatMessage]
|
||||
should_respond: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentExecutorResponse:
|
||||
"""A response from an agent executor.
|
||||
|
||||
Attributes:
|
||||
executor_id: The ID of the executor that generated the response.
|
||||
agent_run_response: The underlying agent run response (unaltered from client).
|
||||
full_conversation: The full conversation context (prior inputs + all assistant/tool outputs) that
|
||||
should be used when chaining to another AgentExecutor. This prevents downstream agents losing
|
||||
user prompts while keeping the emitted AgentRunEvent text faithful to the raw agent output.
|
||||
"""
|
||||
|
||||
executor_id: str
|
||||
agent_run_response: AgentRunResponse
|
||||
full_conversation: list[ChatMessage] | None = None
|
||||
|
||||
|
||||
class AgentExecutor(Executor):
|
||||
"""built-in executor that wraps an agent for handling messages."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
*,
|
||||
agent_thread: AgentThread | None = None,
|
||||
streaming: bool = False,
|
||||
id: str | None = None,
|
||||
):
|
||||
"""Initialize the executor with a unique identifier.
|
||||
|
||||
Args:
|
||||
agent: The agent to be wrapped by this executor.
|
||||
|
||||
Keyword Args:
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
streaming: Enable streaming (emits incremental AgentRunUpdateEvent events) vs single response.
|
||||
id: A unique identifier for the executor. If None, a new UUID will be generated.
|
||||
"""
|
||||
# Prefer provided id; else use agent.name if present; else generate deterministic prefix
|
||||
if id is not None:
|
||||
exec_id = id
|
||||
else:
|
||||
agent_name = agent.name
|
||||
if agent_name:
|
||||
exec_id = str(agent_name)
|
||||
else:
|
||||
logger.warning("Agent has no name, using fallback ID 'executor_unnamed'")
|
||||
exec_id = "executor_unnamed"
|
||||
super().__init__(exec_id)
|
||||
self._agent = agent
|
||||
self._agent_thread = agent_thread or self._agent.get_new_thread()
|
||||
self._streaming = streaming
|
||||
self._cache: list[ChatMessage] = []
|
||||
|
||||
async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse]) -> None:
|
||||
"""Execute the underlying agent, emit events, and enqueue response.
|
||||
|
||||
Terminal detection is handled centrally in Runner.
|
||||
This method only produces AgentRunEvent/AgentRunUpdateEvent plus enqueues an
|
||||
AgentExecutorResponse message for routing.
|
||||
"""
|
||||
if self._streaming:
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in self._agent.run_stream(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
):
|
||||
# Skip empty updates (no textual or structural content)
|
||||
if not update:
|
||||
continue
|
||||
contents = getattr(update, "contents", None)
|
||||
text_val = getattr(update, "text", "")
|
||||
has_text_content = False
|
||||
if contents:
|
||||
for c in contents:
|
||||
if getattr(c, "text", None):
|
||||
has_text_content = True
|
||||
break
|
||||
if not (text_val or has_text_content):
|
||||
continue
|
||||
updates.append(update)
|
||||
await ctx.add_event(AgentRunUpdateEvent(self.id, update))
|
||||
response = AgentRunResponse.from_agent_run_response_updates(updates)
|
||||
else:
|
||||
response = await self._agent.run(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
)
|
||||
await ctx.add_event(AgentRunEvent(self.id, response))
|
||||
|
||||
# Always construct a full conversation snapshot from inputs (cache)
|
||||
# plus agent outputs (agent_run_response.messages). Do not mutate
|
||||
# response.messages so AgentRunEvent remains faithful to the raw output.
|
||||
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)
|
||||
|
||||
agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
|
||||
await ctx.send_message(agent_response)
|
||||
self._cache.clear()
|
||||
|
||||
@handler
|
||||
async def run(self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse]) -> None:
|
||||
"""Handle an AgentExecutorRequest (canonical input).
|
||||
|
||||
This is the standard path: extend cache with provided messages; if should_respond
|
||||
run the agent and emit an AgentExecutorResponse downstream.
|
||||
"""
|
||||
self._cache.extend(request.messages)
|
||||
if request.should_respond:
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_response(self, prior: AgentExecutorResponse, ctx: WorkflowContext[AgentExecutorResponse]) -> None:
|
||||
"""Enable seamless chaining: accept a prior AgentExecutorResponse as input.
|
||||
|
||||
Strategy: treat the prior response's messages as the conversation state and
|
||||
immediately run the agent to produce a new response.
|
||||
"""
|
||||
# Replace cache with full conversation if available, else fall back to agent_run_response messages.
|
||||
if prior.full_conversation is not None:
|
||||
self._cache = list(prior.full_conversation)
|
||||
else:
|
||||
self._cache = list(prior.agent_run_response.messages)
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_str(self, text: str, ctx: WorkflowContext[AgentExecutorResponse]) -> None:
|
||||
"""Accept a raw user prompt string and run the agent (one-shot)."""
|
||||
self._cache = [ChatMessage(role="user", text=text)] # type: ignore[arg-type]
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_message(self, message: ChatMessage, ctx: WorkflowContext[AgentExecutorResponse]) -> None: # type: ignore[name-defined]
|
||||
"""Accept a single ChatMessage as input."""
|
||||
self._cache = [message]
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_messages(self, messages: list[ChatMessage], ctx: WorkflowContext[AgentExecutorResponse]) -> None: # type: ignore[name-defined]
|
||||
"""Accept a list of ChatMessage objects as conversation context."""
|
||||
self._cache = list(messages)
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
|
||||
# endregion: Agent Executor
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
from ._edge import EdgeGroup
|
||||
from ._edge_runner import EdgeRunner, create_edge_runner
|
||||
from ._events import WorkflowEvent, WorkflowOutputEvent, _framework_event_origin
|
||||
from ._events import WorkflowEvent
|
||||
from ._executor import Executor
|
||||
from ._runner_context import (
|
||||
_DATACLASS_MARKER, # type: ignore
|
||||
@@ -183,54 +183,7 @@ class Runner:
|
||||
_normalize_message_payload(message)
|
||||
# Deliver a message through all edge runners associated with the source executor concurrently.
|
||||
tasks = [_deliver_message_inner(edge_runner, message) for edge_runner in associated_edge_runners]
|
||||
if not tasks:
|
||||
# No outgoing edges. If this is an AgentExecutorResponse, treat it as an
|
||||
# intentional terminal emission and emit a WorkflowOutputEvent here.
|
||||
# (Previously this relied on the executor to emit, but AgentExecutor only
|
||||
# sends an AgentExecutorResponse message; centralized completion keeps the
|
||||
# contract consistent with other executors.)
|
||||
try: # Local import to avoid circular dependencies at module import time.
|
||||
from ._executor import AgentExecutorResponse # type: ignore
|
||||
|
||||
if isinstance(message.data, AgentExecutorResponse):
|
||||
final_messages = message.data.agent_run_response.messages
|
||||
final_text = final_messages[-1].text if final_messages else "(no content)"
|
||||
with _framework_event_origin():
|
||||
# TODO(moonbox3): does user expect this event to contain the final text?
|
||||
output_event = WorkflowOutputEvent(data=final_text, source_executor_id="<Runner>")
|
||||
await self._ctx.add_event(output_event)
|
||||
continue # Terminal handled
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug("Suppressed exception during terminal message type check: %s", exc)
|
||||
# Otherwise keep prior behavior (emit warning for unexpected undelivered message).
|
||||
logger.warning(
|
||||
f"Message {message} could not be delivered (no outgoing edges). "
|
||||
"Add a downstream executor or remove the send if this is unexpected."
|
||||
)
|
||||
continue
|
||||
results = await asyncio.gather(*tasks)
|
||||
if not any(results):
|
||||
# Outgoing edges exist but none accepted the message. If this is an
|
||||
# AgentExecutorResponse, treat as natural terminal and emit completion.
|
||||
try:
|
||||
from ._executor import AgentExecutorResponse # type: ignore
|
||||
|
||||
if isinstance(message.data, AgentExecutorResponse):
|
||||
# Emit a single completion event with final text (best-effort extraction)
|
||||
final_messages = message.data.agent_run_response.messages
|
||||
final_text = final_messages[-1].text if final_messages else "(no content)"
|
||||
with _framework_event_origin():
|
||||
# TODO(moonbox3): does user expect this event to contain the final text?
|
||||
output_event = WorkflowOutputEvent(data=final_text, source_executor_id="<Runner>")
|
||||
await self._ctx.add_event(output_event)
|
||||
continue
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.debug("Terminal completion emission failed: %s", exc)
|
||||
|
||||
logger.warning(
|
||||
f"Message {message} could not be delivered. "
|
||||
"This may be due to type incompatibility or no matching targets."
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
messages = await self._ctx.drain_messages()
|
||||
tasks = [_deliver_messages(source_executor_id, messages) for source_executor_id, messages in messages.items()]
|
||||
|
||||
@@ -393,6 +393,22 @@ class RunnerContext(Protocol):
|
||||
"""Reset the context for a new workflow run."""
|
||||
...
|
||||
|
||||
def set_streaming(self, streaming: bool) -> None:
|
||||
"""Set whether agents should stream incremental updates.
|
||||
|
||||
Args:
|
||||
streaming: True for streaming mode (run_stream), False for non-streaming (run).
|
||||
"""
|
||||
...
|
||||
|
||||
def is_streaming(self) -> bool:
|
||||
"""Check if the workflow is in streaming mode.
|
||||
|
||||
Returns:
|
||||
True if streaming mode is enabled, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def create_checkpoint(self, metadata: dict[str, Any] | None = None) -> str:
|
||||
"""Create a checkpoint of the current workflow state.
|
||||
|
||||
@@ -450,6 +466,9 @@ class InProcRunnerContext:
|
||||
self._iteration_count: int = 0
|
||||
self._max_iterations: int = 100
|
||||
|
||||
# Streaming flag - set by workflow's run_stream() vs run()
|
||||
self._streaming: bool = False
|
||||
|
||||
async def send_message(self, message: Message) -> None:
|
||||
self._messages.setdefault(message.source_id, [])
|
||||
self._messages[message.source_id].append(message)
|
||||
@@ -524,6 +543,22 @@ class InProcRunnerContext:
|
||||
def set_workflow_id(self, workflow_id: str) -> None:
|
||||
self._workflow_id = workflow_id
|
||||
|
||||
def set_streaming(self, streaming: bool) -> None:
|
||||
"""Set whether agents should stream incremental updates.
|
||||
|
||||
Args:
|
||||
streaming: True for streaming mode (run_stream), False for non-streaming (run).
|
||||
"""
|
||||
self._streaming = streaming
|
||||
|
||||
def is_streaming(self) -> bool:
|
||||
"""Check if the workflow is in streaming mode.
|
||||
|
||||
Returns:
|
||||
True if streaming mode is enabled, False otherwise.
|
||||
"""
|
||||
return self._streaming
|
||||
|
||||
def reset_for_new_run(self, workflow_shared_state: SharedState | None = None) -> None:
|
||||
self._messages.clear()
|
||||
# Clear any pending events (best-effort) by recreating the queue
|
||||
@@ -531,6 +566,7 @@ class InProcRunnerContext:
|
||||
self._shared_state.clear()
|
||||
self._executor_states.clear()
|
||||
self._iteration_count = 0
|
||||
self._streaming = False # Reset streaming flag
|
||||
if workflow_shared_state is not None and hasattr(workflow_shared_state, "_state"):
|
||||
workflow_shared_state._state.clear() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@@ -42,10 +42,12 @@ from typing import Any
|
||||
|
||||
from agent_framework import AgentProtocol, ChatMessage, Role
|
||||
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._executor import (
|
||||
from ._agent_executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorResponse,
|
||||
)
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._executor import (
|
||||
Executor,
|
||||
handler,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any
|
||||
from .._agents import AgentProtocol
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent import WorkflowAgent
|
||||
from ._agent_executor import AgentExecutor
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._const import DEFAULT_MAX_ITERATIONS
|
||||
from ._edge import (
|
||||
@@ -36,7 +37,7 @@ from ._events import (
|
||||
WorkflowStatusEvent,
|
||||
_framework_event_origin, # type: ignore
|
||||
)
|
||||
from ._executor import AgentExecutor, Executor, RequestInfoExecutor
|
||||
from ._executor import Executor, RequestInfoExecutor
|
||||
from ._model_utils import DictConvertible
|
||||
from ._runner import Runner
|
||||
from ._runner_context import InProcRunnerContext, RunnerContext
|
||||
@@ -281,7 +282,10 @@ class Workflow(DictConvertible):
|
||||
return list(self.executors.values())
|
||||
|
||||
async def _run_workflow_with_tracing(
|
||||
self, initial_executor_fn: Callable[[], Awaitable[None]] | None = None, reset_context: bool = True
|
||||
self,
|
||||
initial_executor_fn: Callable[[], Awaitable[None]] | None = None,
|
||||
reset_context: bool = True,
|
||||
streaming: bool = False,
|
||||
) -> AsyncIterable[WorkflowEvent]:
|
||||
"""Private method to run workflow with proper tracing.
|
||||
|
||||
@@ -291,6 +295,7 @@ class Workflow(DictConvertible):
|
||||
Args:
|
||||
initial_executor_fn: Optional function to execute initial executor
|
||||
reset_context: Whether to reset the context for a new run
|
||||
streaming: Whether to enable streaming mode for agents
|
||||
|
||||
Yields:
|
||||
WorkflowEvent: The events generated during the workflow execution.
|
||||
@@ -323,6 +328,9 @@ class Workflow(DictConvertible):
|
||||
if reset_context:
|
||||
self._runner.context.reset_for_new_run(self._shared_state)
|
||||
|
||||
# Set streaming mode after reset
|
||||
self._runner_context.set_streaming(streaming)
|
||||
|
||||
# Execute initial setup if provided
|
||||
if initial_executor_fn:
|
||||
await initial_executor_fn()
|
||||
@@ -394,7 +402,7 @@ class Workflow(DictConvertible):
|
||||
)
|
||||
|
||||
async for event in self._run_workflow_with_tracing(
|
||||
initial_executor_fn=initial_execution, reset_context=True
|
||||
initial_executor_fn=initial_execution, reset_context=True, streaming=True
|
||||
):
|
||||
yield event
|
||||
finally:
|
||||
@@ -476,6 +484,7 @@ class Workflow(DictConvertible):
|
||||
async for event in self._run_workflow_with_tracing(
|
||||
initial_executor_fn=checkpoint_restoration,
|
||||
reset_context=False, # Don't reset context when resuming from checkpoint
|
||||
streaming=True,
|
||||
):
|
||||
yield event
|
||||
finally:
|
||||
@@ -521,6 +530,7 @@ class Workflow(DictConvertible):
|
||||
async for event in self._run_workflow_with_tracing(
|
||||
initial_executor_fn=send_responses,
|
||||
reset_context=False, # Don't reset context when sending responses
|
||||
streaming=True,
|
||||
):
|
||||
yield event
|
||||
finally:
|
||||
@@ -538,9 +548,6 @@ class Workflow(DictConvertible):
|
||||
"""
|
||||
self._ensure_not_running()
|
||||
try:
|
||||
from agent_framework import AgentRunResponse, AgentRunResponseUpdate
|
||||
|
||||
from ._events import AgentRunEvent, AgentRunUpdateEvent # Local import to avoid cycles
|
||||
|
||||
async def initial_execution() -> None:
|
||||
executor = self.get_start_executor()
|
||||
@@ -556,43 +563,18 @@ class Workflow(DictConvertible):
|
||||
raw_events = [
|
||||
event
|
||||
async for event in self._run_workflow_with_tracing(
|
||||
initial_executor_fn=initial_execution, reset_context=True
|
||||
initial_executor_fn=initial_execution,
|
||||
reset_context=True,
|
||||
)
|
||||
]
|
||||
finally:
|
||||
self._reset_running_flag()
|
||||
|
||||
# Coalesce streaming update events into a single AgentRunEvent per executor sequence.
|
||||
coalesced: list[WorkflowEvent] = []
|
||||
pending_updates: list[AgentRunResponseUpdate] = []
|
||||
pending_executor: str | None = None
|
||||
# Filter events for non-streaming mode
|
||||
filtered: list[WorkflowEvent] = []
|
||||
status_events: list[WorkflowStatusEvent] = []
|
||||
|
||||
def _flush_pending() -> None:
|
||||
nonlocal pending_updates, pending_executor
|
||||
if pending_executor is None or not pending_updates:
|
||||
return
|
||||
# Aggregate updates into a final AgentRunResponse using existing helper
|
||||
aggregated = AgentRunResponse.from_agent_run_response_updates(pending_updates)
|
||||
coalesced.append(AgentRunEvent(pending_executor, aggregated))
|
||||
pending_updates = []
|
||||
pending_executor = None
|
||||
|
||||
for ev in raw_events:
|
||||
if isinstance(ev, AgentRunUpdateEvent):
|
||||
# Start new grouping or continue existing if same executor
|
||||
if pending_executor is None:
|
||||
pending_executor = ev.executor_id
|
||||
if ev.executor_id != pending_executor:
|
||||
# Different executor encountered; flush previous first
|
||||
_flush_pending()
|
||||
pending_executor = ev.executor_id
|
||||
if ev.data is not None:
|
||||
pending_updates.append(ev.data)
|
||||
# Do NOT append update event itself (non-streaming contract)
|
||||
continue
|
||||
# Flush before adding any non-update event
|
||||
_flush_pending()
|
||||
# Omit WorkflowStartedEvent from non-streaming (telemetry-only)
|
||||
if isinstance(ev, WorkflowStartedEvent):
|
||||
continue
|
||||
@@ -600,15 +582,11 @@ class Workflow(DictConvertible):
|
||||
if isinstance(ev, WorkflowStatusEvent):
|
||||
status_events.append(ev)
|
||||
if include_status_events:
|
||||
coalesced.append(ev)
|
||||
filtered.append(ev)
|
||||
continue
|
||||
coalesced.append(ev)
|
||||
filtered.append(ev)
|
||||
|
||||
# Flush any trailing updates
|
||||
_flush_pending()
|
||||
|
||||
# coalesced already excludes start events; includes status events only if opted in
|
||||
return WorkflowRunResult(coalesced, status_events)
|
||||
return WorkflowRunResult(filtered, status_events)
|
||||
|
||||
async def run_from_checkpoint(
|
||||
self,
|
||||
@@ -928,11 +906,23 @@ class WorkflowBuilder:
|
||||
self._executors[executor.id] = executor
|
||||
return executor.id
|
||||
|
||||
def _maybe_wrap_agent(self, candidate: Executor | AgentProtocol) -> Executor:
|
||||
def _maybe_wrap_agent(
|
||||
self,
|
||||
candidate: Executor | AgentProtocol,
|
||||
agent_thread: Any | None = None,
|
||||
output_response: bool = False,
|
||||
executor_id: str | None = None,
|
||||
) -> Executor:
|
||||
"""If the provided object implements AgentProtocol, wrap it in an AgentExecutor.
|
||||
|
||||
This allows fluent builder APIs to directly accept agents instead of
|
||||
requiring callers to manually instantiate AgentExecutor.
|
||||
|
||||
Args:
|
||||
candidate: The executor or agent to wrap.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
|
||||
executor_id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
"""
|
||||
try: # Local import to avoid hard dependency at import time
|
||||
from agent_framework import AgentProtocol # type: ignore
|
||||
@@ -943,26 +933,67 @@ class WorkflowBuilder:
|
||||
return candidate
|
||||
if isinstance(candidate, AgentProtocol): # type: ignore[arg-type]
|
||||
# Reuse existing wrapper for the same agent instance if present
|
||||
existing = self._agent_wrappers.get(id(candidate))
|
||||
agent_instance_id = id(candidate)
|
||||
existing = self._agent_wrappers.get(agent_instance_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
# Use agent name if available and unique among current executors
|
||||
name = getattr(candidate, "name", None)
|
||||
proposed_id: str | None = None
|
||||
if name:
|
||||
proposed_id: str | None = executor_id
|
||||
if proposed_id is None and name:
|
||||
proposed_id = str(name)
|
||||
if proposed_id in self._executors:
|
||||
raise ValueError(
|
||||
f"Duplicate executor ID '{proposed_id}' from agent name. "
|
||||
"Agent names must be unique within a workflow."
|
||||
)
|
||||
wrapper = AgentExecutor(candidate, id=proposed_id, streaming=True)
|
||||
self._agent_wrappers[id(candidate)] = wrapper
|
||||
wrapper = AgentExecutor(
|
||||
candidate,
|
||||
agent_thread=agent_thread,
|
||||
output_response=output_response,
|
||||
id=proposed_id,
|
||||
)
|
||||
self._agent_wrappers[agent_instance_id] = wrapper
|
||||
return wrapper
|
||||
raise TypeError(
|
||||
f"WorkflowBuilder expected an Executor or AgentProtocol instance; got {type(candidate).__name__}."
|
||||
)
|
||||
|
||||
def add_agent(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent_thread: Any | None = None,
|
||||
output_response: bool = False,
|
||||
id: str | None = None,
|
||||
) -> Self:
|
||||
"""Add an agent to the workflow by wrapping it in an AgentExecutor.
|
||||
|
||||
This method creates an AgentExecutor that wraps the agent with the given parameters
|
||||
and ensures that subsequent uses of the same agent instance in other builder methods
|
||||
(like add_edge, set_start_executor, etc.) will reuse the same wrapped executor.
|
||||
|
||||
Note: Agents adapt their behavior based on how the workflow is executed:
|
||||
- run_stream(): Agents emit incremental AgentRunUpdateEvent events as tokens are produced
|
||||
- run(): Agents emit a single AgentRunEvent containing the complete response
|
||||
|
||||
Args:
|
||||
agent: The agent to add to the workflow.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
|
||||
id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
|
||||
Returns:
|
||||
The WorkflowBuilder instance (for method chaining).
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided id or agent name conflicts with an existing executor.
|
||||
"""
|
||||
executor = self._maybe_wrap_agent(
|
||||
agent, agent_thread=agent_thread, output_response=output_response, executor_id=id
|
||||
)
|
||||
self._add_executor(executor)
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
|
||||
@@ -449,3 +449,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
if hasattr(self._runner_context, "get_state"):
|
||||
return await self._runner_context.get_state(self._executor_id) # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def is_streaming(self) -> bool:
|
||||
"""Check if the workflow is running in streaming mode.
|
||||
|
||||
Returns:
|
||||
True if the workflow was started with run_stream(), False if started with run().
|
||||
"""
|
||||
return self._runner_context.is_streaming()
|
||||
|
||||
@@ -8,22 +8,22 @@ from typing_extensions import Never
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorResponse,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentThread,
|
||||
BaseAgent,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
Role,
|
||||
SequentialBuilder,
|
||||
TextContent,
|
||||
WorkflowBuilder,
|
||||
WorkflowOutputEvent,
|
||||
WorkflowContext,
|
||||
WorkflowRunState,
|
||||
WorkflowStatusEvent,
|
||||
handler,
|
||||
)
|
||||
from agent_framework._workflows._executor import AgentExecutorResponse, Executor
|
||||
from agent_framework._workflows._workflow_context import WorkflowContext
|
||||
|
||||
|
||||
class _SimpleAgent(BaseAgent):
|
||||
@@ -71,28 +71,22 @@ class _CaptureFullConversation(Executor):
|
||||
|
||||
|
||||
async def test_agent_executor_populates_full_conversation_non_streaming() -> None:
|
||||
# Arrange: non-streaming AgentExecutor for deterministic response composition
|
||||
# Arrange: AgentExecutor will be non-streaming when using workflow.run()
|
||||
agent = _SimpleAgent(id="agent1", name="A", reply_text="agent-reply")
|
||||
agent_exec = AgentExecutor(agent, streaming=False, id="agent1-exec")
|
||||
agent_exec = AgentExecutor(agent, id="agent1-exec")
|
||||
capturer = _CaptureFullConversation(id="capture")
|
||||
|
||||
wf = WorkflowBuilder().set_start_executor(agent_exec).add_edge(agent_exec, capturer).build()
|
||||
|
||||
# Act: run with a simple user prompt
|
||||
completed = False
|
||||
output: dict | None = None
|
||||
async for ev in wf.run_stream("hello world"):
|
||||
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
|
||||
completed = True
|
||||
elif isinstance(ev, WorkflowOutputEvent):
|
||||
output = ev.data # type: ignore[assignment]
|
||||
if completed and output is not None:
|
||||
break
|
||||
# Act: use run() instead of run_stream() to test non-streaming mode
|
||||
result = await wf.run("hello world")
|
||||
|
||||
# Extract output from run result
|
||||
outputs = result.get_outputs()
|
||||
assert len(outputs) == 1
|
||||
payload = outputs[0]
|
||||
|
||||
# Assert: full_conversation contains [user("hello world"), assistant("agent-reply")]
|
||||
assert completed
|
||||
assert output is not None
|
||||
payload = output
|
||||
assert isinstance(payload, dict)
|
||||
assert payload["length"] == 2
|
||||
assert payload["roles"][0] == Role.USER and "hello world" in (payload["texts"][0] or "")
|
||||
|
||||
@@ -2,12 +2,21 @@
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentRunEvent,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentRunUpdateEvent,
|
||||
AgentThread,
|
||||
BaseAgent,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
FileCheckpointStorage,
|
||||
Message,
|
||||
@@ -15,6 +24,8 @@ from agent_framework import (
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
Role,
|
||||
TextContent,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
@@ -789,3 +800,76 @@ async def test_workflow_concurrent_execution_prevention_mixed_methods():
|
||||
# Now all methods should work again
|
||||
result = await workflow.run(NumberMessage(data=0))
|
||||
assert result.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
|
||||
class _StreamingTestAgent(BaseAgent):
|
||||
"""Test agent that supports both streaming and non-streaming modes."""
|
||||
|
||||
def __init__(self, *, reply_text: str, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._reply_text = reply_text
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResponse:
|
||||
"""Non-streaming run - returns complete response."""
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=self._reply_text)])
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
"""Streaming run - yields incremental updates."""
|
||||
# Simulate streaming by yielding character by character
|
||||
for char in self._reply_text:
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text=char)])
|
||||
|
||||
|
||||
async def test_agent_streaming_vs_non_streaming() -> None:
|
||||
"""Test that run() emits AgentRunEvent while run_stream() emits AgentRunUpdateEvent."""
|
||||
agent = _StreamingTestAgent(id="test_agent", name="TestAgent", reply_text="Hello World")
|
||||
agent_exec = AgentExecutor(agent, id="agent_exec")
|
||||
|
||||
workflow = WorkflowBuilder().set_start_executor(agent_exec).build()
|
||||
|
||||
# Test non-streaming mode with run()
|
||||
result = await workflow.run("test message")
|
||||
|
||||
# Filter for agent events (result is a list of events)
|
||||
agent_run_events = [e for e in result if isinstance(e, AgentRunEvent)]
|
||||
agent_update_events = [e for e in result if isinstance(e, AgentRunUpdateEvent)]
|
||||
|
||||
# In non-streaming mode, should have AgentRunEvent, no AgentRunUpdateEvent
|
||||
assert len(agent_run_events) == 1, "Expected exactly one AgentRunEvent in non-streaming mode"
|
||||
assert len(agent_update_events) == 0, "Expected no AgentRunUpdateEvent in non-streaming mode"
|
||||
assert agent_run_events[0].executor_id == "agent_exec"
|
||||
assert agent_run_events[0].data.messages[0].text == "Hello World"
|
||||
|
||||
# Test streaming mode with run_stream()
|
||||
stream_events: list[WorkflowEvent] = []
|
||||
async for event in workflow.run_stream("test message"):
|
||||
stream_events.append(event)
|
||||
|
||||
# Filter for agent events
|
||||
stream_agent_run_events = [e for e in stream_events if isinstance(e, AgentRunEvent)]
|
||||
stream_agent_update_events = [e for e in stream_events if isinstance(e, AgentRunUpdateEvent)]
|
||||
|
||||
# In streaming mode, should have AgentRunUpdateEvent, no AgentRunEvent
|
||||
assert len(stream_agent_run_events) == 0, "Expected no AgentRunEvent in streaming mode"
|
||||
assert len(stream_agent_update_events) > 0, "Expected AgentRunUpdateEvent events in streaming mode"
|
||||
|
||||
# Verify we got incremental updates (one per character in "Hello World")
|
||||
assert len(stream_agent_update_events) == len("Hello World"), "Expected one update per character"
|
||||
|
||||
# Verify the updates build up to the full message
|
||||
accumulated_text = "".join(
|
||||
e.data.contents[0].text for e in stream_agent_update_events if e.data.contents and e.data.contents[0].text
|
||||
)
|
||||
assert accumulated_text == "Hello World", f"Expected 'Hello World', got '{accumulated_text}'"
|
||||
|
||||
@@ -47,14 +47,6 @@ def test_builder_accepts_agents_directly():
|
||||
assert any(isinstance(e, AgentExecutor) and e.id in {"writer", "reviewer"} for e in wf.executors.values())
|
||||
|
||||
|
||||
def test_builder_agents_always_stream():
|
||||
agent = DummyAgent(id="agentX", name="streamer")
|
||||
wf = WorkflowBuilder().set_start_executor(agent).build()
|
||||
exec_obj = wf.get_start_executor()
|
||||
assert isinstance(exec_obj, AgentExecutor)
|
||||
assert getattr(exec_obj, "_streaming", False) is True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMessage:
|
||||
"""A mock message for testing purposes."""
|
||||
@@ -111,3 +103,108 @@ def test_workflow_builder_fluent_api():
|
||||
assert len(workflow.edge_groups) == 4
|
||||
assert workflow.start_executor_id == executor_a.id
|
||||
assert len(workflow.executors) == 6
|
||||
|
||||
|
||||
def test_add_agent_with_custom_parameters():
|
||||
"""Test adding an agent with custom parameters."""
|
||||
agent = DummyAgent(id="agent_custom", name="custom_agent")
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
# Add agent with custom parameters
|
||||
result = builder.add_agent(agent, output_response=True, id="my_custom_id")
|
||||
|
||||
# Verify that add_agent returns the builder for chaining
|
||||
assert result is builder
|
||||
|
||||
# Build workflow and verify executor is present
|
||||
workflow = builder.set_start_executor(agent).build()
|
||||
assert "my_custom_id" in workflow.executors
|
||||
|
||||
# Verify the executor was created with correct parameters
|
||||
executor = workflow.executors["my_custom_id"]
|
||||
assert isinstance(executor, AgentExecutor)
|
||||
assert executor.id == "my_custom_id"
|
||||
assert getattr(executor, "_output_response", False) is True
|
||||
|
||||
|
||||
def test_add_agent_reuses_same_wrapper():
|
||||
"""Test that using the same agent instance multiple times reuses the same wrapper."""
|
||||
agent = DummyAgent(id="agent_reuse", name="reuse_agent")
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
# Add agent with specific parameters
|
||||
builder.add_agent(agent, output_response=True, id="agent_exec")
|
||||
|
||||
# Use the same agent instance in add_edge - should reuse the same wrapper
|
||||
builder.set_start_executor(agent)
|
||||
|
||||
workflow = builder.build()
|
||||
|
||||
# Verify only one executor exists for this agent
|
||||
assert workflow.start_executor_id == "agent_exec"
|
||||
assert "agent_exec" in workflow.executors
|
||||
assert len([e for e in workflow.executors.values() if isinstance(e, AgentExecutor)]) == 1
|
||||
|
||||
# Verify the executor has the parameters from add_agent
|
||||
start_executor = workflow.get_start_executor()
|
||||
assert isinstance(start_executor, AgentExecutor)
|
||||
assert getattr(start_executor, "_output_response", False) is True
|
||||
|
||||
|
||||
def test_add_agent_then_use_in_edges():
|
||||
"""Test that an agent added via add_agent can be used in edge definitions."""
|
||||
agent1 = DummyAgent(id="agent1", name="first")
|
||||
agent2 = DummyAgent(id="agent2", name="second")
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
# Add agents with specific settings
|
||||
builder.add_agent(agent1, output_response=False, id="exec1")
|
||||
builder.add_agent(agent2, output_response=True, id="exec2")
|
||||
|
||||
# Use the same agent instances to create edges
|
||||
workflow = builder.set_start_executor(agent1).add_edge(agent1, agent2).build()
|
||||
|
||||
# Verify the executors maintain their settings
|
||||
assert workflow.start_executor_id == "exec1"
|
||||
assert "exec1" in workflow.executors
|
||||
assert "exec2" in workflow.executors
|
||||
|
||||
e1 = workflow.executors["exec1"]
|
||||
e2 = workflow.executors["exec2"]
|
||||
|
||||
assert isinstance(e1, AgentExecutor)
|
||||
assert isinstance(e2, AgentExecutor)
|
||||
assert getattr(e1, "_output_response", True) is False
|
||||
assert getattr(e2, "_output_response", False) is True
|
||||
|
||||
|
||||
def test_add_agent_without_explicit_id_uses_agent_name():
|
||||
"""Test that add_agent uses agent name as id when no explicit id is provided."""
|
||||
agent = DummyAgent(id="agent_x", name="named_agent")
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
result = builder.add_agent(agent)
|
||||
|
||||
# Verify that add_agent returns the builder for chaining
|
||||
assert result is builder
|
||||
|
||||
workflow = builder.set_start_executor(agent).build()
|
||||
assert "named_agent" in workflow.executors
|
||||
|
||||
# Verify the executor id matches the agent name
|
||||
executor = workflow.executors["named_agent"]
|
||||
assert executor.id == "named_agent"
|
||||
|
||||
|
||||
def test_add_agent_duplicate_id_raises_error():
|
||||
"""Test that adding agents with duplicate IDs raises an error."""
|
||||
agent1 = DummyAgent(id="agent1", name="first")
|
||||
agent2 = DummyAgent(id="agent2", name="first") # Same name as agent1
|
||||
builder = WorkflowBuilder()
|
||||
|
||||
# Add first agent
|
||||
builder.add_agent(agent1)
|
||||
|
||||
# Adding second agent with same name should raise ValueError
|
||||
with pytest.raises(ValueError, match="Duplicate executor ID"):
|
||||
builder.add_agent(agent2)
|
||||
|
||||
@@ -16,12 +16,13 @@ A Writer agent generates content, then a Reviewer agent critiques it.
|
||||
The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens.
|
||||
|
||||
Purpose:
|
||||
Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors.
|
||||
Show how to wire chat agents into a WorkflowBuilder pipeline using add_agent
|
||||
with settings for streaming and workflow outputs.
|
||||
|
||||
Demonstrate:
|
||||
- Automatic streaming of agent deltas via AgentRunUpdateEvent.
|
||||
- A simple console aggregator that groups updates by executor id and prints them as they arrive.
|
||||
- The workflow completes when idle and outputs are available in events.get_outputs().
|
||||
- Automatic streaming of agent deltas via AgentRunUpdateEvent when using run_stream().
|
||||
- Add an agent via WorkflowBuilder.add_agent() with output_response=True to emit final AgentRunResponse.
|
||||
- Agents adapt to workflow mode: run_stream() emits incremental updates, run() emits complete responses.
|
||||
|
||||
Prerequisites:
|
||||
- Azure AI Agent Service configured, along with the required environment variables.
|
||||
@@ -66,8 +67,17 @@ async def main() -> None:
|
||||
"Provide the feedback in the most concise manner possible."
|
||||
),
|
||||
)
|
||||
|
||||
workflow = WorkflowBuilder().set_start_executor(writer).add_edge(writer, reviewer).build()
|
||||
# Add agents to workflow with custom settings using add_agent.
|
||||
# Agents adapt to workflow mode: run_stream() for incremental updates, run() for complete responses.
|
||||
# Reviewer agent emits final AgentRunResponse as a workflow output.
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.add_agent(writer, id="Writer")
|
||||
.add_agent(reviewer, id="Reviewer", output_response=True)
|
||||
.set_start_executor(writer)
|
||||
.add_edge(writer, reviewer)
|
||||
.build()
|
||||
)
|
||||
|
||||
last_executor_id: str | None = None
|
||||
|
||||
|
||||
@@ -13,12 +13,13 @@ A Writer agent generates content, then a Reviewer agent critiques it.
|
||||
The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens.
|
||||
|
||||
Purpose:
|
||||
Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors.
|
||||
Show how to wire chat agents into a WorkflowBuilder pipeline using add_agent
|
||||
with settings for streaming and workflow outputs.
|
||||
|
||||
Demonstrate:
|
||||
- Automatic streaming of agent deltas via AgentRunUpdateEvent.
|
||||
- A simple console aggregator that groups updates by executor id and prints them as they arrive.
|
||||
- The workflow completes when idle and outputs are available in events.get_outputs().
|
||||
- Automatic streaming of agent deltas via AgentRunUpdateEvent when using run_stream().
|
||||
- Add an agent via WorkflowBuilder.add_agent() with output_response=True to emit final AgentRunResponse.
|
||||
- Agents adapt to workflow mode: run_stream() emits incremental updates, run() emits complete responses.
|
||||
|
||||
Prerequisites:
|
||||
- Azure OpenAI configured for AzureOpenAIChatClient with required environment variables.
|
||||
@@ -32,7 +33,7 @@ async def main():
|
||||
# Create the Azure chat client. AzureCliCredential uses your current az login.
|
||||
chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
|
||||
|
||||
# Define two domain specific chat agents. The builder will wrap these as executors.
|
||||
# Define two domain specific chat agents.
|
||||
writer_agent = chat_client.create_agent(
|
||||
instructions=(
|
||||
"You are an excellent content writer. You create new content and edit contents based on the feedback."
|
||||
@@ -50,11 +51,21 @@ async def main():
|
||||
)
|
||||
|
||||
# Build the workflow using the fluent builder.
|
||||
# Add agents to workflow with custom settings using add_agent.
|
||||
# Agents adapt to workflow mode: run_stream() for incremental updates, run() for complete responses.
|
||||
# Reviewer agent emits final AgentRunResponse as a workflow output.
|
||||
# Set the start node and connect an edge from writer to reviewer.
|
||||
workflow = WorkflowBuilder().set_start_executor(writer_agent).add_edge(writer_agent, reviewer_agent).build()
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.add_agent(writer_agent, id="Writer")
|
||||
.add_agent(reviewer_agent, id="Reviewer", output_response=True)
|
||||
.set_start_executor(writer_agent)
|
||||
.add_edge(writer_agent, reviewer_agent)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Stream events from the workflow. We aggregate partial token updates per executor for readable output.
|
||||
last_executor_id = None
|
||||
last_executor_id: str | None = None
|
||||
|
||||
events = workflow.run_stream("Create a slogan for a new electric SUV that is affordable and fun to drive.")
|
||||
async for event in events:
|
||||
@@ -62,14 +73,14 @@ async def main():
|
||||
# AgentRunUpdateEvent contains incremental text deltas from the underlying agent.
|
||||
# Print a prefix when the executor changes, then append updates on the same line.
|
||||
eid = event.executor_id
|
||||
if eid != last_executor_id: # type: ignore[reportUnnecessaryComparison]
|
||||
if eid != last_executor_id:
|
||||
if last_executor_id is not None:
|
||||
print()
|
||||
print(f"{eid}:", end=" ", flush=True)
|
||||
last_executor_id = eid
|
||||
print(event.data, end="", flush=True)
|
||||
elif isinstance(event, WorkflowOutputEvent):
|
||||
print("===== Final Output =====")
|
||||
print("\n===== Final output =====")
|
||||
print(event.data)
|
||||
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user