mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Fix checkpoint ancestry bug
This commit is contained in:
@@ -62,7 +62,10 @@ class Runner:
|
||||
self._iteration = 0
|
||||
self._max_iterations = max_iterations
|
||||
self._state = state
|
||||
self._resumed_from_checkpoint = False # Track whether we resumed
|
||||
|
||||
# Checkpointing related attributes
|
||||
self._resumed_from_checkpoint = False
|
||||
self._previous_checkpoint_id: CheckpointID | None = None
|
||||
|
||||
@property
|
||||
def context(self) -> RunnerContext:
|
||||
@@ -104,8 +107,6 @@ class Runner:
|
||||
|
||||
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Run the workflow until no more messages are sent."""
|
||||
previous_checkpoint_id: CheckpointID | None = None
|
||||
|
||||
# Emit any events already produced prior to entering loop
|
||||
if await self._ctx.has_events():
|
||||
logger.info("Yielding pre-loop events")
|
||||
@@ -117,7 +118,7 @@ class Runner:
|
||||
# states after which the start executor has run. Note that we execute the start executor outside of the
|
||||
# main iteration loop.
|
||||
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
|
||||
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
|
||||
await self._create_checkpoint_if_enabled()
|
||||
|
||||
while self._iteration < self._max_iterations:
|
||||
logger.info(f"Starting superstep {self._iteration + 1}")
|
||||
@@ -164,7 +165,7 @@ class Runner:
|
||||
self._state.commit()
|
||||
|
||||
# Create checkpoint after each superstep iteration
|
||||
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
|
||||
await self._create_checkpoint_if_enabled()
|
||||
|
||||
yield WorkflowEvent.superstep_completed(iteration=self._iteration)
|
||||
|
||||
@@ -230,10 +231,10 @@ class Runner:
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _create_checkpoint_if_enabled(self, previous_checkpoint_id: CheckpointID | None) -> CheckpointID | None:
|
||||
async def _create_checkpoint_if_enabled(self) -> None:
|
||||
"""Create a checkpoint if checkpointing is enabled and attach a label and metadata."""
|
||||
if not self._ctx.has_checkpointing():
|
||||
return None
|
||||
return
|
||||
|
||||
try:
|
||||
# Save executor states into the shared state before creating the checkpoint,
|
||||
@@ -248,15 +249,26 @@ class Runner:
|
||||
self._workflow_name,
|
||||
self._graph_signature_hash,
|
||||
self._state,
|
||||
previous_checkpoint_id,
|
||||
self._previous_checkpoint_id,
|
||||
self._iteration,
|
||||
)
|
||||
|
||||
logger.info(f"Created checkpoint: {checkpoint_id}")
|
||||
return checkpoint_id
|
||||
logger.info(
|
||||
"Created checkpoint: %s with parent checkpoint at iteration %d: %s",
|
||||
checkpoint_id,
|
||||
self._iteration,
|
||||
self._previous_checkpoint_id,
|
||||
)
|
||||
self._previous_checkpoint_id = checkpoint_id
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create checkpoint: {e}")
|
||||
return None
|
||||
logger.warning(
|
||||
"Failed to create checkpoint at iteration %d: %s. "
|
||||
"Note that this does not fail the workflow run. "
|
||||
"The next successfully-created checkpoint will be parented to the last successful checkpoint: %s",
|
||||
self._iteration,
|
||||
e,
|
||||
self._previous_checkpoint_id,
|
||||
)
|
||||
|
||||
async def restore_from_checkpoint(
|
||||
self,
|
||||
@@ -311,7 +323,7 @@ class Runner:
|
||||
# Apply the checkpoint to the context
|
||||
await self._ctx.apply_checkpoint(checkpoint)
|
||||
# Mark the runner as resumed
|
||||
self._mark_resumed(checkpoint.iteration_count)
|
||||
self._mark_resumed(checkpoint)
|
||||
|
||||
logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
|
||||
except WorkflowCheckpointException:
|
||||
@@ -377,13 +389,14 @@ class Runner:
|
||||
|
||||
return parsed
|
||||
|
||||
def _mark_resumed(self, iteration: int) -> None:
|
||||
def _mark_resumed(self, checkpoint: WorkflowCheckpoint) -> None:
|
||||
"""Mark the runner as having resumed from a checkpoint.
|
||||
|
||||
Optionally set the current iteration and max iterations.
|
||||
"""
|
||||
self._resumed_from_checkpoint = True
|
||||
self._iteration = iteration
|
||||
self._iteration = checkpoint.iteration_count
|
||||
self._previous_checkpoint_id = checkpoint.checkpoint_id
|
||||
|
||||
async def _set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
|
||||
"""Store executor state in state under a reserved key.
|
||||
|
||||
@@ -336,6 +336,97 @@ async def test_workflow_checkpoint_chaining_via_previous_checkpoint_id():
|
||||
)
|
||||
|
||||
|
||||
async def test_workflow_checkpoint_ancestry_preserved_after_resume():
|
||||
"""Resuming from a checkpoint must preserve ancestry: future checkpoints chain back to the resumed one."""
|
||||
from typing_extensions import Never
|
||||
|
||||
from agent_framework import WorkflowBuilder, WorkflowContext, handler
|
||||
from agent_framework._workflows._executor import Executor
|
||||
|
||||
class StartExecutor(Executor):
|
||||
@handler
|
||||
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
|
||||
await ctx.send_message(message, target_id="middle")
|
||||
|
||||
class MiddleExecutor(Executor):
|
||||
@handler
|
||||
async def process(self, message: str, ctx: WorkflowContext[str]) -> None:
|
||||
await ctx.send_message(message + "-processed", target_id="finish")
|
||||
|
||||
class FinishExecutor(Executor):
|
||||
@handler
|
||||
async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None:
|
||||
await ctx.yield_output(message + "-done")
|
||||
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
def _build_workflow() -> Any:
|
||||
start = StartExecutor(id="start")
|
||||
middle = MiddleExecutor(id="middle")
|
||||
finish = FinishExecutor(id="finish")
|
||||
return (
|
||||
WorkflowBuilder(
|
||||
name="resume-ancestry-test",
|
||||
max_iterations=10,
|
||||
start_executor=start,
|
||||
checkpoint_storage=storage,
|
||||
)
|
||||
.add_edge(start, middle)
|
||||
.add_edge(middle, finish)
|
||||
.build()
|
||||
)
|
||||
|
||||
# First run: produce an initial chain of checkpoints
|
||||
workflow = _build_workflow()
|
||||
workflow_name = workflow.name
|
||||
_ = [event async for event in workflow.run("hello", stream=True)]
|
||||
|
||||
initial_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp)
|
||||
assert len(initial_checkpoints) >= 3, (
|
||||
f"Need at least 3 initial checkpoints to pick a middle one, got {len(initial_checkpoints)}"
|
||||
)
|
||||
initial_ids = {cp.checkpoint_id for cp in initial_checkpoints}
|
||||
|
||||
# Pick an intermediate checkpoint to resume from (not the first, not the last)
|
||||
resume_from = initial_checkpoints[len(initial_checkpoints) // 2]
|
||||
|
||||
# Resume on a fresh workflow instance (same graph signature) and run to completion
|
||||
resumed_workflow = _build_workflow()
|
||||
assert resumed_workflow.name == workflow_name
|
||||
_ = [event async for event in resumed_workflow.run(checkpoint_id=resume_from.checkpoint_id, stream=True)]
|
||||
|
||||
# Inspect new checkpoints created after resuming
|
||||
all_checkpoints = sorted(await storage.list_checkpoints(workflow_name=workflow_name), key=lambda c: c.timestamp)
|
||||
new_checkpoints = [cp for cp in all_checkpoints if cp.checkpoint_id not in initial_ids]
|
||||
assert new_checkpoints, "Resuming from an intermediate checkpoint should produce new checkpoints"
|
||||
|
||||
# The very first checkpoint created after resuming must chain back to the resumed checkpoint
|
||||
assert new_checkpoints[0].previous_checkpoint_id == resume_from.checkpoint_id, (
|
||||
"First post-resume checkpoint must chain to the checkpoint that was resumed from; "
|
||||
f"got previous_checkpoint_id={new_checkpoints[0].previous_checkpoint_id!r}, "
|
||||
f"expected {resume_from.checkpoint_id!r}"
|
||||
)
|
||||
|
||||
# Subsequent post-resume checkpoints must continue chaining
|
||||
for i in range(1, len(new_checkpoints)):
|
||||
assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id, (
|
||||
f"Post-resume checkpoint {i} should chain to checkpoint {i - 1}"
|
||||
)
|
||||
|
||||
# Walking the chain backwards from the most recent checkpoint must reach the original root
|
||||
# without breaks (i.e. the full ancestry across the resume boundary is intact).
|
||||
checkpoints_by_id = {cp.checkpoint_id: cp for cp in all_checkpoints}
|
||||
chain: list[str] = []
|
||||
cursor: str | None = new_checkpoints[-1].checkpoint_id
|
||||
while cursor is not None:
|
||||
chain.append(cursor)
|
||||
cursor = checkpoints_by_id[cursor].previous_checkpoint_id
|
||||
# Chain must include the resumed-from checkpoint and terminate at the original root
|
||||
assert resume_from.checkpoint_id in chain
|
||||
assert chain[-1] == initial_checkpoints[0].checkpoint_id
|
||||
assert checkpoints_by_id[chain[-1]].previous_checkpoint_id is None
|
||||
|
||||
|
||||
async def test_memory_checkpoint_storage_roundtrip_json_native_types():
|
||||
"""Test that JSON-native types (str, int, float, bool, None) roundtrip correctly."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -883,7 +883,13 @@ async def test_runner_checkpoint_with_resumed_flag():
|
||||
state = State()
|
||||
|
||||
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
runner._mark_resumed(5) # pyright: ignore[reportPrivateUsage]
|
||||
resumed_checkpoint = WorkflowCheckpoint(
|
||||
checkpoint_id="resumed-cp",
|
||||
workflow_name="test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
iteration_count=5,
|
||||
)
|
||||
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add a message to trigger the checkpoint creation path
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id="START"))
|
||||
@@ -903,6 +909,86 @@ async def test_runner_checkpoint_with_resumed_flag():
|
||||
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_mark_resumed_sets_previous_checkpoint_id():
|
||||
"""_mark_resumed must populate _previous_checkpoint_id so future checkpoints chain back to the resume point."""
|
||||
runner = Runner(
|
||||
[],
|
||||
{},
|
||||
State(),
|
||||
InProcRunnerContext(),
|
||||
"test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
)
|
||||
|
||||
# Pre-condition: nothing to chain back to
|
||||
assert runner._previous_checkpoint_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
resumed_checkpoint = WorkflowCheckpoint(
|
||||
checkpoint_id="resumed-cp-id",
|
||||
workflow_name="test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
iteration_count=3,
|
||||
)
|
||||
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
|
||||
assert runner._iteration == 3 # pyright: ignore[reportPrivateUsage]
|
||||
assert runner._previous_checkpoint_id == "resumed-cp-id" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_post_resume_checkpoint_chains_to_resumed_checkpoint():
|
||||
"""After resuming, the next checkpoint created must reference the resumed checkpoint as its parent."""
|
||||
storage = InMemoryCheckpointStorage()
|
||||
ctx = CheckpointingContext(storage)
|
||||
executor_a = MockExecutor(id="executor_a")
|
||||
executor_b = MockExecutor(id="executor_b")
|
||||
|
||||
edges = [
|
||||
SingleEdgeGroup(executor_a.id, executor_b.id),
|
||||
SingleEdgeGroup(executor_b.id, executor_a.id),
|
||||
]
|
||||
|
||||
executors: dict[str, Executor] = {
|
||||
executor_a.id: executor_a,
|
||||
executor_b.id: executor_b,
|
||||
}
|
||||
state = State()
|
||||
|
||||
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
# Simulate having resumed from a prior checkpoint
|
||||
resumed_checkpoint = WorkflowCheckpoint(
|
||||
checkpoint_id="parent-checkpoint-id",
|
||||
workflow_name="test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
iteration_count=1,
|
||||
)
|
||||
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Seed a message so the runner has work to do (and creates checkpoints at superstep boundaries)
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id=executor_a.id))
|
||||
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
|
||||
# Find the first checkpoint created after the resume point (across all workflows tracked by storage)
|
||||
new_checkpoints = sorted(
|
||||
await storage.list_checkpoints(workflow_name="test_name"),
|
||||
key=lambda c: c.timestamp,
|
||||
)
|
||||
assert new_checkpoints, "Resuming and running should produce at least one new checkpoint"
|
||||
|
||||
# The first new checkpoint must chain to the resumed-from checkpoint, not to None
|
||||
assert new_checkpoints[0].previous_checkpoint_id == "parent-checkpoint-id", (
|
||||
"First post-resume checkpoint must chain to the resumed checkpoint id; "
|
||||
f"got {new_checkpoints[0].previous_checkpoint_id!r}"
|
||||
)
|
||||
|
||||
# Subsequent post-resume checkpoints continue the chain
|
||||
for i in range(1, len(new_checkpoints)):
|
||||
assert new_checkpoints[i].previous_checkpoint_id == new_checkpoints[i - 1].checkpoint_id
|
||||
|
||||
|
||||
class ExecutorThatFailsWithEvents(Executor):
|
||||
"""An executor that emits events and then raises an exception after receiving messages."""
|
||||
|
||||
@@ -1127,7 +1213,13 @@ async def test_runner_reset_for_new_run_clears_shared_state():
|
||||
async def test_runner_reset_for_new_run_clears_resumed_from_checkpoint_flag():
|
||||
"""reset_for_new_run clears the flag set by restore_from_checkpoint."""
|
||||
runner = _make_runner()
|
||||
runner._mark_resumed(iteration=5) # pyright: ignore[reportPrivateUsage]
|
||||
resumed_checkpoint = WorkflowCheckpoint(
|
||||
checkpoint_id="resumed-cp",
|
||||
workflow_name="test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
iteration_count=5,
|
||||
)
|
||||
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
|
||||
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
Reference in New Issue
Block a user