Python: Add checkpoint save and restore hooks to executor (#2097)

* Add checkpoint hooks

* Deprecate get_executor_state and set_executor_state

* Fix tests and samples

* Add doc strings

* Add sample

* Fix import

* Address comments and fix tests

* Address comments

* conditional import
This commit is contained in:
Tao Chen
2025-11-17 10:19:01 -08:00
committed by GitHub
Unverified
parent 132597957a
commit c361ad8d33
22 changed files with 508 additions and 723 deletions
@@ -37,6 +37,8 @@ from ._events import (
ExecutorFailedEvent,
ExecutorInvokedEvent,
RequestInfoEvent,
SuperStepCompletedEvent,
SuperStepStartedEvent,
WorkflowErrorDetails,
WorkflowEvent,
WorkflowEventSource,
@@ -152,6 +154,8 @@ __all__ = [
"StandardMagenticManager",
"SubWorkflowRequestMessage",
"SubWorkflowResponseMessage",
"SuperStepCompletedEvent",
"SuperStepStartedEvent",
"SwitchCaseEdgeGroup",
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
@@ -35,6 +35,8 @@ from ._events import (
ExecutorFailedEvent,
ExecutorInvokedEvent,
RequestInfoEvent,
SuperStepCompletedEvent,
SuperStepStartedEvent,
WorkflowErrorDetails,
WorkflowEvent,
WorkflowEventSource,
@@ -148,6 +150,8 @@ __all__ = [
"StandardMagenticManager",
"SubWorkflowRequestMessage",
"SubWorkflowResponseMessage",
"SuperStepCompletedEvent",
"SuperStepStartedEvent",
"SwitchCaseEdgeGroup",
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
import sys
from dataclasses import dataclass
from typing import Any, cast
@@ -20,6 +21,11 @@ from ._message_utils import normalize_messages_input
from ._request_info_mixin import response_handler
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -179,7 +185,8 @@ class AgentExecutor(Executor):
self._pending_responses_to_agent.clear()
await self._run_agent_and_emit(ctx)
async def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.
NOTE: if the thread storage is on the server side, the full thread state
@@ -196,9 +203,6 @@ class AgentExecutor(Executor):
client_module = self._agent.chat_client.__class__.__module__
if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module:
# TODO(TaoChenOSU): update this warning when we surface the hooks for
# custom executor checkpointing.
# https://github.com/microsoft/agent-framework/issues/1816
logger.warning(
"Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. "
"Currently, checkpointing does not capture messages from server-side threads "
@@ -217,7 +221,8 @@ class AgentExecutor(Executor):
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
}
async def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore executor state from checkpoint.
Args:
@@ -4,6 +4,7 @@
import inspect
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
@@ -13,6 +14,12 @@ from ._executor import Executor
from ._orchestrator_helpers import ParticipantRegistry
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -210,11 +217,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
# State persistence (shared across all patterns)
def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current orchestrator state for checkpointing.
Default implementation uses OrchestrationState to serialize common state.
Subclasses should override _snapshot_pattern_metadata() to add pattern-specific data.
Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data.
Returns:
Serialized state dict
@@ -238,11 +246,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
"""
return {}
def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore orchestrator state from checkpoint.
Default implementation uses OrchestrationState to deserialize common state.
Subclasses should override _restore_pattern_metadata() to restore pattern-specific data.
Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data.
Args:
state: Serialized state dict
@@ -6,9 +6,7 @@ These utilities operate on standard `list[ChatMessage]` collections and simple
dictionary snapshots so orchestrators can share logic without new mixins.
"""
import json
from collections.abc import Mapping, Sequence
from typing import Any
from collections.abc import Sequence
from .._types import ChatMessage
@@ -26,25 +24,3 @@ def ensure_author(message: ChatMessage, fallback: str) -> ChatMessage:
"""Attach `fallback` author if message is missing `author_name`."""
message.author_name = message.author_name or fallback
return message
def snapshot_state(conversation: Sequence[ChatMessage]) -> dict[str, Any]:
"""Build an immutable snapshot for checkpoint storage."""
if hasattr(conversation, "to_dict"):
result = conversation.to_dict() # type: ignore[attr-defined]
if isinstance(result, dict):
return result # type: ignore[return-value]
if isinstance(result, Mapping):
return dict(result) # type: ignore[arg-type]
serialisable: list[dict[str, Any]] = []
for message in conversation:
if hasattr(message, "to_dict") and callable(message.to_dict): # type: ignore[attr-defined]
msg_dict = message.to_dict() # type: ignore[attr-defined]
serialisable.append(dict(msg_dict) if isinstance(msg_dict, Mapping) else msg_dict) # type: ignore[arg-type]
elif hasattr(message, "to_json") and callable(message.to_json): # type: ignore[attr-defined]
json_payload = message.to_json() # type: ignore[attr-defined]
parsed = json.loads(json_payload) if isinstance(json_payload, str) else json_payload
serialisable.append(dict(parsed) if isinstance(parsed, Mapping) else parsed) # type: ignore[arg-type]
else:
serialisable.append(dict(getattr(message, "__dict__", {}))) # type: ignore[arg-type]
return {"messages": serialisable}
@@ -294,6 +294,36 @@ class WorkflowOutputEvent(WorkflowEvent):
return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})"
class SuperStepEvent(WorkflowEvent):
"""Event triggered when a superstep starts or ends."""
def __init__(self, iteration: int, data: Any | None = None):
"""Initialize the superstep event.
Args:
iteration: The number of the superstep (1-based index).
data: Optional data associated with the superstep event.
"""
super().__init__(data)
self.iteration = iteration
def __repr__(self) -> str:
"""Return a string representation of the superstep event."""
return f"{self.__class__.__name__}(iteration={self.iteration}, data={self.data})"
class SuperStepStartedEvent(SuperStepEvent):
"""Event triggered when a superstep starts."""
...
class SuperStepCompletedEvent(SuperStepEvent):
"""Event triggered when a superstep ends."""
...
class ExecutorEvent(WorkflowEvent):
"""Base class for executor events."""
@@ -310,17 +340,13 @@ class ExecutorEvent(WorkflowEvent):
class ExecutorInvokedEvent(ExecutorEvent):
"""Event triggered when an executor handler is invoked."""
def __repr__(self) -> str:
"""Return a string representation of the executor handler invoke event."""
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
...
class ExecutorCompletedEvent(ExecutorEvent):
"""Event triggered when an executor handler is completed."""
def __repr__(self) -> str:
"""Return a string representation of the executor handler complete event."""
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
...
class ExecutorFailedEvent(ExecutorEvent):
@@ -155,6 +155,11 @@ class Executor(RequestInfoMixin, DictConvertible):
that parent workflows can intercept. See WorkflowExecutor documentation for details on
workflow composition patterns and request/response handling.
## State Management
Executors can contain states that persist across workflow runs and checkpoints. Override the
`on_checkpoint_save` and `on_checkpoint_restore` methods to implement custom state
serialization and restoration logic.
## Implementation Notes
- Do not call `execute()` directly - it's invoked by the workflow engine
- Do not override `execute()` - define handlers using decorators instead
@@ -460,6 +465,32 @@ class Executor(RequestInfoMixin, DictConvertible):
return self._handlers[message_type]
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Hook called when the workflow is being saved to a checkpoint.
Override this method in subclasses to implement custom logic that should
return state to be saved in the checkpoint.
The returned state dictionary will be passed to `on_checkpoint_restore`
when the workflow is restored from the checkpoint. The dictionary should
only contain JSON-serializable data.
Returns:
A state dictionary to be saved during checkpointing.
"""
return {}
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Hook called when the workflow is restored from a checkpoint.
Override this method in subclasses to implement custom logic that should
run when the workflow is restored from a checkpoint.
Args:
state: The state dictionary that was saved during checkpointing.
"""
...
# endregion: Executor
@@ -16,6 +16,7 @@ Key properties:
import logging
import re
import sys
from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any
@@ -50,6 +51,12 @@ from ._workflow import Workflow
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -307,15 +314,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
) -> None:
"""Process an agent's response and determine whether to route, request input, or terminate."""
# Hydrate coordinator state (and detect new run) using checkpointable executor state
state = await ctx.get_executor_state()
if not state:
self._clear_conversation()
elif not self._get_conversation():
restored = self._restore_conversation_from_state(state)
if restored:
self._conversation = list(restored)
source = ctx.get_source_executor_id()
is_starting_agent = source == self._starting_agent_id
@@ -343,7 +341,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
# Update current agent when handoff occurs
self._current_agent_id = target
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
await self._persist_state(ctx)
# Clean tool-related content before sending to next agent
cleaned = clean_conversation_for_handoff(conversation)
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
@@ -360,7 +358,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
f"Agent '{source}' responded without handoff. "
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
)
await self._persist_state(ctx)
if await self._check_termination():
# Clean the output conversation for display
@@ -388,7 +385,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
"""Receive full conversation with new user input from gateway, update history, trim for agent."""
# Update authoritative conversation
self._conversation = list(message.full_conversation)
await self._persist_state(ctx)
# Check termination before sending to agent
if await self._check_termination():
@@ -473,11 +469,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
)
return list(conversation)
async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None:
"""Store authoritative conversation snapshot without losing rich metadata."""
state_payload = self.snapshot_state()
await ctx.set_executor_state(state_payload)
@override
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""Serialize pattern-specific state.
@@ -492,6 +484,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
}
return {}
@override
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
"""Restore pattern-specific state.
@@ -503,17 +496,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
if self._return_to_previous and "current_agent_id" in metadata:
self._current_agent_id = metadata["current_agent_id"]
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
"""Rehydrate the coordinator's conversation history from checkpointed state.
DEPRECATED: Use restore_state() instead. Kept for backward compatibility.
"""
from ._orchestration_state import OrchestrationState
orch_state_dict = {"conversation": state.get("full_conversation", state.get("conversation", []))}
temp_state = OrchestrationState.from_dict(orch_state_dict)
return list(temp_state.conversation)
def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None:
"""Merge top-level response metadata into the latest assistant message."""
if not agent_response.additional_properties:
@@ -45,9 +45,15 @@ from ._workflow import Workflow, WorkflowRunResult
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -673,11 +679,11 @@ class MagenticManagerBase(ABC):
"""Prepare the final answer."""
...
def snapshot_state(self) -> dict[str, Any]:
def on_checkpoint_save(self) -> dict[str, Any]:
"""Serialize runtime state for checkpointing."""
return {}
def restore_state(self, state: dict[str, Any]) -> None:
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore runtime state from checkpoint data."""
return
@@ -695,22 +701,6 @@ class StandardMagenticManager(MagenticManagerBase):
task_ledger: _MagenticTaskLedger | None
def snapshot_state(self) -> dict[str, Any]:
state = super().snapshot_state()
if self.task_ledger is not None:
state = dict(state)
state["task_ledger"] = self.task_ledger.to_dict()
return state
def restore_state(self, state: dict[str, Any]) -> None:
super().restore_state(state)
ledger = state.get("task_ledger")
if ledger is not None:
try:
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager task ledger from checkpoint state")
def __init__(
self,
chat_client: ChatClientProtocol,
@@ -940,6 +930,22 @@ class StandardMagenticManager(MagenticManagerBase):
author_name=response.author_name or MAGENTIC_MANAGER_NAME,
)
@override
def on_checkpoint_save(self) -> dict[str, Any]:
state: dict[str, Any] = {}
if self.task_ledger is not None:
state["task_ledger"] = self.task_ledger.to_dict()
return state
@override
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
ledger = state.get("task_ledger")
if ledger is not None:
try:
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager task ledger from checkpoint state")
# endregion Magentic Manager
@@ -997,7 +1003,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
# Terminal state marker to stop further processing after completion/limits
self._terminated = False
# Tracks whether checkpoint state has been applied for this run
self._state_restored = False
def _get_author_name(self) -> str:
"""Get the magentic manager name for orchestrator-generated messages."""
@@ -1036,7 +1041,8 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
)
await ctx.add_event(event)
def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current orchestrator state for checkpointing.
Uses OrchestrationState for structure but maintains Magentic's complex metadata
@@ -1055,14 +1061,16 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
state["magentic_context"] = self._context.to_dict()
if self._task_ledger is not None:
state["task_ledger"] = _message_to_payload(self._task_ledger)
manager_state: dict[str, Any] | None = None
with contextlib.suppress(Exception):
manager_state = self._manager.snapshot_state()
if manager_state:
state["manager_state"] = manager_state
try:
state["manager_state"] = self._manager.on_checkpoint_save()
except Exception as exc:
logger.warning("Failed to save manager state for checkpoint: %s\nSkipping...", exc)
return state
def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore orchestrator state from checkpoint.
Maintains backward compatibility with existing Magentic checkpoints
@@ -1112,7 +1120,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
manager_state = state.get("manager_state")
if manager_state is not None:
try:
self._manager.restore_state(manager_state)
self._manager.on_checkpoint_restore(manager_state)
except Exception as exc: # pragma: no cover
logger.warning("Failed to restore manager state: %s", exc)
@@ -1142,49 +1150,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
for name, description in expected.items():
restored[name] = description
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""Serialize pattern-specific state.
Magentic uses custom snapshot_state() instead of base class hooks.
This method exists to satisfy the base class contract.
Returns:
Empty dict (Magentic manages its own state)
"""
return {}
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
"""Restore pattern-specific state.
Magentic uses custom restore_state() instead of base class hooks.
This method exists to satisfy the base class contract.
Args:
metadata: Pattern-specific state dict (ignored)
"""
pass
async def _ensure_state_restored(
self,
context: WorkflowContext[Any, Any],
) -> None:
if self._state_restored and self._context is not None:
return
state = await context.get_executor_state()
if not state:
self._state_restored = True
return
if not isinstance(state, dict):
self._state_restored = True
return
try:
self.restore_state(state)
except Exception as exc: # pragma: no cover
logger.warning("Magentic Orchestrator: Failed to apply checkpoint state: %s", exc, exc_info=True)
raise
else:
self._state_restored = True
@handler
async def handle_start_message(
self,
@@ -1204,7 +1169,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
)
if message.messages:
self._context.chat_history.extend(message.messages)
self._state_restored = True
# Non-streaming callback for the orchestrator receipt of the task
await self._emit_orchestrator_message(context, message.task, ORCH_MSG_KIND_USER_TASK)
@@ -1269,7 +1234,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
"""Handle responses from agents."""
if getattr(self, "_terminated", False):
return
await self._ensure_state_restored(context)
if self._context is None:
raise RuntimeError("Magentic Orchestrator: Received response but not initialized")
@@ -1301,7 +1266,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
) -> None:
if getattr(self, "_terminated", False):
return
await self._ensure_state_restored(context)
if self._context is None:
return
@@ -1636,9 +1601,9 @@ class MagenticAgentExecutor(Executor):
self._agent = agent
self._agent_id = agent_id
self._chat_history: list[ChatMessage] = []
self._state_restored = False
def snapshot_state(self) -> dict[str, Any]:
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.
Returns:
@@ -1650,7 +1615,8 @@ class MagenticAgentExecutor(Executor):
"chat_history": encode_chat_messages(self._chat_history),
}
def restore_state(self, state: dict[str, Any]) -> None:
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore executor state from checkpoint.
Args:
@@ -1668,24 +1634,6 @@ class MagenticAgentExecutor(Executor):
else:
self._chat_history = []
async def _ensure_state_restored(self, context: WorkflowContext[Any, Any]) -> None:
if self._state_restored and self._chat_history:
return
state = await context.get_executor_state()
if not state:
self._state_restored = True
return
if not isinstance(state, dict):
self._state_restored = True
return
try:
self.restore_state(state)
except Exception as exc: # pragma: no cover
logger.warning("Agent %s: Failed to apply checkpoint state: %s", self._agent_id, exc, exc_info=True)
raise
else:
self._state_restored = True
@handler
async def handle_response_message(
self, message: _MagenticResponseMessage, context: WorkflowContext[_MagenticResponseMessage]
@@ -1693,8 +1641,6 @@ class MagenticAgentExecutor(Executor):
"""Handle response message (task ledger broadcast)."""
logger.debug("Agent %s: Received response message", self._agent_id)
await self._ensure_state_restored(context)
# Check if this message is intended for this agent
if message.target_agent is not None and message.target_agent != self._agent_id and not message.broadcast:
# Message is targeted to a different agent, ignore it
@@ -1735,8 +1681,6 @@ class MagenticAgentExecutor(Executor):
logger.info("Agent %s: Received request to respond", self._agent_id)
await self._ensure_state_restored(context)
# Add persona adoption message with appropriate role
persona_role = self._get_persona_adoption_role()
persona_msg = ChatMessage(
@@ -1783,7 +1727,6 @@ class MagenticAgentExecutor(Executor):
"""Reset the internal chat history of the agent (internal operation)."""
logger.debug("Agent %s: Resetting chat history", self._agent_id)
self._chat_history.clear()
self._state_restored = True
async def _emit_agent_delta_event(
self,
@@ -11,7 +11,7 @@ from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpo
from ._const import EXECUTOR_STATE_KEY
from ._edge import EdgeGroup
from ._edge_runner import EdgeRunner, create_edge_runner
from ._events import WorkflowEvent
from ._events import SuperStepCompletedEvent, SuperStepStartedEvent, WorkflowEvent
from ._executor import Executor
from ._runner_context import (
Message,
@@ -92,6 +92,7 @@ class Runner:
while self._iteration < self._max_iterations:
logger.info(f"Starting superstep {self._iteration + 1}")
yield SuperStepStartedEvent(iteration=self._iteration + 1)
# Run iteration concurrently with live event streaming: we poll
# for new events while the iteration coroutine progresses.
@@ -126,6 +127,9 @@ class Runner:
# Create checkpoint after each superstep iteration
await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}")
yield SuperStepCompletedEvent(iteration=self._iteration)
# Check for convergence: no more messages to process
if not await self._ctx.has_messages():
break
@@ -183,8 +187,8 @@ class Runner:
return None
try:
# Auto-snapshot executor states
await self._auto_snapshot_executor_states()
# Snapshot executor states
await self._save_executor_states()
checkpoint_category = "initial" if checkpoint_type == "after_initial_execution" else "superstep"
metadata = {
"superstep": self._iteration,
@@ -203,41 +207,6 @@ class Runner:
logger.warning(f"Failed to create {checkpoint_type} checkpoint: {e}")
return None
async def _auto_snapshot_executor_states(self) -> None:
"""Populate executor state by calling snapshot hooks on executors if available.
TODO(@taochen#1614): this method is potentially problematic if executors also call
set_executor_state on the context directly. We should clarify the intended usage
pattern for executor state management.
Convention:
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
- Else if it has a plain attribute `state` that is a dict, use that.
Only JSON-serializable dicts should be provided by executors.
"""
for exec_id, executor in self._executors.items():
state_dict: dict[str, Any] | None = None
snapshot = getattr(executor, "snapshot_state", None)
try:
if callable(snapshot):
maybe = snapshot()
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
maybe = await maybe # type: ignore[assignment]
if isinstance(maybe, dict):
state_dict = maybe # type: ignore[assignment]
else:
state_attr = getattr(executor, "state", None)
if isinstance(state_attr, dict):
state_dict = state_attr # type: ignore[assignment]
except Exception as ex: # pragma: no cover
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
if state_dict is not None:
try:
await self._set_executor_state(exec_id, state_dict)
except Exception as ex: # pragma: no cover
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
async def restore_from_checkpoint(
self,
checkpoint_id: str,
@@ -300,7 +269,65 @@ class Runner:
logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}")
return False
async def _save_executor_states(self) -> None:
"""Populate executor state by calling checkpoint hooks on executors.
Backward compatibility behavior:
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
- Else if it has a plain attribute `state` that is a dict, use that.
Updated behavior:
- Executors should implement `on_checkpoint_save(self) -> dict` to provide state.
This method will try the backward compatibility behavior first; if that does not yield state,
it falls back to the updated behavior.
Only JSON-serializable dicts should be provided by executors.
"""
for exec_id, executor in self._executors.items():
state_dict: dict[str, Any] | None = None
# Try backward compatibility behavior first
# TODO(@taochen): Remove backward compatibility
snapshot = getattr(executor, "snapshot_state", None)
try:
if callable(snapshot):
maybe = snapshot()
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
maybe = await maybe # type: ignore[assignment]
if isinstance(maybe, dict):
state_dict = maybe # type: ignore[assignment]
else:
state_attr = getattr(executor, "state", None)
if isinstance(state_attr, dict):
state_dict = state_attr # type: ignore[assignment]
except Exception as ex: # pragma: no cover
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
if state_dict is None:
# Try the updated behavior only if backward compatibility did not yield state
try:
state_dict = await executor.on_checkpoint_save()
except Exception as ex: # pragma: no cover
raise ValueError(f"Executor {exec_id} on_checkpoint_save failed: {ex}") from ex
try:
await self._set_executor_state(exec_id, state_dict)
except Exception as ex: # pragma: no cover
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
async def _restore_executor_states(self) -> None:
"""Restore executor state by calling restore hooks on executors.
Backward compatibility behavior:
- If an executor defines an async or sync method `restore_state(self, state: dict)`, use it.
- Else, skip restoration for that executor.
Updated behavior:
- Executors should implement `on_checkpoint_restore(self, state: dict)` to restore state.
This method will try the backward compatibility behavior first; if that does not restore state,
it falls back to the updated behavior.
"""
has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
if not has_executor_states:
return
@@ -309,16 +336,18 @@ class Runner:
if not isinstance(executor_states, dict):
raise ValueError("Executor states in shared state is not a dictionary. Unable to restore.")
for executor_id, state in executor_states.items():
for executor_id, state in executor_states.items(): # pyright: ignore[reportUnknownVariableType]
if not isinstance(executor_id, str):
raise ValueError("Executor ID in executor states is not a string. Unable to restore.")
if not isinstance(state, dict):
raise ValueError(f"Executor state for {executor_id} is not a dictionary. Unable to restore.")
if not isinstance(state, dict) or not all(isinstance(k, str) for k in state): # pyright: ignore[reportUnknownVariableType]
raise ValueError(f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore.")
executor = self._executors.get(executor_id)
if not executor:
raise ValueError(f"Executor {executor_id} not found during state restoration.")
# Try backward compatibility behavior first
# TODO(@taochen): Remove backward compatibility
restored = False
restore_method = getattr(executor, "restore_state", None)
try:
@@ -330,6 +359,14 @@ class Runner:
except Exception as ex: # pragma: no cover - defensive
raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex
if not restored:
# Try the updated behavior only if backward compatibility did not restore
try:
await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType]
restored = True
except Exception as ex: # pragma: no cover - defensive
raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex
if not restored:
logger.debug(f"Executor {executor_id} does not support state restoration; skipping.")
@@ -109,9 +109,9 @@ class Workflow(DictConvertible):
"""A graph-based execution engine that orchestrates connected executors.
## Overview
A workflow executes a directed graph of executors connected via edge groups using a Pregel-like model,
running in supersteps until the graph becomes idle. Workflows are created using the
WorkflowBuilder class - do not instantiate this class directly.
A workflow executes a directed graph of executors connected via edge groups using a
Pregel-like model, running in supersteps until the graph becomes idle. Workflows
are created using the WorkflowBuilder class - do not instantiate this class directly.
## Execution Model
Executors run in synchronized supersteps where each executor:
@@ -142,6 +142,10 @@ class Workflow(DictConvertible):
- HIL continuation: Provide `responses` to continue after RequestInfoExecutor requests
- Runtime checkpointing: Provide `checkpoint_storage` to enable/override checkpointing for this run
## State Management
Workflow instances contain states and states are preserved across calls to `run` and `run_stream`.
To execute multiple independent runs, create separate Workflow instances via WorkflowBuilder.
## External Input Requests
Executors within a workflow can request external input using `ctx.request_info()`:
1. Executor calls `ctx.request_info()` to request input
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Union, cast, get_args, get_origi
from opentelemetry.propagate import inject
from opentelemetry.trace import SpanKind
from typing_extensions import Never, TypeVar
from typing_extensions import Never, TypeVar, deprecated
from ..observability import OtelAttr, create_workflow_span
from ._const import EXECUTOR_STATE_KEY
@@ -410,6 +410,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
"""Get the shared state."""
return self._shared_state
@deprecated(
"Override `on_checkpoint_save()` methods instead. "
"For cross-executor state sharing, use set_shared_state() instead. "
"This API will be removed after 12/01/2025."
)
async def set_executor_state(self, state: dict[str, Any]) -> None:
"""Store executor state in shared state under a reserved key.
@@ -428,6 +433,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
existing_states[self._executor_id] = state
await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states)
@deprecated(
"Override `on_checkpoint_restore()` methods instead. "
"For cross-executor state sharing, use get_shared_state() instead. "
"This API will be removed after 12/01/2025."
)
async def get_executor_state(self) -> dict[str, Any] | None:
"""Retrieve previously persisted state for this executor, if any."""
has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import contextlib
import logging
import sys
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
@@ -26,6 +26,12 @@ from ._typing_utils import is_instance_of
from ._workflow import WorkflowRunResult
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
@@ -181,8 +187,7 @@ class WorkflowExecutor(Executor):
# Includes all sub-workflow output types
# Plus SubWorkflowRequestMessage if sub-workflow can make requests
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
```
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
## Error Handling
WorkflowExecutor propagates sub-workflow failures:
@@ -221,23 +226,10 @@ class WorkflowExecutor(Executor):
### Important Considerations
**Shared Workflow Instance**: All concurrent executions use the same underlying workflow instance.
For proper isolation, ensure that:
- The wrapped workflow and its executors are stateless
- Executors use WorkflowContext state management instead of instance variables
- Any shared state is managed through WorkflowContext.get_shared_state/set_shared_state
For proper isolation, ensure that the wrapped workflow and its executors are stateless.
.. code-block:: python
# Good: Stateless executor using context state
class StatelessExecutor(Executor):
@handler
async def process(self, data: str, ctx: WorkflowContext[str]) -> None:
# Use context state instead of instance variables
state = await ctx.get_executor_state() or {}
state["processed"] = data
await ctx.set_executor_state(state)
# Avoid: Stateful executor with instance variables
class StatefulExecutor(Executor):
def __init__(self):
@@ -246,23 +238,23 @@ class WorkflowExecutor(Executor):
## Integration with Parent Workflows
Parent workflows can intercept sub-workflow requests:
```python
class ParentExecutor(Executor):
@handler
async def handle_subworkflow_request(
self,
request: SubWorkflowRequestMessage,
ctx: WorkflowContext[SubWorkflowResponseMessage],
) -> None:
# Handle request locally or forward to external source
if self.can_handle_locally(request):
# Send response back to sub-workflow
response = request.create_response(data="local response data")
await ctx.send_message(response, target_id=request.source_executor_id)
else:
# Forward to external handler
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
```
.. code-block:: python
class ParentExecutor(Executor):
@handler
async def handle_subworkflow_request(
self,
request: SubWorkflowRequestMessage,
ctx: WorkflowContext[SubWorkflowResponseMessage],
) -> None:
# Handle request locally or forward to external source
if self.can_handle_locally(request):
# Send response back to sub-workflow
response = request.create_response(data="local response data")
await ctx.send_message(response, target_id=request.source_executor_id)
else:
# Forward to external handler
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
## Implementation Notes
- Sub-workflows run to completion before processing their results
@@ -296,7 +288,6 @@ class WorkflowExecutor(Executor):
self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext
# Map request_id to execution_id for response routing
self._request_to_execution: dict[str, str] = {} # request_id -> execution_id
self._state_loaded: bool = False
@property
def input_types(self) -> list[type[Any]]:
@@ -362,8 +353,6 @@ class WorkflowExecutor(Executor):
input_data: The input data to send to the sub-workflow.
ctx: The workflow context from the parent.
"""
await self._ensure_state_loaded(ctx)
# Create execution context for this sub-workflow run
execution_id = str(uuid.uuid4())
execution_context = ExecutionContext(
@@ -405,8 +394,6 @@ class WorkflowExecutor(Executor):
response: The response to a previous request.
ctx: The workflow context.
"""
await self._ensure_state_loaded(ctx)
# Find the execution context for this request
original_request = response.source_event
execution_id = self._request_to_execution.get(original_request.request_id)
@@ -434,8 +421,6 @@ class WorkflowExecutor(Executor):
# Accumulate the response in this execution's context
execution_context.collected_responses[original_request.request_id] = response.data
await self._persist_execution_state(ctx)
# Check if we have all expected responses for this execution
if len(execution_context.collected_responses) < execution_context.expected_response_count:
logger.debug(
@@ -459,25 +444,20 @@ class WorkflowExecutor(Executor):
if not execution_context.pending_requests:
del self._execution_contexts[execution_id]
async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None:
if self._state_loaded:
return
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Get the current state of the WorkflowExecutor for checkpointing purposes."""
return {
"execution_contexts": {
execution_id: encode_checkpoint_value(execution_context)
for execution_id, execution_context in self._execution_contexts.items()
},
"request_to_execution": dict(self._request_to_execution),
}
state: dict[str, Any] | None = None
try:
state = await ctx.get_executor_state()
except Exception:
state = None
if isinstance(state, dict) and state:
with contextlib.suppress(Exception):
await self.restore_state(state)
self._state_loaded = True
else:
self._state_loaded = True
async def restore_state(self, state: dict[str, Any]) -> None:
"""Restore pending request bookkeeping from a checkpoint snapshot."""
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore the WorkflowExecutor state from a checkpoint snapshot."""
# Validate the state contains the right keys
if "execution_contexts" not in state:
raise KeyError("Missing 'execution_contexts' in WorkflowExecutor state.")
@@ -529,23 +509,6 @@ class WorkflowExecutor(Executor):
for event in request_info_events
])
self._state_loaded = True
async def _persist_execution_state(self, ctx: WorkflowContext) -> None:
"""Persist the state of the WorkflowExecutor for checkpointing purposes."""
state = {
"execution_contexts": {
execution_id: encode_checkpoint_value(execution_context)
for execution_id, execution_context in self._execution_contexts.items()
},
"request_to_execution": dict(self._request_to_execution),
}
try:
await ctx.set_executor_state(state)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}")
async def _process_workflow_result(
self,
result: WorkflowRunResult,
@@ -635,5 +598,3 @@ class WorkflowExecutor(Executor):
)
else:
raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}")
await self._persist_execution_state(ctx)
@@ -158,8 +158,8 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
assert thread_messages[1].text == "Initial response 1"
async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
"""Test AgentExecutor's snapshot_state and restore_state methods directly."""
async def test_agent_executor_save_and_restore_state_directly() -> None:
"""Test AgentExecutor's on_checkpoint_save and on_checkpoint_restore methods directly."""
# Create agent with thread containing messages
agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent")
thread = AgentThread(message_store=ChatMessageStore())
@@ -182,7 +182,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage]
# Snapshot the state
state = await executor.snapshot_state() # type: ignore[reportUnknownMemberType]
state = await executor.on_checkpoint_save()
# Verify snapshot contains both cache and thread
assert "cache" in state
@@ -206,7 +206,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
assert len(initial_thread_msgs) == 0
# Restore state
await new_executor.restore_state(state) # type: ignore[reportUnknownMemberType]
await new_executor.on_checkpoint_restore(state)
# Verify cache is restored
restored_cache = new_executor._cache # type: ignore[reportPrivateUsage]
@@ -288,57 +288,6 @@ def test_build_fails_without_participants():
HandoffBuilder().build()
async def test_multiple_runs_dont_leak_conversation():
"""Verify that running the same workflow multiple times doesn't leak conversation history."""
triage = _RecordingAgent(name="triage", handoff_to="specialist")
specialist = _RecordingAgent(name="specialist")
workflow = (
HandoffBuilder(participants=[triage, specialist])
.set_coordinator("triage")
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2)
.build()
)
# First run
events = await _drain(workflow.run_stream("First run message"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second message"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert outputs, "First run should emit output"
first_run_conversation = outputs[-1].data
assert isinstance(first_run_conversation, list)
first_run_conv_list = cast(list[ChatMessage], first_run_conversation)
first_run_user_messages = [msg for msg in first_run_conv_list if msg.role == Role.USER]
assert len(first_run_user_messages) == 2
assert any("First run message" in msg.text for msg in first_run_user_messages if msg.text)
# Second run - should start fresh, not include first run's messages
triage.calls.clear()
specialist.calls.clear()
events = await _drain(workflow.run_stream("Second run different message"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Another message"}))
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
assert outputs, "Second run should emit output"
second_run_conversation = outputs[-1].data
assert isinstance(second_run_conversation, list)
second_run_conv_list = cast(list[ChatMessage], second_run_conversation)
second_run_user_messages = [msg for msg in second_run_conv_list if msg.role == Role.USER]
assert len(second_run_user_messages) == 2, (
"Second run should have exactly 2 user messages, not accumulate first run"
)
assert any("Second run different message" in msg.text for msg in second_run_user_messages if msg.text)
assert not any("First run message" in msg.text for msg in second_run_user_messages if msg.text), (
"Second run should NOT contain first run's messages"
)
async def test_handoff_async_termination_condition() -> None:
"""Test that async termination conditions work correctly."""
termination_call_count = 0
@@ -585,7 +534,7 @@ async def test_return_to_previous_state_serialization():
coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage]
# Snapshot the state
state = coordinator.snapshot_state()
state = await coordinator.on_checkpoint_save()
# Verify pattern metadata includes current_agent_id
assert "metadata" in state
@@ -603,7 +552,7 @@ async def test_return_to_previous_state_serialization():
)
# Restore state
coordinator2.restore_state(state)
await coordinator2.on_checkpoint_restore(state)
# Verify current_agent_id was restored
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from collections.abc import AsyncIterable
from dataclasses import dataclass
from typing import Any, cast
@@ -42,6 +43,11 @@ from agent_framework._workflows._magentic import ( # type: ignore[reportPrivate
_MagenticStartMessage, # type: ignore
)
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
def test_magentic_start_message_from_string():
msg = _MagenticStartMessage.from_string("Do the thing")
@@ -101,8 +107,9 @@ class FakeManager(MagenticManagerBase):
next_speaker_name: str = "agentA"
instruction_text: str = "Proceed with step 1"
def snapshot_state(self) -> dict[str, Any]:
state = super().snapshot_state()
@override
def on_checkpoint_save(self) -> dict[str, Any]:
state = super().on_checkpoint_save()
if self.task_ledger is not None:
state = dict(state)
state["task_ledger"] = {
@@ -111,8 +118,9 @@ class FakeManager(MagenticManagerBase):
}
return state
def restore_state(self, state: dict[str, Any]) -> None:
super().restore_state(state)
@override
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
super().on_checkpoint_restore(state)
ledger_state = state.get("task_ledger")
if isinstance(ledger_state, dict):
ledger_dict = cast(dict[str, Any], ledger_state)
@@ -185,7 +193,6 @@ async def test_standard_manager_progress_ledger_and_fallback():
assert ledger2.is_request_satisfied.answer is False
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_magentic_workflow_plan_review_approval_to_completion():
manager = FakeManager(max_round_count=10)
wf = (
@@ -204,7 +211,7 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
completed = False
output: ChatMessage | None = None
async for ev in wf.run_stream(
async for ev in wf.send_responses_streaming(
responses={req_event.request_id: MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE)}
):
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
@@ -218,7 +225,6 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
assert isinstance(output, ChatMessage)
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds():
class CountingManager(FakeManager):
# Declare as a model field so assignment is allowed under Pydantic
@@ -250,7 +256,7 @@ async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds()
# Reply APPROVE with comments (no edited text). Expect one replan and no second review round.
saw_second_review = False
completed = False
async for ev in wf.run_stream(
async for ev in wf.send_responses_streaming(
responses={
req_event.request_id: MagenticPlanReviewReply(
decision=MagenticPlanReviewDecision.APPROVE,
@@ -298,7 +304,6 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result():
assert data.role == Role.ASSISTANT
@pytest.mark.skip(reason="Response handling refactored - send_responses_streaming no longer exists")
async def test_magentic_checkpoint_resume_round_trip():
storage = InMemoryCheckpointStorage()
@@ -369,7 +374,7 @@ class _DummyExec(Executor):
pass
def test_magentic_agent_executor_snapshot_roundtrip():
async def test_magentic_agent_executor_on_checkpoint_save_and_restore_roundtrip():
backing_executor = _DummyExec("backing")
agent_exec = MagenticAgentExecutor(backing_executor, "agentA")
agent_exec._chat_history.extend([ # type: ignore[reportPrivateUsage]
@@ -377,10 +382,10 @@ def test_magentic_agent_executor_snapshot_roundtrip():
ChatMessage(role=Role.ASSISTANT, text="world", author_name="agentA"),
])
state = agent_exec.snapshot_state()
state = await agent_exec.on_checkpoint_save()
restored_executor = MagenticAgentExecutor(_DummyExec("backing2"), "agentA")
restored_executor.restore_state(state)
await restored_executor.on_checkpoint_restore(state)
assert len(restored_executor._chat_history) == 2 # type: ignore[reportPrivateUsage]
assert restored_executor._chat_history[0].text == "hello" # type: ignore[reportPrivateUsage]
@@ -199,7 +199,10 @@ async def test_fan_out():
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
# executor_b will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
assert len(events) == 7
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
# This workflow will converge in 2 supersteps because executor_c will send one more message
# after executor_b completes
assert len(events) == 11
assert events.get_final_state() == WorkflowRunState.IDLE
outputs = events.get_outputs()
@@ -220,7 +223,9 @@ async def test_fan_out_multiple_completed_events():
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
# executor_b and executor_c will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
assert len(events) == 8
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
# This workflow will converge in 1 superstep because executor_a and executor_b will not send further messages
assert len(events) == 10
# Multiple outputs are expected from both executors
outputs = events.get_outputs()
@@ -246,7 +251,8 @@ async def test_fan_in():
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
# aggregator will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
assert len(events) == 9
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
assert len(events) == 13
assert events.get_final_state() == WorkflowRunState.IDLE
outputs = events.get_outputs()
@@ -37,10 +37,14 @@ class WorkflowHILRequest:
class WorkflowTestExecutor(Executor):
"""Test executor with HIL."""
def __init__(self, id: str) -> None:
super().__init__(id=id)
self._data_value: str | None = None
@handler
async def process(self, data: WorkflowTestData, ctx: WorkflowContext) -> None:
"""Process data and request approval."""
await ctx.set_executor_state({"data_value": data.value})
self._data_value = data.value
# Request HIL (checkpoint created here)
await ctx.request_info(request_data=WorkflowHILRequest(question=f"Approve {data.value}?"), response_type=str)
@@ -50,8 +54,7 @@ class WorkflowTestExecutor(Executor):
self, original_request: WorkflowHILRequest, response: str, ctx: WorkflowContext[str]
) -> None:
"""Handle HIL response."""
state = await ctx.get_executor_state() or {}
value = state.get("data_value", "")
value = self._data_value or ""
await ctx.send_message(f"{value}_approved" if response.lower() == "yes" else f"{value}_rejected")