mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Fix handoff workflow context management and improve AG-UI demo (#5136)
This commit is contained in:
committed by
GitHub
Unverified
parent
f94a75daa5
commit
e10d448ae2
@@ -38,17 +38,15 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, SupportsAgentRun
|
||||
from agent_framework import Agent, AgentResponse, Message, SupportsAgentRun
|
||||
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination
|
||||
from agent_framework._sessions import AgentSession
|
||||
from agent_framework._tools import FunctionTool, tool
|
||||
from agent_framework._types import AgentResponse, Content, Message
|
||||
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest
|
||||
from agent_framework._workflows._agent_utils import resolve_agent_id
|
||||
from agent_framework._workflows._checkpoint import CheckpointStorage
|
||||
from agent_framework._workflows._events import WorkflowEvent
|
||||
from agent_framework._workflows._request_info_mixin import response_handler
|
||||
from agent_framework._workflows._typing_utils import is_chat_agent
|
||||
from agent_framework._workflows._workflow import Workflow
|
||||
from agent_framework._workflows._workflow_builder import WorkflowBuilder
|
||||
from agent_framework._workflows._workflow_context import WorkflowContext
|
||||
@@ -263,88 +261,6 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
|
||||
return cloned_agent
|
||||
|
||||
def _persist_pending_approval_function_calls(self) -> None:
|
||||
"""Persist pending approval function calls for stateless provider resumes.
|
||||
|
||||
Handoff workflows force ``store=False`` and replay conversation state from ``_full_conversation``.
|
||||
When a run pauses on function approval, ``AgentExecutor`` returns ``None`` and the assistant
|
||||
function-call message is not returned as an ``AgentResponse``. Without persisting that call, the
|
||||
next turn may submit only a function result, which responses-style APIs reject.
|
||||
"""
|
||||
pending_calls: list[Content] = []
|
||||
for request in self._pending_agent_requests.values():
|
||||
if request.type != "function_approval_request":
|
||||
continue
|
||||
function_call = getattr(request, "function_call", None)
|
||||
if isinstance(function_call, Content) and function_call.type == "function_call":
|
||||
pending_calls.append(function_call)
|
||||
|
||||
if not pending_calls:
|
||||
return
|
||||
|
||||
self._full_conversation.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=pending_calls,
|
||||
author_name=self._agent.name,
|
||||
)
|
||||
)
|
||||
|
||||
def _persist_missing_approved_function_results(
|
||||
self,
|
||||
*,
|
||||
runtime_tool_messages: list[Message],
|
||||
response_messages: list[Message],
|
||||
) -> None:
|
||||
"""Persist fallback function_result entries for approved calls when missing.
|
||||
|
||||
In approval resumes, function invocation can execute approved tools without
|
||||
always surfacing those tool outputs in the returned ``AgentResponse.messages``.
|
||||
For stateless handoff replays, we must keep call/output pairs balanced.
|
||||
"""
|
||||
candidate_results: dict[str, Content] = {}
|
||||
for message in runtime_tool_messages:
|
||||
for content in message.contents:
|
||||
if content.type == "function_result":
|
||||
call_id = getattr(content, "call_id", None)
|
||||
if isinstance(call_id, str) and call_id:
|
||||
candidate_results[call_id] = content
|
||||
continue
|
||||
|
||||
if content.type != "function_approval_response" or not content.approved:
|
||||
continue
|
||||
|
||||
function_call = getattr(content, "function_call", None)
|
||||
call_id = getattr(function_call, "call_id", None) or getattr(content, "id", None)
|
||||
if isinstance(call_id, str) and call_id and call_id not in candidate_results:
|
||||
# Fallback content for approved calls when runtime messages do not include
|
||||
# a concrete function_result payload.
|
||||
candidate_results[call_id] = Content.from_function_result(
|
||||
call_id=call_id,
|
||||
result='{"status":"approved"}',
|
||||
)
|
||||
|
||||
if not candidate_results:
|
||||
return
|
||||
|
||||
observed_result_call_ids: set[str] = set()
|
||||
for message in [*self._full_conversation, *response_messages]:
|
||||
for content in message.contents:
|
||||
if content.type == "function_result" and isinstance(content.call_id, str) and content.call_id:
|
||||
observed_result_call_ids.add(content.call_id)
|
||||
|
||||
missing_call_ids = sorted(set(candidate_results.keys()) - observed_result_call_ids)
|
||||
if not missing_call_ids:
|
||||
return
|
||||
|
||||
self._full_conversation.append(
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[candidate_results[call_id] for call_id in missing_call_ids],
|
||||
author_name=self._agent.name,
|
||||
)
|
||||
)
|
||||
|
||||
def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]:
|
||||
"""Produce a deep copy of the Agent while preserving runtime configuration."""
|
||||
options = agent.default_options
|
||||
@@ -360,7 +276,6 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
cloned_options = deepcopy(options)
|
||||
# Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once.
|
||||
cloned_options["allow_multiple_tool_calls"] = False
|
||||
cloned_options["store"] = False
|
||||
cloned_options["tools"] = new_tools
|
||||
|
||||
# restore the original tools, in case they are shared between agents
|
||||
@@ -426,45 +341,15 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
@override
|
||||
async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None:
|
||||
"""Override to support handoff."""
|
||||
incoming_messages = list(self._cache)
|
||||
cleaned_incoming_messages = clean_conversation_for_handoff(incoming_messages)
|
||||
runtime_tool_messages = [
|
||||
message
|
||||
for message in incoming_messages
|
||||
if any(
|
||||
content.type
|
||||
in {
|
||||
"function_result",
|
||||
"function_approval_response",
|
||||
}
|
||||
for content in message.contents
|
||||
)
|
||||
or message.role == "tool"
|
||||
]
|
||||
|
||||
# When the full conversation is empty, it means this is the first run.
|
||||
# Broadcast the initial cache to all other agents. Subsequent runs won't
|
||||
# need this since responses are broadcast after each agent run and user input.
|
||||
if self._is_start_agent and not self._full_conversation:
|
||||
await self._broadcast_messages(cleaned_incoming_messages, ctx)
|
||||
await self._broadcast_messages(self._cache.copy(), ctx)
|
||||
|
||||
# Persist only cleaned chat history between turns to avoid replaying stale tool calls.
|
||||
self._full_conversation.extend(cleaned_incoming_messages)
|
||||
|
||||
# Always run with full conversation context for request_info resumes.
|
||||
# Keep runtime tool-control messages for this run only (e.g., approval responses).
|
||||
self._cache = list(self._full_conversation)
|
||||
self._cache.extend(runtime_tool_messages)
|
||||
|
||||
# Handoff workflows are orchestrator-stateful and provider-stateless by design.
|
||||
# If an existing session still has a service conversation id, clear it to avoid
|
||||
# replaying stale unresolved tool calls across resumed turns.
|
||||
if (
|
||||
is_chat_agent(self._agent)
|
||||
and self._agent.default_options.get("store") is False
|
||||
and self._session.service_session_id is not None
|
||||
):
|
||||
self._session.service_session_id = None
|
||||
# Full conversation maintains the chat history between agents across handoffs,
|
||||
# excluding internal agent messages such as tool calls and results.
|
||||
self._full_conversation.extend(self._cache.copy())
|
||||
|
||||
# Check termination condition before running the agent
|
||||
if await self._check_terminate_and_yield(ctx):
|
||||
@@ -483,36 +368,35 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
|
||||
# A function approval request is issued by the base AgentExecutor
|
||||
if response is None:
|
||||
if is_chat_agent(self._agent) and self._agent.default_options.get("store") is False:
|
||||
self._persist_pending_approval_function_calls()
|
||||
# Agent did not complete (e.g., waiting for user input); do not emit response
|
||||
logger.debug("AgentExecutor %s: Agent did not complete, awaiting user input", self.id)
|
||||
return
|
||||
|
||||
# Remove function call related content from the agent response for broadcast.
|
||||
# This prevents replaying stale tool artifacts to other agents.
|
||||
# Remove function call related content from the agent response for full conversation history
|
||||
cleaned_response = clean_conversation_for_handoff(response.messages)
|
||||
|
||||
# For internal tracking, preserve the full response (including function_calls)
|
||||
# in _full_conversation so that Azure OpenAI can match function_calls with
|
||||
# function_results when the workflow resumes after user approvals.
|
||||
self._full_conversation.extend(response.messages)
|
||||
self._persist_missing_approved_function_results(
|
||||
runtime_tool_messages=runtime_tool_messages,
|
||||
response_messages=response.messages,
|
||||
)
|
||||
# Append the agent response to the full conversation history. This list removes
|
||||
# function call related content such that the result stays consistent regardless
|
||||
# of which agent yields the final output.
|
||||
self._full_conversation.extend(cleaned_response)
|
||||
|
||||
# Broadcast only the cleaned response to other agents (without function_calls/results)
|
||||
await self._broadcast_messages(cleaned_response, ctx)
|
||||
|
||||
# Check if a handoff was requested
|
||||
if handoff_target := self._is_handoff_requested(response):
|
||||
if is_handoff_requested := self._is_handoff_requested(response):
|
||||
handoff_target, handoff_message = is_handoff_requested
|
||||
if handoff_target not in self._handoff_targets:
|
||||
raise ValueError(
|
||||
f"Agent '{resolve_agent_id(self._agent)}' attempted to handoff to unknown "
|
||||
f"target '{handoff_target}'. Valid targets are: {', '.join(self._handoff_targets)}"
|
||||
)
|
||||
|
||||
# Add the handoff message to the cache so that the next invocation of the agent includes
|
||||
# the tool call result. This is necessary because each tool call must have a corresponding
|
||||
# tool result.
|
||||
self._cache.append(handoff_message)
|
||||
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[], should_respond=True),
|
||||
target_id=handoff_target,
|
||||
@@ -589,12 +473,25 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
# Since all agents are connected via fan-out, we can directly send the message
|
||||
await ctx.send_message(agent_executor_request)
|
||||
|
||||
def _is_handoff_requested(self, response: AgentResponse) -> str | None:
|
||||
def _is_handoff_requested(self, response: AgentResponse) -> tuple[str, Message] | None:
|
||||
"""Determine if the agent response includes a handoff request.
|
||||
|
||||
If a handoff tool is invoked, the middleware will short-circuit execution
|
||||
and provide a synthetic result that includes the target agent ID. The message
|
||||
that contains the function result will be the last message in the response.
|
||||
|
||||
Args:
|
||||
response: The AgentResponse to inspect for handoff requests
|
||||
|
||||
Returns:
|
||||
A tuple of (target_agent_id, message) if a handoff is requested, or None if no handoff is requested
|
||||
|
||||
Note:
|
||||
The returned message is the full message that contains the handoff function result content. This is
|
||||
needed to complete the agent's chat history due to the `_AutoHandoffMiddleware` short-circuiting
|
||||
behavior, which prevents the handoff tool call and result from being included in the agent response
|
||||
messages. By returning the full message, we can ensure the agent's chat history remains valid with
|
||||
a function result for the handoff tool call.
|
||||
"""
|
||||
if not response.messages:
|
||||
return None
|
||||
@@ -617,7 +514,7 @@ class HandoffAgentExecutor(AgentExecutor):
|
||||
if parsed_payload:
|
||||
handoff_target = parsed_payload.get(HANDOFF_FUNCTION_RESULT_KEY)
|
||||
if isinstance(handoff_target, str):
|
||||
return handoff_target
|
||||
return handoff_target, last_message
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -1034,6 +931,25 @@ class HandoffBuilder:
|
||||
# Resolve agents (either from instances or factories)
|
||||
# The returned map keys are either executor IDs or factory names, which is need to resolve handoff configs
|
||||
resolved_agents = self._resolve_agents()
|
||||
|
||||
# Validate that all agents have require_per_service_call_history_persistence enabled.
|
||||
# Handoff workflows use middleware that short-circuits tool calls (MiddlewareTermination),
|
||||
# which means the service never sees those tool results. Without per-service-call
|
||||
# history persistence, local history providers would persist tool results that
|
||||
# the service has no record of, causing call/result mismatches on subsequent turns.
|
||||
agents_missing_flag = [
|
||||
resolve_agent_id(agent)
|
||||
for agent in resolved_agents.values()
|
||||
if not agent.require_per_service_call_history_persistence
|
||||
]
|
||||
if agents_missing_flag:
|
||||
raise ValueError(
|
||||
f"Handoff workflows require all participant agents to have "
|
||||
f"'require_per_service_call_history_persistence=True'. "
|
||||
f"The following agents are missing this setting: {', '.join(agents_missing_flag)}. "
|
||||
f"Set this flag when constructing each Agent to ensure local history stays "
|
||||
f"consistent with the service across handoff tool-call short-circuits."
|
||||
)
|
||||
# Resolve handoff configurations to use agent display names
|
||||
# The returned map keys are executor IDs
|
||||
resolved_handoffs = self._resolve_handoffs(resolved_agents)
|
||||
|
||||
@@ -24,6 +24,11 @@ def clean_conversation_for_handoff(conversation: list[Message]) -> list[Message]
|
||||
- Drops all non-text content from every message.
|
||||
- Drops messages with no remaining text content.
|
||||
- Preserves original roles and author names for retained text messages.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation history, including tool-control content
|
||||
Returns:
|
||||
Cleaned conversation history with only text content, suitable for handoff routing
|
||||
"""
|
||||
cleaned: list[Message] = []
|
||||
for msg in conversation:
|
||||
@@ -31,6 +36,8 @@ def clean_conversation_for_handoff(conversation: list[Message]) -> list[Message]
|
||||
# (function_call/function_result/approval payloads) is runtime-only and
|
||||
# must not be replayed in future model turns.
|
||||
text_parts = [content.text for content in msg.contents if content.type == "text" and content.text]
|
||||
# TODO(@taochen): This is a simplified check that considers any non-text content as a tool call.
|
||||
# We need to enhance this logic to specifically identify tool related contents.
|
||||
if not text_parts:
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Annotated, Any, cast
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -16,13 +17,15 @@ from agent_framework import (
|
||||
Message,
|
||||
ResponseStream,
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
resolve_agent_id,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._clients import BaseChatClient
|
||||
from agent_framework._middleware import ChatMiddlewareLayer, FunctionInvocationContext, MiddlewareTermination
|
||||
from agent_framework._tools import FunctionInvocationLayer, FunctionTool
|
||||
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder
|
||||
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder, HandoffSentEvent
|
||||
from pytest import param
|
||||
|
||||
from agent_framework_orchestrations._handoff import (
|
||||
HANDOFF_FUNCTION_RESULT_KEY,
|
||||
@@ -34,6 +37,7 @@ from agent_framework_orchestrations._handoff import (
|
||||
from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff
|
||||
|
||||
|
||||
# region unit tests
|
||||
class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
||||
"""Mock chat client for testing handoff workflows."""
|
||||
|
||||
@@ -132,7 +136,12 @@ class MockHandoffAgent(Agent):
|
||||
handoff_to: The name of the agent to hand off to, or None for no handoff.
|
||||
This is hardcoded for testing purposes so that the agent always attempts to hand off.
|
||||
"""
|
||||
super().__init__(client=MockChatClient(name=name, handoff_to=handoff_to), name=name, id=name)
|
||||
super().__init__(
|
||||
client=MockChatClient(name=name, handoff_to=handoff_to),
|
||||
name=name,
|
||||
id=name,
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
|
||||
class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
||||
@@ -255,6 +264,7 @@ async def test_resume_keeps_prior_user_context_for_same_agent() -> None:
|
||||
id="refund_agent",
|
||||
name="refund_agent",
|
||||
client=ContextAwareRefundClient(),
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[refund_agent], termination_condition=lambda _: False)
|
||||
@@ -352,6 +362,7 @@ async def test_tool_approval_responses_are_not_replayed_from_history() -> None:
|
||||
name="refund_agent",
|
||||
client=ApprovalReplayClient(),
|
||||
tools=[submit_refund_counted],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
|
||||
@@ -455,6 +466,7 @@ async def test_handoff_resume_preserves_approval_function_call_for_stateless_run
|
||||
name="refund_agent",
|
||||
client=client,
|
||||
tools=[submit_refund],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
|
||||
@@ -524,11 +536,13 @@ async def test_handoff_replay_serializes_handoff_function_results() -> None:
|
||||
id="triage",
|
||||
name="triage",
|
||||
client=ReplaySafeHandoffClient(name="triage", handoff_sequence=["specialist", None]),
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
specialist = Agent(
|
||||
id="specialist",
|
||||
name="specialist",
|
||||
client=ReplaySafeHandoffClient(name="specialist", handoff_sequence=["triage"]),
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
workflow = (
|
||||
@@ -652,11 +666,13 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs(
|
||||
name="refund_agent",
|
||||
client=refund_client,
|
||||
tools=[submit_refund],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
order_agent = Agent(
|
||||
id="order_agent",
|
||||
name="order_agent",
|
||||
client=OrderReplayClient(),
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[refund_agent, order_agent], termination_condition=lambda _: False)
|
||||
@@ -686,16 +702,6 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs(
|
||||
assert refund_client.resume_validated is True
|
||||
|
||||
|
||||
def test_handoff_clone_disables_provider_side_storage() -> None:
|
||||
"""Handoff executors should force store=False to avoid stale provider call state."""
|
||||
triage = MockHandoffAgent(name="triage")
|
||||
workflow = HandoffBuilder(participants=[triage]).with_start_agent(triage).build()
|
||||
|
||||
executor = workflow.executors[resolve_agent_id(triage)]
|
||||
assert isinstance(executor, HandoffAgentExecutor)
|
||||
assert executor._agent.default_options.get("store") is False
|
||||
|
||||
|
||||
async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None:
|
||||
"""Handoff clones should keep per-service-call history persistence active for auto-handoff termination."""
|
||||
triage_history = InMemoryHistoryProvider()
|
||||
@@ -711,6 +717,7 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() ->
|
||||
name="specialist",
|
||||
client=MockChatClient(name="specialist"),
|
||||
default_options={"tool_choice": "none"},
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
workflow = (
|
||||
@@ -738,21 +745,6 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() ->
|
||||
assert all(message.role != "tool" for message in stored_messages)
|
||||
|
||||
|
||||
async def test_handoff_clears_stale_service_session_id_before_run() -> None:
|
||||
"""Stale service session IDs must be dropped before each handoff agent turn."""
|
||||
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
||||
specialist = MockHandoffAgent(name="specialist")
|
||||
workflow = HandoffBuilder(participants=[triage, specialist]).with_start_agent(triage).build()
|
||||
|
||||
triage_executor = workflow.executors[resolve_agent_id(triage)]
|
||||
assert isinstance(triage_executor, HandoffAgentExecutor)
|
||||
triage_executor._session.service_session_id = "resp_stale_value"
|
||||
|
||||
await _drain(workflow.run("My order is damaged", stream=True))
|
||||
|
||||
assert triage_executor._session.service_session_id is None
|
||||
|
||||
|
||||
def test_clean_conversation_for_handoff_keeps_text_only_history() -> None:
|
||||
"""Tool-control messages must be excluded from persisted handoff history."""
|
||||
function_call = Content.from_function_call(
|
||||
@@ -791,52 +783,6 @@ def test_clean_conversation_for_handoff_keeps_text_only_history() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_persist_missing_approved_function_results_handles_runtime_and_fallback_outputs() -> None:
|
||||
"""Persisted history should retain approved call outputs across runtime shapes."""
|
||||
agent = MockHandoffAgent(name="triage")
|
||||
executor = HandoffAgentExecutor(agent, handoffs=[])
|
||||
|
||||
call_with_runtime_result = "call-runtime-result"
|
||||
call_with_approval_only = "call-approval-only"
|
||||
|
||||
executor._full_conversation = [
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id=call_with_runtime_result, name="submit_refund", arguments={}),
|
||||
Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
approval_response = Content.from_function_approval_response(
|
||||
approved=True,
|
||||
id=call_with_approval_only,
|
||||
function_call=Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}),
|
||||
)
|
||||
runtime_messages = [
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[Content.from_function_result(call_id=call_with_runtime_result, result='{"submitted":true}')],
|
||||
),
|
||||
Message(role="user", contents=[approval_response]),
|
||||
]
|
||||
|
||||
executor._persist_missing_approved_function_results(runtime_tool_messages=runtime_messages, response_messages=[])
|
||||
|
||||
persisted_tool_messages = [message for message in executor._full_conversation if message.role == "tool"]
|
||||
assert persisted_tool_messages
|
||||
persisted_results = [
|
||||
content
|
||||
for message in persisted_tool_messages
|
||||
for content in message.contents
|
||||
if content.type == "function_result" and content.call_id
|
||||
]
|
||||
result_by_call_id = {content.call_id: content.result for content in persisted_results}
|
||||
assert result_by_call_id[call_with_runtime_result] == '{"submitted":true}'
|
||||
assert result_by_call_id[call_with_approval_only] == '{"status":"approved"}'
|
||||
|
||||
|
||||
async def test_autonomous_mode_yields_output_without_user_request():
|
||||
"""Ensure autonomous interaction mode yields output without requesting user input."""
|
||||
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
||||
@@ -979,7 +925,12 @@ async def test_handoff_terminates_without_request_info_when_latest_response_meet
|
||||
|
||||
return _get()
|
||||
|
||||
agent = Agent(id="order_agent", name="order_agent", client=FinalizingClient())
|
||||
agent = Agent(
|
||||
id="order_agent",
|
||||
name="order_agent",
|
||||
client=FinalizingClient(),
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
participants=[agent],
|
||||
@@ -1061,6 +1012,7 @@ async def test_context_provider_preserved_during_handoff():
|
||||
name="test_agent",
|
||||
id="test_agent",
|
||||
context_providers=[context_provider],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
# Verify the original agent has the context provider
|
||||
@@ -1104,8 +1056,8 @@ async def test_auto_handoff_middleware_intercepts_handoff_tool_call() -> None:
|
||||
middleware = _AutoHandoffMiddleware([HandoffConfiguration(target=target_id)])
|
||||
|
||||
@tool(name=get_handoff_tool_name(target_id), approval_mode="never_require")
|
||||
def handoff_tool() -> str:
|
||||
return "unreachable"
|
||||
def handoff_tool() -> None:
|
||||
pass
|
||||
|
||||
context = FunctionInvocationContext(function=handoff_tool, arguments={})
|
||||
call_next = AsyncMock()
|
||||
@@ -1136,6 +1088,20 @@ async def test_auto_handoff_middleware_calls_next_for_non_handoff_tool() -> None
|
||||
assert context.result is None
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_agents_without_per_service_call_history_persistence() -> None:
|
||||
"""HandoffBuilder.build() should reject agents missing require_per_service_call_history_persistence."""
|
||||
agent_without_flag = Agent(
|
||||
client=MockChatClient(name="no_flag"),
|
||||
name="no_flag",
|
||||
id="no_flag",
|
||||
# require_per_service_call_history_persistence defaults to False
|
||||
)
|
||||
agent_with_flag = MockHandoffAgent(name="has_flag") # MockHandoffAgent sets flag to True
|
||||
|
||||
with pytest.raises(ValueError, match="require_per_service_call_history_persistence"):
|
||||
HandoffBuilder(participants=[agent_without_flag, agent_with_flag]).with_start_agent(agent_with_flag).build()
|
||||
|
||||
|
||||
def test_handoff_builder_rejects_non_agent_supports_agent_run():
|
||||
"""Verify that participants() rejects SupportsAgentRun implementations that are not Agent instances."""
|
||||
from agent_framework import AgentResponse, AgentSession, SupportsAgentRun
|
||||
@@ -1160,3 +1126,246 @@ def test_handoff_builder_rejects_non_agent_supports_agent_run():
|
||||
|
||||
with pytest.raises(TypeError, match="Participants must be Agent instances"):
|
||||
HandoffBuilder().participants([fake])
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region integration tests
|
||||
|
||||
|
||||
try:
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
_has_foundry_deps = True
|
||||
except ImportError:
|
||||
_has_foundry_deps = False
|
||||
|
||||
skip_if_foundry_integration_tests_disabled = pytest.mark.skipif(
|
||||
not _has_foundry_deps or os.getenv("FOUNDRY_PROJECT_ENDPOINT", "") == "" or os.getenv("FOUNDRY_MODEL", "") == "",
|
||||
reason="No real FOUNDRY_PROJECT_ENDPOINT or FOUNDRY_MODEL provided; skipping integration tests.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@skip_if_foundry_integration_tests_disabled
|
||||
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
|
||||
async def test_simple_handoff_workflow(store: bool) -> None:
|
||||
"""Test a simple handoff workflow with two agents."""
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
triage_agent = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
|
||||
"based on the problem described."
|
||||
),
|
||||
name="triage_agent",
|
||||
default_options={"store": store},
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
refund_agent = Agent(
|
||||
client=client,
|
||||
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
|
||||
name="refund_agent",
|
||||
default_options={"store": store},
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
participants=[triage_agent, refund_agent],
|
||||
termination_condition=lambda conversation: (
|
||||
# We terminate after triage hands off to refund to test handoff works
|
||||
len(conversation) > 0 and conversation[-1].author_name == refund_agent.name
|
||||
),
|
||||
)
|
||||
.with_start_agent(triage_agent)
|
||||
.build()
|
||||
)
|
||||
|
||||
workflow_result = await workflow.run("I want to get a refund")
|
||||
# The workflow should end in IDLE state rather than IDLE_WITH_PENDING_REQUESTS
|
||||
# because the termination condition is met right after the refund agent's response.
|
||||
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
|
||||
# Output should contain responses from both agents and a final full conversation from between them.
|
||||
assert len(workflow_result.get_outputs()) == 3
|
||||
# There will be exactly one handoff request
|
||||
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
|
||||
assert len(handoff_event) == 1
|
||||
assert isinstance(handoff_event[0].data, HandoffSentEvent)
|
||||
assert handoff_event[0].data.source == triage_agent.name
|
||||
assert handoff_event[0].data.target == refund_agent.name
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@skip_if_foundry_integration_tests_disabled
|
||||
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
|
||||
async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> None:
|
||||
"""Test a simple handoff workflow with two agents where the second agent makes a request after handoff."""
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
triage_agent = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
|
||||
"based on the problem described."
|
||||
),
|
||||
name="triage_agent",
|
||||
default_options={"store": store},
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
refund_agent = Agent(
|
||||
client=client,
|
||||
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
|
||||
name="refund_agent",
|
||||
default_options={"store": store},
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
participants=[triage_agent, refund_agent],
|
||||
termination_condition=lambda conversation: (
|
||||
# We terminate after the refund agent request user input and the user provides
|
||||
# a response. There will be two user messages in the conversation at that point
|
||||
# - the original user message and the follow-up message in response to the refund
|
||||
# agent's request.
|
||||
len([message for message in conversation if message.role == "user"]) == 2
|
||||
),
|
||||
)
|
||||
.with_start_agent(triage_agent)
|
||||
.build()
|
||||
)
|
||||
|
||||
workflow_result = await workflow.run("I want to get a refund")
|
||||
# The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE
|
||||
# because the user has not yet responded to the refund agent's request yet.
|
||||
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
# There will be exactly one handoff request
|
||||
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
|
||||
assert len(handoff_event) == 1
|
||||
assert isinstance(handoff_event[0].data, HandoffSentEvent)
|
||||
assert handoff_event[0].data.source == triage_agent.name
|
||||
assert handoff_event[0].data.target == refund_agent.name
|
||||
# There should be exactly one request for information from the refund agent after handoff
|
||||
request_events = [event for event in workflow_result if event.type == "request_info"]
|
||||
assert len(request_events) == 1
|
||||
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
|
||||
# Provide the user's response to the refund agent's request to allow the workflow to complete.
|
||||
workflow_result = await workflow.run(
|
||||
responses={
|
||||
request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"),
|
||||
},
|
||||
)
|
||||
|
||||
# The workflow should now end in IDLE state since the termination condition
|
||||
# is met after the user's response to the refund agent's request.
|
||||
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
|
||||
@tool(approval_mode="always_require")
|
||||
def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str:
|
||||
"""Simulated function to process a refund for a given order number."""
|
||||
return f"Refund processed successfully for order {order_number}."
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@skip_if_foundry_integration_tests_disabled
|
||||
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
|
||||
async def test_simple_handoff_workflow_with_approval_request(store: bool) -> None:
|
||||
"""Test a simple handoff workflow with two agents where the second agent makes a request after handoff."""
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
triage_agent = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
|
||||
"based on the problem described."
|
||||
),
|
||||
name="triage_agent",
|
||||
default_options={"store": store},
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
refund_agent = Agent(
|
||||
client=client,
|
||||
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
|
||||
name="refund_agent",
|
||||
default_options={"store": store},
|
||||
tools=[process_refund],
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
# This workflow will be terminated manually
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
participants=[triage_agent, refund_agent],
|
||||
)
|
||||
.with_start_agent(triage_agent)
|
||||
.build()
|
||||
)
|
||||
|
||||
workflow_result = await workflow.run("I want to get a refund")
|
||||
# The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE
|
||||
# because the user has not yet responded to the refund agent's request yet.
|
||||
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
# There will be exactly one handoff request
|
||||
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
|
||||
assert len(handoff_event) == 1
|
||||
assert isinstance(handoff_event[0].data, HandoffSentEvent)
|
||||
assert handoff_event[0].data.source == triage_agent.name
|
||||
assert handoff_event[0].data.target == refund_agent.name
|
||||
# There should be exactly one request for information from the refund agent after handoff
|
||||
request_events = [event for event in workflow_result if event.type == "request_info"]
|
||||
assert len(request_events) == 1
|
||||
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
|
||||
# Provide the user's response to the refund agent's request to allow the workflow to complete.
|
||||
workflow_result = await workflow.run(
|
||||
responses={
|
||||
request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"),
|
||||
},
|
||||
)
|
||||
|
||||
# The workflow should now end in IDLE_WITH_PENDING_REQUESTS state since the refund agent will ask for
|
||||
# approval to process the refund after receiving the user's response.
|
||||
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
|
||||
# There should be exactly one request for tool approval from the refund agent.
|
||||
request_events = [event for event in workflow_result if event.type == "request_info"]
|
||||
assert len(request_events) == 1
|
||||
assert isinstance(request_events[0].data, Content) and request_events[0].data.type == "function_approval_request"
|
||||
|
||||
# Provide the user's response to the refund agent's request to allow the workflow to complete.
|
||||
workflow_result = await workflow.run(
|
||||
responses={request_events[0].request_id: request_events[0].data.to_function_approval_response(approved=True)}
|
||||
)
|
||||
|
||||
# The refund agent will process the refund after receiving approval, but since there is no termination condition,
|
||||
# the workflow will end in IDLE_WITH_PENDING_REQUESTS state waiting for further user input.
|
||||
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
# There should be exactly one request for information from the refund agent after processing the refund,
|
||||
# which is the follow-up question asking if there is anything else they can help with.
|
||||
request_events = [event for event in workflow_result if event.type == "request_info"]
|
||||
assert len(request_events) == 1
|
||||
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
|
||||
workflow_result = await workflow.run(responses={request_events[0].request_id: HandoffAgentUserRequest.terminate()})
|
||||
|
||||
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user