Fix handoff workflow context management and improve AG-UI demo (#5136)

This commit is contained in:
Evan Mattson
2026-04-08 13:08:24 +09:00
committed by GitHub
Unverified
parent f94a75daa5
commit e10d448ae2
19 changed files with 601 additions and 252 deletions
@@ -38,17 +38,15 @@ from copy import deepcopy
from dataclasses import dataclass
from typing import Any
from agent_framework import Agent, SupportsAgentRun
from agent_framework import Agent, AgentResponse, Message, SupportsAgentRun
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination
from agent_framework._sessions import AgentSession
from agent_framework._tools import FunctionTool, tool
from agent_framework._types import AgentResponse, Content, Message
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest
from agent_framework._workflows._agent_utils import resolve_agent_id
from agent_framework._workflows._checkpoint import CheckpointStorage
from agent_framework._workflows._events import WorkflowEvent
from agent_framework._workflows._request_info_mixin import response_handler
from agent_framework._workflows._typing_utils import is_chat_agent
from agent_framework._workflows._workflow import Workflow
from agent_framework._workflows._workflow_builder import WorkflowBuilder
from agent_framework._workflows._workflow_context import WorkflowContext
@@ -263,88 +261,6 @@ class HandoffAgentExecutor(AgentExecutor):
return cloned_agent
def _persist_pending_approval_function_calls(self) -> None:
"""Persist pending approval function calls for stateless provider resumes.
Handoff workflows force ``store=False`` and replay conversation state from ``_full_conversation``.
When a run pauses on function approval, ``AgentExecutor`` returns ``None`` and the assistant
function-call message is not returned as an ``AgentResponse``. Without persisting that call, the
next turn may submit only a function result, which responses-style APIs reject.
"""
pending_calls: list[Content] = []
for request in self._pending_agent_requests.values():
if request.type != "function_approval_request":
continue
function_call = getattr(request, "function_call", None)
if isinstance(function_call, Content) and function_call.type == "function_call":
pending_calls.append(function_call)
if not pending_calls:
return
self._full_conversation.append(
Message(
role="assistant",
contents=pending_calls,
author_name=self._agent.name,
)
)
def _persist_missing_approved_function_results(
self,
*,
runtime_tool_messages: list[Message],
response_messages: list[Message],
) -> None:
"""Persist fallback function_result entries for approved calls when missing.
In approval resumes, function invocation can execute approved tools without
always surfacing those tool outputs in the returned ``AgentResponse.messages``.
For stateless handoff replays, we must keep call/output pairs balanced.
"""
candidate_results: dict[str, Content] = {}
for message in runtime_tool_messages:
for content in message.contents:
if content.type == "function_result":
call_id = getattr(content, "call_id", None)
if isinstance(call_id, str) and call_id:
candidate_results[call_id] = content
continue
if content.type != "function_approval_response" or not content.approved:
continue
function_call = getattr(content, "function_call", None)
call_id = getattr(function_call, "call_id", None) or getattr(content, "id", None)
if isinstance(call_id, str) and call_id and call_id not in candidate_results:
# Fallback content for approved calls when runtime messages do not include
# a concrete function_result payload.
candidate_results[call_id] = Content.from_function_result(
call_id=call_id,
result='{"status":"approved"}',
)
if not candidate_results:
return
observed_result_call_ids: set[str] = set()
for message in [*self._full_conversation, *response_messages]:
for content in message.contents:
if content.type == "function_result" and isinstance(content.call_id, str) and content.call_id:
observed_result_call_ids.add(content.call_id)
missing_call_ids = sorted(set(candidate_results.keys()) - observed_result_call_ids)
if not missing_call_ids:
return
self._full_conversation.append(
Message(
role="tool",
contents=[candidate_results[call_id] for call_id in missing_call_ids],
author_name=self._agent.name,
)
)
def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]:
"""Produce a deep copy of the Agent while preserving runtime configuration."""
options = agent.default_options
@@ -360,7 +276,6 @@ class HandoffAgentExecutor(AgentExecutor):
cloned_options = deepcopy(options)
# Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once.
cloned_options["allow_multiple_tool_calls"] = False
cloned_options["store"] = False
cloned_options["tools"] = new_tools
# restore the original tools, in case they are shared between agents
@@ -426,45 +341,15 @@ class HandoffAgentExecutor(AgentExecutor):
@override
async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None:
"""Override to support handoff."""
incoming_messages = list(self._cache)
cleaned_incoming_messages = clean_conversation_for_handoff(incoming_messages)
runtime_tool_messages = [
message
for message in incoming_messages
if any(
content.type
in {
"function_result",
"function_approval_response",
}
for content in message.contents
)
or message.role == "tool"
]
# When the full conversation is empty, it means this is the first run.
# Broadcast the initial cache to all other agents. Subsequent runs won't
# need this since responses are broadcast after each agent run and user input.
if self._is_start_agent and not self._full_conversation:
await self._broadcast_messages(cleaned_incoming_messages, ctx)
await self._broadcast_messages(self._cache.copy(), ctx)
# Persist only cleaned chat history between turns to avoid replaying stale tool calls.
self._full_conversation.extend(cleaned_incoming_messages)
# Always run with full conversation context for request_info resumes.
# Keep runtime tool-control messages for this run only (e.g., approval responses).
self._cache = list(self._full_conversation)
self._cache.extend(runtime_tool_messages)
# Handoff workflows are orchestrator-stateful and provider-stateless by design.
# If an existing session still has a service conversation id, clear it to avoid
# replaying stale unresolved tool calls across resumed turns.
if (
is_chat_agent(self._agent)
and self._agent.default_options.get("store") is False
and self._session.service_session_id is not None
):
self._session.service_session_id = None
# Full conversation maintains the chat history between agents across handoffs,
# excluding internal agent messages such as tool calls and results.
self._full_conversation.extend(self._cache.copy())
# Check termination condition before running the agent
if await self._check_terminate_and_yield(ctx):
@@ -483,36 +368,35 @@ class HandoffAgentExecutor(AgentExecutor):
# A function approval request is issued by the base AgentExecutor
if response is None:
if is_chat_agent(self._agent) and self._agent.default_options.get("store") is False:
self._persist_pending_approval_function_calls()
# Agent did not complete (e.g., waiting for user input); do not emit response
logger.debug("AgentExecutor %s: Agent did not complete, awaiting user input", self.id)
return
# Remove function call related content from the agent response for broadcast.
# This prevents replaying stale tool artifacts to other agents.
# Remove function call related content from the agent response for full conversation history
cleaned_response = clean_conversation_for_handoff(response.messages)
# For internal tracking, preserve the full response (including function_calls)
# in _full_conversation so that Azure OpenAI can match function_calls with
# function_results when the workflow resumes after user approvals.
self._full_conversation.extend(response.messages)
self._persist_missing_approved_function_results(
runtime_tool_messages=runtime_tool_messages,
response_messages=response.messages,
)
# Append the agent response to the full conversation history. This list removes
# function call related content such that the result stays consistent regardless
# of which agent yields the final output.
self._full_conversation.extend(cleaned_response)
# Broadcast only the cleaned response to other agents (without function_calls/results)
await self._broadcast_messages(cleaned_response, ctx)
# Check if a handoff was requested
if handoff_target := self._is_handoff_requested(response):
if is_handoff_requested := self._is_handoff_requested(response):
handoff_target, handoff_message = is_handoff_requested
if handoff_target not in self._handoff_targets:
raise ValueError(
f"Agent '{resolve_agent_id(self._agent)}' attempted to handoff to unknown "
f"target '{handoff_target}'. Valid targets are: {', '.join(self._handoff_targets)}"
)
# Add the handoff message to the cache so that the next invocation of the agent includes
# the tool call result. This is necessary because each tool call must have a corresponding
# tool result.
self._cache.append(handoff_message)
await ctx.send_message(
AgentExecutorRequest(messages=[], should_respond=True),
target_id=handoff_target,
@@ -589,12 +473,25 @@ class HandoffAgentExecutor(AgentExecutor):
# Since all agents are connected via fan-out, we can directly send the message
await ctx.send_message(agent_executor_request)
def _is_handoff_requested(self, response: AgentResponse) -> str | None:
def _is_handoff_requested(self, response: AgentResponse) -> tuple[str, Message] | None:
"""Determine if the agent response includes a handoff request.
If a handoff tool is invoked, the middleware will short-circuit execution
and provide a synthetic result that includes the target agent ID. The message
that contains the function result will be the last message in the response.
Args:
response: The AgentResponse to inspect for handoff requests
Returns:
A tuple of (target_agent_id, message) if a handoff is requested, or None if no handoff is requested
Note:
The returned message is the full message that contains the handoff function result content. This is
needed to complete the agent's chat history due to the `_AutoHandoffMiddleware` short-circuiting
behavior, which prevents the handoff tool call and result from being included in the agent response
messages. By returning the full message, we can ensure the agent's chat history remains valid with
a function result for the handoff tool call.
"""
if not response.messages:
return None
@@ -617,7 +514,7 @@ class HandoffAgentExecutor(AgentExecutor):
if parsed_payload:
handoff_target = parsed_payload.get(HANDOFF_FUNCTION_RESULT_KEY)
if isinstance(handoff_target, str):
return handoff_target
return handoff_target, last_message
else:
continue
@@ -1034,6 +931,25 @@ class HandoffBuilder:
# Resolve agents (either from instances or factories)
# The returned map keys are either executor IDs or factory names, which is need to resolve handoff configs
resolved_agents = self._resolve_agents()
# Validate that all agents have require_per_service_call_history_persistence enabled.
# Handoff workflows use middleware that short-circuits tool calls (MiddlewareTermination),
# which means the service never sees those tool results. Without per-service-call
# history persistence, local history providers would persist tool results that
# the service has no record of, causing call/result mismatches on subsequent turns.
agents_missing_flag = [
resolve_agent_id(agent)
for agent in resolved_agents.values()
if not agent.require_per_service_call_history_persistence
]
if agents_missing_flag:
raise ValueError(
f"Handoff workflows require all participant agents to have "
f"'require_per_service_call_history_persistence=True'. "
f"The following agents are missing this setting: {', '.join(agents_missing_flag)}. "
f"Set this flag when constructing each Agent to ensure local history stays "
f"consistent with the service across handoff tool-call short-circuits."
)
# Resolve handoff configurations to use agent display names
# The returned map keys are executor IDs
resolved_handoffs = self._resolve_handoffs(resolved_agents)
@@ -24,6 +24,11 @@ def clean_conversation_for_handoff(conversation: list[Message]) -> list[Message]
- Drops all non-text content from every message.
- Drops messages with no remaining text content.
- Preserves original roles and author names for retained text messages.
Args:
conversation: Full conversation history, including tool-control content
Returns:
Cleaned conversation history with only text content, suitable for handoff routing
"""
cleaned: list[Message] = []
for msg in conversation:
@@ -31,6 +36,8 @@ def clean_conversation_for_handoff(conversation: list[Message]) -> list[Message]
# (function_call/function_result/approval payloads) is runtime-only and
# must not be replayed in future model turns.
text_parts = [content.text for content in msg.contents if content.type == "text" and content.text]
# TODO(@taochen): This is a simplified check that considers any non-text content as a tool call.
# We need to enhance this logic to specifically identify tool related contents.
if not text_parts:
continue
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import os
import re
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from typing import Any, cast
from typing import Annotated, Any, cast
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -16,13 +17,15 @@ from agent_framework import (
Message,
ResponseStream,
WorkflowEvent,
WorkflowRunState,
resolve_agent_id,
tool,
)
from agent_framework._clients import BaseChatClient
from agent_framework._middleware import ChatMiddlewareLayer, FunctionInvocationContext, MiddlewareTermination
from agent_framework._tools import FunctionInvocationLayer, FunctionTool
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder, HandoffSentEvent
from pytest import param
from agent_framework_orchestrations._handoff import (
HANDOFF_FUNCTION_RESULT_KEY,
@@ -34,6 +37,7 @@ from agent_framework_orchestrations._handoff import (
from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff
# region unit tests
class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
"""Mock chat client for testing handoff workflows."""
@@ -132,7 +136,12 @@ class MockHandoffAgent(Agent):
handoff_to: The name of the agent to hand off to, or None for no handoff.
This is hardcoded for testing purposes so that the agent always attempts to hand off.
"""
super().__init__(client=MockChatClient(name=name, handoff_to=handoff_to), name=name, id=name)
super().__init__(
client=MockChatClient(name=name, handoff_to=handoff_to),
name=name,
id=name,
require_per_service_call_history_persistence=True,
)
class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
@@ -255,6 +264,7 @@ async def test_resume_keeps_prior_user_context_for_same_agent() -> None:
id="refund_agent",
name="refund_agent",
client=ContextAwareRefundClient(),
require_per_service_call_history_persistence=True,
)
workflow = (
HandoffBuilder(participants=[refund_agent], termination_condition=lambda _: False)
@@ -352,6 +362,7 @@ async def test_tool_approval_responses_are_not_replayed_from_history() -> None:
name="refund_agent",
client=ApprovalReplayClient(),
tools=[submit_refund_counted],
require_per_service_call_history_persistence=True,
)
workflow = (
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
@@ -455,6 +466,7 @@ async def test_handoff_resume_preserves_approval_function_call_for_stateless_run
name="refund_agent",
client=client,
tools=[submit_refund],
require_per_service_call_history_persistence=True,
)
workflow = (
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
@@ -524,11 +536,13 @@ async def test_handoff_replay_serializes_handoff_function_results() -> None:
id="triage",
name="triage",
client=ReplaySafeHandoffClient(name="triage", handoff_sequence=["specialist", None]),
require_per_service_call_history_persistence=True,
)
specialist = Agent(
id="specialist",
name="specialist",
client=ReplaySafeHandoffClient(name="specialist", handoff_sequence=["triage"]),
require_per_service_call_history_persistence=True,
)
workflow = (
@@ -652,11 +666,13 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs(
name="refund_agent",
client=refund_client,
tools=[submit_refund],
require_per_service_call_history_persistence=True,
)
order_agent = Agent(
id="order_agent",
name="order_agent",
client=OrderReplayClient(),
require_per_service_call_history_persistence=True,
)
workflow = (
HandoffBuilder(participants=[refund_agent, order_agent], termination_condition=lambda _: False)
@@ -686,16 +702,6 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs(
assert refund_client.resume_validated is True
def test_handoff_clone_disables_provider_side_storage() -> None:
"""Handoff executors should force store=False to avoid stale provider call state."""
triage = MockHandoffAgent(name="triage")
workflow = HandoffBuilder(participants=[triage]).with_start_agent(triage).build()
executor = workflow.executors[resolve_agent_id(triage)]
assert isinstance(executor, HandoffAgentExecutor)
assert executor._agent.default_options.get("store") is False
async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None:
"""Handoff clones should keep per-service-call history persistence active for auto-handoff termination."""
triage_history = InMemoryHistoryProvider()
@@ -711,6 +717,7 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() ->
name="specialist",
client=MockChatClient(name="specialist"),
default_options={"tool_choice": "none"},
require_per_service_call_history_persistence=True,
)
workflow = (
@@ -738,21 +745,6 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() ->
assert all(message.role != "tool" for message in stored_messages)
async def test_handoff_clears_stale_service_session_id_before_run() -> None:
"""Stale service session IDs must be dropped before each handoff agent turn."""
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
specialist = MockHandoffAgent(name="specialist")
workflow = HandoffBuilder(participants=[triage, specialist]).with_start_agent(triage).build()
triage_executor = workflow.executors[resolve_agent_id(triage)]
assert isinstance(triage_executor, HandoffAgentExecutor)
triage_executor._session.service_session_id = "resp_stale_value"
await _drain(workflow.run("My order is damaged", stream=True))
assert triage_executor._session.service_session_id is None
def test_clean_conversation_for_handoff_keeps_text_only_history() -> None:
"""Tool-control messages must be excluded from persisted handoff history."""
function_call = Content.from_function_call(
@@ -791,52 +783,6 @@ def test_clean_conversation_for_handoff_keeps_text_only_history() -> None:
]
def test_persist_missing_approved_function_results_handles_runtime_and_fallback_outputs() -> None:
"""Persisted history should retain approved call outputs across runtime shapes."""
agent = MockHandoffAgent(name="triage")
executor = HandoffAgentExecutor(agent, handoffs=[])
call_with_runtime_result = "call-runtime-result"
call_with_approval_only = "call-approval-only"
executor._full_conversation = [
Message(
role="assistant",
contents=[
Content.from_function_call(call_id=call_with_runtime_result, name="submit_refund", arguments={}),
Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}),
],
)
]
approval_response = Content.from_function_approval_response(
approved=True,
id=call_with_approval_only,
function_call=Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}),
)
runtime_messages = [
Message(
role="tool",
contents=[Content.from_function_result(call_id=call_with_runtime_result, result='{"submitted":true}')],
),
Message(role="user", contents=[approval_response]),
]
executor._persist_missing_approved_function_results(runtime_tool_messages=runtime_messages, response_messages=[])
persisted_tool_messages = [message for message in executor._full_conversation if message.role == "tool"]
assert persisted_tool_messages
persisted_results = [
content
for message in persisted_tool_messages
for content in message.contents
if content.type == "function_result" and content.call_id
]
result_by_call_id = {content.call_id: content.result for content in persisted_results}
assert result_by_call_id[call_with_runtime_result] == '{"submitted":true}'
assert result_by_call_id[call_with_approval_only] == '{"status":"approved"}'
async def test_autonomous_mode_yields_output_without_user_request():
"""Ensure autonomous interaction mode yields output without requesting user input."""
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
@@ -979,7 +925,12 @@ async def test_handoff_terminates_without_request_info_when_latest_response_meet
return _get()
agent = Agent(id="order_agent", name="order_agent", client=FinalizingClient())
agent = Agent(
id="order_agent",
name="order_agent",
client=FinalizingClient(),
require_per_service_call_history_persistence=True,
)
workflow = (
HandoffBuilder(
participants=[agent],
@@ -1061,6 +1012,7 @@ async def test_context_provider_preserved_during_handoff():
name="test_agent",
id="test_agent",
context_providers=[context_provider],
require_per_service_call_history_persistence=True,
)
# Verify the original agent has the context provider
@@ -1104,8 +1056,8 @@ async def test_auto_handoff_middleware_intercepts_handoff_tool_call() -> None:
middleware = _AutoHandoffMiddleware([HandoffConfiguration(target=target_id)])
@tool(name=get_handoff_tool_name(target_id), approval_mode="never_require")
def handoff_tool() -> str:
return "unreachable"
def handoff_tool() -> None:
pass
context = FunctionInvocationContext(function=handoff_tool, arguments={})
call_next = AsyncMock()
@@ -1136,6 +1088,20 @@ async def test_auto_handoff_middleware_calls_next_for_non_handoff_tool() -> None
assert context.result is None
def test_handoff_builder_rejects_agents_without_per_service_call_history_persistence() -> None:
"""HandoffBuilder.build() should reject agents missing require_per_service_call_history_persistence."""
agent_without_flag = Agent(
client=MockChatClient(name="no_flag"),
name="no_flag",
id="no_flag",
# require_per_service_call_history_persistence defaults to False
)
agent_with_flag = MockHandoffAgent(name="has_flag") # MockHandoffAgent sets flag to True
with pytest.raises(ValueError, match="require_per_service_call_history_persistence"):
HandoffBuilder(participants=[agent_without_flag, agent_with_flag]).with_start_agent(agent_with_flag).build()
def test_handoff_builder_rejects_non_agent_supports_agent_run():
"""Verify that participants() rejects SupportsAgentRun implementations that are not Agent instances."""
from agent_framework import AgentResponse, AgentSession, SupportsAgentRun
@@ -1160,3 +1126,246 @@ def test_handoff_builder_rejects_non_agent_supports_agent_run():
with pytest.raises(TypeError, match="Participants must be Agent instances"):
HandoffBuilder().participants([fake])
# endregion
# region integration tests
try:
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
_has_foundry_deps = True
except ImportError:
_has_foundry_deps = False
skip_if_foundry_integration_tests_disabled = pytest.mark.skipif(
not _has_foundry_deps or os.getenv("FOUNDRY_PROJECT_ENDPOINT", "") == "" or os.getenv("FOUNDRY_MODEL", "") == "",
reason="No real FOUNDRY_PROJECT_ENDPOINT or FOUNDRY_MODEL provided; skipping integration tests.",
)
@pytest.mark.integration
@skip_if_foundry_integration_tests_disabled
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
async def test_simple_handoff_workflow(store: bool) -> None:
"""Test a simple handoff workflow with two agents."""
client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AzureCliCredential(),
)
triage_agent = Agent(
client=client,
instructions=(
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
"based on the problem described."
),
name="triage_agent",
default_options={"store": store},
require_per_service_call_history_persistence=True,
)
refund_agent = Agent(
client=client,
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
name="refund_agent",
default_options={"store": store},
require_per_service_call_history_persistence=True,
)
workflow = (
HandoffBuilder(
participants=[triage_agent, refund_agent],
termination_condition=lambda conversation: (
# We terminate after triage hands off to refund to test handoff works
len(conversation) > 0 and conversation[-1].author_name == refund_agent.name
),
)
.with_start_agent(triage_agent)
.build()
)
workflow_result = await workflow.run("I want to get a refund")
# The workflow should end in IDLE state rather than IDLE_WITH_PENDING_REQUESTS
# because the termination condition is met right after the refund agent's response.
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
# Output should contain responses from both agents and a final full conversation from between them.
assert len(workflow_result.get_outputs()) == 3
# There will be exactly one handoff request
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
assert len(handoff_event) == 1
assert isinstance(handoff_event[0].data, HandoffSentEvent)
assert handoff_event[0].data.source == triage_agent.name
assert handoff_event[0].data.target == refund_agent.name
@pytest.mark.integration
@skip_if_foundry_integration_tests_disabled
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> None:
"""Test a simple handoff workflow with two agents where the second agent makes a request after handoff."""
client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AzureCliCredential(),
)
triage_agent = Agent(
client=client,
instructions=(
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
"based on the problem described."
),
name="triage_agent",
default_options={"store": store},
require_per_service_call_history_persistence=True,
)
refund_agent = Agent(
client=client,
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
name="refund_agent",
default_options={"store": store},
require_per_service_call_history_persistence=True,
)
workflow = (
HandoffBuilder(
participants=[triage_agent, refund_agent],
termination_condition=lambda conversation: (
# We terminate after the refund agent request user input and the user provides
# a response. There will be two user messages in the conversation at that point
# - the original user message and the follow-up message in response to the refund
# agent's request.
len([message for message in conversation if message.role == "user"]) == 2
),
)
.with_start_agent(triage_agent)
.build()
)
workflow_result = await workflow.run("I want to get a refund")
# The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE
# because the user has not yet responded to the refund agent's request yet.
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
# There will be exactly one handoff request
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
assert len(handoff_event) == 1
assert isinstance(handoff_event[0].data, HandoffSentEvent)
assert handoff_event[0].data.source == triage_agent.name
assert handoff_event[0].data.target == refund_agent.name
# There should be exactly one request for information from the refund agent after handoff
request_events = [event for event in workflow_result if event.type == "request_info"]
assert len(request_events) == 1
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
# Provide the user's response to the refund agent's request to allow the workflow to complete.
workflow_result = await workflow.run(
responses={
request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"),
},
)
# The workflow should now end in IDLE state since the termination condition
# is met after the user's response to the refund agent's request.
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
@tool(approval_mode="always_require")
def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str:
"""Simulated function to process a refund for a given order number."""
return f"Refund processed successfully for order {order_number}."
@pytest.mark.integration
@skip_if_foundry_integration_tests_disabled
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
async def test_simple_handoff_workflow_with_approval_request(store: bool) -> None:
"""Test a simple handoff workflow with two agents where the second agent makes a request after handoff."""
client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AzureCliCredential(),
)
triage_agent = Agent(
client=client,
instructions=(
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
"based on the problem described."
),
name="triage_agent",
default_options={"store": store},
require_per_service_call_history_persistence=True,
)
refund_agent = Agent(
client=client,
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
name="refund_agent",
default_options={"store": store},
tools=[process_refund],
require_per_service_call_history_persistence=True,
)
# This workflow will be terminated manually
workflow = (
HandoffBuilder(
participants=[triage_agent, refund_agent],
)
.with_start_agent(triage_agent)
.build()
)
workflow_result = await workflow.run("I want to get a refund")
# The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE
# because the user has not yet responded to the refund agent's request yet.
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
# There will be exactly one handoff request
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
assert len(handoff_event) == 1
assert isinstance(handoff_event[0].data, HandoffSentEvent)
assert handoff_event[0].data.source == triage_agent.name
assert handoff_event[0].data.target == refund_agent.name
# There should be exactly one request for information from the refund agent after handoff
request_events = [event for event in workflow_result if event.type == "request_info"]
assert len(request_events) == 1
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
# Provide the user's response to the refund agent's request to allow the workflow to complete.
workflow_result = await workflow.run(
responses={
request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"),
},
)
# The workflow should now end in IDLE_WITH_PENDING_REQUESTS state since the refund agent will ask for
# approval to process the refund after receiving the user's response.
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
# There should be exactly one request for tool approval from the refund agent.
request_events = [event for event in workflow_result if event.type == "request_info"]
assert len(request_events) == 1
assert isinstance(request_events[0].data, Content) and request_events[0].data.type == "function_approval_request"
# Provide the user's response to the refund agent's request to allow the workflow to complete.
workflow_result = await workflow.run(
responses={request_events[0].request_id: request_events[0].data.to_function_approval_response(approved=True)}
)
# The refund agent will process the refund after receiving approval, but since there is no termination condition,
# the workflow will end in IDLE_WITH_PENDING_REQUESTS state waiting for further user input.
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
# There should be exactly one request for information from the refund agent after processing the refund,
# which is the follow-up question asking if there is anything else they can help with.
request_events = [event for event in workflow_result if event.type == "request_info"]
assert len(request_events) == 1
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
workflow_result = await workflow.run(responses={request_events[0].request_id: HandoffAgentUserRequest.terminate()})
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
# endregion