Python: [BREAKING] Python: Make executor ID required, improvements around handling rehydrating checkpoints (#832)

* Make executor ID required, improvements around handling rehydrating checkpoints.

* Duplicate executor validation added

* fix remaining issues

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Evan Mattson
2025-09-20 03:57:09 +09:00
committed by GitHub
Unverified
parent 7cd45e313b
commit aba094b5cf
33 changed files with 1967 additions and 275 deletions
@@ -91,6 +91,7 @@ from ._shared_state import SharedState
from ._telemetry import EdgeGroupDeliveryStatus, WorkflowTracer, workflow_tracer
from ._validation import (
EdgeDuplicationError,
ExecutorDuplicationError,
GraphConnectivityError,
HandlerOutputAnnotationError,
TypeCompatibilityError,
@@ -118,6 +119,7 @@ __all__ = [
"EdgeGroupDeliveryStatus",
"Executor",
"ExecutorCompletedEvent",
"ExecutorDuplicationError",
"ExecutorEvent",
"ExecutorFailedEvent",
"ExecutorInvokedEvent",
@@ -87,6 +87,7 @@ from ._shared_state import SharedState
from ._telemetry import EdgeGroupDeliveryStatus, WorkflowTracer, workflow_tracer
from ._validation import (
EdgeDuplicationError,
ExecutorDuplicationError,
GraphConnectivityError,
HandlerOutputAnnotationError,
TypeCompatibilityError,
@@ -114,6 +115,7 @@ __all__ = [
"EdgeGroupDeliveryStatus",
"Executor",
"ExecutorCompletedEvent",
"ExecutorDuplicationError",
"ExecutorEvent",
"ExecutorFailedEvent",
"ExecutorInvokedEvent",
@@ -145,7 +145,10 @@ class _CallbackAggregator(Executor):
"""
def __init__(self, callback: Callable[..., Any], id: str | None = None) -> None:
super().__init__(id)
derived_id = getattr(callback, "__name__", "") or ""
if not derived_id or derived_id == "<lambda>":
derived_id = f"{type(self).__name__}_unnamed"
super().__init__(id or derived_id)
self._callback = callback
self._param_count = len(inspect.signature(callback).parameters)
@@ -2,12 +2,15 @@
import contextlib
import functools
import importlib
import inspect
import logging
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from textwrap import shorten
from types import UnionType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, get_args, get_origin, overload
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload
if TYPE_CHECKING:
from ._workflow import Workflow
@@ -17,6 +20,7 @@ from pydantic import Field
from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, AgentThread, ChatMessage
from agent_framework._pydantic import AFBaseModel
from ._checkpoint import WorkflowCheckpoint
from ._events import (
AgentRunEvent,
AgentRunUpdateEvent,
@@ -25,35 +29,63 @@ from ._events import (
RequestInfoEvent,
_framework_event_origin, # pyright: ignore[reportPrivateUsage]
)
from ._runner_context import _decode_checkpoint_value
from ._typing_utils import is_instance_of
from ._workflow_context import WorkflowContext
logger = logging.getLogger(__name__)
# region Executor
@dataclass
class PendingRequestDetails:
"""Lightweight information about a pending request captured in a checkpoint."""
request_id: str
prompt: str | None = None
draft: str | None = None
iteration: int | None = None
source_executor_id: str | None = None
original_request: "RequestInfoMessage | dict[str, Any] | None" = None
@dataclass
class WorkflowCheckpointSummary:
"""Human-readable summary of a workflow checkpoint."""
checkpoint_id: str
iteration_count: int
targets: list[str]
executor_states: list[str]
status: str
draft_preview: str | None
pending_requests: list[PendingRequestDetails]
class Executor(AFBaseModel):
"""An executor is a component that processes messages in a workflow."""
# Provide a default so static analyzers (e.g., pyright) don't require passing `id`.
# Runtime still sets a concrete value in __init__.
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
...,
min_length=1,
description="Unique identifier for the executor",
)
type_: str = Field(default="", alias="type", description="The type of executor, corresponding to the class name")
def __init__(self, id: str | None = None, **kwargs: Any) -> None:
def __init__(self, id: str, **kwargs: Any) -> None:
"""Initialize the executor with a unique identifier.
Args:
id: A unique identifier for the executor. If None, a new ID will be generated
following the format <class_name>/<uuid>.
id: A unique identifier for the executor.
kwargs: Additional keyword arguments. Unused in this implementation.
"""
executor_id = f"{self.__class__.__name__}/{uuid.uuid4()}" if id is None else id
if not id:
raise ValueError("Executor ID must be a non-empty string.")
kwargs.update({"id": executor_id})
kwargs.update({"id": id})
if "type" not in kwargs and "type_" not in kwargs:
kwargs["type_"] = self.__class__.__name__
@@ -677,17 +709,15 @@ class RequestInfoExecutor(Executor):
a response is provided externally, it emits the response as a message.
"""
def __init__(self, id: str | None = None):
"""Initialize the RequestInfoExecutor with an optional custom ID.
_PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info"
def __init__(self, id: str):
"""Initialize the RequestInfoExecutor with a unique ID.
Args:
id: Optional custom ID for this RequestInfoExecutor. If not provided,
a unique ID will be generated.
id: Unique ID for this RequestInfoExecutor.
"""
import uuid
executor_id = id or f"request_info_{uuid.uuid4().hex[:8]}"
super().__init__(id=executor_id)
super().__init__(id=id)
self._request_events: dict[str, RequestInfoEvent] = {}
self._sub_workflow_contexts: dict[str, dict[str, str]] = {}
@@ -703,6 +733,7 @@ class RequestInfoExecutor(Executor):
request_data=message,
)
self._request_events[message.request_id] = event
await self._record_pending_request_snapshot(message, source_executor_id, ctx)
await ctx.add_event(event)
@handler
@@ -748,10 +779,13 @@ class RequestInfoExecutor(Executor):
response_data: The data returned in the response.
ctx: The workflow context for sending the response.
"""
if request_id not in self._request_events:
event = self._request_events.get(request_id)
if event is None:
event = await self._rehydrate_request_event(request_id, ctx)
if event is None:
raise ValueError(f"No request found with ID: {request_id}")
event = self._request_events.pop(request_id)
self._request_events.pop(request_id, None)
# Check if this was a forwarded sub-workflow request
if request_id in self._sub_workflow_contexts:
@@ -779,6 +813,472 @@ class RequestInfoExecutor(Executor):
await ctx.send_message(correlated_response, target_id=event.source_executor_id)
await self._clear_pending_request_snapshot(request_id, ctx)
async def _record_pending_request_snapshot(
self,
request: RequestInfoMessage,
source_executor_id: str,
ctx: WorkflowContext[Any],
) -> None:
snapshot = self._build_request_snapshot(request, source_executor_id)
pending = await self._load_pending_request_state(ctx)
pending[request.request_id] = snapshot
await self._persist_pending_request_state(pending, ctx)
async def _clear_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> None:
pending = await self._load_pending_request_state(ctx)
if request_id not in pending:
return
pending.pop(request_id, None)
await self._persist_pending_request_state(pending, ctx)
async def _load_pending_request_state(self, ctx: WorkflowContext[Any]) -> dict[str, Any]:
try:
existing = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read pending request state: {exc}")
return {}
if not isinstance(existing, dict) or existing is None:
if existing not in (None, {}):
logger.warning(
f"RequestInfoExecutor {self.id} encountered non-dict pending state "
f"({type(existing).__name__}); resetting."
)
return {}
return dict(existing)
async def _persist_pending_request_state(self, pending: dict[str, Any], ctx: WorkflowContext[Any]) -> None:
await self._safe_set_shared_state(ctx, pending)
await self._safe_set_runner_state(ctx, pending)
async def _safe_set_shared_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
try:
await ctx.set_shared_state(self._PENDING_SHARED_STATE_KEY, pending)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to update shared pending state: {exc}")
async def _safe_set_runner_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
try:
await ctx.set_state({"pending_requests": pending})
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to update runner state with pending requests: {exc}")
def _build_request_snapshot(
self,
request: RequestInfoMessage,
source_executor_id: str,
) -> dict[str, Any]:
snapshot: dict[str, Any] = {
"request_id": request.request_id,
"source_executor_id": source_executor_id,
"request_type": f"{type(request).__module__}:{type(request).__name__}",
"summary": repr(request),
}
details = self._serialise_request_details(request)
if details:
snapshot["details"] = details
for key in ("prompt", "draft", "iteration"):
if key in details and key not in snapshot:
snapshot[key] = details[key]
return snapshot
def _serialise_request_details(self, request: RequestInfoMessage) -> dict[str, Any] | None:
if is_dataclass(request):
data = self._make_json_safe(asdict(request))
if isinstance(data, dict):
return cast(dict[str, Any], data)
return None
model_dump = getattr(request, "model_dump", None)
if callable(model_dump):
try:
dump = self._make_json_safe(model_dump(mode="json"))
except TypeError:
dump = self._make_json_safe(model_dump())
if isinstance(dump, dict):
return cast(dict[str, Any], dump)
return None
attrs = getattr(request, "__dict__", None)
if isinstance(attrs, dict):
cleaned = self._make_json_safe(attrs)
if isinstance(cleaned, dict):
return cast(dict[str, Any], cleaned)
return None
def _make_json_safe(self, value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Mapping):
safe_dict: dict[str, Any] = {}
for key, val in value.items():
safe_dict[str(key)] = self._make_json_safe(val)
return safe_dict
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return [self._make_json_safe(item) for item in value]
return repr(value)
async def has_pending_request(self, request_id: str, ctx: WorkflowContext[Any]) -> bool:
if request_id in self._request_events or request_id in self._sub_workflow_contexts:
return True
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
return snapshot is not None
async def _rehydrate_request_event(
self,
request_id: str,
ctx: WorkflowContext[Any],
) -> RequestInfoEvent | None:
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
if snapshot is None:
return None
source_executor_id = snapshot.get("source_executor_id")
if not isinstance(source_executor_id, str) or not source_executor_id:
return None
request = self._construct_request_from_snapshot(snapshot)
if request is None:
return None
event = RequestInfoEvent(
request_id=request_id,
source_executor_id=source_executor_id,
request_type=type(request),
request_data=request,
)
self._request_events[request_id] = event
return event
async def _get_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> dict[str, Any] | None:
pending = await self._collect_pending_request_snapshots(ctx)
snapshot = pending.get(request_id)
if snapshot is None:
return None
return snapshot
async def _collect_pending_request_snapshots(self, ctx: WorkflowContext[Any]) -> dict[str, dict[str, Any]]:
combined: dict[str, dict[str, Any]] = {}
try:
shared_pending = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read shared pending state during rehydrate: {exc}")
shared_pending = None
if isinstance(shared_pending, dict):
for key, value in shared_pending.items():
if isinstance(key, str) and isinstance(value, dict):
combined[key] = cast(dict[str, Any], value)
try:
state = await ctx.get_state()
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read runner state during rehydrate: {exc}")
state = None
if isinstance(state, dict):
state_pending = state.get("pending_requests")
if isinstance(state_pending, dict):
for key, value in state_pending.items():
if isinstance(key, str) and isinstance(value, dict) and key not in combined:
combined[key] = cast(dict[str, Any], value)
return combined
def _construct_request_from_snapshot(self, snapshot: dict[str, Any]) -> RequestInfoMessage | None:
details_raw = snapshot.get("details")
details: dict[str, Any] = cast(dict[str, Any], details_raw) if isinstance(details_raw, dict) else {}
request_cls: type[RequestInfoMessage] = RequestInfoMessage
request_type_str = snapshot.get("request_type")
if isinstance(request_type_str, str) and ":" in request_type_str:
module_name, class_name = request_type_str.split(":", 1)
try:
module = importlib.import_module(module_name)
candidate = getattr(module, class_name)
if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage):
request_cls = candidate
except Exception as exc:
logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}")
request_cls = RequestInfoMessage
request: RequestInfoMessage | None = self._instantiate_request(request_cls, details)
if request is None and request_cls is not RequestInfoMessage:
request = self._instantiate_request(RequestInfoMessage, details)
if request is None:
logger.warning(
f"RequestInfoExecutor {self.id} could not reconstruct request "
f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(details.keys())}"
)
return None
for key, value in details.items():
if key == "request_id":
continue
try:
setattr(request, key, value)
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}"
)
continue
snapshot_request_id = snapshot.get("request_id")
if isinstance(snapshot_request_id, str) and snapshot_request_id:
try:
request.request_id = snapshot_request_id
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not apply snapshot "
f"request_id to {type(request).__name__}: {exc}"
)
return request
def _instantiate_request(
self,
request_cls: type[RequestInfoMessage],
details: dict[str, Any],
) -> RequestInfoMessage | None:
try:
model_validate = getattr(request_cls, "model_validate", None)
if callable(model_validate):
return cast(RequestInfoMessage, model_validate(details))
except (TypeError, ValueError) as exc:
logger.debug(
f"RequestInfoExecutor {self.id} validation failed for {request_cls.__name__} via model_validate: {exc}"
)
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} encountered unexpected error during "
f"{request_cls.__name__}.model_validate: {exc}"
)
if is_dataclass(request_cls):
try:
field_names = {f.name for f in fields(request_cls)}
ctor_kwargs = {name: details[name] for name in field_names if name in details}
return request_cls(**ctor_kwargs) # type: ignore[call-arg]
except (TypeError, ValueError) as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not instantiate dataclass "
f"{request_cls.__name__} with snapshot data: {exc}"
)
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} encountered unexpected error "
f"constructing dataclass {request_cls.__name__}: {exc}"
)
try:
instance = request_cls() # type: ignore[call-arg]
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}"
)
return None
for key, value in details.items():
if key == "request_id":
continue
try:
setattr(instance, key, value)
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not set attribute {key} on "
f"{request_cls.__name__} during instantiation: {exc}"
)
continue
return instance
@staticmethod
def pending_requests_from_checkpoint(
checkpoint: WorkflowCheckpoint,
*,
request_executor_ids: Iterable[str] | None = None,
) -> list[PendingRequestDetails]:
executor_filter: set[str] | None = None
if request_executor_ids is not None:
executor_filter = {str(value) for value in request_executor_ids}
pending: dict[str, PendingRequestDetails] = {}
shared_map = checkpoint.shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY)
if isinstance(shared_map, Mapping):
for request_id, snapshot in shared_map.items():
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot)
for state in checkpoint.executor_states.values():
if not isinstance(state, Mapping):
continue
inner = state.get("pending_requests")
if isinstance(inner, Mapping):
for request_id, snapshot in inner.items():
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot)
for source_id, message_list in checkpoint.messages.items():
if executor_filter is not None and source_id not in executor_filter:
continue
if not isinstance(message_list, list):
continue
for message in message_list:
if not isinstance(message, Mapping):
continue
payload = _decode_checkpoint_value(message.get("data"))
RequestInfoExecutor._merge_message_payload(pending, payload, message)
return list(pending.values())
@staticmethod
def checkpoint_summary(
checkpoint: WorkflowCheckpoint,
*,
request_executor_ids: Iterable[str] | None = None,
preview_width: int = 70,
) -> WorkflowCheckpointSummary:
targets = sorted(checkpoint.messages.keys())
executor_states = sorted(checkpoint.executor_states.keys())
pending = RequestInfoExecutor.pending_requests_from_checkpoint(
checkpoint, request_executor_ids=request_executor_ids
)
draft_preview: str | None = None
for entry in pending:
if entry.draft:
draft_preview = shorten(entry.draft, width=preview_width, placeholder="")
break
status = "idle"
if pending:
status = "awaiting human response"
elif not checkpoint.messages and "finalise" in executor_states:
status = "completed"
elif checkpoint.messages:
status = "awaiting next superstep"
elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids):
status = "awaiting request delivery"
return WorkflowCheckpointSummary(
checkpoint_id=checkpoint.checkpoint_id,
iteration_count=checkpoint.iteration_count,
targets=targets,
executor_states=executor_states,
status=status,
draft_preview=draft_preview,
pending_requests=pending,
)
@staticmethod
def _merge_snapshot(
pending: dict[str, PendingRequestDetails],
request_id: str,
snapshot: Any,
) -> None:
if not request_id or not isinstance(snapshot, Mapping):
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=snapshot.get("prompt"),
draft=snapshot.get("draft"),
iteration=snapshot.get("iteration"),
source_executor_id=snapshot.get("source_executor_id"),
)
extra = snapshot.get("details")
if isinstance(extra, Mapping):
RequestInfoExecutor._apply_update(
details,
prompt=extra.get("prompt"),
draft=extra.get("draft"),
iteration=extra.get("iteration"),
)
@staticmethod
def _merge_message_payload(
pending: dict[str, PendingRequestDetails],
payload: Any,
raw_message: Mapping[str, Any],
) -> None:
if isinstance(payload, RequestResponse):
request_id = payload.request_id or RequestInfoExecutor._get_field(payload.original_request, "request_id")
if not request_id:
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=RequestInfoExecutor._get_field(payload.original_request, "prompt"),
draft=RequestInfoExecutor._get_field(payload.original_request, "draft"),
iteration=RequestInfoExecutor._get_field(payload.original_request, "iteration"),
source_executor_id=raw_message.get("source_id"),
original_request=payload.original_request,
)
elif isinstance(payload, RequestInfoMessage):
request_id = getattr(payload, "request_id", None)
if not request_id:
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=getattr(payload, "prompt", None),
draft=getattr(payload, "draft", None),
iteration=getattr(payload, "iteration", None),
source_executor_id=raw_message.get("source_id"),
original_request=payload,
)
@staticmethod
def _apply_update(
details: PendingRequestDetails,
*,
prompt: Any = None,
draft: Any = None,
iteration: Any = None,
source_executor_id: Any = None,
original_request: Any = None,
) -> None:
if prompt and not details.prompt:
details.prompt = str(prompt)
if draft and not details.draft:
details.draft = str(draft)
if iteration is not None and details.iteration is None:
coerced = RequestInfoExecutor._coerce_int(iteration)
if coerced is not None:
details.iteration = coerced
if source_executor_id and not details.source_executor_id:
details.source_executor_id = str(source_executor_id)
if original_request is not None and details.original_request is None:
details.original_request = original_request
@staticmethod
def _get_field(obj: Any, key: str) -> Any:
if obj is None:
return None
if isinstance(obj, Mapping):
return obj.get(key)
return getattr(obj, key, None)
@staticmethod
def _coerce_int(value: Any) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None
# endregion: Request Info Executor
@@ -840,7 +1340,11 @@ class AgentExecutor(Executor):
exec_id = id
else:
agent_name = agent.name
exec_id = str(agent_name) if agent_name else f"executor_{uuid.uuid4()}"
if agent_name:
exec_id = str(agent_name)
else:
logger.warning("Agent has no name, using fallback ID 'executor_unnamed'")
exec_id = "executor_unnamed"
super().__init__(exec_id)
self._agent = agent
self._agent_thread = agent_thread or self._agent.get_new_thread()
@@ -952,12 +1456,12 @@ class WorkflowExecutor(Executor):
workflow: "Workflow" = Field(description="The workflow to execute as a sub-workflow")
def __init__(self, workflow: "Workflow", id: str | None = None, **kwargs: Any):
def __init__(self, workflow: "Workflow", id: str, **kwargs: Any):
"""Initialize the WorkflowExecutor.
Args:
workflow: The workflow to execute as a sub-workflow.
id: Optional unique identifier for this executor.
id: Unique identifier for this executor.
**kwargs: Additional keyword arguments passed to the parent constructor.
"""
kwargs.update({"workflow": workflow})
@@ -1635,7 +1635,7 @@ class MagenticBuilder:
if self._enable_plan_review:
from ._executor import RequestInfoExecutor
request_info = RequestInfoExecutor()
request_info = RequestInfoExecutor(id="request_info")
workflow_builder = (
workflow_builder
# Only route plan review asks to request_info
@@ -13,7 +13,14 @@ from ._edge import EdgeGroup
from ._edge_runner import EdgeRunner, create_edge_runner
from ._events import WorkflowCompletedEvent, WorkflowEvent, _framework_event_origin
from ._executor import Executor
from ._runner_context import Message, RunnerContext
from ._runner_context import (
_DATACLASS_MARKER,
_PYDANTIC_MARKER,
CheckpointState,
Message,
RunnerContext,
_decode_checkpoint_value,
)
from ._shared_state import SharedState
from ._typing_utils import is_instance_of
from ._workflow_context import WorkflowContext
@@ -53,6 +60,7 @@ class Runner:
self._workflow_id = workflow_id
self._running = False
self._resumed_from_checkpoint = False # Track whether we resumed
self.graph_signature_hash: str | None = None
# Set workflow ID in context if provided
if workflow_id:
@@ -244,6 +252,19 @@ class Runner:
"""Inner loop to deliver a single message through an edge runner."""
return await edge_runner.send_message(message, self._shared_state, self._ctx)
def _normalize_message_payload(message: Message) -> None:
data = message.data
if not isinstance(data, dict):
return
if _PYDANTIC_MARKER not in data and _DATACLASS_MARKER not in data:
return
try:
decoded = _decode_checkpoint_value(data)
except Exception as exc: # pragma: no cover - defensive
logger.debug("Failed to decode checkpoint payload during delivery: %s", exc)
return
message.data = decoded
# Handle SubWorkflowRequestInfo messages specially
await _deliver_sub_workflow_requests(messages)
@@ -266,6 +287,7 @@ class Runner:
associated_edge_runners = self._edge_runner_map.get(source_executor_id, [])
for message in non_sub_workflow_messages:
_normalize_message_payload(message)
# Deliver a message through all edge runners associated with the source executor concurrently.
tasks = [_deliver_message_inner(edge_runner, message) for edge_runner in associated_edge_runners]
if not tasks:
@@ -332,6 +354,8 @@ class Runner:
"superstep": self._iteration,
"checkpoint_type": checkpoint_category,
}
if self.graph_signature_hash:
metadata["graph_signature"] = self.graph_signature_hash
checkpoint_id = await self._ctx.create_checkpoint(metadata=metadata)
logger.info(f"Created {checkpoint_type} checkpoint: {checkpoint_id}")
return checkpoint_id
@@ -403,14 +427,45 @@ class Runner:
return False
try:
success = await self._ctx.restore_from_checkpoint(checkpoint_id)
if not success:
checkpoint = await self._ctx.load_checkpoint(checkpoint_id)
if not checkpoint:
logger.error(f"Checkpoint {checkpoint_id} not found")
return False
graph_hash = getattr(self, "graph_signature_hash", None)
checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature")
if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash:
raise ValueError(
"Workflow graph has changed since the checkpoint was created. "
"Please rebuild the original workflow before resuming."
)
if graph_hash and not checkpoint_hash:
logger.warning(
"Checkpoint %s does not include graph signature metadata; skipping topology validation.",
checkpoint_id,
)
state: CheckpointState = {
"messages": checkpoint.messages,
"shared_state": checkpoint.shared_state,
"executor_states": checkpoint.executor_states,
"iteration_count": checkpoint.iteration_count,
"max_iterations": checkpoint.max_iterations,
}
await self._ctx.set_checkpoint_state(state)
if checkpoint.workflow_id:
self._ctx.set_workflow_id(checkpoint.workflow_id)
self._workflow_id = checkpoint.workflow_id
await self._restore_shared_state_from_context()
self.mark_resumed() # mark resumed; iteration/max already restored from context
self.mark_resumed(
iteration=checkpoint.iteration_count,
max_iterations=checkpoint.max_iterations,
)
logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
return True
except ValueError:
raise
except Exception as e:
logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}")
return False
@@ -3,6 +3,7 @@
import asyncio
import importlib
import logging
import sys
import uuid
from collections import defaultdict
from dataclasses import dataclass, fields, is_dataclass
@@ -60,6 +61,39 @@ _MAX_ENCODE_DEPTH = 100
_CYCLE_SENTINEL = "<cycle>"
def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None:
if not isinstance(cls, type):
logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}")
return None
if isinstance(payload, dict):
try:
return cls(**payload) # type: ignore[arg-type]
except TypeError as exc:
logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}")
except Exception as exc:
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}")
try:
instance = object.__new__(cls)
except Exception as exc:
logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}")
return None
for key, val in payload.items():
try:
setattr(instance, key, val)
except Exception as exc:
logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}")
return instance
try:
return cls(payload) # type: ignore[call-arg]
except TypeError as exc:
logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}")
except Exception as exc:
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}")
return None
def _is_pydantic_model(obj: object) -> bool:
"""Best-effort check for Pydantic models (e.g., AFBaseModel).
@@ -99,7 +133,7 @@ def _encode_checkpoint_value(value: Any) -> Any:
"value": v.model_dump(mode="json"),
}
except Exception as exc: # best-effort fallback
logger.debug("Pydantic model_dump failed for %s: %s", cls, exc)
logger.debug(f"Pydantic model_dump failed for {cls}: {exc}")
return str(v)
# Dataclasses (instances only)
@@ -178,36 +212,32 @@ def _decode_checkpoint_value(value: Any) -> Any:
if isinstance(type_key, str):
try:
module_name, class_name = type_key.split(":", 1)
module = importlib.import_module(module_name)
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
cls: Any = getattr(module, class_name)
if hasattr(cls, "model_validate"):
return cls.model_validate(raw)
except Exception as exc:
logger.debug(
"Failed to decode pydantic model %s: %s; returning raw value",
type_key,
exc,
)
logger.debug(f"Failed to decode pydantic model {type_key}: {exc}; returning raw value")
# Dataclass marker handling
if _DATACLASS_MARKER in value_dict and "value" in value_dict:
type_key_dc: str | None = value_dict.get(_DATACLASS_MARKER) # type: ignore[assignment]
raw_dc: Any = value_dict.get("value")
decoded_raw = _decode_checkpoint_value(raw_dc)
if isinstance(type_key_dc, str):
try:
module_name, class_name = type_key_dc.split(":", 1)
module = importlib.import_module(module_name)
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
cls_dc: Any = getattr(module, class_name)
decoded_raw = _decode_checkpoint_value(raw_dc)
if isinstance(decoded_raw, dict):
return cls_dc(**decoded_raw)
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
if constructed is not None:
return constructed
except Exception as exc:
logger.debug(
"Failed to decode dataclass %s: %s; returning raw value",
type_key_dc,
exc,
)
# Fallback to decoded raw value
return _decode_checkpoint_value(raw_dc)
logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value")
return decoded_raw
# Regular dict: decode recursively
decoded: dict[str, Any] = {}
@@ -338,6 +368,10 @@ class RunnerContext(Protocol):
"""
...
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
"""Load a checkpoint without mutating the current context state."""
...
async def get_checkpoint_state(self) -> CheckpointState:
"""Get the current state of the context suitable for checkpointing."""
...
@@ -409,7 +443,7 @@ class InProcRunnerContext:
return
except Exception as exc: # pragma: no cover - defensive logging path
# Best-effort filtering only; never block event delivery on filtering errors
logger.debug("Error while filtering event %r: %s", event, exc, exc_info=True)
logger.debug(f"Error while filtering event {event!r}: {exc}", exc_info=True)
await self._event_queue.put(event)
@@ -497,6 +531,11 @@ class InProcRunnerContext:
logger.info(f"Restored state from checkpoint {checkpoint_id}'")
return True
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
if not self._checkpoint_storage:
raise ValueError("Checkpoint storage not configured")
return await self._checkpoint_storage.load_checkpoint(checkpoint_id)
async def get_checkpoint_state(self) -> CheckpointState:
serializable_messages: dict[str, list[dict[str, Any]]] = {}
for source_id, message_list in self._messages.items():
@@ -1,7 +1,58 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from dataclasses import fields, is_dataclass
from typing import Any, Union, get_args, get_origin
logger = logging.getLogger(__name__)
def _coerce_to_type(value: Any, target_type: type) -> Any | None:
"""Best-effort conversion of value into target_type."""
if isinstance(value, target_type):
return value
# Convert dataclass instances or objects with __dict__ into dict first
if not isinstance(value, dict):
if is_dataclass(value):
value = {f.name: getattr(value, f.name) for f in fields(value)}
else:
value_dict = getattr(value, "__dict__", None)
if isinstance(value_dict, dict):
value = dict(value_dict)
if isinstance(value, dict):
ctor_kwargs: dict[str, Any] = dict(value)
if is_dataclass(target_type):
field_names = {f.name for f in fields(target_type)}
ctor_kwargs = {k: v for k, v in value.items() if k in field_names}
try:
return target_type(**ctor_kwargs) # type: ignore[arg-type]
except TypeError as exc:
logger.debug(f"_coerce_to_type could not call {target_type.__name__}(**..): {exc}")
except Exception as exc: # pragma: no cover - unexpected constructor failure
logger.warning(
f"_coerce_to_type encountered unexpected error calling {target_type.__name__} constructor: {exc}"
)
try:
instance: Any = object.__new__(target_type)
except Exception as exc: # pragma: no cover - pathological type
logger.debug(f"_coerce_to_type could not allocate {target_type.__name__} without __init__: {exc}")
return None
for key, val in value.items():
try:
setattr(instance, key, val)
except Exception as exc:
logger.debug(
f"_coerce_to_type could not set {target_type.__name__}.{key} during fallback assignment: {exc}"
)
continue
return instance
return None
def is_instance_of(data: Any, target_type: type) -> bool:
"""Check if the data is an instance of the target type.
@@ -63,7 +114,10 @@ def is_instance_of(data: Any, target_type: type) -> bool:
and data.original_request is not None
and not is_instance_of(data.original_request, request_type)
):
return False
coerced = _coerce_to_type(data.original_request, request_type)
if coerced is None:
return False
data.original_request = coerced
if hasattr(data, "data") and data.data is not None and not is_instance_of(data.data, response_type):
return False
return True
@@ -34,6 +34,7 @@ class ValidationTypeEnum(Enum):
"""Enumeration of workflow validation types."""
EDGE_DUPLICATION = "EDGE_DUPLICATION"
EXECUTOR_DUPLICATION = "EXECUTOR_DUPLICATION"
TYPE_COMPATIBILITY = "TYPE_COMPATIBILITY"
GRAPH_CONNECTIVITY = "GRAPH_CONNECTIVITY"
HANDLER_OUTPUT_ANNOTATION = "HANDLER_OUTPUT_ANNOTATION"
@@ -63,6 +64,20 @@ class EdgeDuplicationError(WorkflowValidationError):
self.edge_id = edge_id
class ExecutorDuplicationError(WorkflowValidationError):
"""Exception raised when duplicate executor identifiers are detected."""
def __init__(self, executor_id: str):
super().__init__(
message=(
f"Duplicate executor id detected: '{executor_id}'. Executor ids must be globally unique within a "
"workflow."
),
validation_type=ValidationTypeEnum.EXECUTOR_DUPLICATION,
)
self.executor_id = executor_id
class TypeCompatibilityError(WorkflowValidationError):
"""Exception raised when type incompatibility is detected between connected executors."""
@@ -133,10 +148,17 @@ class WorkflowGraphValidator:
def __init__(self) -> None:
self._edges: list[Edge] = []
self._executors: dict[str, Executor] = {}
self._duplicate_executor_ids: set[str] = set()
self._start_executor_ref: Executor | str | None = None
# region Core Validation Methods
def validate_workflow(
self, edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str
self,
edge_groups: Sequence[EdgeGroup],
executors: dict[str, Executor],
start_executor: Executor | str,
*,
duplicate_executor_ids: Sequence[str] | None = None,
) -> None:
"""Validate the entire workflow graph.
@@ -144,6 +166,7 @@ class WorkflowGraphValidator:
edge_groups: list of edge groups in the workflow
executors: Map of executor IDs to executor instances
start_executor: The starting executor (can be instance or ID)
duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate
Raises:
WorkflowValidationError: If any validation fails
@@ -151,6 +174,8 @@ class WorkflowGraphValidator:
self._executors = executors
self._edges = [edge for group in edge_groups for edge in group.edges]
self._edge_groups = edge_groups
self._duplicate_executor_ids = set(duplicate_executor_ids or [])
self._start_executor_ref = start_executor
# If only the start executor exists, add it to the executor map
# Handle the special case where the workflow consists of only a single executor and no edges.
@@ -185,6 +210,7 @@ class WorkflowGraphValidator:
)
# Run all checks
self._validate_executor_id_uniqueness(start_executor_id)
self._validate_edge_duplication()
self._validate_handler_output_annotations()
self._validate_type_compatibility()
@@ -353,6 +379,26 @@ class WorkflowGraphValidator:
# endregion
def _validate_executor_id_uniqueness(self, start_executor_id: str) -> None:
"""Ensure executor identifiers are unique throughout the workflow graph."""
duplicates: set[str] = set(self._duplicate_executor_ids)
id_counts: defaultdict[str, int] = defaultdict(int)
for key, executor in self._executors.items():
id_counts[executor.id] += 1
if key != executor.id:
duplicates.add(executor.id)
duplicates.update({executor_id for executor_id, count in id_counts.items() if count > 1})
if isinstance(self._start_executor_ref, Executor):
mapped = self._executors.get(start_executor_id)
if mapped is not None and mapped is not self._start_executor_ref:
duplicates.add(start_executor_id)
if duplicates:
raise ExecutorDuplicationError(sorted(duplicates)[0])
# region Edge and Type Validation
def _validate_edge_duplication(self) -> None:
"""Validate that there are no duplicate edges in the workflow.
@@ -793,7 +839,11 @@ class WorkflowGraphValidator:
def validate_workflow_graph(
edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str
edge_groups: Sequence[EdgeGroup],
executors: dict[str, Executor],
start_executor: Executor | str,
*,
duplicate_executor_ids: Sequence[str] | None = None,
) -> None:
"""Convenience function to validate a workflow graph.
@@ -801,9 +851,15 @@ def validate_workflow_graph(
edge_groups: list of edge groups in the workflow
executors: Map of executor IDs to executor instances
start_executor: The starting executor (can be instance or ID)
duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate
Raises:
WorkflowValidationError: If any validation fails
"""
validator = WorkflowGraphValidator()
validator.validate_workflow(edge_groups, executors, start_executor)
validator.validate_workflow(
edge_groups,
executors,
start_executor,
duplicate_executor_ids=duplicate_executor_ids,
)
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import hashlib
import json
import logging
import sys
import uuid
@@ -177,6 +179,12 @@ class Workflow(AFBaseModel):
workflow_id=id,
)
# Capture a canonical fingerprint of the workflow graph so checkpoints
# can assert they are resumed with an equivalent topology.
self._graph_signature = self._compute_graph_signature()
self._graph_signature_hash = self._hash_graph_signature(self._graph_signature)
self._runner.graph_signature_hash = self._graph_signature_hash
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
"""Custom serialization that properly handles WorkflowExecutor nested workflows."""
data = super().model_dump(**kwargs)
@@ -377,17 +385,26 @@ class Workflow(AFBaseModel):
request_info_executor = self._find_request_info_executor()
if request_info_executor:
for request_id, response_data in responses.items():
ctx: WorkflowContext[Any] = WorkflowContext(
request_info_executor.id,
[self.__class__.__name__],
self._shared_state,
self._runner.context,
trace_contexts=None, # No parent trace context for new workflow span
source_span_ids=None, # No source span for response handling
)
if not await request_info_executor.has_pending_request(request_id, ctx):
logger.debug(
f"Skipping pre-supplied response for request {request_id}; no pending request found "
f"after checkpoint restoration."
)
continue
await request_info_executor.handle_response(
response_data,
request_id,
WorkflowContext(
request_info_executor.id,
[self.__class__.__name__],
self._shared_state,
self._runner.context,
trace_contexts=None, # No parent trace context for new workflow span
source_span_ids=None, # No source span for response handling
),
ctx,
)
async for event in self._run_workflow_with_tracing(
@@ -590,6 +607,19 @@ class Workflow(AFBaseModel):
if not checkpoint:
return False
graph_hash = getattr(self._runner, "graph_signature_hash", None)
checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature")
if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash:
raise ValueError(
"Workflow graph has changed since the checkpoint was created. "
"Please rebuild the original workflow before resuming."
)
if graph_hash and not checkpoint_hash:
logger.warning(
f"Checkpoint {checkpoint_id} does not include graph signature metadata; "
f"skipping topology validation."
)
temp_context = InProcRunnerContext(checkpoint_storage)
state: CheckpointState = {
"messages": checkpoint.messages,
@@ -608,10 +638,9 @@ class Workflow(AFBaseModel):
return True
except ValueError:
raise
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Failed to restore from external checkpoint {checkpoint_id}: {e}")
return False
@@ -632,7 +661,7 @@ class Workflow(AFBaseModel):
self._shared_state._state.clear() # type: ignore[attr-defined]
self._shared_state._state.update(shared_state_data) # type: ignore[attr-defined]
except Exception as exc: # pragma: no cover
logger.debug("Failed to restore shared_state during external restore: %s", exc)
logger.debug(f"Failed to restore shared_state during external restore: {exc}")
# Restore executor states into the context so ctx.get_state() calls after resume succeed
try:
@@ -641,9 +670,9 @@ class Workflow(AFBaseModel):
try:
await self._runner.context.set_state(exec_id, state)
except Exception as exc: # pragma: no cover - ignore per-executor failures
logger.debug("Failed to restore executor state for %s during external restore: %s", exec_id, exc)
logger.debug(f"Failed to restore executor state for {exec_id} during external restore: {exc}")
except Exception as exc: # pragma: no cover
logger.debug("Failed to iterate executor_states during external restore: %s", exc)
logger.debug(f"Failed to iterate executor_states during external restore: {exc}")
# Transfer pending messages into the context for delivery in the next superstep
messages_data = restored_state["messages"]
@@ -671,6 +700,71 @@ class Workflow(AFBaseModel):
)
)
# Graph signature helpers
def _compute_graph_signature(self) -> dict[str, Any]:
"""Build a canonical fingerprint of the workflow graph topology for checkpoint validation.
This creates a minimal, stable representation that captures only the structural
elements of the workflow (executor types, edge relationships, topology) while
ignoring data/state changes. Used to verify that a workflow's structure hasn't
changed when resuming from checkpoints.
"""
executors_signature = {
executor_id: f"{executor.__class__.__module__}.{executor.__class__.__name__}"
for executor_id, executor in self.executors.items()
}
edge_groups_signature: list[dict[str, Any]] = []
for group in self.edge_groups:
edges = [
{
"source": edge.source_id,
"target": edge.target_id,
"condition": getattr(edge, "condition_name", None),
}
for edge in group.edges
]
edges.sort(key=lambda e: (e["source"], e["target"], e["condition"] or ""))
group_info: dict[str, Any] = {
"group_type": group.__class__.__name__,
"sources": sorted(group.source_executor_ids),
"targets": sorted(group.target_executor_ids),
"edges": edges,
}
if isinstance(group, FanOutEdgeGroup):
group_info["selection_func"] = group.selection_func_name
edge_groups_signature.append(group_info)
edge_groups_signature.sort(
key=lambda info: (
info["group_type"],
tuple(info["sources"]),
tuple(info["targets"]),
json.dumps(info["edges"], sort_keys=True),
json.dumps(info.get("selection_func")),
)
)
return {
"start_executor": self.start_executor_id,
"executors": executors_signature,
"edge_groups": edge_groups_signature,
"max_iterations": self.max_iterations,
}
@staticmethod
def _hash_graph_signature(signature: dict[str, Any]) -> str:
canonical = json.dumps(signature, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
@property
def graph_signature_hash(self) -> str:
return self._graph_signature_hash
def as_agent(self, name: str | None = None) -> WorkflowAgent:
"""Create a WorkflowAgent that wraps this workflow.
@@ -699,6 +793,7 @@ class WorkflowBuilder:
"""Initialize the WorkflowBuilder with an empty list of edges and no starting executor."""
self._edge_groups: list[EdgeGroup] = []
self._executors: dict[str, Executor] = {}
self._duplicate_executor_ids: set[str] = set()
self._start_executor: Executor | str | None = None
self._checkpoint_storage: CheckpointStorage | None = None
self._max_iterations: int = max_iterations
@@ -712,7 +807,11 @@ class WorkflowBuilder:
def _add_executor(self, executor: Executor) -> str:
"""Add an executor to the map and return its ID."""
self._executors[executor.id] = executor
existing = self._executors.get(executor.id)
if existing is not None and existing is not executor:
self._duplicate_executor_ids.add(executor.id)
else:
self._executors[executor.id] = executor
return executor.id
def _maybe_wrap_agent(self, candidate: Executor | AgentProtocol) -> Executor:
@@ -739,7 +838,10 @@ class WorkflowBuilder:
if name:
proposed_id = str(name)
if proposed_id in self._executors:
proposed_id = f"{proposed_id}-{uuid.uuid4().hex[:8]}"
raise ValueError(
f"Duplicate executor ID '{proposed_id}' from agent name. "
"Agent names must be unique within a workflow."
)
wrapper = AgentExecutor(candidate, id=proposed_id, streaming=True)
self._agent_wrappers[id(candidate)] = wrapper
return wrapper
@@ -934,8 +1036,9 @@ class WorkflowBuilder:
self._start_executor = wrapped
# Ensure the start executor is present in the executor map so validation succeeds
# even if no edges are added yet, or before edges wrap the same agent again.
if wrapped.id not in self._executors:
self._executors[wrapped.id] = wrapped
existing = self._executors.get(wrapped.id)
if existing is not wrapped:
self._add_executor(wrapped)
return self
def set_max_iterations(self, max_iterations: int) -> Self:
@@ -986,7 +1089,12 @@ class WorkflowBuilder:
)
# Perform validation before creating the workflow
validate_workflow_graph(self._edge_groups, self._executors, self._start_executor)
validate_workflow_graph(
self._edge_groups,
self._executors,
self._start_executor,
duplicate_executor_ids=tuple(self._duplicate_executor_ids),
)
# Add validation completed event
workflow_tracer.add_build_event("build.validation_completed")
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft. All rights reserved.
from dataclasses import dataclass
from typing import Any, cast
from agent_framework._workflow._executor import RequestInfoMessage, RequestResponse
from agent_framework._workflow._runner_context import _decode_checkpoint_value, _encode_checkpoint_value # type: ignore
from agent_framework._workflow._typing_utils import is_instance_of
@dataclass(kw_only=True)
class SampleRequest(RequestInfoMessage):
prompt: str
def test_decode_dataclass_with_nested_request() -> None:
original = RequestResponse[SampleRequest, str].handled("approve")
original = RequestResponse[SampleRequest, str].with_correlation(
original,
SampleRequest(request_id="abc", prompt="prompt"),
"abc",
)
encoded = _encode_checkpoint_value(original)
decoded = cast(RequestResponse[SampleRequest, str], _decode_checkpoint_value(encoded))
assert isinstance(decoded, RequestResponse)
assert decoded.data == "approve"
assert decoded.request_id == "abc"
assert isinstance(decoded.original_request, SampleRequest)
assert decoded.original_request.prompt == "prompt"
def test_is_instance_of_coerces_request_response_original_request_dict() -> None:
response = RequestResponse[SampleRequest, str].handled("approve")
response = RequestResponse[SampleRequest, str].with_correlation(
response,
SampleRequest(request_id="req-1", prompt="prompt"),
"req-1",
)
# Simulate checkpoint decode fallback leaving a dict
response.original_request = cast(
Any,
{
"request_id": "req-1",
"prompt": "prompt",
},
)
assert is_instance_of(response, RequestResponse[SampleRequest, str])
assert isinstance(response.original_request, SampleRequest)
@@ -0,0 +1,73 @@
# Copyright (c) Microsoft. All rights reserved.
import pytest
from agent_framework import WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, handler
from agent_framework._workflow._checkpoint import InMemoryCheckpointStorage
from agent_framework._workflow._executor import Executor
class StartExecutor(Executor):
@handler
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message, target_id="finish")
class FinishExecutor(Executor):
@handler
async def finish(self, message: str, ctx: WorkflowContext[None]) -> None:
await ctx.add_event(WorkflowCompletedEvent(message))
def build_workflow(storage: InMemoryCheckpointStorage, finish_id: str = "finish"):
start = StartExecutor(id="start")
finish = FinishExecutor(id=finish_id)
builder = WorkflowBuilder(max_iterations=3).set_start_executor(start).add_edge(start, finish)
builder = builder.with_checkpointing(checkpoint_storage=storage)
return builder.build()
async def test_resume_fails_when_graph_mismatch() -> None:
storage = InMemoryCheckpointStorage()
workflow = build_workflow(storage, finish_id="finish")
# Run once to create checkpoints
_ = [event async for event in workflow.run_stream("hello")] # noqa: F841
checkpoints = await storage.list_checkpoints()
assert checkpoints, "expected at least one checkpoint to be created"
target_checkpoint = checkpoints[-1]
# Build a structurally different workflow (different finish executor id)
mismatched_workflow = build_workflow(storage, finish_id="finish_alt")
with pytest.raises(ValueError, match="Workflow graph has changed"):
_ = [
event
async for event in mismatched_workflow.run_stream_from_checkpoint(
target_checkpoint.checkpoint_id,
checkpoint_storage=storage,
)
]
async def test_resume_succeeds_when_graph_matches() -> None:
storage = InMemoryCheckpointStorage()
workflow = build_workflow(storage, finish_id="finish")
_ = [event async for event in workflow.run_stream("hello")] # noqa: F841
checkpoints = sorted(await storage.list_checkpoints(), key=lambda c: c.timestamp)
target_checkpoint = checkpoints[0]
resumed_workflow = build_workflow(storage, finish_id="finish")
events = [
event
async for event in resumed_workflow.run_stream_from_checkpoint(
target_checkpoint.checkpoint_id,
checkpoint_storage=storage,
)
]
assert any(isinstance(event, WorkflowCompletedEvent) for event in events)
@@ -126,3 +126,17 @@ async def test_concurrent_custom_aggregator_sync_callback_is_used() -> None:
assert completed is not None
assert isinstance(completed.data, str)
assert completed.data == "One | Two"
def test_concurrent_custom_aggregator_uses_callback_name_for_id() -> None:
e1 = _FakeAgentExec("agentA", "One")
e2 = _FakeAgentExec("agentB", "Two")
def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override]
return str(len(results))
wf = ConcurrentBuilder().participants([e1, e2]).with_aggregator(summarize).build()
assert "summarize" in wf.executors
aggregator = wf.executors["summarize"]
assert aggregator.id == "summarize"
@@ -5,16 +5,16 @@ import pytest
from agent_framework import Executor, WorkflowContext, handler
def test_executor_without_handlers():
"""Test that an executor without handlers raises an error when trying to run."""
def test_executor_without_id():
"""Test that an executor without an ID raises an error when trying to run."""
class MockExecutorWithoutHandlers(Executor):
class MockExecutorWithoutID(Executor):
"""A mock executor that does not implement any handlers."""
pass
with pytest.raises(ValueError):
MockExecutorWithoutHandlers()
MockExecutorWithoutID(id="")
def test_executor_handler_without_annotations():
@@ -61,7 +61,7 @@ def test_executor_with_valid_handlers():
"""Another mock handler with a valid signature."""
pass
executor = MockExecutorWithValidHandlers()
executor = MockExecutorWithValidHandlers(id="test")
assert executor.id is not None
assert len(executor._handlers) == 2 # type: ignore
assert executor.can_handle("text") is True
@@ -85,7 +85,7 @@ def test_executor_handlers_with_output_types():
"""A mock handler that outputs an integer."""
pass
executor = MockExecutorWithOutputTypes()
executor = MockExecutorWithOutputTypes(id="test")
assert len(executor._handlers) == 2 # type: ignore
string_handler = executor._handlers[str] # type: ignore
@@ -0,0 +1,223 @@
# Copyright (c) Microsoft. All rights reserved.
from dataclasses import dataclass
from typing import Any
import pytest
from agent_framework._workflow._checkpoint import WorkflowCheckpoint
from agent_framework._workflow._events import WorkflowEvent
from agent_framework._workflow._executor import (
PendingRequestDetails,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
)
from agent_framework._workflow._runner_context import CheckpointState, Message, _encode_checkpoint_value # type: ignore
from agent_framework._workflow._shared_state import SharedState
from agent_framework._workflow._workflow_context import WorkflowContext
PENDING_STATE_KEY = RequestInfoExecutor._PENDING_SHARED_STATE_KEY # pyright: ignore[reportPrivateUsage]
class _StubRunnerContext:
"""Minimal runner context stub for exercising WorkflowContext helpers."""
def __init__(self, stored_state: dict[str, Any] | None = None) -> None:
self._state = stored_state or {}
async def send_message(self, message: Message) -> None: # pragma: no cover - unused in tests
return None
async def drain_messages(self) -> dict[str, list[Message]]: # pragma: no cover - unused
return {}
async def has_messages(self) -> bool: # pragma: no cover - unused
return False
async def add_event(self, event: WorkflowEvent) -> None: # pragma: no cover - unused
return None
async def drain_events(self) -> list[WorkflowEvent]: # pragma: no cover - unused
return []
async def has_events(self) -> bool: # pragma: no cover - unused
return False
async def next_event(self) -> WorkflowEvent: # pragma: no cover - unused
raise RuntimeError("Not implemented in stub context")
async def get_state(self, executor_id: str) -> dict[str, Any] | None: # pragma: no cover - trivial
return self._state
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None: # pragma: no cover - unused
self._state = state
def has_checkpointing(self) -> bool: # pragma: no cover - unused
return False
def set_workflow_id(self, workflow_id: str) -> None: # pragma: no cover - unused
pass
def reset_for_new_run(self, workflow_shared_state: SharedState | None = None) -> None: # pragma: no cover - unused
pass
async def create_checkpoint(self, metadata: dict[str, Any] | None = None) -> str: # pragma: no cover - unused
raise RuntimeError("Checkpointing not supported in stub context")
async def restore_from_checkpoint(self, checkpoint_id: str) -> bool: # pragma: no cover - unused
return False
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pragma: no cover - unused
return None
async def get_checkpoint_state(self) -> CheckpointState: # pragma: no cover - unused
return {} # type: ignore[return-value]
async def set_checkpoint_state(self, state: CheckpointState) -> None: # pragma: no cover - unused
pass
@dataclass(kw_only=True)
class SimpleApproval(RequestInfoMessage):
prompt: str = ""
draft: str = ""
iteration: int = 0
@pytest.mark.asyncio
async def test_rehydrate_falls_back_when_request_type_missing() -> None:
"""Rehydration should succeed even if the original request type cannot be imported.
This simulates resuming a workflow where the HumanApprovalRequest class is unavailable
in the current process (e.g., defined in __main__ during the original run).
"""
request_id = "request-123"
snapshot = {
"request_id": request_id,
"source_executor_id": "review_gateway",
"request_type": "nonexistent.module:MissingRequest",
"summary": "...",
"details": {
"request_id": request_id,
"prompt": "Review draft",
"draft": "Draft text",
"iteration": 2,
},
}
shared_state = SharedState()
async with shared_state.hold():
await shared_state.set_within_hold(
PENDING_STATE_KEY,
{request_id: snapshot},
)
runner_ctx = _StubRunnerContext({"pending_requests": {request_id: snapshot}})
ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx)
executor = RequestInfoExecutor(id="request_info")
event = await executor._rehydrate_request_event(request_id, ctx) # pyright: ignore[reportPrivateUsage]
assert event is not None
assert event.request_id == request_id
assert isinstance(event.data, RequestInfoMessage)
assert getattr(event.data, "prompt", None) == "Review draft"
assert getattr(event.data, "iteration", None) == 2
@pytest.mark.asyncio
async def test_has_pending_request_detects_snapshot() -> None:
request_id = "req-pending"
snapshot = {
"request_id": request_id,
"source_executor_id": "review_gateway",
"details": {
"request_id": request_id,
"prompt": "Review",
"draft": "Draft",
},
}
shared_state = SharedState()
async with shared_state.hold():
await shared_state.set_within_hold(
PENDING_STATE_KEY,
{request_id: snapshot},
)
runner_ctx = _StubRunnerContext({"pending_requests": {request_id: snapshot}})
ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx)
executor = RequestInfoExecutor(id="request_info")
assert await executor.has_pending_request(request_id, ctx)
@pytest.mark.asyncio
async def test_has_pending_request_false_when_snapshot_absent() -> None:
shared_state = SharedState()
runner_ctx = _StubRunnerContext({"pending_requests": {}})
ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx)
executor = RequestInfoExecutor(id="request_info")
assert not await executor.has_pending_request("missing", ctx)
def test_pending_requests_from_checkpoint_and_summary() -> None:
request = SimpleApproval(prompt="Review draft", draft="Draft text", iteration=3)
request.request_id = "req-42"
response = RequestResponse[SimpleApproval, str].handled("approve")
response = RequestResponse[SimpleApproval, str].with_correlation(
response,
request,
request.request_id,
)
encoded_response = _encode_checkpoint_value(response)
checkpoint = WorkflowCheckpoint(
checkpoint_id="cp-1",
workflow_id="wf",
messages={
"request_info": [
{
"data": encoded_response,
"source_id": "request_info",
"target_id": "review_gateway",
}
]
},
shared_state={
PENDING_STATE_KEY: {
request.request_id: {
"request_id": request.request_id,
"prompt": request.prompt,
"draft": request.draft,
"iteration": request.iteration,
"source_executor_id": "review_gateway",
}
}
},
executor_states={},
iteration_count=1,
)
pending = RequestInfoExecutor.pending_requests_from_checkpoint(checkpoint)
assert len(pending) == 1
entry = pending[0]
assert isinstance(entry, PendingRequestDetails)
assert entry.request_id == "req-42"
assert entry.prompt == "Review draft"
assert entry.draft == "Draft text"
assert entry.iteration == 3
assert entry.original_request is not None
summary = RequestInfoExecutor.checkpoint_summary(checkpoint)
assert summary.checkpoint_id == "cp-1"
assert summary.status == "awaiting human response"
assert summary.pending_requests[0].request_id == "req-42"
@@ -605,10 +605,7 @@ class TestSerializationWorkflowClasses:
executor = SampleExecutor(id="valid-id")
assert executor.id == "valid-id"
# Test validation failure for empty id - pydantic automatically validates min_length=1
from pydantic import ValidationError
with pytest.raises(ValidationError):
with pytest.raises(ValueError):
SampleExecutor(id="")
def test_edge_field_validation(self) -> None:
@@ -200,7 +200,7 @@ async def test_sub_workflow_with_interception():
# Create parent workflow with interception
parent = ParentOrchestrator(approved_domains={"example.com", "internal.org"})
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
parent_request_info = RequestInfoExecutor()
parent_request_info = RequestInfoExecutor(id="request_info")
main_workflow = (
WorkflowBuilder()
@@ -280,7 +280,7 @@ async def test_conditional_forwarding() -> None:
# Setup workflows
email_validator = EmailValidator()
request_info = RequestInfoExecutor()
request_info = RequestInfoExecutor(id="request_info")
validation_workflow = (
WorkflowBuilder()
@@ -292,7 +292,7 @@ async def test_conditional_forwarding() -> None:
parent = ConditionalParent()
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
parent_request_info = RequestInfoExecutor()
parent_request_info = RequestInfoExecutor(id="request_info")
main_workflow = (
WorkflowBuilder()
@@ -364,7 +364,7 @@ async def test_workflow_scoped_interception() -> None:
# Create two identical sub-workflows
def create_validation_workflow():
validator = EmailValidator()
request_info = RequestInfoExecutor()
request_info = RequestInfoExecutor(id="request_info")
return (
WorkflowBuilder()
.set_start_executor(validator)
@@ -379,7 +379,7 @@ async def test_workflow_scoped_interception() -> None:
parent = MultiWorkflowParent()
executor_a = WorkflowExecutor(workflow_a, id="workflow_a")
executor_b = WorkflowExecutor(workflow_b, id="workflow_b")
parent_request_info = RequestInfoExecutor()
parent_request_info = RequestInfoExecutor(id="request_info")
main_workflow = (
WorkflowBuilder()
@@ -8,6 +8,7 @@ import pytest
from agent_framework import (
EdgeDuplicationError,
Executor,
ExecutorDuplicationError,
GraphConnectivityError,
TypeCompatibilityError,
ValidationTypeEnum,
@@ -79,6 +80,17 @@ def test_valid_workflow_passes_validation():
assert workflow is not None
def test_duplicate_executor_ids_fail_validation():
executor1 = StringExecutor(id="dup")
executor2 = IntExecutor(id="dup")
with pytest.raises(ExecutorDuplicationError) as exc_info:
(WorkflowBuilder().add_edge(executor1, executor2).set_start_executor(executor1).build())
assert exc_info.value.executor_id == "dup"
assert exc_info.value.validation_type == ValidationTypeEnum.EXECUTOR_DUPLICATION
def test_edge_duplication_validation_fails():
executor1 = StringExecutor(id="executor1")
executor2 = StringExecutor(id="executor2")
@@ -163,7 +163,7 @@ async def test_workflow_send_responses_streaming():
"""Test the workflow run with approval."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = MockExecutorRequestApproval(id="executor_b")
request_info_executor = RequestInfoExecutor()
request_info_executor = RequestInfoExecutor(id="request_info")
workflow = (
WorkflowBuilder()
@@ -195,7 +195,7 @@ async def test_workflow_send_responses():
"""Test the workflow run with approval."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = MockExecutorRequestApproval(id="executor_b")
request_info_executor = RequestInfoExecutor()
request_info_executor = RequestInfoExecutor(id="request_info")
workflow = (
WorkflowBuilder()
@@ -11,6 +11,7 @@ from agent_framework import (
AgentRunUpdateEvent,
ChatMessage,
Executor,
FunctionCallContent,
FunctionResultContent,
RequestInfoExecutor,
RequestInfoMessage,
@@ -94,8 +95,8 @@ class TestWorkflowAgent:
assert len(result.messages) >= 2, f"Expected at least 2 messages, got {len(result.messages)}"
# Find messages from each executor
step1_messages = []
step2_messages = []
step1_messages: list[ChatMessage] = []
step2_messages: list[ChatMessage] = []
for message in result.messages:
first_content = message.contents[0]
@@ -111,8 +112,8 @@ class TestWorkflowAgent:
assert len(step2_messages) >= 1, "Should have received message from Step2 executor"
# Verify the processing worked for both
step1_text = step1_messages[0].contents[0].text
step2_text = step2_messages[0].contents[0].text
step1_text: str = step1_messages[0].contents[0].text # type: ignore[attr-defined]
step2_text: str = step2_messages[0].contents[0].text # type: ignore[attr-defined]
assert "Step1: Hello World" in step1_text
assert "Step2: Step1: Hello World" in step2_text
@@ -128,7 +129,7 @@ class TestWorkflowAgent:
agent = WorkflowAgent(workflow=workflow, name="Streaming Test Agent")
# Execute workflow streaming to capture streaming events
updates = []
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream("Test input"):
updates.append(update)
@@ -137,8 +138,8 @@ class TestWorkflowAgent:
# Verify we got a streaming update
assert updates[0].contents is not None
first_content = updates[0].contents[0]
second_content = updates[1].contents[0]
first_content: TextContent = updates[0].contents[0] # type: ignore[assignment]
second_content: TextContent = updates[1].contents[0] # type: ignore[assignment]
assert isinstance(first_content, TextContent)
assert "Streaming1: Test input" in first_content.text
assert isinstance(second_content, TextContent)
@@ -148,7 +149,7 @@ class TestWorkflowAgent:
"""Test end-to-end workflow with RequestInfoEvent handling."""
# Create workflow with requesting executor -> request info executor (no cycle)
requesting_executor = RequestingExecutor(id="requester")
request_info_executor = RequestInfoExecutor()
request_info_executor = RequestInfoExecutor(id="request_info")
workflow = (
WorkflowBuilder()
@@ -160,21 +161,21 @@ class TestWorkflowAgent:
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
# Execute workflow streaming to get request info event
updates = []
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream("Start request"):
updates.append(update)
# Should have received a function call for the request info
assert len(updates) > 0
# Find the function call update (RequestInfoEvent converted to function call)
function_call_update = None
function_call_update: AgentRunResponseUpdate | None = None
for update in updates:
if update.contents and hasattr(update.contents[0], "name") and update.contents[0].name == "request_info":
if update.contents and hasattr(update.contents[0], "name") and update.contents[0].name == "request_info": # type: ignore[attr-defined]
function_call_update = update
break
assert function_call_update is not None, "Should have received a request_info function call"
function_call = function_call_update.contents[0]
function_call: FunctionCallContent = function_call_update.contents[0] # type: ignore[assignment]
# Verify the function call has expected structure
assert function_call.call_id is not None
@@ -230,7 +231,7 @@ class TestWorkflowAgent:
raise ValueError("Unsupported message type")
# Create a simple workflow
executor = _Executor()
executor = _Executor(id="test")
workflow = WorkflowBuilder().set_start_executor(executor).build()
# Try to create an agent with unsupported input types
@@ -1,5 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.
from dataclasses import dataclass
from typing import Any
import pytest
from agent_framework import (
@@ -155,3 +158,59 @@ async def test_run_includes_status_events_idle_with_requests():
assert len(timeline) >= 3
assert timeline[-2].state == WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS
assert timeline[-1].state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
@dataclass
class SnapshotRequest(RequestInfoMessage):
prompt: str = ""
draft: str = ""
iteration: int = 0
class SnapshotRequester(Executor):
"""Executor that emits a rich RequestInfoMessage for persistence tests."""
def __init__(self, id: str, prompt: str, draft: str) -> None:
super().__init__(id=id)
self._prompt = prompt
self._draft = draft
@handler
async def ask(self, _: str, ctx: WorkflowContext[SnapshotRequest]) -> None: # pragma: no cover - simple helper
await ctx.send_message(SnapshotRequest(prompt=self._prompt, draft=self._draft, iteration=1))
async def test_request_info_executor_tracks_pending_requests_via_shared_state():
prompt = "Review the launch copy"
draft = "Limited edition grinder now $249"
requester = SnapshotRequester(id="snapshot_req", prompt=prompt, draft=draft)
request_info = RequestInfoExecutor(id="request_info")
wf = WorkflowBuilder().set_start_executor(requester).add_edge(requester, request_info).build()
events = [event async for event in wf.run_stream("start")]
assert any(isinstance(event, RequestInfoEvent) for event in events)
pending_map: dict[str, Any] = await wf._shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) # type: ignore[reportPrivateUsage]
assert isinstance(pending_map, dict)
assert len(pending_map) == 1
snapshot: dict[str, Any] = next(iter(pending_map.values()))
assert snapshot["prompt"] == prompt
assert snapshot["draft"] == draft
assert snapshot.get("iteration") == 1
request_id: str = snapshot["request_id"]
request_info_resume = RequestInfoExecutor(id="request_info_resume")
resume_context: WFContext[Any] = WFContext(
executor_id=request_info_resume.id,
source_executor_ids=[wf.__class__.__name__],
shared_state=wf._shared_state, # type: ignore[reportPrivateUsage]
runner_context=wf._runner_context, # type: ignore[reportPrivateUsage]
)
await request_info_resume.handle_response("approve", request_id, resume_context)
updated_pending: dict[str, Any] = await wf._shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) # type: ignore[reportPrivateUsage]
assert isinstance(updated_pending, dict)
assert request_id not in updated_pending
@@ -43,6 +43,7 @@ Once comfortable with these, explore the rest of the samples below.
| Sample | File | Concepts |
|---|---|---|
| Checkpoint & Resume | [checkpoint/checkpoint_with_resume.py](./checkpoint/checkpoint_with_resume.py) | Create checkpoints, inspect them, and resume execution |
| Checkpoint & HITL Resume | [checkpoint/checkpoint_with_human_in_the_loop.py](./checkpoint/checkpoint_with_human_in_the_loop.py) | Combine checkpointing with human approvals and resume pending HITL requests |
### composition
| Sample | File | Concepts |
@@ -50,7 +50,7 @@ Prerequisites
# - Compute a result
# - Forward that result to downstream node(s) using ctx.send_message(result)
class UpperCase(Executor):
def __init__(self, id: str | None = None):
def __init__(self, id: str):
super().__init__(id=id)
@handler
@@ -1,108 +1,34 @@
# Copyright (c) Microsoft. All rights reserved.
# import asyncio
# from agent_framework.foundry import FoundryChatClient
# from agent_framework import AgentRunUpdateEvent, WorkflowBuilder, WorkflowCompletedEvent
# from azure.identity.aio import AzureCliCredential
# """
# Sample: Agents in a workflow with streaming
# A Writer agent generates content, then a Reviewer agent critiques it.
# The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens.
# Purpose:
# Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors.
# Demonstrate:
# - Automatic streaming of agent deltas via AgentRunUpdateEvent.
# - A simple console aggregator that groups updates by executor id and prints them as they arrive.
# - A final WorkflowCompletedEvent that contains the reviewer outcome after both agents finish.
# Prerequisites:
# - Foundry Agent Service configured, along with the required environment variables.
# - Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample.
# - Basic familiarity with WorkflowBuilder, edges, events, and streaming runs.
# """
# async def main():
# """Build and run a simple two node agent workflow: Writer then Reviewer."""
# # Create the Foundry chat client.
# async with (
# AzureCliCredential() as credential,
# FoundryChatClient(async_credential=credential).create_agent(
# name="Writer",
# instructions=(
# "You are an excellent content writer.You create new content and edit contents based on the feedback."
# ),
# ) as writer_agent,
# FoundryChatClient(async_credential=credential).create_agent(
# name="Reviewer",
# instructions=(
# "You are an excellent content reviewer."
# "Provide actionable feedback to the writer about the provided content."
# "Provide the feedback in the most concise manner possible."
# ),
# ) as reviewer_agent,
# ):
# # Build the workflow using the fluent builder.
# # Set the start node and connect an edge from writer to reviewer.
# workflow = WorkflowBuilder().set_start_executor(writer_agent).add_edge(writer_agent, reviewer_agent).build()
# # Stream events from the workflow. We aggregate partial token updates per executor for readable output.
# completed_event: WorkflowCompletedEvent | None = None
# last_executor_id = None
# async for event in workflow.run_stream(
# "Create a slogan for a new electric SUV that is affordable and fun to drive."
# ):
# if isinstance(event, AgentRunUpdateEvent):
# # AgentRunUpdateEvent contains incremental text deltas from the underlying agent.
# # Print a prefix when the executor changes, then append updates on the same line.
# eid = event.executor_id
# if eid != last_executor_id:
# if last_executor_id is not None:
# print()
# print(f"{eid}:", end=" ", flush=True)
# last_executor_id = eid
# print(event.data, end="", flush=True)
# elif isinstance(event, WorkflowCompletedEvent):
# # Terminal event with the final reviewer output.
# completed_event = event
# # Print the final consolidated reviewer result.
# if completed_event:
# print("\n===== Final Output =====")
# print(completed_event.data)
# """
# Sample Output:
# writer_agent: Charge Up Your Journey. Fun, Affordable, Electric.
# reviewer_agent: Clear message, but consider highlighting SUV specific benefits
# (space, versatility) for stronger impact. Try more vivid language to evoke
# excitement. Example: "Big on Space. Big on Fun. Electric for Everyone."
# ===== Final Output =====
# Clear message, but consider highlighting SUV specific benefits (space, versatility)
# for stronger impact. Try more vivid language to evoke excitement. Example:
# "Big on Space. Big on Fun. Electric for Everyone."
# """
# if __name__ == "__main__":
# asyncio.run(main())
import asyncio
from collections.abc import Awaitable, Callable
from contextlib import AsyncExitStack
from typing import Any
from collections.abc import Awaitable, Callable
from agent_framework.foundry import FoundryChatClient
from agent_framework import AgentRunUpdateEvent, WorkflowBuilder, WorkflowCompletedEvent
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
"""
Sample: Agents in a workflow with streaming
A Writer agent generates content, then a Reviewer agent critiques it.
The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens.
Purpose:
Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors.
Demonstrate:
- Automatic streaming of agent deltas via AgentRunUpdateEvent.
- A simple console aggregator that groups updates by executor id and prints them as they arrive.
- A final WorkflowCompletedEvent that contains the reviewer outcome after both agents finish.
Prerequisites:
- Foundry Agent Service configured, along with the required environment variables.
- Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample.
- Basic familiarity with WorkflowBuilder, edges, events, and streaming runs.
"""
async def create_foundry_agent() -> tuple[Callable[..., Awaitable[Any]], Callable[[], Awaitable[None]]]:
"""Helper method to create a Foundry agent factory and a close function.
@@ -1,27 +1,31 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import sys
from dataclasses import dataclass
from pathlib import Path
from agent_framework import (
# Ensure local getting_started package can be imported when running as a script.
_SAMPLES_ROOT = Path(__file__).resolve().parents[3]
if str(_SAMPLES_ROOT) not in sys.path:
sys.path.insert(0, str(_SAMPLES_ROOT))
from agent_framework import ( # noqa: E402
ChatMessage,
Executor,
FunctionCallContent,
FunctionResultContent,
Role,
)
from agent_framework.openai import OpenAIChatClient
from agent_framework import (
Executor,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
Role,
WorkflowAgent,
WorkflowBuilder,
WorkflowContext,
handler,
)
from samples.getting_started.workflow.agents.workflow_as_agent_reflection_pattern import (
from agent_framework.openai import OpenAIChatClient # noqa: E402
from getting_started.workflow.agents.workflow_as_agent_reflection_pattern import ( # noqa: E402
ReviewRequest,
ReviewResponse,
Worker,
@@ -56,8 +60,9 @@ class HumanReviewRequest(RequestInfoMessage):
class ReviewerWithHumanInTheLoop(Executor):
"""Executor that always escalates reviews to a human manager."""
def __init__(self, worker_id: str, request_info_id: str) -> None:
super().__init__()
def __init__(self, worker_id: str, request_info_id: str, reviewer_id: str | None = None) -> None:
unique_id = reviewer_id or f"{worker_id}-reviewer"
super().__init__(id=unique_id)
self._worker_id = worker_id
self._request_info_id = request_info_id
@@ -96,8 +101,8 @@ async def main() -> None:
# Create executors for the workflow.
print("Creating chat client and executors...")
mini_chat_client = OpenAIChatClient(ai_model_id="gpt-4.1-nano")
worker = Worker(chat_client=mini_chat_client)
request_info_executor = RequestInfoExecutor()
worker = Worker(id="sub-worker", chat_client=mini_chat_client)
request_info_executor = RequestInfoExecutor(id="request_info")
reviewer = ReviewerWithHumanInTheLoop(worker_id=worker.id, request_info_id=request_info_executor.id)
print("Building workflow with Worker ↔ Reviewer cycle...")
@@ -4,9 +4,19 @@ import asyncio
from dataclasses import dataclass
from uuid import uuid4
from agent_framework import AgentRunResponseUpdate, ChatClientProtocol, ChatMessage, Contents, Role
from agent_framework import (
AgentRunResponseUpdate,
AgentRunUpdateEvent,
ChatClientProtocol,
ChatMessage,
Contents,
Executor,
Role,
WorkflowBuilder,
WorkflowContext,
handler,
)
from agent_framework.openai import OpenAIChatClient
from agent_framework import AgentRunUpdateEvent, Executor, WorkflowBuilder, WorkflowContext, handler
from pydantic import BaseModel
"""
@@ -54,8 +64,8 @@ class ReviewResponse:
class Reviewer(Executor):
"""Executor that reviews agent responses and provides structured feedback."""
def __init__(self, chat_client: ChatClientProtocol) -> None:
super().__init__()
def __init__(self, id: str, chat_client: ChatClientProtocol) -> None:
super().__init__(id=id)
self._chat_client = chat_client
@handler
@@ -106,8 +116,8 @@ class Reviewer(Executor):
class Worker(Executor):
"""Executor that generates responses and incorporates feedback when necessary."""
def __init__(self, chat_client: ChatClientProtocol) -> None:
super().__init__()
def __init__(self, id: str, chat_client: ChatClientProtocol) -> None:
super().__init__(id=id)
self._chat_client = chat_client
self._pending_requests: dict[str, tuple[ReviewRequest, list[ChatMessage]]] = {}
@@ -189,8 +199,8 @@ async def main() -> None:
print("Creating chat client and executors...")
mini_chat_client = OpenAIChatClient(ai_model_id="gpt-4.1-nano")
chat_client = OpenAIChatClient(ai_model_id="gpt-4.1")
reviewer = Reviewer(chat_client=chat_client)
worker = Worker(chat_client=mini_chat_client)
reviewer = Reviewer(id="reviewer", chat_client=chat_client)
worker = Worker(id="worker", chat_client=mini_chat_client)
print("Building workflow with Worker ↔ Reviewer cycle...")
agent = (
@@ -0,0 +1,483 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import AsyncIterable
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
from agent_framework import (
AgentExecutor,
AgentExecutorRequest,
AgentExecutorResponse,
ChatMessage,
Executor,
FileCheckpointStorage,
RequestInfoEvent,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
Role,
WorkflowBuilder,
WorkflowCompletedEvent,
WorkflowContext,
WorkflowRunState,
WorkflowStatusEvent,
handler,
)
from agent_framework.azure import AzureChatClient
from azure.identity import AzureCliCredential
# NOTE: the Azure client imports above are real dependencies. When running this
# sample outside of Azure-enabled environments you may wish to swap in the
# `agent_framework.builtin` chat client or mock the writer executor. We keep the
# concrete import here so readers can see an end-to-end configuration.
if TYPE_CHECKING:
from agent_framework import Workflow
from agent_framework._workflow._checkpoint import WorkflowCheckpoint
"""
Sample: Checkpoint + human-in-the-loop quickstart.
This getting-started sample keeps the moving pieces to a minimum:
1. A brief is turned into a consistent prompt for an AI copywriter.
2. The copywriter (an `AgentExecutor`) drafts release notes.
3. A reviewer gateway routes every draft through `RequestInfoExecutor` so a human
can approve or request tweaks.
4. The workflow records checkpoints between each superstep so you can stop the
program, restart later, and optionally pre-supply human answers on resume.
Key concepts demonstrated
-------------------------
- Minimal executor pipeline with checkpoint persistence.
- Human-in-the-loop pause/resume by pairing `RequestInfoExecutor` with
checkpoint restoration.
- Supplying responses at restore time (`run_stream_from_checkpoint(..., responses=...)`).
Typical pause/resume flow
-------------------------
1. Run the workflow until a human approval request is emitted.
2. If the human is offline, exit the program. A checkpoint with
``status=awaiting human response`` now exists.
3. Later, restart the script, select that checkpoint, and provide the stored
human decision when prompted to pre-supply responses.
Doing so applies the answer immediately on resume, so the system does **not**
re-emit the same `RequestInfoEvent`.
"""
# Directory used for the sample's temporary checkpoint files. We isolate the
# demo artefacts so that repeated runs do not collide with other samples and so
# the clean-up step at the end of the script can simply delete the directory.
TEMP_DIR = Path(__file__).with_suffix("").parent / "tmp" / "checkpoints_hitl"
TEMP_DIR.mkdir(parents=True, exist_ok=True)
class BriefPreparer(Executor):
"""Normalises the user brief and sends a single AgentExecutorRequest."""
# The first executor in the workflow. By keeping it tiny we make it easier
# to reason about the state that will later be captured in the checkpoint.
# It is responsible for tidying the human-provided brief and kicking off the
# agent run with a deterministic prompt structure.
def __init__(self, id: str, agent_id: str) -> None:
super().__init__(id=id)
self._agent_id = agent_id
@handler
async def prepare(self, brief: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
# Collapse errant whitespace so the prompt is stable between runs.
normalized = " ".join(brief.split()).strip()
if not normalized.endswith("."):
normalized += "."
# Persist the cleaned brief in shared state so downstream executors and
# future checkpoints can recover the original intent.
await ctx.set_shared_state("brief", normalized)
prompt = (
"You are drafting product release notes. Summarise the brief below in two sentences. "
"Keep it positive and end with a call to action.\n\n"
f"BRIEF: {normalized}"
)
# Hand the prompt to the writer agent. We always route through the
# workflow context so the runtime can capture messages for checkpointing.
await ctx.send_message(
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
target_id=self._agent_id,
)
@dataclass
class HumanApprovalRequest(RequestInfoMessage):
"""Message sent to the human reviewer via RequestInfoExecutor."""
# These fields are intentionally simple because they are serialised into
# checkpoints. Keeping them primitive types guarantees the new
# `pending_requests_from_checkpoint` helper can reconstruct them on resume.
prompt: str = ""
draft: str = ""
iteration: int = 0
class ReviewGateway(Executor):
"""Routes agent drafts to humans and optionally back for revisions."""
def __init__(self, id: str, reviewer_id: str, writer_id: str, finalize_id: str) -> None:
super().__init__(id=id)
self._reviewer_id = reviewer_id
self._writer_id = writer_id
self._finalize_id = finalize_id
@handler
async def on_agent_response(
self,
response: AgentExecutorResponse,
ctx: WorkflowContext[HumanApprovalRequest],
) -> None:
# Capture the agent output so we can surface it to the reviewer and
# persist iterations. The `RequestInfoExecutor` relies on this state to
# rehydrate when checkpoints are restored.
draft = response.agent_run_response.text or ""
iteration = int((await ctx.get_state() or {}).get("iteration", 0)) + 1
await ctx.set_state({"iteration": iteration, "last_draft": draft})
# Emit a human approval request. Because this flows through
# RequestInfoExecutor it will pause the workflow until an answer is
# supplied either interactively or via pre-supplied responses.
await ctx.send_message(
HumanApprovalRequest(
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
draft=draft,
iteration=iteration,
),
target_id=self._reviewer_id,
)
@handler
async def on_human_feedback(
self,
feedback: RequestResponse[HumanApprovalRequest, str],
ctx: WorkflowContext[AgentExecutorRequest | str],
) -> None:
# The RequestResponse wrapper gives us both the human data and the
# original request message, even when resuming from checkpoints.
reply = (feedback.data or "").strip()
state = await ctx.get_state() or {}
draft = state.get("last_draft") or (feedback.original_request.draft if feedback.original_request else "")
if reply.lower() == "approve":
# When the human signs off we can short-circuit the workflow and
# send the approved draft to the final executor.
await ctx.send_message(draft, target_id=self._finalize_id)
return
# Any other response loops us back to the writer with fresh guidance.
guidance = reply or "Tighten the copy and emphasise customer benefit."
iteration = int(state.get("iteration", 1)) + 1
await ctx.set_state({"iteration": iteration, "last_draft": draft})
prompt = (
"Revise the launch note. Respond with the new copy only.\n\n"
f"Previous draft:\n{draft}\n\n"
f"Human guidance: {guidance}"
)
await ctx.send_message(
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
target_id=self._writer_id,
)
class FinaliseExecutor(Executor):
"""Publishes the approved text."""
@handler
async def publish(self, text: str, ctx: WorkflowContext[Any]) -> None:
# Store the output so diagnostics or a UI could fetch the final copy.
await ctx.set_state({"published_text": text})
# Emit a workflow completion event so the runner stops cleanly.
await ctx.add_event(WorkflowCompletedEvent(text))
def create_workflow(*, checkpoint_storage: FileCheckpointStorage | None = None) -> "Workflow":
"""Assemble the workflow graph used by both the initial run and resume."""
# The Azure client is created once so our agent executor can issue calls to
# the hosted model. The agent id is stable across runs which keeps
# checkpoints deterministic.
chat_client = AzureChatClient(credential=AzureCliCredential())
writer = AgentExecutor(
chat_client.create_agent(
instructions="Write concise, warm release notes that sound human and helpful.",
),
id="writer",
)
# RequestInfoExecutor is the lynchpin for human-in-the-loop: every draft is
# routed through it so checkpoints can pause while waiting for responses.
review = RequestInfoExecutor(id="request_info")
finalise = FinaliseExecutor(id="finalise")
gateway = ReviewGateway(
id="review_gateway",
reviewer_id=review.id,
writer_id=writer.id,
finalize_id=finalise.id,
)
prepare = BriefPreparer(id="prepare_brief", agent_id=writer.id)
# Wire the workflow DAG. Edges mirror the numbered steps described in the
# module docstring. Because `WorkflowBuilder` is declarative, reading these
# edges is often the quickest way to understand execution order.
builder = (
WorkflowBuilder(max_iterations=6)
.set_start_executor(prepare)
.add_edge(prepare, writer)
.add_edge(writer, gateway)
.add_edge(gateway, review)
.add_edge(review, gateway) # human resumes loop
.add_edge(gateway, writer) # revisions
.add_edge(gateway, finalise)
)
# Opt-in to persistence when the caller provides storage. The workflow
# object itself is identical whether or not checkpointing is enabled.
if checkpoint_storage:
builder = builder.with_checkpointing(checkpoint_storage=checkpoint_storage)
return builder.build()
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
"""Pretty-print saved checkpoints with the new framework summaries."""
print("\nCheckpoint summary:")
for summary in [
RequestInfoExecutor.checkpoint_summary(cp) for cp in sorted(checkpoints, key=lambda c: c.timestamp)
]:
# Compose a single line per checkpoint so the user can scan the output
# and pick the resume point that still has outstanding human work.
line = (
f"- {summary.checkpoint_id} | iter={summary.iteration_count} "
f"| targets={summary.targets} | states={summary.executor_states}"
)
if summary.status:
line += f" | status={summary.status}"
if summary.draft_preview:
line += f" | draft_preview={summary.draft_preview}"
if summary.pending_requests:
line += f" | pending_request_id={summary.pending_requests[0].request_id}"
print(line)
def _print_events(events: list[Any]) -> tuple[WorkflowCompletedEvent | None, list[tuple[str, HumanApprovalRequest]]]:
"""Echo workflow events to the console and collect outstanding requests."""
completed: WorkflowCompletedEvent | None = None
requests: list[tuple[str, HumanApprovalRequest]] = []
for event in events:
print(f"Event: {event}")
if isinstance(event, WorkflowCompletedEvent):
completed = event
elif isinstance(event, RequestInfoEvent) and isinstance(event.data, HumanApprovalRequest):
# Capture pending human approvals so the caller can ask the user for
# input after the current batch of events is processed.
requests.append((event.request_id, event.data))
elif isinstance(event, WorkflowStatusEvent) and event.state in {
WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS,
WorkflowRunState.IDLE_WITH_PENDING_REQUESTS,
}:
print(f"Workflow state: {event.state.name}")
return completed, requests
def _prompt_for_responses(requests: list[tuple[str, HumanApprovalRequest]]) -> dict[str, str] | None:
"""Interactive CLI prompt for any live RequestInfo requests."""
if not requests:
return None
answers: dict[str, str] = {}
for request_id, request in requests:
# Keep the prompt conversational so testers can use the script without
# memorising the workflow APIs.
print("\n=== Human approval needed ===")
print(f"request_id: {request_id}")
if request.iteration:
print(f"Iteration: {request.iteration}")
print(request.prompt)
print("Draft: \n---\n" + request.draft + "\n---")
answer = input("Type 'approve' or enter revision guidance (or 'exit' to quit): ").strip() # noqa: ASYNC250
if answer.lower() == "exit":
raise SystemExit("Stopped by user.")
answers[request_id] = answer
return answers
def _maybe_pre_supply_responses(cp: "WorkflowCheckpoint") -> dict[str, str] | None:
"""Offer to collect responses before resuming a checkpoint."""
pending = RequestInfoExecutor.pending_requests_from_checkpoint(cp)
if not pending:
return None
print(
"This checkpoint still has pending human input. Provide the responses now so the resume step "
"applies them immediately and does not re-emit the original RequestInfo event."
)
choice = input("Pre-supply responses for this checkpoint? [y/N]: ").strip().lower() # noqa: ASYNC250
if choice not in {"y", "yes"}:
return None
answers: dict[str, str] = {}
for item in pending:
iteration = item.iteration or 0
print(f"\nPending draft (iteration {iteration} | request_id={item.request_id}):")
draft_text = (item.draft or "").strip()
if draft_text:
# The shortened preview in the summary may truncate text; here we
# show the full draft so the reviewer can make an informed choice.
print("Draft:\n---\n" + draft_text + "\n---")
else:
print("Draft: [not captured in checkpoint payload - refer to your notes/log]")
prompt_text = (item.prompt or "Review the draft").strip()
print(prompt_text)
answer = input("Response ('approve' or guidance, 'exit' to abort): ").strip() # noqa: ASYNC250
if answer.lower() == "exit":
raise SystemExit("Resume aborted by user.")
answers[item.request_id] = answer
return answers
async def _consume(stream: AsyncIterable[Any]) -> list[Any]:
"""Materialise an async event stream into a list."""
return [event async for event in stream]
async def run_interactive_session(workflow: "Workflow", initial_message: str) -> WorkflowCompletedEvent | None:
"""Run the workflow until it either finishes or pauses for human input."""
pending_responses: dict[str, str] | None = None
completed: WorkflowCompletedEvent | None = None
first = True
while completed is None:
if first:
# Kick off the workflow with the initial brief. The returned events
# include RequestInfo events when the agent produces a draft.
events = await _consume(workflow.run_stream(initial_message))
first = False
elif pending_responses:
# Feed any answers the user just typed back into the workflow.
events = await _consume(workflow.send_responses_streaming(pending_responses))
else:
break
completed, requests = _print_events(events)
pending_responses = _prompt_for_responses(requests)
return completed
async def resume_from_checkpoint(
workflow: "Workflow",
checkpoint_id: str,
storage: FileCheckpointStorage,
pre_supplied: dict[str, str] | None,
) -> None:
"""Resume a stored checkpoint and continue until completion or another pause."""
print(f"\nResuming from checkpoint: {checkpoint_id}")
events = await _consume(
workflow.run_stream_from_checkpoint(
checkpoint_id,
checkpoint_storage=storage,
responses=pre_supplied,
)
)
completed, requests = _print_events(events)
if pre_supplied and not requests and completed is None:
# When the checkpoint only needed the provided answers we let the user
# know the workflow is waiting for the next superstep (usually another
# agent response).
print("Pre-supplied responses applied automatically; workflow is now waiting for the next step.")
pending = _prompt_for_responses(requests)
while completed is None and pending:
events = await _consume(workflow.send_responses_streaming(pending))
completed, requests = _print_events(events)
pending = _prompt_for_responses(requests)
if completed:
print(f"Workflow completed with: {completed.data}")
async def main() -> None:
"""Entry point used by both the initial run and subsequent resumes."""
for file in TEMP_DIR.glob("*.json"):
# Start each execution with a clean slate so the demonstration is
# deterministic even if the directory had stale checkpoints.
file.unlink()
storage = FileCheckpointStorage(storage_path=TEMP_DIR)
workflow = create_workflow(checkpoint_storage=storage)
brief = (
"Introduce our limited edition smart coffee grinder. Mention the $249 price, highlight the "
"sensor that auto-adjusts the grind, and invite customers to pre-order on the website."
)
print("Running workflow (human approval required)...")
completed = await run_interactive_session(workflow, initial_message=brief)
if completed:
print(f"Initial run completed with final copy: {completed.data}")
else:
print("Initial run paused for human input.")
checkpoints = await storage.list_checkpoints()
if not checkpoints:
print("No checkpoints recorded.")
return
# Show the user what is available before we prompt for the index. The
# summary helper keeps this output consistent with other tooling.
_render_checkpoint_summary(checkpoints)
sorted_cps = sorted(checkpoints, key=lambda c: c.timestamp)
print("\nAvailable checkpoints:")
for idx, cp in enumerate(sorted_cps):
print(f" [{idx}] id={cp.checkpoint_id} iter={cp.iteration_count}")
# For the pause/resume demo we typically pick the latest checkpoint whose summary
# status reads "awaiting human response" - that is the saved state that proves the
# workflow can rehydrate, collect the pending answer, and continue after a break.
selection = input("\nResume from which checkpoint? (press Enter to skip): ").strip() # noqa: ASYNC250
if not selection:
print("No resume selected. Exiting.")
return
try:
idx = int(selection)
except ValueError:
print("Invalid input; exiting.")
return
if not 0 <= idx < len(sorted_cps):
print("Index out of range; exiting.")
return
chosen = sorted_cps[idx]
summary = RequestInfoExecutor.checkpoint_summary(chosen)
if summary.status == "completed":
print("Selected checkpoint already reflects a completed workflow; nothing to resume.")
return
# If the user wants, capture their decisions now so the resume call can
# push them into the workflow and avoid re-prompting.
pre_responses = _maybe_pre_supply_responses(chosen)
resumed_workflow = create_workflow()
# Resume with a fresh workflow instance. The checkpoint carries the
# persistent state while this object holds the runtime wiring.
await resume_from_checkpoint(resumed_workflow, chosen.checkpoint_id, storage, pre_responses)
if __name__ == "__main__":
asyncio.run(main())
@@ -3,7 +3,7 @@
import asyncio
import os
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
from agent_framework import (
AgentExecutor,
@@ -12,6 +12,7 @@ from agent_framework import (
ChatMessage,
Executor,
FileCheckpointStorage,
RequestInfoExecutor,
Role,
WorkflowBuilder,
WorkflowCompletedEvent,
@@ -21,6 +22,10 @@ from agent_framework import (
from agent_framework.azure import AzureChatClient
from azure.identity import AzureCliCredential
if TYPE_CHECKING:
from agent_framework import Workflow
from agent_framework._workflow._checkpoint import WorkflowCheckpoint
"""
Sample: Checkpointing and Resuming a Workflow (with an Agent stage)
@@ -87,7 +92,7 @@ class UpperCaseExecutor(Executor):
class SubmitToLowerAgent(Executor):
"""Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility."""
def __init__(self, agent_id: str, id: str | None = None):
def __init__(self, id: str, agent_id: str):
super().__init__(id=id)
self._agent_id = agent_id
@@ -132,10 +137,6 @@ class FinalizeFromAgent(Executor):
class ReverseTextExecutor(Executor):
"""Reverses the input text and persists local state."""
def __init__(self, id: str):
"""Initialize the executor with an ID."""
super().__init__(id=id)
@handler
async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None:
result = text[::-1]
@@ -154,15 +155,10 @@ class ReverseTextExecutor(Executor):
await ctx.send_message(result)
async def main():
# Clear existing checkpoints in this sample directory for a clean run.
checkpoint_dir = Path(TEMP_DIR)
for file in checkpoint_dir.glob("*.json"):
file.unlink()
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow":
# Instantiate the pipeline executors.
upper_case_executor = UpperCaseExecutor(id="upper_case_executor")
reverse_text_executor = ReverseTextExecutor(id="reverse_text_executor")
upper_case_executor = UpperCaseExecutor(id="upper-case")
reverse_text_executor = ReverseTextExecutor(id="reverse-text")
# Configure the agent stage that lowercases the text.
chat_client = AzureChatClient(credential=AzureCliCredential())
@@ -174,14 +170,11 @@ async def main():
)
# Bridge to the agent and terminalization stage.
submit_lower = SubmitToLowerAgent(agent_id=lower_agent.id, id="submit_lower")
submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id)
finalize = FinalizeFromAgent(id="finalize")
# Backing store for checkpoints written by with_checkpointing.
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
# Build the workflow with checkpointing enabled.
workflow = (
return (
WorkflowBuilder(max_iterations=5)
.add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse
.add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request
@@ -192,6 +185,40 @@ async def main():
.build()
)
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
"""Display human-friendly checkpoint metadata using framework summaries."""
if not checkpoints:
return
print("\nCheckpoint summary:")
for cp in sorted(checkpoints, key=lambda c: c.timestamp):
summary = RequestInfoExecutor.checkpoint_summary(cp)
msg_count = sum(len(v) for v in cp.messages.values())
state_keys = sorted(cp.executor_states.keys())
orig = cp.shared_state.get("original_input")
upper = cp.shared_state.get("upper_output")
line = (
f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}"
)
if summary.status:
line += f" | status={summary.status}"
line += f" | shared_state: original_input='{orig}', upper_output='{upper}'"
print(line)
async def main():
# Clear existing checkpoints in this sample directory for a clean run.
checkpoint_dir = Path(TEMP_DIR)
for file in checkpoint_dir.glob("*.json"):
file.unlink()
# Backing store for checkpoints written by with_checkpointing.
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
workflow = create_workflow(checkpoint_storage=checkpoint_storage)
# Run the full workflow once and observe events as they stream.
print("Running workflow with initial message...")
async for event in workflow.run_stream(message="hello world"):
@@ -206,26 +233,20 @@ async def main():
# All checkpoints created by this run share the same workflow_id.
workflow_id = all_checkpoints[0].workflow_id
# Dump a quick summary including shared_state keys to illustrate what persisted.
print("\nCheckpoint summary:")
for cp in sorted(all_checkpoints, key=lambda c: c.timestamp):
msg_count = sum(len(v) for v in cp.messages.values())
state_keys = sorted(list(cp.executor_states.keys())) if hasattr(cp, "executor_states") else []
orig = cp.shared_state.get("original_input") if hasattr(cp, "shared_state") else None
upper = cp.shared_state.get("upper_output") if hasattr(cp, "shared_state") else None
print(
f"- {cp.checkpoint_id} | "
f"iter={cp.iteration_count} | messages={msg_count} | states={state_keys} | "
f"shared_state: original_input='{orig}', upper_output='{upper}'"
)
_render_checkpoint_summary(all_checkpoints)
# Offer an interactive selection of checkpoints to resume from.
sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp)
print("\nAvailable checkpoints to resume from:")
for idx, cp in enumerate(sorted_cps):
summary = RequestInfoExecutor.checkpoint_summary(cp)
line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}"
if summary.status:
line += f" status={summary.status}"
msg_count = sum(len(v) for v in cp.messages.values())
print(f" [{idx}] id={cp.checkpoint_id} iter={cp.iteration_count} messages={msg_count}")
line += f" messages={msg_count}"
print(line)
user_input = input(
"\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: "
@@ -256,15 +277,7 @@ async def main():
# You can reuse the same workflow graph definition and resume from a prior checkpoint.
# This second workflow instance does not enable checkpointing to show that resumption
# reads from stored state but need not write new checkpoints.
new_workflow = (
WorkflowBuilder(max_iterations=5)
.add_edge(upper_case_executor, reverse_text_executor)
.add_edge(reverse_text_executor, submit_lower)
.add_edge(submit_lower, lower_agent)
.add_edge(lower_agent, finalize)
.set_start_executor(upper_case_executor)
.build()
)
new_workflow = create_workflow(checkpoint_storage=checkpoint_storage)
print(f"\nResuming from checkpoint: {chosen_cp_id}")
async for event in new_workflow.run_stream_from_checkpoint(chosen_cp_id, checkpoint_storage=checkpoint_storage):
@@ -275,15 +288,15 @@ async def main():
Running workflow with initial message...
UpperCaseExecutor: 'hello world' -> 'HELLO WORLD'
Event: ExecutorInvokedEvent(executor_id=upper_case_executor)
Event: ExecutorInvokeEvent(executor_id=upper_case_executor)
Event: ExecutorCompletedEvent(executor_id=upper_case_executor)
ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH'
Event: ExecutorInvokedEvent(executor_id=reverse_text_executor)
Event: ExecutorInvokeEvent(executor_id=reverse_text_executor)
Event: ExecutorCompletedEvent(executor_id=reverse_text_executor)
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
Event: ExecutorInvokedEvent(executor_id=submit_lower)
Event: ExecutorInvokedEvent(executor_id=lower_agent)
Event: ExecutorInvokedEvent(executor_id=finalize)
Event: ExecutorInvokeEvent(executor_id=submit_lower)
Event: ExecutorInvokeEvent(executor_id=lower_agent)
Event: ExecutorInvokeEvent(executor_id=finalize)
Event: WorkflowCompletedEvent(data=dlrow olleh)
Checkpoint summary:
@@ -300,9 +313,9 @@ async def main():
Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
Resumed Event: ExecutorInvokedEvent(executor_id=submit_lower)
Resumed Event: ExecutorInvokedEvent(executor_id=lower_agent)
Resumed Event: ExecutorInvokedEvent(executor_id=finalize)
Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower)
Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent)
Resumed Event: ExecutorInvokeEvent(executor_id=finalize)
Resumed Event: WorkflowCompletedEvent(data=dlrow olleh)
""" # noqa: E501
@@ -50,7 +50,7 @@ class GuessNumberExecutor(Executor):
def __init__(self, bound: tuple[int, int], id: str | None = None):
"""Initialize the executor with a target number."""
super().__init__(id=id)
super().__init__(id=id or "guess_number")
self._lower = bound[0]
self._upper = bound[1]
@@ -83,7 +83,7 @@ class SubmitToJudgeAgent(Executor):
"""Send the numeric guess to a judge agent which replies ABOVE/BELOW/MATCHED."""
def __init__(self, judge_agent_id: str, target: int, id: str | None = None):
super().__init__(id=id)
super().__init__(id=id or "submit_to_judge")
self._judge_agent_id = judge_agent_id
self._target = target
@@ -87,7 +87,7 @@ class TurnManager(Executor):
"""
def __init__(self, id: str | None = None):
super().__init__(id=id)
super().__init__(id=id or "turn_manager")
@handler
async def start(self, _: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
@@ -43,7 +43,7 @@ class DispatchToExperts(Executor):
"""Dispatches the incoming prompt to all expert agent executors for parallel processing (fan out)."""
def __init__(self, expert_ids: list[str], id: str | None = None):
super().__init__(id)
super().__init__(id=id or "dispatch_to_experts")
self._expert_ids = expert_ids
@handler
@@ -71,7 +71,7 @@ class AggregateInsights(Executor):
"""Aggregates expert agent responses into a single consolidated result (fan in)."""
def __init__(self, expert_ids: list[str], id: str | None = None):
super().__init__(id)
super().__init__(id=id or "aggregate_insights")
self._expert_ids = expert_ids
@handler
@@ -61,7 +61,7 @@ class Split(Executor):
def __init__(self, map_executor_ids: list[str], id: str | None = None):
"""Store mapper ids so we can assign non overlapping ranges per mapper."""
super().__init__(id)
super().__init__(id=id or "split")
self._map_executor_ids = map_executor_ids
@handler
@@ -145,7 +145,7 @@ class Shuffle(Executor):
def __init__(self, reducer_ids: list[str], id: str | None = None):
"""Remember reducer ids so we can partition work deterministically."""
super().__init__(id)
super().__init__(id=id or "shuffle")
self._reducer_ids = reducer_ids
@handler
@@ -40,7 +40,7 @@ class DispatchToExperts(Executor):
"""Dispatches the incoming prompt to all expert agent executors (fan-out)."""
def __init__(self, expert_ids: list[str], id: str | None = None):
super().__init__(id)
super().__init__(id=id or "dispatch_to_experts")
self._expert_ids = expert_ids
@handler
@@ -67,7 +67,7 @@ class AggregateInsights(Executor):
"""Aggregates expert agent responses into a single consolidated result (fan-in)."""
def __init__(self, expert_ids: list[str], id: str | None = None):
super().__init__(id)
super().__init__(id=id or "aggregate_insights")
self._expert_ids = expert_ids
@handler