mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
[BREAKING] Python: Add context mode to AgentExecutor (#4668)
* Add context mode to AgentExecutor * Fix unit tests * Address comments * Address comments * REvise context mode and add tests * Add chain config to sequential builder * Add sample * Fix pipeline * Address comments * Address comments
This commit is contained in:
committed by
GitHub
Unverified
parent
88ea9d08c7
commit
51828abed4
@@ -232,6 +232,7 @@ class TestSerializationRoundtrip:
|
||||
original = AgentExecutorResponse(
|
||||
executor_id="test_exec",
|
||||
agent_response=AgentResponse(messages=[Message(role="assistant", text="Reply")]),
|
||||
full_conversation=[Message(role="assistant", text="Reply")],
|
||||
)
|
||||
encoded = serialize_value(original)
|
||||
decoded = deserialize_value(encoded)
|
||||
|
||||
@@ -212,6 +212,7 @@ class TestExtractMessageContent:
|
||||
response = AgentExecutorResponse(
|
||||
executor_id="exec",
|
||||
agent_response=AgentResponse(messages=[Message(role="assistant", text="Response text")]),
|
||||
full_conversation=[Message(role="assistant", text="Response text")],
|
||||
)
|
||||
|
||||
result = _extract_message_content(response)
|
||||
@@ -228,6 +229,10 @@ class TestExtractMessageContent:
|
||||
Message(role="assistant", text="Last message"),
|
||||
]
|
||||
),
|
||||
full_conversation=[
|
||||
Message(role="user", text="First"),
|
||||
Message(role="assistant", text="Last message"),
|
||||
],
|
||||
)
|
||||
|
||||
result = _extract_message_content(response)
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from typing_extensions import Never
|
||||
|
||||
@@ -57,7 +57,7 @@ class AgentExecutorResponse:
|
||||
|
||||
executor_id: str
|
||||
agent_response: AgentResponse
|
||||
full_conversation: list[Message] | None = None
|
||||
full_conversation: list[Message]
|
||||
|
||||
|
||||
class AgentExecutor(Executor):
|
||||
@@ -83,6 +83,8 @@ class AgentExecutor(Executor):
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
id: str | None = None,
|
||||
context_mode: Literal["full", "last_agent", "custom"] | None = None,
|
||||
context_filter: Callable[[list[Message]], list[Message]] | None = None,
|
||||
):
|
||||
"""Initialize the executor with a unique identifier.
|
||||
|
||||
@@ -90,6 +92,16 @@ class AgentExecutor(Executor):
|
||||
agent: The agent to be wrapped by this executor.
|
||||
session: The session to use for running the agent. If None, a new session will be created.
|
||||
id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
context_mode: Configuration for how the executor should manage conversation context upon
|
||||
receiving an AgentExecutorResponse as input. Options:
|
||||
- "full": append the full conversation (all prior messages + latest agent response) to the
|
||||
cache for the agent run. This is the default mode.
|
||||
- "last_agent": provide only the messages from the latest agent response as context for
|
||||
the agent run.
|
||||
- "custom": use the provided context_filter function to determine which messages to include
|
||||
as context for the agent run.
|
||||
context_filter: An optional function for filtering conversation context when context_mode is set
|
||||
to "custom".
|
||||
"""
|
||||
# Prefer provided id; else use agent.name if present; else generate deterministic prefix
|
||||
exec_id = id or resolve_agent_id(agent)
|
||||
@@ -107,6 +119,14 @@ class AgentExecutor(Executor):
|
||||
# This tracks the full conversation after each run
|
||||
self._full_conversation: list[Message] = []
|
||||
|
||||
# Context mode validation
|
||||
self._context_mode = context_mode or "full"
|
||||
self._context_filter = context_filter
|
||||
if self._context_mode not in {"full", "last_agent", "custom"}:
|
||||
raise ValueError("context_mode must be one of 'full', 'last_agent', or 'custom'.")
|
||||
if self._context_mode == "custom" and not self._context_filter:
|
||||
raise ValueError("context_filter must be provided when context_mode is set to 'custom'.")
|
||||
|
||||
@property
|
||||
def agent(self) -> SupportsAgentRun:
|
||||
"""Get the underlying agent wrapped by this executor."""
|
||||
@@ -129,6 +149,7 @@ class AgentExecutor(Executor):
|
||||
run the agent and emit an AgentExecutorResponse downstream.
|
||||
"""
|
||||
self._cache.extend(request.messages)
|
||||
|
||||
if request.should_respond:
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@@ -143,19 +164,27 @@ class AgentExecutor(Executor):
|
||||
Strategy: treat the prior response's messages as the conversation state and
|
||||
immediately run the agent to produce a new response.
|
||||
"""
|
||||
# Replace cache with full conversation if available, else fall back to agent_response messages.
|
||||
source_messages = (
|
||||
prior.full_conversation if prior.full_conversation is not None else prior.agent_response.messages
|
||||
)
|
||||
self._cache = list(source_messages)
|
||||
if self._context_mode == "full":
|
||||
self._cache.extend(prior.full_conversation)
|
||||
elif self._context_mode == "last_agent":
|
||||
self._cache.extend(prior.agent_response.messages)
|
||||
else:
|
||||
if not self._context_filter:
|
||||
# This should never happen due to validation in __init__, but mypy doesn't track that well
|
||||
raise ValueError("context_filter function must be provided for 'custom' context_mode.")
|
||||
self._cache.extend(self._context_filter(prior.full_conversation))
|
||||
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
async def from_str(
|
||||
self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate]
|
||||
) -> None:
|
||||
"""Accept a raw user prompt string and run the agent (one-shot)."""
|
||||
self._cache = normalize_messages_input(text)
|
||||
"""Accept a raw user prompt string and run the agent.
|
||||
|
||||
The new string input will be added to the cache which is used as the conversation context for the agent run.
|
||||
"""
|
||||
self._cache.extend(normalize_messages_input(text))
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
@@ -164,8 +193,11 @@ class AgentExecutor(Executor):
|
||||
message: Message,
|
||||
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
|
||||
) -> None:
|
||||
"""Accept a single Message as input."""
|
||||
self._cache = normalize_messages_input(message)
|
||||
"""Accept a single Message as input.
|
||||
|
||||
The new message will be added to the cache which is used as the conversation context for the agent run.
|
||||
"""
|
||||
self._cache.extend(normalize_messages_input(message))
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@handler
|
||||
@@ -174,8 +206,11 @@ class AgentExecutor(Executor):
|
||||
messages: list[str | Message],
|
||||
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
|
||||
) -> None:
|
||||
"""Accept a list of chat inputs (strings or Message) as conversation context."""
|
||||
self._cache = normalize_messages_input(messages)
|
||||
"""Accept a list of chat inputs (strings or Message) as conversation context.
|
||||
|
||||
The new messages will be added to the cache which is used as the conversation context for the agent run.
|
||||
"""
|
||||
self._cache.extend(normalize_messages_input(messages))
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@response_handler
|
||||
@@ -249,24 +284,10 @@ class AgentExecutor(Executor):
|
||||
state: Checkpoint data dict
|
||||
"""
|
||||
cache_payload = state.get("cache")
|
||||
if cache_payload:
|
||||
try:
|
||||
self._cache = cache_payload
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to restore cache: %s", exc)
|
||||
self._cache = []
|
||||
else:
|
||||
self._cache = []
|
||||
self._cache = cache_payload or []
|
||||
|
||||
full_conversation_payload = state.get("full_conversation")
|
||||
if full_conversation_payload:
|
||||
try:
|
||||
self._full_conversation = full_conversation_payload
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to restore full conversation: %s", exc)
|
||||
self._full_conversation = []
|
||||
else:
|
||||
self._full_conversation = []
|
||||
self._full_conversation = full_conversation_payload or []
|
||||
|
||||
session_payload = state.get("agent_session")
|
||||
if session_payload:
|
||||
@@ -279,12 +300,10 @@ class AgentExecutor(Executor):
|
||||
self._session = self._agent.create_session()
|
||||
|
||||
pending_requests_payload = state.get("pending_agent_requests")
|
||||
if pending_requests_payload:
|
||||
self._pending_agent_requests = pending_requests_payload
|
||||
self._pending_agent_requests = pending_requests_payload or {}
|
||||
|
||||
pending_responses_payload = state.get("pending_responses_to_agent")
|
||||
if pending_responses_payload:
|
||||
self._pending_responses_to_agent = pending_responses_payload
|
||||
self._pending_responses_to_agent = pending_responses_payload or []
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the internal cache of the executor."""
|
||||
|
||||
@@ -16,12 +16,12 @@ from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
ResponseStream,
|
||||
WorkflowBuilder,
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
)
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorResponse
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
@@ -139,7 +139,7 @@ async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks()
|
||||
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
|
||||
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
|
||||
executor = AgentExecutor(agent, id="hook_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
output_events: list[Any] = []
|
||||
async for event in workflow.run("run hook test", stream=True):
|
||||
@@ -154,8 +154,9 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
# Create initial agent with a custom session
|
||||
initial_agent = _CountingAgent(id="test_agent", name="TestAgent")
|
||||
# Create two agents to form a two-step workflow
|
||||
initial_agent_a = _CountingAgent(id="agent_a", name="AgentA")
|
||||
initial_agent_b = _CountingAgent(id="agent_b", name="AgentB")
|
||||
initial_session = AgentSession()
|
||||
|
||||
# Add some initial messages to the session state to verify session state persistence
|
||||
@@ -165,11 +166,12 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
]
|
||||
initial_session.state["history"] = {"messages": initial_messages}
|
||||
|
||||
# Create AgentExecutor with the session
|
||||
executor = AgentExecutor(initial_agent, session=initial_session)
|
||||
# Create AgentExecutors — first executor gets the custom session
|
||||
exec_a = AgentExecutor(initial_agent_a, id="exec_a", session=initial_session)
|
||||
exec_b = AgentExecutor(initial_agent_b, id="exec_b")
|
||||
|
||||
# Build workflow with checkpointing enabled
|
||||
wf = SequentialBuilder(participants=[executor], checkpoint_storage=storage).build()
|
||||
# Build two-executor workflow with checkpointing enabled
|
||||
wf = WorkflowBuilder(start_executor=exec_a, checkpoint_storage=storage).add_edge(exec_a, exec_b).build()
|
||||
|
||||
# Run the workflow with a user message
|
||||
first_run_output: AgentExecutorResponse | None = None
|
||||
@@ -180,27 +182,25 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
break
|
||||
|
||||
assert first_run_output is not None
|
||||
assert initial_agent.call_count == 1
|
||||
assert initial_agent_a.call_count == 1
|
||||
|
||||
# Verify checkpoint was created
|
||||
checkpoints = await storage.list_checkpoints(workflow_name=wf.name)
|
||||
assert len(checkpoints) >= 2, (
|
||||
"Expected at least 2 checkpoints. The first one is after the start executor, "
|
||||
"and the second one is after the agent execution."
|
||||
assert len(checkpoints) >= 2, "Expected at least 2 checkpoints: one after exec_a and one after exec_b."
|
||||
|
||||
# Get the first checkpoint that contains exec_a's state (taken after exec_a completes,
|
||||
# before exec_b runs)
|
||||
checkpoints.sort(key=lambda cp: cp.timestamp)
|
||||
restore_checkpoint = next(
|
||||
cp for cp in checkpoints if "_executor_state" in cp.state and "exec_a" in cp.state["_executor_state"]
|
||||
)
|
||||
|
||||
# Get the second checkpoint which should contain the state after processing
|
||||
# the first message by the start executor in the sequential workflow
|
||||
checkpoints.sort(key=lambda cp: cp.timestamp)
|
||||
restore_checkpoint = checkpoints[1]
|
||||
|
||||
# Verify checkpoint contains executor state with both cache and session
|
||||
assert "_executor_state" in restore_checkpoint.state
|
||||
executor_states = restore_checkpoint.state["_executor_state"]
|
||||
assert isinstance(executor_states, dict)
|
||||
assert executor.id in executor_states
|
||||
assert exec_a.id in executor_states
|
||||
|
||||
executor_state = executor_states[executor.id] # type: ignore[index]
|
||||
executor_state = executor_states[exec_a.id] # type: ignore[index]
|
||||
assert "cache" in executor_state, "Checkpoint should store executor cache state"
|
||||
assert "agent_session" in executor_state, "Checkpoint should store executor session state"
|
||||
|
||||
@@ -213,19 +213,26 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
assert "pending_agent_requests" in executor_state
|
||||
assert "pending_responses_to_agent" in executor_state
|
||||
|
||||
# Create a new agent and executor for restoration
|
||||
# Create new agents and executors for restoration
|
||||
# This simulates starting from a fresh state and restoring from checkpoint
|
||||
restored_agent = _CountingAgent(id="test_agent", name="TestAgent")
|
||||
restored_agent_a = _CountingAgent(id="agent_a", name="AgentA")
|
||||
restored_agent_b = _CountingAgent(id="agent_b", name="AgentB")
|
||||
restored_session = AgentSession()
|
||||
restored_executor = AgentExecutor(restored_agent, session=restored_session)
|
||||
restored_exec_a = AgentExecutor(restored_agent_a, id="exec_a", session=restored_session)
|
||||
restored_exec_b = AgentExecutor(restored_agent_b, id="exec_b")
|
||||
|
||||
# Verify the restored agent starts with a fresh state
|
||||
assert restored_agent.call_count == 0
|
||||
# Verify the restored agents start with a fresh state
|
||||
assert restored_agent_a.call_count == 0
|
||||
assert restored_agent_b.call_count == 0
|
||||
|
||||
# Build new workflow with the restored executor
|
||||
wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build()
|
||||
# Build new workflow with the restored executors
|
||||
wf_resume = (
|
||||
WorkflowBuilder(start_executor=restored_exec_a, checkpoint_storage=storage)
|
||||
.add_edge(restored_exec_a, restored_exec_b)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Resume from checkpoint
|
||||
# Resume from checkpoint — exec_a already ran, so exec_b should run and produce output
|
||||
resumed_output: AgentExecutorResponse | None = None
|
||||
async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True):
|
||||
if ev.type == "output":
|
||||
@@ -239,7 +246,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
assert resumed_output is not None
|
||||
|
||||
# Verify the restored executor's session state was restored
|
||||
restored_session_obj = restored_executor._session # type: ignore[reportPrivateUsage]
|
||||
restored_session_obj = restored_exec_a._session # type: ignore[reportPrivateUsage]
|
||||
assert restored_session_obj is not None
|
||||
assert restored_session_obj.session_id == initial_session.session_id
|
||||
|
||||
@@ -306,7 +313,7 @@ async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None:
|
||||
"""Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295)."""
|
||||
agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent")
|
||||
executor = AgentExecutor(agent, id="session_kwarg_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
# This previously raised: TypeError: run() got multiple values for keyword argument 'session'
|
||||
result = await workflow.run("hello", session="user-supplied-value")
|
||||
@@ -318,7 +325,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -
|
||||
"""Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
|
||||
agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent")
|
||||
executor = AgentExecutor(agent, id="stream_kwarg_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
# stream=True at workflow level triggers streaming mode (returns async iterable)
|
||||
events: list[WorkflowEvent] = []
|
||||
@@ -378,7 +385,7 @@ async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None:
|
||||
"""Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
|
||||
agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent")
|
||||
executor = AgentExecutor(agent, id="messages_kwarg_exec")
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
result = await workflow.run("hello", messages=["stale"])
|
||||
assert result is not None
|
||||
@@ -426,7 +433,7 @@ async def test_agent_executor_workflow_with_non_copyable_raw_representation() ->
|
||||
exec_a = AgentExecutor(agent_a, id="exec_a")
|
||||
exec_b = AgentExecutor(agent_b, id="exec_b")
|
||||
|
||||
workflow = SequentialBuilder(participants=[exec_a, exec_b]).build()
|
||||
workflow = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
|
||||
events = await workflow.run("hello")
|
||||
|
||||
completed = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
|
||||
@@ -440,3 +447,194 @@ async def test_agent_executor_workflow_with_non_copyable_raw_representation() ->
|
||||
assert len(agent_responses) > 0
|
||||
assert agent_responses[0].text == "reply from AgentA"
|
||||
assert agent_responses[0].raw_representation is raw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context mode tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MessageCapturingAgent(BaseAgent):
|
||||
"""Agent that records the messages it received and returns a configurable reply."""
|
||||
|
||||
def __init__(self, *, reply_text: str = "reply", **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.reply_text = reply_text
|
||||
self.last_messages: list[Message] = []
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
captured: list[Message] = []
|
||||
if messages:
|
||||
for m in messages: # type: ignore[union-attr]
|
||||
if isinstance(m, Message):
|
||||
captured.append(m)
|
||||
elif isinstance(m, str):
|
||||
captured.append(Message("user", [m]))
|
||||
self.last_messages = captured
|
||||
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=self.reply_text)])
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", [self.reply_text])])
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
def test_context_mode_custom_requires_context_filter() -> None:
|
||||
"""context_mode='custom' without context_filter must raise ValueError."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
with pytest.raises(ValueError, match="context_filter must be provided"):
|
||||
AgentExecutor(agent, context_mode="custom")
|
||||
|
||||
|
||||
def test_context_mode_custom_with_filter_succeeds() -> None:
|
||||
"""context_mode='custom' with a context_filter should not raise."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, context_mode="custom", context_filter=lambda msgs: msgs[-1:])
|
||||
assert executor._context_mode == "custom" # pyright: ignore[reportPrivateUsage]
|
||||
assert executor._context_filter is not None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_context_mode_defaults_to_full() -> None:
|
||||
"""Default context_mode should be 'full'."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent)
|
||||
assert executor._context_mode == "full" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_context_mode_invalid_value_raises() -> None:
|
||||
"""Invalid context_mode value should raise ValueError."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
with pytest.raises(ValueError, match="context_mode must be one of"):
|
||||
AgentExecutor(agent, context_mode="invalid_mode") # type: ignore
|
||||
|
||||
|
||||
async def test_from_response_context_mode_full_passes_full_conversation() -> None:
|
||||
"""context_mode='full' (default) should pass full_conversation to the second agent."""
|
||||
first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply")
|
||||
second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply")
|
||||
|
||||
exec_a = AgentExecutor(first, id="exec_a")
|
||||
exec_b = AgentExecutor(second, id="exec_b", context_mode="full")
|
||||
|
||||
wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
|
||||
|
||||
async for ev in wf.run("hello", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
# Second agent should see full conversation: [user("hello"), assistant("first reply")]
|
||||
seen = second.last_messages
|
||||
assert len(seen) == 2
|
||||
assert seen[0].role == "user" and "hello" in (seen[0].text or "")
|
||||
assert seen[1].role == "assistant" and "first reply" in (seen[1].text or "")
|
||||
|
||||
|
||||
async def test_from_response_context_mode_last_agent_passes_only_agent_messages() -> None:
|
||||
"""context_mode='last_agent' should pass only the previous agent's response messages."""
|
||||
first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply")
|
||||
second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply")
|
||||
|
||||
exec_a = AgentExecutor(first, id="exec_a")
|
||||
exec_b = AgentExecutor(second, id="exec_b", context_mode="last_agent")
|
||||
|
||||
wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
|
||||
|
||||
async for ev in wf.run("hello", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
# Second agent should see only the assistant message from first: [assistant("first reply")]
|
||||
seen = second.last_messages
|
||||
assert len(seen) == 1
|
||||
assert seen[0].role == "assistant" and "first reply" in (seen[0].text or "")
|
||||
|
||||
|
||||
async def test_from_response_context_mode_custom_uses_filter() -> None:
|
||||
"""context_mode='custom' should invoke context_filter on full_conversation."""
|
||||
first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply")
|
||||
second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply")
|
||||
|
||||
# Custom filter: keep only user messages
|
||||
def only_user_messages(msgs: list[Message]) -> list[Message]:
|
||||
return [m for m in msgs if m.role == "user"]
|
||||
|
||||
exec_a = AgentExecutor(first, id="exec_a")
|
||||
exec_b = AgentExecutor(second, id="exec_b", context_mode="custom", context_filter=only_user_messages)
|
||||
|
||||
wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
|
||||
|
||||
async for ev in wf.run("hello", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
# Second agent should see only user messages: [user("hello")]
|
||||
seen = second.last_messages
|
||||
assert len(seen) == 1
|
||||
assert seen[0].role == "user" and "hello" in (seen[0].text or "")
|
||||
|
||||
|
||||
async def test_checkpoint_save_does_not_include_context_mode() -> None:
|
||||
"""on_checkpoint_save should not include context_mode in the saved state."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, context_mode="last_agent")
|
||||
|
||||
state = await executor.on_checkpoint_save()
|
||||
|
||||
assert "context_mode" not in state
|
||||
assert "cache" in state
|
||||
assert "agent_session" in state
|
||||
|
||||
|
||||
async def test_checkpoint_restore_works_without_context_mode_in_state() -> None:
|
||||
"""on_checkpoint_restore should succeed when state does not contain context_mode."""
|
||||
agent = _CountingAgent(id="a", name="A")
|
||||
executor = AgentExecutor(agent, context_mode="last_agent")
|
||||
|
||||
# Simulate a checkpoint state without context_mode (as saved by the new code)
|
||||
state: dict[str, Any] = {
|
||||
"cache": [Message(role="user", text="cached msg")],
|
||||
"full_conversation": [],
|
||||
"agent_session": AgentSession().to_dict(),
|
||||
"pending_agent_requests": {},
|
||||
"pending_responses_to_agent": [],
|
||||
}
|
||||
|
||||
await executor.on_checkpoint_restore(state)
|
||||
|
||||
cache = executor._cache # pyright: ignore[reportPrivateUsage]
|
||||
assert len(cache) == 1
|
||||
assert cache[0].text == "cached msg"
|
||||
# context_mode should remain as configured in the constructor, not changed by restore
|
||||
assert executor._context_mode == "last_agent" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -341,7 +341,7 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets
|
||||
|
||||
await ctx.send_message(
|
||||
WorkflowMessage(
|
||||
data=AgentExecutorResponse("agent", AgentResponse()),
|
||||
data=AgentExecutorResponse("agent", AgentResponse(), []),
|
||||
source_id="agent",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -100,24 +100,21 @@ class _AggregateAgentConversations(Executor):
|
||||
assistant_replies: list[Message] = []
|
||||
|
||||
for r in results:
|
||||
resp_messages = list(getattr(r.agent_response, "messages", []) or [])
|
||||
conv = r.full_conversation if r.full_conversation is not None else resp_messages
|
||||
resp_messages = list(r.agent_response.messages)
|
||||
|
||||
logger.debug(
|
||||
f"Aggregating executor {getattr(r, 'executor_id', '<unknown>')}: "
|
||||
f"{len(resp_messages)} response msgs, {len(conv)} conversation msgs"
|
||||
f"{len(resp_messages)} response msgs, {len(r.full_conversation)} conversation msgs"
|
||||
)
|
||||
|
||||
# Capture a single user prompt (first encountered across any conversation)
|
||||
if prompt_message is None:
|
||||
found_user = next((m for m in conv if _is_role(m, "user")), None)
|
||||
if found_user is not None:
|
||||
prompt_message = found_user
|
||||
prompt_message = next((m for m in r.full_conversation if _is_role(m, "user")), None)
|
||||
|
||||
# Pick the final assistant message from the response; fallback to conversation search
|
||||
final_assistant = next((m for m in reversed(resp_messages) if _is_role(m, "assistant")), None)
|
||||
if final_assistant is None:
|
||||
final_assistant = next((m for m in reversed(conv) if _is_role(m, "assistant")), None)
|
||||
final_assistant = next((m for m in reversed(r.full_conversation) if _is_role(m, "assistant")), None)
|
||||
|
||||
if final_assistant is not None:
|
||||
assistant_replies.append(final_assistant)
|
||||
|
||||
+11
-3
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from agent_framework._agents import SupportsAgentRun
|
||||
from agent_framework._types import Message
|
||||
@@ -117,18 +118,25 @@ class AgentApprovalExecutor(WorkflowExecutor):
|
||||
agent's output or send the final response to down stream executors in the orchestration.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: SupportsAgentRun) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
agent: SupportsAgentRun,
|
||||
context_mode: Literal["full", "last_agent", "custom"] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the AgentApprovalExecutor.
|
||||
|
||||
Args:
|
||||
agent: The agent protocol to use for generating responses.
|
||||
context_mode: The mode for providing context to the agent.
|
||||
"""
|
||||
super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True)
|
||||
self._context_mode: Literal["full", "last_agent", "custom"] | None = context_mode
|
||||
self._description = agent.description
|
||||
|
||||
super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True)
|
||||
|
||||
def _build_workflow(self, agent: SupportsAgentRun) -> Workflow:
|
||||
"""Build the internal workflow for the AgentApprovalExecutor."""
|
||||
agent_executor = AgentExecutor(agent)
|
||||
agent_executor = AgentExecutor(agent, context_mode=self._context_mode)
|
||||
request_info_executor = AgentRequestInfoExecutor(id="agent_request_info_executor")
|
||||
|
||||
return (
|
||||
|
||||
@@ -38,7 +38,7 @@ confusion and to mirror how the concurrent builder uses explicit dispatcher/aggr
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from agent_framework import Message, SupportsAgentRun
|
||||
from agent_framework._workflows._agent_executor import (
|
||||
@@ -143,6 +143,7 @@ class SequentialBuilder:
|
||||
*,
|
||||
participants: Sequence[SupportsAgentRun | Executor],
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
chain_only_agent_responses: bool = False,
|
||||
intermediate_outputs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the SequentialBuilder.
|
||||
@@ -150,10 +151,14 @@ class SequentialBuilder:
|
||||
Args:
|
||||
participants: Sequence of agent or executor instances to run sequentially.
|
||||
checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence.
|
||||
chain_only_agent_responses: If True, only agent responses are chained between agents.
|
||||
By default, the full conversation context is passed to the next agent. This also applies
|
||||
to Executor -> Agent transitions if the executor sends `AgentExecutorResponse`.
|
||||
intermediate_outputs: If True, enables intermediate outputs from agent participants.
|
||||
"""
|
||||
self._participants: list[SupportsAgentRun | Executor] = []
|
||||
self._checkpoint_storage: CheckpointStorage | None = checkpoint_storage
|
||||
self._chain_only_agent_responses: bool = chain_only_agent_responses
|
||||
self._request_info_enabled: bool = False
|
||||
self._request_info_filter: set[str] | None = None
|
||||
self._intermediate_outputs: bool = intermediate_outputs
|
||||
@@ -225,6 +230,10 @@ class SequentialBuilder:
|
||||
|
||||
participants: list[Executor | SupportsAgentRun] = self._participants
|
||||
|
||||
context_mode: Literal["full", "last_agent", "custom"] | None = (
|
||||
"last_agent" if self._chain_only_agent_responses else None
|
||||
)
|
||||
|
||||
executors: list[Executor] = []
|
||||
for p in participants:
|
||||
if isinstance(p, Executor):
|
||||
@@ -234,9 +243,9 @@ class SequentialBuilder:
|
||||
not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter
|
||||
):
|
||||
# Handle request info enabled agents
|
||||
executors.append(AgentApprovalExecutor(p))
|
||||
executors.append(AgentApprovalExecutor(p, context_mode=context_mode))
|
||||
else:
|
||||
executors.append(AgentExecutor(p))
|
||||
executors.append(AgentExecutor(p, context_mode=context_mode))
|
||||
else:
|
||||
raise TypeError(f"Participants must be SupportsAgentRun or Executor instances. Got {type(p).__name__}.")
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ class TestAgentRequestInfoExecutor:
|
||||
agent_response = AgentExecutorResponse(
|
||||
executor_id="test_agent",
|
||||
agent_response=agent_response,
|
||||
full_conversation=agent_response.messages,
|
||||
)
|
||||
|
||||
ctx = MagicMock(spec=WorkflowContext)
|
||||
@@ -135,6 +136,7 @@ class TestAgentRequestInfoExecutor:
|
||||
original_request = AgentExecutorResponse(
|
||||
executor_id="test_agent",
|
||||
agent_response=agent_response,
|
||||
full_conversation=agent_response.messages,
|
||||
)
|
||||
|
||||
response = AgentRequestInfoResponse.from_strings(["Additional input"])
|
||||
@@ -161,6 +163,7 @@ class TestAgentRequestInfoExecutor:
|
||||
original_request = AgentExecutorResponse(
|
||||
executor_id="test_agent",
|
||||
agent_response=agent_response,
|
||||
full_conversation=agent_response.messages,
|
||||
)
|
||||
|
||||
response = AgentRequestInfoResponse.approve()
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
from agent_framework import (
|
||||
AgentExecutorResponse,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
Content,
|
||||
Executor,
|
||||
Message,
|
||||
ResponseStream,
|
||||
TypeCompatibilityError,
|
||||
WorkflowContext,
|
||||
WorkflowRunState,
|
||||
@@ -25,26 +27,45 @@ from agent_framework.orchestrations import SequentialBuilder
|
||||
class _EchoAgent(BaseAgent):
|
||||
"""Simple agent that appends a single assistant message with its name."""
|
||||
|
||||
def run( # type: ignore[override]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
if stream:
|
||||
return self._run_stream()
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} reply")])
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", [f"{self.name} reply"])])
|
||||
|
||||
return _run()
|
||||
|
||||
async def _run_stream(self) -> AsyncIterable[AgentResponseUpdate]:
|
||||
# Minimal async generator with one assistant update
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} reply")])
|
||||
|
||||
|
||||
class _SummarizerExec(Executor):
|
||||
"""Custom executor that summarizes by appending a short assistant message."""
|
||||
@@ -251,3 +272,121 @@ async def test_sequential_builder_reusable_after_build_with_participants() -> No
|
||||
|
||||
assert builder._participants[0] is a1 # type: ignore
|
||||
assert builder._participants[1] is a2 # type: ignore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chain_only_agent_responses tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _CapturingAgent(BaseAgent):
|
||||
"""Agent that records the messages it received and returns a configurable reply."""
|
||||
|
||||
def __init__(self, *, reply_text: str = "reply", **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.reply_text = reply_text
|
||||
self.last_messages: list[Message] = []
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
captured: list[Message] = []
|
||||
if messages:
|
||||
for m in messages: # type: ignore[union-attr]
|
||||
if isinstance(m, Message):
|
||||
captured.append(m)
|
||||
elif isinstance(m, str):
|
||||
captured.append(Message("user", [m]))
|
||||
self.last_messages = captured
|
||||
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=self.reply_text)])
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", [self.reply_text])])
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
async def test_chain_only_agent_responses_false_passes_full_conversation() -> None:
|
||||
"""Default (chain_only_agent_responses=False) passes full conversation to the second agent."""
|
||||
a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply")
|
||||
a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply")
|
||||
|
||||
wf = SequentialBuilder(participants=[a1, a2], chain_only_agent_responses=False).build()
|
||||
|
||||
async for ev in wf.run("hello", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
# Second agent should see full conversation: [user("hello"), assistant("A1 reply")]
|
||||
seen = a2.last_messages
|
||||
assert len(seen) == 2
|
||||
assert seen[0].role == "user" and "hello" in (seen[0].text or "")
|
||||
assert seen[1].role == "assistant" and "A1 reply" in (seen[1].text or "")
|
||||
|
||||
|
||||
async def test_chain_only_agent_responses_true_passes_only_agent_messages() -> None:
|
||||
"""chain_only_agent_responses=True passes only the previous agent's response messages."""
|
||||
a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply")
|
||||
a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply")
|
||||
|
||||
wf = SequentialBuilder(participants=[a1, a2], chain_only_agent_responses=True).build()
|
||||
|
||||
async for ev in wf.run("hello", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
# Second agent should see only the assistant message: [assistant("A1 reply")]
|
||||
seen = a2.last_messages
|
||||
assert len(seen) == 1
|
||||
assert seen[0].role == "assistant" and "A1 reply" in (seen[0].text or "")
|
||||
|
||||
|
||||
async def test_chain_only_agent_responses_three_agents() -> None:
|
||||
"""chain_only_agent_responses=True with three agents: each sees only the prior agent's reply."""
|
||||
a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply")
|
||||
a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply")
|
||||
a3 = _CapturingAgent(id="agent3", name="A3", reply_text="A3 reply")
|
||||
|
||||
wf = SequentialBuilder(participants=[a1, a2, a3], chain_only_agent_responses=True).build()
|
||||
|
||||
async for ev in wf.run("hello", stream=True):
|
||||
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
# a2 should see only A1's reply
|
||||
assert len(a2.last_messages) == 1
|
||||
assert a2.last_messages[0].role == "assistant" and "A1 reply" in (a2.last_messages[0].text or "")
|
||||
|
||||
# a3 should see only A2's reply
|
||||
assert len(a3.last_messages) == 1
|
||||
assert a3.last_messages[0].role == "assistant" and "A2 reply" in (a3.last_messages[0].text or "")
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import AgentResponseUpdate
|
||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
"""
|
||||
Sample: Sequential workflow with chain_only_agent_responses=True
|
||||
|
||||
Demonstrates SequentialBuilder with `chain_only_agent_responses=True`, which passes
|
||||
only the previous agent's response (not the full conversation history) to the next
|
||||
agent. This is useful when each agent should focus solely on refining or transforming
|
||||
the prior agent's output without being influenced by earlier turns.
|
||||
|
||||
In this sample, a writer agent produces a draft tagline, a translator agent translates
|
||||
it into French (seeing only the writer's output, not the original user prompt), and a
|
||||
reviewer agent evaluates the translation (seeing only the translator's output).
|
||||
|
||||
Compare with `sequential_agents.py`, which uses the default behavior where the full
|
||||
conversation context is passed to each agent.
|
||||
|
||||
Prerequisites:
|
||||
- AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
|
||||
- Azure OpenAI configured for AzureOpenAIResponsesClient with required environment variables.
|
||||
- Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample.
|
||||
"""
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# 1) Create agents
|
||||
client = AzureOpenAIResponsesClient(
|
||||
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
writer = client.as_agent(
|
||||
instructions="You are a concise copywriter. Provide a single, punchy marketing sentence based on the prompt.",
|
||||
name="writer",
|
||||
)
|
||||
|
||||
translator = client.as_agent(
|
||||
instructions="You are a translator. Translate the given text into French. Output only the translation.",
|
||||
name="translator",
|
||||
)
|
||||
|
||||
reviewer = client.as_agent(
|
||||
instructions="You are a reviewer. Evaluate the quality of the marketing tagline.",
|
||||
name="reviewer",
|
||||
)
|
||||
|
||||
# 2) Build sequential workflow: writer -> translator -> reviewer
|
||||
# chain_only_agent_responses=True means each agent sees only the previous agent's reply,
|
||||
# not the full conversation history.
|
||||
workflow = SequentialBuilder(
|
||||
participants=[writer, translator, reviewer],
|
||||
chain_only_agent_responses=True,
|
||||
intermediate_outputs=True,
|
||||
).build()
|
||||
|
||||
# 3) Run and collect outputs
|
||||
last_agent: str | None = None
|
||||
async for event in workflow.run("Write a tagline for a budget-friendly eBike.", stream=True):
|
||||
if event.type == "output" and isinstance(event.data, AgentResponseUpdate):
|
||||
if event.data.author_name != last_agent:
|
||||
last_agent = event.data.author_name
|
||||
print()
|
||||
print(f"{last_agent}: ", end="", flush=True)
|
||||
print(event.data.text, end="", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user