Python: Fix context duplication in handoff workflows when restoring from checkpoint (#2867)

* Fix context duplication in handoff workflows when restoring from checkpoint

* Address Copilot PR review
This commit is contained in:
Evan Mattson
2025-12-16 18:52:59 +09:00
committed by GitHub
Unverified
parent 11d6dcfe80
commit 958a488f96
3 changed files with 493 additions and 24 deletions
@@ -130,19 +130,57 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent:
@dataclass
class HandoffUserInputRequest:
"""Request message emitted when the workflow needs fresh user input."""
"""Request message emitted when the workflow needs fresh user input.
Note: The conversation field is intentionally excluded from checkpoint serialization
to prevent duplication. The conversation is preserved in the coordinator's state
and will be reconstructed on restore. See issue #2667.
"""
conversation: list[ChatMessage]
awaiting_agent_id: str
prompt: str
source_executor_id: str
def to_dict(self) -> dict[str, Any]:
"""Serialize to dict, excluding conversation to prevent checkpoint duplication.
The conversation is already preserved in the workflow coordinator's state.
Including it here would cause duplicate messages when restoring from checkpoint.
"""
return {
"awaiting_agent_id": self.awaiting_agent_id,
"prompt": self.prompt,
"source_executor_id": self.source_executor_id,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "HandoffUserInputRequest":
"""Deserialize from dict, initializing conversation as empty.
The conversation will be reconstructed from the coordinator's state on restore.
"""
return cls(
conversation=[],
awaiting_agent_id=data["awaiting_agent_id"],
prompt=data["prompt"],
source_executor_id=data["source_executor_id"],
)
@dataclass
class _ConversationWithUserInput:
"""Internal message carrying full conversation + new user messages from gateway to coordinator."""
"""Internal message carrying full conversation + new user messages from gateway to coordinator.
Attributes:
full_conversation: The conversation messages to process.
is_post_restore: If True, indicates this message was created after a checkpoint restore.
The coordinator should append these messages to its existing conversation rather
than replacing it. This prevents duplicate messages (see issue #2667).
"""
full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc]
is_post_restore: bool = False
@dataclass
@@ -439,9 +477,25 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
message: _ConversationWithUserInput,
ctx: WorkflowContext[AgentExecutorRequest, list[ChatMessage]],
) -> None:
"""Receive full conversation with new user input from gateway, update history, trim for agent."""
# Update authoritative conversation
self._conversation = list(message.full_conversation)
"""Receive user input from gateway, update history, and route to agent.
The message.full_conversation may contain:
- Full conversation history + new user messages (normal flow)
- Only new user messages (post-checkpoint-restore flow, see issue #2667)
The gateway sets message.is_post_restore=True when resuming after a checkpoint
restore. In that case, we append the new messages to the existing conversation
rather than replacing it.
"""
incoming = message.full_conversation
if message.is_post_restore and self._conversation:
# Post-restore: append new user messages to existing conversation
# The coordinator already has its conversation restored from checkpoint
self._conversation.extend(incoming)
else:
# Normal flow: replace with full conversation
self._conversation = list(incoming) if incoming else self._conversation
# Reset autonomous turn counter on new user input
self._autonomous_turns = 0
@@ -626,15 +680,24 @@ class _UserInputGateway(Executor):
response: object,
ctx: WorkflowContext[_ConversationWithUserInput],
) -> None:
"""Convert user input responses back into chat messages and resume the workflow."""
# Reconstruct full conversation with new user input
conversation = list(original_request.conversation)
user_messages = _as_user_messages(response)
conversation.extend(user_messages)
"""Convert user input responses back into chat messages and resume the workflow.
After checkpoint restore, original_request.conversation will be empty (not serialized
to prevent duplication - see issue #2667). In this case, we send only the new user
messages and let the coordinator append them to its already-restored conversation.
"""
user_messages = _as_user_messages(response)
if original_request.conversation:
# Normal flow: have conversation history from the original request
conversation = list(original_request.conversation)
conversation.extend(user_messages)
message = _ConversationWithUserInput(full_conversation=conversation, is_post_restore=False)
else:
# Post-restore flow: conversation was not serialized, send only new user messages
# The coordinator will append these to its already-restored conversation
message = _ConversationWithUserInput(full_conversation=user_messages, is_post_restore=True)
# Send full conversation back to coordinator (not trimmed)
# Coordinator will update its authoritative history and trim for agent
message = _ConversationWithUserInput(full_conversation=conversation)
await ctx.send_message(message, target_id="handoff-coordinator")
@@ -25,7 +25,12 @@ from agent_framework import (
from agent_framework._mcp import MCPTool
from agent_framework._workflows import AgentRunEvent
from agent_framework._workflows import _handoff as handoff_module # type: ignore
from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage]
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from agent_framework._workflows._handoff import (
_clone_chat_agent, # type: ignore[reportPrivateUsage]
_ConversationWithUserInput,
_UserInputGateway,
)
from agent_framework._workflows._workflow_builder import WorkflowBuilder
@@ -775,3 +780,402 @@ async def test_return_to_previous_state_serialization():
# Verify current_agent_id was restored
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]
async def test_handoff_user_input_request_checkpoint_excludes_conversation():
"""Test that HandoffUserInputRequest serialization excludes conversation to prevent duplication.
Issue #2667: When checkpointing a workflow with a pending HandoffUserInputRequest,
the conversation field gets serialized twice: once in the RequestInfoEvent's data
and once in the coordinator's conversation state. On restore, this causes duplicate
messages.
The fix is to exclude the conversation field during checkpoint serialization since
the conversation is already preserved in the coordinator's state.
"""
# Create a conversation history
conversation = [
ChatMessage(role=Role.USER, text="Hello"),
ChatMessage(role=Role.ASSISTANT, text="Hi there!"),
ChatMessage(role=Role.USER, text="Help me"),
]
# Create a HandoffUserInputRequest with the conversation
request = HandoffUserInputRequest(
conversation=conversation,
awaiting_agent_id="specialist_agent",
prompt="Please provide your input",
source_executor_id="gateway",
)
# Encode the request (simulating checkpoint save)
encoded = encode_checkpoint_value(request)
# Verify conversation is NOT in the encoded output
# The fix should exclude conversation from serialization
assert isinstance(encoded, dict)
# If using MODEL_MARKER strategy (to_dict/from_dict)
if "__af_model__" in encoded or "__af_dataclass__" in encoded:
value = encoded.get("value", {})
assert "conversation" not in value, "conversation should be excluded from checkpoint serialization"
# Decode the request (simulating checkpoint restore)
decoded = decode_checkpoint_value(encoded)
# Verify the decoded request is a HandoffUserInputRequest
assert isinstance(decoded, HandoffUserInputRequest)
# Verify other fields are preserved
assert decoded.awaiting_agent_id == "specialist_agent"
assert decoded.prompt == "Please provide your input"
assert decoded.source_executor_id == "gateway"
# Conversation should be an empty list after deserialization
# (will be reconstructed from coordinator state on restore)
assert decoded.conversation == []
async def test_handoff_user_input_request_roundtrip_preserves_metadata():
"""Test that non-conversation fields survive checkpoint roundtrip."""
request = HandoffUserInputRequest(
conversation=[ChatMessage(role=Role.USER, text="test")],
awaiting_agent_id="test_agent",
prompt="Enter your response",
source_executor_id="test_gateway",
)
# Roundtrip through checkpoint encoding
encoded = encode_checkpoint_value(request)
decoded = decode_checkpoint_value(encoded)
assert isinstance(decoded, HandoffUserInputRequest)
assert decoded.awaiting_agent_id == request.awaiting_agent_id
assert decoded.prompt == request.prompt
assert decoded.source_executor_id == request.source_executor_id
async def test_request_info_event_with_handoff_user_input_request():
"""Test RequestInfoEvent serialization with HandoffUserInputRequest data."""
conversation = [
ChatMessage(role=Role.USER, text="Hello"),
ChatMessage(role=Role.ASSISTANT, text="How can I help?"),
]
request = HandoffUserInputRequest(
conversation=conversation,
awaiting_agent_id="specialist",
prompt="Provide input",
source_executor_id="gateway",
)
# Create a RequestInfoEvent wrapping the request
event = RequestInfoEvent(
request_id="test-request-123",
source_executor_id="gateway",
request_data=request,
response_type=object,
)
# Serialize the event
event_dict = event.to_dict()
# Verify the data field doesn't contain conversation
data_encoded = event_dict["data"]
if isinstance(data_encoded, dict) and ("__af_model__" in data_encoded or "__af_dataclass__" in data_encoded):
value = data_encoded.get("value", {})
assert "conversation" not in value
# Deserialize and verify
restored_event = RequestInfoEvent.from_dict(event_dict)
assert isinstance(restored_event.data, HandoffUserInputRequest)
assert restored_event.data.awaiting_agent_id == "specialist"
assert restored_event.data.conversation == []
async def test_handoff_user_input_request_to_dict_excludes_conversation():
"""Test that to_dict() method excludes conversation field."""
conversation = [
ChatMessage(role=Role.USER, text="Hello"),
ChatMessage(role=Role.ASSISTANT, text="Hi!"),
]
request = HandoffUserInputRequest(
conversation=conversation,
awaiting_agent_id="agent1",
prompt="Enter input",
source_executor_id="gateway",
)
# Call to_dict directly
data = request.to_dict()
# Verify conversation is excluded
assert "conversation" not in data
assert data["awaiting_agent_id"] == "agent1"
assert data["prompt"] == "Enter input"
assert data["source_executor_id"] == "gateway"
async def test_handoff_user_input_request_from_dict_creates_empty_conversation():
"""Test that from_dict() creates an instance with empty conversation."""
data = {
"awaiting_agent_id": "agent1",
"prompt": "Enter input",
"source_executor_id": "gateway",
}
request = HandoffUserInputRequest.from_dict(data)
assert request.conversation == []
assert request.awaiting_agent_id == "agent1"
assert request.prompt == "Enter input"
assert request.source_executor_id == "gateway"
async def test_user_input_gateway_resume_handles_empty_conversation():
"""Test that _UserInputGateway.resume_from_user handles post-restore scenario.
After checkpoint restore, the HandoffUserInputRequest will have an empty
conversation. The gateway should handle this by sending only the new user
messages to the coordinator.
"""
from unittest.mock import AsyncMock
# Create a gateway
gateway = _UserInputGateway(
starting_agent_id="coordinator",
prompt="Enter input",
id="test-gateway",
)
# Simulate post-restore: request with empty conversation
restored_request = HandoffUserInputRequest(
conversation=[], # Empty after restore
awaiting_agent_id="specialist",
prompt="Enter input",
source_executor_id="test-gateway",
)
# Create mock context
mock_ctx = MagicMock()
mock_ctx.send_message = AsyncMock()
# Call resume_from_user with a user response
await gateway.resume_from_user(restored_request, "New user message", mock_ctx)
# Verify send_message was called
mock_ctx.send_message.assert_called_once()
# Get the message that was sent
call_args = mock_ctx.send_message.call_args
sent_message = call_args[0][0]
# Verify it's a _ConversationWithUserInput
assert isinstance(sent_message, _ConversationWithUserInput)
# Verify it contains only the new user message (not any history)
assert len(sent_message.full_conversation) == 1
assert sent_message.full_conversation[0].role == Role.USER
assert sent_message.full_conversation[0].text == "New user message"
async def test_user_input_gateway_resume_with_full_conversation():
"""Test that _UserInputGateway.resume_from_user handles normal flow correctly.
In normal flow (no checkpoint restore), the HandoffUserInputRequest has
the full conversation. The gateway should send the full conversation
plus the new user messages.
"""
from unittest.mock import AsyncMock
# Create a gateway
gateway = _UserInputGateway(
starting_agent_id="coordinator",
prompt="Enter input",
id="test-gateway",
)
# Normal flow: request with full conversation
normal_request = HandoffUserInputRequest(
conversation=[
ChatMessage(role=Role.USER, text="Hello"),
ChatMessage(role=Role.ASSISTANT, text="Hi!"),
],
awaiting_agent_id="specialist",
prompt="Enter input",
source_executor_id="test-gateway",
)
# Create mock context
mock_ctx = MagicMock()
mock_ctx.send_message = AsyncMock()
# Call resume_from_user with a user response
await gateway.resume_from_user(normal_request, "Follow up message", mock_ctx)
# Verify send_message was called
mock_ctx.send_message.assert_called_once()
# Get the message that was sent
call_args = mock_ctx.send_message.call_args
sent_message = call_args[0][0]
# Verify it's a _ConversationWithUserInput
assert isinstance(sent_message, _ConversationWithUserInput)
# Verify it contains the full conversation plus new user message
assert len(sent_message.full_conversation) == 3
assert sent_message.full_conversation[0].text == "Hello"
assert sent_message.full_conversation[1].text == "Hi!"
assert sent_message.full_conversation[2].text == "Follow up message"
async def test_coordinator_handle_user_input_post_restore():
"""Test that _HandoffCoordinator.handle_user_input handles post-restore correctly.
After checkpoint restore, the coordinator has its conversation restored,
and the gateway sends only the new user messages. The coordinator should
append these to its existing conversation rather than replacing.
"""
from unittest.mock import AsyncMock
from agent_framework._workflows._handoff import _HandoffCoordinator
# Create a coordinator with pre-existing conversation (simulating restored state)
coordinator = _HandoffCoordinator(
starting_agent_id="triage",
specialist_ids={"specialist_a": "specialist_a"},
input_gateway_id="gateway",
termination_condition=lambda conv: False,
id="test-coordinator",
)
# Simulate restored conversation
coordinator._conversation = [
ChatMessage(role=Role.USER, text="Hello"),
ChatMessage(role=Role.ASSISTANT, text="Hi there!"),
ChatMessage(role=Role.USER, text="Help me"),
ChatMessage(role=Role.ASSISTANT, text="Sure, what do you need?"),
]
# Create mock context
mock_ctx = MagicMock()
mock_ctx.send_message = AsyncMock()
# Simulate post-restore: only new user message with explicit flag
incoming = _ConversationWithUserInput(
full_conversation=[ChatMessage(role=Role.USER, text="I need shipping help")],
is_post_restore=True,
)
# Handle the user input
await coordinator.handle_user_input(incoming, mock_ctx)
# Verify conversation was appended, not replaced
assert len(coordinator._conversation) == 5
assert coordinator._conversation[0].text == "Hello"
assert coordinator._conversation[1].text == "Hi there!"
assert coordinator._conversation[2].text == "Help me"
assert coordinator._conversation[3].text == "Sure, what do you need?"
assert coordinator._conversation[4].text == "I need shipping help"
async def test_coordinator_handle_user_input_normal_flow():
"""Test that _HandoffCoordinator.handle_user_input handles normal flow correctly.
In normal flow (no restore), the gateway sends the full conversation.
The coordinator should replace its conversation with the incoming one.
"""
from unittest.mock import AsyncMock
from agent_framework._workflows._handoff import _HandoffCoordinator
# Create a coordinator
coordinator = _HandoffCoordinator(
starting_agent_id="triage",
specialist_ids={"specialist_a": "specialist_a"},
input_gateway_id="gateway",
termination_condition=lambda conv: False,
id="test-coordinator",
)
# Set some initial conversation
coordinator._conversation = [
ChatMessage(role=Role.USER, text="Old message"),
]
# Create mock context
mock_ctx = MagicMock()
mock_ctx.send_message = AsyncMock()
# Normal flow: full conversation including new user message (is_post_restore=False by default)
incoming = _ConversationWithUserInput(
full_conversation=[
ChatMessage(role=Role.USER, text="Hello"),
ChatMessage(role=Role.ASSISTANT, text="Hi!"),
ChatMessage(role=Role.USER, text="New message"),
],
is_post_restore=False,
)
# Handle the user input
await coordinator.handle_user_input(incoming, mock_ctx)
# Verify conversation was replaced (normal flow with full history)
assert len(coordinator._conversation) == 3
assert coordinator._conversation[0].text == "Hello"
assert coordinator._conversation[1].text == "Hi!"
assert coordinator._conversation[2].text == "New message"
async def test_coordinator_handle_user_input_multiple_consecutive_user_messages():
"""Test that multiple consecutive USER messages in normal flow are handled correctly.
This is a regression test for the edge case where a user submits multiple consecutive
USER messages. The explicit is_post_restore flag ensures this doesn't get incorrectly
detected as a post-restore scenario.
"""
from unittest.mock import AsyncMock
from agent_framework._workflows._handoff import _HandoffCoordinator
# Create a coordinator with existing conversation
coordinator = _HandoffCoordinator(
starting_agent_id="triage",
specialist_ids={"specialist_a": "specialist_a"},
input_gateway_id="gateway",
termination_condition=lambda conv: False,
id="test-coordinator",
)
# Set existing conversation with 4 messages
coordinator._conversation = [
ChatMessage(role=Role.USER, text="Original message 1"),
ChatMessage(role=Role.ASSISTANT, text="Response 1"),
ChatMessage(role=Role.USER, text="Original message 2"),
ChatMessage(role=Role.ASSISTANT, text="Response 2"),
]
# Create mock context
mock_ctx = MagicMock()
mock_ctx.send_message = AsyncMock()
# Normal flow: User sends multiple consecutive USER messages
# This should REPLACE the conversation, not append to it
incoming = _ConversationWithUserInput(
full_conversation=[
ChatMessage(role=Role.USER, text="New user message 1"),
ChatMessage(role=Role.USER, text="New user message 2"),
],
is_post_restore=False, # Explicit flag - this is normal flow
)
# Handle the user input
await coordinator.handle_user_input(incoming, mock_ctx)
# Verify conversation was REPLACED (not appended)
# Without the explicit flag, the old heuristic might incorrectly append
assert len(coordinator._conversation) == 2
assert coordinator._conversation[0].text == "New user message 1"
assert coordinator._conversation[1].text == "New user message 2"
@@ -122,11 +122,17 @@ def _print_handoff_request(request: HandoffUserInputRequest, request_id: str) ->
print(f"Awaiting agent: {request.awaiting_agent_id}")
print(f"Prompt: {request.prompt}")
print("\nConversation so far:")
for msg in request.conversation[-3:]:
author = msg.author_name or msg.role.value
snippet = msg.text[:120] + "..." if len(msg.text) > 120 else msg.text
print(f" {author}: {snippet}")
# Note: After checkpoint restore, conversation may be empty because it's not serialized
# to prevent duplication (the conversation is preserved in the coordinator's state).
# See issue #2667.
if request.conversation:
print("\nConversation so far:")
for msg in request.conversation[-3:]:
author = msg.author_name or msg.role.value
snippet = msg.text[:120] + "..." if len(msg.text) > 120 else msg.text
print(f" {author}: {snippet}")
else:
print("\n(Conversation restored from checkpoint - context preserved in workflow state)")
print(f"{'=' * 60}\n")
@@ -273,11 +279,7 @@ async def resume_with_responses(
elif isinstance(event, WorkflowOutputEvent):
print("\n[Workflow Output Event - Conversation Update]")
if (
event.data
and isinstance(event.data, list)
and all(isinstance(msg, ChatMessage) for msg in event.data)
):
if event.data and isinstance(event.data, list) and all(isinstance(msg, ChatMessage) for msg in event.data):
# Now safe to cast event.data to list[ChatMessage]
conversation = cast(list[ChatMessage], event.data)
for msg in conversation[-3:]: # Show last 3 messages