[BREAKING] Python: Add context mode to AgentExecutor (#4668)

* Add context mode to AgentExecutor

* Fix unit tests

* Address comments

* Address comments

* REvise context mode and add tests

* Add chain config to sequential builder

* Add sample

* Fix pipeline

* Address comments

* Address comments
This commit is contained in:
Tao Chen
2026-03-20 11:27:02 -07:00
committed by GitHub
Unverified
parent 88ea9d08c7
commit 51828abed4
11 changed files with 549 additions and 89 deletions
@@ -232,6 +232,7 @@ class TestSerializationRoundtrip:
original = AgentExecutorResponse(
executor_id="test_exec",
agent_response=AgentResponse(messages=[Message(role="assistant", text="Reply")]),
full_conversation=[Message(role="assistant", text="Reply")],
)
encoded = serialize_value(original)
decoded = deserialize_value(encoded)
@@ -212,6 +212,7 @@ class TestExtractMessageContent:
response = AgentExecutorResponse(
executor_id="exec",
agent_response=AgentResponse(messages=[Message(role="assistant", text="Response text")]),
full_conversation=[Message(role="assistant", text="Response text")],
)
result = _extract_message_content(response)
@@ -228,6 +229,10 @@ class TestExtractMessageContent:
Message(role="assistant", text="Last message"),
]
),
full_conversation=[
Message(role="user", text="First"),
Message(role="assistant", text="Last message"),
],
)
result = _extract_message_content(response)
@@ -4,7 +4,7 @@ import logging
import sys
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from typing import Any, cast
from typing import Any, Literal, cast
from typing_extensions import Never
@@ -57,7 +57,7 @@ class AgentExecutorResponse:
executor_id: str
agent_response: AgentResponse
full_conversation: list[Message] | None = None
full_conversation: list[Message]
class AgentExecutor(Executor):
@@ -83,6 +83,8 @@ class AgentExecutor(Executor):
*,
session: AgentSession | None = None,
id: str | None = None,
context_mode: Literal["full", "last_agent", "custom"] | None = None,
context_filter: Callable[[list[Message]], list[Message]] | None = None,
):
"""Initialize the executor with a unique identifier.
@@ -90,6 +92,16 @@ class AgentExecutor(Executor):
agent: The agent to be wrapped by this executor.
session: The session to use for running the agent. If None, a new session will be created.
id: A unique identifier for the executor. If None, the agent's name will be used if available.
context_mode: Configuration for how the executor should manage conversation context upon
receiving an AgentExecutorResponse as input. Options:
- "full": append the full conversation (all prior messages + latest agent response) to the
cache for the agent run. This is the default mode.
- "last_agent": provide only the messages from the latest agent response as context for
the agent run.
- "custom": use the provided context_filter function to determine which messages to include
as context for the agent run.
context_filter: An optional function for filtering conversation context when context_mode is set
to "custom".
"""
# Prefer provided id; else use agent.name if present; else generate deterministic prefix
exec_id = id or resolve_agent_id(agent)
@@ -107,6 +119,14 @@ class AgentExecutor(Executor):
# This tracks the full conversation after each run
self._full_conversation: list[Message] = []
# Context mode validation
self._context_mode = context_mode or "full"
self._context_filter = context_filter
if self._context_mode not in {"full", "last_agent", "custom"}:
raise ValueError("context_mode must be one of 'full', 'last_agent', or 'custom'.")
if self._context_mode == "custom" and not self._context_filter:
raise ValueError("context_filter must be provided when context_mode is set to 'custom'.")
@property
def agent(self) -> SupportsAgentRun:
"""Get the underlying agent wrapped by this executor."""
@@ -129,6 +149,7 @@ class AgentExecutor(Executor):
run the agent and emit an AgentExecutorResponse downstream.
"""
self._cache.extend(request.messages)
if request.should_respond:
await self._run_agent_and_emit(ctx)
@@ -143,19 +164,27 @@ class AgentExecutor(Executor):
Strategy: treat the prior response's messages as the conversation state and
immediately run the agent to produce a new response.
"""
# Replace cache with full conversation if available, else fall back to agent_response messages.
source_messages = (
prior.full_conversation if prior.full_conversation is not None else prior.agent_response.messages
)
self._cache = list(source_messages)
if self._context_mode == "full":
self._cache.extend(prior.full_conversation)
elif self._context_mode == "last_agent":
self._cache.extend(prior.agent_response.messages)
else:
if not self._context_filter:
# This should never happen due to validation in __init__, but mypy doesn't track that well
raise ValueError("context_filter function must be provided for 'custom' context_mode.")
self._cache.extend(self._context_filter(prior.full_conversation))
await self._run_agent_and_emit(ctx)
@handler
async def from_str(
self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate]
) -> None:
"""Accept a raw user prompt string and run the agent (one-shot)."""
self._cache = normalize_messages_input(text)
"""Accept a raw user prompt string and run the agent.
The new string input will be added to the cache which is used as the conversation context for the agent run.
"""
self._cache.extend(normalize_messages_input(text))
await self._run_agent_and_emit(ctx)
@handler
@@ -164,8 +193,11 @@ class AgentExecutor(Executor):
message: Message,
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
) -> None:
"""Accept a single Message as input."""
self._cache = normalize_messages_input(message)
"""Accept a single Message as input.
The new message will be added to the cache which is used as the conversation context for the agent run.
"""
self._cache.extend(normalize_messages_input(message))
await self._run_agent_and_emit(ctx)
@handler
@@ -174,8 +206,11 @@ class AgentExecutor(Executor):
messages: list[str | Message],
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
) -> None:
"""Accept a list of chat inputs (strings or Message) as conversation context."""
self._cache = normalize_messages_input(messages)
"""Accept a list of chat inputs (strings or Message) as conversation context.
The new messages will be added to the cache which is used as the conversation context for the agent run.
"""
self._cache.extend(normalize_messages_input(messages))
await self._run_agent_and_emit(ctx)
@response_handler
@@ -249,24 +284,10 @@ class AgentExecutor(Executor):
state: Checkpoint data dict
"""
cache_payload = state.get("cache")
if cache_payload:
try:
self._cache = cache_payload
except Exception as exc:
logger.warning("Failed to restore cache: %s", exc)
self._cache = []
else:
self._cache = []
self._cache = cache_payload or []
full_conversation_payload = state.get("full_conversation")
if full_conversation_payload:
try:
self._full_conversation = full_conversation_payload
except Exception as exc:
logger.warning("Failed to restore full conversation: %s", exc)
self._full_conversation = []
else:
self._full_conversation = []
self._full_conversation = full_conversation_payload or []
session_payload = state.get("agent_session")
if session_payload:
@@ -279,12 +300,10 @@ class AgentExecutor(Executor):
self._session = self._agent.create_session()
pending_requests_payload = state.get("pending_agent_requests")
if pending_requests_payload:
self._pending_agent_requests = pending_requests_payload
self._pending_agent_requests = pending_requests_payload or {}
pending_responses_payload = state.get("pending_responses_to_agent")
if pending_responses_payload:
self._pending_responses_to_agent = pending_responses_payload
self._pending_responses_to_agent = pending_responses_payload or []
def reset(self) -> None:
"""Reset the internal cache of the executor."""
@@ -16,12 +16,12 @@ from agent_framework import (
Content,
Message,
ResponseStream,
WorkflowBuilder,
WorkflowEvent,
WorkflowRunState,
)
from agent_framework._workflows._agent_executor import AgentExecutorResponse
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
from agent_framework.orchestrations import SequentialBuilder
if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture
@@ -139,7 +139,7 @@ async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks()
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
executor = AgentExecutor(agent, id="hook_exec")
workflow = SequentialBuilder(participants=[executor]).build()
workflow = WorkflowBuilder(start_executor=executor).build()
output_events: list[Any] = []
async for event in workflow.run("run hook test", stream=True):
@@ -154,8 +154,9 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
"""Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly."""
storage = InMemoryCheckpointStorage()
# Create initial agent with a custom session
initial_agent = _CountingAgent(id="test_agent", name="TestAgent")
# Create two agents to form a two-step workflow
initial_agent_a = _CountingAgent(id="agent_a", name="AgentA")
initial_agent_b = _CountingAgent(id="agent_b", name="AgentB")
initial_session = AgentSession()
# Add some initial messages to the session state to verify session state persistence
@@ -165,11 +166,12 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
]
initial_session.state["history"] = {"messages": initial_messages}
# Create AgentExecutor with the session
executor = AgentExecutor(initial_agent, session=initial_session)
# Create AgentExecutors — first executor gets the custom session
exec_a = AgentExecutor(initial_agent_a, id="exec_a", session=initial_session)
exec_b = AgentExecutor(initial_agent_b, id="exec_b")
# Build workflow with checkpointing enabled
wf = SequentialBuilder(participants=[executor], checkpoint_storage=storage).build()
# Build two-executor workflow with checkpointing enabled
wf = WorkflowBuilder(start_executor=exec_a, checkpoint_storage=storage).add_edge(exec_a, exec_b).build()
# Run the workflow with a user message
first_run_output: AgentExecutorResponse | None = None
@@ -180,27 +182,25 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
break
assert first_run_output is not None
assert initial_agent.call_count == 1
assert initial_agent_a.call_count == 1
# Verify checkpoint was created
checkpoints = await storage.list_checkpoints(workflow_name=wf.name)
assert len(checkpoints) >= 2, (
"Expected at least 2 checkpoints. The first one is after the start executor, "
"and the second one is after the agent execution."
assert len(checkpoints) >= 2, "Expected at least 2 checkpoints: one after exec_a and one after exec_b."
# Get the first checkpoint that contains exec_a's state (taken after exec_a completes,
# before exec_b runs)
checkpoints.sort(key=lambda cp: cp.timestamp)
restore_checkpoint = next(
cp for cp in checkpoints if "_executor_state" in cp.state and "exec_a" in cp.state["_executor_state"]
)
# Get the second checkpoint which should contain the state after processing
# the first message by the start executor in the sequential workflow
checkpoints.sort(key=lambda cp: cp.timestamp)
restore_checkpoint = checkpoints[1]
# Verify checkpoint contains executor state with both cache and session
assert "_executor_state" in restore_checkpoint.state
executor_states = restore_checkpoint.state["_executor_state"]
assert isinstance(executor_states, dict)
assert executor.id in executor_states
assert exec_a.id in executor_states
executor_state = executor_states[executor.id] # type: ignore[index]
executor_state = executor_states[exec_a.id] # type: ignore[index]
assert "cache" in executor_state, "Checkpoint should store executor cache state"
assert "agent_session" in executor_state, "Checkpoint should store executor session state"
@@ -213,19 +213,26 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
assert "pending_agent_requests" in executor_state
assert "pending_responses_to_agent" in executor_state
# Create a new agent and executor for restoration
# Create new agents and executors for restoration
# This simulates starting from a fresh state and restoring from checkpoint
restored_agent = _CountingAgent(id="test_agent", name="TestAgent")
restored_agent_a = _CountingAgent(id="agent_a", name="AgentA")
restored_agent_b = _CountingAgent(id="agent_b", name="AgentB")
restored_session = AgentSession()
restored_executor = AgentExecutor(restored_agent, session=restored_session)
restored_exec_a = AgentExecutor(restored_agent_a, id="exec_a", session=restored_session)
restored_exec_b = AgentExecutor(restored_agent_b, id="exec_b")
# Verify the restored agent starts with a fresh state
assert restored_agent.call_count == 0
# Verify the restored agents start with a fresh state
assert restored_agent_a.call_count == 0
assert restored_agent_b.call_count == 0
# Build new workflow with the restored executor
wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build()
# Build new workflow with the restored executors
wf_resume = (
WorkflowBuilder(start_executor=restored_exec_a, checkpoint_storage=storage)
.add_edge(restored_exec_a, restored_exec_b)
.build()
)
# Resume from checkpoint
# Resume from checkpoint — exec_a already ran, so exec_b should run and produce output
resumed_output: AgentExecutorResponse | None = None
async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True):
if ev.type == "output":
@@ -239,7 +246,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
assert resumed_output is not None
# Verify the restored executor's session state was restored
restored_session_obj = restored_executor._session # type: ignore[reportPrivateUsage]
restored_session_obj = restored_exec_a._session # type: ignore[reportPrivateUsage]
assert restored_session_obj is not None
assert restored_session_obj.session_id == initial_session.session_id
@@ -306,7 +313,7 @@ async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None:
"""Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295)."""
agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent")
executor = AgentExecutor(agent, id="session_kwarg_exec")
workflow = SequentialBuilder(participants=[executor]).build()
workflow = WorkflowBuilder(start_executor=executor).build()
# This previously raised: TypeError: run() got multiple values for keyword argument 'session'
result = await workflow.run("hello", session="user-supplied-value")
@@ -318,7 +325,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -
"""Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent")
executor = AgentExecutor(agent, id="stream_kwarg_exec")
workflow = SequentialBuilder(participants=[executor]).build()
workflow = WorkflowBuilder(start_executor=executor).build()
# stream=True at workflow level triggers streaming mode (returns async iterable)
events: list[WorkflowEvent] = []
@@ -378,7 +385,7 @@ async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None:
"""Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError."""
agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent")
executor = AgentExecutor(agent, id="messages_kwarg_exec")
workflow = SequentialBuilder(participants=[executor]).build()
workflow = WorkflowBuilder(start_executor=executor).build()
result = await workflow.run("hello", messages=["stale"])
assert result is not None
@@ -426,7 +433,7 @@ async def test_agent_executor_workflow_with_non_copyable_raw_representation() ->
exec_a = AgentExecutor(agent_a, id="exec_a")
exec_b = AgentExecutor(agent_b, id="exec_b")
workflow = SequentialBuilder(participants=[exec_a, exec_b]).build()
workflow = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
events = await workflow.run("hello")
completed = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
@@ -440,3 +447,194 @@ async def test_agent_executor_workflow_with_non_copyable_raw_representation() ->
assert len(agent_responses) > 0
assert agent_responses[0].text == "reply from AgentA"
assert agent_responses[0].raw_representation is raw
# ---------------------------------------------------------------------------
# Context mode tests
# ---------------------------------------------------------------------------
class _MessageCapturingAgent(BaseAgent):
"""Agent that records the messages it received and returns a configurable reply."""
def __init__(self, *, reply_text: str = "reply", **kwargs: Any):
super().__init__(**kwargs)
self.reply_text = reply_text
self.last_messages: list[Message] = []
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
captured: list[Message] = []
if messages:
for m in messages: # type: ignore[union-attr]
if isinstance(m, Message):
captured.append(m)
elif isinstance(m, str):
captured.append(Message("user", [m]))
self.last_messages = captured
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(contents=[Content.from_text(text=self.reply_text)])
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [self.reply_text])])
return _run()
def test_context_mode_custom_requires_context_filter() -> None:
"""context_mode='custom' without context_filter must raise ValueError."""
agent = _CountingAgent(id="a", name="A")
with pytest.raises(ValueError, match="context_filter must be provided"):
AgentExecutor(agent, context_mode="custom")
def test_context_mode_custom_with_filter_succeeds() -> None:
"""context_mode='custom' with a context_filter should not raise."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, context_mode="custom", context_filter=lambda msgs: msgs[-1:])
assert executor._context_mode == "custom" # pyright: ignore[reportPrivateUsage]
assert executor._context_filter is not None # pyright: ignore[reportPrivateUsage]
def test_context_mode_defaults_to_full() -> None:
"""Default context_mode should be 'full'."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent)
assert executor._context_mode == "full" # pyright: ignore[reportPrivateUsage]
def test_context_mode_invalid_value_raises() -> None:
"""Invalid context_mode value should raise ValueError."""
agent = _CountingAgent(id="a", name="A")
with pytest.raises(ValueError, match="context_mode must be one of"):
AgentExecutor(agent, context_mode="invalid_mode") # type: ignore
async def test_from_response_context_mode_full_passes_full_conversation() -> None:
"""context_mode='full' (default) should pass full_conversation to the second agent."""
first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply")
second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply")
exec_a = AgentExecutor(first, id="exec_a")
exec_b = AgentExecutor(second, id="exec_b", context_mode="full")
wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
async for ev in wf.run("hello", stream=True):
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
break
# Second agent should see full conversation: [user("hello"), assistant("first reply")]
seen = second.last_messages
assert len(seen) == 2
assert seen[0].role == "user" and "hello" in (seen[0].text or "")
assert seen[1].role == "assistant" and "first reply" in (seen[1].text or "")
async def test_from_response_context_mode_last_agent_passes_only_agent_messages() -> None:
"""context_mode='last_agent' should pass only the previous agent's response messages."""
first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply")
second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply")
exec_a = AgentExecutor(first, id="exec_a")
exec_b = AgentExecutor(second, id="exec_b", context_mode="last_agent")
wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
async for ev in wf.run("hello", stream=True):
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
break
# Second agent should see only the assistant message from first: [assistant("first reply")]
seen = second.last_messages
assert len(seen) == 1
assert seen[0].role == "assistant" and "first reply" in (seen[0].text or "")
async def test_from_response_context_mode_custom_uses_filter() -> None:
"""context_mode='custom' should invoke context_filter on full_conversation."""
first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply")
second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply")
# Custom filter: keep only user messages
def only_user_messages(msgs: list[Message]) -> list[Message]:
return [m for m in msgs if m.role == "user"]
exec_a = AgentExecutor(first, id="exec_a")
exec_b = AgentExecutor(second, id="exec_b", context_mode="custom", context_filter=only_user_messages)
wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build()
async for ev in wf.run("hello", stream=True):
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
break
# Second agent should see only user messages: [user("hello")]
seen = second.last_messages
assert len(seen) == 1
assert seen[0].role == "user" and "hello" in (seen[0].text or "")
async def test_checkpoint_save_does_not_include_context_mode() -> None:
"""on_checkpoint_save should not include context_mode in the saved state."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, context_mode="last_agent")
state = await executor.on_checkpoint_save()
assert "context_mode" not in state
assert "cache" in state
assert "agent_session" in state
async def test_checkpoint_restore_works_without_context_mode_in_state() -> None:
"""on_checkpoint_restore should succeed when state does not contain context_mode."""
agent = _CountingAgent(id="a", name="A")
executor = AgentExecutor(agent, context_mode="last_agent")
# Simulate a checkpoint state without context_mode (as saved by the new code)
state: dict[str, Any] = {
"cache": [Message(role="user", text="cached msg")],
"full_conversation": [],
"agent_session": AgentSession().to_dict(),
"pending_agent_requests": {},
"pending_responses_to_agent": [],
}
await executor.on_checkpoint_restore(state)
cache = executor._cache # pyright: ignore[reportPrivateUsage]
assert len(cache) == 1
assert cache[0].text == "cached msg"
# context_mode should remain as configured in the constructor, not changed by restore
assert executor._context_mode == "last_agent" # pyright: ignore[reportPrivateUsage]
@@ -341,7 +341,7 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets
await ctx.send_message(
WorkflowMessage(
data=AgentExecutorResponse("agent", AgentResponse()),
data=AgentExecutorResponse("agent", AgentResponse(), []),
source_id="agent",
)
)
@@ -100,24 +100,21 @@ class _AggregateAgentConversations(Executor):
assistant_replies: list[Message] = []
for r in results:
resp_messages = list(getattr(r.agent_response, "messages", []) or [])
conv = r.full_conversation if r.full_conversation is not None else resp_messages
resp_messages = list(r.agent_response.messages)
logger.debug(
f"Aggregating executor {getattr(r, 'executor_id', '<unknown>')}: "
f"{len(resp_messages)} response msgs, {len(conv)} conversation msgs"
f"{len(resp_messages)} response msgs, {len(r.full_conversation)} conversation msgs"
)
# Capture a single user prompt (first encountered across any conversation)
if prompt_message is None:
found_user = next((m for m in conv if _is_role(m, "user")), None)
if found_user is not None:
prompt_message = found_user
prompt_message = next((m for m in r.full_conversation if _is_role(m, "user")), None)
# Pick the final assistant message from the response; fallback to conversation search
final_assistant = next((m for m in reversed(resp_messages) if _is_role(m, "assistant")), None)
if final_assistant is None:
final_assistant = next((m for m in reversed(conv) if _is_role(m, "assistant")), None)
final_assistant = next((m for m in reversed(r.full_conversation) if _is_role(m, "assistant")), None)
if final_assistant is not None:
assistant_replies.append(final_assistant)
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from dataclasses import dataclass
from typing import Literal
from agent_framework._agents import SupportsAgentRun
from agent_framework._types import Message
@@ -117,18 +118,25 @@ class AgentApprovalExecutor(WorkflowExecutor):
agent's output or send the final response to down stream executors in the orchestration.
"""
def __init__(self, agent: SupportsAgentRun) -> None:
def __init__(
self,
agent: SupportsAgentRun,
context_mode: Literal["full", "last_agent", "custom"] | None = None,
) -> None:
"""Initialize the AgentApprovalExecutor.
Args:
agent: The agent protocol to use for generating responses.
context_mode: The mode for providing context to the agent.
"""
super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True)
self._context_mode: Literal["full", "last_agent", "custom"] | None = context_mode
self._description = agent.description
super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True)
def _build_workflow(self, agent: SupportsAgentRun) -> Workflow:
"""Build the internal workflow for the AgentApprovalExecutor."""
agent_executor = AgentExecutor(agent)
agent_executor = AgentExecutor(agent, context_mode=self._context_mode)
request_info_executor = AgentRequestInfoExecutor(id="agent_request_info_executor")
return (
@@ -38,7 +38,7 @@ confusion and to mirror how the concurrent builder uses explicit dispatcher/aggr
import logging
from collections.abc import Sequence
from typing import Any
from typing import Any, Literal
from agent_framework import Message, SupportsAgentRun
from agent_framework._workflows._agent_executor import (
@@ -143,6 +143,7 @@ class SequentialBuilder:
*,
participants: Sequence[SupportsAgentRun | Executor],
checkpoint_storage: CheckpointStorage | None = None,
chain_only_agent_responses: bool = False,
intermediate_outputs: bool = False,
) -> None:
"""Initialize the SequentialBuilder.
@@ -150,10 +151,14 @@ class SequentialBuilder:
Args:
participants: Sequence of agent or executor instances to run sequentially.
checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence.
chain_only_agent_responses: If True, only agent responses are chained between agents.
By default, the full conversation context is passed to the next agent. This also applies
to Executor -> Agent transitions if the executor sends `AgentExecutorResponse`.
intermediate_outputs: If True, enables intermediate outputs from agent participants.
"""
self._participants: list[SupportsAgentRun | Executor] = []
self._checkpoint_storage: CheckpointStorage | None = checkpoint_storage
self._chain_only_agent_responses: bool = chain_only_agent_responses
self._request_info_enabled: bool = False
self._request_info_filter: set[str] | None = None
self._intermediate_outputs: bool = intermediate_outputs
@@ -225,6 +230,10 @@ class SequentialBuilder:
participants: list[Executor | SupportsAgentRun] = self._participants
context_mode: Literal["full", "last_agent", "custom"] | None = (
"last_agent" if self._chain_only_agent_responses else None
)
executors: list[Executor] = []
for p in participants:
if isinstance(p, Executor):
@@ -234,9 +243,9 @@ class SequentialBuilder:
not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter
):
# Handle request info enabled agents
executors.append(AgentApprovalExecutor(p))
executors.append(AgentApprovalExecutor(p, context_mode=context_mode))
else:
executors.append(AgentExecutor(p))
executors.append(AgentExecutor(p, context_mode=context_mode))
else:
raise TypeError(f"Participants must be SupportsAgentRun or Executor instances. Got {type(p).__name__}.")
@@ -117,6 +117,7 @@ class TestAgentRequestInfoExecutor:
agent_response = AgentExecutorResponse(
executor_id="test_agent",
agent_response=agent_response,
full_conversation=agent_response.messages,
)
ctx = MagicMock(spec=WorkflowContext)
@@ -135,6 +136,7 @@ class TestAgentRequestInfoExecutor:
original_request = AgentExecutorResponse(
executor_id="test_agent",
agent_response=agent_response,
full_conversation=agent_response.messages,
)
response = AgentRequestInfoResponse.from_strings(["Additional input"])
@@ -161,6 +163,7 @@ class TestAgentRequestInfoExecutor:
original_request = AgentExecutorResponse(
executor_id="test_agent",
agent_response=agent_response,
full_conversation=agent_response.messages,
)
response = AgentRequestInfoResponse.approve()
@@ -1,18 +1,20 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Awaitable
from typing import Any
from typing import Any, Literal, overload
import pytest
from agent_framework import (
AgentExecutorResponse,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
Content,
Executor,
Message,
ResponseStream,
TypeCompatibilityError,
WorkflowContext,
WorkflowRunState,
@@ -25,26 +27,45 @@ from agent_framework.orchestrations import SequentialBuilder
class _EchoAgent(BaseAgent):
"""Simple agent that appends a single assistant message with its name."""
def run( # type: ignore[override]
@overload
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:
return self._run_stream()
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} reply")])
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [f"{self.name} reply"])])
return _run()
async def _run_stream(self) -> AsyncIterable[AgentResponseUpdate]:
# Minimal async generator with one assistant update
yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} reply")])
class _SummarizerExec(Executor):
"""Custom executor that summarizes by appending a short assistant message."""
@@ -251,3 +272,121 @@ async def test_sequential_builder_reusable_after_build_with_participants() -> No
assert builder._participants[0] is a1 # type: ignore
assert builder._participants[1] is a2 # type: ignore
# ---------------------------------------------------------------------------
# chain_only_agent_responses tests
# ---------------------------------------------------------------------------
class _CapturingAgent(BaseAgent):
"""Agent that records the messages it received and returns a configurable reply."""
def __init__(self, *, reply_text: str = "reply", **kwargs: Any):
super().__init__(**kwargs)
self.reply_text = reply_text
self.last_messages: list[Message] = []
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
captured: list[Message] = []
if messages:
for m in messages: # type: ignore[union-attr]
if isinstance(m, Message):
captured.append(m)
elif isinstance(m, str):
captured.append(Message("user", [m]))
self.last_messages = captured
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(contents=[Content.from_text(text=self.reply_text)])
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [self.reply_text])])
return _run()
async def test_chain_only_agent_responses_false_passes_full_conversation() -> None:
"""Default (chain_only_agent_responses=False) passes full conversation to the second agent."""
a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply")
a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply")
wf = SequentialBuilder(participants=[a1, a2], chain_only_agent_responses=False).build()
async for ev in wf.run("hello", stream=True):
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
break
# Second agent should see full conversation: [user("hello"), assistant("A1 reply")]
seen = a2.last_messages
assert len(seen) == 2
assert seen[0].role == "user" and "hello" in (seen[0].text or "")
assert seen[1].role == "assistant" and "A1 reply" in (seen[1].text or "")
async def test_chain_only_agent_responses_true_passes_only_agent_messages() -> None:
"""chain_only_agent_responses=True passes only the previous agent's response messages."""
a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply")
a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply")
wf = SequentialBuilder(participants=[a1, a2], chain_only_agent_responses=True).build()
async for ev in wf.run("hello", stream=True):
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
break
# Second agent should see only the assistant message: [assistant("A1 reply")]
seen = a2.last_messages
assert len(seen) == 1
assert seen[0].role == "assistant" and "A1 reply" in (seen[0].text or "")
async def test_chain_only_agent_responses_three_agents() -> None:
"""chain_only_agent_responses=True with three agents: each sees only the prior agent's reply."""
a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply")
a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply")
a3 = _CapturingAgent(id="agent3", name="A3", reply_text="A3 reply")
wf = SequentialBuilder(participants=[a1, a2, a3], chain_only_agent_responses=True).build()
async for ev in wf.run("hello", stream=True):
if ev.type == "status" and ev.state == WorkflowRunState.IDLE:
break
# a2 should see only A1's reply
assert len(a2.last_messages) == 1
assert a2.last_messages[0].role == "assistant" and "A1 reply" in (a2.last_messages[0].text or "")
# a3 should see only A2's reply
assert len(a3.last_messages) == 1
assert a3.last_messages[0].role == "assistant" and "A2 reply" in (a3.last_messages[0].text or "")
@@ -0,0 +1,81 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import os
from agent_framework import AgentResponseUpdate
from agent_framework.azure import AzureOpenAIResponsesClient
from agent_framework.orchestrations import SequentialBuilder
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
"""
Sample: Sequential workflow with chain_only_agent_responses=True
Demonstrates SequentialBuilder with `chain_only_agent_responses=True`, which passes
only the previous agent's response (not the full conversation history) to the next
agent. This is useful when each agent should focus solely on refining or transforming
the prior agent's output without being influenced by earlier turns.
In this sample, a writer agent produces a draft tagline, a translator agent translates
it into French (seeing only the writer's output, not the original user prompt), and a
reviewer agent evaluates the translation (seeing only the translator's output).
Compare with `sequential_agents.py`, which uses the default behavior where the full
conversation context is passed to each agent.
Prerequisites:
- AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
- Azure OpenAI configured for AzureOpenAIResponsesClient with required environment variables.
- Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample.
"""
# Load environment variables from .env file
load_dotenv()
async def main() -> None:
# 1) Create agents
client = AzureOpenAIResponsesClient(
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
credential=AzureCliCredential(),
)
writer = client.as_agent(
instructions="You are a concise copywriter. Provide a single, punchy marketing sentence based on the prompt.",
name="writer",
)
translator = client.as_agent(
instructions="You are a translator. Translate the given text into French. Output only the translation.",
name="translator",
)
reviewer = client.as_agent(
instructions="You are a reviewer. Evaluate the quality of the marketing tagline.",
name="reviewer",
)
# 2) Build sequential workflow: writer -> translator -> reviewer
# chain_only_agent_responses=True means each agent sees only the previous agent's reply,
# not the full conversation history.
workflow = SequentialBuilder(
participants=[writer, translator, reviewer],
chain_only_agent_responses=True,
intermediate_outputs=True,
).build()
# 3) Run and collect outputs
last_agent: str | None = None
async for event in workflow.run("Write a tagline for a budget-friendly eBike.", stream=True):
if event.type == "output" and isinstance(event.data, AgentResponseUpdate):
if event.data.author_name != last_agent:
last_agent = event.data.author_name
print()
print(f"{last_agent}: ", end="", flush=True)
print(event.data.text, end="", flush=True)
if __name__ == "__main__":
asyncio.run(main())