mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
b065a4ce51
* Rename provider base APIs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Allow provider-added chat and function middleware Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Simulate service-stored history per model call Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix typing regressions in CI Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix response ID suppression review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Rename per-service-call history persistence APIs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address context persistence review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Stabilize markdown sample docs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Persist service continuation state per call Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1163 lines
45 KiB
Python
1163 lines
45 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import re
|
|
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
|
from typing import Any, cast
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from agent_framework import (
|
|
Agent,
|
|
ChatResponse,
|
|
ChatResponseUpdate,
|
|
Content,
|
|
ContextProvider,
|
|
InMemoryHistoryProvider,
|
|
Message,
|
|
ResponseStream,
|
|
WorkflowEvent,
|
|
resolve_agent_id,
|
|
tool,
|
|
)
|
|
from agent_framework._clients import BaseChatClient
|
|
from agent_framework._middleware import ChatMiddlewareLayer, FunctionInvocationContext, MiddlewareTermination
|
|
from agent_framework._tools import FunctionInvocationLayer, FunctionTool
|
|
from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder
|
|
|
|
from agent_framework_orchestrations._handoff import (
|
|
HANDOFF_FUNCTION_RESULT_KEY,
|
|
HandoffAgentExecutor,
|
|
HandoffConfiguration,
|
|
_AutoHandoffMiddleware, # pyright: ignore[reportPrivateUsage]
|
|
get_handoff_tool_name,
|
|
)
|
|
from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff
|
|
|
|
|
|
class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
"""Mock chat client for testing handoff workflows."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str = "",
|
|
handoff_to: str | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the mock chat client.
|
|
|
|
Args:
|
|
name: The name of the agent using this chat client.
|
|
handoff_to: The name of the agent to hand off to, or None for no handoff.
|
|
This is hardcoded for testing purposes so that the agent always attempts to hand off.
|
|
"""
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._name = name
|
|
self._handoff_to = handoff_to
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
if stream:
|
|
return self._build_streaming_response(options=dict(options))
|
|
|
|
async def _get() -> ChatResponse:
|
|
contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id())
|
|
reply = Message(
|
|
role="assistant",
|
|
contents=contents,
|
|
)
|
|
return ChatResponse(messages=reply, response_id="mock_response")
|
|
|
|
return _get()
|
|
|
|
def _build_streaming_response(self, *, options: dict[str, Any]) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id())
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
|
response_format = options.get("response_format")
|
|
output_format_type = response_format if isinstance(response_format, type) else None
|
|
return ChatResponse.from_updates(updates, output_format_type=output_format_type)
|
|
|
|
return ResponseStream(_stream(), finalizer=_finalize)
|
|
|
|
def _next_call_id(self) -> str | None:
|
|
if not self._handoff_to:
|
|
return None
|
|
call_id = f"{self._name}-handoff-{self._call_index}"
|
|
self._call_index += 1
|
|
return call_id
|
|
|
|
|
|
def _build_reply_contents(
|
|
agent_name: str,
|
|
handoff_to: str | None,
|
|
call_id: str | None,
|
|
) -> list[Content]:
|
|
contents: list[Content] = []
|
|
if handoff_to and call_id:
|
|
contents.append(
|
|
Content.from_function_call(
|
|
call_id=call_id, name=f"handoff_to_{handoff_to}", arguments={"handoff_to": handoff_to}
|
|
)
|
|
)
|
|
text = f"{agent_name} reply"
|
|
contents.append(Content.from_text(text=text))
|
|
return contents
|
|
|
|
|
|
class MockHandoffAgent(Agent):
|
|
"""Mock agent that can hand off to another agent."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
name: str,
|
|
handoff_to: str | None = None,
|
|
) -> None:
|
|
"""Initialize the mock handoff agent.
|
|
|
|
Args:
|
|
name: The name of the agent.
|
|
handoff_to: The name of the agent to hand off to, or None for no handoff.
|
|
This is hardcoded for testing purposes so that the agent always attempts to hand off.
|
|
"""
|
|
super().__init__(client=MockChatClient(name=name, handoff_to=handoff_to), name=name, id=name)
|
|
|
|
|
|
class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
"""Mock client that expects prior user context to remain available on resume."""
|
|
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del kwargs
|
|
del options
|
|
|
|
contents = self._next_contents(messages)
|
|
if stream:
|
|
return self._build_streaming_response(contents)
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="context-aware")
|
|
|
|
return _get()
|
|
|
|
def _build_streaming_response(self, contents: list[Content]) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
|
return ChatResponse.from_updates(updates)
|
|
|
|
return ResponseStream(_stream(), finalizer=_finalize)
|
|
|
|
def _next_contents(self, messages: Sequence[Message]) -> list[Content]:
|
|
user_text = " ".join(message.text or "" for message in messages if message.role == "user")
|
|
order_match = re.search(r"\b(\d{4,12})\b", user_text)
|
|
order_id = order_match.group(1) if order_match else None
|
|
asks_refund = any(token in user_text.lower() for token in ("broken", "damaged", "refund", "cracked"))
|
|
|
|
if self._call_index == 0:
|
|
reply = "Refund Agent: Please share your order number."
|
|
elif self._call_index == 1:
|
|
if order_id:
|
|
reply = f"Refund Agent: Thanks, I found order {order_id}. Why do you need the refund?"
|
|
else:
|
|
reply = "Refund Agent: I still need your order number."
|
|
else:
|
|
if order_id and asks_refund:
|
|
reply = f"Refund Agent: Got it for order {order_id}. I can proceed with your refund."
|
|
else:
|
|
reply = "Refund Agent: I still need your order number."
|
|
|
|
self._call_index += 1
|
|
return [Content.from_text(text=reply)]
|
|
|
|
|
|
async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]:
|
|
return [event async for event in stream]
|
|
|
|
|
|
async def test_handoff():
|
|
"""Test that agents can hand off to each other."""
|
|
|
|
# `triage` hands off to `specialist`, who then hands off to `escalation`.
|
|
# `escalation` has no handoff, so the workflow should request user input to continue.
|
|
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
|
specialist = MockHandoffAgent(name="specialist", handoff_to="escalation")
|
|
escalation = MockHandoffAgent(name="escalation")
|
|
|
|
# Without explicitly defining handoffs, the builder will create connections
|
|
# between all agents.
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[triage, specialist, escalation],
|
|
termination_condition=lambda conv: sum(1 for m in conv if m.role == "user") >= 2,
|
|
)
|
|
.with_start_agent(triage)
|
|
.build()
|
|
)
|
|
|
|
# Start conversation - triage hands off to specialist then escalation
|
|
# escalation won't trigger a handoff, so the response from it will become
|
|
# a request for user input because autonomous mode is not enabled by default.
|
|
events = await _drain(workflow.run("Need technical support", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
|
|
assert requests
|
|
assert len(requests) == 1
|
|
|
|
request = requests[0]
|
|
assert isinstance(request.data, HandoffAgentUserRequest)
|
|
assert request.source_executor_id == escalation.name
|
|
|
|
|
|
def _latest_request_info_event(events: list[WorkflowEvent]) -> WorkflowEvent[Any]:
|
|
request_events = [event for event in events if event.type == "request_info"]
|
|
assert request_events
|
|
request_event = request_events[-1]
|
|
assert isinstance(request_event.data, HandoffAgentUserRequest)
|
|
return request_event
|
|
|
|
|
|
def _request_text(event: WorkflowEvent[Any]) -> str:
|
|
request_payload = cast(HandoffAgentUserRequest, event.data)
|
|
messages = request_payload.agent_response.messages
|
|
assert messages
|
|
return messages[-1].text or ""
|
|
|
|
|
|
async def test_resume_keeps_prior_user_context_for_same_agent() -> None:
|
|
"""Ensure same-agent request_info resumes retain prior turn context."""
|
|
refund_agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=ContextAwareRefundClient(),
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[refund_agent], termination_condition=lambda _: False)
|
|
.with_start_agent(refund_agent)
|
|
.build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("My order arrived damaged.", stream=True))
|
|
first_request = _latest_request_info_event(first_events)
|
|
assert "order number" in _request_text(first_request).lower()
|
|
|
|
second_events = await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={first_request.request_id: [Message(role="user", text="Order 2939393")]},
|
|
)
|
|
)
|
|
second_request = _latest_request_info_event(second_events)
|
|
second_text = _request_text(second_request).lower()
|
|
assert "order 2939393" in second_text
|
|
assert "order number" not in second_text
|
|
|
|
third_events = await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={second_request.request_id: [Message(role="user", text="It arrived broken and unusable.")]},
|
|
)
|
|
)
|
|
third_request = _latest_request_info_event(third_events)
|
|
third_text = _request_text(third_request).lower()
|
|
assert "order 2939393" in third_text
|
|
assert "order number" not in third_text
|
|
|
|
|
|
async def test_tool_approval_responses_are_not_replayed_from_history() -> None:
|
|
"""Ensure persisted history does not re-execute previously approved tool calls."""
|
|
execution_count = 0
|
|
|
|
@tool(name="submit_refund_counted", approval_mode="always_require")
|
|
def submit_refund_counted() -> str:
|
|
nonlocal execution_count
|
|
execution_count += 1
|
|
return "ok"
|
|
|
|
class ApprovalReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del messages
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [
|
|
Content.from_function_call(
|
|
call_id="refund-call-1",
|
|
name="submit_refund_counted",
|
|
arguments={},
|
|
)
|
|
]
|
|
elif self._call_index == 1:
|
|
contents = [Content.from_text(text="Refund approved and recorded.")]
|
|
else:
|
|
contents = [Content.from_text(text="No additional tool work needed.")]
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=contents)],
|
|
response_id="approval-replay",
|
|
)
|
|
|
|
return _get()
|
|
|
|
agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=ApprovalReplayClient(),
|
|
tools=[submit_refund_counted],
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("start", stream=True))
|
|
first_requests = [event for event in first_events if event.type == "request_info"]
|
|
assert first_requests
|
|
first_request = first_requests[-1]
|
|
assert isinstance(first_request.data, Content)
|
|
approval_response = first_request.data.to_function_approval_response(approved=True)
|
|
|
|
second_events = await _drain(workflow.run(stream=True, responses={first_request.request_id: approval_response}))
|
|
second_request = _latest_request_info_event(second_events)
|
|
|
|
await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={second_request.request_id: [Message(role="user", text="Thanks, what's next?")]},
|
|
)
|
|
)
|
|
|
|
assert execution_count == 1
|
|
|
|
|
|
async def test_handoff_resume_preserves_approval_function_call_for_stateless_runs() -> None:
|
|
"""Approval resume turns must replay matching function calls when store=False."""
|
|
|
|
@tool(name="submit_refund", approval_mode="always_require")
|
|
def submit_refund() -> str:
|
|
return "ok"
|
|
|
|
class StrictStatelessApprovalClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
self.resume_validated = False
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [
|
|
Content.from_function_call(
|
|
call_id="refund-call-1",
|
|
name="submit_refund",
|
|
arguments={},
|
|
)
|
|
]
|
|
else:
|
|
function_call_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_call" and content.call_id
|
|
}
|
|
function_result_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_result" and content.call_id
|
|
}
|
|
missing_call_ids = sorted(function_result_ids - function_call_ids)
|
|
if missing_call_ids:
|
|
raise AssertionError(
|
|
f"No tool call found for function call output with call_id {missing_call_ids[0]}."
|
|
)
|
|
self.resume_validated = True
|
|
contents = [Content.from_text(text="Refund submitted.")]
|
|
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=contents)],
|
|
response_id="strict-stateless",
|
|
)
|
|
|
|
return _get()
|
|
|
|
client = StrictStatelessApprovalClient()
|
|
agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=client,
|
|
tools=[submit_refund],
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("start", stream=True))
|
|
approval_requests = [
|
|
event for event in first_events if event.type == "request_info" and isinstance(event.data, Content)
|
|
]
|
|
assert approval_requests
|
|
first_request = approval_requests[0]
|
|
|
|
approval_response = first_request.data.to_function_approval_response(True)
|
|
await _drain(workflow.run(stream=True, responses={first_request.request_id: approval_response}))
|
|
|
|
assert client.resume_validated is True
|
|
|
|
|
|
async def test_handoff_replay_serializes_handoff_function_results() -> None:
|
|
"""Returning to the same agent must not replay dict tool outputs."""
|
|
|
|
class ReplaySafeHandoffClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self, name: str, handoff_sequence: list[str | None]) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._name = name
|
|
self._handoff_sequence = handoff_sequence
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del options
|
|
del kwargs
|
|
|
|
for message in messages:
|
|
for content in message.contents:
|
|
if content.type == "function_result" and isinstance(content.result, dict):
|
|
raise AssertionError("Expected replayed function_result payloads to be JSON strings.")
|
|
|
|
handoff_to = (
|
|
self._handoff_sequence[self._call_index] if self._call_index < len(self._handoff_sequence) else None
|
|
)
|
|
call_id = f"{self._name}-handoff-{self._call_index}" if handoff_to else None
|
|
contents = _build_reply_contents(self._name, handoff_to, call_id)
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="replay-safe")
|
|
|
|
return _get()
|
|
|
|
triage = Agent(
|
|
id="triage",
|
|
name="triage",
|
|
client=ReplaySafeHandoffClient(name="triage", handoff_sequence=["specialist", None]),
|
|
)
|
|
specialist = Agent(
|
|
id="specialist",
|
|
name="specialist",
|
|
client=ReplaySafeHandoffClient(name="specialist", handoff_sequence=["triage"]),
|
|
)
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False)
|
|
.with_start_agent(triage)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("start", stream=True))
|
|
requests = [event for event in events if event.type == "request_info"]
|
|
assert requests
|
|
assert requests[-1].source_executor_id == triage.name
|
|
|
|
|
|
async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs() -> None:
|
|
"""Approved calls must keep function_call/function_result pairs for later replays."""
|
|
submit_call_id = "call_submit_refund_approved"
|
|
|
|
@tool(name="submit_refund", approval_mode="always_require")
|
|
def submit_refund() -> str:
|
|
return "submitted"
|
|
|
|
class RefundReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
self.resume_validated = False
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [Content.from_function_call(call_id=submit_call_id, name="submit_refund", arguments={})]
|
|
elif self._call_index == 1:
|
|
contents = _build_reply_contents("refund_agent", "order_agent", "refund-order-handoff-1")
|
|
else:
|
|
function_call_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_call" and content.call_id
|
|
}
|
|
function_result_ids = {
|
|
content.call_id
|
|
for message in messages
|
|
for content in message.contents
|
|
if content.type == "function_result" and content.call_id
|
|
}
|
|
if submit_call_id in function_call_ids and submit_call_id not in function_result_ids:
|
|
raise AssertionError(f"No tool output found for function call {submit_call_id}.")
|
|
self.resume_validated = True
|
|
contents = [Content.from_text(text="Refund agent resumed.")]
|
|
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", contents=contents)],
|
|
response_id="refund-replay",
|
|
)
|
|
|
|
return _get()
|
|
|
|
class OrderReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
self._call_index = 0
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del messages
|
|
del options
|
|
del kwargs
|
|
|
|
if self._call_index == 0:
|
|
contents = [Content.from_text(text="Would you like a replacement or a refund?")]
|
|
else:
|
|
contents = _build_reply_contents("order_agent", "refund_agent", "order-refund-handoff-1")
|
|
self._call_index += 1
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="order-replay")
|
|
|
|
return _get()
|
|
|
|
refund_client = RefundReplayClient()
|
|
refund_agent = Agent(
|
|
id="refund_agent",
|
|
name="refund_agent",
|
|
client=refund_client,
|
|
tools=[submit_refund],
|
|
)
|
|
order_agent = Agent(
|
|
id="order_agent",
|
|
name="order_agent",
|
|
client=OrderReplayClient(),
|
|
)
|
|
workflow = (
|
|
HandoffBuilder(participants=[refund_agent, order_agent], termination_condition=lambda _: False)
|
|
.with_start_agent(refund_agent)
|
|
.build()
|
|
)
|
|
|
|
first_events = await _drain(workflow.run("start", stream=True))
|
|
approval_requests = [
|
|
event for event in first_events if event.type == "request_info" and isinstance(event.data, Content)
|
|
]
|
|
assert approval_requests
|
|
approval_request = approval_requests[-1]
|
|
approval_response = approval_request.data.to_function_approval_response(True)
|
|
|
|
second_events = await _drain(workflow.run(stream=True, responses={approval_request.request_id: approval_response}))
|
|
order_request = _latest_request_info_event(second_events)
|
|
assert order_request.source_executor_id == order_agent.name
|
|
|
|
await _drain(
|
|
workflow.run(
|
|
stream=True,
|
|
responses={order_request.request_id: [Message(role="user", text="Please continue with refund.")]},
|
|
)
|
|
)
|
|
|
|
assert refund_client.resume_validated is True
|
|
|
|
|
|
def test_handoff_clone_disables_provider_side_storage() -> None:
|
|
"""Handoff executors should force store=False to avoid stale provider call state."""
|
|
triage = MockHandoffAgent(name="triage")
|
|
workflow = HandoffBuilder(participants=[triage]).with_start_agent(triage).build()
|
|
|
|
executor = workflow.executors[resolve_agent_id(triage)]
|
|
assert isinstance(executor, HandoffAgentExecutor)
|
|
assert executor._agent.default_options.get("store") is False
|
|
|
|
|
|
async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None:
|
|
"""Handoff clones should keep per-service-call history persistence active for auto-handoff termination."""
|
|
triage_history = InMemoryHistoryProvider()
|
|
triage = Agent(
|
|
id="triage",
|
|
name="triage",
|
|
client=MockChatClient(name="triage", handoff_to="specialist"),
|
|
context_providers=[triage_history],
|
|
require_per_service_call_history_persistence=True,
|
|
)
|
|
specialist = Agent(
|
|
id="specialist",
|
|
name="specialist",
|
|
client=MockChatClient(name="specialist"),
|
|
default_options={"tool_choice": "none"},
|
|
)
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False)
|
|
.with_start_agent(triage)
|
|
.add_handoff(triage, [specialist])
|
|
.add_handoff(specialist, [triage])
|
|
.build()
|
|
)
|
|
|
|
await _drain(workflow.run("start", stream=True))
|
|
|
|
executor = workflow.executors[resolve_agent_id(triage)]
|
|
assert isinstance(executor, HandoffAgentExecutor)
|
|
assert executor._agent.require_per_service_call_history_persistence is True
|
|
|
|
provider_state = executor._session.state[triage_history.source_id]
|
|
stored_messages = await triage_history.get_messages(
|
|
executor._session.session_id,
|
|
state=provider_state,
|
|
)
|
|
|
|
assert [message.role for message in stored_messages] == ["user", "assistant"]
|
|
assert any(content.type == "function_call" for content in stored_messages[-1].contents)
|
|
assert all(message.role != "tool" for message in stored_messages)
|
|
|
|
|
|
async def test_handoff_clears_stale_service_session_id_before_run() -> None:
|
|
"""Stale service session IDs must be dropped before each handoff agent turn."""
|
|
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
|
specialist = MockHandoffAgent(name="specialist")
|
|
workflow = HandoffBuilder(participants=[triage, specialist]).with_start_agent(triage).build()
|
|
|
|
triage_executor = workflow.executors[resolve_agent_id(triage)]
|
|
assert isinstance(triage_executor, HandoffAgentExecutor)
|
|
triage_executor._session.service_session_id = "resp_stale_value"
|
|
|
|
await _drain(workflow.run("My order is damaged", stream=True))
|
|
|
|
assert triage_executor._session.service_session_id is None
|
|
|
|
|
|
def test_clean_conversation_for_handoff_keeps_text_only_history() -> None:
|
|
"""Tool-control messages must be excluded from persisted handoff history."""
|
|
function_call = Content.from_function_call(
|
|
call_id="handoff-call-1",
|
|
name="handoff_to_refund_agent",
|
|
arguments={"context": "route to refund"},
|
|
)
|
|
approval_response = Content.from_function_approval_response(
|
|
approved=True,
|
|
id="approval-1",
|
|
function_call=function_call,
|
|
)
|
|
|
|
conversation = [
|
|
Message(role="user", text="My order arrived damaged."),
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
function_call,
|
|
Content.from_text(text="Triage Agent: Routing you to Refund."),
|
|
],
|
|
),
|
|
Message(role="tool", contents=[Content.from_function_result(call_id="handoff-call-1", result="ok")]),
|
|
Message(role="user", contents=[approval_response]),
|
|
Message(
|
|
role="assistant",
|
|
contents=[Content.from_function_call(call_id="handoff-call-2", name="handoff_to_order_agent")],
|
|
),
|
|
]
|
|
|
|
cleaned = clean_conversation_for_handoff(conversation)
|
|
assert [message.role for message in cleaned] == ["user", "assistant"]
|
|
assert [message.text for message in cleaned] == [
|
|
"My order arrived damaged.",
|
|
"Triage Agent: Routing you to Refund.",
|
|
]
|
|
|
|
|
|
def test_persist_missing_approved_function_results_handles_runtime_and_fallback_outputs() -> None:
|
|
"""Persisted history should retain approved call outputs across runtime shapes."""
|
|
agent = MockHandoffAgent(name="triage")
|
|
executor = HandoffAgentExecutor(agent, handoffs=[])
|
|
|
|
call_with_runtime_result = "call-runtime-result"
|
|
call_with_approval_only = "call-approval-only"
|
|
|
|
executor._full_conversation = [
|
|
Message(
|
|
role="assistant",
|
|
contents=[
|
|
Content.from_function_call(call_id=call_with_runtime_result, name="submit_refund", arguments={}),
|
|
Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}),
|
|
],
|
|
)
|
|
]
|
|
|
|
approval_response = Content.from_function_approval_response(
|
|
approved=True,
|
|
id=call_with_approval_only,
|
|
function_call=Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}),
|
|
)
|
|
runtime_messages = [
|
|
Message(
|
|
role="tool",
|
|
contents=[Content.from_function_result(call_id=call_with_runtime_result, result='{"submitted":true}')],
|
|
),
|
|
Message(role="user", contents=[approval_response]),
|
|
]
|
|
|
|
executor._persist_missing_approved_function_results(runtime_tool_messages=runtime_messages, response_messages=[])
|
|
|
|
persisted_tool_messages = [message for message in executor._full_conversation if message.role == "tool"]
|
|
assert persisted_tool_messages
|
|
persisted_results = [
|
|
content
|
|
for message in persisted_tool_messages
|
|
for content in message.contents
|
|
if content.type == "function_result" and content.call_id
|
|
]
|
|
result_by_call_id = {content.call_id: content.result for content in persisted_results}
|
|
assert result_by_call_id[call_with_runtime_result] == '{"submitted":true}'
|
|
assert result_by_call_id[call_with_approval_only] == '{"status":"approved"}'
|
|
|
|
|
|
async def test_autonomous_mode_yields_output_without_user_request():
|
|
"""Ensure autonomous interaction mode yields output without requesting user input."""
|
|
triage = MockHandoffAgent(name="triage", handoff_to="specialist")
|
|
specialist = MockHandoffAgent(name="specialist")
|
|
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[triage, specialist],
|
|
# This termination condition ensures the workflow runs through both agents.
|
|
# First message is the user message to triage, second is triage's response, which
|
|
# is a handoff to specialist, third is specialist's response that should not request
|
|
# user input due to autonomous mode. Fourth message will come from the specialist
|
|
# again and will trigger termination.
|
|
termination_condition=lambda conv: len(conv) >= 4,
|
|
)
|
|
.with_start_agent(triage)
|
|
# Since specialist has no handoff, the specialist will be generating normal responses.
|
|
# With autonomous mode, this should continue until the termination condition is met.
|
|
.with_autonomous_mode(
|
|
agents=[specialist],
|
|
turn_limits={resolve_agent_id(specialist): 1},
|
|
)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("Package arrived broken", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
assert not requests, "Autonomous mode should not request additional user input"
|
|
|
|
outputs = [ev for ev in events if ev.type == "output"]
|
|
assert outputs, "Autonomous mode should yield a workflow output"
|
|
|
|
final_conversation = outputs[-1].data
|
|
assert isinstance(final_conversation, list)
|
|
conversation_list = cast(list[Message], final_conversation)
|
|
assert any(msg.role == "assistant" and (msg.text or "").startswith("specialist reply") for msg in conversation_list)
|
|
|
|
|
|
async def test_autonomous_mode_resumes_user_input_on_turn_limit():
|
|
"""Autonomous mode should resume user input request when turn limit is reached."""
|
|
triage = MockHandoffAgent(name="triage", handoff_to="worker")
|
|
worker = MockHandoffAgent(name="worker")
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[triage, worker], termination_condition=lambda conv: False)
|
|
.with_start_agent(triage)
|
|
.with_autonomous_mode(agents=[worker], turn_limits={resolve_agent_id(worker): 2})
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("Start", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
assert requests and len(requests) == 1, "Turn limit should force a user input request"
|
|
assert requests[0].source_executor_id == worker.name
|
|
|
|
|
|
def test_build_fails_without_start_agent():
|
|
"""Verify that build() raises ValueError when with_start_agent() was not called."""
|
|
triage = MockHandoffAgent(name="triage")
|
|
specialist = MockHandoffAgent(name="specialist")
|
|
|
|
with pytest.raises(ValueError, match=r"Must call with_start_agent\(...\) before building the workflow."):
|
|
HandoffBuilder(participants=[triage, specialist]).build()
|
|
|
|
|
|
def test_build_fails_without_participants():
|
|
"""Verify that build() raises ValueError when no participants are provided."""
|
|
with pytest.raises(ValueError):
|
|
HandoffBuilder(participants=[]).build()
|
|
|
|
|
|
async def test_handoff_async_termination_condition() -> None:
|
|
"""Test that async termination conditions work correctly."""
|
|
termination_call_count = 0
|
|
|
|
async def async_termination(conv: list[Message]) -> bool:
|
|
nonlocal termination_call_count
|
|
termination_call_count += 1
|
|
user_count = sum(1 for msg in conv if msg.role == "user")
|
|
return user_count >= 2
|
|
|
|
coordinator = MockHandoffAgent(name="coordinator", handoff_to="worker")
|
|
worker = MockHandoffAgent(name="worker")
|
|
|
|
workflow = (
|
|
HandoffBuilder(participants=[coordinator, worker], termination_condition=async_termination)
|
|
.with_start_agent(coordinator)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("First user message", stream=True))
|
|
requests = [ev for ev in events if ev.type == "request_info"]
|
|
assert requests
|
|
|
|
events = await _drain(
|
|
workflow.run(
|
|
stream=True, responses={requests[-1].request_id: [Message(role="user", text="Second user message")]}
|
|
)
|
|
)
|
|
outputs = [ev for ev in events if ev.type == "output"]
|
|
assert len(outputs) == 1
|
|
|
|
final_conversation = outputs[0].data
|
|
assert isinstance(final_conversation, list)
|
|
final_conv_list = cast(list[Message], final_conversation)
|
|
user_messages = [msg for msg in final_conv_list if msg.role == "user"]
|
|
assert len(user_messages) == 2
|
|
assert termination_call_count > 0
|
|
|
|
|
|
async def test_handoff_terminates_without_request_info_when_latest_response_meets_condition() -> None:
|
|
"""Termination triggered by the latest assistant response should not emit request_info."""
|
|
|
|
class FinalizingClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
|
|
def __init__(self) -> None:
|
|
ChatMiddlewareLayer.__init__(self)
|
|
FunctionInvocationLayer.__init__(self)
|
|
BaseChatClient.__init__(self)
|
|
|
|
def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: Sequence[Message],
|
|
stream: bool,
|
|
options: Mapping[str, Any],
|
|
**kwargs: Any,
|
|
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
|
del messages, options, kwargs
|
|
contents = [Content.from_text(text="Replacement request submitted. Case complete.")]
|
|
|
|
if stream:
|
|
|
|
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
|
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")
|
|
|
|
return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))
|
|
|
|
async def _get() -> ChatResponse:
|
|
return ChatResponse(messages=[Message(role="assistant", contents=contents)], response_id="finalizing")
|
|
|
|
return _get()
|
|
|
|
agent = Agent(id="order_agent", name="order_agent", client=FinalizingClient())
|
|
workflow = (
|
|
HandoffBuilder(
|
|
participants=[agent],
|
|
termination_condition=lambda conv: any(
|
|
message.role == "assistant" and "case complete." in (message.text or "").lower() for message in conv
|
|
),
|
|
)
|
|
.with_start_agent(agent)
|
|
.build()
|
|
)
|
|
|
|
events = await _drain(workflow.run("ship replacement", stream=True))
|
|
|
|
requests = [event for event in events if event.type == "request_info"]
|
|
assert not requests
|
|
|
|
outputs = [event for event in events if event.type == "output"]
|
|
assert outputs
|
|
conversation_outputs = [event for event in outputs if isinstance(event.data, list)]
|
|
assert len(conversation_outputs) == 1
|
|
|
|
|
|
async def test_tool_choice_preserved_from_agent_config():
|
|
"""Verify that agent-level tool_choice configuration is preserved and not overridden."""
|
|
# Create a mock chat client that records the tool_choice used
|
|
recorded_tool_choices: list[Any] = []
|
|
|
|
async def mock_get_response(messages: Any, options: dict[str, Any] | None = None, **kwargs: Any) -> ChatResponse:
|
|
if options:
|
|
recorded_tool_choices.append(options.get("tool_choice"))
|
|
return ChatResponse(
|
|
messages=[Message(role="assistant", text="Response")],
|
|
response_id="test_response",
|
|
)
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.get_response = AsyncMock(side_effect=mock_get_response)
|
|
|
|
# Create agent with specific tool_choice configuration via default_options
|
|
agent = Agent(
|
|
client=mock_client,
|
|
name="test_agent",
|
|
default_options={"tool_choice": {"mode": "required"}}, # type: ignore
|
|
)
|
|
|
|
# Run the agent
|
|
await agent.run("Test message")
|
|
|
|
# Verify tool_choice was preserved
|
|
assert len(recorded_tool_choices) > 0, "No tool_choice recorded"
|
|
last_tool_choice = recorded_tool_choices[-1]
|
|
assert last_tool_choice is not None, "tool_choice should not be None"
|
|
assert last_tool_choice == {"mode": "required"}, f"Expected 'required', got {last_tool_choice}"
|
|
|
|
|
|
async def test_context_provider_preserved_during_handoff():
|
|
"""Verify that context_providers are preserved when cloning agents in handoff workflows."""
|
|
# Track whether context provider methods were called
|
|
provider_calls: list[str] = []
|
|
|
|
class TestContextProvider(ContextProvider):
|
|
"""A test context provider that tracks its invocations."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__("test")
|
|
|
|
async def before_run(self, **kwargs: Any) -> None:
|
|
provider_calls.append("before_run")
|
|
|
|
# Create context provider
|
|
context_provider = TestContextProvider()
|
|
|
|
# Create a mock chat client
|
|
mock_client = MockChatClient(name="test_agent")
|
|
|
|
# Create agent with context provider using proper constructor
|
|
agent = Agent(
|
|
client=mock_client,
|
|
name="test_agent",
|
|
id="test_agent",
|
|
context_providers=[context_provider],
|
|
)
|
|
|
|
# Verify the original agent has the context provider
|
|
assert context_provider in agent.context_providers, "Original agent should have context provider"
|
|
|
|
# Build handoff workflow - this should clone the agent and preserve context_providers
|
|
workflow = HandoffBuilder(participants=[agent]).with_start_agent(agent).build()
|
|
|
|
# Run workflow with a simple message to trigger context provider
|
|
await _drain(workflow.run("Test message", stream=True))
|
|
|
|
# Verify context provider was invoked during the workflow execution
|
|
assert len(provider_calls) > 0, (
|
|
"Context provider should be called during workflow execution, "
|
|
"indicating it was properly preserved during agent cloning"
|
|
)
|
|
|
|
|
|
def test_handoff_builder_accepts_all_instances_in_add_handoff():
|
|
"""Test that add_handoff accepts all instances when using participants."""
|
|
triage = MockHandoffAgent(name="triage", handoff_to="specialist_a")
|
|
specialist_a = MockHandoffAgent(name="specialist_a")
|
|
specialist_b = MockHandoffAgent(name="specialist_b")
|
|
|
|
# This should work - all instances with participants
|
|
builder = (
|
|
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
|
.with_start_agent(triage)
|
|
.add_handoff(triage, [specialist_a, specialist_b])
|
|
)
|
|
|
|
workflow = builder.build()
|
|
assert "triage" in workflow.executors
|
|
assert "specialist_a" in workflow.executors
|
|
assert "specialist_b" in workflow.executors
|
|
|
|
|
|
async def test_auto_handoff_middleware_intercepts_handoff_tool_call() -> None:
|
|
"""Middleware should short-circuit matching handoff tool calls with a synthetic result."""
|
|
target_id = "specialist"
|
|
middleware = _AutoHandoffMiddleware([HandoffConfiguration(target=target_id)])
|
|
|
|
@tool(name=get_handoff_tool_name(target_id), approval_mode="never_require")
|
|
def handoff_tool() -> str:
|
|
return "unreachable"
|
|
|
|
context = FunctionInvocationContext(function=handoff_tool, arguments={})
|
|
call_next = AsyncMock()
|
|
|
|
with pytest.raises(MiddlewareTermination) as exc_info:
|
|
await middleware.process(context, call_next)
|
|
|
|
call_next.assert_not_awaited()
|
|
expected_result = FunctionTool.parse_result({HANDOFF_FUNCTION_RESULT_KEY: target_id})
|
|
assert context.result == expected_result
|
|
assert exc_info.value.result == expected_result
|
|
|
|
|
|
async def test_auto_handoff_middleware_calls_next_for_non_handoff_tool() -> None:
|
|
"""Middleware should pass through when the function name is not a configured handoff tool."""
|
|
middleware = _AutoHandoffMiddleware([HandoffConfiguration(target="specialist")])
|
|
|
|
@tool(name="regular_tool", approval_mode="never_require")
|
|
def regular_tool() -> str:
|
|
return "ok"
|
|
|
|
context = FunctionInvocationContext(function=regular_tool, arguments={})
|
|
call_next = AsyncMock()
|
|
|
|
await middleware.process(context, call_next)
|
|
|
|
call_next.assert_awaited_once()
|
|
assert context.result is None
|
|
|
|
|
|
def test_handoff_builder_rejects_non_agent_supports_agent_run():
|
|
"""Verify that participants() rejects SupportsAgentRun implementations that are not Agent instances."""
|
|
from agent_framework import AgentResponse, AgentSession, SupportsAgentRun
|
|
|
|
class FakeAgentRun:
|
|
def __init__(self, id, name):
|
|
self.id = id
|
|
self.name = name
|
|
self.description = "d"
|
|
|
|
async def run(self, messages=None, *, stream=False, session=None, **kwargs):
|
|
return AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("ok")])])
|
|
|
|
def create_session(self, **kwargs):
|
|
return AgentSession()
|
|
|
|
def get_session(self, *, service_session_id, **kwargs):
|
|
return AgentSession(service_session_id=service_session_id)
|
|
|
|
fake = FakeAgentRun("a", "A")
|
|
assert isinstance(fake, SupportsAgentRun)
|
|
|
|
with pytest.raises(TypeError, match="Participants must be Agent instances"):
|
|
HandoffBuilder().participants([fake])
|