mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add checkpoint save and restore hooks to executor (#2097)
* Add checkpoint hooks * Deprecate get_executor_state and set_executor_state * Fix tests and samples * Add doc strings * Add sample * Fix import * Address comments and fix tests * Address comments * conditional import
This commit is contained in:
committed by
GitHub
Unverified
parent
132597957a
commit
c361ad8d33
@@ -37,6 +37,8 @@ from ._events import (
|
||||
ExecutorFailedEvent,
|
||||
ExecutorInvokedEvent,
|
||||
RequestInfoEvent,
|
||||
SuperStepCompletedEvent,
|
||||
SuperStepStartedEvent,
|
||||
WorkflowErrorDetails,
|
||||
WorkflowEvent,
|
||||
WorkflowEventSource,
|
||||
@@ -152,6 +154,8 @@ __all__ = [
|
||||
"StandardMagenticManager",
|
||||
"SubWorkflowRequestMessage",
|
||||
"SubWorkflowResponseMessage",
|
||||
"SuperStepCompletedEvent",
|
||||
"SuperStepStartedEvent",
|
||||
"SwitchCaseEdgeGroup",
|
||||
"SwitchCaseEdgeGroupCase",
|
||||
"SwitchCaseEdgeGroupDefault",
|
||||
|
||||
@@ -35,6 +35,8 @@ from ._events import (
|
||||
ExecutorFailedEvent,
|
||||
ExecutorInvokedEvent,
|
||||
RequestInfoEvent,
|
||||
SuperStepCompletedEvent,
|
||||
SuperStepStartedEvent,
|
||||
WorkflowErrorDetails,
|
||||
WorkflowEvent,
|
||||
WorkflowEventSource,
|
||||
@@ -148,6 +150,8 @@ __all__ = [
|
||||
"StandardMagenticManager",
|
||||
"SubWorkflowRequestMessage",
|
||||
"SubWorkflowResponseMessage",
|
||||
"SuperStepCompletedEvent",
|
||||
"SuperStepStartedEvent",
|
||||
"SwitchCaseEdgeGroup",
|
||||
"SwitchCaseEdgeGroupCase",
|
||||
"SwitchCaseEdgeGroupDefault",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -20,6 +21,11 @@ from ._message_utils import normalize_messages_input
|
||||
from ._request_info_mixin import response_handler
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -179,7 +185,8 @@ class AgentExecutor(Executor):
|
||||
self._pending_responses_to_agent.clear()
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
async def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current executor state for checkpointing.
|
||||
|
||||
NOTE: if the thread storage is on the server side, the full thread state
|
||||
@@ -196,9 +203,6 @@ class AgentExecutor(Executor):
|
||||
client_module = self._agent.chat_client.__class__.__module__
|
||||
|
||||
if client_class_name == "AzureAIAgentClient" and "azure_ai" in client_module:
|
||||
# TODO(TaoChenOSU): update this warning when we surface the hooks for
|
||||
# custom executor checkpointing.
|
||||
# https://github.com/microsoft/agent-framework/issues/1816
|
||||
logger.warning(
|
||||
"Checkpointing an AgentExecutor with AzureAIAgentClient that uses server-side threads. "
|
||||
"Currently, checkpointing does not capture messages from server-side threads "
|
||||
@@ -217,7 +221,8 @@ class AgentExecutor(Executor):
|
||||
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
|
||||
}
|
||||
|
||||
async def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore executor state from checkpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any
|
||||
@@ -13,6 +14,12 @@ from ._executor import Executor
|
||||
from ._orchestrator_helpers import ParticipantRegistry
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -210,11 +217,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
|
||||
|
||||
# State persistence (shared across all patterns)
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current orchestrator state for checkpointing.
|
||||
|
||||
Default implementation uses OrchestrationState to serialize common state.
|
||||
Subclasses should override _snapshot_pattern_metadata() to add pattern-specific data.
|
||||
Subclasses can override this method or _snapshot_pattern_metadata() to add pattern-specific data.
|
||||
|
||||
Returns:
|
||||
Serialized state dict
|
||||
@@ -238,11 +246,12 @@ class BaseGroupChatOrchestrator(Executor, ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore orchestrator state from checkpoint.
|
||||
|
||||
Default implementation uses OrchestrationState to deserialize common state.
|
||||
Subclasses should override _restore_pattern_metadata() to restore pattern-specific data.
|
||||
Subclasses can override this method or _restore_pattern_metadata() to restore pattern-specific data.
|
||||
|
||||
Args:
|
||||
state: Serialized state dict
|
||||
|
||||
@@ -6,9 +6,7 @@ These utilities operate on standard `list[ChatMessage]` collections and simple
|
||||
dictionary snapshots so orchestrators can share logic without new mixins.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .._types import ChatMessage
|
||||
|
||||
@@ -26,25 +24,3 @@ def ensure_author(message: ChatMessage, fallback: str) -> ChatMessage:
|
||||
"""Attach `fallback` author if message is missing `author_name`."""
|
||||
message.author_name = message.author_name or fallback
|
||||
return message
|
||||
|
||||
|
||||
def snapshot_state(conversation: Sequence[ChatMessage]) -> dict[str, Any]:
|
||||
"""Build an immutable snapshot for checkpoint storage."""
|
||||
if hasattr(conversation, "to_dict"):
|
||||
result = conversation.to_dict() # type: ignore[attr-defined]
|
||||
if isinstance(result, dict):
|
||||
return result # type: ignore[return-value]
|
||||
if isinstance(result, Mapping):
|
||||
return dict(result) # type: ignore[arg-type]
|
||||
serialisable: list[dict[str, Any]] = []
|
||||
for message in conversation:
|
||||
if hasattr(message, "to_dict") and callable(message.to_dict): # type: ignore[attr-defined]
|
||||
msg_dict = message.to_dict() # type: ignore[attr-defined]
|
||||
serialisable.append(dict(msg_dict) if isinstance(msg_dict, Mapping) else msg_dict) # type: ignore[arg-type]
|
||||
elif hasattr(message, "to_json") and callable(message.to_json): # type: ignore[attr-defined]
|
||||
json_payload = message.to_json() # type: ignore[attr-defined]
|
||||
parsed = json.loads(json_payload) if isinstance(json_payload, str) else json_payload
|
||||
serialisable.append(dict(parsed) if isinstance(parsed, Mapping) else parsed) # type: ignore[arg-type]
|
||||
else:
|
||||
serialisable.append(dict(getattr(message, "__dict__", {}))) # type: ignore[arg-type]
|
||||
return {"messages": serialisable}
|
||||
|
||||
@@ -294,6 +294,36 @@ class WorkflowOutputEvent(WorkflowEvent):
|
||||
return f"{self.__class__.__name__}(data={self.data}, source_executor_id={self.source_executor_id})"
|
||||
|
||||
|
||||
class SuperStepEvent(WorkflowEvent):
|
||||
"""Event triggered when a superstep starts or ends."""
|
||||
|
||||
def __init__(self, iteration: int, data: Any | None = None):
|
||||
"""Initialize the superstep event.
|
||||
|
||||
Args:
|
||||
iteration: The number of the superstep (1-based index).
|
||||
data: Optional data associated with the superstep event.
|
||||
"""
|
||||
super().__init__(data)
|
||||
self.iteration = iteration
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the superstep event."""
|
||||
return f"{self.__class__.__name__}(iteration={self.iteration}, data={self.data})"
|
||||
|
||||
|
||||
class SuperStepStartedEvent(SuperStepEvent):
|
||||
"""Event triggered when a superstep starts."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class SuperStepCompletedEvent(SuperStepEvent):
|
||||
"""Event triggered when a superstep ends."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ExecutorEvent(WorkflowEvent):
|
||||
"""Base class for executor events."""
|
||||
|
||||
@@ -310,17 +340,13 @@ class ExecutorEvent(WorkflowEvent):
|
||||
class ExecutorInvokedEvent(ExecutorEvent):
|
||||
"""Event triggered when an executor handler is invoked."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the executor handler invoke event."""
|
||||
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
|
||||
...
|
||||
|
||||
|
||||
class ExecutorCompletedEvent(ExecutorEvent):
|
||||
"""Event triggered when an executor handler is completed."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the executor handler complete event."""
|
||||
return f"{self.__class__.__name__}(executor_id={self.executor_id}, data={self.data})"
|
||||
...
|
||||
|
||||
|
||||
class ExecutorFailedEvent(ExecutorEvent):
|
||||
|
||||
@@ -155,6 +155,11 @@ class Executor(RequestInfoMixin, DictConvertible):
|
||||
that parent workflows can intercept. See WorkflowExecutor documentation for details on
|
||||
workflow composition patterns and request/response handling.
|
||||
|
||||
## State Management
|
||||
Executors can contain states that persist across workflow runs and checkpoints. Override the
|
||||
`on_checkpoint_save` and `on_checkpoint_restore` methods to implement custom state
|
||||
serialization and restoration logic.
|
||||
|
||||
## Implementation Notes
|
||||
- Do not call `execute()` directly - it's invoked by the workflow engine
|
||||
- Do not override `execute()` - define handlers using decorators instead
|
||||
@@ -460,6 +465,32 @@ class Executor(RequestInfoMixin, DictConvertible):
|
||||
return self._handlers[message_type]
|
||||
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")
|
||||
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Hook called when the workflow is being saved to a checkpoint.
|
||||
|
||||
Override this method in subclasses to implement custom logic that should
|
||||
return state to be saved in the checkpoint.
|
||||
|
||||
The returned state dictionary will be passed to `on_checkpoint_restore`
|
||||
when the workflow is restored from the checkpoint. The dictionary should
|
||||
only contain JSON-serializable data.
|
||||
|
||||
Returns:
|
||||
A state dictionary to be saved during checkpointing.
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Hook called when the workflow is restored from a checkpoint.
|
||||
|
||||
Override this method in subclasses to implement custom logic that should
|
||||
run when the workflow is restored from a checkpoint.
|
||||
|
||||
Args:
|
||||
state: The state dictionary that was saved during checkpointing.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# endregion: Executor
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ Key properties:
|
||||
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
@@ -50,6 +51,12 @@ from ._workflow import Workflow
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -307,15 +314,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
|
||||
) -> None:
|
||||
"""Process an agent's response and determine whether to route, request input, or terminate."""
|
||||
# Hydrate coordinator state (and detect new run) using checkpointable executor state
|
||||
state = await ctx.get_executor_state()
|
||||
if not state:
|
||||
self._clear_conversation()
|
||||
elif not self._get_conversation():
|
||||
restored = self._restore_conversation_from_state(state)
|
||||
if restored:
|
||||
self._conversation = list(restored)
|
||||
|
||||
source = ctx.get_source_executor_id()
|
||||
is_starting_agent = source == self._starting_agent_id
|
||||
|
||||
@@ -343,7 +341,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
# Update current agent when handoff occurs
|
||||
self._current_agent_id = target
|
||||
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
|
||||
await self._persist_state(ctx)
|
||||
|
||||
# Clean tool-related content before sending to next agent
|
||||
cleaned = clean_conversation_for_handoff(conversation)
|
||||
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
|
||||
@@ -360,7 +358,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
f"Agent '{source}' responded without handoff. "
|
||||
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
|
||||
)
|
||||
await self._persist_state(ctx)
|
||||
|
||||
if await self._check_termination():
|
||||
# Clean the output conversation for display
|
||||
@@ -388,7 +385,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
"""Receive full conversation with new user input from gateway, update history, trim for agent."""
|
||||
# Update authoritative conversation
|
||||
self._conversation = list(message.full_conversation)
|
||||
await self._persist_state(ctx)
|
||||
|
||||
# Check termination before sending to agent
|
||||
if await self._check_termination():
|
||||
@@ -473,11 +469,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
)
|
||||
return list(conversation)
|
||||
|
||||
async def _persist_state(self, ctx: WorkflowContext[Any, Any]) -> None:
|
||||
"""Store authoritative conversation snapshot without losing rich metadata."""
|
||||
state_payload = self.snapshot_state()
|
||||
await ctx.set_executor_state(state_payload)
|
||||
|
||||
@override
|
||||
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
|
||||
"""Serialize pattern-specific state.
|
||||
|
||||
@@ -492,6 +484,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
}
|
||||
return {}
|
||||
|
||||
@override
|
||||
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
|
||||
"""Restore pattern-specific state.
|
||||
|
||||
@@ -503,17 +496,6 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
if self._return_to_previous and "current_agent_id" in metadata:
|
||||
self._current_agent_id = metadata["current_agent_id"]
|
||||
|
||||
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
|
||||
"""Rehydrate the coordinator's conversation history from checkpointed state.
|
||||
|
||||
DEPRECATED: Use restore_state() instead. Kept for backward compatibility.
|
||||
"""
|
||||
from ._orchestration_state import OrchestrationState
|
||||
|
||||
orch_state_dict = {"conversation": state.get("full_conversation", state.get("conversation", []))}
|
||||
temp_state = OrchestrationState.from_dict(orch_state_dict)
|
||||
return list(temp_state.conversation)
|
||||
|
||||
def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None:
|
||||
"""Merge top-level response metadata into the latest assistant message."""
|
||||
if not agent_response.additional_properties:
|
||||
|
||||
@@ -45,9 +45,15 @@ from ._workflow import Workflow, WorkflowRunResult
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -673,11 +679,11 @@ class MagenticManagerBase(ABC):
|
||||
"""Prepare the final answer."""
|
||||
...
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Serialize runtime state for checkpointing."""
|
||||
return {}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore runtime state from checkpoint data."""
|
||||
return
|
||||
|
||||
@@ -695,22 +701,6 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
|
||||
task_ledger: _MagenticTaskLedger | None
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
state = super().snapshot_state()
|
||||
if self.task_ledger is not None:
|
||||
state = dict(state)
|
||||
state["task_ledger"] = self.task_ledger.to_dict()
|
||||
return state
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
super().restore_state(state)
|
||||
ledger = state.get("task_ledger")
|
||||
if ledger is not None:
|
||||
try:
|
||||
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Failed to restore manager task ledger from checkpoint state")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_client: ChatClientProtocol,
|
||||
@@ -940,6 +930,22 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
author_name=response.author_name or MAGENTIC_MANAGER_NAME,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
state: dict[str, Any] = {}
|
||||
if self.task_ledger is not None:
|
||||
state["task_ledger"] = self.task_ledger.to_dict()
|
||||
return state
|
||||
|
||||
@override
|
||||
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
ledger = state.get("task_ledger")
|
||||
if ledger is not None:
|
||||
try:
|
||||
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Failed to restore manager task ledger from checkpoint state")
|
||||
|
||||
|
||||
# endregion Magentic Manager
|
||||
|
||||
@@ -997,7 +1003,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
# Terminal state marker to stop further processing after completion/limits
|
||||
self._terminated = False
|
||||
# Tracks whether checkpoint state has been applied for this run
|
||||
self._state_restored = False
|
||||
|
||||
def _get_author_name(self) -> str:
|
||||
"""Get the magentic manager name for orchestrator-generated messages."""
|
||||
@@ -1036,7 +1041,8 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
)
|
||||
await ctx.add_event(event)
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current orchestrator state for checkpointing.
|
||||
|
||||
Uses OrchestrationState for structure but maintains Magentic's complex metadata
|
||||
@@ -1055,14 +1061,16 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
state["magentic_context"] = self._context.to_dict()
|
||||
if self._task_ledger is not None:
|
||||
state["task_ledger"] = _message_to_payload(self._task_ledger)
|
||||
manager_state: dict[str, Any] | None = None
|
||||
with contextlib.suppress(Exception):
|
||||
manager_state = self._manager.snapshot_state()
|
||||
if manager_state:
|
||||
state["manager_state"] = manager_state
|
||||
|
||||
try:
|
||||
state["manager_state"] = self._manager.on_checkpoint_save()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to save manager state for checkpoint: %s\nSkipping...", exc)
|
||||
|
||||
return state
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore orchestrator state from checkpoint.
|
||||
|
||||
Maintains backward compatibility with existing Magentic checkpoints
|
||||
@@ -1112,7 +1120,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
manager_state = state.get("manager_state")
|
||||
if manager_state is not None:
|
||||
try:
|
||||
self._manager.restore_state(manager_state)
|
||||
self._manager.on_checkpoint_restore(manager_state)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning("Failed to restore manager state: %s", exc)
|
||||
|
||||
@@ -1142,49 +1150,6 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
for name, description in expected.items():
|
||||
restored[name] = description
|
||||
|
||||
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
|
||||
"""Serialize pattern-specific state.
|
||||
|
||||
Magentic uses custom snapshot_state() instead of base class hooks.
|
||||
This method exists to satisfy the base class contract.
|
||||
|
||||
Returns:
|
||||
Empty dict (Magentic manages its own state)
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
|
||||
"""Restore pattern-specific state.
|
||||
|
||||
Magentic uses custom restore_state() instead of base class hooks.
|
||||
This method exists to satisfy the base class contract.
|
||||
|
||||
Args:
|
||||
metadata: Pattern-specific state dict (ignored)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _ensure_state_restored(
|
||||
self,
|
||||
context: WorkflowContext[Any, Any],
|
||||
) -> None:
|
||||
if self._state_restored and self._context is not None:
|
||||
return
|
||||
state = await context.get_executor_state()
|
||||
if not state:
|
||||
self._state_restored = True
|
||||
return
|
||||
if not isinstance(state, dict):
|
||||
self._state_restored = True
|
||||
return
|
||||
try:
|
||||
self.restore_state(state)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning("Magentic Orchestrator: Failed to apply checkpoint state: %s", exc, exc_info=True)
|
||||
raise
|
||||
else:
|
||||
self._state_restored = True
|
||||
|
||||
@handler
|
||||
async def handle_start_message(
|
||||
self,
|
||||
@@ -1204,7 +1169,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
)
|
||||
if message.messages:
|
||||
self._context.chat_history.extend(message.messages)
|
||||
self._state_restored = True
|
||||
|
||||
# Non-streaming callback for the orchestrator receipt of the task
|
||||
await self._emit_orchestrator_message(context, message.task, ORCH_MSG_KIND_USER_TASK)
|
||||
|
||||
@@ -1269,7 +1234,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
"""Handle responses from agents."""
|
||||
if getattr(self, "_terminated", False):
|
||||
return
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
if self._context is None:
|
||||
raise RuntimeError("Magentic Orchestrator: Received response but not initialized")
|
||||
|
||||
@@ -1301,7 +1266,7 @@ class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator):
|
||||
) -> None:
|
||||
if getattr(self, "_terminated", False):
|
||||
return
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
if self._context is None:
|
||||
return
|
||||
|
||||
@@ -1636,9 +1601,9 @@ class MagenticAgentExecutor(Executor):
|
||||
self._agent = agent
|
||||
self._agent_id = agent_id
|
||||
self._chat_history: list[ChatMessage] = []
|
||||
self._state_restored = False
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current executor state for checkpointing.
|
||||
|
||||
Returns:
|
||||
@@ -1650,7 +1615,8 @@ class MagenticAgentExecutor(Executor):
|
||||
"chat_history": encode_chat_messages(self._chat_history),
|
||||
}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore executor state from checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -1668,24 +1634,6 @@ class MagenticAgentExecutor(Executor):
|
||||
else:
|
||||
self._chat_history = []
|
||||
|
||||
async def _ensure_state_restored(self, context: WorkflowContext[Any, Any]) -> None:
|
||||
if self._state_restored and self._chat_history:
|
||||
return
|
||||
state = await context.get_executor_state()
|
||||
if not state:
|
||||
self._state_restored = True
|
||||
return
|
||||
if not isinstance(state, dict):
|
||||
self._state_restored = True
|
||||
return
|
||||
try:
|
||||
self.restore_state(state)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning("Agent %s: Failed to apply checkpoint state: %s", self._agent_id, exc, exc_info=True)
|
||||
raise
|
||||
else:
|
||||
self._state_restored = True
|
||||
|
||||
@handler
|
||||
async def handle_response_message(
|
||||
self, message: _MagenticResponseMessage, context: WorkflowContext[_MagenticResponseMessage]
|
||||
@@ -1693,8 +1641,6 @@ class MagenticAgentExecutor(Executor):
|
||||
"""Handle response message (task ledger broadcast)."""
|
||||
logger.debug("Agent %s: Received response message", self._agent_id)
|
||||
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
# Check if this message is intended for this agent
|
||||
if message.target_agent is not None and message.target_agent != self._agent_id and not message.broadcast:
|
||||
# Message is targeted to a different agent, ignore it
|
||||
@@ -1735,8 +1681,6 @@ class MagenticAgentExecutor(Executor):
|
||||
|
||||
logger.info("Agent %s: Received request to respond", self._agent_id)
|
||||
|
||||
await self._ensure_state_restored(context)
|
||||
|
||||
# Add persona adoption message with appropriate role
|
||||
persona_role = self._get_persona_adoption_role()
|
||||
persona_msg = ChatMessage(
|
||||
@@ -1783,7 +1727,6 @@ class MagenticAgentExecutor(Executor):
|
||||
"""Reset the internal chat history of the agent (internal operation)."""
|
||||
logger.debug("Agent %s: Resetting chat history", self._agent_id)
|
||||
self._chat_history.clear()
|
||||
self._state_restored = True
|
||||
|
||||
async def _emit_agent_delta_event(
|
||||
self,
|
||||
|
||||
@@ -11,7 +11,7 @@ from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpo
|
||||
from ._const import EXECUTOR_STATE_KEY
|
||||
from ._edge import EdgeGroup
|
||||
from ._edge_runner import EdgeRunner, create_edge_runner
|
||||
from ._events import WorkflowEvent
|
||||
from ._events import SuperStepCompletedEvent, SuperStepStartedEvent, WorkflowEvent
|
||||
from ._executor import Executor
|
||||
from ._runner_context import (
|
||||
Message,
|
||||
@@ -92,6 +92,7 @@ class Runner:
|
||||
|
||||
while self._iteration < self._max_iterations:
|
||||
logger.info(f"Starting superstep {self._iteration + 1}")
|
||||
yield SuperStepStartedEvent(iteration=self._iteration + 1)
|
||||
|
||||
# Run iteration concurrently with live event streaming: we poll
|
||||
# for new events while the iteration coroutine progresses.
|
||||
@@ -126,6 +127,9 @@ class Runner:
|
||||
# Create checkpoint after each superstep iteration
|
||||
await self._create_checkpoint_if_enabled(f"superstep_{self._iteration}")
|
||||
|
||||
yield SuperStepCompletedEvent(iteration=self._iteration)
|
||||
|
||||
# Check for convergence: no more messages to process
|
||||
if not await self._ctx.has_messages():
|
||||
break
|
||||
|
||||
@@ -183,8 +187,8 @@ class Runner:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Auto-snapshot executor states
|
||||
await self._auto_snapshot_executor_states()
|
||||
# Snapshot executor states
|
||||
await self._save_executor_states()
|
||||
checkpoint_category = "initial" if checkpoint_type == "after_initial_execution" else "superstep"
|
||||
metadata = {
|
||||
"superstep": self._iteration,
|
||||
@@ -203,41 +207,6 @@ class Runner:
|
||||
logger.warning(f"Failed to create {checkpoint_type} checkpoint: {e}")
|
||||
return None
|
||||
|
||||
async def _auto_snapshot_executor_states(self) -> None:
|
||||
"""Populate executor state by calling snapshot hooks on executors if available.
|
||||
|
||||
TODO(@taochen#1614): this method is potentially problematic if executors also call
|
||||
set_executor_state on the context directly. We should clarify the intended usage
|
||||
pattern for executor state management.
|
||||
|
||||
Convention:
|
||||
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
|
||||
- Else if it has a plain attribute `state` that is a dict, use that.
|
||||
Only JSON-serializable dicts should be provided by executors.
|
||||
"""
|
||||
for exec_id, executor in self._executors.items():
|
||||
state_dict: dict[str, Any] | None = None
|
||||
snapshot = getattr(executor, "snapshot_state", None)
|
||||
try:
|
||||
if callable(snapshot):
|
||||
maybe = snapshot()
|
||||
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
|
||||
maybe = await maybe # type: ignore[assignment]
|
||||
if isinstance(maybe, dict):
|
||||
state_dict = maybe # type: ignore[assignment]
|
||||
else:
|
||||
state_attr = getattr(executor, "state", None)
|
||||
if isinstance(state_attr, dict):
|
||||
state_dict = state_attr # type: ignore[assignment]
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
|
||||
|
||||
if state_dict is not None:
|
||||
try:
|
||||
await self._set_executor_state(exec_id, state_dict)
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
|
||||
|
||||
async def restore_from_checkpoint(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
@@ -300,7 +269,65 @@ class Runner:
|
||||
logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _save_executor_states(self) -> None:
|
||||
"""Populate executor state by calling checkpoint hooks on executors.
|
||||
|
||||
Backward compatibility behavior:
|
||||
- If an executor defines an async or sync method `snapshot_state(self) -> dict`, use it.
|
||||
- Else if it has a plain attribute `state` that is a dict, use that.
|
||||
|
||||
Updated behavior:
|
||||
- Executors should implement `on_checkpoint_save(self) -> dict` to provide state.
|
||||
|
||||
This method will try the backward compatibility behavior first; if that does not yield state,
|
||||
it falls back to the updated behavior.
|
||||
|
||||
Only JSON-serializable dicts should be provided by executors.
|
||||
"""
|
||||
for exec_id, executor in self._executors.items():
|
||||
state_dict: dict[str, Any] | None = None
|
||||
# Try backward compatibility behavior first
|
||||
# TODO(@taochen): Remove backward compatibility
|
||||
snapshot = getattr(executor, "snapshot_state", None)
|
||||
try:
|
||||
if callable(snapshot):
|
||||
maybe = snapshot()
|
||||
if asyncio.iscoroutine(maybe): # type: ignore[arg-type]
|
||||
maybe = await maybe # type: ignore[assignment]
|
||||
if isinstance(maybe, dict):
|
||||
state_dict = maybe # type: ignore[assignment]
|
||||
else:
|
||||
state_attr = getattr(executor, "state", None)
|
||||
if isinstance(state_attr, dict):
|
||||
state_dict = state_attr # type: ignore[assignment]
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
|
||||
|
||||
if state_dict is None:
|
||||
# Try the updated behavior only if backward compatibility did not yield state
|
||||
try:
|
||||
state_dict = await executor.on_checkpoint_save()
|
||||
except Exception as ex: # pragma: no cover
|
||||
raise ValueError(f"Executor {exec_id} on_checkpoint_save failed: {ex}") from ex
|
||||
|
||||
try:
|
||||
await self._set_executor_state(exec_id, state_dict)
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
|
||||
|
||||
async def _restore_executor_states(self) -> None:
|
||||
"""Restore executor state by calling restore hooks on executors.
|
||||
|
||||
Backward compatibility behavior:
|
||||
- If an executor defines an async or sync method `restore_state(self, state: dict)`, use it.
|
||||
- Else, skip restoration for that executor.
|
||||
|
||||
Updated behavior:
|
||||
- Executors should implement `on_checkpoint_restore(self, state: dict)` to restore state.
|
||||
|
||||
This method will try the backward compatibility behavior first; if that does not restore state,
|
||||
it falls back to the updated behavior.
|
||||
"""
|
||||
has_executor_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
|
||||
if not has_executor_states:
|
||||
return
|
||||
@@ -309,16 +336,18 @@ class Runner:
|
||||
if not isinstance(executor_states, dict):
|
||||
raise ValueError("Executor states in shared state is not a dictionary. Unable to restore.")
|
||||
|
||||
for executor_id, state in executor_states.items():
|
||||
for executor_id, state in executor_states.items(): # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(executor_id, str):
|
||||
raise ValueError("Executor ID in executor states is not a string. Unable to restore.")
|
||||
if not isinstance(state, dict):
|
||||
raise ValueError(f"Executor state for {executor_id} is not a dictionary. Unable to restore.")
|
||||
if not isinstance(state, dict) or not all(isinstance(k, str) for k in state): # pyright: ignore[reportUnknownVariableType]
|
||||
raise ValueError(f"Executor state for {executor_id} is not a dict[str, Any]. Unable to restore.")
|
||||
|
||||
executor = self._executors.get(executor_id)
|
||||
if not executor:
|
||||
raise ValueError(f"Executor {executor_id} not found during state restoration.")
|
||||
|
||||
# Try backward compatibility behavior first
|
||||
# TODO(@taochen): Remove backward compatibility
|
||||
restored = False
|
||||
restore_method = getattr(executor, "restore_state", None)
|
||||
try:
|
||||
@@ -330,6 +359,14 @@ class Runner:
|
||||
except Exception as ex: # pragma: no cover - defensive
|
||||
raise ValueError(f"Executor {executor_id} restore_state failed: {ex}") from ex
|
||||
|
||||
if not restored:
|
||||
# Try the updated behavior only if backward compatibility did not restore
|
||||
try:
|
||||
await executor.on_checkpoint_restore(state) # pyright: ignore[reportUnknownArgumentType]
|
||||
restored = True
|
||||
except Exception as ex: # pragma: no cover - defensive
|
||||
raise ValueError(f"Executor {executor_id} on_checkpoint_restore failed: {ex}") from ex
|
||||
|
||||
if not restored:
|
||||
logger.debug(f"Executor {executor_id} does not support state restoration; skipping.")
|
||||
|
||||
|
||||
@@ -109,9 +109,9 @@ class Workflow(DictConvertible):
|
||||
"""A graph-based execution engine that orchestrates connected executors.
|
||||
|
||||
## Overview
|
||||
A workflow executes a directed graph of executors connected via edge groups using a Pregel-like model,
|
||||
running in supersteps until the graph becomes idle. Workflows are created using the
|
||||
WorkflowBuilder class - do not instantiate this class directly.
|
||||
A workflow executes a directed graph of executors connected via edge groups using a
|
||||
Pregel-like model, running in supersteps until the graph becomes idle. Workflows
|
||||
are created using the WorkflowBuilder class - do not instantiate this class directly.
|
||||
|
||||
## Execution Model
|
||||
Executors run in synchronized supersteps where each executor:
|
||||
@@ -142,6 +142,10 @@ class Workflow(DictConvertible):
|
||||
- HIL continuation: Provide `responses` to continue after RequestInfoExecutor requests
|
||||
- Runtime checkpointing: Provide `checkpoint_storage` to enable/override checkpointing for this run
|
||||
|
||||
## State Management
|
||||
Workflow instances contain states and states are preserved across calls to `run` and `run_stream`.
|
||||
To execute multiple independent runs, create separate Workflow instances via WorkflowBuilder.
|
||||
|
||||
## External Input Requests
|
||||
Executors within a workflow can request external input using `ctx.request_info()`:
|
||||
1. Executor calls `ctx.request_info()` to request input
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Union, cast, get_args, get_origi
|
||||
|
||||
from opentelemetry.propagate import inject
|
||||
from opentelemetry.trace import SpanKind
|
||||
from typing_extensions import Never, TypeVar
|
||||
from typing_extensions import Never, TypeVar, deprecated
|
||||
|
||||
from ..observability import OtelAttr, create_workflow_span
|
||||
from ._const import EXECUTOR_STATE_KEY
|
||||
@@ -410,6 +410,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
"""Get the shared state."""
|
||||
return self._shared_state
|
||||
|
||||
@deprecated(
|
||||
"Override `on_checkpoint_save()` methods instead. "
|
||||
"For cross-executor state sharing, use set_shared_state() instead. "
|
||||
"This API will be removed after 12/01/2025."
|
||||
)
|
||||
async def set_executor_state(self, state: dict[str, Any]) -> None:
|
||||
"""Store executor state in shared state under a reserved key.
|
||||
|
||||
@@ -428,6 +433,11 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
existing_states[self._executor_id] = state
|
||||
await self._shared_state.set(EXECUTOR_STATE_KEY, existing_states)
|
||||
|
||||
@deprecated(
|
||||
"Override `on_checkpoint_restore()` methods instead. "
|
||||
"For cross-executor state sharing, use get_shared_state() instead. "
|
||||
"This API will be removed after 12/01/2025."
|
||||
)
|
||||
async def get_executor_state(self) -> dict[str, Any] | None:
|
||||
"""Retrieve previously persisted state for this executor, if any."""
|
||||
has_existing_states = await self._shared_state.has(EXECUTOR_STATE_KEY)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -26,6 +26,12 @@ from ._typing_utils import is_instance_of
|
||||
from ._workflow import WorkflowRunResult
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -181,8 +187,7 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
# Includes all sub-workflow output types
|
||||
# Plus SubWorkflowRequestMessage if sub-workflow can make requests
|
||||
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
|
||||
```
|
||||
output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable
|
||||
|
||||
## Error Handling
|
||||
WorkflowExecutor propagates sub-workflow failures:
|
||||
@@ -221,23 +226,10 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
### Important Considerations
|
||||
**Shared Workflow Instance**: All concurrent executions use the same underlying workflow instance.
|
||||
For proper isolation, ensure that:
|
||||
- The wrapped workflow and its executors are stateless
|
||||
- Executors use WorkflowContext state management instead of instance variables
|
||||
- Any shared state is managed through WorkflowContext.get_shared_state/set_shared_state
|
||||
For proper isolation, ensure that the wrapped workflow and its executors are stateless.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Good: Stateless executor using context state
|
||||
class StatelessExecutor(Executor):
|
||||
@handler
|
||||
async def process(self, data: str, ctx: WorkflowContext[str]) -> None:
|
||||
# Use context state instead of instance variables
|
||||
state = await ctx.get_executor_state() or {}
|
||||
state["processed"] = data
|
||||
await ctx.set_executor_state(state)
|
||||
|
||||
|
||||
# Avoid: Stateful executor with instance variables
|
||||
class StatefulExecutor(Executor):
|
||||
def __init__(self):
|
||||
@@ -246,23 +238,23 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
## Integration with Parent Workflows
|
||||
Parent workflows can intercept sub-workflow requests:
|
||||
```python
|
||||
class ParentExecutor(Executor):
|
||||
@handler
|
||||
async def handle_subworkflow_request(
|
||||
self,
|
||||
request: SubWorkflowRequestMessage,
|
||||
ctx: WorkflowContext[SubWorkflowResponseMessage],
|
||||
) -> None:
|
||||
# Handle request locally or forward to external source
|
||||
if self.can_handle_locally(request):
|
||||
# Send response back to sub-workflow
|
||||
response = request.create_response(data="local response data")
|
||||
await ctx.send_message(response, target_id=request.source_executor_id)
|
||||
else:
|
||||
# Forward to external handler
|
||||
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
|
||||
```
|
||||
|
||||
.. code-block:: python
|
||||
class ParentExecutor(Executor):
|
||||
@handler
|
||||
async def handle_subworkflow_request(
|
||||
self,
|
||||
request: SubWorkflowRequestMessage,
|
||||
ctx: WorkflowContext[SubWorkflowResponseMessage],
|
||||
) -> None:
|
||||
# Handle request locally or forward to external source
|
||||
if self.can_handle_locally(request):
|
||||
# Send response back to sub-workflow
|
||||
response = request.create_response(data="local response data")
|
||||
await ctx.send_message(response, target_id=request.source_executor_id)
|
||||
else:
|
||||
# Forward to external handler
|
||||
await ctx.request_info(request.source_event, response_type=request.source_event.response_type)
|
||||
|
||||
## Implementation Notes
|
||||
- Sub-workflows run to completion before processing their results
|
||||
@@ -296,7 +288,6 @@ class WorkflowExecutor(Executor):
|
||||
self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext
|
||||
# Map request_id to execution_id for response routing
|
||||
self._request_to_execution: dict[str, str] = {} # request_id -> execution_id
|
||||
self._state_loaded: bool = False
|
||||
|
||||
@property
|
||||
def input_types(self) -> list[type[Any]]:
|
||||
@@ -362,8 +353,6 @@ class WorkflowExecutor(Executor):
|
||||
input_data: The input data to send to the sub-workflow.
|
||||
ctx: The workflow context from the parent.
|
||||
"""
|
||||
await self._ensure_state_loaded(ctx)
|
||||
|
||||
# Create execution context for this sub-workflow run
|
||||
execution_id = str(uuid.uuid4())
|
||||
execution_context = ExecutionContext(
|
||||
@@ -405,8 +394,6 @@ class WorkflowExecutor(Executor):
|
||||
response: The response to a previous request.
|
||||
ctx: The workflow context.
|
||||
"""
|
||||
await self._ensure_state_loaded(ctx)
|
||||
|
||||
# Find the execution context for this request
|
||||
original_request = response.source_event
|
||||
execution_id = self._request_to_execution.get(original_request.request_id)
|
||||
@@ -434,8 +421,6 @@ class WorkflowExecutor(Executor):
|
||||
# Accumulate the response in this execution's context
|
||||
execution_context.collected_responses[original_request.request_id] = response.data
|
||||
|
||||
await self._persist_execution_state(ctx)
|
||||
|
||||
# Check if we have all expected responses for this execution
|
||||
if len(execution_context.collected_responses) < execution_context.expected_response_count:
|
||||
logger.debug(
|
||||
@@ -459,25 +444,20 @@ class WorkflowExecutor(Executor):
|
||||
if not execution_context.pending_requests:
|
||||
del self._execution_contexts[execution_id]
|
||||
|
||||
async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None:
|
||||
if self._state_loaded:
|
||||
return
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Get the current state of the WorkflowExecutor for checkpointing purposes."""
|
||||
return {
|
||||
"execution_contexts": {
|
||||
execution_id: encode_checkpoint_value(execution_context)
|
||||
for execution_id, execution_context in self._execution_contexts.items()
|
||||
},
|
||||
"request_to_execution": dict(self._request_to_execution),
|
||||
}
|
||||
|
||||
state: dict[str, Any] | None = None
|
||||
try:
|
||||
state = await ctx.get_executor_state()
|
||||
except Exception:
|
||||
state = None
|
||||
|
||||
if isinstance(state, dict) and state:
|
||||
with contextlib.suppress(Exception):
|
||||
await self.restore_state(state)
|
||||
self._state_loaded = True
|
||||
else:
|
||||
self._state_loaded = True
|
||||
|
||||
async def restore_state(self, state: dict[str, Any]) -> None:
|
||||
"""Restore pending request bookkeeping from a checkpoint snapshot."""
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
"""Restore the WorkflowExecutor state from a checkpoint snapshot."""
|
||||
# Validate the state contains the right keys
|
||||
if "execution_contexts" not in state:
|
||||
raise KeyError("Missing 'execution_contexts' in WorkflowExecutor state.")
|
||||
@@ -529,23 +509,6 @@ class WorkflowExecutor(Executor):
|
||||
for event in request_info_events
|
||||
])
|
||||
|
||||
self._state_loaded = True
|
||||
|
||||
async def _persist_execution_state(self, ctx: WorkflowContext) -> None:
|
||||
"""Persist the state of the WorkflowExecutor for checkpointing purposes."""
|
||||
state = {
|
||||
"execution_contexts": {
|
||||
execution_id: encode_checkpoint_value(execution_context)
|
||||
for execution_id, execution_context in self._execution_contexts.items()
|
||||
},
|
||||
"request_to_execution": dict(self._request_to_execution),
|
||||
}
|
||||
|
||||
try:
|
||||
await ctx.set_executor_state(state)
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}")
|
||||
|
||||
async def _process_workflow_result(
|
||||
self,
|
||||
result: WorkflowRunResult,
|
||||
@@ -635,5 +598,3 @@ class WorkflowExecutor(Executor):
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}")
|
||||
|
||||
await self._persist_execution_state(ctx)
|
||||
|
||||
@@ -158,8 +158,8 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
assert thread_messages[1].text == "Initial response 1"
|
||||
|
||||
|
||||
async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
|
||||
"""Test AgentExecutor's snapshot_state and restore_state methods directly."""
|
||||
async def test_agent_executor_save_and_restore_state_directly() -> None:
|
||||
"""Test AgentExecutor's on_checkpoint_save and on_checkpoint_restore methods directly."""
|
||||
# Create agent with thread containing messages
|
||||
agent = _CountingAgent(id="direct_test_agent", name="DirectTestAgent")
|
||||
thread = AgentThread(message_store=ChatMessageStore())
|
||||
@@ -182,7 +182,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
|
||||
executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Snapshot the state
|
||||
state = await executor.snapshot_state() # type: ignore[reportUnknownMemberType]
|
||||
state = await executor.on_checkpoint_save()
|
||||
|
||||
# Verify snapshot contains both cache and thread
|
||||
assert "cache" in state
|
||||
@@ -206,7 +206,7 @@ async def test_agent_executor_snapshot_and_restore_state_directly() -> None:
|
||||
assert len(initial_thread_msgs) == 0
|
||||
|
||||
# Restore state
|
||||
await new_executor.restore_state(state) # type: ignore[reportUnknownMemberType]
|
||||
await new_executor.on_checkpoint_restore(state)
|
||||
|
||||
# Verify cache is restored
|
||||
restored_cache = new_executor._cache # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -288,57 +288,6 @@ def test_build_fails_without_participants():
|
||||
HandoffBuilder().build()
|
||||
|
||||
|
||||
async def test_multiple_runs_dont_leak_conversation():
|
||||
"""Verify that running the same workflow multiple times doesn't leak conversation history."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist")
|
||||
specialist = _RecordingAgent(name="specialist")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist])
|
||||
.set_coordinator("triage")
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2)
|
||||
.build()
|
||||
)
|
||||
|
||||
# First run
|
||||
events = await _drain(workflow.run_stream("First run message"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second message"}))
|
||||
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
|
||||
assert outputs, "First run should emit output"
|
||||
|
||||
first_run_conversation = outputs[-1].data
|
||||
assert isinstance(first_run_conversation, list)
|
||||
first_run_conv_list = cast(list[ChatMessage], first_run_conversation)
|
||||
first_run_user_messages = [msg for msg in first_run_conv_list if msg.role == Role.USER]
|
||||
assert len(first_run_user_messages) == 2
|
||||
assert any("First run message" in msg.text for msg in first_run_user_messages if msg.text)
|
||||
|
||||
# Second run - should start fresh, not include first run's messages
|
||||
triage.calls.clear()
|
||||
specialist.calls.clear()
|
||||
|
||||
events = await _drain(workflow.run_stream("Second run different message"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Another message"}))
|
||||
outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)]
|
||||
assert outputs, "Second run should emit output"
|
||||
|
||||
second_run_conversation = outputs[-1].data
|
||||
assert isinstance(second_run_conversation, list)
|
||||
second_run_conv_list = cast(list[ChatMessage], second_run_conversation)
|
||||
second_run_user_messages = [msg for msg in second_run_conv_list if msg.role == Role.USER]
|
||||
assert len(second_run_user_messages) == 2, (
|
||||
"Second run should have exactly 2 user messages, not accumulate first run"
|
||||
)
|
||||
assert any("Second run different message" in msg.text for msg in second_run_user_messages if msg.text)
|
||||
assert not any("First run message" in msg.text for msg in second_run_user_messages if msg.text), (
|
||||
"Second run should NOT contain first run's messages"
|
||||
)
|
||||
|
||||
|
||||
async def test_handoff_async_termination_condition() -> None:
|
||||
"""Test that async termination conditions work correctly."""
|
||||
termination_call_count = 0
|
||||
@@ -585,7 +534,7 @@ async def test_return_to_previous_state_serialization():
|
||||
coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Snapshot the state
|
||||
state = coordinator.snapshot_state()
|
||||
state = await coordinator.on_checkpoint_save()
|
||||
|
||||
# Verify pattern metadata includes current_agent_id
|
||||
assert "metadata" in state
|
||||
@@ -603,7 +552,7 @@ async def test_return_to_previous_state_serialization():
|
||||
)
|
||||
|
||||
# Restore state
|
||||
coordinator2.restore_state(state)
|
||||
await coordinator2.on_checkpoint_restore(state)
|
||||
|
||||
# Verify current_agent_id was restored
|
||||
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
@@ -42,6 +43,11 @@ from agent_framework._workflows._magentic import ( # type: ignore[reportPrivate
|
||||
_MagenticStartMessage, # type: ignore
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def test_magentic_start_message_from_string():
|
||||
msg = _MagenticStartMessage.from_string("Do the thing")
|
||||
@@ -101,8 +107,9 @@ class FakeManager(MagenticManagerBase):
|
||||
next_speaker_name: str = "agentA"
|
||||
instruction_text: str = "Proceed with step 1"
|
||||
|
||||
def snapshot_state(self) -> dict[str, Any]:
|
||||
state = super().snapshot_state()
|
||||
@override
|
||||
def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
state = super().on_checkpoint_save()
|
||||
if self.task_ledger is not None:
|
||||
state = dict(state)
|
||||
state["task_ledger"] = {
|
||||
@@ -111,8 +118,9 @@ class FakeManager(MagenticManagerBase):
|
||||
}
|
||||
return state
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
super().restore_state(state)
|
||||
@override
|
||||
def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
super().on_checkpoint_restore(state)
|
||||
ledger_state = state.get("task_ledger")
|
||||
if isinstance(ledger_state, dict):
|
||||
ledger_dict = cast(dict[str, Any], ledger_state)
|
||||
@@ -185,7 +193,6 @@ async def test_standard_manager_progress_ledger_and_fallback():
|
||||
assert ledger2.is_request_satisfied.answer is False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
|
||||
async def test_magentic_workflow_plan_review_approval_to_completion():
|
||||
manager = FakeManager(max_round_count=10)
|
||||
wf = (
|
||||
@@ -204,7 +211,7 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
|
||||
|
||||
completed = False
|
||||
output: ChatMessage | None = None
|
||||
async for ev in wf.run_stream(
|
||||
async for ev in wf.send_responses_streaming(
|
||||
responses={req_event.request_id: MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE)}
|
||||
):
|
||||
if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE:
|
||||
@@ -218,7 +225,6 @@ async def test_magentic_workflow_plan_review_approval_to_completion():
|
||||
assert isinstance(output, ChatMessage)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Response handling refactored - responses no longer passed to run_stream()")
|
||||
async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds():
|
||||
class CountingManager(FakeManager):
|
||||
# Declare as a model field so assignment is allowed under Pydantic
|
||||
@@ -250,7 +256,7 @@ async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds()
|
||||
# Reply APPROVE with comments (no edited text). Expect one replan and no second review round.
|
||||
saw_second_review = False
|
||||
completed = False
|
||||
async for ev in wf.run_stream(
|
||||
async for ev in wf.send_responses_streaming(
|
||||
responses={
|
||||
req_event.request_id: MagenticPlanReviewReply(
|
||||
decision=MagenticPlanReviewDecision.APPROVE,
|
||||
@@ -298,7 +304,6 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result():
|
||||
assert data.role == Role.ASSISTANT
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Response handling refactored - send_responses_streaming no longer exists")
|
||||
async def test_magentic_checkpoint_resume_round_trip():
|
||||
storage = InMemoryCheckpointStorage()
|
||||
|
||||
@@ -369,7 +374,7 @@ class _DummyExec(Executor):
|
||||
pass
|
||||
|
||||
|
||||
def test_magentic_agent_executor_snapshot_roundtrip():
|
||||
async def test_magentic_agent_executor_on_checkpoint_save_and_restore_roundtrip():
|
||||
backing_executor = _DummyExec("backing")
|
||||
agent_exec = MagenticAgentExecutor(backing_executor, "agentA")
|
||||
agent_exec._chat_history.extend([ # type: ignore[reportPrivateUsage]
|
||||
@@ -377,10 +382,10 @@ def test_magentic_agent_executor_snapshot_roundtrip():
|
||||
ChatMessage(role=Role.ASSISTANT, text="world", author_name="agentA"),
|
||||
])
|
||||
|
||||
state = agent_exec.snapshot_state()
|
||||
state = await agent_exec.on_checkpoint_save()
|
||||
|
||||
restored_executor = MagenticAgentExecutor(_DummyExec("backing2"), "agentA")
|
||||
restored_executor.restore_state(state)
|
||||
await restored_executor.on_checkpoint_restore(state)
|
||||
|
||||
assert len(restored_executor._chat_history) == 2 # type: ignore[reportPrivateUsage]
|
||||
assert restored_executor._chat_history[0].text == "hello" # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -199,7 +199,10 @@ async def test_fan_out():
|
||||
|
||||
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
|
||||
# executor_b will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
|
||||
assert len(events) == 7
|
||||
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
|
||||
# This workflow will converge in 2 supersteps because executor_c will send one more message
|
||||
# after executor_b completes
|
||||
assert len(events) == 11
|
||||
|
||||
assert events.get_final_state() == WorkflowRunState.IDLE
|
||||
outputs = events.get_outputs()
|
||||
@@ -220,7 +223,9 @@ async def test_fan_out_multiple_completed_events():
|
||||
|
||||
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
|
||||
# executor_b and executor_c will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
|
||||
assert len(events) == 8
|
||||
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
|
||||
# This workflow will converge in 1 superstep because executor_a and executor_b will not send further messages
|
||||
assert len(events) == 10
|
||||
|
||||
# Multiple outputs are expected from both executors
|
||||
outputs = events.get_outputs()
|
||||
@@ -246,7 +251,8 @@ async def test_fan_in():
|
||||
|
||||
# Each executor will emit two events: ExecutorInvokedEvent and ExecutorCompletedEvent
|
||||
# aggregator will also emit a WorkflowOutputEvent (no WorkflowCompletedEvent anymore)
|
||||
assert len(events) == 9
|
||||
# Each superstep will emit also emit a WorkflowStartedEvent and WorkflowCompletedEvent
|
||||
assert len(events) == 13
|
||||
|
||||
assert events.get_final_state() == WorkflowRunState.IDLE
|
||||
outputs = events.get_outputs()
|
||||
|
||||
@@ -37,10 +37,14 @@ class WorkflowHILRequest:
|
||||
class WorkflowTestExecutor(Executor):
|
||||
"""Test executor with HIL."""
|
||||
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._data_value: str | None = None
|
||||
|
||||
@handler
|
||||
async def process(self, data: WorkflowTestData, ctx: WorkflowContext) -> None:
|
||||
"""Process data and request approval."""
|
||||
await ctx.set_executor_state({"data_value": data.value})
|
||||
self._data_value = data.value
|
||||
|
||||
# Request HIL (checkpoint created here)
|
||||
await ctx.request_info(request_data=WorkflowHILRequest(question=f"Approve {data.value}?"), response_type=str)
|
||||
@@ -50,8 +54,7 @@ class WorkflowTestExecutor(Executor):
|
||||
self, original_request: WorkflowHILRequest, response: str, ctx: WorkflowContext[str]
|
||||
) -> None:
|
||||
"""Handle HIL response."""
|
||||
state = await ctx.get_executor_state() or {}
|
||||
value = state.get("data_value", "")
|
||||
value = self._data_value or ""
|
||||
await ctx.send_message(f"{value}_approved" if response.lower() == "yes" else f"{value}_rejected")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user