mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add checkpoint save and restore hooks to executor (#2097)
* Add checkpoint hooks * Deprecate get_executor_state and set_executor_state * Fix tests and samples * Add doc strings * Add sample * Fix import * Address comments and fix tests * Address comments * conditional import
This commit is contained in:
committed by
GitHub
Unverified
parent
132597957a
commit
c361ad8d33
+22
-16
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, override
|
||||
|
||||
# NOTE: the Azure client imports above are real dependencies. When running this
|
||||
# sample outside of Azure-enabled environments you may wish to swap in the
|
||||
@@ -116,19 +117,19 @@ class ReviewGateway(Executor):
|
||||
def __init__(self, id: str, writer_id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._writer_id = writer_id
|
||||
self._iteration = 0
|
||||
|
||||
@handler
|
||||
async def on_agent_response(self, response: AgentExecutorResponse, ctx: WorkflowContext) -> None:
|
||||
# Capture the agent output so we can surface it to the reviewer and persist iterations.
|
||||
draft = response.agent_run_response.text or ""
|
||||
iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1
|
||||
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
|
||||
self._iteration += 1
|
||||
|
||||
# Emit a human approval request.
|
||||
await ctx.request_info(
|
||||
request_data=HumanApprovalRequest(
|
||||
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
|
||||
draft=draft,
|
||||
iteration=iteration,
|
||||
draft=response.agent_run_response.text,
|
||||
iteration=self._iteration,
|
||||
),
|
||||
response_type=str,
|
||||
)
|
||||
@@ -142,28 +143,33 @@ class ReviewGateway(Executor):
|
||||
) -> None:
|
||||
# The `original_request` is the request we sent earlier that is now being answered.
|
||||
reply = feedback.strip()
|
||||
state = await ctx.get_executor_state() or {}
|
||||
draft = state.get("last_draft") or (original_request.draft or "")
|
||||
|
||||
if reply.lower() == "approve":
|
||||
if len(reply) == 0 or reply.lower() == "approve":
|
||||
# Workflow is completed when the human approves.
|
||||
await ctx.yield_output(draft)
|
||||
await ctx.yield_output(original_request.draft)
|
||||
return
|
||||
|
||||
# Any other response loops us back to the writer with fresh guidance.
|
||||
guidance = reply or "Tighten the copy and emphasise customer benefit."
|
||||
iteration = int(state.get("iteration", 1)) + 1
|
||||
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
|
||||
prompt = (
|
||||
"Revise the launch note. Respond with the new copy only.\n\n"
|
||||
f"Previous draft:\n{draft}\n\n"
|
||||
f"Human guidance: {guidance}"
|
||||
f"Previous draft:\n{original_request.draft}\n\n"
|
||||
f"Human guidance: {reply}"
|
||||
)
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
|
||||
target_id=self._writer_id,
|
||||
)
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
# Save the current iteration count in executor state for checkpointing.
|
||||
return {"iteration": self._iteration}
|
||||
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
# Restore the iteration count from executor state during checkpoint recovery.
|
||||
self._iteration = state.get("iteration", 0)
|
||||
|
||||
|
||||
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> Workflow:
|
||||
"""Assemble the workflow graph used by both the initial run and resume."""
|
||||
@@ -247,10 +253,10 @@ async def run_interactive_session(
|
||||
else:
|
||||
if initial_message:
|
||||
print(f"\nStarting workflow with brief: {initial_message}\n")
|
||||
event_stream = workflow.run_stream(initial_message)
|
||||
event_stream = workflow.run_stream(message=initial_message)
|
||||
elif checkpoint_id:
|
||||
print("\nStarting workflow from checkpoint...\n")
|
||||
event_stream = workflow.run_stream(checkpoint_id)
|
||||
event_stream = workflow.run_stream(checkpoint_id=checkpoint_id)
|
||||
else:
|
||||
raise ValueError("Either initial_message or checkpoint_id must be provided")
|
||||
|
||||
|
||||
@@ -1,322 +1,157 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
FileCheckpointStorage,
|
||||
Role,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
get_checkpoint_summary,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import Workflow
|
||||
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
||||
|
||||
"""
|
||||
Sample: Checkpointing and Resuming a Workflow (with an Agent stage)
|
||||
Sample: Checkpointing and Resuming a Workflow
|
||||
|
||||
Purpose:
|
||||
This sample shows how to enable checkpointing at superstep boundaries, persist both
|
||||
executor-local state and shared workflow state, and then resume execution from a specific
|
||||
checkpoint. The workflow demonstrates a simple text-processing pipeline that includes
|
||||
an LLM-backed AgentExecutor stage.
|
||||
|
||||
Pipeline:
|
||||
1) UpperCaseExecutor converts input to uppercase and records state.
|
||||
2) ReverseTextExecutor reverses the string.
|
||||
3) SubmitToLowerAgent prepares an AgentExecutorRequest for the lowercasing agent.
|
||||
4) lower_agent (AgentExecutor) converts text to lowercase via Azure OpenAI.
|
||||
5) FinalizeFromAgent yields the final result.
|
||||
This sample shows how to enable checkpointing for a long-running workflow
|
||||
that can be paused and resumed.
|
||||
|
||||
What you learn:
|
||||
- How to persist executor state using ctx.get_executor_state and ctx.set_executor_state.
|
||||
- How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility.
|
||||
- How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder.
|
||||
- How to list and inspect checkpoints programmatically.
|
||||
- How to interactively choose a checkpoint to resume from (instead of always resuming
|
||||
from the most recent or a hard-coded one) using run_stream.
|
||||
- How workflows complete by yielding outputs when idle, not via explicit completion events.
|
||||
- How to configure checkpointing storage (InMemoryCheckpointStorage for testing)
|
||||
- How to resume a workflow from a checkpoint after interruption
|
||||
- How to implement executor state management with checkpoint hooks
|
||||
- How to handle workflow interruptions and automatic recovery
|
||||
|
||||
Pipeline:
|
||||
This sample shows a workflow that computes factor pairs for numbers up to a given limit:
|
||||
1) A start executor that receives the upper limit and creates the initial task
|
||||
2) A worker executor that processes each number to find its factor pairs
|
||||
3) The worker uses checkpoint hooks to save/restore its internal state
|
||||
|
||||
Prerequisites:
|
||||
- Azure AI or Azure OpenAI available for AzureOpenAIChatClient.
|
||||
- Authentication with azure-identity via AzureCliCredential. Run az login locally.
|
||||
- Filesystem access for writing JSON checkpoint files in a temp directory.
|
||||
- Basic understanding of workflow concepts, including executors, edges, events, etc.
|
||||
"""
|
||||
|
||||
# Define the temporary directory for storing checkpoints.
|
||||
# These files allow the workflow to be resumed later.
|
||||
DIR = os.path.dirname(__file__)
|
||||
TEMP_DIR = os.path.join(DIR, "tmp", "checkpoints")
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from random import random
|
||||
from typing import Any, override
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
InMemoryCheckpointStorage,
|
||||
SuperStepCompletedEvent,
|
||||
WorkflowBuilder,
|
||||
WorkflowCheckpoint,
|
||||
WorkflowContext,
|
||||
WorkflowOutputEvent,
|
||||
handler,
|
||||
)
|
||||
|
||||
|
||||
class UpperCaseExecutor(Executor):
|
||||
"""Uppercases the input text and persists both local and shared state."""
|
||||
@dataclass
|
||||
class ComputeTask:
|
||||
"""Task containing the list of numbers remaining to be processed."""
|
||||
|
||||
remaining_numbers: list[int]
|
||||
|
||||
|
||||
class StartExecutor(Executor):
|
||||
"""Initiates the workflow by providing the upper limit for factor pair computation."""
|
||||
|
||||
@handler
|
||||
async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None:
|
||||
result = text.upper()
|
||||
print(f"UpperCaseExecutor: '{text}' -> '{result}'")
|
||||
|
||||
# Persist executor-local state so it is captured in checkpoints
|
||||
# and available after resume for observability or logic.
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_input": text,
|
||||
"last_output": result,
|
||||
})
|
||||
|
||||
# Write to shared_state so downstream executors and any resumed runs can read it.
|
||||
await ctx.set_shared_state("original_input", text)
|
||||
await ctx.set_shared_state("upper_output", result)
|
||||
|
||||
# Send transformed text to the next executor.
|
||||
await ctx.send_message(result)
|
||||
async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None:
|
||||
"""Start the workflow with a list of numbers to process."""
|
||||
print(f"StartExecutor: Starting factor pair computation up to {upper_limit}")
|
||||
await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1))))
|
||||
|
||||
|
||||
class SubmitToLowerAgent(Executor):
|
||||
"""Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility."""
|
||||
class WorkerExecutor(Executor):
|
||||
"""Processes numbers to compute their factor pairs and manages executor state for checkpointing."""
|
||||
|
||||
def __init__(self, id: str, agent_id: str):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._agent_id = agent_id
|
||||
self._composite_number_pairs: dict[int, list[tuple[int, int]]] = {}
|
||||
|
||||
@handler
|
||||
async def submit(self, text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
# Demonstrate reading shared_state written by UpperCaseExecutor.
|
||||
# Shared state survives across checkpoints and is visible to all executors.
|
||||
orig = await ctx.get_shared_state("original_input")
|
||||
upper = await ctx.get_shared_state("upper_output")
|
||||
print(f"LowerAgent (shared_state): original_input='{orig}', upper_output='{upper}'")
|
||||
async def compute(
|
||||
self,
|
||||
task: ComputeTask,
|
||||
ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]],
|
||||
) -> None:
|
||||
"""Process the next number in the task, computing its factor pairs."""
|
||||
next_number = task.remaining_numbers.pop(0)
|
||||
|
||||
# Build a minimal, deterministic prompt for the AgentExecutor.
|
||||
prompt = f"Convert the following text to lowercase. Return ONLY the transformed text.\n\nText: {text}"
|
||||
print(f"WorkerExecutor: Computing factor pairs for {next_number}")
|
||||
pairs: list[tuple[int, int]] = []
|
||||
for i in range(1, next_number):
|
||||
if next_number % i == 0:
|
||||
pairs.append((i, next_number // i))
|
||||
self._composite_number_pairs[next_number] = pairs
|
||||
|
||||
# Send to the AgentExecutor. should_respond=True instructs the agent to produce a reply.
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
|
||||
target_id=self._agent_id,
|
||||
)
|
||||
if not task.remaining_numbers:
|
||||
# All numbers processed - output the results
|
||||
await ctx.yield_output(self._composite_number_pairs)
|
||||
else:
|
||||
# More numbers to process - continue with remaining task
|
||||
await ctx.send_message(task)
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Save the executor's internal state for checkpointing."""
|
||||
return {"composite_number_pairs": self._composite_number_pairs}
|
||||
|
||||
class FinalizeFromAgent(Executor):
|
||||
"""Consumes the AgentExecutorResponse and yields the final result."""
|
||||
|
||||
@handler
|
||||
async def finalize(self, response: AgentExecutorResponse, ctx: WorkflowContext[Any, str]) -> None:
|
||||
result = response.agent_run_response.text or ""
|
||||
|
||||
# Persist executor-local state for auditability when inspecting checkpoints.
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_output": result,
|
||||
"final": True,
|
||||
})
|
||||
|
||||
# Yield the final result so external consumers see the final value.
|
||||
await ctx.yield_output(result)
|
||||
|
||||
|
||||
class ReverseTextExecutor(Executor):
|
||||
"""Reverses the input text and persists local state."""
|
||||
|
||||
@handler
|
||||
async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None:
|
||||
result = text[::-1]
|
||||
print(f"ReverseTextExecutor: '{text}' -> '{result}'")
|
||||
|
||||
# Persist executor-local state so checkpoint inspection can reveal progress.
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_input": text,
|
||||
"last_output": result,
|
||||
})
|
||||
|
||||
# Forward the reversed string to the next stage.
|
||||
await ctx.send_message(result)
|
||||
|
||||
|
||||
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow":
|
||||
# Instantiate the pipeline executors.
|
||||
upper_case_executor = UpperCaseExecutor(id="upper-case")
|
||||
reverse_text_executor = ReverseTextExecutor(id="reverse-text")
|
||||
|
||||
# Configure the agent stage that lowercases the text.
|
||||
chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
|
||||
lower_agent = AgentExecutor(
|
||||
chat_client.create_agent(
|
||||
instructions=("You transform text to lowercase. Reply with ONLY the transformed text.")
|
||||
),
|
||||
id="lower_agent",
|
||||
)
|
||||
|
||||
# Bridge to the agent and terminalization stage.
|
||||
submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id)
|
||||
finalize = FinalizeFromAgent(id="finalize")
|
||||
|
||||
# Build the workflow with checkpointing enabled.
|
||||
return (
|
||||
WorkflowBuilder(max_iterations=5)
|
||||
.add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse
|
||||
.add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request
|
||||
.add_edge(submit_lower, lower_agent) # Submit to AgentExecutor
|
||||
.add_edge(lower_agent, finalize) # Agent output -> Finalize
|
||||
.set_start_executor(upper_case_executor) # Entry point
|
||||
.with_checkpointing(checkpoint_storage=checkpoint_storage) # Enable persistence
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
|
||||
"""Display human-friendly checkpoint metadata using framework summaries."""
|
||||
|
||||
if not checkpoints:
|
||||
return
|
||||
|
||||
print("\nCheckpoint summary:")
|
||||
for cp in sorted(checkpoints, key=lambda c: c.timestamp):
|
||||
summary = get_checkpoint_summary(cp)
|
||||
msg_count = sum(len(v) for v in cp.messages.values())
|
||||
state_keys = sorted(summary.executor_ids)
|
||||
orig = cp.shared_state.get("original_input")
|
||||
upper = cp.shared_state.get("upper_output")
|
||||
|
||||
line = (
|
||||
f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}"
|
||||
)
|
||||
if summary.status:
|
||||
line += f" | status={summary.status}"
|
||||
line += f" | shared_state: original_input='{orig}', upper_output='{upper}'"
|
||||
print(line)
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore the executor's internal state from a checkpoint."""
|
||||
self._composite_number_pairs = state.get("composite_number_pairs", {})
|
||||
|
||||
|
||||
async def main():
|
||||
# Clear existing checkpoints in this sample directory for a clean run.
|
||||
checkpoint_dir = Path(TEMP_DIR)
|
||||
for file in checkpoint_dir.glob("*.json"): # noqa: ASYNC240
|
||||
file.unlink()
|
||||
# Create workflow executors
|
||||
start_executor = StartExecutor(id="start")
|
||||
worker_executor = WorkerExecutor(id="worker")
|
||||
|
||||
# Backing store for checkpoints written by with_checkpointing.
|
||||
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
|
||||
# Build workflow with checkpointing enabled
|
||||
workflow_builder = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(start_executor)
|
||||
.add_edge(start_executor, worker_executor)
|
||||
.add_edge(worker_executor, worker_executor) # Self-loop for iterative processing
|
||||
)
|
||||
checkpoint_storage = InMemoryCheckpointStorage()
|
||||
workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage=checkpoint_storage)
|
||||
|
||||
workflow = create_workflow(checkpoint_storage=checkpoint_storage)
|
||||
# Run workflow with automatic checkpoint recovery
|
||||
latest_checkpoint: WorkflowCheckpoint | None = None
|
||||
while True:
|
||||
workflow = workflow_builder.build()
|
||||
|
||||
# Run the full workflow once and observe events as they stream.
|
||||
print("Running workflow with initial message...")
|
||||
async for event in workflow.run_stream(message="hello world"):
|
||||
print(f"Event: {event}")
|
||||
# Start from checkpoint or fresh execution
|
||||
print(f"\n** Workflow {workflow.id} started **")
|
||||
event_stream = (
|
||||
workflow.run_stream(message=10)
|
||||
if latest_checkpoint is None
|
||||
else workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id)
|
||||
)
|
||||
|
||||
# Inspect checkpoints written during the run.
|
||||
all_checkpoints = await checkpoint_storage.list_checkpoints()
|
||||
if not all_checkpoints:
|
||||
print("No checkpoints found!")
|
||||
return
|
||||
|
||||
# All checkpoints created by this run share the same workflow_id.
|
||||
workflow_id = all_checkpoints[0].workflow_id
|
||||
|
||||
_render_checkpoint_summary(all_checkpoints)
|
||||
|
||||
# Offer an interactive selection of checkpoints to resume from.
|
||||
sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp)
|
||||
|
||||
print("\nAvailable checkpoints to resume from:")
|
||||
for idx, cp in enumerate(sorted_cps):
|
||||
summary = get_checkpoint_summary(cp)
|
||||
line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}"
|
||||
if summary.status:
|
||||
line += f" status={summary.status}"
|
||||
msg_count = sum(len(v) for v in cp.messages.values())
|
||||
line += f" messages={msg_count}"
|
||||
print(line)
|
||||
|
||||
user_input = input( # noqa: ASYNC250
|
||||
"\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: "
|
||||
).strip()
|
||||
|
||||
if not user_input:
|
||||
print("No checkpoint selected. Exiting without resuming.")
|
||||
return
|
||||
|
||||
chosen_cp_id: str | None = None
|
||||
|
||||
# Try as index first
|
||||
if user_input.isdigit():
|
||||
idx = int(user_input)
|
||||
if 0 <= idx < len(sorted_cps):
|
||||
chosen_cp_id = sorted_cps[idx].checkpoint_id
|
||||
# Fall back to direct id match
|
||||
if chosen_cp_id is None:
|
||||
for cp in sorted_cps:
|
||||
if cp.checkpoint_id.startswith(user_input): # allow prefix match for convenience
|
||||
chosen_cp_id = cp.checkpoint_id
|
||||
output: str | None = None
|
||||
async for event in event_stream:
|
||||
if isinstance(event, WorkflowOutputEvent):
|
||||
output = event.data
|
||||
break
|
||||
if isinstance(event, SuperStepCompletedEvent) and random() < 0.5:
|
||||
# Randomly simulate system interruptions
|
||||
# The `SuperStepCompletedEvent` ensures we only interrupt after
|
||||
# the current super-step is fully complete and checkpointed.
|
||||
# If we interrupt mid-step, the workflow may resume from an earlier point.
|
||||
print("\n** Simulating workflow interruption. Stopping execution. **")
|
||||
break
|
||||
|
||||
if chosen_cp_id is None:
|
||||
print("Input did not match any checkpoint. Exiting without resuming.")
|
||||
return
|
||||
# Find the latest checkpoint to resume from
|
||||
all_checkpoints = await checkpoint_storage.list_checkpoints()
|
||||
if not all_checkpoints:
|
||||
raise RuntimeError("No checkpoints available to resume from.")
|
||||
latest_checkpoint = all_checkpoints[-1]
|
||||
print(
|
||||
f"Checkpoint {latest_checkpoint.checkpoint_id}: "
|
||||
f"(iter={latest_checkpoint.iteration_count}, messages={latest_checkpoint.messages})"
|
||||
)
|
||||
|
||||
# You can reuse the same workflow graph definition and resume from a prior checkpoint.
|
||||
# This second workflow instance does not enable checkpointing to show that resumption
|
||||
# reads from stored state but need not write new checkpoints.
|
||||
new_workflow = create_workflow(checkpoint_storage=checkpoint_storage)
|
||||
|
||||
print(f"\nResuming from checkpoint: {chosen_cp_id}")
|
||||
async for event in new_workflow.run_stream(checkpoint_id=chosen_cp_id, checkpoint_storage=checkpoint_storage):
|
||||
print(f"Resumed Event: {event}")
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
Running workflow with initial message...
|
||||
UpperCaseExecutor: 'hello world' -> 'HELLO WORLD'
|
||||
Event: ExecutorInvokeEvent(executor_id=upper_case_executor)
|
||||
Event: ExecutorCompletedEvent(executor_id=upper_case_executor)
|
||||
ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH'
|
||||
Event: ExecutorInvokeEvent(executor_id=reverse_text_executor)
|
||||
Event: ExecutorCompletedEvent(executor_id=reverse_text_executor)
|
||||
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
|
||||
Event: ExecutorInvokeEvent(executor_id=submit_lower)
|
||||
Event: ExecutorInvokeEvent(executor_id=lower_agent)
|
||||
Event: ExecutorInvokeEvent(executor_id=finalize)
|
||||
|
||||
Checkpoint summary:
|
||||
- dfc63e72-8e8d-454f-9b6d-0d740b9062e6 | label='after_initial_execution' | iter=0 | messages=1 | states=['upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
|
||||
- a78c345a-e5d9-45ba-82c0-cb725452d91b | label='superstep_1' | iter=1 | messages=1 | states=['reverse_text_executor', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
|
||||
- 637c1dbd-a525-4404-9583-da03980537a2 | label='superstep_2' | iter=2 | messages=0 | states=['finalize', 'lower_agent', 'reverse_text_executor', 'submit_lower', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
|
||||
|
||||
Available checkpoints to resume from:
|
||||
[0] id=dfc63e72-... iter=0 messages=1 label='after_initial_execution'
|
||||
[1] id=a78c345a-... iter=1 messages=1 label='superstep_1'
|
||||
[2] id=637c1dbd-... iter=2 messages=0 label='superstep_2'
|
||||
|
||||
Enter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: 1
|
||||
|
||||
Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b
|
||||
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower)
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent)
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=finalize)
|
||||
""" # noqa: E501
|
||||
if output is not None:
|
||||
print(f"\nWorkflow completed successfully with output: {output}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,6 +7,7 @@ import uuid
|
||||
from dataclasses import dataclass, field, replace
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, override
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
@@ -205,6 +206,8 @@ class LaunchCoordinator(Executor):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(id="launch_coordinator")
|
||||
# Track pending requests to match responses
|
||||
self._pending_requests: dict[str, SubWorkflowRequestMessage] = {}
|
||||
|
||||
@handler
|
||||
async def kick_off(self, topic: str, ctx: WorkflowContext[DraftTask]) -> None:
|
||||
@@ -244,11 +247,9 @@ class LaunchCoordinator(Executor):
|
||||
if not isinstance(request.source_event.data, ReviewRequest):
|
||||
raise TypeError(f"Expected 'ReviewRequest', got {type(request.source_event.data)}")
|
||||
|
||||
# Record the request to response matching
|
||||
# Record the request for response matching
|
||||
review_request = request.source_event.data
|
||||
executor_state = await ctx.get_executor_state() or {}
|
||||
executor_state[review_request.id] = request
|
||||
await ctx.set_executor_state(executor_state)
|
||||
self._pending_requests[review_request.id] = request
|
||||
|
||||
# Send the request without modification
|
||||
await ctx.request_info(request_data=review_request, response_type=str)
|
||||
@@ -265,17 +266,25 @@ class LaunchCoordinator(Executor):
|
||||
Note that the response must be sent back using SubWorkflowResponseMessage to route
|
||||
the response back to the sub-workflow.
|
||||
"""
|
||||
executor_state = await ctx.get_executor_state() or {}
|
||||
request_message = executor_state.pop(original_request.id, None)
|
||||
|
||||
# Save the executor state back to the context
|
||||
await ctx.set_executor_state(executor_state)
|
||||
request_message = self._pending_requests.pop(original_request.id, None)
|
||||
|
||||
if request_message is None:
|
||||
raise ValueError("No matching pending request found for the resource response")
|
||||
|
||||
await ctx.send_message(request_message.create_response(response))
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture any additional state needed for checkpointing."""
|
||||
return {
|
||||
"pending_requests": self._pending_requests,
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore any additional state needed from checkpointing."""
|
||||
self._pending_requests = state.get("pending_requests", {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow construction helpers
|
||||
@@ -356,9 +365,7 @@ async def main() -> None:
|
||||
workflow2 = build_parent_workflow(storage)
|
||||
|
||||
request_info_event: RequestInfoEvent | None = None
|
||||
async for event in workflow2.run_stream(
|
||||
resume_checkpoint.checkpoint_id,
|
||||
):
|
||||
async for event in workflow2.run_stream(checkpoint_id=resume_checkpoint.checkpoint_id):
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
request_info_event = event
|
||||
|
||||
|
||||
Reference in New Issue
Block a user