Python: Add checkpoint save and restore hooks to executor (#2097)

* Add checkpoint hooks

* Deprecate get_executor_state and set_executor_state

* Fix tests and samples

* Add doc strings

* Add sample

* Fix import

* Address comments and fix tests

* Address comments

* conditional import
This commit is contained in:
Tao Chen
2025-11-17 10:19:01 -08:00
committed by GitHub
Unverified
parent 132597957a
commit c361ad8d33
22 changed files with 508 additions and 723 deletions
@@ -3,6 +3,7 @@
import asyncio
from dataclasses import dataclass
from pathlib import Path
from typing import Any, override
# NOTE: the Azure client imports above are real dependencies. When running this
# sample outside of Azure-enabled environments you may wish to swap in the
@@ -116,19 +117,19 @@ class ReviewGateway(Executor):
def __init__(self, id: str, writer_id: str) -> None:
super().__init__(id=id)
self._writer_id = writer_id
self._iteration = 0
@handler
async def on_agent_response(self, response: AgentExecutorResponse, ctx: WorkflowContext) -> None:
# Capture the agent output so we can surface it to the reviewer and persist iterations.
draft = response.agent_run_response.text or ""
iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
self._iteration += 1
# Emit a human approval request.
await ctx.request_info(
request_data=HumanApprovalRequest(
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
draft=draft,
iteration=iteration,
draft=response.agent_run_response.text,
iteration=self._iteration,
),
response_type=str,
)
@@ -142,28 +143,33 @@ class ReviewGateway(Executor):
) -> None:
# The `original_request` is the request we sent earlier that is now being answered.
reply = feedback.strip()
state = await ctx.get_executor_state() or {}
draft = state.get("last_draft") or (original_request.draft or "")
if reply.lower() == "approve":
if len(reply) == 0 or reply.lower() == "approve":
# Workflow is completed when the human approves.
await ctx.yield_output(draft)
await ctx.yield_output(original_request.draft)
return
# Any other response loops us back to the writer with fresh guidance.
guidance = reply or "Tighten the copy and emphasise customer benefit."
iteration = int(state.get("iteration", 1)) + 1
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
prompt = (
"Revise the launch note. Respond with the new copy only.\n\n"
f"Previous draft:\n{draft}\n\n"
f"Human guidance: {guidance}"
f"Previous draft:\n{original_request.draft}\n\n"
f"Human guidance: {reply}"
)
await ctx.send_message(
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
target_id=self._writer_id,
)
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
# Save the current iteration count in executor state for checkpointing.
return {"iteration": self._iteration}
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
# Restore the iteration count from executor state during checkpoint recovery.
self._iteration = state.get("iteration", 0)
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> Workflow:
"""Assemble the workflow graph used by both the initial run and resume."""
@@ -247,10 +253,10 @@ async def run_interactive_session(
else:
if initial_message:
print(f"\nStarting workflow with brief: {initial_message}\n")
event_stream = workflow.run_stream(initial_message)
event_stream = workflow.run_stream(message=initial_message)
elif checkpoint_id:
print("\nStarting workflow from checkpoint...\n")
event_stream = workflow.run_stream(checkpoint_id)
event_stream = workflow.run_stream(checkpoint_id=checkpoint_id)
else:
raise ValueError("Either initial_message or checkpoint_id must be provided")
@@ -1,322 +1,157 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any
from agent_framework import (
AgentExecutor,
AgentExecutorRequest,
AgentExecutorResponse,
ChatMessage,
Executor,
FileCheckpointStorage,
Role,
WorkflowBuilder,
WorkflowContext,
get_checkpoint_summary,
handler,
)
from agent_framework.azure import AzureOpenAIChatClient
from azure.identity import AzureCliCredential
if TYPE_CHECKING:
from agent_framework import Workflow
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
"""
Sample: Checkpointing and Resuming a Workflow (with an Agent stage)
Sample: Checkpointing and Resuming a Workflow
Purpose:
This sample shows how to enable checkpointing at superstep boundaries, persist both
executor-local state and shared workflow state, and then resume execution from a specific
checkpoint. The workflow demonstrates a simple text-processing pipeline that includes
an LLM-backed AgentExecutor stage.
Pipeline:
1) UpperCaseExecutor converts input to uppercase and records state.
2) ReverseTextExecutor reverses the string.
3) SubmitToLowerAgent prepares an AgentExecutorRequest for the lowercasing agent.
4) lower_agent (AgentExecutor) converts text to lowercase via Azure OpenAI.
5) FinalizeFromAgent yields the final result.
This sample shows how to enable checkpointing for a long-running workflow
that can be paused and resumed.
What you learn:
- How to persist executor state using ctx.get_executor_state and ctx.set_executor_state.
- How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility.
- How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder.
- How to list and inspect checkpoints programmatically.
- How to interactively choose a checkpoint to resume from (instead of always resuming
from the most recent or a hard-coded one) using run_stream.
- How workflows complete by yielding outputs when idle, not via explicit completion events.
- How to configure checkpointing storage (InMemoryCheckpointStorage for testing)
- How to resume a workflow from a checkpoint after interruption
- How to implement executor state management with checkpoint hooks
- How to handle workflow interruptions and automatic recovery
Pipeline:
This sample shows a workflow that computes factor pairs for numbers up to a given limit:
1) A start executor that receives the upper limit and creates the initial task
2) A worker executor that processes each number to find its factor pairs
3) The worker uses checkpoint hooks to save/restore its internal state
Prerequisites:
- Azure AI or Azure OpenAI available for AzureOpenAIChatClient.
- Authentication with azure-identity via AzureCliCredential. Run az login locally.
- Filesystem access for writing JSON checkpoint files in a temp directory.
- Basic understanding of workflow concepts, including executors, edges, events, etc.
"""
# Define the temporary directory for storing checkpoints.
# These files allow the workflow to be resumed later.
DIR = os.path.dirname(__file__)
TEMP_DIR = os.path.join(DIR, "tmp", "checkpoints")
os.makedirs(TEMP_DIR, exist_ok=True)
import asyncio
from dataclasses import dataclass
from random import random
from typing import Any, override
from agent_framework import (
Executor,
InMemoryCheckpointStorage,
SuperStepCompletedEvent,
WorkflowBuilder,
WorkflowCheckpoint,
WorkflowContext,
WorkflowOutputEvent,
handler,
)
class UpperCaseExecutor(Executor):
"""Uppercases the input text and persists both local and shared state."""
@dataclass
class ComputeTask:
"""Task containing the list of numbers remaining to be processed."""
remaining_numbers: list[int]
class StartExecutor(Executor):
"""Initiates the workflow by providing the upper limit for factor pair computation."""
@handler
async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None:
result = text.upper()
print(f"UpperCaseExecutor: '{text}' -> '{result}'")
# Persist executor-local state so it is captured in checkpoints
# and available after resume for observability or logic.
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_executor_state({
"count": count,
"last_input": text,
"last_output": result,
})
# Write to shared_state so downstream executors and any resumed runs can read it.
await ctx.set_shared_state("original_input", text)
await ctx.set_shared_state("upper_output", result)
# Send transformed text to the next executor.
await ctx.send_message(result)
async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None:
"""Start the workflow with a list of numbers to process."""
print(f"StartExecutor: Starting factor pair computation up to {upper_limit}")
await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1))))
class SubmitToLowerAgent(Executor):
"""Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility."""
class WorkerExecutor(Executor):
"""Processes numbers to compute their factor pairs and manages executor state for checkpointing."""
def __init__(self, id: str, agent_id: str):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self._agent_id = agent_id
self._composite_number_pairs: dict[int, list[tuple[int, int]]] = {}
@handler
async def submit(self, text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
# Demonstrate reading shared_state written by UpperCaseExecutor.
# Shared state survives across checkpoints and is visible to all executors.
orig = await ctx.get_shared_state("original_input")
upper = await ctx.get_shared_state("upper_output")
print(f"LowerAgent (shared_state): original_input='{orig}', upper_output='{upper}'")
async def compute(
self,
task: ComputeTask,
ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]],
) -> None:
"""Process the next number in the task, computing its factor pairs."""
next_number = task.remaining_numbers.pop(0)
# Build a minimal, deterministic prompt for the AgentExecutor.
prompt = f"Convert the following text to lowercase. Return ONLY the transformed text.\n\nText: {text}"
print(f"WorkerExecutor: Computing factor pairs for {next_number}")
pairs: list[tuple[int, int]] = []
for i in range(1, next_number):
if next_number % i == 0:
pairs.append((i, next_number // i))
self._composite_number_pairs[next_number] = pairs
# Send to the AgentExecutor. should_respond=True instructs the agent to produce a reply.
await ctx.send_message(
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
target_id=self._agent_id,
)
if not task.remaining_numbers:
# All numbers processed - output the results
await ctx.yield_output(self._composite_number_pairs)
else:
# More numbers to process - continue with remaining task
await ctx.send_message(task)
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Save the executor's internal state for checkpointing."""
return {"composite_number_pairs": self._composite_number_pairs}
class FinalizeFromAgent(Executor):
"""Consumes the AgentExecutorResponse and yields the final result."""
@handler
async def finalize(self, response: AgentExecutorResponse, ctx: WorkflowContext[Any, str]) -> None:
result = response.agent_run_response.text or ""
# Persist executor-local state for auditability when inspecting checkpoints.
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_executor_state({
"count": count,
"last_output": result,
"final": True,
})
# Yield the final result so external consumers see the final value.
await ctx.yield_output(result)
class ReverseTextExecutor(Executor):
"""Reverses the input text and persists local state."""
@handler
async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None:
result = text[::-1]
print(f"ReverseTextExecutor: '{text}' -> '{result}'")
# Persist executor-local state so checkpoint inspection can reveal progress.
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_executor_state({
"count": count,
"last_input": text,
"last_output": result,
})
# Forward the reversed string to the next stage.
await ctx.send_message(result)
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow":
# Instantiate the pipeline executors.
upper_case_executor = UpperCaseExecutor(id="upper-case")
reverse_text_executor = ReverseTextExecutor(id="reverse-text")
# Configure the agent stage that lowercases the text.
chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
lower_agent = AgentExecutor(
chat_client.create_agent(
instructions=("You transform text to lowercase. Reply with ONLY the transformed text.")
),
id="lower_agent",
)
# Bridge to the agent and terminalization stage.
submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id)
finalize = FinalizeFromAgent(id="finalize")
# Build the workflow with checkpointing enabled.
return (
WorkflowBuilder(max_iterations=5)
.add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse
.add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request
.add_edge(submit_lower, lower_agent) # Submit to AgentExecutor
.add_edge(lower_agent, finalize) # Agent output -> Finalize
.set_start_executor(upper_case_executor) # Entry point
.with_checkpointing(checkpoint_storage=checkpoint_storage) # Enable persistence
.build()
)
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
"""Display human-friendly checkpoint metadata using framework summaries."""
if not checkpoints:
return
print("\nCheckpoint summary:")
for cp in sorted(checkpoints, key=lambda c: c.timestamp):
summary = get_checkpoint_summary(cp)
msg_count = sum(len(v) for v in cp.messages.values())
state_keys = sorted(summary.executor_ids)
orig = cp.shared_state.get("original_input")
upper = cp.shared_state.get("upper_output")
line = (
f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}"
)
if summary.status:
line += f" | status={summary.status}"
line += f" | shared_state: original_input='{orig}', upper_output='{upper}'"
print(line)
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore the executor's internal state from a checkpoint."""
self._composite_number_pairs = state.get("composite_number_pairs", {})
async def main():
# Clear existing checkpoints in this sample directory for a clean run.
checkpoint_dir = Path(TEMP_DIR)
for file in checkpoint_dir.glob("*.json"): # noqa: ASYNC240
file.unlink()
# Create workflow executors
start_executor = StartExecutor(id="start")
worker_executor = WorkerExecutor(id="worker")
# Backing store for checkpoints written by with_checkpointing.
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
# Build workflow with checkpointing enabled
workflow_builder = (
WorkflowBuilder()
.set_start_executor(start_executor)
.add_edge(start_executor, worker_executor)
.add_edge(worker_executor, worker_executor) # Self-loop for iterative processing
)
checkpoint_storage = InMemoryCheckpointStorage()
workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage=checkpoint_storage)
workflow = create_workflow(checkpoint_storage=checkpoint_storage)
# Run workflow with automatic checkpoint recovery
latest_checkpoint: WorkflowCheckpoint | None = None
while True:
workflow = workflow_builder.build()
# Run the full workflow once and observe events as they stream.
print("Running workflow with initial message...")
async for event in workflow.run_stream(message="hello world"):
print(f"Event: {event}")
# Start from checkpoint or fresh execution
print(f"\n** Workflow {workflow.id} started **")
event_stream = (
workflow.run_stream(message=10)
if latest_checkpoint is None
else workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id)
)
# Inspect checkpoints written during the run.
all_checkpoints = await checkpoint_storage.list_checkpoints()
if not all_checkpoints:
print("No checkpoints found!")
return
# All checkpoints created by this run share the same workflow_id.
workflow_id = all_checkpoints[0].workflow_id
_render_checkpoint_summary(all_checkpoints)
# Offer an interactive selection of checkpoints to resume from.
sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp)
print("\nAvailable checkpoints to resume from:")
for idx, cp in enumerate(sorted_cps):
summary = get_checkpoint_summary(cp)
line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}"
if summary.status:
line += f" status={summary.status}"
msg_count = sum(len(v) for v in cp.messages.values())
line += f" messages={msg_count}"
print(line)
user_input = input( # noqa: ASYNC250
"\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: "
).strip()
if not user_input:
print("No checkpoint selected. Exiting without resuming.")
return
chosen_cp_id: str | None = None
# Try as index first
if user_input.isdigit():
idx = int(user_input)
if 0 <= idx < len(sorted_cps):
chosen_cp_id = sorted_cps[idx].checkpoint_id
# Fall back to direct id match
if chosen_cp_id is None:
for cp in sorted_cps:
if cp.checkpoint_id.startswith(user_input): # allow prefix match for convenience
chosen_cp_id = cp.checkpoint_id
output: str | None = None
async for event in event_stream:
if isinstance(event, WorkflowOutputEvent):
output = event.data
break
if isinstance(event, SuperStepCompletedEvent) and random() < 0.5:
# Randomly simulate system interruptions
# The `SuperStepCompletedEvent` ensures we only interrupt after
# the current super-step is fully complete and checkpointed.
# If we interrupt mid-step, the workflow may resume from an earlier point.
print("\n** Simulating workflow interruption. Stopping execution. **")
break
if chosen_cp_id is None:
print("Input did not match any checkpoint. Exiting without resuming.")
return
# Find the latest checkpoint to resume from
all_checkpoints = await checkpoint_storage.list_checkpoints()
if not all_checkpoints:
raise RuntimeError("No checkpoints available to resume from.")
latest_checkpoint = all_checkpoints[-1]
print(
f"Checkpoint {latest_checkpoint.checkpoint_id}: "
f"(iter={latest_checkpoint.iteration_count}, messages={latest_checkpoint.messages})"
)
# You can reuse the same workflow graph definition and resume from a prior checkpoint.
# This second workflow instance does not enable checkpointing to show that resumption
# reads from stored state but need not write new checkpoints.
new_workflow = create_workflow(checkpoint_storage=checkpoint_storage)
print(f"\nResuming from checkpoint: {chosen_cp_id}")
async for event in new_workflow.run_stream(checkpoint_id=chosen_cp_id, checkpoint_storage=checkpoint_storage):
print(f"Resumed Event: {event}")
"""
Sample Output:
Running workflow with initial message...
UpperCaseExecutor: 'hello world' -> 'HELLO WORLD'
Event: ExecutorInvokeEvent(executor_id=upper_case_executor)
Event: ExecutorCompletedEvent(executor_id=upper_case_executor)
ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH'
Event: ExecutorInvokeEvent(executor_id=reverse_text_executor)
Event: ExecutorCompletedEvent(executor_id=reverse_text_executor)
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
Event: ExecutorInvokeEvent(executor_id=submit_lower)
Event: ExecutorInvokeEvent(executor_id=lower_agent)
Event: ExecutorInvokeEvent(executor_id=finalize)
Checkpoint summary:
- dfc63e72-8e8d-454f-9b6d-0d740b9062e6 | label='after_initial_execution' | iter=0 | messages=1 | states=['upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
- a78c345a-e5d9-45ba-82c0-cb725452d91b | label='superstep_1' | iter=1 | messages=1 | states=['reverse_text_executor', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
- 637c1dbd-a525-4404-9583-da03980537a2 | label='superstep_2' | iter=2 | messages=0 | states=['finalize', 'lower_agent', 'reverse_text_executor', 'submit_lower', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
Available checkpoints to resume from:
[0] id=dfc63e72-... iter=0 messages=1 label='after_initial_execution'
[1] id=a78c345a-... iter=1 messages=1 label='superstep_1'
[2] id=637c1dbd-... iter=2 messages=0 label='superstep_2'
Enter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: 1
Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower)
Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent)
Resumed Event: ExecutorInvokeEvent(executor_id=finalize)
""" # noqa: E501
if output is not None:
print(f"\nWorkflow completed successfully with output: {output}")
break
if __name__ == "__main__":
@@ -7,6 +7,7 @@ import uuid
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, override
from agent_framework import (
Executor,
@@ -205,6 +206,8 @@ class LaunchCoordinator(Executor):
def __init__(self) -> None:
super().__init__(id="launch_coordinator")
# Track pending requests to match responses
self._pending_requests: dict[str, SubWorkflowRequestMessage] = {}
@handler
async def kick_off(self, topic: str, ctx: WorkflowContext[DraftTask]) -> None:
@@ -244,11 +247,9 @@ class LaunchCoordinator(Executor):
if not isinstance(request.source_event.data, ReviewRequest):
raise TypeError(f"Expected 'ReviewRequest', got {type(request.source_event.data)}")
# Record the request to response matching
# Record the request for response matching
review_request = request.source_event.data
executor_state = await ctx.get_executor_state() or {}
executor_state[review_request.id] = request
await ctx.set_executor_state(executor_state)
self._pending_requests[review_request.id] = request
# Send the request without modification
await ctx.request_info(request_data=review_request, response_type=str)
@@ -265,17 +266,25 @@ class LaunchCoordinator(Executor):
Note that the response must be sent back using SubWorkflowResponseMessage to route
the response back to the sub-workflow.
"""
executor_state = await ctx.get_executor_state() or {}
request_message = executor_state.pop(original_request.id, None)
# Save the executor state back to the context
await ctx.set_executor_state(executor_state)
request_message = self._pending_requests.pop(original_request.id, None)
if request_message is None:
raise ValueError("No matching pending request found for the resource response")
await ctx.send_message(request_message.create_response(response))
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture any additional state needed for checkpointing."""
return {
"pending_requests": self._pending_requests,
}
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore any additional state needed from checkpointing."""
self._pending_requests = state.get("pending_requests", {})
# ---------------------------------------------------------------------------
# Workflow construction helpers
@@ -356,9 +365,7 @@ async def main() -> None:
workflow2 = build_parent_workflow(storage)
request_info_event: RequestInfoEvent | None = None
async for event in workflow2.run_stream(
resume_checkpoint.checkpoint_id,
):
async for event in workflow2.run_stream(checkpoint_id=resume_checkpoint.checkpoint_id):
if isinstance(event, RequestInfoEvent):
request_info_event = event