Python: Ensure agent thread is part of checkpoint (#1756)

* ensure agent thread is part of checkpoint

* Update python/packages/core/agent_framework/_workflows/_agent_executor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* remove data copying for server side thread.

* refine warning check

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Eric Zhu
2025-10-31 08:48:43 -07:00
committed by GitHub
Unverified
parent 68b6a55757
commit 889f45a7ef
2 changed files with 261 additions and 4 deletions
@@ -7,6 +7,7 @@ from typing import Any
from .._agents import AgentProtocol, ChatAgent
from .._threads import AgentThread
from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from ._conversation_state import encode_chat_messages
from ._events import (
AgentRunEvent,
AgentRunUpdateEvent, # type: ignore[reportPrivateUsage]
@@ -191,19 +192,43 @@ class AgentExecutor(Executor):
self._cache = normalize_messages_input(messages)
await self._run_agent_and_emit(ctx)
def snapshot_state(self) -> dict[str, Any]:
async def snapshot_state(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.
NOTE: if the thread storage is on the server side, the full thread state
may not be serialized locally. Therefore, we are relying on the server-side
to ensure the thread state is preserved and immutable across checkpoints.
This is not the case for AzureAI Agents, but works for the Responses API.
Returns:
Dict containing serialized cache state
Dict containing serialized cache and thread state
"""
from ._conversation_state import encode_chat_messages
# Check if using AzureAIAgentClient with server-side thread and warn about checkpointing limitations
if isinstance(self._agent, ChatAgent) and self._agent_thread.service_thread_id is not None:
client_class_name = self._agent.chat_client.__class__.__name__
client_module = self._agent.chat_client.__class__.__module__
if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module:
# TODO(TaoChenOSU): update this warning when we surface the hooks for
# custom executor checkpointing.
# https://github.com/microsoft/agent-framework/issues/1816
logger.warning(
"Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. "
"Currently, checkpointing does not capture messages from server-side threads "
"(service_thread_id: %s). The thread state in checkpoints is not immutable and can be "
"modified by subsequent runs. If you need reliable checkpointing with Azure AI agents, "
"consider implementing a custom executor and managing the thread state yourself.",
self._agent_thread.service_thread_id,
)
serialized_thread = await self._agent_thread.serialize()
return {
"cache": encode_chat_messages(self._cache),
"agent_thread": serialized_thread,
}
def restore_state(self, state: dict[str, Any]) -> None:
async def restore_state(self, state: dict[str, Any]) -> None:
"""Restore executor state from checkpoint.
Args:
@@ -221,6 +246,18 @@ class AgentExecutor(Executor):
else:
self._cache = []
thread_payload = state.get("agent_thread")
if thread_payload:
try:
# Deserialize the thread state directly
self._agent_thread = await AgentThread.deserialize(thread_payload)
except Exception as exc:
logger.warning("Failed to restore agent thread: %s", exc)
self._agent_thread = self._agent.get_new_thread()
else:
self._agent_thread = self._agent.get_new_thread()
def reset(self) -> None:
"""Reset the internal cache of the executor."""
logger.debug("AgentExecutor %s: Resetting cache", self.id)
@@ -0,0 +1,220 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable
from typing import Any
from agent_framework import (
AgentExecutor,
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
BaseAgent,
ChatMessage,
ChatMessageStore,
Role,
SequentialBuilder,
TextContent,
WorkflowOutputEvent,
WorkflowRunState,
WorkflowStatusEvent,
)
from agent_framework._workflows._agent_executor import AgentExecutorResponse
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
class _CountingAgent(BaseAgent):
"""Agent that echoes messages with a counter to verify thread state persistence."""
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.call_count = 0
async def run( # type: ignore[override]
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AgentRunResponse:
self.call_count += 1
return AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text=f"Response #{self.call_count}: {self.display_name}")]
)
async def run_stream( # type: ignore[override]
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
self.call_count += 1
yield AgentRunResponseUpdate(contents=[TextContent(text=f"Response #{self.call_count}: {self.display_name}")])
async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
"""Test that workflow checkpoint stores AgentExecutor's cache and thread states and restores them correctly."""
storage = InMemoryCheckpointStorage()
# Create initial agent with a custom thread that has a message store
initial_agent = _CountingAgent(id="test_agent", name="TestAgent")
initial_thread = AgentThread(message_store=ChatMessageStore())
# Add some initial messages to the thread to verify thread state persistence
initial_messages = [
ChatMessage(role=Role.USER, text="Initial message 1"),
ChatMessage(role=Role.ASSISTANT, text="Initial response 1"),
]
await initial_thread.on_new_messages(initial_messages)
# Create AgentExecutor with the thread
executor = AgentExecutor(initial_agent, agent_thread=initial_thread)
# Build workflow with checkpointing enabled
wf = SequentialBuilder().participants([executor]).with_checkpointing(storage).build()
# Run the workflow with a user message
first_run_output: AgentExecutorResponse | None = None
async for ev in wf.run_stream("First workflow run"):
if isinstance(ev, WorkflowOutputEvent):
first_run_output = ev.data # type: ignore[assignment]
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
break
assert first_run_output is not None
assert initial_agent.call_count == 1
# Verify checkpoint was created
checkpoints = await storage.list_checkpoints()
assert len(checkpoints) > 0
# Find a suitable checkpoint to restore (prefer superstep checkpoint)
checkpoints.sort(key=lambda cp: cp.timestamp)
restore_checkpoint = next(
(cp for cp in checkpoints if (cp.metadata or {}).get("checkpoint_type") == "superstep"),
checkpoints[-1],
)
# Verify checkpoint contains executor state with both cache and thread
assert "_executor_state" in restore_checkpoint.shared_state
executor_states = restore_checkpoint.shared_state["_executor_state"]
assert isinstance(executor_states, dict)
assert executor.id in executor_states
executor_state = executor_states[executor.id] # type: ignore[index]
assert "cache" in executor_state, "Checkpoint should store executor cache state"
assert "agent_thread" in executor_state, "Checkpoint should store executor thread state"
# Verify thread state includes message store
thread_state = executor_state["agent_thread"] # type: ignore[index]
assert "chat_message_store_state" in thread_state, "Thread state should include message store"
chat_store_state = thread_state["chat_message_store_state"] # type: ignore[index]
assert "messages" in chat_store_state, "Message store state should include messages"
# Create a new agent and executor for restoration
# This simulates starting from a fresh state and restoring from checkpoint
restored_agent = _CountingAgent(id="test_agent", name="TestAgent")
restored_thread = AgentThread(message_store=ChatMessageStore())
restored_executor = AgentExecutor(restored_agent, agent_thread=restored_thread)
# Verify the restored agent starts with a fresh state
assert restored_agent.call_count == 0
# Build new workflow with the restored executor
wf_resume = SequentialBuilder().participants([restored_executor]).with_checkpointing(storage).build()
# Resume from checkpoint
resumed_output: AgentExecutorResponse | None = None
async for ev in wf_resume.run_stream_from_checkpoint(restore_checkpoint.checkpoint_id):
if isinstance(ev, WorkflowOutputEvent):
resumed_output = ev.data # type: ignore[assignment]
if isinstance(ev, WorkflowStatusEvent) and ev.state in (
WorkflowRunState.IDLE,
WorkflowRunState.IDLE_WITH_PENDING_REQUESTS,
):
break
assert resumed_output is not None
# Verify the restored executor's state matches the original
# The cache should be restored (though it may be cleared after processing)
# The thread should have all messages including those from the initial state
message_store = restored_executor._agent_thread.message_store # type: ignore[reportPrivateUsage]
assert message_store is not None
thread_messages = await message_store.list_messages()
# Thread should contain:
# 1. Initial messages from before the checkpoint (2 messages)
# 2. User message from first run (1 message)
# 3. Assistant response from first run (1 message)
assert len(thread_messages) >= 2, "Thread should preserve initial messages from before checkpoint"
# Verify initial messages are preserved
assert thread_messages[0].text == "Initial message 1"
assert thread_messages[1].text == "Initial response 1"
async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
"""Test AgentExecutor's snapshot_state and restore_state methods directly."""
# Create agent with thread containing messages
agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent")
thread = AgentThread(message_store=ChatMessageStore())
# Add messages to thread
thread_messages = [
ChatMessage(role=Role.USER, text="Message in thread 1"),
ChatMessage(role=Role.ASSISTANT, text="Thread response 1"),
ChatMessage(role=Role.USER, text="Message in thread 2"),
]
await thread.on_new_messages(thread_messages)
executor = AgentExecutor(agent, agent_thread=thread)
# Add messages to executor cache
cache_messages = [
ChatMessage(role=Role.USER, text="Cached user message"),
ChatMessage(role=Role.ASSISTANT, text="Cached assistant response"),
]
executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage]
# Snapshot the state
state = await executor.snapshot_state() # type: ignore[reportUnknownMemberType]
# Verify snapshot contains both cache and thread
assert "cache" in state
assert "agent_thread" in state
# Verify thread state structure
thread_state = state["agent_thread"] # type: ignore[index]
assert "chat_message_store_state" in thread_state
assert "messages" in thread_state["chat_message_store_state"]
# Create new executor to restore into
new_agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent")
new_thread = AgentThread(message_store=ChatMessageStore())
new_executor = AgentExecutor(new_agent, agent_thread=new_thread)
# Verify new executor starts empty
assert len(new_executor._cache) == 0 # type: ignore[reportPrivateUsage]
initial_message_store = new_thread.message_store
assert initial_message_store is not None
initial_thread_msgs = await initial_message_store.list_messages()
assert len(initial_thread_msgs) == 0
# Restore state
await new_executor.restore_state(state) # type: ignore[reportUnknownMemberType]
# Verify cache is restored
restored_cache = new_executor._cache # type: ignore[reportPrivateUsage]
assert len(restored_cache) == len(cache_messages)
assert restored_cache[0].text == "Cached user message"
assert restored_cache[1].text == "Cached assistant response"
# Verify thread messages are restored
restored_message_store = new_executor._agent_thread.message_store # type: ignore[reportPrivateUsage]
assert restored_message_store is not None
restored_thread_msgs = await restored_message_store.list_messages()
assert len(restored_thread_msgs) == len(thread_messages)
assert restored_thread_msgs[0].text == "Message in thread 1"
assert restored_thread_msgs[1].text == "Thread response 1"
assert restored_thread_msgs[2].text == "Message in thread 2"