mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add checkpoint save and restore hooks to executor (#2097)
* Add checkpoint hooks * Deprecate get_executor_state and set_executor_state * Fix tests and samples * Add doc strings * Add sample * Fix import * Address comments and fix tests * Address comments * conditional import
This commit is contained in:
committed by
GitHub
Unverified
parent
132597957a
commit
c361ad8d33
@@ -37,6 +37,8 @@ from ._events import (
|
||||
ExecutorFailedEvent,
|
||||
ExecutorInvokedEvent,
|
||||
RequestInfoEvent,
|
||||
SuperStepCompletedEvent,
|
||||
SuperStepStartedEvent,
|
||||
WorkflowErrorDetails,
|
||||
WorkflowEvent,
|
||||
WorkflowEventSource,
|
||||
@@ -152,6 +154,8 @@ __all__ = [
|
||||
"StandardMagenticManager",
|
||||
"SubWorkflowRequestMessage",
|
||||
"SubWorkflowResponseMessage",
|
||||
"SuperStepCompletedEvent",
|
||||
"SuperStepStartedEvent",
|
||||
"SwitchCaseEdgeGroup",
|
||||
"SwitchCaseEdgeGroupCase",
|
||||
"SwitchCaseEdgeGroupDefault",
|
||||
|
||||
@@ -35,6 +35,8 @@ from ._events import (
|
||||
ExecutorFailedEvent,
|
||||
ExecutorInvokedEvent,
|
||||
RequestInfoEvent,
|
||||
SuperStepCompletedEvent,
|
||||
SuperStepStartedEvent,
|
||||
WorkflowErrorDetails,
|
||||
WorkflowEvent,
|
||||
WorkflowEventSource,
|
||||
@@ -148,6 +150,8 @@ __all__ = [
|
||||
"StandardMagenticManager",
|
||||
"SubWorkflowRequestMessage",
|
||||
"SubWorkflowResponseMessage",
|
||||
"SuperStepCompletedEvent",
|
||||
"SuperStepStartedEvent",
|
||||
"SwitchCaseEdgeGroup",
|
||||
"SwitchCaseEdgeGroupCase",
|
||||
"SwitchCaseEdgeGroupDefault",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -20,6 +21,11 @@ from ._message_utils import normalize_messages_input
|
||||
from ._request_info_mixin import response_handler
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -179,7 +185,8 @@ class AgentExecutor(Executor):
|
||||
self._pending_responses_to_agent.clear()
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
async def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current executor state for checkpointing.
|
||||
|
||||
NOTE: if the thread storage is on the server side, the full thread state
|
||||
@@ -196,9 +203,6 @@ class AgentExecutor(Executor):
|
||||
client_module = self._agent.chat_client.__class__.__module__
|
||||
|
||||
if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module:
|
||||
# TODO(TaoChenOSU): update this warning when we surface the hooks for
|
||||
# custom executor checkpointing.
|
||||
# https://github.com/microsoft/agent-framework/issues/1816
|
||||
logger.warning(
|
||||
"Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. "
|
||||
"Currently, checkpointing does not capture messages from server-side threads "
|
||||
@@ -217,7 +221,8 @@ class AgentExecutor(Executor):
|
||||
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
|
||||
}
|
||||
|
||||
async def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore executor state from checkpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any
|
||||
@@ -13,6 +14,12 @@ from ._executor import Executor
|
||||
from ._orchestrator_helpers import ParticipantRegistry
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -210,11 +217,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
|
||||
|
||||
# State persistence (shared across all patterns)
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current orchestrator state for checkpointing.
|
||||
|
||||
Default implementation uses OrchestrationState to serialize common state.
|
||||
Subclasses should override _snapshot_pattern_metadata() to add pattern-specific data.
|
||||
Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data.
|
||||
|
||||
Returns:
|
||||
Serialized state dict
|
||||
@@ -238,11 +246,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore orchestrator state from checkpoint.
|
||||
|
||||
Default implementation uses OrchestrationState to deserialize common state.
|
||||
Subclasses should override _restore_pattern_metadata() to restore pattern-specific data.
|
||||
Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data.
|
||||
|
||||
Args:
|
||||
state: Serialized state dict
|
||||
|
||||
@@ -6,9 +6,7 @@ These utilities operate on standard `list[ChatMessage]` collections and simple
|
||||
dictionary snapshots so orchestrators can share logic without new mixins.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .._types import ChatMessage
|
||||
|
||||
@@ -26,25 +24,3 @@ def ensure_author(message: ChatMessage, fallback: str) -> ChatMessage:
|
||||
"""Attach `fallback` author if message is missing `author_name`."""
|
||||
message.author_name = message.author_name or fallback
|
||||
return message
|
||||
|
||||
|
||||
def snapshot_state(conversation: Sequence[ChatMessage]) -> dict[str, Any]:
|
||||
"""Build an immutable snapshot for checkpoint storage."""
|
||||
if hasattr(conversation, "to_dict"):
|
||||
result = conversation.to_dict() # type: ignore[attr-defined]
|
||||
if isinstance(result, dict):
|
||||
return result # type: ignore[return-value]
|
||||
if isinstance(result, Mapping):
|
||||
return dict(result) # type: ignore[arg-type]
|
||||
serialisable: list[dict[str, Any]] = []
|
||||
for message in conversation:
|
||||
if hasattr(message, "to_dict") and callable(message.to_dict): # type: ignore[attr-defined]
|
||||
msg_dict = message.to_dict() # type: ignore[attr-defined]
|
||||
serialisable.append(dict(msg_dict) if isinstance(msg_dict, Mapping) else msg_dict) # type: ignore[arg-type]
|
||||
elif hasattr(message, "to_json") and callable(message.to_json): # type: ignore[attr-defined]
|
||||
json_payload = message.to_json() # type: ignore[attr-defined]
|
||||
parsed = json.loads(json_payload) if isinstance(json_payload, str) else json_payload
|
||||
serialisable.append(dict(parsed) if isinstance(parsed, Mapping) else parsed) # type: ignore[arg-type]
|
||||
else:
|
||||
serialisable.append(dict(getattr(message, "__dict__", {}))) # type: ignore[arg-type]
|
||||
return {"messages": serialisable}
|
||||
|
||||
@@ -294,6 +294,36 @@ class WorkflowOutputEvent(WorkflowEvent):
|
||||
return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})"
|
||||
|
||||
|
||||
class SuperStepEvent(WorkflowEvent):
|
||||
"""Event triggered when a superstep starts or ends."""
|
||||
|
||||
def __init__(self, iteration: int, data: Any | None = None):
|
||||
"""Initialize the superstep event.
|
||||
|
||||
Args:
|
||||
iteration: The number of the superstep (1-based index).
|
||||
data: Optional data associated with the superstep event.
|
||||
"""
|
||||
super().__init__(data)
|
||||
self.iteration = iteration
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the superstep event."""
|
||||
return f"{self.__class__.__name__}(iteration={self.iteration}, data={self.data})"
|
||||
|
||||
|
||||
class SuperStepStartedEvent(SuperStepEvent):
|
||||
"""Event triggered when a superstep starts."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class SuperStepCompletedEvent(SuperStepEvent):
|
||||
"""Event triggered when a superstep ends."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ExecutorEvent(WorkflowEvent):
|
||||
"""Base class for executor events."""
|
||||
|
||||
@@ -310,17 +340,13 @@ class ExecutorEvent(WorkflowEvent):
|
||||
class ExecutorInvokedEvent(ExecutorEvent):
|
||||
"""Event triggered when an executor handler is invoked."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the executor handler invoke event."""
|
||||
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
|
||||
...
|
||||
|
||||
|
||||
class ExecutorCompletedEvent(ExecutorEvent):
|
||||
"""Event triggered when an executor handler is completed."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the executor handler complete event."""
|
||||
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
|
||||
...
|
||||
|
||||
|
||||
class ExecutorFailedEvent(ExecutorEvent):
|
||||
|
||||
@@ -155,6 +155,11 @@ class Executor(RequestInfoMixin, DictConvertible):
|
||||
that parent workflows can intercept. See WorkflowExecutor documentation for details on
|
||||
workflow composition patterns and request/response handling.
|
||||
|
||||
## State Management
|
||||
Executors can contain states that persist across workflow runs and checkpoints. Override the
|
||||
`on_checkpoint_save` and `on_checkpoint_restore` methods to implement custom state
|
||||
serialization and restoration logic.
|
||||
|
||||
## Implementation Notes
|
||||
- Do not call `execute()` directly - it's invoked by the workflow engine
|
||||
- Do not override `execute()` - define handlers using decorators instead
|
||||
@@ -460,6 +465,32 @@ class Executor(RequestInfoMixin, DictConvertible):
|
||||
return self._handlers[message_type]
|
||||
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")
|
||||
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Hook called when the workflow is being saved to a checkpoint.
|
||||
|
||||
Override this method in subclasses to implement custom logic that should
|
||||
return state to be saved in the checkpoint.
|
||||
|
||||
The returned state dictionary will be passed to `on_checkpoint_restore`
|
||||
when the workflow is restored from the checkpoint. The dictionary should
|
||||
only contain JSON-serializable data.
|
||||
|
||||
Returns:
|
||||
A state dictionary to be saved during checkpointing.
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Hook called when the workflow is restored from a checkpoint.
|
||||
|
||||
Override this method in subclasses to implement custom logic that should
|
||||
run when the workflow is restored from a checkpoint.
|
||||
|
||||
Args:
|
||||
state: The state dictionary that was saved during checkpointing.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# endregion: Executor
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ Key properties:
|
||||
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
@@ -50,6 +51,12 @@ from ._workflow import Workflow
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -307,15 +314,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
|
||||
) -> None:
|
||||
"""Process an agent's response and determine whether to route, request input, or terminate."""
|
||||
# Hydrate coordinator state (and detect new run) using checkpointable executor state
|
||||
state = await ctx.get_executor_state()
|
||||
if not state:
|
||||
self._clear_conversation()
|
||||
elif not self._get_conversation():
|
||||
restored = self._restore_conversation_from_state(state)
|
||||
if restored:
|
||||
self._conversation = list(restored)
|
||||
|
||||
source = ctx.get_source_executor_id()
|
||||
is_starting_agent = source == self._starting_agent_id
|
||||
|
||||
@@ -343,7 +341,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
# Update current agent when handoff occurs
|
||||
self._current_agent_id = target
|
||||
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
|
||||
await self._persist_state(ctx)
|
||||
|
||||
# Clean tool-related content before sending to next agent
|
||||
cleaned = clean_conversation_for_handoff(conversation)
|
||||
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
|
||||
@@ -360,7 +358,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
f"Agent '{source}' responded without handoff. "
|
||||
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
|
||||
)
|
||||
await self._persist_state(ctx)
|
||||
|
||||
if await self._check_termination():
|
||||
# Clean the output conversation for display
|
||||
@@ -388,7 +385,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
"""Receive full conversation with new user input from gateway, update history, trim for agent."""
|
||||
# Update authoritative conversation
|
||||
self._conversation = list(message.full_conversation)
|
||||
await self._persist_state(ctx)
|
||||
|
||||
# Check termination before sending to agent
|
||||
if await self._check_termination():
|
||||
@@ -473,11 +469,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
)
|
||||
return list(conversation)
|
||||
|
||||
async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None:
|
||||
"""Store authoritative conversation snapshot without losing rich metadata."""
|
||||
state_payload = self.snapshot_state()
|
||||
await ctx.set_executor_state(state_payload)
|
||||
|
||||
@override
|
||||
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
|
||||
"""Serialize pattern-specific state.
|
||||
|
||||
@@ -492,6 +484,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
}
|
||||
return {}
|
||||
|
||||
@override
|
||||
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
|
||||
"""Restore pattern-specific state.
|
||||
|
||||
@@ -503,17 +496,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
if self._return_to_previous and "current_agent_id" in metadata:
|
||||
self._current_agent_id = metadata["current_agent_id"]
|
||||
|
||||
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
|
||||
"""Rehydrate the coordinator's conversation history from checkpointed state.
|
||||
|
||||
DEPRECATED: Use restore_state() instead. Kept for backward compatibility.
|
||||
"""
|
||||
from ._orchestration_state import OrchestrationState
|
||||
|
||||
orch_state_dict = {"conversation": state.get("full_conversation", state.get("conversation", []))}
|
||||
temp_state = OrchestrationState.from_dict(orch_state_dict)
|
||||
return list(temp_state.conversation)
|
||||
|
||||
def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None:
|
||||
"""Merge top-level response metadata into the latest assistant message."""
|
||||
if not agent_response.additional_properties:
|
||||
|
||||
@@ -45,9 +45,15 @@ from ._workflow import Workflow, WorkflowRunResult
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -673,11 +679,11 @@ class MagenticManagerBase(ABC):
|
||||
"""Prepare the final answer."""
|
||||
...
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Serialize runtime state for checkpointing."""
|
||||
return {}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore runtime state from checkpoint data."""
|
||||
return
|
||||
|
||||
@@ -695,22 +701,6 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
|
||||
task_ledger: _MagenticTaskLedger | None
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
state = super().snapshot_state()
|
||||
if self.task_ledger is not None:
|
||||
state = dict(state)
|
||||
state["task_ledger"] = self.task_ledger.to_dict()
|
||||
return state
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
super().restore_state(state)
|
||||
ledger = state.get("task_ledger")
|
||||
if ledger is not None:
|
||||
try:
|
||||
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Failed to restore manager task ledger from checkpoint state")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_client: ChatClientProtocol,
|
||||
@@ -940,6 +930,22 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
author_name=response.author_name or MAGENTIC_MANAGER_NAME,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
state: dict[str, Any] = {}
|
||||
if self.task_ledger is not None:
|
||||
state["task_ledger"] = self.task_ledger.to_dict()
|
||||
return state
|
||||
|
||||
@override
|
||||
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
ledger = state.get("task_ledger")
|
||||
if ledger is not None:
|
||||
try:
|
||||
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Failed to restore manager task ledger from checkpoint state")
|
||||
|
||||
|
||||
# endregion Magentic Manager
|
||||
|
||||
@@ -997,7 +1003,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
# Terminal state marker to stop further processing after completion/limits
|
||||
self._terminated = False
|
||||
# Tracks whether checkpoint state has been applied for this run
|
||||
self._state_restored = False
|
||||
|
||||
def _get_author_name(self) -> str:
|
||||
"""Get the magentic manager name for orchestrator-generated messages."""
|
||||
@@ -1036,7 +1041,8 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
)
|
||||
await ctx.add_event(event)
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current orchestrator state for checkpointing.
|
||||
|
||||
Uses OrchestrationState for structure but maintains Magentic's complex metadata
|
||||
@@ -1055,14 +1061,16 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
state["magentic_context"] = self._context.to_dict()
|
||||
if self._task_ledger is not None:
|
||||
state["task_ledger"] = _message_to_payload(self._task_ledger)
|
||||
manager_state: dict[str, Any] | None = None
|
||||
with contextlib.suppress(Exception):
|
||||
manager_state = self._manager.snapshot_state()
|
||||
if manager_state:
|
||||
state["manager_state"] = manager_state
|
||||
|
||||
try:
|
||||
state["manager_state"] = self._manager.on_checkpoint_save()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to save manager state for checkpoint: %s\nSkipping...", exc)
|
||||
|
||||
return state
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore orchestrator state from checkpoint.
|
||||
|
||||
Maintains backward compatibility with existing Magentic checkpoints
|
||||
@@ -1112,7 +1120,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
manager_state = state.get("manager_state")
|
||||
if manager_state is not None:
|
||||
try:
|
||||
self._manager.restore_state(manager_state)
|
||||
self._manager.on_checkpoint_restore(manager_state)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning("Failed to restore manager state: %s", exc)
|
||||
|
||||
@@ -1142,49 +1150,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
for name, description in expected.items():
|
||||
restored[name] = description
|
||||
|
||||
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
|
||||
"""Serialize pattern-specific state.
|
||||
|
||||
Magentic uses custom snapshot_state() instead of base class hooks.
|
||||
This method exists to satisfy the base class contract.
|
||||
|
||||
Returns:
|
||||
Empty dict (Magentic manages its own state)
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
|
||||
"""Restore pattern-specific state.
|
||||
|
||||
Magentic uses custom restore_state() instead of base class hooks.
|
||||
This method exists to satisfy the base class contract.
|
||||
|
||||
Args:
|
||||
metadata: Pattern-specific state dict (ignored)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _ensure_state_restored(
|
||||
self,
|
||||
context: WorkflowContext[Any, Any],
|
||||
) -> None:
|
||||
if self._state_restored and self._context is not None:
|
||||
return
|
||||
state = await context.get_executor_state()
|
||||
if not state:
|
||||
self._state_restored = True
|
||||
return
|
||||
if not isinstance(state, dict):
|
||||
self._state_restored = True
|
||||
return
|
||||
try:
|
||||
self.restore_state(state)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning("Magentic Orchestrator: Failed to apply checkpoint state: %s", exc, exc_info=True)
|
||||
raise
|
||||
else:
|
||||
self._state_restored = True
|
||||
|
||||
@handler
|
||||
async def handle_start_message(
|
||||
self,
|
||||
@@ -1204,7 +1169,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
)
|
||||
if message.messages:
|
||||
self._context.chat_history.extend(message.messages)
|
||||
self._state_restored = True
|
||||
|
||||
# Non-streaming callback for the orchestrator receipt of the task
|
||||
await self._emit_orchestrator_message(context, message.task, ORCH_MSG_KIND_USER_TASK)
|
||||
|
||||
@@ -1269,7 +1234,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
"""Handle responses from agents."""
|
||||
if getattr(self, "_terminated", False):
|
||||
return
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
if self._context is None:
|
||||
raise RuntimeError("Magentic Orchestrator: Received response but not initialized")
|
||||
|
||||
@@ -1301,7 +1266,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
) -> None:
|
||||
if getattr(self, "_terminated", False):
|
||||
return
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
if self._context is None:
|
||||
return
|
||||
|
||||
@@ -1636,9 +1601,9 @@ class MagenticAgentExecutor(Executor):
|
||||
self._agent = agent
|
||||
self._agent_id = agent_id
|
||||
self._chat_history: list[ChatMessage] = []
|
||||
self._state_restored = False
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current executor state for checkpointing.
|
||||
|
||||
Returns:
|
||||
@@ -1650,7 +1615,8 @@ class MagenticAgentExecutor(Executor):
|
||||
"chat_history": encode_chat_messages(self._chat_history),
|
||||
}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore executor state from checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -1668,24 +1634,6 @@ class MagenticAgentExecutor(Executor):
|
||||
else:
|
||||
self._chat_history = []
|
||||
|
||||
async def _ensure_state_restored(self, context: WorkflowContext[Any, Any]) -> None:
|
||||
if self._state_restored and self._chat_history:
|
||||
return
|
||||
state = await context.get_executor_state()
|
||||
if not state:
|
||||
self._state_restored = True
|
||||
return
|
||||
if not isinstance(state, dict):
|
||||
self._state_restored = True
|
||||
return
|
||||
try:
|
||||
self.restore_state(state)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning("Agent %s: Failed to apply checkpoint state: %s", self._agent_id, exc, exc_info=True)
|
||||
raise
|
||||
else:
|
||||
self._state_restored = True
|
||||
|
||||
@handler
|
||||
async def handle_response_message(
|
||||
self, message: _MagenticResponseMessage, context: WorkflowContext[_MagenticResponseMessage]
|
||||
@@ -1693,8 +1641,6 @@ class MagenticAgentExecutor(Executor):
|
||||
"""Handle response message (task ledger broadcast)."""
|
||||
logger.debug("Agent %s: Received response message", self._agent_id)
|
||||
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
# Check if this message is intended for this agent
|
||||
if message.target_agent is not None and message.target_agent != self._agent_id and not message.broadcast:
|
||||
# Message is targeted to a different agent, ignore it
|
||||
@@ -1735,8 +1681,6 @@ class MagenticAgentExecutor(Executor):
|
||||
|
||||
logger.info("Agent %s: Received request to respond", self._agent_id)
|
||||
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
# Add persona adoption message with appropriate role
|
||||
persona_role = self._get_persona_adoption_role()
|
||||
persona_msg = ChatMessage(
|
||||
@@ -1783,7 +1727,6 @@ class MagenticAgentExecutor(Executor):
|
||||
"""Reset the internal chat history of the agent (internal operation)."""
|
||||
logger.debug("Agent %s: Resetting chat history", self._agent_id)
|
||||
self._chat_history.clear()
|
||||
self._state_restored = True
|
||||
|
||||
async def _emit_agent_delta_event(
|
||||
self,
|
||||
|
||||
@@ -11,7 +11,7 @@ from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpo
|
||||
from ._const import EXECUTOR_STATE_KEY
|
||||
from ._edge import EdgeGroup
|
||||
from ._edge_runner import EdgeRunner, create_edge_runner
|
||||
from ._events import WorkflowEvent
|
||||
from ._events import SuperStepCompletedEvent, SuperStepStartedEvent, WorkflowEvent
|
||||
from ._executor import Executor
|
||||
from ._runner_context import (
|
||||
Message,
|
||||
@@ -92,6 +92,7 @@ class Runner:
|
||||
|
||||
while self._iteration < self._max_iterations:
|
||||
logger.info(f"Starting superstep {self._iteration + 1}")
|
||||
yield SuperStepStartedEvent(iteration=self._iteration + 1)
|
||||
|
||||
# Run iteration concurrently with live event streaming: we poll
|
||||
# for new events while the iteration coroutine progresses.
|
||||
@@ -126,6 +127,9 @@ class Runner:
|
||||
# Create checkpoint after each superstep iteration
|
||||
await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}")
|
||||
|
||||
yield SuperStepCompletedEvent(iteration=self._iteration)
|
||||
|
||||
# Check for convergence: no more messages to process
|
||||
if not await self._ctx.has_messages():
|
||||
break
|
||||
|
||||
@@ -183,8 +187,8 @@ class Runner:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Auto-snapshot executor states
|
||||
await self._auto_snapshot_executor_states()
|
||||
# Snapshot executor states
|
||||
await self._save_executor_states()
|
||||
checkpoint_category = "initial" if checkpoint_type == "after_initial_execution" else "superstep"
|
||||
metadata = {
|
||||
"superstep": self._iteration,
|
||||
@@ -203,41 +207,6 @@ class Runner:
|
||||
logger.warning(f"Failed to create {checkpoint_type} checkpoint: {e}")
|
||||
return None
|
||||
|
||||
async def _auto_snapshot_executor_states(self) -> None:
|
||||
"""Populate executor state by calling snapshot hooks on executors if available.
|
||||
|
||||
TODO(@taochen#1614): this method is potentially problematic if executors also call
|
||||
set_executor_state on the context directly. We should clarify the intended usage
|
||||
pattern for executor state management.
|
||||
|
||||
Convention:
|
||||
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
|
||||
- Else if it has a plain attribute `state` that is a dict, use that.
|
||||
Only JSON-serializable dicts should be provided by executors.
|
||||
"""
|
||||
for exec_id, executor in self._executors.items():
|
||||
state_dict: dict[str, Any] | None = None
|
||||
snapshot = getattr(executor, "snapshot_state", None)
|
||||
try:
|
||||
if callable(snapshot):
|
||||
maybe = snapshot()
|
||||
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
|
||||
maybe = await maybe # type: ignore[assignment]
|
||||
if isinstance(maybe, dict):
|
||||
state_dict = maybe # type: ignore[assignment]
|
||||
else:
|
||||
state_attr = getattr(executor, "state", None)
|
||||
if isinstance(state_attr, dict):
|
||||
state_dict = state_attr # type: ignore[assignment]
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
|
||||
|
||||
if state_dict is not None:
|
||||
try:
|
||||
await self._set_executor_state(exec_id, state_dict)
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
|
||||
|
||||
async def restore_from_checkpoint(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
@@ -300,7 +269,65 @@ class Runner:
|
||||
logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _save_executor_states(self) -> None:
|
||||
"""Populate executor state by calling checkpoint hooks on executors.
|
||||
|
||||
Backward compatibility behavior:
|
||||
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
|
||||
- Else if it has a plain attribute `state` that is a dict, use that.
|
||||
|
||||
Updated behavior:
|
||||
- Executors should implement `on_checkpoint_save(self) -> dict` to provide state.
|
||||
|
||||
This method will try the backward compatibility behavior first; if that does not yield state,
|
||||
it falls back to the updated behavior.
|
||||
|
||||
Only JSON-serializable dicts should be provided by executors.
|
||||
"""
|
||||
for exec_id, executor in self._executors.items():
|
||||
state_dict: dict[str, Any] | None = None
|
||||
# Try backward compatibility behavior first
|
||||
# TODO(@taochen): Remove backward compatibility
|
||||
snapshot = getattr(executor, "snapshot_state", None)
|
||||
try:
|
||||
if callable(snapshot):
|
||||
maybe = snapshot()
|
||||
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
|
||||
maybe = await maybe # type: ignore[assignment]
|
||||
if isinstance(maybe, dict):
|
||||
state_dict = maybe # type: ignore[assignment]
|
||||
else:
|
||||
state_attr = getattr(executor, "state", None)
|
||||
if isinstance(state_attr, dict):
|
||||
state_dict = state_attr # type: ignore[assignment]
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
|
||||
|
||||
if state_dict is None:
|
||||
# Try the updated behavior only if backward compatibility did not yield state
|
||||
try:
|
||||
state_dict = await executor.on_checkpoint_save()
|
||||
except Exception as ex: # pragma: no cover
|
||||
raise ValueError(f"Executor {exec_id} on_checkpoint_save failed: {ex}") from ex
|
||||
|
||||
try:
|
||||
await self._set_executor_state(exec_id, state_dict)
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
|
||||
|
||||
async def _restore_executor_states(self) -> None:
|
||||
"""Restore executor state by calling restore hooks on executors.
|
||||
|
||||
Backward compatibility behavior:
|
||||
- If an executor defines an async or sync method `restore_state(self, state: dict)`, use it.
|
||||
- Else, skip restoration for that executor.
|
||||
|
||||
Updated behavior:
|
||||
- Executors should implement `on_checkpoint_restore(self, state: dict)` to restore state.
|
||||
|
||||
This method will try the backward compatibility behavior first; if that does not restore state,
|
||||
it falls back to the updated behavior.
|
||||
"""
|
||||
has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
|
||||
if not has_executor_states:
|
||||
return
|
||||
@@ -309,16 +336,18 @@ class Runner:
|
||||
if not isinstance(executor_states, dict):
|
||||
raise ValueError("Executor states in shared state is not a dictionary. Unable to restore.")
|
||||
|
||||
for executor_id, state in executor_states.items():
|
||||
for executor_id, state in executor_states.items(): # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(executor_id, str):
|
||||
raise ValueError("Executor ID in executor states is not a string. Unable to restore.")
|
||||
if not isinstance(state, dict):
|
||||
raise ValueError(f"Executor state for {executor_id} is not a dictionary. Unable to restore.")
|
||||
if not isinstance(state, dict) or not all(isinstance(k, str) for k in state): # pyright: ignore[reportUnknownVariableType]
|
||||
raise ValueError(f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore.")
|
||||
|
||||
executor = self._executors.get(executor_id)
|
||||
if not executor:
|
||||
raise ValueError(f"Executor {executor_id} not found during state restoration.")
|
||||
|
||||
# Try backward compatibility behavior first
|
||||
# TODO(@taochen): Remove backward compatibility
|
||||
restored = False
|
||||
restore_method = getattr(executor, "restore_state", None)
|
||||
try:
|
||||
@@ -330,6 +359,14 @@ class Runner:
|
||||
except Exception as ex: # pragma: no cover - defensive
|
||||
raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex
|
||||
|
||||
if not restored:
|
||||
# Try the updated behavior only if backward compatibility did not restore
|
||||
try:
|
||||
await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType]
|
||||
restored = True
|
||||
except Exception as ex: # pragma: no cover - defensive
|
||||
raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex
|
||||
|
||||
if not restored:
|
||||
logger.debug(f"Executor {executor_id} does not support state restoration; skipping.")
|
||||
|
||||
|
||||
@@ -109,9 +109,9 @@ class Workflow(DictConvertible):
|
||||
"""A graph-based execution engine that orchestrates connected executors.
|
||||
|
||||
## Overview
|
||||
A workflow executes a directed graph of executors connected via edge groups using a Pregel-like model,
|
||||
running in supersteps until the graph becomes idle. Workflows are created using the
|
||||
WorkflowBuilder class - do not instantiate this class directly.
|
||||
A workflow executes a directed graph of executors connected via edge groups using a
|
||||
Pregel-like model, running in supersteps until the graph becomes idle. Workflows
|
||||
are created using the WorkflowBuilder class - do not instantiate this class directly.
|
||||
|
||||
## Execution Model
|
||||
Executors run in synchronized supersteps where each executor:
|
||||
@@ -142,6 +142,10 @@ class Workflow(DictConvertible):
|
||||
- HIL continuation: Provide `responses` to continue after RequestInfoExecutor requests
|
||||
- Runtime checkpointing: Provide `checkpoint_storage` to enable/override checkpointing for this run
|
||||
|
||||
## State Management
|
||||
Workflow instances contain states and states are preserved across calls to `run` and `run_stream`.
|
||||
To execute multiple independent runs, create separate Workflow instances via WorkflowBuilder.
|
||||
|
||||
## External Input Requests
|
||||
Executors within a workflow can request external input using `ctx.request_info()`:
|
||||
1. Executor calls `ctx.request_info()` to request input
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Union, cast, get_args, get_origi
|
||||
|
||||
from opentelemetry.propagate import inject
|
||||
from opentelemetry.trace import SpanKind
|
||||
from typing_extensions import Never, TypeVar
|
||||
from typing_extensions import Never, TypeVar, deprecated
|
||||
|
||||
from ..observability import OtelAttr, create_workflow_span
|
||||
from ._const import EXECUTOR_STATE_KEY
|
||||
@@ -410,6 +410,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
"""Get the shared state."""
|
||||
return self._shared_state
|
||||
|
||||
@deprecated(
|
||||
"Override `on_checkpoint_save()` methods instead. "
|
||||
"For cross-executor state sharing, use set_shared_state() instead. "
|
||||
"This API will be removed after 12/01/2025."
|
||||
)
|
||||
async def set_executor_state(self, state: dict[str, Any]) -> None:
|
||||
"""Store executor state in shared state under a reserved key.
|
||||
|
||||
@@ -428,6 +433,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
existing_states[self._executor_id] = state
|
||||
await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states)
|
||||
|
||||
@deprecated(
|
||||
"Override `on_checkpoint_restore()` methods instead. "
|
||||
"For cross-executor state sharing, use get_shared_state() instead. "
|
||||
"This API will be removed after 12/01/2025."
|
||||
)
|
||||
async def get_executor_state(self) -> dict[str, Any] | None:
|
||||
"""Retrieve previously persisted state for this executor, if any."""
|
||||
has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -26,6 +26,12 @@ from ._typing_utils import is_instance_of
|
||||
from ._workflow import WorkflowRunResult
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -181,8 +187,7 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
# Includes all sub-workflow output types
|
||||
# Plus SubWorkflowRequestMessage if sub-workflow can make requests
|
||||
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
|
||||
```
|
||||
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
|
||||
|
||||
## Error Handling
|
||||
WorkflowExecutor propagates sub-workflow failures:
|
||||
@@ -221,23 +226,10 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
### Important Considerations
|
||||
**Shared Workflow Instance**: All concurrent executions use the same underlying workflow instance.
|
||||
For proper isolation, ensure that:
|
||||
- The wrapped workflow and its executors are stateless
|
||||
- Executors use WorkflowContext state management instead of instance variables
|
||||
- Any shared state is managed through WorkflowContext.get_shared_state/set_shared_state
|
||||
For proper isolation, ensure that the wrapped workflow and its executors are stateless.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Good: Stateless executor using context state
|
||||
class StatelessExecutor(Executor):
|
||||
@handler
|
||||
async def process(self, data: str, ctx: WorkflowContext[str]) -> None:
|
||||
# Use context state instead of instance variables
|
||||
state = await ctx.get_executor_state() or {}
|
||||
state["processed"] = data
|
||||
await ctx.set_executor_state(state)
|
||||
|
||||
|
||||
# Avoid: Stateful executor with instance variables
|
||||
class StatefulExecutor(Executor):
|
||||
def __init__(self):
|
||||
@@ -246,23 +238,23 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
## Integration with Parent Workflows
|
||||
Parent workflows can intercept sub-workflow requests:
|
||||
```python
|
||||
class ParentExecutor(Executor):
|
||||
@handler
|
||||
async def handle_subworkflow_request(
|
||||
self,
|
||||
request: SubWorkflowRequestMessage,
|
||||
ctx: WorkflowContext[SubWorkflowResponseMessage],
|
||||
) -> None:
|
||||
# Handle request locally or forward to external source
|
||||
if self.can_handle_locally(request):
|
||||
# Send response back to sub-workflow
|
||||
response = request.create_response(data="local response data")
|
||||
await ctx.send_message(response, target_id=request.source_executor_id)
|
||||
else:
|
||||
# Forward to external handler
|
||||
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
|
||||
```
|
||||
|
||||
.. code-block:: python
|
||||
class ParentExecutor(Executor):
|
||||
@handler
|
||||
async def handle_subworkflow_request(
|
||||
self,
|
||||
request: SubWorkflowRequestMessage,
|
||||
ctx: WorkflowContext[SubWorkflowResponseMessage],
|
||||
) -> None:
|
||||
# Handle request locally or forward to external source
|
||||
if self.can_handle_locally(request):
|
||||
# Send response back to sub-workflow
|
||||
response = request.create_response(data="local response data")
|
||||
await ctx.send_message(response, target_id=request.source_executor_id)
|
||||
else:
|
||||
# Forward to external handler
|
||||
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
|
||||
|
||||
## Implementation Notes
|
||||
- Sub-workflows run to completion before processing their results
|
||||
@@ -296,7 +288,6 @@ class WorkflowExecutor(Executor):
|
||||
self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext
|
||||
# Map request_id to execution_id for response routing
|
||||
self._request_to_execution: dict[str, str] = {} # request_id -> execution_id
|
||||
self._state_loaded: bool = False
|
||||
|
||||
@property
|
||||
def input_types(self) -> list[type[Any]]:
|
||||
@@ -362,8 +353,6 @@ class WorkflowExecutor(Executor):
|
||||
input_data: The input data to send to the sub-workflow.
|
||||
ctx: The workflow context from the parent.
|
||||
"""
|
||||
await self._ensure_state_loaded(ctx)
|
||||
|
||||
# Create execution context for this sub-workflow run
|
||||
execution_id = str(uuid.uuid4())
|
||||
execution_context = ExecutionContext(
|
||||
@@ -405,8 +394,6 @@ class WorkflowExecutor(Executor):
|
||||
response: The response to a previous request.
|
||||
ctx: The workflow context.
|
||||
"""
|
||||
await self._ensure_state_loaded(ctx)
|
||||
|
||||
# Find the execution context for this request
|
||||
original_request = response.source_event
|
||||
execution_id = self._request_to_execution.get(original_request.request_id)
|
||||
@@ -434,8 +421,6 @@ class WorkflowExecutor(Executor):
|
||||
# Accumulate the response in this execution's context
|
||||
execution_context.collected_responses[original_request.request_id] = response.data
|
||||
|
||||
await self._persist_execution_state(ctx)
|
||||
|
||||
# Check if we have all expected responses for this execution
|
||||
if len(execution_context.collected_responses) < execution_context.expected_response_count:
|
||||
logger.debug(
|
||||
@@ -459,25 +444,20 @@ class WorkflowExecutor(Executor):
|
||||
if not execution_context.pending_requests:
|
||||
del self._execution_contexts[execution_id]
|
||||
|
||||
async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None:
|
||||
if self._state_loaded:
|
||||
return
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Get the current state of the WorkflowExecutor for checkpointing purposes."""
|
||||
return {
|
||||
"execution_contexts": {
|
||||
execution_id: encode_checkpoint_value(execution_context)
|
||||
for execution_id, execution_context in self._execution_contexts.items()
|
||||
},
|
||||
"request_to_execution": dict(self._request_to_execution),
|
||||
}
|
||||
|
||||
state: dict[str, Any] | None = None
|
||||
try:
|
||||
state = await ctx.get_executor_state()
|
||||
except Exception:
|
||||
state = None
|
||||
|
||||
if isinstance(state, dict) and state:
|
||||
with contextlib.suppress(Exception):
|
||||
await self.restore_state(state)
|
||||
self._state_loaded = True
|
||||
else:
|
||||
self._state_loaded = True
|
||||
|
||||
async def restore_state(self, state: dict[str, Any]) -> None:
|
||||
"""Restore pending request bookkeeping from a checkpoint snapshot."""
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore the WorkflowExecutor state from a checkpoint snapshot."""
|
||||
# Validate the state contains the right keys
|
||||
if "execution_contexts" not in state:
|
||||
raise KeyError("Missing 'execution_contexts' in WorkflowExecutor state.")
|
||||
@@ -529,23 +509,6 @@ class WorkflowExecutor(Executor):
|
||||
for event in request_info_events
|
||||
])
|
||||
|
||||
self._state_loaded = True
|
||||
|
||||
async def _persist_execution_state(self, ctx: WorkflowContext) -> None:
|
||||
"""Persist the state of the WorkflowExecutor for checkpointing purposes."""
|
||||
state = {
|
||||
"execution_contexts": {
|
||||
execution_id: encode_checkpoint_value(execution_context)
|
||||
for execution_id, execution_context in self._execution_contexts.items()
|
||||
},
|
||||
"request_to_execution": dict(self._request_to_execution),
|
||||
}
|
||||
|
||||
try:
|
||||
await ctx.set_executor_state(state)
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}")
|
||||
|
||||
async def _process_workflow_result(
|
||||
self,
|
||||
result: WorkflowRunResult,
|
||||
@@ -635,5 +598,3 @@ class WorkflowExecutor(Executor):
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}")
|
||||
|
||||
await self._persist_execution_state(ctx)
|
||||
|
||||
@@ -158,8 +158,8 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
assert thread_messages[1].text == "Initial response 1"
|
||||
|
||||
|
||||
async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
|
||||
"""Test AgentExecutor's snapshot_state and restore_state methods directly."""
|
||||
async def test_agent_executor_save_and_restore_state_directly() -> None:
|
||||
"""Test AgentExecutor's on_checkpoint_save and on_checkpoint_restore methods directly."""
|
||||
# Create agent with thread containing messages
|
||||
agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent")
|
||||
thread = AgentThread(message_store=ChatMessageStore())
|
||||
@@ -182,7 +182,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
|
||||
executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Snapshot the state
|
||||
state = await executor.snapshot_state() # type: ignore[reportUnknownMemberType]
|
||||
state = await executor.on_checkpoint_save()
|
||||
|
||||
# Verify snapshot contains both cache and thread
|
||||
assert "cache" in state
|
||||
@@ -206,7 +206,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
|
||||
assert len(initial_thread_msgs) == 0
|
||||
|
||||
# Restore state
|
||||
await new_executor.restore_state(state) # type: ignore[reportUnknownMemberType]
|
||||
await new_executor.on_checkpoint_restore(state)
|
||||
|
||||
# Verify cache is restored
|
||||
restored_cache = new_executor._cache # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -288,57 +288,6 @@ def test_build_fails_without_participants():
|
||||
HandoffBuilder().build()
|
||||
|
||||
|
||||
async def test_multiple_runs_dont_leak_conversation():
|
||||
"""Verify that running the same workflow multiple times doesn't leak conversation history."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist")
|
||||
specialist = _RecordingAgent(name="specialist")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist])
|
||||
.set_coordinator("triage")
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2)
|
||||
.build()
|
||||
)
|
||||
|
||||
# First run
|
||||
events = await _drain(workflow.run_stream("First run message"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second message"}))
|
||||
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
|
||||
assert outputs, "First run should emit output"
|
||||
|
||||
first_run_conversation = outputs[-1].data
|
||||
assert isinstance(first_run_conversation, list)
|
||||
first_run_conv_list = cast(list[ChatMessage], first_run_conversation)
|
||||
first_run_user_messages = [msg for msg in first_run_conv_list if msg.role == Role.USER]
|
||||
assert len(first_run_user_messages) == 2
|
||||
assert any("First run message" in msg.text for msg in first_run_user_messages if msg.text)
|
||||
|
||||
# Second run - should start fresh, not include first run's messages
|
||||
triage.calls.clear()
|
||||
specialist.calls.clear()
|
||||
|
||||
events = await _drain(workflow.run_stream("Second run different message"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Another message"}))
|
||||
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
|
||||
assert outputs, "Second run should emit output"
|
||||
|
||||
second_run_conversation = outputs[-1].data
|
||||
assert isinstance(second_run_conversation, list)
|
||||
second_run_conv_list = cast(list[ChatMessage], second_run_conversation)
|
||||
second_run_user_messages = [msg for msg in second_run_conv_list if msg.role == Role.USER]
|
||||
assert len(second_run_user_messages) == 2, (
|
||||
"Second run should have exactly 2 user messages, not accumulate first run"
|
||||
)
|
||||
assert any("Second run different message" in msg.text for msg in second_run_user_messages if msg.text)
|
||||
assert not any("First run message" in msg.text for msg in second_run_user_messages if msg.text), (
|
||||
"Second run should NOT contain first run's messages"
|
||||
)
|
||||
|
||||
|
||||
async def test_handoff_async_termination_condition() -> None:
|
||||
"""Test that async termination conditions work correctly."""
|
||||
termination_call_count = 0
|
||||
@@ -585,7 +534,7 @@ async def test_return_to_previous_state_serialization():
|
||||
coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Snapshot the state
|
||||
state = coordinator.snapshot_state()
|
||||
state = await coordinator.on_checkpoint_save()
|
||||
|
||||
# Verify pattern metadata includes current_agent_id
|
||||
assert "metadata" in state
|
||||
@@ -603,7 +552,7 @@ async def test_return_to_previous_state_serialization():
|
||||
)
|
||||
|
||||
# Restore state
|
||||
coordinator2.restore_state(state)
|
||||
await coordinator2.on_checkpoint_restore(state)
|
||||
|
||||
# Verify current_agent_id was restored
|
||||
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
@@ -42,6 +43,11 @@ from agent_framework._workflows._magentic import ( # type: ignore[reportPrivate
|
||||
_MagenticStartMessage, # type: ignore
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def test_magentic_start_message_from_string():
|
||||
msg = _MagenticStartMessage.from_string("Do the thing")
|
||||
@@ -101,8 +107,9 @@ class FakeManager(MagenticManagerBase):
|
||||
next_speaker_name: str = "agentA"
|
||||
instruction_text: str = "Proceed with step 1"
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
state = super().snapshot_state()
|
||||
@override
|
||||
def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
state = super().on_checkpoint_save()
|
||||
if self.task_ledger is not None:
|
||||
state = dict(state)
|
||||
state["task_ledger"] = {
|
||||
@@ -111,8 +118,9 @@ class FakeManager(MagenticManagerBase):
|
||||
}
|
||||
return state
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
super().restore_state(state)
|
||||
@override
|
||||
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
super().on_checkpoint_restore(state)
|
||||
ledger_state = state.get("task_ledger")
|
||||
if isinstance(ledger_state, dict):
|
||||
ledger_dict = cast(dict[str, Any], ledger_state)
|
||||
@@ -185,7 +193,6 @@ async def test_standard_manager_progress_ledger_and_fallback():
|
||||
assert ledger2.is_request_satisfied.answer is False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
|
||||
async def test_magentic_workflow_plan_review_approval_to_completion():
|
||||
manager = FakeManager(max_round_count=10)
|
||||
wf = (
|
||||
@@ -204,7 +211,7 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
|
||||
|
||||
completed = False
|
||||
output: ChatMessage | None = None
|
||||
async for ev in wf.run_stream(
|
||||
async for ev in wf.send_responses_streaming(
|
||||
responses={req_event.request_id: MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE)}
|
||||
):
|
||||
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
|
||||
@@ -218,7 +225,6 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
|
||||
assert isinstance(output, ChatMessage)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
|
||||
async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds():
|
||||
class CountingManager(FakeManager):
|
||||
# Declare as a model field so assignment is allowed under Pydantic
|
||||
@@ -250,7 +256,7 @@ async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds()
|
||||
# Reply APPROVE with comments (no edited text). Expect one replan and no second review round.
|
||||
saw_second_review = False
|
||||
completed = False
|
||||
async for ev in wf.run_stream(
|
||||
async for ev in wf.send_responses_streaming(
|
||||
responses={
|
||||
req_event.request_id: MagenticPlanReviewReply(
|
||||
decision=MagenticPlanReviewDecision.APPROVE,
|
||||
@@ -298,7 +304,6 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result():
|
||||
assert data.role == Role.ASSISTANT
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Response handling refactored - send_responses_streaming no longer exists")
|
||||
async def test_magentic_checkpoint_resume_round_trip():
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -369,7 +374,7 @@ class _DummyExec(Executor):
|
||||
pass
|
||||
|
||||
|
||||
def test_magentic_agent_executor_snapshot_roundtrip():
|
||||
async def test_magentic_agent_executor_on_checkpoint_save_and_restore_roundtrip():
|
||||
backing_executor = _DummyExec("backing")
|
||||
agent_exec = MagenticAgentExecutor(backing_executor, "agentA")
|
||||
agent_exec._chat_history.extend([ # type: ignore[reportPrivateUsage]
|
||||
@@ -377,10 +382,10 @@ def test_magentic_agent_executor_snapshot_roundtrip():
|
||||
ChatMessage(role=Role.ASSISTANT, text="world", author_name="agentA"),
|
||||
])
|
||||
|
||||
state = agent_exec.snapshot_state()
|
||||
state = await agent_exec.on_checkpoint_save()
|
||||
|
||||
restored_executor = MagenticAgentExecutor(_DummyExec("backing2"), "agentA")
|
||||
restored_executor.restore_state(state)
|
||||
await restored_executor.on_checkpoint_restore(state)
|
||||
|
||||
assert len(restored_executor._chat_history) == 2 # type: ignore[reportPrivateUsage]
|
||||
assert restored_executor._chat_history[0].text == "hello" # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -199,7 +199,10 @@ async def test_fan_out():
|
||||
|
||||
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
|
||||
# executor_b will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
|
||||
assert len(events) == 7
|
||||
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
|
||||
# This workflow will converge in 2 supersteps because executor_c will send one more message
|
||||
# after executor_b completes
|
||||
assert len(events) == 11
|
||||
|
||||
assert events.get_final_state() == WorkflowRunState.IDLE
|
||||
outputs = events.get_outputs()
|
||||
@@ -220,7 +223,9 @@ async def test_fan_out_multiple_completed_events():
|
||||
|
||||
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
|
||||
# executor_b and executor_c will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
|
||||
assert len(events) == 8
|
||||
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
|
||||
# This workflow will converge in 1 superstep because executor_a and executor_b will not send further messages
|
||||
assert len(events) == 10
|
||||
|
||||
# Multiple outputs are expected from both executors
|
||||
outputs = events.get_outputs()
|
||||
@@ -246,7 +251,8 @@ async def test_fan_in():
|
||||
|
||||
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
|
||||
# aggregator will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
|
||||
assert len(events) == 9
|
||||
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
|
||||
assert len(events) == 13
|
||||
|
||||
assert events.get_final_state() == WorkflowRunState.IDLE
|
||||
outputs = events.get_outputs()
|
||||
|
||||
@@ -37,10 +37,14 @@ class WorkflowHILRequest:
|
||||
class WorkflowTestExecutor(Executor):
|
||||
"""Test executor with HIL."""
|
||||
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._data_value: str | None = None
|
||||
|
||||
@handler
|
||||
async def process(self, data: WorkflowTestData, ctx: WorkflowContext) -> None:
|
||||
"""Process data and request approval."""
|
||||
await ctx.set_executor_state({"data_value": data.value})
|
||||
self._data_value = data.value
|
||||
|
||||
# Request HIL (checkpoint created here)
|
||||
await ctx.request_info(request_data=WorkflowHILRequest(question=f"Approve {data.value}?"), response_type=str)
|
||||
@@ -50,8 +54,7 @@ class WorkflowTestExecutor(Executor):
|
||||
self, original_request: WorkflowHILRequest, response: str, ctx: WorkflowContext[str]
|
||||
) -> None:
|
||||
"""Handle HIL response."""
|
||||
state = await ctx.get_executor_state() or {}
|
||||
value = state.get("data_value", "")
|
||||
value = self._data_value or ""
|
||||
await ctx.send_message(f"{value}_approved" if response.lower() == "yes" else f"{value}_rejected")
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ Workflow Steps:
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Annotated
|
||||
from typing import Literal
|
||||
|
||||
from agent_framework import (
|
||||
Case,
|
||||
@@ -31,9 +31,11 @@ from agent_framework import (
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Never
|
||||
|
||||
|
||||
# Define response model with clear user guidance
|
||||
class SpamDecision(BaseModel):
|
||||
"""User's decision on whether the email is spam."""
|
||||
|
||||
decision: Literal["spam", "not spam"] = Field(
|
||||
description="Enter 'spam' to mark as spam, or 'not spam' to mark as legitimate"
|
||||
)
|
||||
@@ -71,10 +73,11 @@ class SpamDetectorResponse:
|
||||
class SpamApprovalRequest:
|
||||
"""Human-in-the-loop approval request for spam classification."""
|
||||
|
||||
email_message: str = ""
|
||||
detected_as_spam: bool = False
|
||||
confidence: float = 0.0
|
||||
reasons: str = ""
|
||||
email_message: str
|
||||
detected_as_spam: bool
|
||||
confidence: float
|
||||
reasons: list[str]
|
||||
full_email_content: EmailContent
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -128,8 +131,6 @@ class EmailPreprocessor(Executor):
|
||||
await ctx.send_message(result)
|
||||
|
||||
|
||||
|
||||
|
||||
class SpamDetector(Executor):
|
||||
"""Step 2: An executor that analyzes content and determines if a message is spam."""
|
||||
|
||||
@@ -139,7 +140,9 @@ class SpamDetector(Executor):
|
||||
self._spam_keywords = spam_keywords
|
||||
|
||||
@handler
|
||||
async def handle_email_content(self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]) -> None:
|
||||
async def handle_email_content(
|
||||
self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]
|
||||
) -> None:
|
||||
"""Analyze email content and determine if the message is spam, then request human approval."""
|
||||
await asyncio.sleep(2.0) # Simulate analysis and detection time
|
||||
|
||||
@@ -186,25 +189,13 @@ class SpamDetector(Executor):
|
||||
|
||||
is_spam = spam_score >= 0.5
|
||||
|
||||
# Store detection result in executor state for later use
|
||||
# Store minimal data needed (not complex objects that don't serialize well)
|
||||
await ctx.set_executor_state({
|
||||
"original_message": email_content.original_message,
|
||||
"cleaned_message": email_content.cleaned_message,
|
||||
"word_count": email_content.word_count,
|
||||
"has_suspicious_patterns": email_content.has_suspicious_patterns,
|
||||
"is_spam": is_spam,
|
||||
"ai_original_classification": is_spam, # Store original AI decision
|
||||
"confidence_score": spam_score,
|
||||
"spam_reasons": spam_reasons
|
||||
})
|
||||
|
||||
# Request human approval before proceeding using new API
|
||||
approval_request = SpamApprovalRequest(
|
||||
email_message=email_text[:200], # First 200 chars
|
||||
detected_as_spam=is_spam,
|
||||
confidence=spam_score,
|
||||
reasons=", ".join(spam_reasons) if spam_reasons else "no specific reasons"
|
||||
reasons=spam_reasons,
|
||||
full_email_content=email_content,
|
||||
)
|
||||
|
||||
await ctx.request_info(
|
||||
@@ -214,20 +205,15 @@ class SpamDetector(Executor):
|
||||
|
||||
@response_handler
|
||||
async def handle_human_response(
|
||||
self,
|
||||
original_request: SpamApprovalRequest,
|
||||
response: SpamDecision,
|
||||
ctx: WorkflowContext[SpamDetectorResponse]
|
||||
self, original_request: SpamApprovalRequest, response: SpamDecision, ctx: WorkflowContext[SpamDetectorResponse]
|
||||
) -> None:
|
||||
"""Process human approval response and continue workflow."""
|
||||
print(f"[SpamDetector] handle_human_response called with response: {response}")
|
||||
|
||||
# Get stored detection result
|
||||
state = await ctx.get_executor_state() or {}
|
||||
print(f"[SpamDetector] Retrieved state: {state}")
|
||||
ai_original = state.get("ai_original_classification", False)
|
||||
confidence_score = state.get("confidence_score", 0.0)
|
||||
spam_reasons = state.get("spam_reasons", [])
|
||||
ai_original = original_request.detected_as_spam
|
||||
confidence_score = original_request.confidence
|
||||
spam_reasons = original_request.reasons
|
||||
|
||||
# Parse human decision from the response model
|
||||
human_decision = response.decision.strip().lower()
|
||||
@@ -241,27 +227,21 @@ class SpamDetector(Executor):
|
||||
# Default to AI decision if unclear
|
||||
is_spam = ai_original
|
||||
|
||||
# Reconstruct EmailContent from stored primitives
|
||||
email_content = EmailContent(
|
||||
original_message=state.get("original_message", ""),
|
||||
cleaned_message=state.get("cleaned_message", ""),
|
||||
word_count=state.get("word_count", 0),
|
||||
has_suspicious_patterns=state.get("has_suspicious_patterns", False)
|
||||
)
|
||||
|
||||
result = SpamDetectorResponse(
|
||||
email_content=email_content,
|
||||
email_content=original_request.full_email_content,
|
||||
is_spam=is_spam,
|
||||
confidence_score=confidence_score,
|
||||
spam_reasons=spam_reasons,
|
||||
human_reviewed=True,
|
||||
human_decision=response.decision,
|
||||
ai_original_classification=ai_original
|
||||
ai_original_classification=ai_original,
|
||||
)
|
||||
|
||||
print(f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True")
|
||||
print(
|
||||
f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True"
|
||||
)
|
||||
await ctx.send_message(result)
|
||||
print(f"[SpamDetector] Message sent successfully")
|
||||
print("[SpamDetector] Message sent successfully")
|
||||
|
||||
|
||||
class SpamHandler(Executor):
|
||||
@@ -427,7 +407,9 @@ workflow = (
|
||||
spam_detector,
|
||||
[
|
||||
Case(condition=lambda x: isinstance(x, SpamDetectorResponse) and x.is_spam, target=spam_handler),
|
||||
Default(target=legitimate_message_handler), # Default handles non-spam and non-SpamDetectorResponse messages
|
||||
Default(
|
||||
target=legitimate_message_handler
|
||||
), # Default handles non-spam and non-SpamDetectorResponse messages
|
||||
],
|
||||
)
|
||||
.add_edge(spam_handler, final_processor)
|
||||
|
||||
+22
-16
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, override
|
||||
|
||||
# NOTE: the Azure client imports above are real dependencies. When running this
|
||||
# sample outside of Azure-enabled environments you may wish to swap in the
|
||||
@@ -116,19 +117,19 @@ class ReviewGateway(Executor):
|
||||
def __init__(self, id: str, writer_id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._writer_id = writer_id
|
||||
self._iteration = 0
|
||||
|
||||
@handler
|
||||
async def on_agent_response(self, response: AgentExecutorResponse, ctx: WorkflowContext) -> None:
|
||||
# Capture the agent output so we can surface it to the reviewer and persist iterations.
|
||||
draft = response.agent_run_response.text or ""
|
||||
iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1
|
||||
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
|
||||
self._iteration += 1
|
||||
|
||||
# Emit a human approval request.
|
||||
await ctx.request_info(
|
||||
request_data=HumanApprovalRequest(
|
||||
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
|
||||
draft=draft,
|
||||
iteration=iteration,
|
||||
draft=response.agent_run_response.text,
|
||||
iteration=self._iteration,
|
||||
),
|
||||
response_type=str,
|
||||
)
|
||||
@@ -142,28 +143,33 @@ class ReviewGateway(Executor):
|
||||
) -> None:
|
||||
# The `original_request` is the request we sent earlier that is now being answered.
|
||||
reply = feedback.strip()
|
||||
state = await ctx.get_executor_state() or {}
|
||||
draft = state.get("last_draft") or (original_request.draft or "")
|
||||
|
||||
if reply.lower() == "approve":
|
||||
if len(reply) == 0 or reply.lower() == "approve":
|
||||
# Workflow is completed when the human approves.
|
||||
await ctx.yield_output(draft)
|
||||
await ctx.yield_output(original_request.draft)
|
||||
return
|
||||
|
||||
# Any other response loops us back to the writer with fresh guidance.
|
||||
guidance = reply or "Tighten the copy and emphasise customer benefit."
|
||||
iteration = int(state.get("iteration", 1)) + 1
|
||||
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
|
||||
prompt = (
|
||||
"Revise the launch note. Respond with the new copy only.\n\n"
|
||||
f"Previous draft:\n{draft}\n\n"
|
||||
f"Human guidance: {guidance}"
|
||||
f"Previous draft:\n{original_request.draft}\n\n"
|
||||
f"Human guidance: {reply}"
|
||||
)
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
|
||||
target_id=self._writer_id,
|
||||
)
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
# Save the current iteration count in executor state for checkpointing.
|
||||
return {"iteration": self._iteration}
|
||||
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
# Restore the iteration count from executor state during checkpoint recovery.
|
||||
self._iteration = state.get("iteration", 0)
|
||||
|
||||
|
||||
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> Workflow:
|
||||
"""Assemble the workflow graph used by both the initial run and resume."""
|
||||
@@ -247,10 +253,10 @@ async def run_interactive_session(
|
||||
else:
|
||||
if initial_message:
|
||||
print(f"\nStarting workflow with brief: {initial_message}\n")
|
||||
event_stream = workflow.run_stream(initial_message)
|
||||
event_stream = workflow.run_stream(message=initial_message)
|
||||
elif checkpoint_id:
|
||||
print("\nStarting workflow from checkpoint...\n")
|
||||
event_stream = workflow.run_stream(checkpoint_id)
|
||||
event_stream = workflow.run_stream(checkpoint_id=checkpoint_id)
|
||||
else:
|
||||
raise ValueError("Either initial_message or checkpoint_id must be provided")
|
||||
|
||||
|
||||
@@ -1,322 +1,157 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
FileCheckpointStorage,
|
||||
Role,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
get_checkpoint_summary,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import Workflow
|
||||
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
||||
|
||||
"""
|
||||
Sample: Checkpointing and Resuming a Workflow (with an Agent stage)
|
||||
Sample: Checkpointing and Resuming a Workflow
|
||||
|
||||
Purpose:
|
||||
This sample shows how to enable checkpointing at superstep boundaries, persist both
|
||||
executor-local state and shared workflow state, and then resume execution from a specific
|
||||
checkpoint. The workflow demonstrates a simple text-processing pipeline that includes
|
||||
an LLM-backed AgentExecutor stage.
|
||||
|
||||
Pipeline:
|
||||
1) UpperCaseExecutor converts input to uppercase and records state.
|
||||
2) ReverseTextExecutor reverses the string.
|
||||
3) SubmitToLowerAgent prepares an AgentExecutorRequest for the lowercasing agent.
|
||||
4) lower_agent (AgentExecutor) converts text to lowercase via Azure OpenAI.
|
||||
5) FinalizeFromAgent yields the final result.
|
||||
This sample shows how to enable checkpointing for a long-running workflow
|
||||
that can be paused and resumed.
|
||||
|
||||
What you learn:
|
||||
- How to persist executor state using ctx.get_executor_state and ctx.set_executor_state.
|
||||
- How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility.
|
||||
- How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder.
|
||||
- How to list and inspect checkpoints programmatically.
|
||||
- How to interactively choose a checkpoint to resume from (instead of always resuming
|
||||
from the most recent or a hard-coded one) using run_stream.
|
||||
- How workflows complete by yielding outputs when idle, not via explicit completion events.
|
||||
- How to configure checkpointing storage (InMemoryCheckpointStorage for testing)
|
||||
- How to resume a workflow from a checkpoint after interruption
|
||||
- How to implement executor state management with checkpoint hooks
|
||||
- How to handle workflow interruptions and automatic recovery
|
||||
|
||||
Pipeline:
|
||||
This sample shows a workflow that computes factor pairs for numbers up to a given limit:
|
||||
1) A start executor that receives the upper limit and creates the initial task
|
||||
2) A worker executor that processes each number to find its factor pairs
|
||||
3) The worker uses checkpoint hooks to save/restore its internal state
|
||||
|
||||
Prerequisites:
|
||||
- Azure AI or Azure OpenAI available for AzureOpenAIChatClient.
|
||||
- Authentication with azure-identity via AzureCliCredential. Run az login locally.
|
||||
- Filesystem access for writing JSON checkpoint files in a temp directory.
|
||||
- Basic understanding of workflow concepts, including executors, edges, events, etc.
|
||||
"""
|
||||
|
||||
# Define the temporary directory for storing checkpoints.
|
||||
# These files allow the workflow to be resumed later.
|
||||
DIR = os.path.dirname(__file__)
|
||||
TEMP_DIR = os.path.join(DIR, "tmp", "checkpoints")
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from random import random
|
||||
from typing import Any, override
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
InMemoryCheckpointStorage,
|
||||
SuperStepCompletedEvent,
|
||||
WorkflowBuilder,
|
||||
WorkflowCheckpoint,
|
||||
WorkflowContext,
|
||||
WorkflowOutputEvent,
|
||||
handler,
|
||||
)
|
||||
|
||||
|
||||
class UpperCaseExecutor(Executor):
|
||||
"""Uppercases the input text and persists both local and shared state."""
|
||||
@dataclass
|
||||
class ComputeTask:
|
||||
"""Task containing the list of numbers remaining to be processed."""
|
||||
|
||||
remaining_numbers: list[int]
|
||||
|
||||
|
||||
class StartExecutor(Executor):
|
||||
"""Initiates the workflow by providing the upper limit for factor pair computation."""
|
||||
|
||||
@handler
|
||||
async def to_upper_case(self, text: str, ctx: WorkflowContext[str]) -> None:
|
||||
result = text.upper()
|
||||
print(f"UpperCaseExecutor: '{text}' -> '{result}'")
|
||||
|
||||
# Persist executor-local state so it is captured in checkpoints
|
||||
# and available after resume for observability or logic.
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_input": text,
|
||||
"last_output": result,
|
||||
})
|
||||
|
||||
# Write to shared_state so downstream executors and any resumed runs can read it.
|
||||
await ctx.set_shared_state("original_input", text)
|
||||
await ctx.set_shared_state("upper_output", result)
|
||||
|
||||
# Send transformed text to the next executor.
|
||||
await ctx.send_message(result)
|
||||
async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None:
|
||||
"""Start the workflow with a list of numbers to process."""
|
||||
print(f"StartExecutor: Starting factor pair computation up to {upper_limit}")
|
||||
await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1))))
|
||||
|
||||
|
||||
class SubmitToLowerAgent(Executor):
|
||||
"""Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility."""
|
||||
class WorkerExecutor(Executor):
|
||||
"""Processes numbers to compute their factor pairs and manages executor state for checkpointing."""
|
||||
|
||||
def __init__(self, id: str, agent_id: str):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._agent_id = agent_id
|
||||
self._composite_number_pairs: dict[int, list[tuple[int, int]]] = {}
|
||||
|
||||
@handler
|
||||
async def submit(self, text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
# Demonstrate reading shared_state written by UpperCaseExecutor.
|
||||
# Shared state survives across checkpoints and is visible to all executors.
|
||||
orig = await ctx.get_shared_state("original_input")
|
||||
upper = await ctx.get_shared_state("upper_output")
|
||||
print(f"LowerAgent (shared_state): original_input='{orig}', upper_output='{upper}'")
|
||||
async def compute(
|
||||
self,
|
||||
task: ComputeTask,
|
||||
ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]],
|
||||
) -> None:
|
||||
"""Process the next number in the task, computing its factor pairs."""
|
||||
next_number = task.remaining_numbers.pop(0)
|
||||
|
||||
# Build a minimal, deterministic prompt for the AgentExecutor.
|
||||
prompt = f"Convert the following text to lowercase. Return ONLY the transformed text.\n\nText: {text}"
|
||||
print(f"WorkerExecutor: Computing factor pairs for {next_number}")
|
||||
pairs: list[tuple[int, int]] = []
|
||||
for i in range(1, next_number):
|
||||
if next_number % i == 0:
|
||||
pairs.append((i, next_number // i))
|
||||
self._composite_number_pairs[next_number] = pairs
|
||||
|
||||
# Send to the AgentExecutor. should_respond=True instructs the agent to produce a reply.
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
|
||||
target_id=self._agent_id,
|
||||
)
|
||||
if not task.remaining_numbers:
|
||||
# All numbers processed - output the results
|
||||
await ctx.yield_output(self._composite_number_pairs)
|
||||
else:
|
||||
# More numbers to process - continue with remaining task
|
||||
await ctx.send_message(task)
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Save the executor's internal state for checkpointing."""
|
||||
return {"composite_number_pairs": self._composite_number_pairs}
|
||||
|
||||
class FinalizeFromAgent(Executor):
|
||||
"""Consumes the AgentExecutorResponse and yields the final result."""
|
||||
|
||||
@handler
|
||||
async def finalize(self, response: AgentExecutorResponse, ctx: WorkflowContext[Any, str]) -> None:
|
||||
result = response.agent_run_response.text or ""
|
||||
|
||||
# Persist executor-local state for auditability when inspecting checkpoints.
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_output": result,
|
||||
"final": True,
|
||||
})
|
||||
|
||||
# Yield the final result so external consumers see the final value.
|
||||
await ctx.yield_output(result)
|
||||
|
||||
|
||||
class ReverseTextExecutor(Executor):
|
||||
"""Reverses the input text and persists local state."""
|
||||
|
||||
@handler
|
||||
async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None:
|
||||
result = text[::-1]
|
||||
print(f"ReverseTextExecutor: '{text}' -> '{result}'")
|
||||
|
||||
# Persist executor-local state so checkpoint inspection can reveal progress.
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_input": text,
|
||||
"last_output": result,
|
||||
})
|
||||
|
||||
# Forward the reversed string to the next stage.
|
||||
await ctx.send_message(result)
|
||||
|
||||
|
||||
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow":
|
||||
# Instantiate the pipeline executors.
|
||||
upper_case_executor = UpperCaseExecutor(id="upper-case")
|
||||
reverse_text_executor = ReverseTextExecutor(id="reverse-text")
|
||||
|
||||
# Configure the agent stage that lowercases the text.
|
||||
chat_client = AzureOpenAIChatClient(credential=AzureCliCredential())
|
||||
lower_agent = AgentExecutor(
|
||||
chat_client.create_agent(
|
||||
instructions=("You transform text to lowercase. Reply with ONLY the transformed text.")
|
||||
),
|
||||
id="lower_agent",
|
||||
)
|
||||
|
||||
# Bridge to the agent and terminalization stage.
|
||||
submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id)
|
||||
finalize = FinalizeFromAgent(id="finalize")
|
||||
|
||||
# Build the workflow with checkpointing enabled.
|
||||
return (
|
||||
WorkflowBuilder(max_iterations=5)
|
||||
.add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse
|
||||
.add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request
|
||||
.add_edge(submit_lower, lower_agent) # Submit to AgentExecutor
|
||||
.add_edge(lower_agent, finalize) # Agent output -> Finalize
|
||||
.set_start_executor(upper_case_executor) # Entry point
|
||||
.with_checkpointing(checkpoint_storage=checkpoint_storage) # Enable persistence
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
|
||||
"""Display human-friendly checkpoint metadata using framework summaries."""
|
||||
|
||||
if not checkpoints:
|
||||
return
|
||||
|
||||
print("\nCheckpoint summary:")
|
||||
for cp in sorted(checkpoints, key=lambda c: c.timestamp):
|
||||
summary = get_checkpoint_summary(cp)
|
||||
msg_count = sum(len(v) for v in cp.messages.values())
|
||||
state_keys = sorted(summary.executor_ids)
|
||||
orig = cp.shared_state.get("original_input")
|
||||
upper = cp.shared_state.get("upper_output")
|
||||
|
||||
line = (
|
||||
f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}"
|
||||
)
|
||||
if summary.status:
|
||||
line += f" | status={summary.status}"
|
||||
line += f" | shared_state: original_input='{orig}', upper_output='{upper}'"
|
||||
print(line)
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore the executor's internal state from a checkpoint."""
|
||||
self._composite_number_pairs = state.get("composite_number_pairs", {})
|
||||
|
||||
|
||||
async def main():
|
||||
# Clear existing checkpoints in this sample directory for a clean run.
|
||||
checkpoint_dir = Path(TEMP_DIR)
|
||||
for file in checkpoint_dir.glob("*.json"): # noqa: ASYNC240
|
||||
file.unlink()
|
||||
# Create workflow executors
|
||||
start_executor = StartExecutor(id="start")
|
||||
worker_executor = WorkerExecutor(id="worker")
|
||||
|
||||
# Backing store for checkpoints written by with_checkpointing.
|
||||
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
|
||||
# Build workflow with checkpointing enabled
|
||||
workflow_builder = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(start_executor)
|
||||
.add_edge(start_executor, worker_executor)
|
||||
.add_edge(worker_executor, worker_executor) # Self-loop for iterative processing
|
||||
)
|
||||
checkpoint_storage = InMemoryCheckpointStorage()
|
||||
workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage=checkpoint_storage)
|
||||
|
||||
workflow = create_workflow(checkpoint_storage=checkpoint_storage)
|
||||
# Run workflow with automatic checkpoint recovery
|
||||
latest_checkpoint: WorkflowCheckpoint | None = None
|
||||
while True:
|
||||
workflow = workflow_builder.build()
|
||||
|
||||
# Run the full workflow once and observe events as they stream.
|
||||
print("Running workflow with initial message...")
|
||||
async for event in workflow.run_stream(message="hello world"):
|
||||
print(f"Event: {event}")
|
||||
# Start from checkpoint or fresh execution
|
||||
print(f"\n** Workflow {workflow.id} started **")
|
||||
event_stream = (
|
||||
workflow.run_stream(message=10)
|
||||
if latest_checkpoint is None
|
||||
else workflow.run_stream(checkpoint_id=latest_checkpoint.checkpoint_id)
|
||||
)
|
||||
|
||||
# Inspect checkpoints written during the run.
|
||||
all_checkpoints = await checkpoint_storage.list_checkpoints()
|
||||
if not all_checkpoints:
|
||||
print("No checkpoints found!")
|
||||
return
|
||||
|
||||
# All checkpoints created by this run share the same workflow_id.
|
||||
workflow_id = all_checkpoints[0].workflow_id
|
||||
|
||||
_render_checkpoint_summary(all_checkpoints)
|
||||
|
||||
# Offer an interactive selection of checkpoints to resume from.
|
||||
sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp)
|
||||
|
||||
print("\nAvailable checkpoints to resume from:")
|
||||
for idx, cp in enumerate(sorted_cps):
|
||||
summary = get_checkpoint_summary(cp)
|
||||
line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}"
|
||||
if summary.status:
|
||||
line += f" status={summary.status}"
|
||||
msg_count = sum(len(v) for v in cp.messages.values())
|
||||
line += f" messages={msg_count}"
|
||||
print(line)
|
||||
|
||||
user_input = input( # noqa: ASYNC250
|
||||
"\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: "
|
||||
).strip()
|
||||
|
||||
if not user_input:
|
||||
print("No checkpoint selected. Exiting without resuming.")
|
||||
return
|
||||
|
||||
chosen_cp_id: str | None = None
|
||||
|
||||
# Try as index first
|
||||
if user_input.isdigit():
|
||||
idx = int(user_input)
|
||||
if 0 <= idx < len(sorted_cps):
|
||||
chosen_cp_id = sorted_cps[idx].checkpoint_id
|
||||
# Fall back to direct id match
|
||||
if chosen_cp_id is None:
|
||||
for cp in sorted_cps:
|
||||
if cp.checkpoint_id.startswith(user_input): # allow prefix match for convenience
|
||||
chosen_cp_id = cp.checkpoint_id
|
||||
output: str | None = None
|
||||
async for event in event_stream:
|
||||
if isinstance(event, WorkflowOutputEvent):
|
||||
output = event.data
|
||||
break
|
||||
if isinstance(event, SuperStepCompletedEvent) and random() < 0.5:
|
||||
# Randomly simulate system interruptions
|
||||
# The `SuperStepCompletedEvent` ensures we only interrupt after
|
||||
# the current super-step is fully complete and checkpointed.
|
||||
# If we interrupt mid-step, the workflow may resume from an earlier point.
|
||||
print("\n** Simulating workflow interruption. Stopping execution. **")
|
||||
break
|
||||
|
||||
if chosen_cp_id is None:
|
||||
print("Input did not match any checkpoint. Exiting without resuming.")
|
||||
return
|
||||
# Find the latest checkpoint to resume from
|
||||
all_checkpoints = await checkpoint_storage.list_checkpoints()
|
||||
if not all_checkpoints:
|
||||
raise RuntimeError("No checkpoints available to resume from.")
|
||||
latest_checkpoint = all_checkpoints[-1]
|
||||
print(
|
||||
f"Checkpoint {latest_checkpoint.checkpoint_id}: "
|
||||
f"(iter={latest_checkpoint.iteration_count}, messages={latest_checkpoint.messages})"
|
||||
)
|
||||
|
||||
# You can reuse the same workflow graph definition and resume from a prior checkpoint.
|
||||
# This second workflow instance does not enable checkpointing to show that resumption
|
||||
# reads from stored state but need not write new checkpoints.
|
||||
new_workflow = create_workflow(checkpoint_storage=checkpoint_storage)
|
||||
|
||||
print(f"\nResuming from checkpoint: {chosen_cp_id}")
|
||||
async for event in new_workflow.run_stream(checkpoint_id=chosen_cp_id, checkpoint_storage=checkpoint_storage):
|
||||
print(f"Resumed Event: {event}")
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
Running workflow with initial message...
|
||||
UpperCaseExecutor: 'hello world' -> 'HELLO WORLD'
|
||||
Event: ExecutorInvokeEvent(executor_id=upper_case_executor)
|
||||
Event: ExecutorCompletedEvent(executor_id=upper_case_executor)
|
||||
ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH'
|
||||
Event: ExecutorInvokeEvent(executor_id=reverse_text_executor)
|
||||
Event: ExecutorCompletedEvent(executor_id=reverse_text_executor)
|
||||
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
|
||||
Event: ExecutorInvokeEvent(executor_id=submit_lower)
|
||||
Event: ExecutorInvokeEvent(executor_id=lower_agent)
|
||||
Event: ExecutorInvokeEvent(executor_id=finalize)
|
||||
|
||||
Checkpoint summary:
|
||||
- dfc63e72-8e8d-454f-9b6d-0d740b9062e6 | label='after_initial_execution' | iter=0 | messages=1 | states=['upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
|
||||
- a78c345a-e5d9-45ba-82c0-cb725452d91b | label='superstep_1' | iter=1 | messages=1 | states=['reverse_text_executor', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
|
||||
- 637c1dbd-a525-4404-9583-da03980537a2 | label='superstep_2' | iter=2 | messages=0 | states=['finalize', 'lower_agent', 'reverse_text_executor', 'submit_lower', 'upper_case_executor'] | shared_state: original_input='hello world', upper_output='HELLO WORLD'
|
||||
|
||||
Available checkpoints to resume from:
|
||||
[0] id=dfc63e72-... iter=0 messages=1 label='after_initial_execution'
|
||||
[1] id=a78c345a-... iter=1 messages=1 label='superstep_1'
|
||||
[2] id=637c1dbd-... iter=2 messages=0 label='superstep_2'
|
||||
|
||||
Enter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: 1
|
||||
|
||||
Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b
|
||||
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower)
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent)
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=finalize)
|
||||
""" # noqa: E501
|
||||
if output is not None:
|
||||
print(f"\nWorkflow completed successfully with output: {output}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,6 +7,7 @@ import uuid
|
||||
from dataclasses import dataclass, field, replace
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, override
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
@@ -205,6 +206,8 @@ class LaunchCoordinator(Executor):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(id="launch_coordinator")
|
||||
# Track pending requests to match responses
|
||||
self._pending_requests: dict[str, SubWorkflowRequestMessage] = {}
|
||||
|
||||
@handler
|
||||
async def kick_off(self, topic: str, ctx: WorkflowContext[DraftTask]) -> None:
|
||||
@@ -244,11 +247,9 @@ class LaunchCoordinator(Executor):
|
||||
if not isinstance(request.source_event.data, ReviewRequest):
|
||||
raise TypeError(f"Expected 'ReviewRequest', got {type(request.source_event.data)}")
|
||||
|
||||
# Record the request to response matching
|
||||
# Record the request for response matching
|
||||
review_request = request.source_event.data
|
||||
executor_state = await ctx.get_executor_state() or {}
|
||||
executor_state[review_request.id] = request
|
||||
await ctx.set_executor_state(executor_state)
|
||||
self._pending_requests[review_request.id] = request
|
||||
|
||||
# Send the request without modification
|
||||
await ctx.request_info(request_data=review_request, response_type=str)
|
||||
@@ -265,17 +266,25 @@ class LaunchCoordinator(Executor):
|
||||
Note that the response must be sent back using SubWorkflowResponseMessage to route
|
||||
the response back to the sub-workflow.
|
||||
"""
|
||||
executor_state = await ctx.get_executor_state() or {}
|
||||
request_message = executor_state.pop(original_request.id, None)
|
||||
|
||||
# Save the executor state back to the context
|
||||
await ctx.set_executor_state(executor_state)
|
||||
request_message = self._pending_requests.pop(original_request.id, None)
|
||||
|
||||
if request_message is None:
|
||||
raise ValueError("No matching pending request found for the resource response")
|
||||
|
||||
await ctx.send_message(request_message.create_response(response))
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture any additional state needed for checkpointing."""
|
||||
return {
|
||||
"pending_requests": self._pending_requests,
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore any additional state needed from checkpointing."""
|
||||
self._pending_requests = state.get("pending_requests", {})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow construction helpers
|
||||
@@ -356,9 +365,7 @@ async def main() -> None:
|
||||
workflow2 = build_parent_workflow(storage)
|
||||
|
||||
request_info_event: RequestInfoEvent | None = None
|
||||
async for event in workflow2.run_stream(
|
||||
resume_checkpoint.checkpoint_id,
|
||||
):
|
||||
async for event in workflow2.run_stream(checkpoint_id=resume_checkpoint.checkpoint_id):
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
request_info_event = event
|
||||
|
||||
|
||||
Reference in New Issue
Block a user