mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
866a325b48
* Fix orchestration outputs so as_agent() returns the final answer only. Align other orchestration outputs * Fix orchestration output issues from review comments 1. Sample cleanup: Remove commented-out FoundryChatClient block and update prerequisites to reference OPENAI_CHAT_MODEL_ID instead of FOUNDRY_* vars. 2. Sequential approval output: Change _EndWithConversation.end_with_agent_executor_response from a no-op sink to yield response.agent_response. When the last participant is AgentApprovalExecutor (via with_request_info), _EndWithConversation is the output executor so the yield produces the terminal answer. When the last participant is a regular AgentExecutor, _EndWithConversation is not in output_executors so the yield is silently filtered out. 3. Forward data events through WorkflowExecutor: _process_workflow_result now also forwards 'data' events from sub-workflows so that emit_intermediate_data=True on AgentExecutor works correctly when wrapped in AgentApprovalExecutor. 4. Concurrent docstring: Update _AggregateAgentConversations docstring to say 'deterministic participant order' instead of 'completion order'. 5. Add test_concurrent_intermediate_outputs_emits_data_events verifying that ConcurrentBuilder(intermediate_outputs=True) emits per-participant data events alongside the single aggregated output event. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add tests for sequential workflow with_request_info and intermediate_outputs (#5301) Address PR review comments 2, 3, and 5: - Add test_sequential_request_info_last_participant_emits_output: Verifies that when the last participant is wrapped via with_request_info() (AgentApprovalExecutor), the workflow still emits a terminal output after approval, exercising the _EndWithConversation.end_with_agent_executor_response fallback path. - Add test_sequential_request_info_with_intermediate_outputs_emits_data_events: Verifies that emit_intermediate_data=True works correctly through AgentApprovalExecutor wrapping—WorkflowExecutor._process_result already forwards data events from sub-workflows, so intermediate agent responses surface as data events in the parent workflow. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix pyright type errors from AgentResponse output refactor (#5301) Update cast() calls in _group_chat.py and _magentic.py to use WorkflowContext[Never, AgentResponse] instead of the old WorkflowContext[Never, list[Message]], matching the updated method signatures in _base_group_chat_orchestrator.py. Fix _sequential.py _EndWithConversation.end_with_agent_executor_response to declare WorkflowContext[Any, AgentResponse] so yield_output accepts AgentResponse[None]. Fix _workflow_executor.py data event forwarding to handle nullable executor_id. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix pyright reportUnknownVariableType in _agent.py (#5301) Extract event.data into a typed local variable before the isinstance check to avoid pyright narrowing it to AgentResponse[Unknown]. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix pyright reportMissingImports for orjson in file history samples (#5301) Add pyright: ignore[reportMissingImports] to orjson imports that are already guarded by try/except ImportError, matching the existing pattern used elsewhere in the samples. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #5301: review comment fixes * Address review feedback for #5301: review comment fixes * Revert sequential_workflow_as_agent sample to FoundryChatClient Reverts the mistaken switch from FoundryChatClient to OpenAIChatClient in the sequential workflow as agent sample. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address ultrareview feedback: emit_data_events rename + WorkflowAgent reasoning conversion Layered on top of the prior review-feedback work in this branch. Renames: - AgentExecutor.emit_intermediate_data -> emit_data_events (mechanical rename; orchestration semantics live at the orchestration layer, not the general-purpose executor). Forwarded through MagenticAgentExecutor, AgentApprovalExecutor, and all orchestration call sites. - HandoffAgentExecutor._check_terminate_and_yield -> _should_terminate (pure predicate; no longer yields anything). HandoffBuilder docstring rewritten to describe the new per-agent AgentResponse output contract. WorkflowAgent reasoning-content conversion: - Add _rewrite_text_to_reasoning(contents) and _msg_as_reasoning(msg) helpers; the as_agent() path now reframes text content from data events as text_reasoning Content blocks before merging into the AgentResponse. - Consumers iterate msg.contents and branch on content.type — same path they already use for Claude thinking and OpenAI reasoning. No new field on Message/AgentResponse/WorkflowEvent. - Streaming branch constructs fresh AgentResponseUpdate instances instead of mutating shared payloads (regression test added). - Helper _msg_maybe_reasoning consolidates the conditional rewrite at three call sites in the non-streaming conversion. Tests: - TestWorkflowAgentReasoningHelpers + TestWorkflowAgentDataEventReasoningConversion add 9 new tests covering helpers, non-streaming, streaming, mixed content, already-reasoning passthrough, and mutation-safety regression. - Updated test_sequential_as_agent_with_intermediate_outputs_includes_chain to assert text_reasoning content for intermediate agents. * Fix pyright: widen event.data to Any to avoid partial-unknown narrowing The streaming conversion path narrowed event.data via isinstance against generic AgentResponse, producing AgentResponse[Unknown] and tripping reportUnknownVariableType/reportUnknownMemberType. Binding data: Any before the check keeps runtime behavior identical while restoring a fully known type for downstream access. * Clean up design * Scope to agent output semantics only * yield AgentResponseUpdate streaming, AgentResponse non-streaming * Fix mypy/pyright: widen cast types at GroupChat callsites Eight callsites in _group_chat.py still cast to WorkflowContext[Never, AgentResponse] but the base orchestrator methods now accept the wider WorkflowContext[Never, AgentResponse | AgentResponseUpdate] (mode-aware yields). W_OutT is invariant, so the narrower cast is not assignable. Magentic was widened in the same commit; this catches the GroupChat callsites that were missed. * Python: skip flaky Foundry / Foundry Hosting integration tests (#5553) These two integration tests have been failing in the merge queue across multiple unrelated PRs (5301, 5531). Both are marked `@pytest.mark.flaky` with 3 retries, but all attempts fail back-to-back. Skipping both with a reason pointing to #5553 so they can be fixed properly without continuing to block unrelated merges. - packages/foundry_hosting/tests/test_responses_int.py::TestOptions::test_temperature_and_max_tokens - packages/foundry/tests/foundry/test_foundry_embedding_client.py::TestFoundryEmbeddingIntegration::test_text_embedding_live Also includes a one-line uv.lock specifier-ordering normalization auto-applied by the poe-check pre-commit hook. --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1417 lines
56 KiB
Python
1417 lines
56 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import os
|
|
import re
|
|
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
|
from typing import Annotated, Any, cast
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from agent_framework import (
|
|
Agent,
|
|
AgentResponse,
|
|
AgentResponseUpdate,
|
|
ChatResponse,
|
|
ChatResponseUpdate,
|
|
Content,
|
|
ContextProvider,
|
|
InMemoryHistoryProvider,
|
|
Message,
|
|
ResponseStream,
|
|
WorkflowEvent,
|
|
WorkflowRunState,
|
|
function_middleware,
|
|
resolve_agent_id,
|
|
tool,
|
|
)
|
|
from agent_framework._clients import BaseChatClient
|
|
from agent_framework._middleware import (
|
|
ChatMiddlewareLayer,
|
|
FunctionInvocationContext,
|
|
MiddlewareTermination,
|
|
)
|
|
from agent_framework._tools import FunctionInvocationLayer, FunctionTool
|
|
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder, HandoffSentEvent
|
|
from pytest import param
|
|
|
|
from agent_framework_orchestrations._handoff import (
|
|
HANDOFF_FUNCTION_RESULT_KEY,
|
|
HandoffAgentExecutor,
|
|
HandoffConfiguration,
|
|
_AutoHandoffMiddleware, # pyright: ignore[reportPrivateUsage]
|
|
get_handoff_tool_name,
|
|
)
|
|
from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff
|
|
|
|
|
|
# region unit tests
|
|
class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
"""Mock chat client for testing handoff workflows."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str = "",
|
|
handoff_to: str | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the mock chat client.
|
|
|
|
Args:
|
|
name: The name of the agent using this chat client.
|
|
handoff_to: The name of the agent to hand off to, or None for no handoff.
|
|
This is hardcoded for testing purposes so that the agent always attempts to hand off.
|
|
"""
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._name = name
|
|
self._handoff_to = handoff_to
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
if stream:
|
|
return self._build_streaming_response(options=dict(options))
|
|
|
|
async def _get() -> ChatResponse:
|
|
contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id())
|
|
reply = Message(
|
|
role="assistant",
|
|
contents=contents,
|
|
)
|
|
return ChatResponse(messages=reply, response_id="mock_response")
|
|
|
|
return _get()
|
|
|
|
def _build_streaming_response(self, *, options: dict[str, Any]) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id())
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
|
response_format = options.get("response_format")
|
|
output_format_type = response_format if isinstance(response_format, type) else None
|
|
return ChatResponse.from_updates(updates, output_format_type=output_format_type)
|
|
|
|
return ResponseStream(_stream(), finalizer=_finalize)
|
|
|
|
def _next_call_id(self) -> str | None:
|
|
if not self._handoff_to:
|
|
return None
|
|
call_id = f"{self._name}-handoff-{self._call_index}"
|
|
self._call_index += 1
|
|
return call_id
|
|
|
|
|
|
def _build_reply_contents(
|
|
agent_name: str,
|
|
handoff_to: str | None,
|
|
call_id: str | None,
|
|
) -> list[Content]:
|
|
contents: list[Content] = []
|
|
if handoff_to and call_id:
|
|
contents.append(
|
|
Content.from_function_call(
|
|
call_id=call_id, name=f"handoff_to_{handoff_to}", arguments={"handoff_to": handoff_to}
|
|
)
|
|
)
|
|
text = f"{agent_name} reply"
|
|
contents.append(Content.from_text(text=text))
|
|
return contents
|
|
|
|
|
|
class MockHandoffAgent(Agent):
|
|
"""Mock agent that can hand off to another agent."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str,
|
|
handoff_to: str | None = None,
|
|
) -> None:
|
|
"""Initialize the mock handoff agent.
|
|
|
|
Args:
|
|
name: The name of the agent.
|
|
handoff_to: The name of the agent to hand off to, or None for no handoff.
|
|
This is hardcoded for testing purposes so that the agent always attempts to hand off.
|
|
"""
|
|
super().__init__(
|
|
client=MockChatClient(name=name, handoff_to=handoff_to),
|
|
name=name,
|
|
id=name,
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
|
|
class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
"""Mock client that expects prior user context to remain available on resume."""
|
|
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del kwargs
|
|
del options
|
|
|
|
contents = self._next_contents(messages)
|
|
if stream:
|
|
return self._build_streaming_response(contents)
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="context-aware")
|
|
|
|
return _get()
|
|
|
|
def _build_streaming_response(self, contents: list[Content]) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
|
return ChatResponse.from_updates(updates)
|
|
|
|
return ResponseStream(_stream(), finalizer=_finalize)
|
|
|
|
def _next_contents(self, messages: Sequence[Message]) -> list[Content]:
|
|
user_text = " ".join(message.text or "" for message in messages if message.role == "user")
|
|
order_match = re.search(r"\b(\d{4,12})\b", user_text)
|
|
order_id = order_match.group(1) if order_match else None
|
|
asks_refund = any(token in user_text.lower() for token in ("broken", "damaged", "refund", "cracked"))
|
|
|
|
if self._call_index == 0:
|
|
reply = "Refund Agent: Please share your order number."
|
|
elif self._call_index == 1:
|
|
if order_id:
|
|
reply = f"Refund Agent: Thanks, I found order {order_id}. Why do you need the refund?"
|
|
else:
|
|
reply = "Refund Agent: I still need your order number."
|
|
else:
|
|
if order_id and asks_refund:
|
|
reply = f"Refund Agent: Got it for order {order_id}. I can proceed with your refund."
|
|
else:
|
|
reply = "Refund Agent: I still need your order number."
|
|
|
|
self._call_index += 1
|
|
return [Content.from_text(text=reply)]
|
|
|
|
|
|
async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]:
|
|
return [event async for event in stream]
|
|
|
|
|
|
async def test_handoff():
|
|
"""Test that agents can hand off to each other."""
|
|
|
|
# `triage` hands off to `specialist`, who then hands off to `escalation`.
|
|
# `escalation` has no handoff, so the workflow should request user input to continue.
|
|
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
|
specialist = MockHandoffAgent(name="specialist", handoff_to="escalation")
|
|
escalation = MockHandoffAgent(name="escalation")
|
|
|
|
# Without explicitly defining handoffs, the builder will create connections
|
|
# between all agents.
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[triage, specialist, escalation],
|
|
termination_condition=lambda conv: sum(1 for m in conv if m.role == "user") >= 2,
|
|
)
|
|
.with_start_agent(triage)
|
|
.build()
|
|
)
|
|
|
|
# Start conversation - triage hands off to specialist then escalation
|
|
# escalation won't trigger a handoff, so the response from it will become
|
|
# a request for user input because autonomous mode is not enabled by default.
|
|
events = await _drain(workflow.run("Need technical support", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
|
|
assert requests
|
|
assert len(requests) == 1
|
|
|
|
request = requests[0]
|
|
assert isinstance(request.data, HandoffAgentUserRequest)
|
|
assert request.source_executor_id == escalation.name
|
|
|
|
|
|
def _latest_request_info_event(events: list[WorkflowEvent]) -> WorkflowEvent[Any]:
|
|
request_events = [event for event in events if event.type == "request_info"]
|
|
assert request_events
|
|
request_event = request_events[-1]
|
|
assert isinstance(request_event.data, HandoffAgentUserRequest)
|
|
return request_event
|
|
|
|
|
|
def _request_text(event: WorkflowEvent[Any]) -> str:
|
|
request_payload = cast(HandoffAgentUserRequest, event.data)
|
|
messages = request_payload.agent_response.messages
|
|
assert messages
|
|
return messages[-1].text or ""
|
|
|
|
|
|
async def test_resume_keeps_prior_user_context_for_same_agent() -> None:
|
|
"""Ensure same-agent request_info resumes retain prior turn context."""
|
|
refund_agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=ContextAwareRefundClient(),
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[refund_agent], termination_condition=lambda _: False)
|
|
.with_start_agent(refund_agent)
|
|
.build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("My order arrived damaged.", stream=True))
|
|
first_request = _latest_request_info_event(first_events)
|
|
assert "order number" in _request_text(first_request).lower()
|
|
|
|
second_events = await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={first_request.request_id: [Message(role="user", contents=["Order 2939393"])]},
|
|
)
|
|
)
|
|
second_request = _latest_request_info_event(second_events)
|
|
second_text = _request_text(second_request).lower()
|
|
assert "order 2939393" in second_text
|
|
assert "order number" not in second_text
|
|
|
|
third_events = await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={second_request.request_id: [Message(role="user", contents=["It arrived broken and unusable."])]},
|
|
)
|
|
)
|
|
third_request = _latest_request_info_event(third_events)
|
|
third_text = _request_text(third_request).lower()
|
|
assert "order 2939393" in third_text
|
|
assert "order number" not in third_text
|
|
|
|
|
|
async def test_tool_approval_responses_are_not_replayed_from_history() -> None:
|
|
"""Ensure persisted history does not re-execute previously approved tool calls."""
|
|
execution_count = 0
|
|
|
|
@tool(name="submit_refund_counted", approval_mode="always_require")
|
|
def submit_refund_counted() -> str:
|
|
nonlocal execution_count
|
|
execution_count += 1
|
|
return "ok"
|
|
|
|
class ApprovalReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del messages
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [
|
|
Content.from_function_call(
|
|
call_id="refund-call-1",
|
|
name="submit_refund_counted",
|
|
arguments={},
|
|
)
|
|
]
|
|
elif self._call_index == 1:
|
|
contents = [Content.from_text(text="Refund approved and recorded.")]
|
|
else:
|
|
contents = [Content.from_text(text="No additional tool work needed.")]
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=contents)],
|
|
response_id="approval-replay",
|
|
)
|
|
|
|
return _get()
|
|
|
|
agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=ApprovalReplayClient(),
|
|
tools=[submit_refund_counted],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("start", stream=True))
|
|
first_requests = [event for event in first_events if event.type == "request_info"]
|
|
assert first_requests
|
|
first_request = first_requests[-1]
|
|
assert isinstance(first_request.data, Content)
|
|
approval_response = first_request.data.to_function_approval_response(approved=True)
|
|
|
|
second_events = await _drain(workflow.run(stream=True, responses={first_request.request_id: approval_response}))
|
|
second_request = _latest_request_info_event(second_events)
|
|
|
|
await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={second_request.request_id: [Message(role="user", contents=["Thanks, what's next?"])]},
|
|
)
|
|
)
|
|
|
|
assert execution_count == 1
|
|
|
|
|
|
async def test_handoff_resume_preserves_approval_function_call_for_stateless_runs() -> None:
|
|
"""Approval resume turns must replay matching function calls when store=False."""
|
|
|
|
@tool(name="submit_refund", approval_mode="always_require")
|
|
def submit_refund() -> str:
|
|
return "ok"
|
|
|
|
class StrictStatelessApprovalClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
self.resume_validated = False
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [
|
|
Content.from_function_call(
|
|
call_id="refund-call-1",
|
|
name="submit_refund",
|
|
arguments={},
|
|
)
|
|
]
|
|
else:
|
|
function_call_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_call" and content.call_id
|
|
}
|
|
function_result_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_result" and content.call_id
|
|
}
|
|
missing_call_ids = sorted(function_result_ids - function_call_ids)
|
|
if missing_call_ids:
|
|
raise AssertionError(
|
|
f"No tool call found for function call output with call_id {missing_call_ids[0]}."
|
|
)
|
|
self.resume_validated = True
|
|
contents = [Content.from_text(text="Refund submitted.")]
|
|
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=contents)],
|
|
response_id="strict-stateless",
|
|
)
|
|
|
|
return _get()
|
|
|
|
client = StrictStatelessApprovalClient()
|
|
agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=client,
|
|
tools=[submit_refund],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("start", stream=True))
|
|
approval_requests = [
|
|
event for event in first_events if event.type == "request_info" and isinstance(event.data, Content)
|
|
]
|
|
assert approval_requests
|
|
first_request = approval_requests[0]
|
|
|
|
approval_response = first_request.data.to_function_approval_response(True)
|
|
await _drain(workflow.run(stream=True, responses={first_request.request_id: approval_response}))
|
|
|
|
assert client.resume_validated is True
|
|
|
|
|
|
async def test_handoff_replay_serializes_handoff_function_results() -> None:
|
|
"""Returning to the same agent must not replay dict tool outputs."""
|
|
|
|
class ReplaySafeHandoffClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self, name: str, handoff_sequence: list[str | None]) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._name = name
|
|
self._handoff_sequence = handoff_sequence
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del options
|
|
del kwargs
|
|
|
|
for message in messages:
|
|
for content in message.contents:
|
|
if content.type == "function_result" and isinstance(content.result, dict):
|
|
raise AssertionError("Expected replayed function_result payloads to be JSON strings.")
|
|
|
|
handoff_to = (
|
|
self._handoff_sequence[self._call_index] if self._call_index < len(self._handoff_sequence) else None
|
|
)
|
|
call_id = f"{self._name}-handoff-{self._call_index}" if handoff_to else None
|
|
contents = _build_reply_contents(self._name, handoff_to, call_id)
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="replay-safe")
|
|
|
|
return _get()
|
|
|
|
triage = Agent(
|
|
id="triage",
|
|
name="triage",
|
|
client=ReplaySafeHandoffClient(name="triage", handoff_sequence=["specialist", None]),
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
specialist = Agent(
|
|
id="specialist",
|
|
name="specialist",
|
|
client=ReplaySafeHandoffClient(name="specialist", handoff_sequence=["triage"]),
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False)
|
|
.with_start_agent(triage)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("start", stream=True))
|
|
requests = [event for event in events if event.type == "request_info"]
|
|
assert requests
|
|
assert requests[-1].source_executor_id == triage.name
|
|
|
|
|
|
async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs() -> None:
|
|
"""Approved calls must keep function_call/function_result pairs for later replays."""
|
|
submit_call_id = "call_submit_refund_approved"
|
|
|
|
@tool(name="submit_refund", approval_mode="always_require")
|
|
def submit_refund() -> str:
|
|
return "submitted"
|
|
|
|
class RefundReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
self.resume_validated = False
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [Content.from_function_call(call_id=submit_call_id, name="submit_refund", arguments={})]
|
|
elif self._call_index == 1:
|
|
contents = _build_reply_contents("refund_agent", "order_agent", "refund-order-handoff-1")
|
|
else:
|
|
function_call_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_call" and content.call_id
|
|
}
|
|
function_result_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_result" and content.call_id
|
|
}
|
|
if submit_call_id in function_call_ids and submit_call_id not in function_result_ids:
|
|
raise AssertionError(f"No tool output found for function call {submit_call_id}.")
|
|
self.resume_validated = True
|
|
contents = [Content.from_text(text="Refund agent resumed.")]
|
|
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=contents)],
|
|
response_id="refund-replay",
|
|
)
|
|
|
|
return _get()
|
|
|
|
class OrderReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del messages
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [Content.from_text(text="Would you like a replacement or a refund?")]
|
|
else:
|
|
contents = _build_reply_contents("order_agent", "refund_agent", "order-refund-handoff-1")
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="order-replay")
|
|
|
|
return _get()
|
|
|
|
refund_client = RefundReplayClient()
|
|
refund_agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=refund_client,
|
|
tools=[submit_refund],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
order_agent = Agent(
|
|
id="order_agent",
|
|
name="order_agent",
|
|
client=OrderReplayClient(),
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[refund_agent, order_agent], termination_condition=lambda _: False)
|
|
.with_start_agent(refund_agent)
|
|
.build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("start", stream=True))
|
|
approval_requests = [
|
|
event for event in first_events if event.type == "request_info" and isinstance(event.data, Content)
|
|
]
|
|
assert approval_requests
|
|
approval_request = approval_requests[-1]
|
|
approval_response = approval_request.data.to_function_approval_response(True)
|
|
|
|
second_events = await _drain(workflow.run(stream=True, responses={approval_request.request_id: approval_response}))
|
|
order_request = _latest_request_info_event(second_events)
|
|
assert order_request.source_executor_id == order_agent.name
|
|
|
|
await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={order_request.request_id: [Message(role="user", contents=["Please continue with refund."])]},
|
|
)
|
|
)
|
|
|
|
assert refund_client.resume_validated is True
|
|
|
|
|
|
async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None:
|
|
"""Handoff clones should keep per-service-call history persistence active for auto-handoff termination."""
|
|
triage_history = InMemoryHistoryProvider()
|
|
triage = Agent(
|
|
id="triage",
|
|
name="triage",
|
|
client=MockChatClient(name="triage", handoff_to="specialist"),
|
|
context_providers=[triage_history],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
specialist = Agent(
|
|
id="specialist",
|
|
name="specialist",
|
|
client=MockChatClient(name="specialist"),
|
|
default_options={"tool_choice": "none"},
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False)
|
|
.with_start_agent(triage)
|
|
.add_handoff(triage, [specialist])
|
|
.add_handoff(specialist, [triage])
|
|
.build()
|
|
)
|
|
|
|
await _drain(workflow.run("start", stream=True))
|
|
|
|
executor = workflow.executors[resolve_agent_id(triage)]
|
|
assert isinstance(executor, HandoffAgentExecutor)
|
|
assert executor._agent.require_per_service_call_history_persistence is True
|
|
|
|
provider_state = executor._session.state[triage_history.source_id]
|
|
stored_messages = await triage_history.get_messages(
|
|
executor._session.session_id,
|
|
state=provider_state,
|
|
)
|
|
|
|
assert [message.role for message in stored_messages] == ["user", "assistant"]
|
|
assert any(content.type == "function_call" for content in stored_messages[-1].contents)
|
|
assert all(message.role != "tool" for message in stored_messages)
|
|
|
|
|
|
async def test_handoff_clone_preserves_all_middleware_types() -> None:
|
|
"""Handoff clones should preserve function and agent middleware from the original agent."""
|
|
|
|
@function_middleware
|
|
async def tracking_middleware(context: FunctionInvocationContext, call_next):
|
|
await call_next()
|
|
|
|
agent_a = Agent(
|
|
id="agent_a",
|
|
name="agent_a",
|
|
client=MockChatClient(name="agent_a", handoff_to="agent_b"),
|
|
middleware=[tracking_middleware],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
agent_b = Agent(
|
|
id="agent_b",
|
|
name="agent_b",
|
|
client=MockChatClient(name="agent_b"),
|
|
default_options={"tool_choice": "none"},
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[agent_a, agent_b], termination_condition=lambda _: False)
|
|
.with_start_agent(agent_a)
|
|
.add_handoff(agent_a, [agent_b])
|
|
.add_handoff(agent_b, [agent_a])
|
|
.build()
|
|
)
|
|
|
|
executor = workflow.executors[resolve_agent_id(agent_a)]
|
|
assert isinstance(executor, HandoffAgentExecutor)
|
|
cloned_middleware = executor._agent.middleware or []
|
|
assert tracking_middleware in cloned_middleware, "User function middleware should be preserved on cloned agent"
|
|
|
|
|
|
def test_clean_conversation_for_handoff_keeps_text_only_history() -> None:
|
|
"""Tool-control messages must be excluded from persisted handoff history."""
|
|
function_call = Content.from_function_call(
|
|
call_id="handoff-call-1",
|
|
name="handoff_to_refund_agent",
|
|
arguments={"context": "route to refund"},
|
|
)
|
|
approval_response = Content.from_function_approval_response(
|
|
approved=True,
|
|
id="approval-1",
|
|
function_call=function_call,
|
|
)
|
|
|
|
conversation = [
|
|
Message(role="user", contents=["My order arrived damaged."]),
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
function_call,
|
|
Content.from_text(text="Triage Agent: Routing you to Refund."),
|
|
],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result(call_id="handoff-call-1", result="ok")]),
|
|
Message(role="user", contents=[approval_response]),
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call(call_id="handoff-call-2", name="handoff_to_order_agent")],
|
|
),
|
|
]
|
|
|
|
cleaned = clean_conversation_for_handoff(conversation)
|
|
assert [message.role for message in cleaned] == ["user", "assistant"]
|
|
assert [message.text for message in cleaned] == [
|
|
"My order arrived damaged.",
|
|
"Triage Agent: Routing you to Refund.",
|
|
]
|
|
|
|
|
|
async def test_autonomous_mode_yields_output_without_user_request():
|
|
"""Ensure autonomous interaction mode yields output without requesting user input."""
|
|
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
|
specialist = MockHandoffAgent(name="specialist")
|
|
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[triage, specialist],
|
|
# This termination condition ensures the workflow runs through both agents.
|
|
# First message is the user message to triage, second is triage's response, which
|
|
# is a handoff to specialist, third is specialist's response that should not request
|
|
# user input due to autonomous mode. Fourth message will come from the specialist
|
|
# again and will trigger termination.
|
|
termination_condition=lambda conv: len(conv) >= 4,
|
|
)
|
|
.with_start_agent(triage)
|
|
# Since specialist has no handoff, the specialist will be generating normal responses.
|
|
# With autonomous mode, this should continue until the termination condition is met.
|
|
.with_autonomous_mode(
|
|
agents=[specialist],
|
|
turn_limits={resolve_agent_id(specialist): 1},
|
|
)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("Package arrived broken", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
assert not requests, "Autonomous mode should not request additional user input"
|
|
|
|
outputs = [ev for ev in events if ev.type == "output"]
|
|
assert outputs, "Autonomous mode should yield a workflow output"
|
|
|
|
# Per-agent activity surfaces as `output` events from each HandoffAgentExecutor as they
|
|
# speak. Handoff has no orchestrator that produces a separate "answer" — the conversation
|
|
# IS the result. In streaming mode payloads are AgentResponseUpdate; combined text should
|
|
# contain the specialist's reply.
|
|
payloads = [ev.data for ev in outputs if isinstance(ev.data, (AgentResponse, AgentResponseUpdate))]
|
|
combined = " ".join(
|
|
getattr(p, "text", None) or " ".join(m.text for m in getattr(p, "messages", [])) for p in payloads
|
|
)
|
|
assert "specialist reply" in combined
|
|
|
|
|
|
async def test_autonomous_mode_resumes_user_input_on_turn_limit():
|
|
"""Autonomous mode should resume user input request when turn limit is reached."""
|
|
triage = MockHandoffAgent(name="triage", handoff_to="worker")
|
|
worker = MockHandoffAgent(name="worker")
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[triage, worker], termination_condition=lambda conv: False)
|
|
.with_start_agent(triage)
|
|
.with_autonomous_mode(agents=[worker], turn_limits={resolve_agent_id(worker): 2})
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("Start", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
assert requests and len(requests) == 1, "Turn limit should force a user input request"
|
|
assert requests[0].source_executor_id == worker.name
|
|
|
|
|
|
def test_build_fails_without_start_agent():
|
|
"""Verify that build() raises ValueError when with_start_agent() was not called."""
|
|
triage = MockHandoffAgent(name="triage")
|
|
specialist = MockHandoffAgent(name="specialist")
|
|
|
|
with pytest.raises(ValueError, match=r"Must call with_start_agent\(...\) before building the workflow."):
|
|
HandoffBuilder(participants=[triage, specialist]).build()
|
|
|
|
|
|
def test_build_fails_without_participants():
|
|
"""Verify that build() raises ValueError when no participants are provided."""
|
|
with pytest.raises(ValueError):
|
|
HandoffBuilder(participants=[]).build()
|
|
|
|
|
|
async def test_handoff_async_termination_condition() -> None:
|
|
"""Test that async termination conditions work correctly."""
|
|
termination_call_count = 0
|
|
|
|
async def async_termination(conv: list[Message]) -> bool:
|
|
nonlocal termination_call_count
|
|
termination_call_count += 1
|
|
user_count = sum(1 for msg in conv if msg.role == "user")
|
|
return user_count >= 2
|
|
|
|
coordinator = MockHandoffAgent(name="coordinator", handoff_to="worker")
|
|
worker = MockHandoffAgent(name="worker")
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[coordinator, worker], termination_condition=async_termination)
|
|
.with_start_agent(coordinator)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("First user message", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
assert requests
|
|
|
|
events = await _drain(
|
|
workflow.run(
|
|
stream=True, responses={requests[-1].request_id: [Message(role="user", contents=["Second user message"])]}
|
|
)
|
|
)
|
|
# Resume run terminates without further agent activity once the second user message
|
|
# satisfies the termination condition. The workflow returns to idle cleanly.
|
|
idle_states = [ev for ev in events if ev.type == "status" and ev.state == WorkflowRunState.IDLE]
|
|
assert idle_states, "Workflow should become idle after termination"
|
|
assert termination_call_count > 0
|
|
|
|
|
|
async def test_handoff_terminates_without_request_info_when_latest_response_meets_condition() -> None:
|
|
"""Termination triggered by the latest assistant response should not emit request_info."""
|
|
|
|
class FinalizingClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del messages, options, kwargs
|
|
contents = [Content.from_text(text="Replacement request submitted. Case complete.")]
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="finalizing")
|
|
|
|
return _get()
|
|
|
|
agent = Agent(
|
|
id="order_agent",
|
|
name="order_agent",
|
|
client=FinalizingClient(),
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[agent],
|
|
termination_condition=lambda conv: any(
|
|
message.role == "assistant" and "case complete." in (message.text or "").lower() for message in conv
|
|
),
|
|
)
|
|
.with_start_agent(agent)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("ship replacement", stream=True))
|
|
|
|
requests = [event for event in events if event.type == "request_info"]
|
|
assert not requests
|
|
|
|
outputs = [event for event in events if event.type == "output"]
|
|
assert outputs
|
|
# Per-agent activity surfaces as output events (AgentResponseUpdate in streaming mode).
|
|
agent_payloads = [event for event in outputs if isinstance(event.data, (AgentResponse, AgentResponseUpdate))]
|
|
assert len(agent_payloads) >= 1
|
|
|
|
|
|
async def test_tool_choice_preserved_from_agent_config():
|
|
"""Verify that agent-level tool_choice configuration is preserved and not overridden."""
|
|
# Create a mock chat client that records the tool_choice used
|
|
recorded_tool_choices: list[Any] = []
|
|
|
|
async def mock_get_response(messages: Any, options: dict[str, Any] | None = None, **kwargs: Any) -> ChatResponse:
|
|
if options:
|
|
recorded_tool_choices.append(options.get("tool_choice"))
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=["Response"])],
|
|
response_id="test_response",
|
|
)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_response = AsyncMock(side_effect=mock_get_response)
|
|
|
|
# Create agent with specific tool_choice configuration via default_options
|
|
agent = Agent(
|
|
client=mock_client,
|
|
name="test_agent",
|
|
default_options={"tool_choice": {"mode": "required"}}, # type: ignore
|
|
)
|
|
|
|
# Run the agent
|
|
await agent.run("Test message")
|
|
|
|
# Verify tool_choice was preserved
|
|
assert len(recorded_tool_choices) > 0, "No tool_choice recorded"
|
|
last_tool_choice = recorded_tool_choices[-1]
|
|
assert last_tool_choice is not None, "tool_choice should not be None"
|
|
assert last_tool_choice == {"mode": "required"}, f"Expected 'required', got {last_tool_choice}"
|
|
|
|
|
|
async def test_context_provider_preserved_during_handoff():
|
|
"""Verify that context_providers are preserved when cloning agents in handoff workflows."""
|
|
# Track whether context provider methods were called
|
|
provider_calls: list[str] = []
|
|
|
|
class TestContextProvider(ContextProvider):
|
|
"""A test context provider that tracks its invocations."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__("test")
|
|
|
|
async def before_run(self, **kwargs: Any) -> None:
|
|
provider_calls.append("before_run")
|
|
|
|
# Create context provider
|
|
context_provider = TestContextProvider()
|
|
|
|
# Create a mock chat client
|
|
mock_client = MockChatClient(name="test_agent")
|
|
|
|
# Create agent with context provider using proper constructor
|
|
agent = Agent(
|
|
client=mock_client,
|
|
name="test_agent",
|
|
id="test_agent",
|
|
context_providers=[context_provider],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
# Verify the original agent has the context provider
|
|
assert context_provider in agent.context_providers, "Original agent should have context provider"
|
|
|
|
# Build handoff workflow - this should clone the agent and preserve context_providers
|
|
workflow = HandoffBuilder(participants=[agent]).with_start_agent(agent).build()
|
|
|
|
# Run workflow with a simple message to trigger context provider
|
|
await _drain(workflow.run("Test message", stream=True))
|
|
|
|
# Verify context provider was invoked during the workflow execution
|
|
assert len(provider_calls) > 0, (
|
|
"Context provider should be called during workflow execution, "
|
|
"indicating it was properly preserved during agent cloning"
|
|
)
|
|
|
|
|
|
def test_handoff_builder_accepts_all_instances_in_add_handoff():
|
|
"""Test that add_handoff accepts all instances when using participants."""
|
|
triage = MockHandoffAgent(name="triage", handoff_to="specialist_a")
|
|
specialist_a = MockHandoffAgent(name="specialist_a")
|
|
specialist_b = MockHandoffAgent(name="specialist_b")
|
|
|
|
# This should work - all instances with participants
|
|
builder = (
|
|
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
|
.with_start_agent(triage)
|
|
.add_handoff(triage, [specialist_a, specialist_b])
|
|
)
|
|
|
|
workflow = builder.build()
|
|
assert "triage" in workflow.executors
|
|
assert "specialist_a" in workflow.executors
|
|
assert "specialist_b" in workflow.executors
|
|
|
|
|
|
async def test_auto_handoff_middleware_intercepts_handoff_tool_call() -> None:
|
|
"""Middleware should short-circuit matching handoff tool calls with a synthetic result."""
|
|
target_id = "specialist"
|
|
middleware = _AutoHandoffMiddleware([HandoffConfiguration(target=target_id)])
|
|
|
|
@tool(name=get_handoff_tool_name(target_id), approval_mode="never_require")
|
|
def handoff_tool() -> None:
|
|
pass
|
|
|
|
context = FunctionInvocationContext(function=handoff_tool, arguments={})
|
|
call_next = AsyncMock()
|
|
|
|
with pytest.raises(MiddlewareTermination) as exc_info:
|
|
await middleware.process(context, call_next)
|
|
|
|
call_next.assert_not_awaited()
|
|
expected_result = FunctionTool.parse_result({HANDOFF_FUNCTION_RESULT_KEY: target_id})
|
|
assert context.result == expected_result
|
|
assert exc_info.value.result == expected_result
|
|
|
|
|
|
async def test_auto_handoff_middleware_calls_next_for_non_handoff_tool() -> None:
|
|
"""Middleware should pass through when the function name is not a configured handoff tool."""
|
|
middleware = _AutoHandoffMiddleware([HandoffConfiguration(target="specialist")])
|
|
|
|
@tool(name="regular_tool", approval_mode="never_require")
|
|
def regular_tool() -> str:
|
|
return "ok"
|
|
|
|
context = FunctionInvocationContext(function=regular_tool, arguments={})
|
|
call_next = AsyncMock()
|
|
|
|
await middleware.process(context, call_next)
|
|
|
|
call_next.assert_awaited_once()
|
|
assert context.result is None
|
|
|
|
|
|
def test_handoff_builder_rejects_agents_without_per_service_call_history_persistence() -> None:
|
|
"""HandoffBuilder.build() should reject agents missing require_per_service_call_history_persistence."""
|
|
agent_without_flag = Agent(
|
|
client=MockChatClient(name="no_flag"),
|
|
name="no_flag",
|
|
id="no_flag",
|
|
# require_per_service_call_history_persistence defaults to False
|
|
)
|
|
agent_with_flag = MockHandoffAgent(name="has_flag") # MockHandoffAgent sets flag to True
|
|
|
|
with pytest.raises(ValueError, match="require_per_service_call_history_persistence"):
|
|
HandoffBuilder(participants=[agent_without_flag, agent_with_flag]).with_start_agent(agent_with_flag).build()
|
|
|
|
|
|
def test_handoff_builder_rejects_non_agent_supports_agent_run():
|
|
"""Verify that participants() rejects SupportsAgentRun implementations that are not Agent instances."""
|
|
from agent_framework import AgentResponse, AgentSession, SupportsAgentRun
|
|
|
|
class FakeAgentRun:
|
|
def __init__(self, id, name):
|
|
self.id = id
|
|
self.name = name
|
|
self.description = "d"
|
|
|
|
async def run(self, messages=None, *, stream=False, session=None, **kwargs):
|
|
return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
|
|
|
|
def create_session(self, **kwargs):
|
|
return AgentSession()
|
|
|
|
def get_session(self, *, service_session_id, **kwargs):
|
|
return AgentSession(service_session_id=service_session_id)
|
|
|
|
fake = FakeAgentRun("a", "A")
|
|
assert isinstance(fake, SupportsAgentRun)
|
|
|
|
with pytest.raises(TypeError, match="Participants must be Agent instances"):
|
|
HandoffBuilder().participants([fake])
|
|
|
|
|
|
# endregion
|
|
|
|
# region integration tests
|
|
|
|
|
|
try:
|
|
from agent_framework.foundry import FoundryChatClient
|
|
from azure.identity import AzureCliCredential
|
|
|
|
_has_foundry_deps = True
|
|
except ImportError:
|
|
_has_foundry_deps = False
|
|
|
|
skip_if_foundry_integration_tests_disabled = pytest.mark.skipif(
|
|
not _has_foundry_deps or os.getenv("FOUNDRY_PROJECT_ENDPOINT", "") == "" or os.getenv("FOUNDRY_MODEL", "") == "",
|
|
reason="No real FOUNDRY_PROJECT_ENDPOINT or FOUNDRY_MODEL provided; skipping integration tests.",
|
|
)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@skip_if_foundry_integration_tests_disabled
|
|
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
|
|
async def test_simple_handoff_workflow(store: bool) -> None:
|
|
"""Test a simple handoff workflow with two agents."""
|
|
client = FoundryChatClient(
|
|
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
|
model=os.environ["FOUNDRY_MODEL"],
|
|
credential=AzureCliCredential(),
|
|
)
|
|
|
|
triage_agent = Agent(
|
|
client=client,
|
|
instructions=(
|
|
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
|
|
"based on the problem described."
|
|
),
|
|
name="triage_agent",
|
|
default_options={"store": store},
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
refund_agent = Agent(
|
|
client=client,
|
|
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
|
|
name="refund_agent",
|
|
default_options={"store": store},
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[triage_agent, refund_agent],
|
|
termination_condition=lambda conversation: (
|
|
# We terminate after triage hands off to refund to test handoff works
|
|
len(conversation) > 0 and conversation[-1].author_name == refund_agent.name
|
|
),
|
|
)
|
|
.with_start_agent(triage_agent)
|
|
.build()
|
|
)
|
|
|
|
workflow_result = await workflow.run("I want to get a refund")
|
|
# The workflow should end in IDLE state rather than IDLE_WITH_PENDING_REQUESTS
|
|
# because the termination condition is met right after the refund agent's response.
|
|
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
|
|
# Output should contain responses from both agents and a final full conversation from between them.
|
|
assert len(workflow_result.get_outputs()) == 3
|
|
# There will be exactly one handoff request
|
|
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
|
|
assert len(handoff_event) == 1
|
|
assert isinstance(handoff_event[0].data, HandoffSentEvent)
|
|
assert handoff_event[0].data.source == triage_agent.name
|
|
assert handoff_event[0].data.target == refund_agent.name
|
|
|
|
|
|
@pytest.mark.integration
|
|
@skip_if_foundry_integration_tests_disabled
|
|
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
|
|
async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> None:
|
|
"""Test a simple handoff workflow with two agents where the second agent makes a request after handoff."""
|
|
client = FoundryChatClient(
|
|
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
|
model=os.environ["FOUNDRY_MODEL"],
|
|
credential=AzureCliCredential(),
|
|
)
|
|
|
|
triage_agent = Agent(
|
|
client=client,
|
|
instructions=(
|
|
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
|
|
"based on the problem described."
|
|
),
|
|
name="triage_agent",
|
|
default_options={"store": store},
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
refund_agent = Agent(
|
|
client=client,
|
|
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
|
|
name="refund_agent",
|
|
default_options={"store": store},
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[triage_agent, refund_agent],
|
|
termination_condition=lambda conversation: (
|
|
# We terminate after the refund agent request user input and the user provides
|
|
# a response. There will be two user messages in the conversation at that point
|
|
# - the original user message and the follow-up message in response to the refund
|
|
# agent's request.
|
|
len([message for message in conversation if message.role == "user"]) == 2
|
|
),
|
|
)
|
|
.with_start_agent(triage_agent)
|
|
.build()
|
|
)
|
|
|
|
workflow_result = await workflow.run("I want to get a refund")
|
|
# The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE
|
|
# because the user has not yet responded to the refund agent's request yet.
|
|
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
|
# There will be exactly one handoff request
|
|
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
|
|
assert len(handoff_event) == 1
|
|
assert isinstance(handoff_event[0].data, HandoffSentEvent)
|
|
assert handoff_event[0].data.source == triage_agent.name
|
|
assert handoff_event[0].data.target == refund_agent.name
|
|
# There should be exactly one request for information from the refund agent after handoff
|
|
request_events = [event for event in workflow_result if event.type == "request_info"]
|
|
assert len(request_events) == 1
|
|
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
|
|
# Provide the user's response to the refund agent's request to allow the workflow to complete.
|
|
workflow_result = await workflow.run(
|
|
responses={
|
|
request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"),
|
|
},
|
|
)
|
|
|
|
# The workflow should now end in IDLE state since the termination condition
|
|
# is met after the user's response to the refund agent's request.
|
|
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
|
|
|
|
|
|
@tool(approval_mode="always_require")
|
|
def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str:
|
|
"""Simulated function to process a refund for a given order number."""
|
|
return f"Refund processed successfully for order {order_number}."
|
|
|
|
|
|
@pytest.mark.integration
|
|
@skip_if_foundry_integration_tests_disabled
|
|
@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")])
|
|
async def test_simple_handoff_workflow_with_approval_request(store: bool) -> None:
|
|
"""Test a simple handoff workflow with two agents where the second agent makes a request after handoff."""
|
|
client = FoundryChatClient(
|
|
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
|
model=os.environ["FOUNDRY_MODEL"],
|
|
credential=AzureCliCredential(),
|
|
)
|
|
|
|
triage_agent = Agent(
|
|
client=client,
|
|
instructions=(
|
|
"You are frontline support triage. Route customer issues to the appropriate specialist agents "
|
|
"based on the problem described."
|
|
),
|
|
name="triage_agent",
|
|
default_options={"store": store},
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
refund_agent = Agent(
|
|
client=client,
|
|
instructions="You process refund requests. Ask user the ID of the order they want refunded.",
|
|
name="refund_agent",
|
|
default_options={"store": store},
|
|
tools=[process_refund],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
|
|
# This workflow will be terminated manually
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[triage_agent, refund_agent],
|
|
)
|
|
.with_start_agent(triage_agent)
|
|
.build()
|
|
)
|
|
|
|
workflow_result = await workflow.run("I want to get a refund")
|
|
# The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE
|
|
# because the user has not yet responded to the refund agent's request yet.
|
|
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
|
# There will be exactly one handoff request
|
|
handoff_event = [event for event in workflow_result if event.type == "handoff_sent"]
|
|
assert len(handoff_event) == 1
|
|
assert isinstance(handoff_event[0].data, HandoffSentEvent)
|
|
assert handoff_event[0].data.source == triage_agent.name
|
|
assert handoff_event[0].data.target == refund_agent.name
|
|
# There should be exactly one request for information from the refund agent after handoff
|
|
request_events = [event for event in workflow_result if event.type == "request_info"]
|
|
assert len(request_events) == 1
|
|
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
|
|
# Provide the user's response to the refund agent's request to allow the workflow to complete.
|
|
workflow_result = await workflow.run(
|
|
responses={
|
|
request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"),
|
|
},
|
|
)
|
|
|
|
# The workflow should now end in IDLE_WITH_PENDING_REQUESTS state since the refund agent will ask for
|
|
# approval to process the refund after receiving the user's response.
|
|
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
|
|
|
# There should be exactly one request for tool approval from the refund agent.
|
|
request_events = [event for event in workflow_result if event.type == "request_info"]
|
|
assert len(request_events) == 1
|
|
assert isinstance(request_events[0].data, Content) and request_events[0].data.type == "function_approval_request"
|
|
|
|
# Provide the user's response to the refund agent's request to allow the workflow to complete.
|
|
workflow_result = await workflow.run(
|
|
responses={request_events[0].request_id: request_events[0].data.to_function_approval_response(approved=True)}
|
|
)
|
|
|
|
# The refund agent will process the refund after receiving approval, but since there is no termination condition,
|
|
# the workflow will end in IDLE_WITH_PENDING_REQUESTS state waiting for further user input.
|
|
assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
|
# There should be exactly one request for information from the refund agent after processing the refund,
|
|
# which is the follow-up question asking if there is anything else they can help with.
|
|
request_events = [event for event in workflow_result if event.type == "request_info"]
|
|
assert len(request_events) == 1
|
|
assert isinstance(request_events[0].data, HandoffAgentUserRequest)
|
|
workflow_result = await workflow.run(responses={request_events[0].request_id: HandoffAgentUserRequest.terminate()})
|
|
|
|
assert workflow_result.get_final_state() == WorkflowRunState.IDLE
|
|
|
|
|
|
# endregion
|