Fix checkpoint ancestry bug

This commit is contained in:
Tao Chen
2026-06-10 16:36:15 -07:00
Unverified
parent 47012f1dcf
commit 0e3831192a
3 changed files with 213 additions and 17 deletions
@@ -62,7 +62,10 @@ class Runner:
self._iteration = 0
self._max_iterations = max_iterations
self._state = state
self._resumed_from_checkpoint = False # Track whether we resumed
# Checkpointing related attributes
self._resumed_from_checkpoint = False
self._previous_checkpoint_id: CheckpointID | None = None
@property
def context(self) -> RunnerContext:
@@ -104,8 +107,6 @@ class Runner:
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
"""Run the workflow until no more messages are sent."""
previous_checkpoint_id: CheckpointID | None = None
# Emit any events already produced prior to entering loop
if await self._ctx.has_events():
logger.info("Yielding pre-loop events")
@@ -117,7 +118,7 @@ class Runner:
# states after which the start executor has run. Note that we execute the start executor outside of the
# main iteration loop.
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
await self._create_checkpoint_if_enabled()
while self._iteration < self._max_iterations:
logger.info(f"Starting superstep {self._iteration + 1}")
@@ -164,7 +165,7 @@ class Runner:
self._state.commit()
# Create checkpoint after each superstep iteration
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
await self._create_checkpoint_if_enabled()
yield WorkflowEvent.superstep_completed(iteration=self._iteration)
@@ -230,10 +231,10 @@ class Runner:
]
await asyncio.gather(*tasks)
async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: CheckpointID | None) -> CheckpointID | None:
async def _create_checkpoint_if_enabled(self) -> None:
"""Create a checkpoint if checkpointing is enabled and attach a label and metadata."""
if not self._ctx.has_checkpointing():
return None
return
try:
# Save executor states into the shared state before creating the checkpoint,
@@ -248,15 +249,26 @@ class Runner:
self._workflow_name,
self._graph_signature_hash,
self._state,
previous_checkpoint_id,
self._previous_checkpoint_id,
self._iteration,
)
logger.info(f"Created checkpoint: {checkpoint_id}")
return checkpoint_id
logger.info(
"Created checkpoint: %s with parent checkpoint at iteration %d: %s",
checkpoint_id,
self._iteration,
self._previous_checkpoint_id,
)
self._previous_checkpoint_id = checkpoint_id
except Exception as e:
logger.warning(f"Failed to create checkpoint: {e}")
return None
logger.warning(
"Failed to create checkpoint at iteration %d: %s. "
"Note that this does not fail the workflow run. "
"The next successfully-created checkpoint will be parented to the last successful checkpoint: %s",
self._iteration,
e,
self._previous_checkpoint_id,
)
async def restore_from_checkpoint(
self,
@@ -311,7 +323,7 @@ class Runner:
# Apply the checkpoint to the context
await self._ctx.apply_checkpoint(checkpoint)
# Mark the runner as resumed
self._mark_resumed(checkpoint.iteration_count)
self._mark_resumed(checkpoint)
logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
except WorkflowCheckpointException:
@@ -377,13 +389,14 @@ class Runner:
return parsed
def _mark_resumed(self, iteration: int) -> None:
def _mark_resumed(self, checkpoint: WorkflowCheckpoint) -> None:
"""Mark the runner as having resumed from a checkpoint.
Optionally set the current iteration and max iterations.
"""
self._resumed_from_checkpoint = True
self._iteration = iteration
self._iteration = checkpoint.iteration_count
self._previous_checkpoint_id = checkpoint.checkpoint_id
async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
"""Store executor state in state under a reserved key.
@@ -336,6 +336,97 @@ async def test_workflow_checkpoint_chaining_via_previous_checkpoint_id():
)
async def test_workflow_checkpoint_ancestry_preserved_after_resume():
"""Resuming from a checkpoint must preserve ancestry: future checkpoints chain back to the resumed one."""
from typing_extensions import Never
from agent_framework import WorkflowBuilder, WorkflowContext, handler
from agent_framework._workflows._executor import Executor
class StartExecutor(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message, target_id="middle")
class MiddleExecutor(Executor):
@handler
async def process(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message + "-processed", target_id="finish")
class FinishExecutor(Executor):
@handler
async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(message + "-done")
storage = InMemoryCheckpointStorage()
def _build_workflow() -> Any:
start = StartExecutor(id="start")
middle = MiddleExecutor(id="middle")
finish = FinishExecutor(id="finish")
return (
WorkflowBuilder(
name="resume-ancestry-test",
max_iterations=10,
start_executor=start,
checkpoint_storage=storage,
)
.add_edge(start, middle)
.add_edge(middle, finish)
.build()
)
# First run: produce an initial chain of checkpoints
workflow = _build_workflow()
workflow_name = workflow.name
_ = [event async for event in workflow.run("hello", stream=True)]
initial_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp)
assert len(initial_checkpoints) >= 3, (
f"Need at least 3 initial checkpoints to pick a middle one, got {len(initial_checkpoints)}"
)
initial_ids = {cp.checkpoint_id for cp in initial_checkpoints}
# Pick an intermediate checkpoint to resume from (not the first, not the last)
resume_from = initial_checkpoints[len(initial_checkpoints) // 2]
# Resume on a fresh workflow instance (same graph signature) and run to completion
resumed_workflow = _build_workflow()
assert resumed_workflow.name == workflow_name
_ = [event async for event in resumed_workflow.run(checkpoint_id=resume_from.checkpoint_id, stream=True)]
# Inspect new checkpoints created after resuming
all_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp)
new_checkpoints = [cp for cp in all_checkpoints if cp.checkpoint_id not in initial_ids]
assert new_checkpoints, "Resuming from an intermediate checkpoint should produce new checkpoints"
# The very first checkpoint created after resuming must chain back to the resumed checkpoint
assert new_checkpoints[0].previous_checkpoint_id == resume_from.checkpoint_id, (
"First post-resume checkpoint must chain to the checkpoint that was resumed from; "
f"got previous_checkpoint_id={new_checkpoints[0].previous_checkpoint_id!r}, "
f"expected {resume_from.checkpoint_id!r}"
)
# Subsequent post-resume checkpoints must continue chaining
for i in range(1, len(new_checkpoints)):
assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id, (
f"Post-resume checkpoint {i} should chain to checkpoint {i - 1}"
)
# Walking the chain backwards from the most recent checkpoint must reach the original root
# without breaks (i.e. the full ancestry across the resume boundary is intact).
checkpoints_by_id = {cp.checkpoint_id: cp for cp in all_checkpoints}
chain: list[str] = []
cursor: str | None = new_checkpoints[-1].checkpoint_id
while cursor is not None:
chain.append(cursor)
cursor = checkpoints_by_id[cursor].previous_checkpoint_id
# Chain must include the resumed-from checkpoint and terminate at the original root
assert resume_from.checkpoint_id in chain
assert chain[-1] == initial_checkpoints[0].checkpoint_id
assert checkpoints_by_id[chain[-1]].previous_checkpoint_id is None
async def test_memory_checkpoint_storage_roundtrip_json_native_types():
"""Test that JSON-native types (str, int, float, bool, None) roundtrip correctly."""
storage = InMemoryCheckpointStorage()
@@ -883,7 +883,13 @@ async def test_runner_checkpoint_with_resumed_flag():
state = State()
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
runner._mark_resumed(5) # pyright: ignore[reportPrivateUsage]
resumed_checkpoint = WorkflowCheckpoint(
checkpoint_id="resumed-cp",
workflow_name="test_name",
graph_signature_hash="test_hash",
iteration_count=5,
)
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
# Add a message to trigger the checkpoint creation path
await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id="START"))
@@ -903,6 +909,86 @@ async def test_runner_checkpoint_with_resumed_flag():
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
async def test_runner_mark_resumed_sets_previous_checkpoint_id():
"""_mark_resumed must populate _previous_checkpoint_id so future checkpoints chain back to the resume point."""
runner = Runner(
[],
{},
State(),
InProcRunnerContext(),
"test_name",
graph_signature_hash="test_hash",
)
# Pre-condition: nothing to chain back to
assert runner._previous_checkpoint_id is None # pyright: ignore[reportPrivateUsage]
resumed_checkpoint = WorkflowCheckpoint(
checkpoint_id="resumed-cp-id",
workflow_name="test_name",
graph_signature_hash="test_hash",
iteration_count=3,
)
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
assert runner._iteration == 3 # pyright: ignore[reportPrivateUsage]
assert runner._previous_checkpoint_id == "resumed-cp-id" # pyright: ignore[reportPrivateUsage]
async def test_runner_post_resume_checkpoint_chains_to_resumed_checkpoint():
"""After resuming, the next checkpoint created must reference the resumed checkpoint as its parent."""
storage = InMemoryCheckpointStorage()
ctx = CheckpointingContext(storage)
executor_a = MockExecutor(id="executor_a")
executor_b = MockExecutor(id="executor_b")
edges = [
SingleEdgeGroup(executor_a.id, executor_b.id),
SingleEdgeGroup(executor_b.id, executor_a.id),
]
executors: dict[str, Executor] = {
executor_a.id: executor_a,
executor_b.id: executor_b,
}
state = State()
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
# Simulate having resumed from a prior checkpoint
resumed_checkpoint = WorkflowCheckpoint(
checkpoint_id="parent-checkpoint-id",
workflow_name="test_name",
graph_signature_hash="test_hash",
iteration_count=1,
)
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
# Seed a message so the runner has work to do (and creates checkpoints at superstep boundaries)
await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id=executor_a.id))
async for _ in runner.run_until_convergence():
pass
# Find the first checkpoint created after the resume point (across all workflows tracked by storage)
new_checkpoints = sorted(
await storage.list_checkpoints(workflow_name="test_name"),
key=lambda c: c.timestamp,
)
assert new_checkpoints, "Resuming and running should produce at least one new checkpoint"
# The first new checkpoint must chain to the resumed-from checkpoint, not to None
assert new_checkpoints[0].previous_checkpoint_id == "parent-checkpoint-id", (
"First post-resume checkpoint must chain to the resumed checkpoint id; "
f"got {new_checkpoints[0].previous_checkpoint_id!r}"
)
# Subsequent post-resume checkpoints continue the chain
for i in range(1, len(new_checkpoints)):
assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id
class ExecutorThatFailsWithEvents(Executor):
"""An executor that emits events and then raises an exception after receiving messages."""
@@ -1127,7 +1213,13 @@ async def test_runner_reset_for_new_run_clears_shared_state():
async def test_runner_reset_for_new_run_clears_resumed_from_checkpoint_flag():
"""reset_for_new_run clears the flag set by restore_from_checkpoint."""
runner = _make_runner()
runner._mark_resumed(iteration=5) # pyright: ignore[reportPrivateUsage]
resumed_checkpoint = WorkflowCheckpoint(
checkpoint_id="resumed-cp",
workflow_name="test_name",
graph_signature_hash="test_hash",
iteration_count=5,
)
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
await runner.reset_for_new_run()