diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 07f737db77..0c05abbb69 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -7,6 +7,7 @@ from typing import Any from .._agents import AgentProtocol, ChatAgent from .._threads import AgentThread from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage +from ._conversation_state import encode_chat_messages from ._events import ( AgentRunEvent, AgentRunUpdateEvent, # type: ignore[reportPrivateUsage] @@ -191,19 +192,43 @@ class AgentExecutor(Executor): self._cache = normalize_messages_input(messages) await self._run_agent_and_emit(ctx) - def snapshot_state(self) -> dict[str, Any]: + async def snapshot_state(self) -> dict[str, Any]: """Capture current executor state for checkpointing. + NOTE: if the thread storage is on the server side, the full thread state + may not be serialized locally. Therefore, we are relying on the server-side + to ensure the thread state is preserved and immutable across checkpoints. + This is not the case for AzureAI Agents, but works for the Responses API. + Returns: - Dict containing serialized cache state + Dict containing serialized cache and thread state """ - from ._conversation_state import encode_chat_messages + # Check if using AzureAIAgentClient with server-side thread and warn about checkpointing limitations + if isinstance(self._agent, ChatAgent) and self._agent_thread.service_thread_id is not None: + client_class_name = self._agent.chat_client.__class__.__name__ + client_module = self._agent.chat_client.__class__.__module__ + + if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module: + # TODO(TaoChenOSU): update this warning when we surface the hooks for + # custom executor checkpointing. + # https://github.com/microsoft/agent-framework/issues/1816 + logger.warning( + "Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. " + "Currently, checkpointing does not capture messages from server-side threads " + "(service_thread_id: %s). The thread state in checkpoints is not immutable and can be " + "modified by subsequent runs. If you need reliable checkpointing with Azure AI agents, " + "consider implementing a custom executor and managing the thread state yourself.", + self._agent_thread.service_thread_id, + ) + + serialized_thread = await self._agent_thread.serialize() return { "cache": encode_chat_messages(self._cache), + "agent_thread": serialized_thread, } - def restore_state(self, state: dict[str, Any]) -> None: + async def restore_state(self, state: dict[str, Any]) -> None: """Restore executor state from checkpoint. Args: @@ -221,6 +246,18 @@ class AgentExecutor(Executor): else: self._cache = [] + thread_payload = state.get("agent_thread") + if thread_payload: + try: + # Deserialize the thread state directly + self._agent_thread = await AgentThread.deserialize(thread_payload) + + except Exception as exc: + logger.warning("Failed to restore agent thread: %s", exc) + self._agent_thread = self._agent.get_new_thread() + else: + self._agent_thread = self._agent.get_new_thread() + def reset(self) -> None: """Reset the internal cache of the executor.""" logger.debug("AgentExecutor %s: Resetting cache", self.id) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py new file mode 100644 index 0000000000..59ffbf8436 --- /dev/null +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -0,0 +1,220 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable +from typing import Any + +from agent_framework import ( + AgentExecutor, + AgentRunResponse, + AgentRunResponseUpdate, + AgentThread, + BaseAgent, + ChatMessage, + ChatMessageStore, + Role, + SequentialBuilder, + TextContent, + WorkflowOutputEvent, + WorkflowRunState, + WorkflowStatusEvent, +) +from agent_framework._workflows._agent_executor import AgentExecutorResponse +from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage + + +class _CountingAgent(BaseAgent): + """Agent that echoes messages with a counter to verify thread state persistence.""" + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.call_count = 0 + + async def run( # type: ignore[override] + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + self.call_count += 1 + return AgentRunResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text=f"Response #{self.call_count}: {self.display_name}")] + ) + + async def run_stream( # type: ignore[override] + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + self.call_count += 1 + yield AgentRunResponseUpdate(contents=[TextContent(text=f"Response #{self.call_count}: {self.display_name}")]) + + +async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: + """Test that workflow checkpoint stores AgentExecutor's cache and thread states and restores them correctly.""" + storage = InMemoryCheckpointStorage() + + # Create initial agent with a custom thread that has a message store + initial_agent = _CountingAgent(id="test_agent", name="TestAgent") + initial_thread = AgentThread(message_store=ChatMessageStore()) + + # Add some initial messages to the thread to verify thread state persistence + initial_messages = [ + ChatMessage(role=Role.USER, text="Initial message 1"), + ChatMessage(role=Role.ASSISTANT, text="Initial response 1"), + ] + await initial_thread.on_new_messages(initial_messages) + + # Create AgentExecutor with the thread + executor = AgentExecutor(initial_agent, agent_thread=initial_thread) + + # Build workflow with checkpointing enabled + wf = SequentialBuilder().participants([executor]).with_checkpointing(storage).build() + + # Run the workflow with a user message + first_run_output: AgentExecutorResponse | None = None + async for ev in wf.run_stream("First workflow run"): + if isinstance(ev, WorkflowOutputEvent): + first_run_output = ev.data # type: ignore[assignment] + if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: + break + + assert first_run_output is not None + assert initial_agent.call_count == 1 + + # Verify checkpoint was created + checkpoints = await storage.list_checkpoints() + assert len(checkpoints) > 0 + + # Find a suitable checkpoint to restore (prefer superstep checkpoint) + checkpoints.sort(key=lambda cp: cp.timestamp) + restore_checkpoint = next( + (cp for cp in checkpoints if (cp.metadata or {}).get("checkpoint_type") == "superstep"), + checkpoints[-1], + ) + + # Verify checkpoint contains executor state with both cache and thread + assert "_executor_state" in restore_checkpoint.shared_state + executor_states = restore_checkpoint.shared_state["_executor_state"] + assert isinstance(executor_states, dict) + assert executor.id in executor_states + + executor_state = executor_states[executor.id] # type: ignore[index] + assert "cache" in executor_state, "Checkpoint should store executor cache state" + assert "agent_thread" in executor_state, "Checkpoint should store executor thread state" + + # Verify thread state includes message store + thread_state = executor_state["agent_thread"] # type: ignore[index] + assert "chat_message_store_state" in thread_state, "Thread state should include message store" + chat_store_state = thread_state["chat_message_store_state"] # type: ignore[index] + assert "messages" in chat_store_state, "Message store state should include messages" + + # Create a new agent and executor for restoration + # This simulates starting from a fresh state and restoring from checkpoint + restored_agent = _CountingAgent(id="test_agent", name="TestAgent") + restored_thread = AgentThread(message_store=ChatMessageStore()) + restored_executor = AgentExecutor(restored_agent, agent_thread=restored_thread) + + # Verify the restored agent starts with a fresh state + assert restored_agent.call_count == 0 + + # Build new workflow with the restored executor + wf_resume = SequentialBuilder().participants([restored_executor]).with_checkpointing(storage).build() + + # Resume from checkpoint + resumed_output: AgentExecutorResponse | None = None + async for ev in wf_resume.run_stream_from_checkpoint(restore_checkpoint.checkpoint_id): + if isinstance(ev, WorkflowOutputEvent): + resumed_output = ev.data # type: ignore[assignment] + if isinstance(ev, WorkflowStatusEvent) and ev.state in ( + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + ): + break + + assert resumed_output is not None + + # Verify the restored executor's state matches the original + # The cache should be restored (though it may be cleared after processing) + # The thread should have all messages including those from the initial state + message_store = restored_executor._agent_thread.message_store # type: ignore[reportPrivateUsage] + assert message_store is not None + thread_messages = await message_store.list_messages() + + # Thread should contain: + # 1. Initial messages from before the checkpoint (2 messages) + # 2. User message from first run (1 message) + # 3. Assistant response from first run (1 message) + assert len(thread_messages) >= 2, "Thread should preserve initial messages from before checkpoint" + + # Verify initial messages are preserved + assert thread_messages[0].text == "Initial message 1" + assert thread_messages[1].text == "Initial response 1" + + +async def test_agent_executor_snapshot_and_restore_state_directly() -> None: + """Test AgentExecutor's snapshot_state and restore_state methods directly.""" + # Create agent with thread containing messages + agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent") + thread = AgentThread(message_store=ChatMessageStore()) + + # Add messages to thread + thread_messages = [ + ChatMessage(role=Role.USER, text="Message in thread 1"), + ChatMessage(role=Role.ASSISTANT, text="Thread response 1"), + ChatMessage(role=Role.USER, text="Message in thread 2"), + ] + await thread.on_new_messages(thread_messages) + + executor = AgentExecutor(agent, agent_thread=thread) + + # Add messages to executor cache + cache_messages = [ + ChatMessage(role=Role.USER, text="Cached user message"), + ChatMessage(role=Role.ASSISTANT, text="Cached assistant response"), + ] + executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage] + + # Snapshot the state + state = await executor.snapshot_state() # type: ignore[reportUnknownMemberType] + + # Verify snapshot contains both cache and thread + assert "cache" in state + assert "agent_thread" in state + + # Verify thread state structure + thread_state = state["agent_thread"] # type: ignore[index] + assert "chat_message_store_state" in thread_state + assert "messages" in thread_state["chat_message_store_state"] + + # Create new executor to restore into + new_agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent") + new_thread = AgentThread(message_store=ChatMessageStore()) + new_executor = AgentExecutor(new_agent, agent_thread=new_thread) + + # Verify new executor starts empty + assert len(new_executor._cache) == 0 # type: ignore[reportPrivateUsage] + initial_message_store = new_thread.message_store + assert initial_message_store is not None + initial_thread_msgs = await initial_message_store.list_messages() + assert len(initial_thread_msgs) == 0 + + # Restore state + await new_executor.restore_state(state) # type: ignore[reportUnknownMemberType] + + # Verify cache is restored + restored_cache = new_executor._cache # type: ignore[reportPrivateUsage] + assert len(restored_cache) == len(cache_messages) + assert restored_cache[0].text == "Cached user message" + assert restored_cache[1].text == "Cached assistant response" + + # Verify thread messages are restored + restored_message_store = new_executor._agent_thread.message_store # type: ignore[reportPrivateUsage] + assert restored_message_store is not None + restored_thread_msgs = await restored_message_store.list_messages() + assert len(restored_thread_msgs) == len(thread_messages) + assert restored_thread_msgs[0].text == "Message in thread 1" + assert restored_thread_msgs[1].text == "Thread response 1" + assert restored_thread_msgs[2].text == "Message in thread 2"