mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix underlying tool choice bug and all for return to previous Handoff subagent (#2037)
* Fix tool_choice override bug and add enable_return_to_previous support * Add unit test for handoff checkpointing * Handle tools when we have them
This commit is contained in:
committed by
GitHub
Unverified
parent
45dc0ff073
commit
548e0f028e
@@ -224,7 +224,7 @@ def _merge_chat_options(
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
@@ -496,7 +496,7 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
@@ -595,7 +595,7 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
|
||||
@@ -1525,6 +1525,12 @@ def _handle_function_calls_response(
|
||||
prepped_messages = prepare_messages(messages)
|
||||
response: "ChatResponse | None" = None
|
||||
fcc_messages: "list[ChatMessage]" = []
|
||||
|
||||
# If tools are provided but tool_choice is not set, default to "auto" for function invocation
|
||||
tools = _extract_tools(kwargs)
|
||||
if tools and kwargs.get("tool_choice") is None:
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
for attempt_idx in range(config.max_iterations if config.enabled else 0):
|
||||
fcc_todo = _collect_approval_responses(prepped_messages)
|
||||
if fcc_todo:
|
||||
|
||||
@@ -85,8 +85,8 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent:
|
||||
# so we need to recombine them here to pass the complete tools list to the constructor.
|
||||
# This makes sure MCP tools are preserved when cloning agents for handoff workflows.
|
||||
all_tools = list(options.tools) if options.tools else []
|
||||
if agent._local_mcp_tools:
|
||||
all_tools.extend(agent._local_mcp_tools)
|
||||
if agent._local_mcp_tools: # type: ignore
|
||||
all_tools.extend(agent._local_mcp_tools) # type: ignore
|
||||
|
||||
return ChatAgent(
|
||||
chat_client=agent.chat_client,
|
||||
@@ -133,6 +133,14 @@ class _ConversationWithUserInput:
|
||||
full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ConversationForUserInput:
|
||||
"""Internal message from coordinator to gateway specifying which agent will receive the response."""
|
||||
|
||||
conversation: list[ChatMessage]
|
||||
next_agent_id: str
|
||||
|
||||
|
||||
class _AutoHandoffMiddleware(FunctionMiddleware):
|
||||
"""Intercept handoff tool invocations and short-circuit execution with synthetic results."""
|
||||
|
||||
@@ -275,6 +283,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]],
|
||||
id: str,
|
||||
handoff_tool_targets: Mapping[str, str] | None = None,
|
||||
return_to_previous: bool = False,
|
||||
) -> None:
|
||||
"""Create a coordinator that manages routing between specialists and the user."""
|
||||
super().__init__(id)
|
||||
@@ -284,6 +293,8 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
self._input_gateway_id = input_gateway_id
|
||||
self._termination_condition = termination_condition
|
||||
self._handoff_tool_targets = {k.lower(): v for k, v in (handoff_tool_targets or {}).items()}
|
||||
self._return_to_previous = return_to_previous
|
||||
self._current_agent_id: str | None = None # Track the current agent handling conversation
|
||||
|
||||
def _get_author_name(self) -> str:
|
||||
"""Get the coordinator name for orchestrator-generated messages."""
|
||||
@@ -293,7 +304,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
async def handle_agent_response(
|
||||
self,
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage]],
|
||||
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
|
||||
) -> None:
|
||||
"""Process an agent's response and determine whether to route, request input, or terminate."""
|
||||
# Hydrate coordinator state (and detect new run) using checkpointable executor state
|
||||
@@ -329,6 +340,9 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
# Check for handoff from ANY agent (starting agent or specialist)
|
||||
target = self._resolve_specialist(response.agent_run_response, conversation)
|
||||
if target is not None:
|
||||
# Update current agent when handoff occurs
|
||||
self._current_agent_id = target
|
||||
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
|
||||
await self._persist_state(ctx)
|
||||
# Clean tool-related content before sending to next agent
|
||||
cleaned = clean_conversation_for_handoff(conversation)
|
||||
@@ -340,10 +354,15 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
if not is_starting_agent and source not in self._specialist_ids:
|
||||
raise RuntimeError(f"HandoffCoordinator received response from unknown executor '{source}'.")
|
||||
|
||||
# Update current agent when they respond without handoff
|
||||
self._current_agent_id = source
|
||||
logger.info(
|
||||
f"Agent '{source}' responded without handoff. "
|
||||
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
|
||||
)
|
||||
await self._persist_state(ctx)
|
||||
|
||||
if await self._check_termination():
|
||||
logger.info("Handoff workflow termination condition met. Ending conversation.")
|
||||
# Clean the output conversation for display
|
||||
cleaned_output = clean_conversation_for_handoff(conversation)
|
||||
await ctx.yield_output(cleaned_output)
|
||||
@@ -352,7 +371,13 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
# Clean conversation before sending to gateway for user input request
|
||||
# This removes tool messages that shouldn't be shown to users
|
||||
cleaned_for_display = clean_conversation_for_handoff(conversation)
|
||||
await ctx.send_message(cleaned_for_display, target_id=self._input_gateway_id)
|
||||
|
||||
# The awaiting_agent_id is the agent that just responded and is awaiting user input
|
||||
# This is the source of the current response
|
||||
next_agent_id = source
|
||||
|
||||
message_to_gateway = _ConversationForUserInput(conversation=cleaned_for_display, next_agent_id=next_agent_id)
|
||||
await ctx.send_message(message_to_gateway, target_id=self._input_gateway_id) # type: ignore[arg-type]
|
||||
|
||||
@handler
|
||||
async def handle_user_input(
|
||||
@@ -367,14 +392,26 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
|
||||
# Check termination before sending to agent
|
||||
if await self._check_termination():
|
||||
logger.info("Handoff workflow termination condition met. Ending conversation.")
|
||||
await ctx.yield_output(list(self._conversation))
|
||||
return
|
||||
|
||||
# Clean before sending to starting agent
|
||||
# Determine routing target based on return-to-previous setting
|
||||
target_agent_id = self._starting_agent_id
|
||||
if self._return_to_previous and self._current_agent_id:
|
||||
# Route back to the current agent that's handling the conversation
|
||||
target_agent_id = self._current_agent_id
|
||||
logger.info(
|
||||
f"Return-to-previous enabled: routing user input to current agent '{target_agent_id}' "
|
||||
f"(bypassing coordinator '{self._starting_agent_id}')"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Routing user input to coordinator '{target_agent_id}'")
|
||||
# Note: Stack is only used for specialist-to-specialist handoffs, not user input routing
|
||||
|
||||
# Clean before sending to target agent
|
||||
cleaned = clean_conversation_for_handoff(self._conversation)
|
||||
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
|
||||
await ctx.send_message(request, target_id=self._starting_agent_id)
|
||||
await ctx.send_message(request, target_id=target_agent_id)
|
||||
|
||||
def _resolve_specialist(self, agent_response: AgentRunResponse, conversation: list[ChatMessage]) -> str | None:
|
||||
"""Resolve the specialist executor id requested by the agent response, if any."""
|
||||
@@ -444,22 +481,27 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
|
||||
"""Serialize pattern-specific state.
|
||||
|
||||
Handoff has no additional metadata beyond base conversation state.
|
||||
Includes the current agent for return-to-previous routing.
|
||||
|
||||
Returns:
|
||||
Empty dict (no pattern-specific state)
|
||||
Dict containing current agent if return-to-previous is enabled
|
||||
"""
|
||||
if self._return_to_previous:
|
||||
return {
|
||||
"current_agent_id": self._current_agent_id,
|
||||
}
|
||||
return {}
|
||||
|
||||
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
|
||||
"""Restore pattern-specific state.
|
||||
|
||||
Handoff has no additional metadata beyond base conversation state.
|
||||
Restores the current agent for return-to-previous routing.
|
||||
|
||||
Args:
|
||||
metadata: Pattern-specific state dict (ignored)
|
||||
metadata: Pattern-specific state dict
|
||||
"""
|
||||
pass
|
||||
if self._return_to_previous and "current_agent_id" in metadata:
|
||||
self._current_agent_id = metadata["current_agent_id"]
|
||||
|
||||
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
|
||||
"""Rehydrate the coordinator's conversation history from checkpointed state.
|
||||
@@ -507,8 +549,21 @@ class _UserInputGateway(Executor):
|
||||
self._prompt = prompt or "Provide your next input for the conversation."
|
||||
|
||||
@handler
|
||||
async def request_input(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
async def request_input(self, message: _ConversationForUserInput, ctx: WorkflowContext) -> None:
|
||||
"""Emit a `HandoffUserInputRequest` capturing the conversation snapshot."""
|
||||
if not message.conversation:
|
||||
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
|
||||
request = HandoffUserInputRequest(
|
||||
conversation=list(message.conversation),
|
||||
awaiting_agent_id=message.next_agent_id,
|
||||
prompt=self._prompt,
|
||||
source_executor_id=self.id,
|
||||
)
|
||||
await ctx.request_info(request, object)
|
||||
|
||||
@handler
|
||||
async def request_input_legacy(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
"""Legacy handler for backward compatibility - emit user input request with starting agent."""
|
||||
if not conversation:
|
||||
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
|
||||
request = HandoffUserInputRequest(
|
||||
@@ -558,7 +613,7 @@ def _as_user_messages(payload: Any) -> list[ChatMessage]:
|
||||
|
||||
|
||||
def _default_termination_condition(conversation: list[ChatMessage]) -> bool:
|
||||
"""Default termination: stop after 10 user messages to prevent infinite loops."""
|
||||
"""Default termination: stop after 10 user messages."""
|
||||
user_message_count = sum(1 for msg in conversation if msg.role == Role.USER)
|
||||
return user_message_count >= 10
|
||||
|
||||
@@ -743,6 +798,7 @@ class HandoffBuilder:
|
||||
)
|
||||
self._auto_register_handoff_tools: bool = True
|
||||
self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids]
|
||||
self._return_to_previous: bool = False
|
||||
|
||||
if participants:
|
||||
self.participants(participants)
|
||||
@@ -1198,6 +1254,77 @@ class HandoffBuilder:
|
||||
self._termination_condition = condition
|
||||
return self
|
||||
|
||||
def enable_return_to_previous(self, enabled: bool = True) -> "HandoffBuilder":
|
||||
"""Enable direct return to the current agent after user input, bypassing the coordinator.
|
||||
|
||||
When enabled, after a specialist responds without requesting another handoff, user input
|
||||
routes directly back to that same specialist instead of always routing back to the
|
||||
coordinator agent for re-evaluation.
|
||||
|
||||
This is useful when a specialist needs multiple turns with the user to gather information
|
||||
or resolve an issue, avoiding unnecessary coordinator involvement while maintaining context.
|
||||
|
||||
Flow Comparison:
|
||||
|
||||
**Default (disabled):**
|
||||
User -> Coordinator -> Specialist -> User -> Coordinator -> Specialist -> ...
|
||||
|
||||
**With return_to_previous (enabled):**
|
||||
User -> Coordinator -> Specialist -> User -> Specialist -> ...
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable return-to-previous routing. Default is True.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, technical_support, billing])
|
||||
.set_coordinator("triage")
|
||||
.add_handoff(triage, [technical_support, billing])
|
||||
.enable_return_to_previous() # Enable direct return routing
|
||||
.build()
|
||||
)
|
||||
|
||||
# Flow: User asks question
|
||||
# -> Triage routes to Technical Support
|
||||
# -> Technical Support asks clarifying question
|
||||
# -> User provides more info
|
||||
# -> Routes back to Technical Support (not Triage)
|
||||
# -> Technical Support continues helping
|
||||
|
||||
Multi-tier handoff example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator("triage")
|
||||
.add_handoff(triage, [specialist_a, specialist_b])
|
||||
.add_handoff(specialist_a, specialist_b)
|
||||
.enable_return_to_previous()
|
||||
.build()
|
||||
)
|
||||
|
||||
# Flow: User asks question
|
||||
# -> Triage routes to Specialist A
|
||||
# -> Specialist A hands off to Specialist B
|
||||
# -> Specialist B asks clarifying question
|
||||
# -> User provides more info
|
||||
# -> Routes back to Specialist B (who is currently handling the conversation)
|
||||
|
||||
Note:
|
||||
This feature routes to whichever agent most recently responded, whether that's
|
||||
the coordinator or a specialist. The conversation continues with that agent until
|
||||
they either hand off to another agent or the termination condition is met.
|
||||
"""
|
||||
self._return_to_previous = enabled
|
||||
return self
|
||||
|
||||
def build(self) -> Workflow:
|
||||
"""Construct the final Workflow instance from the configured builder.
|
||||
|
||||
@@ -1326,6 +1453,7 @@ class HandoffBuilder:
|
||||
termination_condition=self._termination_condition,
|
||||
id="handoff-coordinator",
|
||||
handoff_tool_targets=handoff_tool_targets,
|
||||
return_to_previous=self._return_to_previous,
|
||||
)
|
||||
|
||||
wiring = _GroupChatConfig(
|
||||
|
||||
@@ -23,7 +23,7 @@ from agent_framework import (
|
||||
WorkflowOutputEvent,
|
||||
)
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework._workflows._handoff import _clone_chat_agent
|
||||
from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -392,12 +392,218 @@ async def test_clone_chat_agent_preserves_mcp_tools() -> None:
|
||||
)
|
||||
|
||||
assert hasattr(original_agent, "_local_mcp_tools")
|
||||
assert len(original_agent._local_mcp_tools) == 1
|
||||
assert original_agent._local_mcp_tools[0] == mock_mcp_tool
|
||||
assert len(original_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage]
|
||||
assert original_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage]
|
||||
|
||||
cloned_agent = _clone_chat_agent(original_agent)
|
||||
|
||||
assert hasattr(cloned_agent, "_local_mcp_tools")
|
||||
assert len(cloned_agent._local_mcp_tools) == 1
|
||||
assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool
|
||||
assert len(cloned_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage]
|
||||
assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage]
|
||||
assert cloned_agent.chat_options.tools is not None
|
||||
assert len(cloned_agent.chat_options.tools) == 1
|
||||
|
||||
|
||||
async def test_return_to_previous_routing():
|
||||
"""Test that return-to-previous routes back to the current specialist handling the conversation."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
|
||||
specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b")
|
||||
specialist_b = _RecordingAgent(name="specialist_b")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator(triage)
|
||||
.add_handoff(triage, [specialist_a, specialist_b])
|
||||
.add_handoff(specialist_a, specialist_b)
|
||||
.enable_return_to_previous(True)
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 4)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Start conversation - triage hands off to specialist_a
|
||||
events = await _drain(workflow.run_stream("Initial request"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
assert len(specialist_a.calls) > 0
|
||||
|
||||
# Specialist_a should have been called with initial request
|
||||
initial_specialist_a_calls = len(specialist_a.calls)
|
||||
|
||||
# Second user message - specialist_a hands off to specialist_b
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"}))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
|
||||
# Specialist_b should have been called
|
||||
assert len(specialist_b.calls) > 0
|
||||
initial_specialist_b_calls = len(specialist_b.calls)
|
||||
|
||||
# Third user message - with return_to_previous, should route back to specialist_b (current agent)
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
|
||||
third_requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
|
||||
# Specialist_b should have been called again (return-to-previous routes to current agent)
|
||||
assert len(specialist_b.calls) > initial_specialist_b_calls, (
|
||||
"Specialist B should be called again due to return-to-previous routing to current agent"
|
||||
)
|
||||
|
||||
# Specialist_a should NOT be called again (it's no longer the current agent)
|
||||
assert len(specialist_a.calls) == initial_specialist_a_calls, (
|
||||
"Specialist A should not be called again - specialist_b is the current agent"
|
||||
)
|
||||
|
||||
# Triage should only have been called once at the start
|
||||
assert len(triage.calls) == 1, "Triage should only be called once (initial routing)"
|
||||
|
||||
# Verify awaiting_agent_id is set to specialist_b (the agent that just responded)
|
||||
if third_requests:
|
||||
user_input_req = third_requests[-1].data
|
||||
assert isinstance(user_input_req, HandoffUserInputRequest)
|
||||
assert user_input_req.awaiting_agent_id == "specialist_b", (
|
||||
f"Expected awaiting_agent_id 'specialist_b' but got '{user_input_req.awaiting_agent_id}'"
|
||||
)
|
||||
|
||||
|
||||
async def test_return_to_previous_disabled_routes_to_coordinator():
|
||||
"""Test that with return-to-previous disabled, routing goes back to coordinator."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
|
||||
specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b")
|
||||
specialist_b = _RecordingAgent(name="specialist_b")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator(triage)
|
||||
.add_handoff(triage, [specialist_a, specialist_b])
|
||||
.add_handoff(specialist_a, specialist_b)
|
||||
.enable_return_to_previous(False)
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Start conversation - triage hands off to specialist_a
|
||||
events = await _drain(workflow.run_stream("Initial request"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
assert len(triage.calls) == 1
|
||||
|
||||
# Second user message - specialist_a hands off to specialist_b
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"}))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
|
||||
# Third user message - without return_to_previous, should route back to triage
|
||||
await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
|
||||
|
||||
# Triage should have been called twice total: initial + after specialist_b responds
|
||||
assert len(triage.calls) == 2, "Triage should be called twice (initial + default routing to coordinator)"
|
||||
|
||||
|
||||
async def test_return_to_previous_enabled():
|
||||
"""Verify that enable_return_to_previous() keeps control with the current specialist."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
|
||||
specialist_a = _RecordingAgent(name="specialist_a")
|
||||
specialist_b = _RecordingAgent(name="specialist_b")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator("triage")
|
||||
.enable_return_to_previous(True)
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Start conversation - triage hands off to specialist_a
|
||||
events = await _drain(workflow.run_stream("Initial request"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
assert len(triage.calls) == 1
|
||||
assert len(specialist_a.calls) == 1
|
||||
|
||||
# Second user message - with return_to_previous, should route to specialist_a (not triage)
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
|
||||
# Triage should only have been called once (initial) - specialist_a handles follow-up
|
||||
assert len(triage.calls) == 1, "Triage should only be called once (initial)"
|
||||
assert len(specialist_a.calls) == 2, "Specialist A should handle follow-up with return_to_previous enabled"
|
||||
|
||||
|
||||
async def test_tool_choice_preserved_from_agent_config():
|
||||
"""Verify that agent-level tool_choice configuration is preserved and not overridden."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from agent_framework import ChatResponse, ToolMode
|
||||
|
||||
# Create a mock chat client that records the tool_choice used
|
||||
recorded_tool_choices: list[Any] = []
|
||||
|
||||
async def mock_get_response(messages: Any, **kwargs: Any) -> ChatResponse:
|
||||
chat_options = kwargs.get("chat_options")
|
||||
if chat_options:
|
||||
recorded_tool_choices.append(chat_options.tool_choice)
|
||||
return ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Response")],
|
||||
response_id="test_response",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_response = AsyncMock(side_effect=mock_get_response)
|
||||
|
||||
# Create agent with specific tool_choice configuration
|
||||
agent = ChatAgent(
|
||||
chat_client=mock_client,
|
||||
name="test_agent",
|
||||
tool_choice=ToolMode(mode="required"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
await agent.run("Test message")
|
||||
|
||||
# Verify tool_choice was preserved
|
||||
assert len(recorded_tool_choices) > 0, "No tool_choice recorded"
|
||||
last_tool_choice = recorded_tool_choices[-1]
|
||||
assert last_tool_choice is not None, "tool_choice should not be None"
|
||||
assert str(last_tool_choice) == "required", f"Expected 'required', got {last_tool_choice}"
|
||||
|
||||
|
||||
async def test_return_to_previous_state_serialization():
|
||||
"""Test that return_to_previous state is properly serialized/deserialized for checkpointing."""
|
||||
from agent_framework._workflows._handoff import _HandoffCoordinator # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Create a coordinator with return_to_previous enabled
|
||||
coordinator = _HandoffCoordinator(
|
||||
starting_agent_id="triage",
|
||||
specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"},
|
||||
input_gateway_id="gateway",
|
||||
termination_condition=lambda conv: False,
|
||||
id="test-coordinator",
|
||||
return_to_previous=True,
|
||||
)
|
||||
|
||||
# Set the current agent (simulating a handoff scenario)
|
||||
coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Snapshot the state
|
||||
state = coordinator.snapshot_state()
|
||||
|
||||
# Verify pattern metadata includes current_agent_id
|
||||
assert "metadata" in state
|
||||
assert "current_agent_id" in state["metadata"]
|
||||
assert state["metadata"]["current_agent_id"] == "specialist_a"
|
||||
|
||||
# Create a new coordinator and restore state
|
||||
coordinator2 = _HandoffCoordinator(
|
||||
starting_agent_id="triage",
|
||||
specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"},
|
||||
input_gateway_id="gateway",
|
||||
termination_condition=lambda conv: False,
|
||||
id="test-coordinator",
|
||||
return_to_previous=True,
|
||||
)
|
||||
|
||||
# Restore state
|
||||
coordinator2.restore_state(state)
|
||||
|
||||
# Verify current_agent_id was restored
|
||||
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -299,6 +299,7 @@ This directory contains samples demonstrating the capabilities of Microsoft Agen
|
||||
| [`getting_started/workflows/orchestration/group_chat_simple_selector.py`](./getting_started/workflows/orchestration/group_chat_simple_selector.py) | Sample: Group Chat Orchestration with function-based speaker selector |
|
||||
| [`getting_started/workflows/orchestration/handoff_simple.py`](./getting_started/workflows/orchestration/handoff_simple.py) | Sample: Handoff Orchestration with simple agent handoff pattern |
|
||||
| [`getting_started/workflows/orchestration/handoff_specialist_to_specialist.py`](./getting_started/workflows/orchestration/handoff_specialist_to_specialist.py) | Sample: Handoff Orchestration with specialist-to-specialist routing |
|
||||
| [`getting_started/workflows/orchestration/handoff_return_to_previous`](./getting_started/workflows/orchestration/handoff_return_to_previous.py) | Return-to-previous routing: after user input, routes back to the previous specialist instead of coordinator using `.enable_return_to_previous()` |
|
||||
| [`getting_started/workflows/orchestration/magentic.py`](./getting_started/workflows/orchestration/magentic.py) | Sample: Magentic Orchestration (agentic task planning with multi-agent execution) |
|
||||
| [`getting_started/workflows/orchestration/magentic_checkpoint.py`](./getting_started/workflows/orchestration/magentic_checkpoint.py) | Sample: Magentic Orchestration with Checkpointing |
|
||||
| [`getting_started/workflows/orchestration/magentic_human_plan_update.py`](./getting_started/workflows/orchestration/magentic_human_plan_update.py) | Sample: Magentic Orchestration with Human Plan Review |
|
||||
|
||||
@@ -96,6 +96,7 @@ Once comfortable with these, explore the rest of the samples below.
|
||||
| Group Chat with Simple Function Selector | [orchestration/group_chat_simple_selector.py](./orchestration/group_chat_simple_selector.py) | Group chat with a simple function selector for next speaker |
|
||||
| Handoff (Simple) | [orchestration/handoff_simple.py](./orchestration/handoff_simple.py) | Single-tier routing: triage agent routes to specialists, control returns to user after each specialist response |
|
||||
| Handoff (Specialist-to-Specialist) | [orchestration/handoff_specialist_to_specialist.py](./orchestration/handoff_specialist_to_specialist.py) | Multi-tier routing: specialists can hand off to other specialists using `.add_handoff()` fluent API |
|
||||
| Handoff (Return-to-Previous) | [orchestration/handoff_return_to_previous.py](./orchestration/handoff_return_to_previous.py) | Return-to-previous routing: after user input, routes back to the previous specialist instead of coordinator using `.enable_return_to_previous()` |
|
||||
| Magentic Workflow (Multi-Agent) | [orchestration/magentic.py](./orchestration/magentic.py) | Orchestrate multiple agents with Magentic manager and streaming |
|
||||
| Magentic + Human Plan Review | [orchestration/magentic_human_plan_update.py](./orchestration/magentic_human_plan_update.py) | Human reviews/updates the plan before execution |
|
||||
| Magentic + Checkpoint Resume | [orchestration/magentic_checkpoint.py](./orchestration/magentic_checkpoint.py) | Resume Magentic orchestration from saved checkpoints |
|
||||
|
||||
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import cast
|
||||
|
||||
from agent_framework import (
|
||||
ChatAgent,
|
||||
HandoffBuilder,
|
||||
HandoffUserInputRequest,
|
||||
RequestInfoEvent,
|
||||
WorkflowEvent,
|
||||
WorkflowOutputEvent,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
"""Sample: Handoff workflow with return-to-previous routing enabled.
|
||||
|
||||
This interactive sample demonstrates the return-to-previous feature where user inputs
|
||||
route directly back to the specialist currently handling their request, rather than
|
||||
always going through the coordinator for re-evaluation.
|
||||
|
||||
Routing Pattern (with return-to-previous enabled):
|
||||
User -> Coordinator -> Technical Support -> User -> Technical Support -> ...
|
||||
|
||||
Routing Pattern (default, without return-to-previous):
|
||||
User -> Coordinator -> Technical Support -> User -> Coordinator -> Technical Support -> ...
|
||||
|
||||
This is useful when a specialist needs multiple turns with the user to gather
|
||||
information or resolve an issue, avoiding unnecessary coordinator involvement.
|
||||
|
||||
Specialist-to-Specialist Handoff:
|
||||
When a user's request changes to a topic outside the current specialist's domain,
|
||||
the specialist can hand off DIRECTLY to another specialist without going back through
|
||||
the coordinator:
|
||||
|
||||
User -> Coordinator -> Technical Support -> User -> Technical Support (billing question)
|
||||
-> Billing -> User -> Billing ...
|
||||
|
||||
Example Interaction:
|
||||
1. User reports a technical issue
|
||||
2. Coordinator routes to technical support specialist
|
||||
3. Technical support asks clarifying questions
|
||||
4. User provides details (routes directly back to technical support)
|
||||
5. Technical support continues troubleshooting with full context
|
||||
6. Issue resolved, user asks about billing
|
||||
7. Technical support hands off DIRECTLY to billing specialist
|
||||
8. Billing specialist helps with payment
|
||||
9. User continues with billing (routes directly to billing)
|
||||
|
||||
Prerequisites:
|
||||
- `az login` (Azure CLI authentication)
|
||||
- Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.)
|
||||
|
||||
Usage:
|
||||
Run the script and interact with the support workflow by typing your requests.
|
||||
Type 'exit' or 'quit' to end the conversation.
|
||||
|
||||
Key Concepts:
|
||||
- Return-to-previous: Direct routing to current agent handling the conversation
|
||||
- Current agent tracking: Framework remembers which agent is actively helping the user
|
||||
- Context preservation: Specialist maintains full conversation context
|
||||
- Domain switching: Specialists can hand back to coordinator when topic changes
|
||||
"""
|
||||
|
||||
|
||||
def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAgent, ChatAgent, ChatAgent]:
|
||||
"""Create and configure the coordinator and specialist agents.
|
||||
|
||||
Returns:
|
||||
Tuple of (coordinator, technical_support, account_specialist, billing_agent)
|
||||
"""
|
||||
coordinator = chat_client.create_agent(
|
||||
instructions=(
|
||||
"You are a customer support coordinator. Analyze the user's request and route to "
|
||||
"the appropriate specialist:\n"
|
||||
"- technical_support for technical issues, troubleshooting, repairs, hardware/software problems\n"
|
||||
"- account_specialist for account changes, profile updates, settings, login issues\n"
|
||||
"- billing_agent for payments, invoices, refunds, charges, billing questions\n"
|
||||
"\n"
|
||||
"When you receive a request, immediately call the matching handoff tool without explaining. "
|
||||
"Read the most recent user message to determine the correct specialist."
|
||||
),
|
||||
name="coordinator",
|
||||
)
|
||||
|
||||
technical_support = chat_client.create_agent(
|
||||
instructions=(
|
||||
"You provide technical support. Help users troubleshoot technical issues, "
|
||||
"arrange repairs, and answer technical questions. "
|
||||
"Gather information through conversation. "
|
||||
"If the user asks about billing, payments, invoices, or refunds, hand off to billing_agent. "
|
||||
"If the user asks about account settings or profile changes, hand off to account_specialist."
|
||||
),
|
||||
name="technical_support",
|
||||
)
|
||||
|
||||
account_specialist = chat_client.create_agent(
|
||||
instructions=(
|
||||
"You handle account management. Help with profile updates, account settings, "
|
||||
"and preferences. Gather information through conversation. "
|
||||
"If the user asks about technical issues or troubleshooting, hand off to technical_support. "
|
||||
"If the user asks about billing, payments, invoices, or refunds, hand off to billing_agent."
|
||||
),
|
||||
name="account_specialist",
|
||||
)
|
||||
|
||||
billing_agent = chat_client.create_agent(
|
||||
instructions=(
|
||||
"You handle billing only. Process payments, explain invoices, handle refunds. "
|
||||
"If the user asks about technical issues or troubleshooting, hand off to technical_support. "
|
||||
"If the user asks about account settings or profile changes, hand off to account_specialist."
|
||||
),
|
||||
name="billing_agent",
|
||||
)
|
||||
|
||||
return coordinator, technical_support, account_specialist, billing_agent
|
||||
|
||||
|
||||
def handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]:
|
||||
"""Process events and return pending input requests."""
|
||||
pending_requests: list[RequestInfoEvent] = []
|
||||
for event in events:
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
pending_requests.append(event)
|
||||
request_data = cast(HandoffUserInputRequest, event.data)
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"AWAITING INPUT FROM: {request_data.awaiting_agent_id.upper()}")
|
||||
print(f"{'=' * 60}")
|
||||
for msg in request_data.conversation[-3:]:
|
||||
author = msg.author_name or msg.role.value
|
||||
prefix = ">>> " if author == request_data.awaiting_agent_id else " "
|
||||
print(f"{prefix}[{author}]: {msg.text}")
|
||||
elif isinstance(event, WorkflowOutputEvent):
|
||||
print(f"\n{'=' * 60}")
|
||||
print("[WORKFLOW COMPLETE]")
|
||||
print(f"{'=' * 60}")
|
||||
return pending_requests
|
||||
|
||||
|
||||
async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]:
|
||||
"""Drain an async iterable into a list."""
|
||||
events: list[WorkflowEvent] = []
|
||||
async for event in stream:
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Demonstrate return-to-previous routing in a handoff workflow."""
|
||||
chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
|
||||
coordinator, technical, account, billing = create_agents(chat_client)
|
||||
|
||||
print("Handoff Workflow with Return-to-Previous Routing")
|
||||
print("=" * 60)
|
||||
print("\nThis interactive demo shows how user inputs route directly")
|
||||
print("to the specialist handling your request, avoiding unnecessary")
|
||||
print("coordinator re-evaluation on each turn.")
|
||||
print("\nSpecialists can hand off directly to other specialists when")
|
||||
print("your request changes topics (e.g., from technical to billing).")
|
||||
print("\nType 'exit' or 'quit' to end the conversation.\n")
|
||||
|
||||
# Configure handoffs with return-to-previous enabled
|
||||
# Specialists can hand off directly to other specialists when topic changes
|
||||
workflow = (
|
||||
HandoffBuilder(
|
||||
name="return_to_previous_demo",
|
||||
participants=[coordinator, technical, account, billing],
|
||||
)
|
||||
.set_coordinator(coordinator)
|
||||
.add_handoff(coordinator, [technical, account, billing]) # Coordinator routes to all specialists
|
||||
.add_handoff(technical, [billing, account]) # Technical can route to billing or account
|
||||
.add_handoff(account, [technical, billing]) # Account can route to technical or billing
|
||||
.add_handoff(billing, [technical, account]) # Billing can route to technical or account
|
||||
.enable_return_to_previous(True) # Enable the `return to previous handoff` feature
|
||||
.with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 10)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Get initial user request
|
||||
initial_request = input("You: ").strip() # noqa: ASYNC250
|
||||
if not initial_request or initial_request.lower() in ["exit", "quit"]:
|
||||
print("Goodbye!")
|
||||
return
|
||||
|
||||
# Start workflow with initial message
|
||||
events = await _drain(workflow.run_stream(initial_request))
|
||||
pending_requests = handle_events(events)
|
||||
|
||||
# Interactive loop: keep prompting for user input
|
||||
while pending_requests:
|
||||
user_input = input("\nYou: ").strip() # noqa: ASYNC250
|
||||
|
||||
if not user_input or user_input.lower() in ["exit", "quit"]:
|
||||
print("\nEnding conversation. Goodbye!")
|
||||
break
|
||||
|
||||
responses = {req.request_id: user_input for req in pending_requests}
|
||||
events = await _drain(workflow.send_responses_streaming(responses))
|
||||
pending_requests = handle_events(events)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Conversation ended.")
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
Handoff Workflow with Return-to-Previous Routing
|
||||
============================================================
|
||||
|
||||
This interactive demo shows how user inputs route directly
|
||||
to the specialist handling your request, avoiding unnecessary
|
||||
coordinator re-evaluation on each turn.
|
||||
|
||||
Specialists can hand off directly to other specialists when
|
||||
your request changes topics (e.g., from technical to billing).
|
||||
|
||||
Type 'exit' or 'quit' to end the conversation.
|
||||
|
||||
You: I need help with my bill, I was charged twice by mistake.
|
||||
|
||||
============================================================
|
||||
AWAITING INPUT FROM: BILLING_AGENT
|
||||
============================================================
|
||||
[user]: I need help with my bill, I was charged twice by mistake.
|
||||
[coordinator]: You will be connected to a billing agent who can assist you with the double charge on your bill.
|
||||
>>> [billing_agent]: I'm here to help with billing concerns! I'm sorry you were charged twice. Could you
|
||||
please provide the invoice number or your account email so I can look into this and begin processing a refund?
|
||||
|
||||
You: Invoice 1234
|
||||
|
||||
============================================================
|
||||
AWAITING INPUT FROM: BILLING_AGENT
|
||||
============================================================
|
||||
>>> [billing_agent]: I'm here to help with billing concerns! I'm sorry you were charged twice.
|
||||
Could you please provide the invoice number or your account email so I can look into this and begin
|
||||
processing a refund?
|
||||
[user]: Invoice 1234
|
||||
>>> [billing_agent]: Thank you for providing the invoice number (1234). I will review the details and work
|
||||
on processing a refund for the duplicate charge.
|
||||
|
||||
Can you confirm which payment method you used for this bill (e.g., credit card, PayPal)?
|
||||
This helps ensure your refund is processed to the correct account.
|
||||
|
||||
You: I used my credit card, which is on autopay.
|
||||
|
||||
============================================================
|
||||
AWAITING INPUT FROM: BILLING_AGENT
|
||||
============================================================
|
||||
>>> [billing_agent]: Thank you for providing the invoice number (1234). I will review the details and work on
|
||||
processing a refund for the duplicate charge.
|
||||
|
||||
Can you confirm which payment method you used for this bill (e.g., credit card, PayPal)? This helps ensure
|
||||
your refund is processed to the correct account.
|
||||
[user]: I used my credit card, which is on autopay.
|
||||
>>> [billing_agent]: Thank you for confirming your payment method. I will look into invoice 1234 and
|
||||
process a refund for the duplicate charge to your credit card.
|
||||
|
||||
You will receive a notification once the refund is completed. If you have any further questions about your billing
|
||||
or need an update, please let me know!
|
||||
|
||||
You: Actually I also can't turn on my modem. It reset and now won't turn on.
|
||||
|
||||
============================================================
|
||||
AWAITING INPUT FROM: TECHNICAL_SUPPORT
|
||||
============================================================
|
||||
[user]: Actually I also can't turn on my modem. It reset and now won't turn on.
|
||||
[billing_agent]: I'm connecting you with technical support for assistance with your modem not turning on after
|
||||
the reset. They'll be able to help troubleshoot and resolve this issue.
|
||||
|
||||
At the same time, technical support will also handle your refund request for the duplicate charge on invoice 1234
|
||||
to your credit card on autopay.
|
||||
|
||||
You will receive updates from the appropriate teams shortly.
|
||||
>>> [technical_support]: Thanks for letting me know about your modem issue! To help you further, could you tell me:
|
||||
|
||||
1. Is there any light showing on the modem at all, or is it completely off?
|
||||
2. Have you tried unplugging the modem from power and plugging it back in?
|
||||
3. Do you hear or feel anything (like a slight hum or vibration) when the modem is plugged in?
|
||||
|
||||
Let me know, and I'll guide you through troubleshooting or arrange a repair if needed.
|
||||
|
||||
You: exit
|
||||
|
||||
Ending conversation. Goodbye!
|
||||
|
||||
============================================================
|
||||
Conversation ended.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user