Python: Fix WorkflowAgent event handling and kwargs forwarding (#2946)

* Fix kwargs propagation through workflow.as_agent()

* Fix WorkflowAgent to respect AgentExecutor output_response setting
This commit is contained in:
Evan Mattson
2025-12-19 04:35:07 +09:00
committed by GitHub
Unverified
parent a841bdd1cc
commit b0a7a1fcb8
6 changed files with 484 additions and 17 deletions
@@ -26,6 +26,7 @@ from agent_framework import (
)
from ..exceptions import AgentExecutionException
from ._agent_executor import AgentExecutor
from ._checkpoint import CheckpointStorage
from ._events import (
AgentRunUpdateEvent,
@@ -141,7 +142,8 @@ class WorkflowAgent(BaseAgent):
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
used to load and restore the checkpoint. When provided without checkpoint_id,
enables checkpointing for this run.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments passed through to underlying workflow
and ai_function tools.
Returns:
The final workflow response as an AgentRunResponse.
@@ -153,7 +155,7 @@ class WorkflowAgent(BaseAgent):
response_id = str(uuid.uuid4())
async for update in self._run_stream_impl(
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs
):
response_updates.append(update)
@@ -187,7 +189,8 @@ class WorkflowAgent(BaseAgent):
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
used to load and restore the checkpoint. When provided without checkpoint_id,
enables checkpointing for this run.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments passed through to underlying workflow
and ai_function tools.
Yields:
AgentRunResponseUpdate objects representing the workflow execution progress.
@@ -198,7 +201,7 @@ class WorkflowAgent(BaseAgent):
response_id = str(uuid.uuid4())
async for update in self._run_stream_impl(
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs
):
response_updates.append(update)
yield update
@@ -216,6 +219,7 @@ class WorkflowAgent(BaseAgent):
thread: AgentThread,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Internal implementation of streaming execution.
@@ -225,6 +229,8 @@ class WorkflowAgent(BaseAgent):
thread: The conversation thread containing message history.
checkpoint_id: ID of checkpoint to restore from.
checkpoint_storage: Runtime checkpoint storage.
**kwargs: Additional keyword arguments passed through to the underlying
workflow and ai_function tools.
Yields:
AgentRunResponseUpdate objects representing the workflow execution progress.
@@ -255,6 +261,7 @@ class WorkflowAgent(BaseAgent):
message=None,
checkpoint_id=checkpoint_id,
checkpoint_storage=checkpoint_storage,
**kwargs,
)
else:
# Execute workflow with streaming (initial run or no function responses)
@@ -268,6 +275,7 @@ class WorkflowAgent(BaseAgent):
event_stream = self.workflow.run_stream(
message=conversation_messages,
checkpoint_storage=checkpoint_storage,
**kwargs,
)
# Process events from the stream
@@ -286,10 +294,20 @@ class WorkflowAgent(BaseAgent):
AgentRunUpdateEvent, RequestInfoEvent, and WorkflowOutputEvent are processed.
Other workflow events are ignored as they are workflow-internal.
For AgentRunUpdateEvent from AgentExecutor instances, only events from executors
with output_response=True are converted to agent updates. This prevents agent
responses from executors that were not explicitly marked to surface their output.
Non-AgentExecutor executors that emit AgentRunUpdateEvent directly are allowed
through since they explicitly chose to emit the event.
"""
match event:
case AgentRunUpdateEvent(data=update):
# Direct pass-through of update in an agent streaming event
case AgentRunUpdateEvent(data=update, executor_id=executor_id):
# For AgentExecutor instances, only pass through if output_response=True.
# Non-AgentExecutor executors that emit AgentRunUpdateEvent are allowed through.
executor = self.workflow.executors.get(executor_id)
if isinstance(executor, AgentExecutor) and not executor.output_response:
return None
if update:
return update
return None
@@ -297,11 +315,17 @@ class WorkflowAgent(BaseAgent):
case WorkflowOutputEvent(data=data, source_executor_id=source_executor_id):
# Convert workflow output to an agent response update.
# Handle different data types appropriately.
# Skip AgentRunResponse from AgentExecutor with output_response=True
# since streaming events already surfaced the content.
if isinstance(data, AgentRunResponse):
executor = self.workflow.executors.get(source_executor_id)
if isinstance(executor, AgentExecutor) and executor.output_response:
return None
if isinstance(data, AgentRunResponseUpdate):
# Already an update, pass through
return data
if isinstance(data, ChatMessage):
# Convert ChatMessage to update
return AgentRunResponseUpdate(
contents=list(data.contents),
role=data.role,
@@ -311,15 +335,9 @@ class WorkflowAgent(BaseAgent):
created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
raw_representation=data,
)
# Determine contents based on data type
if isinstance(data, BaseContent):
# Already a content type (TextContent, ImageContent, etc.)
contents: list[Contents] = [cast(Contents, data)]
elif isinstance(data, str):
contents = [TextContent(text=data)]
else:
# Fallback: convert to string representation
contents = [TextContent(text=str(data))]
contents = self._extract_contents(data)
if not contents:
return None
return AgentRunResponseUpdate(
contents=contents,
role=Role.ASSISTANT,
@@ -405,6 +423,18 @@ class WorkflowAgent(BaseAgent):
raise AgentExecutionException("Unexpected content type while awaiting request info responses.")
return function_responses
def _extract_contents(self, data: Any) -> list[Contents]:
"""Recursively extract Contents from workflow output data."""
if isinstance(data, ChatMessage):
return list(data.contents)
if isinstance(data, list):
return [c for item in data for c in self._extract_contents(item)]
if isinstance(data, BaseContent):
return [cast(Contents, data)]
if isinstance(data, str):
return [TextContent(text=data)]
return [TextContent(text=str(data))]
class _ResponseState(TypedDict):
"""State for grouping response updates by message_id."""
@@ -99,6 +99,11 @@ class AgentExecutor(Executor):
self._output_response = output_response
self._cache: list[ChatMessage] = []
@property
def output_response(self) -> bool:
"""Whether this executor yields AgentRunResponse as workflow output when complete."""
return self._output_response
@property
def workflow_output_types(self) -> list[type[Any]]:
# Override to declare AgentRunResponse as a possible output type only if enabled.
@@ -1,11 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.
import uuid
from collections.abc import AsyncIterable
from typing import Any
import pytest
from agent_framework import (
AgentProtocol,
AgentRunResponse,
AgentRunResponseUpdate,
AgentRunUpdateEvent,
@@ -422,6 +424,48 @@ class TestWorkflowAgent:
assert isinstance(updates[2].raw_representation, CustomData)
assert updates[2].raw_representation.value == 42
async def test_workflow_as_agent_yield_output_with_list_of_chat_messages(self) -> None:
"""Test that yield_output with list[ChatMessage] extracts contents from all messages.
Note: TextContent items are coalesced by _finalize_response, so multiple text contents
become a single merged TextContent in the final response.
"""
@executor
async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
# Yield a list of ChatMessages (as SequentialBuilder does)
msg_list = [
ChatMessage(role=Role.USER, contents=[TextContent(text="first message")]),
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="second message")]),
ChatMessage(
role=Role.ASSISTANT,
contents=[TextContent(text="third"), TextContent(text="fourth")],
),
]
await ctx.yield_output(msg_list)
workflow = WorkflowBuilder().set_start_executor(list_yielding_executor).build()
agent = workflow.as_agent("list-msg-agent")
# Verify streaming returns the update with all 4 contents before coalescing
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream("test"):
updates.append(update)
assert len(updates) == 1
assert len(updates[0].contents) == 4
texts = [c.text for c in updates[0].contents if isinstance(c, TextContent)]
assert texts == ["first message", "second message", "third", "fourth"]
# Verify run() coalesces text contents (expected behavior)
result = await agent.run("test")
assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
# TextContent items are coalesced into one
assert len(result.messages[0].contents) == 1
assert result.messages[0].text == "first messagesecond messagethirdfourth"
async def test_thread_conversation_history_included_in_workflow_run(self) -> None:
"""Test that conversation history from thread is included when running WorkflowAgent.
@@ -521,6 +565,142 @@ class TestWorkflowAgent:
checkpoints = await checkpoint_storage.list_checkpoints(workflow.id)
assert len(checkpoints) > 0, "Checkpoints should have been created when checkpoint_storage is provided"
async def test_agent_executor_output_response_false_filters_streaming_events(self):
"""Test that AgentExecutor with output_response=False does not surface streaming events."""
class MockAgent(AgentProtocol):
"""Mock agent for testing."""
def __init__(self, name: str, response_text: str) -> None:
self._name = name
self._response_text = response_text
self._description: str | None = None
@property
def name(self) -> str | None:
return self._name
@property
def description(self) -> str | None:
return self._description
def get_new_thread(self) -> AgentThread:
return AgentThread()
async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse:
return AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)],
text=self._response_text,
)
async def run_stream(
self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any
) -> AsyncIterable[AgentRunResponseUpdate]:
for word in self._response_text.split():
yield AgentRunResponseUpdate(
contents=[TextContent(text=word + " ")],
role=Role.ASSISTANT,
author_name=self._name,
)
@executor
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
from agent_framework import AgentExecutorRequest
await ctx.yield_output("Start output")
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
# Build workflow: start -> agent1 (no output) -> agent2 (output_response=True)
workflow = (
WorkflowBuilder()
.register_executor(lambda: start_executor, "start")
.register_agent(lambda: MockAgent("agent1", "Agent1 output - should NOT appear"), "agent1")
.register_agent(
lambda: MockAgent("agent2", "Agent2 output - SHOULD appear"), "agent2", output_response=True
)
.set_start_executor("start")
.add_edge("start", "agent1")
.add_edge("agent1", "agent2")
.build()
)
agent = WorkflowAgent(workflow=workflow, name="Test Agent")
result = await agent.run("Test input")
# Collect all message texts
texts = [msg.text for msg in result.messages if msg.text]
# Start output should appear (from yield_output)
assert any("Start output" in t for t in texts), "Start output should appear"
# Agent1 output should NOT appear (output_response=False)
assert not any("Agent1" in t for t in texts), "Agent1 output should NOT appear"
# Agent2 output should appear (output_response=True)
assert any("Agent2" in t for t in texts), "Agent2 output should appear"
async def test_agent_executor_output_response_no_duplicate_from_workflow_output_event(self):
"""Test that AgentExecutor with output_response=True does not duplicate content."""
class MockAgent(AgentProtocol):
"""Mock agent for testing."""
def __init__(self, name: str, response_text: str) -> None:
self._name = name
self._response_text = response_text
self._description: str | None = None
@property
def name(self) -> str | None:
return self._name
@property
def description(self) -> str | None:
return self._description
def get_new_thread(self) -> AgentThread:
return AgentThread()
async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse:
return AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)],
text=self._response_text,
)
async def run_stream(
self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any
) -> AsyncIterable[AgentRunResponseUpdate]:
yield AgentRunResponseUpdate(
contents=[TextContent(text=self._response_text)],
role=Role.ASSISTANT,
author_name=self._name,
)
@executor
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
from agent_framework import AgentExecutorRequest
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
# Build workflow with single agent that has output_response=True
workflow = (
WorkflowBuilder()
.register_executor(lambda: start_executor, "start")
.register_agent(lambda: MockAgent("agent", "Unique response text"), "agent", output_response=True)
.set_start_executor("start")
.add_edge("start", "agent")
.build()
)
agent = WorkflowAgent(workflow=workflow, name="Test Agent")
result = await agent.run("Test input")
# Count occurrences of the unique response text
unique_text_count = sum(1 for msg in result.messages if msg.text and "Unique response text" in msg.text)
# Should appear exactly once (not duplicated from both streaming and WorkflowOutputEvent)
assert unique_text_count == 1, f"Response should appear exactly once, but appeared {unique_text_count} times"
class TestWorkflowAgentMergeUpdates:
"""Test cases specifically for the WorkflowAgent.merge_updates static method."""
@@ -492,6 +492,117 @@ async def test_magentic_kwargs_stored_in_shared_state() -> None:
# endregion
# region WorkflowAgent (as_agent) kwargs Tests
async def test_workflow_as_agent_run_propagates_kwargs_to_underlying_agent() -> None:
"""Test that kwargs passed to workflow_agent.run() flow through to the underlying agents."""
agent = _KwargsCapturingAgent(name="inner_agent")
workflow = SequentialBuilder().participants([agent]).build()
workflow_agent = workflow.as_agent(name="TestWorkflowAgent")
custom_data = {"endpoint": "https://api.example.com", "version": "v1"}
user_token = {"user_name": "alice", "access_level": "admin"}
_ = await workflow_agent.run(
"test message",
custom_data=custom_data,
user_token=user_token,
)
# Verify inner agent received kwargs
assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once"
received = agent.captured_kwargs[0]
assert "custom_data" in received, "Inner agent should receive custom_data kwarg"
assert "user_token" in received, "Inner agent should receive user_token kwarg"
assert received["custom_data"] == custom_data
assert received["user_token"] == user_token
async def test_workflow_as_agent_run_stream_propagates_kwargs_to_underlying_agent() -> None:
"""Test that kwargs passed to workflow_agent.run_stream() flow through to the underlying agents."""
agent = _KwargsCapturingAgent(name="inner_agent")
workflow = SequentialBuilder().participants([agent]).build()
workflow_agent = workflow.as_agent(name="TestWorkflowAgent")
custom_data = {"session_id": "xyz123"}
api_token = "secret-token"
async for _ in workflow_agent.run_stream(
"test message",
custom_data=custom_data,
api_token=api_token,
):
pass
# Verify inner agent received kwargs
assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once"
received = agent.captured_kwargs[0]
assert "custom_data" in received, "Inner agent should receive custom_data kwarg"
assert "api_token" in received, "Inner agent should receive api_token kwarg"
assert received["custom_data"] == custom_data
assert received["api_token"] == api_token
async def test_workflow_as_agent_propagates_kwargs_to_multiple_agents() -> None:
"""Test that kwargs flow to all agents when using workflow.as_agent()."""
agent1 = _KwargsCapturingAgent(name="agent1")
agent2 = _KwargsCapturingAgent(name="agent2")
workflow = SequentialBuilder().participants([agent1, agent2]).build()
workflow_agent = workflow.as_agent(name="MultiAgentWorkflow")
custom_data = {"batch_id": "batch-001"}
_ = await workflow_agent.run("test message", custom_data=custom_data)
# Both agents should have received kwargs
assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked"
assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked"
assert agent1.captured_kwargs[0].get("custom_data") == custom_data
assert agent2.captured_kwargs[0].get("custom_data") == custom_data
async def test_workflow_as_agent_kwargs_with_none_values() -> None:
"""Test that kwargs with None values are passed through correctly via as_agent()."""
agent = _KwargsCapturingAgent(name="none_test_agent")
workflow = SequentialBuilder().participants([agent]).build()
workflow_agent = workflow.as_agent(name="NoneTestWorkflow")
_ = await workflow_agent.run("test", optional_param=None, other_param="value")
assert len(agent.captured_kwargs) >= 1
received = agent.captured_kwargs[0]
assert "optional_param" in received
assert received["optional_param"] is None
assert received["other_param"] == "value"
async def test_workflow_as_agent_kwargs_with_complex_nested_data() -> None:
"""Test that complex nested data structures flow through correctly via as_agent()."""
agent = _KwargsCapturingAgent(name="nested_agent")
workflow = SequentialBuilder().participants([agent]).build()
workflow_agent = workflow.as_agent(name="NestedDataWorkflow")
complex_data = {
"level1": {
"level2": {
"level3": ["a", "b", "c"],
"number": 42,
},
"list": [1, 2, {"nested": True}],
},
}
_ = await workflow_agent.run("test", complex_data=complex_data)
assert len(agent.captured_kwargs) >= 1
received = agent.captured_kwargs[0]
assert received.get("complex_data") == complex_data
# endregion
# region SubWorkflow (WorkflowExecutor) Tests
@@ -45,6 +45,7 @@ Once comfortable with these, explore the rest of the samples below.
| Workflow as Agent (Reflection Pattern) | [agents/workflow_as_agent_reflection_pattern.py](./agents/workflow_as_agent_reflection_pattern.py) | Wrap a workflow so it can behave like an agent (reflection pattern) |
| Workflow as Agent + HITL | [agents/workflow_as_agent_human_in_the_loop.py](./agents/workflow_as_agent_human_in_the_loop.py) | Extend workflow-as-agent with human-in-the-loop capability |
| Workflow as Agent with Thread | [agents/workflow_as_agent_with_thread.py](./agents/workflow_as_agent_with_thread.py) | Use AgentThread to maintain conversation history across workflow-as-agent invocations |
| Workflow as Agent kwargs | [agents/workflow_as_agent_kwargs.py](./agents/workflow_as_agent_kwargs.py) | Pass custom context (data, user tokens) via kwargs through workflow.as_agent() to @ai_function tools |
| Handoff Workflow as Agent | [agents/handoff_workflow_as_agent.py](./agents/handoff_workflow_as_agent.py) | Use a HandoffBuilder workflow as an agent with HITL via FunctionCallContent/FunctionResultContent |
### checkpoint
@@ -0,0 +1,140 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import json
from typing import Annotated, Any
from agent_framework import SequentialBuilder, ai_function
from agent_framework.openai import OpenAIChatClient
from pydantic import Field
"""
Sample: Workflow as Agent with kwargs Propagation to @ai_function Tools
This sample demonstrates how to flow custom context (skill data, user tokens, etc.)
through a workflow exposed via .as_agent() to @ai_function tools using the **kwargs pattern.
Key Concepts:
- Build a workflow using SequentialBuilder (or any builder pattern)
- Expose the workflow as a reusable agent via workflow.as_agent()
- Pass custom context as kwargs when invoking workflow_agent.run() or run_stream()
- kwargs are stored in SharedState and propagated to all agent invocations
- @ai_function tools receive kwargs via **kwargs parameter
When to use workflow.as_agent():
- To treat an entire workflow orchestration as a single agent
- To compose workflows into higher-level orchestrations
- To maintain a consistent agent interface for callers
Prerequisites:
- OpenAI environment variables configured
"""
# Define tools that accept custom context via **kwargs
@ai_function
def get_user_data(
query: Annotated[str, Field(description="What user data to retrieve")],
**kwargs: Any,
) -> str:
"""Retrieve user-specific data based on the authenticated context."""
user_token = kwargs.get("user_token", {})
user_name = user_token.get("user_name", "anonymous")
access_level = user_token.get("access_level", "none")
print(f"\n[get_user_data] Received kwargs keys: {list(kwargs.keys())}")
print(f"[get_user_data] User: {user_name}")
print(f"[get_user_data] Access level: {access_level}")
return f"Retrieved data for user {user_name} with {access_level} access: {query}"
@ai_function
def call_api(
endpoint_name: Annotated[str, Field(description="Name of the API endpoint to call")],
**kwargs: Any,
) -> str:
"""Call an API using the configured endpoints from custom_data."""
custom_data = kwargs.get("custom_data", {})
api_config = custom_data.get("api_config", {})
base_url = api_config.get("base_url", "unknown")
endpoints = api_config.get("endpoints", {})
print(f"\n[call_api] Received kwargs keys: {list(kwargs.keys())}")
print(f"[call_api] Base URL: {base_url}")
print(f"[call_api] Available endpoints: {list(endpoints.keys())}")
if endpoint_name in endpoints:
return f"Called {base_url}{endpoints[endpoint_name]} successfully"
return f"Endpoint '{endpoint_name}' not found in configuration"
async def main() -> None:
print("=" * 70)
print("Workflow as Agent kwargs Flow Demo")
print("=" * 70)
# Create chat client
chat_client = OpenAIChatClient()
# Create agent with tools that use kwargs
agent = chat_client.create_agent(
name="assistant",
instructions=(
"You are a helpful assistant. Use the available tools to help users. "
"When asked about user data, use get_user_data. "
"When asked to call an API, use call_api."
),
tools=[get_user_data, call_api],
)
# Build a sequential workflow
workflow = SequentialBuilder().participants([agent]).build()
# Expose the workflow as an agent using .as_agent()
workflow_agent = workflow.as_agent(name="WorkflowAgent")
# Define custom context that will flow to ai_functions via kwargs
custom_data = {
"api_config": {
"base_url": "https://api.example.com",
"endpoints": {
"users": "/v1/users",
"orders": "/v1/orders",
"products": "/v1/products",
},
},
}
user_token = {
"user_name": "bob@contoso.com",
"access_level": "admin",
}
print("\nCustom Data being passed:")
print(json.dumps(custom_data, indent=2))
print(f"\nUser: {user_token['user_name']}")
print("\n" + "-" * 70)
print("Workflow Agent Execution (watch for [tool_name] logs showing kwargs received):")
print("-" * 70)
# Run workflow agent with kwargs - these will flow through to ai_functions
# Note: kwargs are passed to workflow_agent.run_stream() just like workflow.run_stream()
print("\n===== Streaming Response =====")
async for update in workflow_agent.run_stream(
"Please get my user data and then call the users API endpoint.",
custom_data=custom_data,
user_token=user_token,
):
if update.text:
print(update.text, end="", flush=True)
print()
print("\n" + "=" * 70)
print("Sample Complete")
print("=" * 70)
if __name__ == "__main__":
asyncio.run(main())