Python: Add checkpoint save and restore hooks to executor (#2097)

* Add checkpoint hooks

* Deprecate get_executor_state and set_executor_state

* Fix tests and samples

* Add doc strings

* Add sample

* Fix import

* Address comments and fix tests

* Address comments

* conditional import
This commit is contained in:
Tao Chen
2025-11-17 10:19:01 -08:00
committed by GitHub
Unverified
parent 132597957a
commit c361ad8d33
22 changed files with 508 additions and 723 deletions
@@ -37,6 +37,8 @@ from ._events import (
ExecutorFailedEvent,
ExecutorInvokedEvent,
RequestInfoEvent,
SuperStepCompletedEvent,
SuperStepStartedEvent,
WorkflowErrorDetails,
WorkflowEvent,
WorkflowEventSource,
@@ -152,6 +154,8 @@ __all__ = [
"StandardMagenticManager",
"SubWorkflowRequestMessage",
"SubWorkflowResponseMessage",
"SuperStepCompletedEvent",
"SuperStepStartedEvent",
"SwitchCaseEdgeGroup",
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
@@ -35,6 +35,8 @@ from ._events import (
ExecutorFailedEvent,
ExecutorInvokedEvent,
RequestInfoEvent,
SuperStepCompletedEvent,
SuperStepStartedEvent,
WorkflowErrorDetails,
WorkflowEvent,
WorkflowEventSource,
@@ -148,6 +150,8 @@ __all__ = [
"StandardMagenticManager",
"SubWorkflowRequestMessage",
"SubWorkflowResponseMessage",
"SuperStepCompletedEvent",
"SuperStepStartedEvent",
"SwitchCaseEdgeGroup",
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
import sys
from dataclasses import dataclass
from typing import Any, cast
@@ -20,6 +21,11 @@ from ._message_utils import normalize_messages_input
from ._request_info_mixin import response_handler
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -179,7 +185,8 @@ class AgentExecutor(Executor):
self._pending_responses_to_agent.clear()
await self._run_agent_and_emit(ctx)
async def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.
NOTE: if the thread storage is on the server side, the full thread state
@@ -196,9 +203,6 @@ class AgentExecutor(Executor):
client_module = self._agent.chat_client.__class__.__module__
if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module:
# TODO(TaoChenOSU): update this warning when we surface the hooks for
# custom executor checkpointing.
# https://github.com/microsoft/agent-framework/issues/1816
logger.warning(
"Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. "
"Currently, checkpointing does not capture messages from server-side threads "
@@ -217,7 +221,8 @@ class AgentExecutor(Executor):
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
}
async def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore executor state from checkpoint.
Args:
@@ -4,6 +4,7 @@
import inspect
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
@@ -13,6 +14,12 @@ from ._executor import Executor
from ._orchestrator_helpers import ParticipantRegistry
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -210,11 +217,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
# State persistence (shared across all patterns)
def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current orchestrator state for checkpointing.
Default implementation uses OrchestrationState to serialize common state.
Subclasses should override _snapshot_pattern_metadata() to add pattern-specific data.
Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data.
Returns:
Serialized state dict
@@ -238,11 +246,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
"""
return {}
def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore orchestrator state from checkpoint.
Default implementation uses OrchestrationState to deserialize common state.
Subclasses should override _restore_pattern_metadata() to restore pattern-specific data.
Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data.
Args:
state: Serialized state dict
@@ -6,9 +6,7 @@ These utilities operate on standard `list[ChatMessage]` collections and simple
dictionary snapshots so orchestrators can share logic without new mixins.
"""
import json
from collections.abc import Mapping, Sequence
from typing import Any
from collections.abc import Sequence
from .._types import ChatMessage
@@ -26,25 +24,3 @@ def ensure_author(message: ChatMessage, fallback: str) -> ChatMessage:
"""Attach `fallback` author if message is missing `author_name`."""
message.author_name = message.author_name or fallback
return message
def snapshot_state(conversation: Sequence[ChatMessage]) -> dict[str, Any]:
"""Build an immutable snapshot for checkpoint storage."""
if hasattr(conversation, "to_dict"):
result = conversation.to_dict() # type: ignore[attr-defined]
if isinstance(result, dict):
return result # type: ignore[return-value]
if isinstance(result, Mapping):
return dict(result) # type: ignore[arg-type]
serialisable: list[dict[str, Any]] = []
for message in conversation:
if hasattr(message, "to_dict") and callable(message.to_dict): # type: ignore[attr-defined]
msg_dict = message.to_dict() # type: ignore[attr-defined]
serialisable.append(dict(msg_dict) if isinstance(msg_dict, Mapping) else msg_dict) # type: ignore[arg-type]
elif hasattr(message, "to_json") and callable(message.to_json): # type: ignore[attr-defined]
json_payload = message.to_json() # type: ignore[attr-defined]
parsed = json.loads(json_payload) if isinstance(json_payload, str) else json_payload
serialisable.append(dict(parsed) if isinstance(parsed, Mapping) else parsed) # type: ignore[arg-type]
else:
serialisable.append(dict(getattr(message, "__dict__", {}))) # type: ignore[arg-type]
return {"messages": serialisable}
@@ -294,6 +294,36 @@ class WorkflowOutputEvent(WorkflowEvent):
return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})"
class SuperStepEvent(WorkflowEvent):
"""Event triggered when a superstep starts or ends."""
def __init__(self, iteration: int, data: Any | None = None):
"""Initialize the superstep event.
Args:
iteration: The number of the superstep (1-based index).
data: Optional data associated with the superstep event.
"""
super().__init__(data)
self.iteration = iteration
def __repr__(self) -> str:
"""Return a string representation of the superstep event."""
return f"{self.__class__.__name__}(iteration={self.iteration}, data={self.data})"
class SuperStepStartedEvent(SuperStepEvent):
"""Event triggered when a superstep starts."""
...
class SuperStepCompletedEvent(SuperStepEvent):
"""Event triggered when a superstep ends."""
...
class ExecutorEvent(WorkflowEvent):
"""Base class for executor events."""
@@ -310,17 +340,13 @@ class ExecutorEvent(WorkflowEvent):
class ExecutorInvokedEvent(ExecutorEvent):
"""Event triggered when an executor handler is invoked."""
def __repr__(self) -> str:
"""Return a string representation of the executor handler invoke event."""
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
...
class ExecutorCompletedEvent(ExecutorEvent):
"""Event triggered when an executor handler is completed."""
def __repr__(self) -> str:
"""Return a string representation of the executor handler complete event."""
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
...
class ExecutorFailedEvent(ExecutorEvent):
@@ -155,6 +155,11 @@ class Executor(RequestInfoMixin, DictConvertible):
that parent workflows can intercept. See WorkflowExecutor documentation for details on
workflow composition patterns and request/response handling.
## State Management
Executors can contain states that persist across workflow runs and checkpoints. Override the
`on_checkpoint_save` and `on_checkpoint_restore` methods to implement custom state
serialization and restoration logic.
## Implementation Notes
- Do not call `execute()` directly - it's invoked by the workflow engine
- Do not override `execute()` - define handlers using decorators instead
@@ -460,6 +465,32 @@ class Executor(RequestInfoMixin, DictConvertible):
return self._handlers[message_type]
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Hook called when the workflow is being saved to a checkpoint.
Override this method in subclasses to implement custom logic that should
return state to be saved in the checkpoint.
The returned state dictionary will be passed to `on_checkpoint_restore`
when the workflow is restored from the checkpoint. The dictionary should
only contain JSON-serializable data.
Returns:
A state dictionary to be saved during checkpointing.
"""
return {}
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Hook called when the workflow is restored from a checkpoint.
Override this method in subclasses to implement custom logic that should
run when the workflow is restored from a checkpoint.
Args:
state: The state dictionary that was saved during checkpointing.
"""
...
# endregion: Executor
@@ -16,6 +16,7 @@ Key properties:
import logging
import re
import sys
from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any
@@ -50,6 +51,12 @@ from ._workflow import Workflow
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -307,15 +314,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
) -> None:
"""Process an agent's response and determine whether to route, request input, or terminate."""
# Hydrate coordinator state (and detect new run) using checkpointable executor state
state = await ctx.get_executor_state()
if not state:
self._clear_conversation()
elif not self._get_conversation():
restored = self._restore_conversation_from_state(state)
if restored:
self._conversation = list(restored)
source = ctx.get_source_executor_id()
is_starting_agent = source == self._starting_agent_id
@@ -343,7 +341,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
# Update current agent when handoff occurs
self._current_agent_id = target
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
await self._persist_state(ctx)
# Clean tool-related content before sending to next agent
cleaned = clean_conversation_for_handoff(conversation)
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
@@ -360,7 +358,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
f"Agent '{source}' responded without handoff. "
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
)
await self._persist_state(ctx)
if await self._check_termination():
# Clean the output conversation for display
@@ -388,7 +385,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
"""Receive full conversation with new user input from gateway, update history, trim for agent."""
# Update authoritative conversation
self._conversation = list(message.full_conversation)
await self._persist_state(ctx)
# Check termination before sending to agent
if await self._check_termination():
@@ -473,11 +469,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
)
return list(conversation)
async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None:
"""Store authoritative conversation snapshot without losing rich metadata."""
state_payload = self.snapshot_state()
await ctx.set_executor_state(state_payload)
@override
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""Serialize pattern-specific state.
@@ -492,6 +484,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
}
return {}
@override
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
"""Restore pattern-specific state.
@@ -503,17 +496,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
if self._return_to_previous and "current_agent_id" in metadata:
self._current_agent_id = metadata["current_agent_id"]
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
"""Rehydrate the coordinator's conversation history from checkpointed state.
DEPRECATED: Use restore_state() instead. Kept for backward compatibility.
"""
from ._orchestration_state import OrchestrationState
orch_state_dict = {"conversation": state.get("full_conversation", state.get("conversation", []))}
temp_state = OrchestrationState.from_dict(orch_state_dict)
return list(temp_state.conversation)
def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None:
"""Merge top-level response metadata into the latest assistant message."""
if not agent_response.additional_properties:
@@ -45,9 +45,15 @@ from ._workflow import Workflow, WorkflowRunResult
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -673,11 +679,11 @@ class MagenticManagerBase(ABC):
"""Prepare the final answer."""
...
def snapshot_state(self) -> dict[str, Any]:
def on_checkpoint_save(self) -> dict[str, Any]:
"""Serialize runtime state for checkpointing."""
return {}
def restore_state(self, state: dict[str, Any]) -> None:
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore runtime state from checkpoint data."""
return
@@ -695,22 +701,6 @@ class StandardMagenticManager(MagenticManagerBase):
task_ledger: _MagenticTaskLedger | None
def snapshot_state(self) -> dict[str, Any]:
state = super().snapshot_state()
if self.task_ledger is not None:
state = dict(state)
state["task_ledger"] = self.task_ledger.to_dict()
return state
def restore_state(self, state: dict[str, Any]) -> None:
super().restore_state(state)
ledger = state.get("task_ledger")
if ledger is not None:
try:
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager task ledger from checkpoint state")
def __init__(
self,
chat_client: ChatClientProtocol,
@@ -940,6 +930,22 @@ class StandardMagenticManager(MagenticManagerBase):
author_name=response.author_name or MAGENTIC_MANAGER_NAME,
)
@override
def on_checkpoint_save(self) -> dict[str, Any]:
state: dict[str, Any] = {}
if self.task_ledger is not None:
state["task_ledger"] = self.task_ledger.to_dict()
return state
@override
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
ledger = state.get("task_ledger")
if ledger is not None:
try:
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager task ledger from checkpoint state")
# endregion Magentic Manager
@@ -997,7 +1003,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
# Terminal state marker to stop further processing after completion/limits
self._terminated = False
# Tracks whether checkpoint state has been applied for this run
self._state_restored = False
def _get_author_name(self) -> str:
"""Get the magentic manager name for orchestrator-generated messages."""
@@ -1036,7 +1041,8 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
)
await ctx.add_event(event)
def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current orchestrator state for checkpointing.
Uses OrchestrationState for structure but maintains Magentic's complex metadata
@@ -1055,14 +1061,16 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
state["magentic_context"] = self._context.to_dict()
if self._task_ledger is not None:
state["task_ledger"] = _message_to_payload(self._task_ledger)
manager_state: dict[str, Any] | None = None
with contextlib.suppress(Exception):
manager_state = self._manager.snapshot_state()
if manager_state:
state["manager_state"] = manager_state
try:
state["manager_state"] = self._manager.on_checkpoint_save()
except Exception as exc:
logger.warning("Failed to save manager state for checkpoint: %s\nSkipping...", exc)
return state
def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore orchestrator state from checkpoint.
Maintains backward compatibility with existing Magentic checkpoints
@@ -1112,7 +1120,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
manager_state = state.get("manager_state")
if manager_state is not None:
try:
self._manager.restore_state(manager_state)
self._manager.on_checkpoint_restore(manager_state)
except Exception as exc: # pragma: no cover
logger.warning("Failed to restore manager state: %s", exc)
@@ -1142,49 +1150,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
for name, description in expected.items():
restored[name] = description
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""Serialize pattern-specific state.
Magentic uses custom snapshot_state() instead of base class hooks.
This method exists to satisfy the base class contract.
Returns:
Empty dict (Magentic manages its own state)
"""
return {}
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
"""Restore pattern-specific state.
Magentic uses custom restore_state() instead of base class hooks.
This method exists to satisfy the base class contract.
Args:
metadata: Pattern-specific state dict (ignored)
"""
pass
async def _ensure_state_restored(
self,
context: WorkflowContext[Any, Any],
) -> None:
if self._state_restored and self._context is not None:
return
state = await context.get_executor_state()
if not state:
self._state_restored = True
return
if not isinstance(state, dict):
self._state_restored = True
return
try:
self.restore_state(state)
except Exception as exc: # pragma: no cover
logger.warning("Magentic Orchestrator: Failed to apply checkpoint state: %s", exc, exc_info=True)
raise
else:
self._state_restored = True
@handler
async def handle_start_message(
self,
@@ -1204,7 +1169,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
)
if message.messages:
self._context.chat_history.extend(message.messages)
self._state_restored = True
# Non-streaming callback for the orchestrator receipt of the task
await self._emit_orchestrator_message(context, message.task, ORCH_MSG_KIND_USER_TASK)
@@ -1269,7 +1234,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
"""Handle responses from agents."""
if getattr(self, "_terminated", False):
return
await self._ensure_state_restored(context)
if self._context is None:
raise RuntimeError("Magentic Orchestrator: Received response but not initialized")
@@ -1301,7 +1266,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
) -> None:
if getattr(self, "_terminated", False):
return
await self._ensure_state_restored(context)
if self._context is None:
return
@@ -1636,9 +1601,9 @@ class MagenticAgentExecutor(Executor):
self._agent = agent
self._agent_id = agent_id
self._chat_history: list[ChatMessage] = []
self._state_restored = False
def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.
Returns:
@@ -1650,7 +1615,8 @@ class MagenticAgentExecutor(Executor):
"chat_history": encode_chat_messages(self._chat_history),
}
def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore executor state from checkpoint.
Args:
@@ -1668,24 +1634,6 @@ class MagenticAgentExecutor(Executor):
else:
self._chat_history = []
async def _ensure_state_restored(self, context: WorkflowContext[Any, Any]) -> None:
if self._state_restored and self._chat_history:
return
state = await context.get_executor_state()
if not state:
self._state_restored = True
return
if not isinstance(state, dict):
self._state_restored = True
return
try:
self.restore_state(state)
except Exception as exc: # pragma: no cover
logger.warning("Agent %s: Failed to apply checkpoint state: %s", self._agent_id, exc, exc_info=True)
raise
else:
self._state_restored = True
@handler
async def handle_response_message(
self, message: _MagenticResponseMessage, context: WorkflowContext[_MagenticResponseMessage]
@@ -1693,8 +1641,6 @@ class MagenticAgentExecutor(Executor):
"""Handle response message (task ledger broadcast)."""
logger.debug("Agent %s: Received response message", self._agent_id)
await self._ensure_state_restored(context)
# Check if this message is intended for this agent
if message.target_agent is not None and message.target_agent != self._agent_id and not message.broadcast:
# Message is targeted to a different agent, ignore it
@@ -1735,8 +1681,6 @@ class MagenticAgentExecutor(Executor):
logger.info("Agent %s: Received request to respond", self._agent_id)
await self._ensure_state_restored(context)
# Add persona adoption message with appropriate role
persona_role = self._get_persona_adoption_role()
persona_msg = ChatMessage(
@@ -1783,7 +1727,6 @@ class MagenticAgentExecutor(Executor):
"""Reset the internal chat history of the agent (internal operation)."""
logger.debug("Agent %s: Resetting chat history", self._agent_id)
self._chat_history.clear()
self._state_restored = True
async def _emit_agent_delta_event(
self,
@@ -11,7 +11,7 @@ from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpo
from ._const import EXECUTOR_STATE_KEY
from ._edge import EdgeGroup
from ._edge_runner import EdgeRunner, create_edge_runner
from ._events import WorkflowEvent
from ._events import SuperStepCompletedEvent, SuperStepStartedEvent, WorkflowEvent
from ._executor import Executor
from ._runner_context import (
Message,
@@ -92,6 +92,7 @@ class Runner:
while self._iteration < self._max_iterations:
logger.info(f"Starting superstep {self._iteration + 1}")
yield SuperStepStartedEvent(iteration=self._iteration + 1)
# Run iteration concurrently with live event streaming: we poll
# for new events while the iteration coroutine progresses.
@@ -126,6 +127,9 @@ class Runner:
# Create checkpoint after each superstep iteration
await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}")
yield SuperStepCompletedEvent(iteration=self._iteration)
# Check for convergence: no more messages to process
if not await self._ctx.has_messages():
break
@@ -183,8 +187,8 @@ class Runner:
return None
try:
# Auto-snapshot executor states
await self._auto_snapshot_executor_states()
# Snapshot executor states
await self._save_executor_states()
checkpoint_category = "initial" if checkpoint_type == "after_initial_execution" else "superstep"
metadata = {
"superstep": self._iteration,
@@ -203,41 +207,6 @@ class Runner:
logger.warning(f"Failed to create {checkpoint_type} checkpoint: {e}")
return None
async def _auto_snapshot_executor_states(self) -> None:
"""Populate executor state by calling snapshot hooks on executors if available.
TODO(@taochen#1614): this method is potentially problematic if executors also call
set_executor_state on the context directly. We should clarify the intended usage
pattern for executor state management.
Convention:
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
- Else if it has a plain attribute `state` that is a dict, use that.
Only JSON-serializable dicts should be provided by executors.
"""
for exec_id, executor in self._executors.items():
state_dict: dict[str, Any] | None = None
snapshot = getattr(executor, "snapshot_state", None)
try:
if callable(snapshot):
maybe = snapshot()
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
maybe = await maybe # type: ignore[assignment]
if isinstance(maybe, dict):
state_dict = maybe # type: ignore[assignment]
else:
state_attr = getattr(executor, "state", None)
if isinstance(state_attr, dict):
state_dict = state_attr # type: ignore[assignment]
except Exception as ex: # pragma: no cover
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
if state_dict is not None:
try:
await self._set_executor_state(exec_id, state_dict)
except Exception as ex: # pragma: no cover
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
async def restore_from_checkpoint(
self,
checkpoint_id: str,
@@ -300,7 +269,65 @@ class Runner:
logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}")
return False
async def _save_executor_states(self) -> None:
"""Populate executor state by calling checkpoint hooks on executors.
Backward compatibility behavior:
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
- Else if it has a plain attribute `state` that is a dict, use that.
Updated behavior:
- Executors should implement `on_checkpoint_save(self) -> dict` to provide state.
This method will try the backward compatibility behavior first; if that does not yield state,
it falls back to the updated behavior.
Only JSON-serializable dicts should be provided by executors.
"""
for exec_id, executor in self._executors.items():
state_dict: dict[str, Any] | None = None
# Try backward compatibility behavior first
# TODO(@taochen): Remove backward compatibility
snapshot = getattr(executor, "snapshot_state", None)
try:
if callable(snapshot):
maybe = snapshot()
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
maybe = await maybe # type: ignore[assignment]
if isinstance(maybe, dict):
state_dict = maybe # type: ignore[assignment]
else:
state_attr = getattr(executor, "state", None)
if isinstance(state_attr, dict):
state_dict = state_attr # type: ignore[assignment]
except Exception as ex: # pragma: no cover
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
if state_dict is None:
# Try the updated behavior only if backward compatibility did not yield state
try:
state_dict = await executor.on_checkpoint_save()
except Exception as ex: # pragma: no cover
raise ValueError(f"Executor {exec_id} on_checkpoint_save failed: {ex}") from ex
try:
await self._set_executor_state(exec_id, state_dict)
except Exception as ex: # pragma: no cover
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
async def _restore_executor_states(self) -> None:
"""Restore executor state by calling restore hooks on executors.
Backward compatibility behavior:
- If an executor defines an async or sync method `restore_state(self, state: dict)`, use it.
- Else, skip restoration for that executor.
Updated behavior:
- Executors should implement `on_checkpoint_restore(self, state: dict)` to restore state.
This method will try the backward compatibility behavior first; if that does not restore state,
it falls back to the updated behavior.
"""
has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
if not has_executor_states:
return
@@ -309,16 +336,18 @@ class Runner:
if not isinstance(executor_states, dict):
raise ValueError("Executor states in shared state is not a dictionary. Unable to restore.")
for executor_id, state in executor_states.items():
for executor_id, state in executor_states.items(): # pyright: ignore[reportUnknownVariableType]
if not isinstance(executor_id, str):
raise ValueError("Executor ID in executor states is not a string. Unable to restore.")
if not isinstance(state, dict):
raise ValueError(f"Executor state for {executor_id} is not a dictionary. Unable to restore.")
if not isinstance(state, dict) or not all(isinstance(k, str) for k in state): # pyright: ignore[reportUnknownVariableType]
raise ValueError(f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore.")
executor = self._executors.get(executor_id)
if not executor:
raise ValueError(f"Executor {executor_id} not found during state restoration.")
# Try backward compatibility behavior first
# TODO(@taochen): Remove backward compatibility
restored = False
restore_method = getattr(executor, "restore_state", None)
try:
@@ -330,6 +359,14 @@ class Runner:
except Exception as ex: # pragma: no cover - defensive
raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex
if not restored:
# Try the updated behavior only if backward compatibility did not restore
try:
await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType]
restored = True
except Exception as ex: # pragma: no cover - defensive
raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex
if not restored:
logger.debug(f"Executor {executor_id} does not support state restoration; skipping.")
@@ -109,9 +109,9 @@ class Workflow(DictConvertible):
"""A graph-based execution engine that orchestrates connected executors.
## Overview
A workflow executes a directed graph of executors connected via edge groups using a Pregel-like model,
running in supersteps until the graph becomes idle. Workflows are created using the
WorkflowBuilder class - do not instantiate this class directly.
A workflow executes a directed graph of executors connected via edge groups using a
Pregel-like model, running in supersteps until the graph becomes idle. Workflows
are created using the WorkflowBuilder class - do not instantiate this class directly.
## Execution Model
Executors run in synchronized supersteps where each executor:
@@ -142,6 +142,10 @@ class Workflow(DictConvertible):
- HIL continuation: Provide `responses` to continue after RequestInfoExecutor requests
- Runtime checkpointing: Provide `checkpoint_storage` to enable/override checkpointing for this run
## State Management
Workflow instances contain states and states are preserved across calls to `run` and `run_stream`.
To execute multiple independent runs, create separate Workflow instances via WorkflowBuilder.
## External Input Requests
Executors within a workflow can request external input using `ctx.request_info()`:
1. Executor calls `ctx.request_info()` to request input
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Union, cast, get_args, get_origi
from opentelemetry.propagate import inject
from opentelemetry.trace import SpanKind
from typing_extensions import Never, TypeVar
from typing_extensions import Never, TypeVar, deprecated
from ..observability import OtelAttr, create_workflow_span
from ._const import EXECUTOR_STATE_KEY
@@ -410,6 +410,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
"""Get the shared state."""
return self._shared_state
@deprecated(
"Override `on_checkpoint_save()` methods instead. "
"For cross-executor state sharing, use set_shared_state() instead. "
"This API will be removed after 12/01/2025."
)
async def set_executor_state(self, state: dict[str, Any]) -> None:
"""Store executor state in shared state under a reserved key.
@@ -428,6 +433,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
existing_states[self._executor_id] = state
await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states)
@deprecated(
"Override `on_checkpoint_restore()` methods instead. "
"For cross-executor state sharing, use get_shared_state() instead. "
"This API will be removed after 12/01/2025."
)
async def get_executor_state(self) -> dict[str, Any] | None:
"""Retrieve previously persisted state for this executor, if any."""
has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import contextlib
import logging
import sys
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
@@ -26,6 +26,12 @@ from ._typing_utils import is_instance_of
from ._workflow import WorkflowRunResult
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -181,8 +187,7 @@ class WorkflowExecutor(Executor):
# Includes all sub-workflow output types
# Plus SubWorkflowRequestMessage if sub-workflow can make requests
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
```
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
## Error Handling
WorkflowExecutor propagates sub-workflow failures:
@@ -221,23 +226,10 @@ class WorkflowExecutor(Executor):
### Important Considerations
**Shared Workflow Instance**: All concurrent executions use the same underlying workflow instance.
For proper isolation, ensure that:
- The wrapped workflow and its executors are stateless
- Executors use WorkflowContext state management instead of instance variables
- Any shared state is managed through WorkflowContext.get_shared_state/set_shared_state
For proper isolation, ensure that the wrapped workflow and its executors are stateless.
.. code-block:: python
# Good: Stateless executor using context state
class StatelessExecutor(Executor):
@handler
async def process(self, data: str, ctx: WorkflowContext[str]) -> None:
# Use context state instead of instance variables
state = await ctx.get_executor_state() or {}
state["processed"] = data
await ctx.set_executor_state(state)
# Avoid: Stateful executor with instance variables
class StatefulExecutor(Executor):
def __init__(self):
@@ -246,23 +238,23 @@ class WorkflowExecutor(Executor):
## Integration with Parent Workflows
Parent workflows can intercept sub-workflow requests:
```python
class ParentExecutor(Executor):
@handler
async def handle_subworkflow_request(
self,
request: SubWorkflowRequestMessage,
ctx: WorkflowContext[SubWorkflowResponseMessage],
) -> None:
# Handle request locally or forward to external source
if self.can_handle_locally(request):
# Send response back to sub-workflow
response = request.create_response(data="local response data")
await ctx.send_message(response, target_id=request.source_executor_id)
else:
# Forward to external handler
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
```
.. code-block:: python
class ParentExecutor(Executor):
@handler
async def handle_subworkflow_request(
self,
request: SubWorkflowRequestMessage,
ctx: WorkflowContext[SubWorkflowResponseMessage],
) -> None:
# Handle request locally or forward to external source
if self.can_handle_locally(request):
# Send response back to sub-workflow
response = request.create_response(data="local response data")
await ctx.send_message(response, target_id=request.source_executor_id)
else:
# Forward to external handler
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
## Implementation Notes
- Sub-workflows run to completion before processing their results
@@ -296,7 +288,6 @@ class WorkflowExecutor(Executor):
self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext
# Map request_id to execution_id for response routing
self._request_to_execution: dict[str, str] = {} # request_id -> execution_id
self._state_loaded: bool = False
@property
def input_types(self) -> list[type[Any]]:
@@ -362,8 +353,6 @@ class WorkflowExecutor(Executor):
input_data: The input data to send to the sub-workflow.
ctx: The workflow context from the parent.
"""
await self._ensure_state_loaded(ctx)
# Create execution context for this sub-workflow run
execution_id = str(uuid.uuid4())
execution_context = ExecutionContext(
@@ -405,8 +394,6 @@ class WorkflowExecutor(Executor):
response: The response to a previous request.
ctx: The workflow context.
"""
await self._ensure_state_loaded(ctx)
# Find the execution context for this request
original_request = response.source_event
execution_id = self._request_to_execution.get(original_request.request_id)
@@ -434,8 +421,6 @@ class WorkflowExecutor(Executor):
# Accumulate the response in this execution's context
execution_context.collected_responses[original_request.request_id] = response.data
await self._persist_execution_state(ctx)
# Check if we have all expected responses for this execution
if len(execution_context.collected_responses) < execution_context.expected_response_count:
logger.debug(
@@ -459,25 +444,20 @@ class WorkflowExecutor(Executor):
if not execution_context.pending_requests:
del self._execution_contexts[execution_id]
async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None:
if self._state_loaded:
return
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Get the current state of the WorkflowExecutor for checkpointing purposes."""
return {
"execution_contexts": {
execution_id: encode_checkpoint_value(execution_context)
for execution_id, execution_context in self._execution_contexts.items()
},
"request_to_execution": dict(self._request_to_execution),
}
state: dict[str, Any] | None = None
try:
state = await ctx.get_executor_state()
except Exception:
state = None
if isinstance(state, dict) and state:
with contextlib.suppress(Exception):
await self.restore_state(state)
self._state_loaded = True
else:
self._state_loaded = True
async def restore_state(self, state: dict[str, Any]) -> None:
"""Restore pending request bookkeeping from a checkpoint snapshot."""
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore the WorkflowExecutor state from a checkpoint snapshot."""
# Validate the state contains the right keys
if "execution_contexts" not in state:
raise KeyError("Missing 'execution_contexts' in WorkflowExecutor state.")
@@ -529,23 +509,6 @@ class WorkflowExecutor(Executor):
for event in request_info_events
])
self._state_loaded = True
async def _persist_execution_state(self, ctx: WorkflowContext) -> None:
"""Persist the state of the WorkflowExecutor for checkpointing purposes."""
state = {
"execution_contexts": {
execution_id: encode_checkpoint_value(execution_context)
for execution_id, execution_context in self._execution_contexts.items()
},
"request_to_execution": dict(self._request_to_execution),
}
try:
await ctx.set_executor_state(state)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}")
async def _process_workflow_result(
self,
result: WorkflowRunResult,
@@ -635,5 +598,3 @@ class WorkflowExecutor(Executor):
)
else:
raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}")
await self._persist_execution_state(ctx)
@@ -158,8 +158,8 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
assert thread_messages[1].text == "Initial response 1"
async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
"""Test AgentExecutor's snapshot_state and restore_state methods directly."""
async def test_agent_executor_save_and_restore_state_directly() -> None:
"""Test AgentExecutor's on_checkpoint_save and on_checkpoint_restore methods directly."""
# Create agent with thread containing messages
agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent")
thread = AgentThread(message_store=ChatMessageStore())
@@ -182,7 +182,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage]
# Snapshot the state
state = await executor.snapshot_state() # type: ignore[reportUnknownMemberType]
state = await executor.on_checkpoint_save()
# Verify snapshot contains both cache and thread
assert "cache" in state
@@ -206,7 +206,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
assert len(initial_thread_msgs) == 0
# Restore state
await new_executor.restore_state(state) # type: ignore[reportUnknownMemberType]
await new_executor.on_checkpoint_restore(state)
# Verify cache is restored
restored_cache = new_executor._cache # type: ignore[reportPrivateUsage]
@@ -288,57 +288,6 @@ def test_build_fails_without_participants():
HandoffBuilder().build()
async def test_multiple_runs_dont_leak_conversation():
"""Verify that running the same workflow multiple times doesn't leak conversation history."""
triage = _RecordingAgent(name="triage", handoff_to="specialist")
specialist = _RecordingAgent(name="specialist")
workflow = (
HandoffBuilder(participants=[triage, specialist])
.set_coordinator("triage")
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2)
.build()
)
# First run
events = await _drain(workflow.run_stream("First run message"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second message"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert outputs, "First run should emit output"
first_run_conversation = outputs[-1].data
assert isinstance(first_run_conversation, list)
first_run_conv_list = cast(list[ChatMessage], first_run_conversation)
first_run_user_messages = [msg for msg in first_run_conv_list if msg.role == Role.USER]
assert len(first_run_user_messages) == 2
assert any("First run message" in msg.text for msg in first_run_user_messages if msg.text)
# Second run - should start fresh, not include first run's messages
triage.calls.clear()
specialist.calls.clear()
events = await _drain(workflow.run_stream("Second run different message"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Another message"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert outputs, "Second run should emit output"
second_run_conversation = outputs[-1].data
assert isinstance(second_run_conversation, list)
second_run_conv_list = cast(list[ChatMessage], second_run_conversation)
second_run_user_messages = [msg for msg in second_run_conv_list if msg.role == Role.USER]
assert len(second_run_user_messages) == 2, (
"Second run should have exactly 2 user messages, not accumulate first run"
)
assert any("Second run different message" in msg.text for msg in second_run_user_messages if msg.text)
assert not any("First run message" in msg.text for msg in second_run_user_messages if msg.text), (
"Second run should NOT contain first run's messages"
)
async def test_handoff_async_termination_condition() -> None:
"""Test that async termination conditions work correctly."""
termination_call_count = 0
@@ -585,7 +534,7 @@ async def test_return_to_previous_state_serialization():
coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage]
# Snapshot the state
state = coordinator.snapshot_state()
state = await coordinator.on_checkpoint_save()
# Verify pattern metadata includes current_agent_id
assert "metadata" in state
@@ -603,7 +552,7 @@ async def test_return_to_previous_state_serialization():
)
# Restore state
coordinator2.restore_state(state)
await coordinator2.on_checkpoint_restore(state)
# Verify current_agent_id was restored
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from collections.abc import AsyncIterable
from dataclasses import dataclass
from typing import Any, cast
@@ -42,6 +43,11 @@ from agent_framework._workflows._magentic import ( # type: ignore[reportPrivate
_MagenticStartMessage, # type: ignore
)
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
def test_magentic_start_message_from_string():
msg = _MagenticStartMessage.from_string("Do the thing")
@@ -101,8 +107,9 @@ class FakeManager(MagenticManagerBase):
next_speaker_name: str = "agentA"
instruction_text: str = "Proceed with step 1"
def snapshot_state(self) -> dict[str, Any]:
state = super().snapshot_state()
@override
def on_checkpoint_save(self) -> dict[str, Any]:
state = super().on_checkpoint_save()
if self.task_ledger is not None:
state = dict(state)
state["task_ledger"] = {
@@ -111,8 +118,9 @@ class FakeManager(MagenticManagerBase):
}
return state
def restore_state(self, state: dict[str, Any]) -> None:
super().restore_state(state)
@override
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
super().on_checkpoint_restore(state)
ledger_state = state.get("task_ledger")
if isinstance(ledger_state, dict):
ledger_dict = cast(dict[str, Any], ledger_state)
@@ -185,7 +193,6 @@ async def test_standard_manager_progress_ledger_and_fallback():
assert ledger2.is_request_satisfied.answer is False
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_magentic_workflow_plan_review_approval_to_completion():
manager = FakeManager(max_round_count=10)
wf = (
@@ -204,7 +211,7 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
completed = False
output: ChatMessage | None = None
async for ev in wf.run_stream(
async for ev in wf.send_responses_streaming(
responses={req_event.request_id: MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE)}
):
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
@@ -218,7 +225,6 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
assert isinstance(output, ChatMessage)
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds():
class CountingManager(FakeManager):
# Declare as a model field so assignment is allowed under Pydantic
@@ -250,7 +256,7 @@ async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds()
# Reply APPROVE with comments (no edited text). Expect one replan and no second review round.
saw_second_review = False
completed = False
async for ev in wf.run_stream(
async for ev in wf.send_responses_streaming(
responses={
req_event.request_id: MagenticPlanReviewReply(
decision=MagenticPlanReviewDecision.APPROVE,
@@ -298,7 +304,6 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result():
assert data.role == Role.ASSISTANT
@pytest.mark.skip(reason="Response handling refactored - send_responses_streaming no longer exists")
async def test_magentic_checkpoint_resume_round_trip():
storage = InMemoryCheckpointStorage()
@@ -369,7 +374,7 @@ class _DummyExec(Executor):
pass
def test_magentic_agent_executor_snapshot_roundtrip():
async def test_magentic_agent_executor_on_checkpoint_save_and_restore_roundtrip():
backing_executor = _DummyExec("backing")
agent_exec = MagenticAgentExecutor(backing_executor, "agentA")
agent_exec._chat_history.extend([ # type: ignore[reportPrivateUsage]
@@ -377,10 +382,10 @@ def test_magentic_agent_executor_snapshot_roundtrip():
ChatMessage(role=Role.ASSISTANT, text="world", author_name="agentA"),
])
state = agent_exec.snapshot_state()
state = await agent_exec.on_checkpoint_save()
restored_executor = MagenticAgentExecutor(_DummyExec("backing2"), "agentA")
restored_executor.restore_state(state)
await restored_executor.on_checkpoint_restore(state)
assert len(restored_executor._chat_history) == 2 # type: ignore[reportPrivateUsage]
assert restored_executor._chat_history[0].text == "hello" # type: ignore[reportPrivateUsage]
@@ -199,7 +199,10 @@ async def test_fan_out():
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
# executor_b will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
assert len(events) == 7
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
# This workflow will converge in 2 supersteps because executor_c will send one more message
# after executor_b completes
assert len(events) == 11
assert events.get_final_state() == WorkflowRunState.IDLE
outputs = events.get_outputs()
@@ -220,7 +223,9 @@ async def test_fan_out_multiple_completed_events():
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
# executor_b and executor_c will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
assert len(events) == 8
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
# This workflow will converge in 1 superstep because executor_a and executor_b will not send further messages
assert len(events) == 10
# Multiple outputs are expected from both executors
outputs = events.get_outputs()
@@ -246,7 +251,8 @@ async def test_fan_in():
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
# aggregator will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
assert len(events) == 9
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
assert len(events) == 13
assert events.get_final_state() == WorkflowRunState.IDLE
outputs = events.get_outputs()
@@ -37,10 +37,14 @@ class WorkflowHILRequest:
class WorkflowTestExecutor(Executor):
"""Test executor with HIL."""
def __init__(self, id: str) -> None:
super().__init__(id=id)
self._data_value: str | None = None
@handler
async def process(self, data: WorkflowTestData, ctx: WorkflowContext) -> None:
"""Process data and request approval."""
await ctx.set_executor_state({"data_value": data.value})
self._data_value = data.value
# Request HIL (checkpoint created here)
await ctx.request_info(request_data=WorkflowHILRequest(question=f"Approve {data.value}?"), response_type=str)
@@ -50,8 +54,7 @@ class WorkflowTestExecutor(Executor):
self, original_request: WorkflowHILRequest, response: str, ctx: WorkflowContext[str]
) -> None:
"""Handle HIL response."""
state = await ctx.get_executor_state() or {}
value = state.get("data_value", "")
value = self._data_value or ""
await ctx.send_message(f"{value}_approved" if response.lower() == "yes" else f"{value}_rejected")
@@ -17,7 +17,7 @@ Workflow Steps:
import asyncio
import logging
from dataclasses import dataclass
from typing import Literal, Annotated
from typing import Literal
from agent_framework import (
Case,
@@ -31,9 +31,11 @@ from agent_framework import (
from pydantic import BaseModel, Field
from typing_extensions import Never
# Define response model with clear user guidance
class SpamDecision(BaseModel):
"""User's decision on whether the email is spam."""
decision: Literal["spam", "not spam"] = Field(
description="Enter 'spam' to mark as spam, or 'not spam' to mark as legitimate"
)
@@ -71,10 +73,11 @@ class SpamDetectorResponse:
class SpamApprovalRequest:
"""Human-in-the-loop approval request for spam classification."""
email_message: str = ""
detected_as_spam: bool = False
confidence: float = 0.0
reasons: str = ""
email_message: str
detected_as_spam: bool
confidence: float
reasons: list[str]
full_email_content: EmailContent
@dataclass
@@ -128,8 +131,6 @@ class EmailPreprocessor(Executor):
await ctx.send_message(result)
class SpamDetector(Executor):
"""Step 2: An executor that analyzes content and determines if a message is spam."""
@@ -139,7 +140,9 @@ class SpamDetector(Executor):
self._spam_keywords = spam_keywords
@handler
async def handle_email_content(self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]) -> None:
async def handle_email_content(
self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]
) -> None:
"""Analyze email content and determine if the message is spam, then request human approval."""
await asyncio.sleep(2.0) # Simulate analysis and detection time
@@ -186,25 +189,13 @@ class SpamDetector(Executor):
is_spam = spam_score >= 0.5
# Store detection result in executor state for later use
# Store minimal data needed (not complex objects that don't serialize well)
await ctx.set_executor_state({
"original_message": email_content.original_message,
"cleaned_message": email_content.cleaned_message,
"word_count": email_content.word_count,
"has_suspicious_patterns": email_content.has_suspicious_patterns,
"is_spam": is_spam,
"ai_original_classification": is_spam, # Store original AI decision
"confidence_score": spam_score,
"spam_reasons": spam_reasons
})
# Request human approval before proceeding using new API
approval_request = SpamApprovalRequest(
email_message=email_text[:200], # First 200 chars
detected_as_spam=is_spam,
confidence=spam_score,
reasons=", ".join(spam_reasons) if spam_reasons else "no specific reasons"
reasons=spam_reasons,
full_email_content=email_content,
)
await ctx.request_info(
@@ -214,20 +205,15 @@ class SpamDetector(Executor):
@response_handler
async def handle_human_response(
self,
original_request: SpamApprovalRequest,
response: SpamDecision,
ctx: WorkflowContext[SpamDetectorResponse]
self, original_request: SpamApprovalRequest, response: SpamDecision, ctx: WorkflowContext[SpamDetectorResponse]
) -> None:
"""Process human approval response and continue workflow."""
print(f"[SpamDetector] handle_human_response called with response: {response}")
# Get stored detection result
state = await ctx.get_executor_state() or {}
print(f"[SpamDetector] Retrieved state: {state}")
ai_original = state.get("ai_original_classification", False)
confidence_score = state.get("confidence_score", 0.0)
spam_reasons = state.get("spam_reasons", [])
ai_original = original_request.detected_as_spam
confidence_score = original_request.confidence
spam_reasons = original_request.reasons
# Parse human decision from the response model
human_decision = response.decision.strip().lower()
@@ -241,27 +227,21 @@ class SpamDetector(Executor):
# Default to AI decision if unclear
is_spam = ai_original
# Reconstruct EmailContent from stored primitives
email_content = EmailContent(
original_message=state.get("original_message", ""),
cleaned_message=state.get("cleaned_message", ""),
word_count=state.get("word_count", 0),
has_suspicious_patterns=state.get("has_suspicious_patterns", False)
)
result = SpamDetectorResponse(
email_content=email_content,
email_content=original_request.full_email_content,
is_spam=is_spam,
confidence_score=confidence_score,
spam_reasons=spam_reasons,
human_reviewed=True,
human_decision=response.decision,
ai_original_classification=ai_original
ai_original_classification=ai_original,
)
print(f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True")
print(
f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True"
)
await ctx.send_message(result)
print(f"[SpamDetector] Message sent successfully")
print("[SpamDetector] Message sent successfully")
class SpamHandler(Executor):
@@ -427,7 +407,9 @@ workflow = (
spam_detector,
[
Case(condition=lambda x: isinstance(x, SpamDetectorResponse) and x.is_spam, target=spam_handler),
Default(target=legitimate_message_handler), # Default handles non-spam and non-SpamDetectorResponse messages
Default(
target=legitimate_message_handler
), # Default handles non-spam and non-SpamDetectorResponse messages
],
)
.add_edge(spam_handler, final_processor)
@@ -3,6 +3,7 @@
import asyncio
from dataclasses import dataclass
from pathlib import Path
from typing import Any, override
# NOTE: the Azure client imports above are real dependencies. When running this
# sample outside of Azure-enabled environments you may wish to swap in the
@@ -116,19 +117,19 @@ class ReviewGateway(Executor):
def __init__(self, id: str, writer_id: str) -> None:
super().__init__(id=id)
self._writer_id = writer_id
self._iteration = 0
@handler
async def on_agent_response(self, response: AgentExecutorResponse, ctx: WorkflowContext) -> None:
# Capture the agent output so we can surface it to the reviewer and persist iterations.
draft = response.agent_run_response.text or ""
iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
self._iteration += 1
# Emit a human approval request.
await ctx.request_info(
request_data=HumanApprovalRequest(
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
draft=draft,
iteration=iteration,
draft=response.agent_run_response.text,
iteration=self._iteration,
),
response_type=str,
)
@@ -142,28 +143,33 @@ class ReviewGateway(Executor):
) -> None:
# The `original_request` is the request we sent earlier that is now being answered.
reply = feedback.strip()
state = await ctx.get_executor_state() or {}
draft = state.get("last_draft") or (original_request.draft or "")
if reply.lower() == "approve":
if len(reply) == 0 or reply.lower() == "approve":
# Workflow is completed when the human approves.
await ctx.yield_output(draft)
await ctx.yield_output(original_request.draft)
return
# Any other response loops us back to the writer with fresh guidance.
guidance = reply or "Tighten the copy and emphasise customer benefit."
iteration = int(state.get("iteration", 1)) + 1
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
prompt = (
"Revise the launch note. Respond with the new copy only.\n\n"
f"Previous draft:\n{draft}\n\n"
f"Human guidance: {guidance}"
f"Previous draft:\n{original_request.draft}\n\n"
f"Human guidance: {reply}"
)
await ctx.send_message(
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
target_id=self._writer_id,
)
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
# Save the current iteration count in executor state for checkpointing.
return {"iteration": self._iteration}
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
# Restore the iteration count from executor state during checkpoint recovery.
self._iteration = state.get("iteration", 0)
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> Workflow:
"""Assemble the workflow graph used by both the initial run and resume."""
@@ -247,10 +253,10 @@ async def run_interactive_session(
else:
if initial_message:
print(f"\nStarting workflow with brief: {initial_message}\n")
event_stream = workflow.run_stream(initial_message)
event_stream = workflow.run_stream(message=initial_message)
elif checkpoint_id:
print("\nStarting workflow from checkpoint...\n")
event_stream = workflow.run_stream(checkpoint_id)
event_stream = workflow.run_stream(checkpoint_id=checkpoint_id)
else:
raise ValueError("Either initial_message or checkpoint_id must be provided")
@@ -1,322 +1,157 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any
from agent_framework import (
AgentExecutor,
AgentExecutorRequest,
AgentExecutorResponse,
ChatMessage,
Executor,
FileCheckpointStorage,
Role,
WorkflowBuilder,
WorkflowContext,
get_checkpoint_summary,
handler,
)
from agent_framework.azure import AzureOpenAIChatClient
from azure.identity import AzureCliCredential
if TYPE_CHECKING:
from agent_framework import Workflow
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
"""
Sample: Checkpointing and Resuming a Workflow (with an Agent stage)
Sample: Checkpointing and Resuming a Workflow
Purpose:
This sample shows how to enable checkpointing at superstep boundaries, persist both
executor-local state and shared workflow state, and then resume execution from a specific
checkpoint. The workflow demonstrates a simple text-processing pipeline that includes
an LLM-backed AgentExecutor stage.
Pipeline:
1) UpperCaseExecutor converts input to uppercase and records state.
2) ReverseTextExecutor reverses the string.
3) SubmitToLowerAgent prepares an AgentExecutorRequest for the lowercasing agent.
4) lower_agent (AgentExecutor) converts text to lowercase via Azure OpenAI.
5) FinalizeFromAgent yields the final result.
This sample shows how to enable checkpointing for a long-running workflow
that can be paused and resumed.
What you learn:
- How to persist executor state using ctx.get_executor_state and ctx.set_executor_state.
- How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility.
- How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder.
- How to list and inspect checkpoints programmatically.
- How to interactively choose a checkpoint to resume from (instead of always resuming
from the most recent or a hard-coded one) using run_stream.
- How workflows complete by yielding outputs when idle, not via explicit completion events.
- How to configure checkpointing storage (InMemoryCheckpointStorage for testing)
- How to resume a workflow from a checkpoint after interruption
- How to implement executor state management with checkpoint hooks
- How to handle workflow interruptions and automatic recovery
Pipeline:
This sample shows a workflow that computes factor pairs for numbers up to a given limit:
1) A start executor that receives the upper limit and creates the initial task
2) A worker executor that processes each number to find its factor pairs
3) The worker uses checkpoint hooks to save/restore its internal state
Prerequisites:
- Azure AI or Azure OpenAI available for AzureOpenAIChatClient.
- Authentication with azure-identity via AzureCliCredential. Run az login locally.
- Filesystem access for writing JSON checkpoint files in a temp directory.
- Basic understanding of workflow concepts, including executors, edges, events, etc.
"""
# Define the temporary directory for storing checkpoints.
# These files allow the workflow to be resumed later.
DIR = os.path.dirname(__file__)
TEMP_DIR = os.path.join(DIR, "tmp", "checkpoints")
os.makedirs(TEMP_DIR, exist_ok=True)
import asyncio
from dataclasses import dataclass
from random import random
from typing import Any, override
from agent_framework import (
Executor,
InMemoryCheckpointStorage,
SuperStepCompletedEvent,
WorkflowBuilder,
WorkflowCheckpoint,
WorkflowContext,
WorkflowOutputEvent,
handler,
)
class UpperCaseExecutor(Executor):
"""Uppercases the input text and persists both local and shared state."""
@dataclass
class ComputeTask:
"""Task containing the list of numbers remaining to be processed."""
remaining_numbers: list[int]
class StartExecutor(Executor):
"""Initiates the workflow by providing the upper limit for factor pair computation."""
@handler
async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None:
result = text.upper()
print(f"UpperCaseExecutor: '{text}' -> '{result}'")
# Persist executor-local state so it is captured in checkpoints
# and available after resume for observability or logic.
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_executor_state({
"count": count,
"last_input": text,
"last_output": result,
})
# Write to shared_state so downstream executors and any resumed runs can read it.
await ctx.set_shared_state("original_input", text)
await ctx.set_shared_state("upper_output", result)
# Send transformed text to the next executor.
await ctx.send_message(result)
async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None:
"""Start the workflow with a list of numbers to process."""
print(f"StartExecutor: Starting factor pair computation up to {upper_limit}")
await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1))))
class SubmitToLowerAgent(Executor):
"""Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility."""
class WorkerExecutor(Executor):
"""Processes numbers to compute their factor pairs and manages executor state for checkpointing."""
def __init__(self, id: str, agent_id: str):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self._agent_id = agent_id
self._composite_number_pairs: dict[int, list[tuple[int, int]]] = {}
@handler
async def submit(self, text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
# Demonstrate reading shared_state written by UpperCaseExecutor.
# Shared state survives across checkpoints and is visible to all executors.
orig = await ctx.get_shared_state("original_input")
upper = await ctx.get_shared_state("upper_output")
print(f"LowerAgent (shared_state): original_input='{orig}', upper_output='{upper}'")
async def compute(
self,
task: ComputeTask,
ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]],
) -> None:
"""Process the next number in the task, computing its factor pairs."""
next_number = task.remaining_numbers.pop(0)
# Build a minimal, deterministic prompt for the AgentExecutor.
prompt = f"Convert the following text to lowercase. Return ONLY the transformed text.\n\nText: {text}"
print(f"WorkerExecutor: Computing factor pairs for {next_number}")
pairs: list[tuple[int, int]] = []
for i in range(1, next_number):
if next_number % i == 0:
pairs.append((i, next_number // i))
self._composite_number_pairs[next_number] = pairs
# Send to the AgentExecutor. should_respond=True instructs the agent to produce a reply.
await ctx.send_message(
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
target_id=self._agent_id,
)
if not task.remaining_numbers:
# All numbers processed - output the results
await ctx.yield_output(self._composite_number_pairs)
else:
# More numbers to process - continue with remaining task
await ctx.send_message(task)
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Save the executor's internal state for checkpointing."""
return {"composite_number_pairs": self._composite_number_pairs}
class FinalizeFromAgent(Executor):
"""Consumes the AgentExecutorResponse and yields the final result."""
@handler
async def finalize(self, response: AgentExecutorResponse, ctx: WorkflowContext[Any, str]) -> None:
result = response.agent_run_response.text or ""
# Persist executor-local state for auditability when inspecting checkpoints.
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_executor_state({
"count": count,
"last_output": result,
"final": True,
})
# Yield the final result so external consumers see the final value.
await ctx.yield_output(result)
class ReverseTextExecutor(Executor):
"""Reverses the input text and persists local state."""
@handler
async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None:
result = text[::-1]
print(f"ReverseTextExecutor: '{text}' -> '{result}'")
# Persist executor-local state so checkpoint inspection can reveal progress.
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_executor_state({
"count": count,
"last_input": text,
"last_output": result,
})
# Forward the reversed string to the next stage.
await ctx.send_message(result)
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow":
# Instantiate the pipeline executors.
upper_case_executor = UpperCaseExecutor(id="upper-case")
reverse_text_executor = ReverseTextExecutor(id="reverse-text")
# Configure the agent stage that lowercases the text.
chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
lower_agent = AgentExecutor(
chat_client.create_agent(
instructions=("You transform text to lowercase. Reply with ONLY the transformed text.")
),
id="lower_agent",
)
# Bridge to the agent and terminalization stage.
submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id)
finalize = FinalizeFromAgent(id="finalize")
# Build the workflow with checkpointing enabled.
return (
WorkflowBuilder(max_iterations=5)
.add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse
.add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request
.add_edge(submit_lower, lower_agent) # Submit to AgentExecutor
.add_edge(lower_agent, finalize) # Agent output -> Finalize
.set_start_executor(upper_case_executor) # Entry point
.with_checkpointing(checkpoint_storage=checkpoint_storage) # Enable persistence
.build()
)
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
"""Display human-friendly checkpoint metadata using framework summaries."""
if not checkpoints:
return
print("\nCheckpoint summary:")
for cp in sorted(checkpoints, key=lambda c: c.timestamp):
summary = get_checkpoint_summary(cp)
msg_count = sum(len(v) for v in cp.messages.values())
state_keys = sorted(summary.executor_ids)
orig = cp.shared_state.get("original_input")
upper = cp.shared_state.get("upper_output")
line = (
f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}"
)
if summary.status:
line += f" | status={summary.status}"
line += f" | shared_state: original_input='{orig}', upper_output='{upper}'"
print(line)
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore the executor's internal state from a checkpoint."""
self._composite_number_pairs = state.get("composite_number_pairs", {})
async def main():
# Clear existing checkpoints in this sample directory for a clean run.
checkpoint_dir = Path(TEMP_DIR)
for file in checkpoint_dir.glob("*.json"): # noqa: ASYNC240
file.unlink()
# Create workflow executors
start_executor = StartExecutor(id="start")
worker_executor = WorkerExecutor(id="worker")
# Backing store for checkpoints written by with_checkpointing.
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
# Build workflow with checkpointing enabled
workflow_builder = (
WorkflowBuilder()
.set_start_executor(start_executor)
.add_edge(start_executor, worker_executor)
.add_edge(worker_executor, worker_executor) # Self-loop for iterative processing
)
checkpoint_storage = InMemoryCheckpointStorage()
workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage=checkpoint_storage)
workflow = create_workflow(checkpoint_storage=checkpoint_storage)
# Run workflow with automatic checkpoint recovery
latest_checkpoint: WorkflowCheckpoint | None = None
while True:
workflow = workflow_builder.build()
# Run the full workflow once and observe events as they stream.
print("Running workflow with initial message...")
async for event in workflow.run_stream(message="hello world"):
print(f"Event: {event}")
# Start from checkpoint or fresh execution
print(f"\n** Workflow {workflow.id} started **")
event_stream = (
workflow.run_stream(message=10)
if latest_checkpoint is None
else workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id)
)
# Inspect checkpoints written during the run.
all_checkpoints = await checkpoint_storage.list_checkpoints()
if not all_checkpoints:
print("No checkpoints found!")
return
# All checkpoints created by this run share the same workflow_id.
workflow_id = all_checkpoints[0].workflow_id
_render_checkpoint_summary(all_checkpoints)
# Offer an interactive selection of checkpoints to resume from.
sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp)
print("\nAvailable checkpoints to resume from:")
for idx, cp in enumerate(sorted_cps):
summary = get_checkpoint_summary(cp)
line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}"
if summary.status:
line += f" status={summary.status}"
msg_count = sum(len(v) for v in cp.messages.values())
line += f" messages={msg_count}"
print(line)
user_input = input( # noqa: ASYNC250
"\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: "
).strip()
if not user_input:
print("No checkpoint selected. Exiting without resuming.")
return
chosen_cp_id: str | None = None
# Try as index first
if user_input.isdigit():
idx = int(user_input)
if 0 <= idx < len(sorted_cps):
chosen_cp_id = sorted_cps[idx].checkpoint_id
# Fall back to direct id match
if chosen_cp_id is None:
for cp in sorted_cps:
if cp.checkpoint_id.startswith(user_input): # allow prefix match for convenience
chosen_cp_id = cp.checkpoint_id
output: str | None = None
async for event in event_stream:
if isinstance(event, WorkflowOutputEvent):
output = event.data
break
if isinstance(event, SuperStepCompletedEvent) and random() < 0.5:
# Randomly simulate system interruptions
# The `SuperStepCompletedEvent` ensures we only interrupt after
# the current super-step is fully complete and checkpointed.
# If we interrupt mid-step, the workflow may resume from an earlier point.
print("\n** Simulating workflow interruption. Stopping execution. **")
break
if chosen_cp_id is None:
print("Input did not match any checkpoint. Exiting without resuming.")
return
# Find the latest checkpoint to resume from
all_checkpoints = await checkpoint_storage.list_checkpoints()
if not all_checkpoints:
raise RuntimeError("No checkpoints available to resume from.")
latest_checkpoint = all_checkpoints[-1]
print(
f"Checkpoint {latest_checkpoint.checkpoint_id}: "
f"(iter={latest_checkpoint.iteration_count}, messages={latest_checkpoint.messages})"
)
# You can reuse the same workflow graph definition and resume from a prior checkpoint.
# This second workflow instance does not enable checkpointing to show that resumption
# reads from stored state but need not write new checkpoints.
new_workflow = create_workflow(checkpoint_storage=checkpoint_storage)
print(f"\nResuming from checkpoint: {chosen_cp_id}")
async for event in new_workflow.run_stream(checkpoint_id=chosen_cp_id, checkpoint_storage=checkpoint_storage):
print(f"Resumed Event: {event}")
"""
Sample Output:
Running workflow with initial message...
UpperCaseExecutor: 'hello world' -> 'HELLO WORLD'
Event: ExecutorInvokeEvent(executor_id=upper_case_executor)
Event: ExecutorCompletedEvent(executor_id=upper_case_executor)
ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH'
Event: ExecutorInvokeEvent(executor_id=reverse_text_executor)
Event: ExecutorCompletedEvent(executor_id=reverse_text_executor)
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
Event: ExecutorInvokeEvent(executor_id=submit_lower)
Event: ExecutorInvokeEvent(executor_id=lower_agent)
Event: ExecutorInvokeEvent(executor_id=finalize)
Checkpoint summary:
- dfc63e72-8e8d-454f-9b6d-0d740b9062e6 | label='after_initial_execution' | iter=0 | messages=1 | states=['upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
- a78c345a-e5d9-45ba-82c0-cb725452d91b | label='superstep_1' | iter=1 | messages=1 | states=['reverse_text_executor', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
- 637c1dbd-a525-4404-9583-da03980537a2 | label='superstep_2' | iter=2 | messages=0 | states=['finalize', 'lower_agent', 'reverse_text_executor', 'submit_lower', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
Available checkpoints to resume from:
[0] id=dfc63e72-... iter=0 messages=1 label='after_initial_execution'
[1] id=a78c345a-... iter=1 messages=1 label='superstep_1'
[2] id=637c1dbd-... iter=2 messages=0 label='superstep_2'
Enter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: 1
Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower)
Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent)
Resumed Event: ExecutorInvokeEvent(executor_id=finalize)
""" # noqa: E501
if output is not None:
print(f"\nWorkflow completed successfully with output: {output}")
break
if __name__ == "__main__":
@@ -7,6 +7,7 @@ import uuid
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, override
from agent_framework import (
Executor,
@@ -205,6 +206,8 @@ class LaunchCoordinator(Executor):
def __init__(self) -> None:
super().__init__(id="launch_coordinator")
# Track pending requests to match responses
self._pending_requests: dict[str, SubWorkflowRequestMessage] = {}
@handler
async def kick_off(self, topic: str, ctx: WorkflowContext[DraftTask]) -> None:
@@ -244,11 +247,9 @@ class LaunchCoordinator(Executor):
if not isinstance(request.source_event.data, ReviewRequest):
raise TypeError(f"Expected 'ReviewRequest', got {type(request.source_event.data)}")
# Record the request to response matching
# Record the request for response matching
review_request = request.source_event.data
executor_state = await ctx.get_executor_state() or {}
executor_state[review_request.id] = request
await ctx.set_executor_state(executor_state)
self._pending_requests[review_request.id] = request
# Send the request without modification
await ctx.request_info(request_data=review_request, response_type=str)
@@ -265,17 +266,25 @@ class LaunchCoordinator(Executor):
Note that the response must be sent back using SubWorkflowResponseMessage to route
the response back to the sub-workflow.
"""
executor_state = await ctx.get_executor_state() or {}
request_message = executor_state.pop(original_request.id, None)
# Save the executor state back to the context
await ctx.set_executor_state(executor_state)
request_message = self._pending_requests.pop(original_request.id, None)
if request_message is None:
raise ValueError("No matching pending request found for the resource response")
await ctx.send_message(request_message.create_response(response))
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture any additional state needed for checkpointing."""
return {
"pending_requests": self._pending_requests,
}
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore any additional state needed from checkpointing."""
self._pending_requests = state.get("pending_requests", {})
# ---------------------------------------------------------------------------
# Workflow construction helpers
@@ -356,9 +365,7 @@ async def main() -> None:
workflow2 = build_parent_workflow(storage)
request_info_event: RequestInfoEvent | None = None
async for event in workflow2.run_stream(
resume_checkpoint.checkpoint_id,
):
async for event in workflow2.run_stream(checkpoint_id=resume_checkpoint.checkpoint_id):
if isinstance(event, RequestInfoEvent):
request_info_event = event