mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix WorkflowAgent event handling and kwargs forwarding (#2946)
* Fix kwargs propagation through workflow.as_agent() * Fix WorkflowAgent to respect AgentExecutor output_response setting
This commit is contained in:
committed by
GitHub
Unverified
parent
a841bdd1cc
commit
b0a7a1fcb8
@@ -26,6 +26,7 @@ from agent_framework import (
|
||||
)
|
||||
|
||||
from ..exceptions import AgentExecutionException
|
||||
from ._agent_executor import AgentExecutor
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._events import (
|
||||
AgentRunUpdateEvent,
|
||||
@@ -141,7 +142,8 @@ class WorkflowAgent(BaseAgent):
|
||||
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
|
||||
used to load and restore the checkpoint. When provided without checkpoint_id,
|
||||
enables checkpointing for this run.
|
||||
**kwargs: Additional keyword arguments.
|
||||
**kwargs: Additional keyword arguments passed through to underlying workflow
|
||||
and ai_function tools.
|
||||
|
||||
Returns:
|
||||
The final workflow response as an AgentRunResponse.
|
||||
@@ -153,7 +155,7 @@ class WorkflowAgent(BaseAgent):
|
||||
response_id = str(uuid.uuid4())
|
||||
|
||||
async for update in self._run_stream_impl(
|
||||
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
|
||||
input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs
|
||||
):
|
||||
response_updates.append(update)
|
||||
|
||||
@@ -187,7 +189,8 @@ class WorkflowAgent(BaseAgent):
|
||||
checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id,
|
||||
used to load and restore the checkpoint. When provided without checkpoint_id,
|
||||
enables checkpointing for this run.
|
||||
**kwargs: Additional keyword arguments.
|
||||
**kwargs: Additional keyword arguments passed through to underlying workflow
|
||||
and ai_function tools.
|
||||
|
||||
Yields:
|
||||
AgentRunResponseUpdate objects representing the workflow execution progress.
|
||||
@@ -198,7 +201,7 @@ class WorkflowAgent(BaseAgent):
|
||||
response_id = str(uuid.uuid4())
|
||||
|
||||
async for update in self._run_stream_impl(
|
||||
input_messages, response_id, thread, checkpoint_id, checkpoint_storage
|
||||
input_messages, response_id, thread, checkpoint_id, checkpoint_storage, **kwargs
|
||||
):
|
||||
response_updates.append(update)
|
||||
yield update
|
||||
@@ -216,6 +219,7 @@ class WorkflowAgent(BaseAgent):
|
||||
thread: AgentThread,
|
||||
checkpoint_id: str | None = None,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
"""Internal implementation of streaming execution.
|
||||
|
||||
@@ -225,6 +229,8 @@ class WorkflowAgent(BaseAgent):
|
||||
thread: The conversation thread containing message history.
|
||||
checkpoint_id: ID of checkpoint to restore from.
|
||||
checkpoint_storage: Runtime checkpoint storage.
|
||||
**kwargs: Additional keyword arguments passed through to the underlying
|
||||
workflow and ai_function tools.
|
||||
|
||||
Yields:
|
||||
AgentRunResponseUpdate objects representing the workflow execution progress.
|
||||
@@ -255,6 +261,7 @@ class WorkflowAgent(BaseAgent):
|
||||
message=None,
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Execute workflow with streaming (initial run or no function responses)
|
||||
@@ -268,6 +275,7 @@ class WorkflowAgent(BaseAgent):
|
||||
event_stream = self.workflow.run_stream(
|
||||
message=conversation_messages,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Process events from the stream
|
||||
@@ -286,10 +294,20 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
AgentRunUpdateEvent, RequestInfoEvent, and WorkflowOutputEvent are processed.
|
||||
Other workflow events are ignored as they are workflow-internal.
|
||||
|
||||
For AgentRunUpdateEvent from AgentExecutor instances, only events from executors
|
||||
with output_response=True are converted to agent updates. This prevents agent
|
||||
responses from executors that were not explicitly marked to surface their output.
|
||||
Non-AgentExecutor executors that emit AgentRunUpdateEvent directly are allowed
|
||||
through since they explicitly chose to emit the event.
|
||||
"""
|
||||
match event:
|
||||
case AgentRunUpdateEvent(data=update):
|
||||
# Direct pass-through of update in an agent streaming event
|
||||
case AgentRunUpdateEvent(data=update, executor_id=executor_id):
|
||||
# For AgentExecutor instances, only pass through if output_response=True.
|
||||
# Non-AgentExecutor executors that emit AgentRunUpdateEvent are allowed through.
|
||||
executor = self.workflow.executors.get(executor_id)
|
||||
if isinstance(executor, AgentExecutor) and not executor.output_response:
|
||||
return None
|
||||
if update:
|
||||
return update
|
||||
return None
|
||||
@@ -297,11 +315,17 @@ class WorkflowAgent(BaseAgent):
|
||||
case WorkflowOutputEvent(data=data, source_executor_id=source_executor_id):
|
||||
# Convert workflow output to an agent response update.
|
||||
# Handle different data types appropriately.
|
||||
|
||||
# Skip AgentRunResponse from AgentExecutor with output_response=True
|
||||
# since streaming events already surfaced the content.
|
||||
if isinstance(data, AgentRunResponse):
|
||||
executor = self.workflow.executors.get(source_executor_id)
|
||||
if isinstance(executor, AgentExecutor) and executor.output_response:
|
||||
return None
|
||||
|
||||
if isinstance(data, AgentRunResponseUpdate):
|
||||
# Already an update, pass through
|
||||
return data
|
||||
if isinstance(data, ChatMessage):
|
||||
# Convert ChatMessage to update
|
||||
return AgentRunResponseUpdate(
|
||||
contents=list(data.contents),
|
||||
role=data.role,
|
||||
@@ -311,15 +335,9 @@ class WorkflowAgent(BaseAgent):
|
||||
created_at=datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
|
||||
raw_representation=data,
|
||||
)
|
||||
# Determine contents based on data type
|
||||
if isinstance(data, BaseContent):
|
||||
# Already a content type (TextContent, ImageContent, etc.)
|
||||
contents: list[Contents] = [cast(Contents, data)]
|
||||
elif isinstance(data, str):
|
||||
contents = [TextContent(text=data)]
|
||||
else:
|
||||
# Fallback: convert to string representation
|
||||
contents = [TextContent(text=str(data))]
|
||||
contents = self._extract_contents(data)
|
||||
if not contents:
|
||||
return None
|
||||
return AgentRunResponseUpdate(
|
||||
contents=contents,
|
||||
role=Role.ASSISTANT,
|
||||
@@ -405,6 +423,18 @@ class WorkflowAgent(BaseAgent):
|
||||
raise AgentExecutionException("Unexpected content type while awaiting request info responses.")
|
||||
return function_responses
|
||||
|
||||
def _extract_contents(self, data: Any) -> list[Contents]:
|
||||
"""Recursively extract Contents from workflow output data."""
|
||||
if isinstance(data, ChatMessage):
|
||||
return list(data.contents)
|
||||
if isinstance(data, list):
|
||||
return [c for item in data for c in self._extract_contents(item)]
|
||||
if isinstance(data, BaseContent):
|
||||
return [cast(Contents, data)]
|
||||
if isinstance(data, str):
|
||||
return [TextContent(text=data)]
|
||||
return [TextContent(text=str(data))]
|
||||
|
||||
class _ResponseState(TypedDict):
|
||||
"""State for grouping response updates by message_id."""
|
||||
|
||||
|
||||
@@ -99,6 +99,11 @@ class AgentExecutor(Executor):
|
||||
self._output_response = output_response
|
||||
self._cache: list[ChatMessage] = []
|
||||
|
||||
@property
|
||||
def output_response(self) -> bool:
|
||||
"""Whether this executor yields AgentRunResponse as workflow output when complete."""
|
||||
return self._output_response
|
||||
|
||||
@property
|
||||
def workflow_output_types(self) -> list[type[Any]]:
|
||||
# Override to declare AgentRunResponse as a possible output type only if enabled.
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentRunUpdateEvent,
|
||||
@@ -422,6 +424,48 @@ class TestWorkflowAgent:
|
||||
assert isinstance(updates[2].raw_representation, CustomData)
|
||||
assert updates[2].raw_representation.value == 42
|
||||
|
||||
async def test_workflow_as_agent_yield_output_with_list_of_chat_messages(self) -> None:
|
||||
"""Test that yield_output with list[ChatMessage] extracts contents from all messages.
|
||||
|
||||
Note: TextContent items are coalesced by _finalize_response, so multiple text contents
|
||||
become a single merged TextContent in the final response.
|
||||
"""
|
||||
|
||||
@executor
|
||||
async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
# Yield a list of ChatMessages (as SequentialBuilder does)
|
||||
msg_list = [
|
||||
ChatMessage(role=Role.USER, contents=[TextContent(text="first message")]),
|
||||
ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="second message")]),
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[TextContent(text="third"), TextContent(text="fourth")],
|
||||
),
|
||||
]
|
||||
await ctx.yield_output(msg_list)
|
||||
|
||||
workflow = WorkflowBuilder().set_start_executor(list_yielding_executor).build()
|
||||
agent = workflow.as_agent("list-msg-agent")
|
||||
|
||||
# Verify streaming returns the update with all 4 contents before coalescing
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream("test"):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 1
|
||||
assert len(updates[0].contents) == 4
|
||||
texts = [c.text for c in updates[0].contents if isinstance(c, TextContent)]
|
||||
assert texts == ["first message", "second message", "third", "fourth"]
|
||||
|
||||
# Verify run() coalesces text contents (expected behavior)
|
||||
result = await agent.run("test")
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
# TextContent items are coalesced into one
|
||||
assert len(result.messages[0].contents) == 1
|
||||
assert result.messages[0].text == "first messagesecond messagethirdfourth"
|
||||
|
||||
async def test_thread_conversation_history_included_in_workflow_run(self) -> None:
|
||||
"""Test that conversation history from thread is included when running WorkflowAgent.
|
||||
|
||||
@@ -521,6 +565,142 @@ class TestWorkflowAgent:
|
||||
checkpoints = await checkpoint_storage.list_checkpoints(workflow.id)
|
||||
assert len(checkpoints) > 0, "Checkpoints should have been created when checkpoint_storage is provided"
|
||||
|
||||
async def test_agent_executor_output_response_false_filters_streaming_events(self):
|
||||
"""Test that AgentExecutor with output_response=False does not surface streaming events."""
|
||||
|
||||
class MockAgent(AgentProtocol):
|
||||
"""Mock agent for testing."""
|
||||
|
||||
def __init__(self, name: str, response_text: str) -> None:
|
||||
self._name = name
|
||||
self._response_text = response_text
|
||||
self._description: str | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return self._description
|
||||
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
return AgentThread()
|
||||
|
||||
async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse:
|
||||
return AgentRunResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)],
|
||||
text=self._response_text,
|
||||
)
|
||||
|
||||
async def run_stream(
|
||||
self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
for word in self._response_text.split():
|
||||
yield AgentRunResponseUpdate(
|
||||
contents=[TextContent(text=word + " ")],
|
||||
role=Role.ASSISTANT,
|
||||
author_name=self._name,
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
from agent_framework import AgentExecutorRequest
|
||||
|
||||
await ctx.yield_output("Start output")
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
# Build workflow: start -> agent1 (no output) -> agent2 (output_response=True)
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.register_executor(lambda: start_executor, "start")
|
||||
.register_agent(lambda: MockAgent("agent1", "Agent1 output - should NOT appear"), "agent1")
|
||||
.register_agent(
|
||||
lambda: MockAgent("agent2", "Agent2 output - SHOULD appear"), "agent2", output_response=True
|
||||
)
|
||||
.set_start_executor("start")
|
||||
.add_edge("start", "agent1")
|
||||
.add_edge("agent1", "agent2")
|
||||
.build()
|
||||
)
|
||||
|
||||
agent = WorkflowAgent(workflow=workflow, name="Test Agent")
|
||||
result = await agent.run("Test input")
|
||||
|
||||
# Collect all message texts
|
||||
texts = [msg.text for msg in result.messages if msg.text]
|
||||
|
||||
# Start output should appear (from yield_output)
|
||||
assert any("Start output" in t for t in texts), "Start output should appear"
|
||||
|
||||
# Agent1 output should NOT appear (output_response=False)
|
||||
assert not any("Agent1" in t for t in texts), "Agent1 output should NOT appear"
|
||||
|
||||
# Agent2 output should appear (output_response=True)
|
||||
assert any("Agent2" in t for t in texts), "Agent2 output should appear"
|
||||
|
||||
async def test_agent_executor_output_response_no_duplicate_from_workflow_output_event(self):
|
||||
"""Test that AgentExecutor with output_response=True does not duplicate content."""
|
||||
|
||||
class MockAgent(AgentProtocol):
|
||||
"""Mock agent for testing."""
|
||||
|
||||
def __init__(self, name: str, response_text: str) -> None:
|
||||
self._name = name
|
||||
self._response_text = response_text
|
||||
self._description: str | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return self._description
|
||||
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
return AgentThread()
|
||||
|
||||
async def run(self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any) -> AgentRunResponse:
|
||||
return AgentRunResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text=self._response_text)],
|
||||
text=self._response_text,
|
||||
)
|
||||
|
||||
async def run_stream(
|
||||
self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
yield AgentRunResponseUpdate(
|
||||
contents=[TextContent(text=self._response_text)],
|
||||
role=Role.ASSISTANT,
|
||||
author_name=self._name,
|
||||
)
|
||||
|
||||
@executor
|
||||
async def start_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
from agent_framework import AgentExecutorRequest
|
||||
|
||||
await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=True))
|
||||
|
||||
# Build workflow with single agent that has output_response=True
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.register_executor(lambda: start_executor, "start")
|
||||
.register_agent(lambda: MockAgent("agent", "Unique response text"), "agent", output_response=True)
|
||||
.set_start_executor("start")
|
||||
.add_edge("start", "agent")
|
||||
.build()
|
||||
)
|
||||
|
||||
agent = WorkflowAgent(workflow=workflow, name="Test Agent")
|
||||
result = await agent.run("Test input")
|
||||
|
||||
# Count occurrences of the unique response text
|
||||
unique_text_count = sum(1 for msg in result.messages if msg.text and "Unique response text" in msg.text)
|
||||
|
||||
# Should appear exactly once (not duplicated from both streaming and WorkflowOutputEvent)
|
||||
assert unique_text_count == 1, f"Response should appear exactly once, but appeared {unique_text_count} times"
|
||||
|
||||
|
||||
class TestWorkflowAgentMergeUpdates:
|
||||
"""Test cases specifically for the WorkflowAgent.merge_updates static method."""
|
||||
|
||||
@@ -492,6 +492,117 @@ async def test_magentic_kwargs_stored_in_shared_state() -> None:
|
||||
# endregion
|
||||
|
||||
|
||||
# region WorkflowAgent (as_agent) kwargs Tests
|
||||
|
||||
|
||||
async def test_workflow_as_agent_run_propagates_kwargs_to_underlying_agent() -> None:
|
||||
"""Test that kwargs passed to workflow_agent.run() flow through to the underlying agents."""
|
||||
agent = _KwargsCapturingAgent(name="inner_agent")
|
||||
workflow = SequentialBuilder().participants([agent]).build()
|
||||
workflow_agent = workflow.as_agent(name="TestWorkflowAgent")
|
||||
|
||||
custom_data = {"endpoint": "https://api.example.com", "version": "v1"}
|
||||
user_token = {"user_name": "alice", "access_level": "admin"}
|
||||
|
||||
_ = await workflow_agent.run(
|
||||
"test message",
|
||||
custom_data=custom_data,
|
||||
user_token=user_token,
|
||||
)
|
||||
|
||||
# Verify inner agent received kwargs
|
||||
assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once"
|
||||
received = agent.captured_kwargs[0]
|
||||
assert "custom_data" in received, "Inner agent should receive custom_data kwarg"
|
||||
assert "user_token" in received, "Inner agent should receive user_token kwarg"
|
||||
assert received["custom_data"] == custom_data
|
||||
assert received["user_token"] == user_token
|
||||
|
||||
|
||||
async def test_workflow_as_agent_run_stream_propagates_kwargs_to_underlying_agent() -> None:
|
||||
"""Test that kwargs passed to workflow_agent.run_stream() flow through to the underlying agents."""
|
||||
agent = _KwargsCapturingAgent(name="inner_agent")
|
||||
workflow = SequentialBuilder().participants([agent]).build()
|
||||
workflow_agent = workflow.as_agent(name="TestWorkflowAgent")
|
||||
|
||||
custom_data = {"session_id": "xyz123"}
|
||||
api_token = "secret-token"
|
||||
|
||||
async for _ in workflow_agent.run_stream(
|
||||
"test message",
|
||||
custom_data=custom_data,
|
||||
api_token=api_token,
|
||||
):
|
||||
pass
|
||||
|
||||
# Verify inner agent received kwargs
|
||||
assert len(agent.captured_kwargs) >= 1, "Inner agent should have been invoked at least once"
|
||||
received = agent.captured_kwargs[0]
|
||||
assert "custom_data" in received, "Inner agent should receive custom_data kwarg"
|
||||
assert "api_token" in received, "Inner agent should receive api_token kwarg"
|
||||
assert received["custom_data"] == custom_data
|
||||
assert received["api_token"] == api_token
|
||||
|
||||
|
||||
async def test_workflow_as_agent_propagates_kwargs_to_multiple_agents() -> None:
|
||||
"""Test that kwargs flow to all agents when using workflow.as_agent()."""
|
||||
agent1 = _KwargsCapturingAgent(name="agent1")
|
||||
agent2 = _KwargsCapturingAgent(name="agent2")
|
||||
workflow = SequentialBuilder().participants([agent1, agent2]).build()
|
||||
workflow_agent = workflow.as_agent(name="MultiAgentWorkflow")
|
||||
|
||||
custom_data = {"batch_id": "batch-001"}
|
||||
|
||||
_ = await workflow_agent.run("test message", custom_data=custom_data)
|
||||
|
||||
# Both agents should have received kwargs
|
||||
assert len(agent1.captured_kwargs) >= 1, "First agent should be invoked"
|
||||
assert len(agent2.captured_kwargs) >= 1, "Second agent should be invoked"
|
||||
assert agent1.captured_kwargs[0].get("custom_data") == custom_data
|
||||
assert agent2.captured_kwargs[0].get("custom_data") == custom_data
|
||||
|
||||
|
||||
async def test_workflow_as_agent_kwargs_with_none_values() -> None:
|
||||
"""Test that kwargs with None values are passed through correctly via as_agent()."""
|
||||
agent = _KwargsCapturingAgent(name="none_test_agent")
|
||||
workflow = SequentialBuilder().participants([agent]).build()
|
||||
workflow_agent = workflow.as_agent(name="NoneTestWorkflow")
|
||||
|
||||
_ = await workflow_agent.run("test", optional_param=None, other_param="value")
|
||||
|
||||
assert len(agent.captured_kwargs) >= 1
|
||||
received = agent.captured_kwargs[0]
|
||||
assert "optional_param" in received
|
||||
assert received["optional_param"] is None
|
||||
assert received["other_param"] == "value"
|
||||
|
||||
|
||||
async def test_workflow_as_agent_kwargs_with_complex_nested_data() -> None:
|
||||
"""Test that complex nested data structures flow through correctly via as_agent()."""
|
||||
agent = _KwargsCapturingAgent(name="nested_agent")
|
||||
workflow = SequentialBuilder().participants([agent]).build()
|
||||
workflow_agent = workflow.as_agent(name="NestedDataWorkflow")
|
||||
|
||||
complex_data = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": ["a", "b", "c"],
|
||||
"number": 42,
|
||||
},
|
||||
"list": [1, 2, {"nested": True}],
|
||||
},
|
||||
}
|
||||
|
||||
_ = await workflow_agent.run("test", complex_data=complex_data)
|
||||
|
||||
assert len(agent.captured_kwargs) >= 1
|
||||
received = agent.captured_kwargs[0]
|
||||
assert received.get("complex_data") == complex_data
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region SubWorkflow (WorkflowExecutor) Tests
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ Once comfortable with these, explore the rest of the samples below.
|
||||
| Workflow as Agent (Reflection Pattern) | [agents/workflow_as_agent_reflection_pattern.py](./agents/workflow_as_agent_reflection_pattern.py) | Wrap a workflow so it can behave like an agent (reflection pattern) |
|
||||
| Workflow as Agent + HITL | [agents/workflow_as_agent_human_in_the_loop.py](./agents/workflow_as_agent_human_in_the_loop.py) | Extend workflow-as-agent with human-in-the-loop capability |
|
||||
| Workflow as Agent with Thread | [agents/workflow_as_agent_with_thread.py](./agents/workflow_as_agent_with_thread.py) | Use AgentThread to maintain conversation history across workflow-as-agent invocations |
|
||||
| Workflow as Agent kwargs | [agents/workflow_as_agent_kwargs.py](./agents/workflow_as_agent_kwargs.py) | Pass custom context (data, user tokens) via kwargs through workflow.as_agent() to @ai_function tools |
|
||||
| Handoff Workflow as Agent | [agents/handoff_workflow_as_agent.py](./agents/handoff_workflow_as_agent.py) | Use a HandoffBuilder workflow as an agent with HITL via FunctionCallContent/FunctionResultContent |
|
||||
|
||||
### checkpoint
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Annotated, Any
|
||||
|
||||
from agent_framework import SequentialBuilder, ai_function
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Sample: Workflow as Agent with kwargs Propagation to @ai_function Tools
|
||||
|
||||
This sample demonstrates how to flow custom context (skill data, user tokens, etc.)
|
||||
through a workflow exposed via .as_agent() to @ai_function tools using the **kwargs pattern.
|
||||
|
||||
Key Concepts:
|
||||
- Build a workflow using SequentialBuilder (or any builder pattern)
|
||||
- Expose the workflow as a reusable agent via workflow.as_agent()
|
||||
- Pass custom context as kwargs when invoking workflow_agent.run() or run_stream()
|
||||
- kwargs are stored in SharedState and propagated to all agent invocations
|
||||
- @ai_function tools receive kwargs via **kwargs parameter
|
||||
|
||||
When to use workflow.as_agent():
|
||||
- To treat an entire workflow orchestration as a single agent
|
||||
- To compose workflows into higher-level orchestrations
|
||||
- To maintain a consistent agent interface for callers
|
||||
|
||||
Prerequisites:
|
||||
- OpenAI environment variables configured
|
||||
"""
|
||||
|
||||
|
||||
# Define tools that accept custom context via **kwargs
|
||||
@ai_function
|
||||
def get_user_data(
|
||||
query: Annotated[str, Field(description="What user data to retrieve")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve user-specific data based on the authenticated context."""
|
||||
user_token = kwargs.get("user_token", {})
|
||||
user_name = user_token.get("user_name", "anonymous")
|
||||
access_level = user_token.get("access_level", "none")
|
||||
|
||||
print(f"\n[get_user_data] Received kwargs keys: {list(kwargs.keys())}")
|
||||
print(f"[get_user_data] User: {user_name}")
|
||||
print(f"[get_user_data] Access level: {access_level}")
|
||||
|
||||
return f"Retrieved data for user {user_name} with {access_level} access: {query}"
|
||||
|
||||
|
||||
@ai_function
|
||||
def call_api(
|
||||
endpoint_name: Annotated[str, Field(description="Name of the API endpoint to call")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call an API using the configured endpoints from custom_data."""
|
||||
custom_data = kwargs.get("custom_data", {})
|
||||
api_config = custom_data.get("api_config", {})
|
||||
|
||||
base_url = api_config.get("base_url", "unknown")
|
||||
endpoints = api_config.get("endpoints", {})
|
||||
|
||||
print(f"\n[call_api] Received kwargs keys: {list(kwargs.keys())}")
|
||||
print(f"[call_api] Base URL: {base_url}")
|
||||
print(f"[call_api] Available endpoints: {list(endpoints.keys())}")
|
||||
|
||||
if endpoint_name in endpoints:
|
||||
return f"Called {base_url}{endpoints[endpoint_name]} successfully"
|
||||
return f"Endpoint '{endpoint_name}' not found in configuration"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=" * 70)
|
||||
print("Workflow as Agent kwargs Flow Demo")
|
||||
print("=" * 70)
|
||||
|
||||
# Create chat client
|
||||
chat_client = OpenAIChatClient()
|
||||
|
||||
# Create agent with tools that use kwargs
|
||||
agent = chat_client.create_agent(
|
||||
name="assistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Use the available tools to help users. "
|
||||
"When asked about user data, use get_user_data. "
|
||||
"When asked to call an API, use call_api."
|
||||
),
|
||||
tools=[get_user_data, call_api],
|
||||
)
|
||||
|
||||
# Build a sequential workflow
|
||||
workflow = SequentialBuilder().participants([agent]).build()
|
||||
|
||||
# Expose the workflow as an agent using .as_agent()
|
||||
workflow_agent = workflow.as_agent(name="WorkflowAgent")
|
||||
|
||||
# Define custom context that will flow to ai_functions via kwargs
|
||||
custom_data = {
|
||||
"api_config": {
|
||||
"base_url": "https://api.example.com",
|
||||
"endpoints": {
|
||||
"users": "/v1/users",
|
||||
"orders": "/v1/orders",
|
||||
"products": "/v1/products",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
user_token = {
|
||||
"user_name": "bob@contoso.com",
|
||||
"access_level": "admin",
|
||||
}
|
||||
|
||||
print("\nCustom Data being passed:")
|
||||
print(json.dumps(custom_data, indent=2))
|
||||
print(f"\nUser: {user_token['user_name']}")
|
||||
print("\n" + "-" * 70)
|
||||
print("Workflow Agent Execution (watch for [tool_name] logs showing kwargs received):")
|
||||
print("-" * 70)
|
||||
|
||||
# Run workflow agent with kwargs - these will flow through to ai_functions
|
||||
# Note: kwargs are passed to workflow_agent.run_stream() just like workflow.run_stream()
|
||||
print("\n===== Streaming Response =====")
|
||||
async for update in workflow_agent.run_stream(
|
||||
"Please get my user data and then call the users API endpoint.",
|
||||
custom_data=custom_data,
|
||||
user_token=user_token,
|
||||
):
|
||||
if update.text:
|
||||
print(update.text, end="", flush=True)
|
||||
print()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Sample Complete")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user