mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Python: Make executor ID required, improvements around handling rehydrating checkpoints (#832)
* Make executor ID required, improvements around handling rehydrating checkpoints. * Duplicate executor validation added * fix remaining issues --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
7cd45e313b
commit
aba094b5cf
@@ -91,6 +91,7 @@ from ._shared_state import SharedState
|
||||
from ._telemetry import EdgeGroupDeliveryStatus, WorkflowTracer, workflow_tracer
|
||||
from ._validation import (
|
||||
EdgeDuplicationError,
|
||||
ExecutorDuplicationError,
|
||||
GraphConnectivityError,
|
||||
HandlerOutputAnnotationError,
|
||||
TypeCompatibilityError,
|
||||
@@ -118,6 +119,7 @@ __all__ = [
|
||||
"EdgeGroupDeliveryStatus",
|
||||
"Executor",
|
||||
"ExecutorCompletedEvent",
|
||||
"ExecutorDuplicationError",
|
||||
"ExecutorEvent",
|
||||
"ExecutorFailedEvent",
|
||||
"ExecutorInvokedEvent",
|
||||
|
||||
@@ -87,6 +87,7 @@ from ._shared_state import SharedState
|
||||
from ._telemetry import EdgeGroupDeliveryStatus, WorkflowTracer, workflow_tracer
|
||||
from ._validation import (
|
||||
EdgeDuplicationError,
|
||||
ExecutorDuplicationError,
|
||||
GraphConnectivityError,
|
||||
HandlerOutputAnnotationError,
|
||||
TypeCompatibilityError,
|
||||
@@ -114,6 +115,7 @@ __all__ = [
|
||||
"EdgeGroupDeliveryStatus",
|
||||
"Executor",
|
||||
"ExecutorCompletedEvent",
|
||||
"ExecutorDuplicationError",
|
||||
"ExecutorEvent",
|
||||
"ExecutorFailedEvent",
|
||||
"ExecutorInvokedEvent",
|
||||
|
||||
@@ -145,7 +145,10 @@ class _CallbackAggregator(Executor):
|
||||
"""
|
||||
|
||||
def __init__(self, callback: Callable[..., Any], id: str | None = None) -> None:
|
||||
super().__init__(id)
|
||||
derived_id = getattr(callback, "__name__", "") or ""
|
||||
if not derived_id or derived_id == "<lambda>":
|
||||
derived_id = f"{type(self).__name__}_unnamed"
|
||||
super().__init__(id or derived_id)
|
||||
self._callback = callback
|
||||
self._param_count = len(inspect.signature(callback).parameters)
|
||||
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
|
||||
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
||||
from textwrap import shorten
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, get_args, get_origin, overload
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._workflow import Workflow
|
||||
@@ -17,6 +20,7 @@ from pydantic import Field
|
||||
from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, AgentThread, ChatMessage
|
||||
from agent_framework._pydantic import AFBaseModel
|
||||
|
||||
from ._checkpoint import WorkflowCheckpoint
|
||||
from ._events import (
|
||||
AgentRunEvent,
|
||||
AgentRunUpdateEvent,
|
||||
@@ -25,35 +29,63 @@ from ._events import (
|
||||
RequestInfoEvent,
|
||||
_framework_event_origin, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from ._runner_context import _decode_checkpoint_value
|
||||
from ._typing_utils import is_instance_of
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# region Executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingRequestDetails:
|
||||
"""Lightweight information about a pending request captured in a checkpoint."""
|
||||
|
||||
request_id: str
|
||||
prompt: str | None = None
|
||||
draft: str | None = None
|
||||
iteration: int | None = None
|
||||
source_executor_id: str | None = None
|
||||
original_request: "RequestInfoMessage | dict[str, Any] | None" = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowCheckpointSummary:
|
||||
"""Human-readable summary of a workflow checkpoint."""
|
||||
|
||||
checkpoint_id: str
|
||||
iteration_count: int
|
||||
targets: list[str]
|
||||
executor_states: list[str]
|
||||
status: str
|
||||
draft_preview: str | None
|
||||
pending_requests: list[PendingRequestDetails]
|
||||
|
||||
|
||||
class Executor(AFBaseModel):
|
||||
"""An executor is a component that processes messages in a workflow."""
|
||||
|
||||
# Provide a default so static analyzers (e.g., pyright) don't require passing `id`.
|
||||
# Runtime still sets a concrete value in __init__.
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
...,
|
||||
min_length=1,
|
||||
description="Unique identifier for the executor",
|
||||
)
|
||||
type_: str = Field(default="", alias="type", description="The type of executor, corresponding to the class name")
|
||||
|
||||
def __init__(self, id: str | None = None, **kwargs: Any) -> None:
|
||||
def __init__(self, id: str, **kwargs: Any) -> None:
|
||||
"""Initialize the executor with a unique identifier.
|
||||
|
||||
Args:
|
||||
id: A unique identifier for the executor. If None, a new ID will be generated
|
||||
following the format <class_name>/<uuid>.
|
||||
id: A unique identifier for the executor.
|
||||
kwargs: Additional keyword arguments. Unused in this implementation.
|
||||
"""
|
||||
executor_id = f"{self.__class__.__name__}/{uuid.uuid4()}" if id is None else id
|
||||
if not id:
|
||||
raise ValueError("Executor ID must be a non-empty string.")
|
||||
|
||||
kwargs.update({"id": executor_id})
|
||||
kwargs.update({"id": id})
|
||||
if "type" not in kwargs and "type_" not in kwargs:
|
||||
kwargs["type_"] = self.__class__.__name__
|
||||
|
||||
@@ -677,17 +709,15 @@ class RequestInfoExecutor(Executor):
|
||||
a response is provided externally, it emits the response as a message.
|
||||
"""
|
||||
|
||||
def __init__(self, id: str | None = None):
|
||||
"""Initialize the RequestInfoExecutor with an optional custom ID.
|
||||
_PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info"
|
||||
|
||||
def __init__(self, id: str):
|
||||
"""Initialize the RequestInfoExecutor with a unique ID.
|
||||
|
||||
Args:
|
||||
id: Optional custom ID for this RequestInfoExecutor. If not provided,
|
||||
a unique ID will be generated.
|
||||
id: Unique ID for this RequestInfoExecutor.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
executor_id = id or f"request_info_{uuid.uuid4().hex[:8]}"
|
||||
super().__init__(id=executor_id)
|
||||
super().__init__(id=id)
|
||||
self._request_events: dict[str, RequestInfoEvent] = {}
|
||||
self._sub_workflow_contexts: dict[str, dict[str, str]] = {}
|
||||
|
||||
@@ -703,6 +733,7 @@ class RequestInfoExecutor(Executor):
|
||||
request_data=message,
|
||||
)
|
||||
self._request_events[message.request_id] = event
|
||||
await self._record_pending_request_snapshot(message, source_executor_id, ctx)
|
||||
await ctx.add_event(event)
|
||||
|
||||
@handler
|
||||
@@ -748,10 +779,13 @@ class RequestInfoExecutor(Executor):
|
||||
response_data: The data returned in the response.
|
||||
ctx: The workflow context for sending the response.
|
||||
"""
|
||||
if request_id not in self._request_events:
|
||||
event = self._request_events.get(request_id)
|
||||
if event is None:
|
||||
event = await self._rehydrate_request_event(request_id, ctx)
|
||||
if event is None:
|
||||
raise ValueError(f"No request found with ID: {request_id}")
|
||||
|
||||
event = self._request_events.pop(request_id)
|
||||
self._request_events.pop(request_id, None)
|
||||
|
||||
# Check if this was a forwarded sub-workflow request
|
||||
if request_id in self._sub_workflow_contexts:
|
||||
@@ -779,6 +813,472 @@ class RequestInfoExecutor(Executor):
|
||||
|
||||
await ctx.send_message(correlated_response, target_id=event.source_executor_id)
|
||||
|
||||
await self._clear_pending_request_snapshot(request_id, ctx)
|
||||
|
||||
async def _record_pending_request_snapshot(
|
||||
self,
|
||||
request: RequestInfoMessage,
|
||||
source_executor_id: str,
|
||||
ctx: WorkflowContext[Any],
|
||||
) -> None:
|
||||
snapshot = self._build_request_snapshot(request, source_executor_id)
|
||||
|
||||
pending = await self._load_pending_request_state(ctx)
|
||||
pending[request.request_id] = snapshot
|
||||
await self._persist_pending_request_state(pending, ctx)
|
||||
|
||||
async def _clear_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> None:
|
||||
pending = await self._load_pending_request_state(ctx)
|
||||
if request_id not in pending:
|
||||
return
|
||||
|
||||
pending.pop(request_id, None)
|
||||
await self._persist_pending_request_state(pending, ctx)
|
||||
|
||||
async def _load_pending_request_state(self, ctx: WorkflowContext[Any]) -> dict[str, Any]:
|
||||
try:
|
||||
existing = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"RequestInfoExecutor {self.id} failed to read pending request state: {exc}")
|
||||
return {}
|
||||
|
||||
if not isinstance(existing, dict) or existing is None:
|
||||
if existing not in (None, {}):
|
||||
logger.warning(
|
||||
f"RequestInfoExecutor {self.id} encountered non-dict pending state "
|
||||
f"({type(existing).__name__}); resetting."
|
||||
)
|
||||
return {}
|
||||
|
||||
return dict(existing)
|
||||
|
||||
async def _persist_pending_request_state(self, pending: dict[str, Any], ctx: WorkflowContext[Any]) -> None:
|
||||
await self._safe_set_shared_state(ctx, pending)
|
||||
await self._safe_set_runner_state(ctx, pending)
|
||||
|
||||
async def _safe_set_shared_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
|
||||
try:
|
||||
await ctx.set_shared_state(self._PENDING_SHARED_STATE_KEY, pending)
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"RequestInfoExecutor {self.id} failed to update shared pending state: {exc}")
|
||||
|
||||
async def _safe_set_runner_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
|
||||
try:
|
||||
await ctx.set_state({"pending_requests": pending})
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"RequestInfoExecutor {self.id} failed to update runner state with pending requests: {exc}")
|
||||
|
||||
def _build_request_snapshot(
|
||||
self,
|
||||
request: RequestInfoMessage,
|
||||
source_executor_id: str,
|
||||
) -> dict[str, Any]:
|
||||
snapshot: dict[str, Any] = {
|
||||
"request_id": request.request_id,
|
||||
"source_executor_id": source_executor_id,
|
||||
"request_type": f"{type(request).__module__}:{type(request).__name__}",
|
||||
"summary": repr(request),
|
||||
}
|
||||
|
||||
details = self._serialise_request_details(request)
|
||||
if details:
|
||||
snapshot["details"] = details
|
||||
for key in ("prompt", "draft", "iteration"):
|
||||
if key in details and key not in snapshot:
|
||||
snapshot[key] = details[key]
|
||||
|
||||
return snapshot
|
||||
|
||||
def _serialise_request_details(self, request: RequestInfoMessage) -> dict[str, Any] | None:
|
||||
if is_dataclass(request):
|
||||
data = self._make_json_safe(asdict(request))
|
||||
if isinstance(data, dict):
|
||||
return cast(dict[str, Any], data)
|
||||
return None
|
||||
|
||||
model_dump = getattr(request, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
try:
|
||||
dump = self._make_json_safe(model_dump(mode="json"))
|
||||
except TypeError:
|
||||
dump = self._make_json_safe(model_dump())
|
||||
if isinstance(dump, dict):
|
||||
return cast(dict[str, Any], dump)
|
||||
return None
|
||||
|
||||
attrs = getattr(request, "__dict__", None)
|
||||
if isinstance(attrs, dict):
|
||||
cleaned = self._make_json_safe(attrs)
|
||||
if isinstance(cleaned, dict):
|
||||
return cast(dict[str, Any], cleaned)
|
||||
|
||||
return None
|
||||
|
||||
def _make_json_safe(self, value: Any) -> Any:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, Mapping):
|
||||
safe_dict: dict[str, Any] = {}
|
||||
for key, val in value.items():
|
||||
safe_dict[str(key)] = self._make_json_safe(val)
|
||||
return safe_dict
|
||||
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
return [self._make_json_safe(item) for item in value]
|
||||
return repr(value)
|
||||
|
||||
async def has_pending_request(self, request_id: str, ctx: WorkflowContext[Any]) -> bool:
|
||||
if request_id in self._request_events or request_id in self._sub_workflow_contexts:
|
||||
return True
|
||||
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
|
||||
return snapshot is not None
|
||||
|
||||
async def _rehydrate_request_event(
|
||||
self,
|
||||
request_id: str,
|
||||
ctx: WorkflowContext[Any],
|
||||
) -> RequestInfoEvent | None:
|
||||
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
|
||||
if snapshot is None:
|
||||
return None
|
||||
|
||||
source_executor_id = snapshot.get("source_executor_id")
|
||||
if not isinstance(source_executor_id, str) or not source_executor_id:
|
||||
return None
|
||||
|
||||
request = self._construct_request_from_snapshot(snapshot)
|
||||
if request is None:
|
||||
return None
|
||||
|
||||
event = RequestInfoEvent(
|
||||
request_id=request_id,
|
||||
source_executor_id=source_executor_id,
|
||||
request_type=type(request),
|
||||
request_data=request,
|
||||
)
|
||||
self._request_events[request_id] = event
|
||||
return event
|
||||
|
||||
async def _get_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> dict[str, Any] | None:
|
||||
pending = await self._collect_pending_request_snapshots(ctx)
|
||||
snapshot = pending.get(request_id)
|
||||
if snapshot is None:
|
||||
return None
|
||||
return snapshot
|
||||
|
||||
async def _collect_pending_request_snapshots(self, ctx: WorkflowContext[Any]) -> dict[str, dict[str, Any]]:
|
||||
combined: dict[str, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
shared_pending = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"RequestInfoExecutor {self.id} failed to read shared pending state during rehydrate: {exc}")
|
||||
shared_pending = None
|
||||
|
||||
if isinstance(shared_pending, dict):
|
||||
for key, value in shared_pending.items():
|
||||
if isinstance(key, str) and isinstance(value, dict):
|
||||
combined[key] = cast(dict[str, Any], value)
|
||||
|
||||
try:
|
||||
state = await ctx.get_state()
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"RequestInfoExecutor {self.id} failed to read runner state during rehydrate: {exc}")
|
||||
state = None
|
||||
|
||||
if isinstance(state, dict):
|
||||
state_pending = state.get("pending_requests")
|
||||
if isinstance(state_pending, dict):
|
||||
for key, value in state_pending.items():
|
||||
if isinstance(key, str) and isinstance(value, dict) and key not in combined:
|
||||
combined[key] = cast(dict[str, Any], value)
|
||||
|
||||
return combined
|
||||
|
||||
def _construct_request_from_snapshot(self, snapshot: dict[str, Any]) -> RequestInfoMessage | None:
|
||||
details_raw = snapshot.get("details")
|
||||
details: dict[str, Any] = cast(dict[str, Any], details_raw) if isinstance(details_raw, dict) else {}
|
||||
|
||||
request_cls: type[RequestInfoMessage] = RequestInfoMessage
|
||||
request_type_str = snapshot.get("request_type")
|
||||
if isinstance(request_type_str, str) and ":" in request_type_str:
|
||||
module_name, class_name = request_type_str.split(":", 1)
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
candidate = getattr(module, class_name)
|
||||
if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage):
|
||||
request_cls = candidate
|
||||
except Exception as exc:
|
||||
logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}")
|
||||
request_cls = RequestInfoMessage
|
||||
|
||||
request: RequestInfoMessage | None = self._instantiate_request(request_cls, details)
|
||||
|
||||
if request is None and request_cls is not RequestInfoMessage:
|
||||
request = self._instantiate_request(RequestInfoMessage, details)
|
||||
|
||||
if request is None:
|
||||
logger.warning(
|
||||
f"RequestInfoExecutor {self.id} could not reconstruct request "
|
||||
f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(details.keys())}"
|
||||
)
|
||||
return None
|
||||
|
||||
for key, value in details.items():
|
||||
if key == "request_id":
|
||||
continue
|
||||
try:
|
||||
setattr(request, key, value)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}"
|
||||
)
|
||||
continue
|
||||
|
||||
snapshot_request_id = snapshot.get("request_id")
|
||||
if isinstance(snapshot_request_id, str) and snapshot_request_id:
|
||||
try:
|
||||
request.request_id = snapshot_request_id
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
f"RequestInfoExecutor {self.id} could not apply snapshot "
|
||||
f"request_id to {type(request).__name__}: {exc}"
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
def _instantiate_request(
|
||||
self,
|
||||
request_cls: type[RequestInfoMessage],
|
||||
details: dict[str, Any],
|
||||
) -> RequestInfoMessage | None:
|
||||
try:
|
||||
model_validate = getattr(request_cls, "model_validate", None)
|
||||
if callable(model_validate):
|
||||
return cast(RequestInfoMessage, model_validate(details))
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.debug(
|
||||
f"RequestInfoExecutor {self.id} validation failed for {request_cls.__name__} via model_validate: {exc}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"RequestInfoExecutor {self.id} encountered unexpected error during "
|
||||
f"{request_cls.__name__}.model_validate: {exc}"
|
||||
)
|
||||
|
||||
if is_dataclass(request_cls):
|
||||
try:
|
||||
field_names = {f.name for f in fields(request_cls)}
|
||||
ctor_kwargs = {name: details[name] for name in field_names if name in details}
|
||||
return request_cls(**ctor_kwargs) # type: ignore[call-arg]
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.debug(
|
||||
f"RequestInfoExecutor {self.id} could not instantiate dataclass "
|
||||
f"{request_cls.__name__} with snapshot data: {exc}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"RequestInfoExecutor {self.id} encountered unexpected error "
|
||||
f"constructing dataclass {request_cls.__name__}: {exc}"
|
||||
)
|
||||
|
||||
try:
|
||||
instance = request_cls() # type: ignore[call-arg]
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}"
|
||||
)
|
||||
return None
|
||||
|
||||
for key, value in details.items():
|
||||
if key == "request_id":
|
||||
continue
|
||||
try:
|
||||
setattr(instance, key, value)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
f"RequestInfoExecutor {self.id} could not set attribute {key} on "
|
||||
f"{request_cls.__name__} during instantiation: {exc}"
|
||||
)
|
||||
continue
|
||||
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def pending_requests_from_checkpoint(
|
||||
checkpoint: WorkflowCheckpoint,
|
||||
*,
|
||||
request_executor_ids: Iterable[str] | None = None,
|
||||
) -> list[PendingRequestDetails]:
|
||||
executor_filter: set[str] | None = None
|
||||
if request_executor_ids is not None:
|
||||
executor_filter = {str(value) for value in request_executor_ids}
|
||||
|
||||
pending: dict[str, PendingRequestDetails] = {}
|
||||
|
||||
shared_map = checkpoint.shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY)
|
||||
if isinstance(shared_map, Mapping):
|
||||
for request_id, snapshot in shared_map.items():
|
||||
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot)
|
||||
|
||||
for state in checkpoint.executor_states.values():
|
||||
if not isinstance(state, Mapping):
|
||||
continue
|
||||
inner = state.get("pending_requests")
|
||||
if isinstance(inner, Mapping):
|
||||
for request_id, snapshot in inner.items():
|
||||
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot)
|
||||
|
||||
for source_id, message_list in checkpoint.messages.items():
|
||||
if executor_filter is not None and source_id not in executor_filter:
|
||||
continue
|
||||
if not isinstance(message_list, list):
|
||||
continue
|
||||
for message in message_list:
|
||||
if not isinstance(message, Mapping):
|
||||
continue
|
||||
payload = _decode_checkpoint_value(message.get("data"))
|
||||
RequestInfoExecutor._merge_message_payload(pending, payload, message)
|
||||
|
||||
return list(pending.values())
|
||||
|
||||
@staticmethod
|
||||
def checkpoint_summary(
|
||||
checkpoint: WorkflowCheckpoint,
|
||||
*,
|
||||
request_executor_ids: Iterable[str] | None = None,
|
||||
preview_width: int = 70,
|
||||
) -> WorkflowCheckpointSummary:
|
||||
targets = sorted(checkpoint.messages.keys())
|
||||
executor_states = sorted(checkpoint.executor_states.keys())
|
||||
pending = RequestInfoExecutor.pending_requests_from_checkpoint(
|
||||
checkpoint, request_executor_ids=request_executor_ids
|
||||
)
|
||||
|
||||
draft_preview: str | None = None
|
||||
for entry in pending:
|
||||
if entry.draft:
|
||||
draft_preview = shorten(entry.draft, width=preview_width, placeholder="…")
|
||||
break
|
||||
|
||||
status = "idle"
|
||||
if pending:
|
||||
status = "awaiting human response"
|
||||
elif not checkpoint.messages and "finalise" in executor_states:
|
||||
status = "completed"
|
||||
elif checkpoint.messages:
|
||||
status = "awaiting next superstep"
|
||||
elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids):
|
||||
status = "awaiting request delivery"
|
||||
|
||||
return WorkflowCheckpointSummary(
|
||||
checkpoint_id=checkpoint.checkpoint_id,
|
||||
iteration_count=checkpoint.iteration_count,
|
||||
targets=targets,
|
||||
executor_states=executor_states,
|
||||
status=status,
|
||||
draft_preview=draft_preview,
|
||||
pending_requests=pending,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_snapshot(
|
||||
pending: dict[str, PendingRequestDetails],
|
||||
request_id: str,
|
||||
snapshot: Any,
|
||||
) -> None:
|
||||
if not request_id or not isinstance(snapshot, Mapping):
|
||||
return
|
||||
|
||||
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
|
||||
|
||||
RequestInfoExecutor._apply_update(
|
||||
details,
|
||||
prompt=snapshot.get("prompt"),
|
||||
draft=snapshot.get("draft"),
|
||||
iteration=snapshot.get("iteration"),
|
||||
source_executor_id=snapshot.get("source_executor_id"),
|
||||
)
|
||||
|
||||
extra = snapshot.get("details")
|
||||
if isinstance(extra, Mapping):
|
||||
RequestInfoExecutor._apply_update(
|
||||
details,
|
||||
prompt=extra.get("prompt"),
|
||||
draft=extra.get("draft"),
|
||||
iteration=extra.get("iteration"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_payload(
|
||||
pending: dict[str, PendingRequestDetails],
|
||||
payload: Any,
|
||||
raw_message: Mapping[str, Any],
|
||||
) -> None:
|
||||
if isinstance(payload, RequestResponse):
|
||||
request_id = payload.request_id or RequestInfoExecutor._get_field(payload.original_request, "request_id")
|
||||
if not request_id:
|
||||
return
|
||||
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
|
||||
RequestInfoExecutor._apply_update(
|
||||
details,
|
||||
prompt=RequestInfoExecutor._get_field(payload.original_request, "prompt"),
|
||||
draft=RequestInfoExecutor._get_field(payload.original_request, "draft"),
|
||||
iteration=RequestInfoExecutor._get_field(payload.original_request, "iteration"),
|
||||
source_executor_id=raw_message.get("source_id"),
|
||||
original_request=payload.original_request,
|
||||
)
|
||||
elif isinstance(payload, RequestInfoMessage):
|
||||
request_id = getattr(payload, "request_id", None)
|
||||
if not request_id:
|
||||
return
|
||||
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
|
||||
RequestInfoExecutor._apply_update(
|
||||
details,
|
||||
prompt=getattr(payload, "prompt", None),
|
||||
draft=getattr(payload, "draft", None),
|
||||
iteration=getattr(payload, "iteration", None),
|
||||
source_executor_id=raw_message.get("source_id"),
|
||||
original_request=payload,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _apply_update(
|
||||
details: PendingRequestDetails,
|
||||
*,
|
||||
prompt: Any = None,
|
||||
draft: Any = None,
|
||||
iteration: Any = None,
|
||||
source_executor_id: Any = None,
|
||||
original_request: Any = None,
|
||||
) -> None:
|
||||
if prompt and not details.prompt:
|
||||
details.prompt = str(prompt)
|
||||
if draft and not details.draft:
|
||||
details.draft = str(draft)
|
||||
if iteration is not None and details.iteration is None:
|
||||
coerced = RequestInfoExecutor._coerce_int(iteration)
|
||||
if coerced is not None:
|
||||
details.iteration = coerced
|
||||
if source_executor_id and not details.source_executor_id:
|
||||
details.source_executor_id = str(source_executor_id)
|
||||
if original_request is not None and details.original_request is None:
|
||||
details.original_request = original_request
|
||||
|
||||
@staticmethod
|
||||
def _get_field(obj: Any, key: str) -> Any:
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, Mapping):
|
||||
return obj.get(key)
|
||||
return getattr(obj, key, None)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> int | None:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
# endregion: Request Info Executor
|
||||
|
||||
@@ -840,7 +1340,11 @@ class AgentExecutor(Executor):
|
||||
exec_id = id
|
||||
else:
|
||||
agent_name = agent.name
|
||||
exec_id = str(agent_name) if agent_name else f"executor_{uuid.uuid4()}"
|
||||
if agent_name:
|
||||
exec_id = str(agent_name)
|
||||
else:
|
||||
logger.warning("Agent has no name, using fallback ID 'executor_unnamed'")
|
||||
exec_id = "executor_unnamed"
|
||||
super().__init__(exec_id)
|
||||
self._agent = agent
|
||||
self._agent_thread = agent_thread or self._agent.get_new_thread()
|
||||
@@ -952,12 +1456,12 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
workflow: "Workflow" = Field(description="The workflow to execute as a sub-workflow")
|
||||
|
||||
def __init__(self, workflow: "Workflow", id: str | None = None, **kwargs: Any):
|
||||
def __init__(self, workflow: "Workflow", id: str, **kwargs: Any):
|
||||
"""Initialize the WorkflowExecutor.
|
||||
|
||||
Args:
|
||||
workflow: The workflow to execute as a sub-workflow.
|
||||
id: Optional unique identifier for this executor.
|
||||
id: Unique identifier for this executor.
|
||||
**kwargs: Additional keyword arguments passed to the parent constructor.
|
||||
"""
|
||||
kwargs.update({"workflow": workflow})
|
||||
|
||||
@@ -1635,7 +1635,7 @@ class MagenticBuilder:
|
||||
if self._enable_plan_review:
|
||||
from ._executor import RequestInfoExecutor
|
||||
|
||||
request_info = RequestInfoExecutor()
|
||||
request_info = RequestInfoExecutor(id="request_info")
|
||||
workflow_builder = (
|
||||
workflow_builder
|
||||
# Only route plan review asks to request_info
|
||||
|
||||
@@ -13,7 +13,14 @@ from ._edge import EdgeGroup
|
||||
from ._edge_runner import EdgeRunner, create_edge_runner
|
||||
from ._events import WorkflowCompletedEvent, WorkflowEvent, _framework_event_origin
|
||||
from ._executor import Executor
|
||||
from ._runner_context import Message, RunnerContext
|
||||
from ._runner_context import (
|
||||
_DATACLASS_MARKER,
|
||||
_PYDANTIC_MARKER,
|
||||
CheckpointState,
|
||||
Message,
|
||||
RunnerContext,
|
||||
_decode_checkpoint_value,
|
||||
)
|
||||
from ._shared_state import SharedState
|
||||
from ._typing_utils import is_instance_of
|
||||
from ._workflow_context import WorkflowContext
|
||||
@@ -53,6 +60,7 @@ class Runner:
|
||||
self._workflow_id = workflow_id
|
||||
self._running = False
|
||||
self._resumed_from_checkpoint = False # Track whether we resumed
|
||||
self.graph_signature_hash: str | None = None
|
||||
|
||||
# Set workflow ID in context if provided
|
||||
if workflow_id:
|
||||
@@ -244,6 +252,19 @@ class Runner:
|
||||
"""Inner loop to deliver a single message through an edge runner."""
|
||||
return await edge_runner.send_message(message, self._shared_state, self._ctx)
|
||||
|
||||
def _normalize_message_payload(message: Message) -> None:
|
||||
data = message.data
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
if _PYDANTIC_MARKER not in data and _DATACLASS_MARKER not in data:
|
||||
return
|
||||
try:
|
||||
decoded = _decode_checkpoint_value(data)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug("Failed to decode checkpoint payload during delivery: %s", exc)
|
||||
return
|
||||
message.data = decoded
|
||||
|
||||
# Handle SubWorkflowRequestInfo messages specially
|
||||
await _deliver_sub_workflow_requests(messages)
|
||||
|
||||
@@ -266,6 +287,7 @@ class Runner:
|
||||
|
||||
associated_edge_runners = self._edge_runner_map.get(source_executor_id, [])
|
||||
for message in non_sub_workflow_messages:
|
||||
_normalize_message_payload(message)
|
||||
# Deliver a message through all edge runners associated with the source executor concurrently.
|
||||
tasks = [_deliver_message_inner(edge_runner, message) for edge_runner in associated_edge_runners]
|
||||
if not tasks:
|
||||
@@ -332,6 +354,8 @@ class Runner:
|
||||
"superstep": self._iteration,
|
||||
"checkpoint_type": checkpoint_category,
|
||||
}
|
||||
if self.graph_signature_hash:
|
||||
metadata["graph_signature"] = self.graph_signature_hash
|
||||
checkpoint_id = await self._ctx.create_checkpoint(metadata=metadata)
|
||||
logger.info(f"Created {checkpoint_type} checkpoint: {checkpoint_id}")
|
||||
return checkpoint_id
|
||||
@@ -403,14 +427,45 @@ class Runner:
|
||||
return False
|
||||
|
||||
try:
|
||||
success = await self._ctx.restore_from_checkpoint(checkpoint_id)
|
||||
if not success:
|
||||
checkpoint = await self._ctx.load_checkpoint(checkpoint_id)
|
||||
if not checkpoint:
|
||||
logger.error(f"Checkpoint {checkpoint_id} not found")
|
||||
return False
|
||||
|
||||
graph_hash = getattr(self, "graph_signature_hash", None)
|
||||
checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature")
|
||||
if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash:
|
||||
raise ValueError(
|
||||
"Workflow graph has changed since the checkpoint was created. "
|
||||
"Please rebuild the original workflow before resuming."
|
||||
)
|
||||
if graph_hash and not checkpoint_hash:
|
||||
logger.warning(
|
||||
"Checkpoint %s does not include graph signature metadata; skipping topology validation.",
|
||||
checkpoint_id,
|
||||
)
|
||||
|
||||
state: CheckpointState = {
|
||||
"messages": checkpoint.messages,
|
||||
"shared_state": checkpoint.shared_state,
|
||||
"executor_states": checkpoint.executor_states,
|
||||
"iteration_count": checkpoint.iteration_count,
|
||||
"max_iterations": checkpoint.max_iterations,
|
||||
}
|
||||
await self._ctx.set_checkpoint_state(state)
|
||||
if checkpoint.workflow_id:
|
||||
self._ctx.set_workflow_id(checkpoint.workflow_id)
|
||||
self._workflow_id = checkpoint.workflow_id
|
||||
|
||||
await self._restore_shared_state_from_context()
|
||||
self.mark_resumed() # mark resumed; iteration/max already restored from context
|
||||
self.mark_resumed(
|
||||
iteration=checkpoint.iteration_count,
|
||||
max_iterations=checkpoint.max_iterations,
|
||||
)
|
||||
logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}")
|
||||
return True
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}")
|
||||
return False
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
@@ -60,6 +61,39 @@ _MAX_ENCODE_DEPTH = 100
|
||||
_CYCLE_SENTINEL = "<cycle>"
|
||||
|
||||
|
||||
def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None:
|
||||
if not isinstance(cls, type):
|
||||
logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}")
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
try:
|
||||
return cls(**payload) # type: ignore[arg-type]
|
||||
except TypeError as exc:
|
||||
logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}")
|
||||
try:
|
||||
instance = object.__new__(cls)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}")
|
||||
return None
|
||||
for key, val in payload.items():
|
||||
try:
|
||||
setattr(instance, key, val)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}")
|
||||
return instance
|
||||
|
||||
try:
|
||||
return cls(payload) # type: ignore[call-arg]
|
||||
except TypeError as exc:
|
||||
logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def _is_pydantic_model(obj: object) -> bool:
|
||||
"""Best-effort check for Pydantic models (e.g., AFBaseModel).
|
||||
|
||||
@@ -99,7 +133,7 @@ def _encode_checkpoint_value(value: Any) -> Any:
|
||||
"value": v.model_dump(mode="json"),
|
||||
}
|
||||
except Exception as exc: # best-effort fallback
|
||||
logger.debug("Pydantic model_dump failed for %s: %s", cls, exc)
|
||||
logger.debug(f"Pydantic model_dump failed for {cls}: {exc}")
|
||||
return str(v)
|
||||
|
||||
# Dataclasses (instances only)
|
||||
@@ -178,36 +212,32 @@ def _decode_checkpoint_value(value: Any) -> Any:
|
||||
if isinstance(type_key, str):
|
||||
try:
|
||||
module_name, class_name = type_key.split(":", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
cls: Any = getattr(module, class_name)
|
||||
if hasattr(cls, "model_validate"):
|
||||
return cls.model_validate(raw)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Failed to decode pydantic model %s: %s; returning raw value",
|
||||
type_key,
|
||||
exc,
|
||||
)
|
||||
logger.debug(f"Failed to decode pydantic model {type_key}: {exc}; returning raw value")
|
||||
# Dataclass marker handling
|
||||
if _DATACLASS_MARKER in value_dict and "value" in value_dict:
|
||||
type_key_dc: str | None = value_dict.get(_DATACLASS_MARKER) # type: ignore[assignment]
|
||||
raw_dc: Any = value_dict.get("value")
|
||||
decoded_raw = _decode_checkpoint_value(raw_dc)
|
||||
if isinstance(type_key_dc, str):
|
||||
try:
|
||||
module_name, class_name = type_key_dc.split(":", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
cls_dc: Any = getattr(module, class_name)
|
||||
decoded_raw = _decode_checkpoint_value(raw_dc)
|
||||
if isinstance(decoded_raw, dict):
|
||||
return cls_dc(**decoded_raw)
|
||||
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
|
||||
if constructed is not None:
|
||||
return constructed
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Failed to decode dataclass %s: %s; returning raw value",
|
||||
type_key_dc,
|
||||
exc,
|
||||
)
|
||||
# Fallback to decoded raw value
|
||||
return _decode_checkpoint_value(raw_dc)
|
||||
logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value")
|
||||
return decoded_raw
|
||||
|
||||
# Regular dict: decode recursively
|
||||
decoded: dict[str, Any] = {}
|
||||
@@ -338,6 +368,10 @@ class RunnerContext(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
|
||||
"""Load a checkpoint without mutating the current context state."""
|
||||
...
|
||||
|
||||
async def get_checkpoint_state(self) -> CheckpointState:
|
||||
"""Get the current state of the context suitable for checkpointing."""
|
||||
...
|
||||
@@ -409,7 +443,7 @@ class InProcRunnerContext:
|
||||
return
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
# Best-effort filtering only; never block event delivery on filtering errors
|
||||
logger.debug("Error while filtering event %r: %s", event, exc, exc_info=True)
|
||||
logger.debug(f"Error while filtering event {event!r}: {exc}", exc_info=True)
|
||||
|
||||
await self._event_queue.put(event)
|
||||
|
||||
@@ -497,6 +531,11 @@ class InProcRunnerContext:
|
||||
logger.info(f"Restored state from checkpoint {checkpoint_id}'")
|
||||
return True
|
||||
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
|
||||
if not self._checkpoint_storage:
|
||||
raise ValueError("Checkpoint storage not configured")
|
||||
return await self._checkpoint_storage.load_checkpoint(checkpoint_id)
|
||||
|
||||
async def get_checkpoint_state(self) -> CheckpointState:
|
||||
serializable_messages: dict[str, list[dict[str, Any]]] = {}
|
||||
for source_id, message_list in self._messages.items():
|
||||
|
||||
@@ -1,7 +1,58 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _coerce_to_type(value: Any, target_type: type) -> Any | None:
|
||||
"""Best-effort conversion of value into target_type."""
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
|
||||
# Convert dataclass instances or objects with __dict__ into dict first
|
||||
if not isinstance(value, dict):
|
||||
if is_dataclass(value):
|
||||
value = {f.name: getattr(value, f.name) for f in fields(value)}
|
||||
else:
|
||||
value_dict = getattr(value, "__dict__", None)
|
||||
if isinstance(value_dict, dict):
|
||||
value = dict(value_dict)
|
||||
|
||||
if isinstance(value, dict):
|
||||
ctor_kwargs: dict[str, Any] = dict(value)
|
||||
|
||||
if is_dataclass(target_type):
|
||||
field_names = {f.name for f in fields(target_type)}
|
||||
ctor_kwargs = {k: v for k, v in value.items() if k in field_names}
|
||||
|
||||
try:
|
||||
return target_type(**ctor_kwargs) # type: ignore[arg-type]
|
||||
except TypeError as exc:
|
||||
logger.debug(f"_coerce_to_type could not call {target_type.__name__}(**..): {exc}")
|
||||
except Exception as exc: # pragma: no cover - unexpected constructor failure
|
||||
logger.warning(
|
||||
f"_coerce_to_type encountered unexpected error calling {target_type.__name__} constructor: {exc}"
|
||||
)
|
||||
try:
|
||||
instance: Any = object.__new__(target_type)
|
||||
except Exception as exc: # pragma: no cover - pathological type
|
||||
logger.debug(f"_coerce_to_type could not allocate {target_type.__name__} without __init__: {exc}")
|
||||
return None
|
||||
for key, val in value.items():
|
||||
try:
|
||||
setattr(instance, key, val)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
f"_coerce_to_type could not set {target_type.__name__}.{key} during fallback assignment: {exc}"
|
||||
)
|
||||
continue
|
||||
return instance
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_instance_of(data: Any, target_type: type) -> bool:
|
||||
"""Check if the data is an instance of the target type.
|
||||
@@ -63,7 +114,10 @@ def is_instance_of(data: Any, target_type: type) -> bool:
|
||||
and data.original_request is not None
|
||||
and not is_instance_of(data.original_request, request_type)
|
||||
):
|
||||
return False
|
||||
coerced = _coerce_to_type(data.original_request, request_type)
|
||||
if coerced is None:
|
||||
return False
|
||||
data.original_request = coerced
|
||||
if hasattr(data, "data") and data.data is not None and not is_instance_of(data.data, response_type):
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -34,6 +34,7 @@ class ValidationTypeEnum(Enum):
|
||||
"""Enumeration of workflow validation types."""
|
||||
|
||||
EDGE_DUPLICATION = "EDGE_DUPLICATION"
|
||||
EXECUTOR_DUPLICATION = "EXECUTOR_DUPLICATION"
|
||||
TYPE_COMPATIBILITY = "TYPE_COMPATIBILITY"
|
||||
GRAPH_CONNECTIVITY = "GRAPH_CONNECTIVITY"
|
||||
HANDLER_OUTPUT_ANNOTATION = "HANDLER_OUTPUT_ANNOTATION"
|
||||
@@ -63,6 +64,20 @@ class EdgeDuplicationError(WorkflowValidationError):
|
||||
self.edge_id = edge_id
|
||||
|
||||
|
||||
class ExecutorDuplicationError(WorkflowValidationError):
|
||||
"""Exception raised when duplicate executor identifiers are detected."""
|
||||
|
||||
def __init__(self, executor_id: str):
|
||||
super().__init__(
|
||||
message=(
|
||||
f"Duplicate executor id detected: '{executor_id}'. Executor ids must be globally unique within a "
|
||||
"workflow."
|
||||
),
|
||||
validation_type=ValidationTypeEnum.EXECUTOR_DUPLICATION,
|
||||
)
|
||||
self.executor_id = executor_id
|
||||
|
||||
|
||||
class TypeCompatibilityError(WorkflowValidationError):
|
||||
"""Exception raised when type incompatibility is detected between connected executors."""
|
||||
|
||||
@@ -133,10 +148,17 @@ class WorkflowGraphValidator:
|
||||
def __init__(self) -> None:
|
||||
self._edges: list[Edge] = []
|
||||
self._executors: dict[str, Executor] = {}
|
||||
self._duplicate_executor_ids: set[str] = set()
|
||||
self._start_executor_ref: Executor | str | None = None
|
||||
|
||||
# region Core Validation Methods
|
||||
def validate_workflow(
|
||||
self, edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str
|
||||
self,
|
||||
edge_groups: Sequence[EdgeGroup],
|
||||
executors: dict[str, Executor],
|
||||
start_executor: Executor | str,
|
||||
*,
|
||||
duplicate_executor_ids: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Validate the entire workflow graph.
|
||||
|
||||
@@ -144,6 +166,7 @@ class WorkflowGraphValidator:
|
||||
edge_groups: list of edge groups in the workflow
|
||||
executors: Map of executor IDs to executor instances
|
||||
start_executor: The starting executor (can be instance or ID)
|
||||
duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate
|
||||
|
||||
Raises:
|
||||
WorkflowValidationError: If any validation fails
|
||||
@@ -151,6 +174,8 @@ class WorkflowGraphValidator:
|
||||
self._executors = executors
|
||||
self._edges = [edge for group in edge_groups for edge in group.edges]
|
||||
self._edge_groups = edge_groups
|
||||
self._duplicate_executor_ids = set(duplicate_executor_ids or [])
|
||||
self._start_executor_ref = start_executor
|
||||
|
||||
# If only the start executor exists, add it to the executor map
|
||||
# Handle the special case where the workflow consists of only a single executor and no edges.
|
||||
@@ -185,6 +210,7 @@ class WorkflowGraphValidator:
|
||||
)
|
||||
|
||||
# Run all checks
|
||||
self._validate_executor_id_uniqueness(start_executor_id)
|
||||
self._validate_edge_duplication()
|
||||
self._validate_handler_output_annotations()
|
||||
self._validate_type_compatibility()
|
||||
@@ -353,6 +379,26 @@ class WorkflowGraphValidator:
|
||||
|
||||
# endregion
|
||||
|
||||
def _validate_executor_id_uniqueness(self, start_executor_id: str) -> None:
|
||||
"""Ensure executor identifiers are unique throughout the workflow graph."""
|
||||
duplicates: set[str] = set(self._duplicate_executor_ids)
|
||||
|
||||
id_counts: defaultdict[str, int] = defaultdict(int)
|
||||
for key, executor in self._executors.items():
|
||||
id_counts[executor.id] += 1
|
||||
if key != executor.id:
|
||||
duplicates.add(executor.id)
|
||||
|
||||
duplicates.update({executor_id for executor_id, count in id_counts.items() if count > 1})
|
||||
|
||||
if isinstance(self._start_executor_ref, Executor):
|
||||
mapped = self._executors.get(start_executor_id)
|
||||
if mapped is not None and mapped is not self._start_executor_ref:
|
||||
duplicates.add(start_executor_id)
|
||||
|
||||
if duplicates:
|
||||
raise ExecutorDuplicationError(sorted(duplicates)[0])
|
||||
|
||||
# region Edge and Type Validation
|
||||
def _validate_edge_duplication(self) -> None:
|
||||
"""Validate that there are no duplicate edges in the workflow.
|
||||
@@ -793,7 +839,11 @@ class WorkflowGraphValidator:
|
||||
|
||||
|
||||
def validate_workflow_graph(
|
||||
edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str
|
||||
edge_groups: Sequence[EdgeGroup],
|
||||
executors: dict[str, Executor],
|
||||
start_executor: Executor | str,
|
||||
*,
|
||||
duplicate_executor_ids: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Convenience function to validate a workflow graph.
|
||||
|
||||
@@ -801,9 +851,15 @@ def validate_workflow_graph(
|
||||
edge_groups: list of edge groups in the workflow
|
||||
executors: Map of executor IDs to executor instances
|
||||
start_executor: The starting executor (can be instance or ID)
|
||||
duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate
|
||||
|
||||
Raises:
|
||||
WorkflowValidationError: If any validation fails
|
||||
"""
|
||||
validator = WorkflowGraphValidator()
|
||||
validator.validate_workflow(edge_groups, executors, start_executor)
|
||||
validator.validate_workflow(
|
||||
edge_groups,
|
||||
executors,
|
||||
start_executor,
|
||||
duplicate_executor_ids=duplicate_executor_ids,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
@@ -177,6 +179,12 @@ class Workflow(AFBaseModel):
|
||||
workflow_id=id,
|
||||
)
|
||||
|
||||
# Capture a canonical fingerprint of the workflow graph so checkpoints
|
||||
# can assert they are resumed with an equivalent topology.
|
||||
self._graph_signature = self._compute_graph_signature()
|
||||
self._graph_signature_hash = self._hash_graph_signature(self._graph_signature)
|
||||
self._runner.graph_signature_hash = self._graph_signature_hash
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Custom serialization that properly handles WorkflowExecutor nested workflows."""
|
||||
data = super().model_dump(**kwargs)
|
||||
@@ -377,17 +385,26 @@ class Workflow(AFBaseModel):
|
||||
request_info_executor = self._find_request_info_executor()
|
||||
if request_info_executor:
|
||||
for request_id, response_data in responses.items():
|
||||
ctx: WorkflowContext[Any] = WorkflowContext(
|
||||
request_info_executor.id,
|
||||
[self.__class__.__name__],
|
||||
self._shared_state,
|
||||
self._runner.context,
|
||||
trace_contexts=None, # No parent trace context for new workflow span
|
||||
source_span_ids=None, # No source span for response handling
|
||||
)
|
||||
|
||||
if not await request_info_executor.has_pending_request(request_id, ctx):
|
||||
logger.debug(
|
||||
f"Skipping pre-supplied response for request {request_id}; no pending request found "
|
||||
f"after checkpoint restoration."
|
||||
)
|
||||
continue
|
||||
|
||||
await request_info_executor.handle_response(
|
||||
response_data,
|
||||
request_id,
|
||||
WorkflowContext(
|
||||
request_info_executor.id,
|
||||
[self.__class__.__name__],
|
||||
self._shared_state,
|
||||
self._runner.context,
|
||||
trace_contexts=None, # No parent trace context for new workflow span
|
||||
source_span_ids=None, # No source span for response handling
|
||||
),
|
||||
ctx,
|
||||
)
|
||||
|
||||
async for event in self._run_workflow_with_tracing(
|
||||
@@ -590,6 +607,19 @@ class Workflow(AFBaseModel):
|
||||
if not checkpoint:
|
||||
return False
|
||||
|
||||
graph_hash = getattr(self._runner, "graph_signature_hash", None)
|
||||
checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature")
|
||||
if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash:
|
||||
raise ValueError(
|
||||
"Workflow graph has changed since the checkpoint was created. "
|
||||
"Please rebuild the original workflow before resuming."
|
||||
)
|
||||
if graph_hash and not checkpoint_hash:
|
||||
logger.warning(
|
||||
f"Checkpoint {checkpoint_id} does not include graph signature metadata; "
|
||||
f"skipping topology validation."
|
||||
)
|
||||
|
||||
temp_context = InProcRunnerContext(checkpoint_storage)
|
||||
state: CheckpointState = {
|
||||
"messages": checkpoint.messages,
|
||||
@@ -608,10 +638,9 @@ class Workflow(AFBaseModel):
|
||||
|
||||
return True
|
||||
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Failed to restore from external checkpoint {checkpoint_id}: {e}")
|
||||
return False
|
||||
|
||||
@@ -632,7 +661,7 @@ class Workflow(AFBaseModel):
|
||||
self._shared_state._state.clear() # type: ignore[attr-defined]
|
||||
self._shared_state._state.update(shared_state_data) # type: ignore[attr-defined]
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.debug("Failed to restore shared_state during external restore: %s", exc)
|
||||
logger.debug(f"Failed to restore shared_state during external restore: {exc}")
|
||||
|
||||
# Restore executor states into the context so ctx.get_state() calls after resume succeed
|
||||
try:
|
||||
@@ -641,9 +670,9 @@ class Workflow(AFBaseModel):
|
||||
try:
|
||||
await self._runner.context.set_state(exec_id, state)
|
||||
except Exception as exc: # pragma: no cover - ignore per-executor failures
|
||||
logger.debug("Failed to restore executor state for %s during external restore: %s", exec_id, exc)
|
||||
logger.debug(f"Failed to restore executor state for {exec_id} during external restore: {exc}")
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.debug("Failed to iterate executor_states during external restore: %s", exc)
|
||||
logger.debug(f"Failed to iterate executor_states during external restore: {exc}")
|
||||
|
||||
# Transfer pending messages into the context for delivery in the next superstep
|
||||
messages_data = restored_state["messages"]
|
||||
@@ -671,6 +700,71 @@ class Workflow(AFBaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
# Graph signature helpers
|
||||
|
||||
def _compute_graph_signature(self) -> dict[str, Any]:
|
||||
"""Build a canonical fingerprint of the workflow graph topology for checkpoint validation.
|
||||
|
||||
This creates a minimal, stable representation that captures only the structural
|
||||
elements of the workflow (executor types, edge relationships, topology) while
|
||||
ignoring data/state changes. Used to verify that a workflow's structure hasn't
|
||||
changed when resuming from checkpoints.
|
||||
"""
|
||||
executors_signature = {
|
||||
executor_id: f"{executor.__class__.__module__}.{executor.__class__.__name__}"
|
||||
for executor_id, executor in self.executors.items()
|
||||
}
|
||||
|
||||
edge_groups_signature: list[dict[str, Any]] = []
|
||||
for group in self.edge_groups:
|
||||
edges = [
|
||||
{
|
||||
"source": edge.source_id,
|
||||
"target": edge.target_id,
|
||||
"condition": getattr(edge, "condition_name", None),
|
||||
}
|
||||
for edge in group.edges
|
||||
]
|
||||
edges.sort(key=lambda e: (e["source"], e["target"], e["condition"] or ""))
|
||||
|
||||
group_info: dict[str, Any] = {
|
||||
"group_type": group.__class__.__name__,
|
||||
"sources": sorted(group.source_executor_ids),
|
||||
"targets": sorted(group.target_executor_ids),
|
||||
"edges": edges,
|
||||
}
|
||||
|
||||
if isinstance(group, FanOutEdgeGroup):
|
||||
group_info["selection_func"] = group.selection_func_name
|
||||
|
||||
edge_groups_signature.append(group_info)
|
||||
|
||||
edge_groups_signature.sort(
|
||||
key=lambda info: (
|
||||
info["group_type"],
|
||||
tuple(info["sources"]),
|
||||
tuple(info["targets"]),
|
||||
json.dumps(info["edges"], sort_keys=True),
|
||||
json.dumps(info.get("selection_func")),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"start_executor": self.start_executor_id,
|
||||
"executors": executors_signature,
|
||||
"edge_groups": edge_groups_signature,
|
||||
"max_iterations": self.max_iterations,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _hash_graph_signature(signature: dict[str, Any]) -> str:
|
||||
canonical = json.dumps(signature, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
||||
|
||||
@property
|
||||
def graph_signature_hash(self) -> str:
|
||||
return self._graph_signature_hash
|
||||
|
||||
def as_agent(self, name: str | None = None) -> WorkflowAgent:
|
||||
"""Create a WorkflowAgent that wraps this workflow.
|
||||
|
||||
@@ -699,6 +793,7 @@ class WorkflowBuilder:
|
||||
"""Initialize the WorkflowBuilder with an empty list of edges and no starting executor."""
|
||||
self._edge_groups: list[EdgeGroup] = []
|
||||
self._executors: dict[str, Executor] = {}
|
||||
self._duplicate_executor_ids: set[str] = set()
|
||||
self._start_executor: Executor | str | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = None
|
||||
self._max_iterations: int = max_iterations
|
||||
@@ -712,7 +807,11 @@ class WorkflowBuilder:
|
||||
|
||||
def _add_executor(self, executor: Executor) -> str:
|
||||
"""Add an executor to the map and return its ID."""
|
||||
self._executors[executor.id] = executor
|
||||
existing = self._executors.get(executor.id)
|
||||
if existing is not None and existing is not executor:
|
||||
self._duplicate_executor_ids.add(executor.id)
|
||||
else:
|
||||
self._executors[executor.id] = executor
|
||||
return executor.id
|
||||
|
||||
def _maybe_wrap_agent(self, candidate: Executor | AgentProtocol) -> Executor:
|
||||
@@ -739,7 +838,10 @@ class WorkflowBuilder:
|
||||
if name:
|
||||
proposed_id = str(name)
|
||||
if proposed_id in self._executors:
|
||||
proposed_id = f"{proposed_id}-{uuid.uuid4().hex[:8]}"
|
||||
raise ValueError(
|
||||
f"Duplicate executor ID '{proposed_id}' from agent name. "
|
||||
"Agent names must be unique within a workflow."
|
||||
)
|
||||
wrapper = AgentExecutor(candidate, id=proposed_id, streaming=True)
|
||||
self._agent_wrappers[id(candidate)] = wrapper
|
||||
return wrapper
|
||||
@@ -934,8 +1036,9 @@ class WorkflowBuilder:
|
||||
self._start_executor = wrapped
|
||||
# Ensure the start executor is present in the executor map so validation succeeds
|
||||
# even if no edges are added yet, or before edges wrap the same agent again.
|
||||
if wrapped.id not in self._executors:
|
||||
self._executors[wrapped.id] = wrapped
|
||||
existing = self._executors.get(wrapped.id)
|
||||
if existing is not wrapped:
|
||||
self._add_executor(wrapped)
|
||||
return self
|
||||
|
||||
def set_max_iterations(self, max_iterations: int) -> Self:
|
||||
@@ -986,7 +1089,12 @@ class WorkflowBuilder:
|
||||
)
|
||||
|
||||
# Perform validation before creating the workflow
|
||||
validate_workflow_graph(self._edge_groups, self._executors, self._start_executor)
|
||||
validate_workflow_graph(
|
||||
self._edge_groups,
|
||||
self._executors,
|
||||
self._start_executor,
|
||||
duplicate_executor_ids=tuple(self._duplicate_executor_ids),
|
||||
)
|
||||
|
||||
# Add validation completed event
|
||||
workflow_tracer.add_build_event("build.validation_completed")
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework._workflow._executor import RequestInfoMessage, RequestResponse
|
||||
from agent_framework._workflow._runner_context import _decode_checkpoint_value, _encode_checkpoint_value # type: ignore
|
||||
from agent_framework._workflow._typing_utils import is_instance_of
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleRequest(RequestInfoMessage):
|
||||
prompt: str
|
||||
|
||||
|
||||
def test_decode_dataclass_with_nested_request() -> None:
|
||||
original = RequestResponse[SampleRequest, str].handled("approve")
|
||||
original = RequestResponse[SampleRequest, str].with_correlation(
|
||||
original,
|
||||
SampleRequest(request_id="abc", prompt="prompt"),
|
||||
"abc",
|
||||
)
|
||||
|
||||
encoded = _encode_checkpoint_value(original)
|
||||
decoded = cast(RequestResponse[SampleRequest, str], _decode_checkpoint_value(encoded))
|
||||
|
||||
assert isinstance(decoded, RequestResponse)
|
||||
assert decoded.data == "approve"
|
||||
assert decoded.request_id == "abc"
|
||||
assert isinstance(decoded.original_request, SampleRequest)
|
||||
assert decoded.original_request.prompt == "prompt"
|
||||
|
||||
|
||||
def test_is_instance_of_coerces_request_response_original_request_dict() -> None:
|
||||
response = RequestResponse[SampleRequest, str].handled("approve")
|
||||
response = RequestResponse[SampleRequest, str].with_correlation(
|
||||
response,
|
||||
SampleRequest(request_id="req-1", prompt="prompt"),
|
||||
"req-1",
|
||||
)
|
||||
|
||||
# Simulate checkpoint decode fallback leaving a dict
|
||||
response.original_request = cast(
|
||||
Any,
|
||||
{
|
||||
"request_id": "req-1",
|
||||
"prompt": "prompt",
|
||||
},
|
||||
)
|
||||
|
||||
assert is_instance_of(response, RequestResponse[SampleRequest, str])
|
||||
assert isinstance(response.original_request, SampleRequest)
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, handler
|
||||
from agent_framework._workflow._checkpoint import InMemoryCheckpointStorage
|
||||
from agent_framework._workflow._executor import Executor
|
||||
|
||||
|
||||
class StartExecutor(Executor):
|
||||
@handler
|
||||
async def run(self, message: str, ctx: WorkflowContext[str]) -> None:
|
||||
await ctx.send_message(message, target_id="finish")
|
||||
|
||||
|
||||
class FinishExecutor(Executor):
|
||||
@handler
|
||||
async def finish(self, message: str, ctx: WorkflowContext[None]) -> None:
|
||||
await ctx.add_event(WorkflowCompletedEvent(message))
|
||||
|
||||
|
||||
def build_workflow(storage: InMemoryCheckpointStorage, finish_id: str = "finish"):
|
||||
start = StartExecutor(id="start")
|
||||
finish = FinishExecutor(id=finish_id)
|
||||
|
||||
builder = WorkflowBuilder(max_iterations=3).set_start_executor(start).add_edge(start, finish)
|
||||
builder = builder.with_checkpointing(checkpoint_storage=storage)
|
||||
return builder.build()
|
||||
|
||||
|
||||
async def test_resume_fails_when_graph_mismatch() -> None:
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = build_workflow(storage, finish_id="finish")
|
||||
|
||||
# Run once to create checkpoints
|
||||
_ = [event async for event in workflow.run_stream("hello")] # noqa: F841
|
||||
|
||||
checkpoints = await storage.list_checkpoints()
|
||||
assert checkpoints, "expected at least one checkpoint to be created"
|
||||
target_checkpoint = checkpoints[-1]
|
||||
|
||||
# Build a structurally different workflow (different finish executor id)
|
||||
mismatched_workflow = build_workflow(storage, finish_id="finish_alt")
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow graph has changed"):
|
||||
_ = [
|
||||
event
|
||||
async for event in mismatched_workflow.run_stream_from_checkpoint(
|
||||
target_checkpoint.checkpoint_id,
|
||||
checkpoint_storage=storage,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def test_resume_succeeds_when_graph_matches() -> None:
|
||||
storage = InMemoryCheckpointStorage()
|
||||
workflow = build_workflow(storage, finish_id="finish")
|
||||
_ = [event async for event in workflow.run_stream("hello")] # noqa: F841
|
||||
|
||||
checkpoints = sorted(await storage.list_checkpoints(), key=lambda c: c.timestamp)
|
||||
target_checkpoint = checkpoints[0]
|
||||
|
||||
resumed_workflow = build_workflow(storage, finish_id="finish")
|
||||
|
||||
events = [
|
||||
event
|
||||
async for event in resumed_workflow.run_stream_from_checkpoint(
|
||||
target_checkpoint.checkpoint_id,
|
||||
checkpoint_storage=storage,
|
||||
)
|
||||
]
|
||||
|
||||
assert any(isinstance(event, WorkflowCompletedEvent) for event in events)
|
||||
@@ -126,3 +126,17 @@ async def test_concurrent_custom_aggregator_sync_callback_is_used() -> None:
|
||||
assert completed is not None
|
||||
assert isinstance(completed.data, str)
|
||||
assert completed.data == "One | Two"
|
||||
|
||||
|
||||
def test_concurrent_custom_aggregator_uses_callback_name_for_id() -> None:
|
||||
e1 = _FakeAgentExec("agentA", "One")
|
||||
e2 = _FakeAgentExec("agentB", "Two")
|
||||
|
||||
def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override]
|
||||
return str(len(results))
|
||||
|
||||
wf = ConcurrentBuilder().participants([e1, e2]).with_aggregator(summarize).build()
|
||||
|
||||
assert "summarize" in wf.executors
|
||||
aggregator = wf.executors["summarize"]
|
||||
assert aggregator.id == "summarize"
|
||||
|
||||
@@ -5,16 +5,16 @@ import pytest
|
||||
from agent_framework import Executor, WorkflowContext, handler
|
||||
|
||||
|
||||
def test_executor_without_handlers():
|
||||
"""Test that an executor without handlers raises an error when trying to run."""
|
||||
def test_executor_without_id():
|
||||
"""Test that an executor without an ID raises an error when trying to run."""
|
||||
|
||||
class MockExecutorWithoutHandlers(Executor):
|
||||
class MockExecutorWithoutID(Executor):
|
||||
"""A mock executor that does not implement any handlers."""
|
||||
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
MockExecutorWithoutHandlers()
|
||||
MockExecutorWithoutID(id="")
|
||||
|
||||
|
||||
def test_executor_handler_without_annotations():
|
||||
@@ -61,7 +61,7 @@ def test_executor_with_valid_handlers():
|
||||
"""Another mock handler with a valid signature."""
|
||||
pass
|
||||
|
||||
executor = MockExecutorWithValidHandlers()
|
||||
executor = MockExecutorWithValidHandlers(id="test")
|
||||
assert executor.id is not None
|
||||
assert len(executor._handlers) == 2 # type: ignore
|
||||
assert executor.can_handle("text") is True
|
||||
@@ -85,7 +85,7 @@ def test_executor_handlers_with_output_types():
|
||||
"""A mock handler that outputs an integer."""
|
||||
pass
|
||||
|
||||
executor = MockExecutorWithOutputTypes()
|
||||
executor = MockExecutorWithOutputTypes(id="test")
|
||||
assert len(executor._handlers) == 2 # type: ignore
|
||||
|
||||
string_handler = executor._handlers[str] # type: ignore
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework._workflow._checkpoint import WorkflowCheckpoint
|
||||
from agent_framework._workflow._events import WorkflowEvent
|
||||
from agent_framework._workflow._executor import (
|
||||
PendingRequestDetails,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
)
|
||||
from agent_framework._workflow._runner_context import CheckpointState, Message, _encode_checkpoint_value # type: ignore
|
||||
from agent_framework._workflow._shared_state import SharedState
|
||||
from agent_framework._workflow._workflow_context import WorkflowContext
|
||||
|
||||
PENDING_STATE_KEY = RequestInfoExecutor._PENDING_SHARED_STATE_KEY # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
class _StubRunnerContext:
|
||||
"""Minimal runner context stub for exercising WorkflowContext helpers."""
|
||||
|
||||
def __init__(self, stored_state: dict[str, Any] | None = None) -> None:
|
||||
self._state = stored_state or {}
|
||||
|
||||
async def send_message(self, message: Message) -> None: # pragma: no cover - unused in tests
|
||||
return None
|
||||
|
||||
async def drain_messages(self) -> dict[str, list[Message]]: # pragma: no cover - unused
|
||||
return {}
|
||||
|
||||
async def has_messages(self) -> bool: # pragma: no cover - unused
|
||||
return False
|
||||
|
||||
async def add_event(self, event: WorkflowEvent) -> None: # pragma: no cover - unused
|
||||
return None
|
||||
|
||||
async def drain_events(self) -> list[WorkflowEvent]: # pragma: no cover - unused
|
||||
return []
|
||||
|
||||
async def has_events(self) -> bool: # pragma: no cover - unused
|
||||
return False
|
||||
|
||||
async def next_event(self) -> WorkflowEvent: # pragma: no cover - unused
|
||||
raise RuntimeError("Not implemented in stub context")
|
||||
|
||||
async def get_state(self, executor_id: str) -> dict[str, Any] | None: # pragma: no cover - trivial
|
||||
return self._state
|
||||
|
||||
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None: # pragma: no cover - unused
|
||||
self._state = state
|
||||
|
||||
def has_checkpointing(self) -> bool: # pragma: no cover - unused
|
||||
return False
|
||||
|
||||
def set_workflow_id(self, workflow_id: str) -> None: # pragma: no cover - unused
|
||||
pass
|
||||
|
||||
def reset_for_new_run(self, workflow_shared_state: SharedState | None = None) -> None: # pragma: no cover - unused
|
||||
pass
|
||||
|
||||
async def create_checkpoint(self, metadata: dict[str, Any] | None = None) -> str: # pragma: no cover - unused
|
||||
raise RuntimeError("Checkpointing not supported in stub context")
|
||||
|
||||
async def restore_from_checkpoint(self, checkpoint_id: str) -> bool: # pragma: no cover - unused
|
||||
return False
|
||||
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pragma: no cover - unused
|
||||
return None
|
||||
|
||||
async def get_checkpoint_state(self) -> CheckpointState: # pragma: no cover - unused
|
||||
return {} # type: ignore[return-value]
|
||||
|
||||
async def set_checkpoint_state(self, state: CheckpointState) -> None: # pragma: no cover - unused
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SimpleApproval(RequestInfoMessage):
|
||||
prompt: str = ""
|
||||
draft: str = ""
|
||||
iteration: int = 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rehydrate_falls_back_when_request_type_missing() -> None:
|
||||
"""Rehydration should succeed even if the original request type cannot be imported.
|
||||
|
||||
This simulates resuming a workflow where the HumanApprovalRequest class is unavailable
|
||||
in the current process (e.g., defined in __main__ during the original run).
|
||||
"""
|
||||
|
||||
request_id = "request-123"
|
||||
snapshot = {
|
||||
"request_id": request_id,
|
||||
"source_executor_id": "review_gateway",
|
||||
"request_type": "nonexistent.module:MissingRequest",
|
||||
"summary": "...",
|
||||
"details": {
|
||||
"request_id": request_id,
|
||||
"prompt": "Review draft",
|
||||
"draft": "Draft text",
|
||||
"iteration": 2,
|
||||
},
|
||||
}
|
||||
|
||||
shared_state = SharedState()
|
||||
async with shared_state.hold():
|
||||
await shared_state.set_within_hold(
|
||||
PENDING_STATE_KEY,
|
||||
{request_id: snapshot},
|
||||
)
|
||||
|
||||
runner_ctx = _StubRunnerContext({"pending_requests": {request_id: snapshot}})
|
||||
ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx)
|
||||
|
||||
executor = RequestInfoExecutor(id="request_info")
|
||||
|
||||
event = await executor._rehydrate_request_event(request_id, ctx) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert event is not None
|
||||
assert event.request_id == request_id
|
||||
assert isinstance(event.data, RequestInfoMessage)
|
||||
assert getattr(event.data, "prompt", None) == "Review draft"
|
||||
assert getattr(event.data, "iteration", None) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_pending_request_detects_snapshot() -> None:
|
||||
request_id = "req-pending"
|
||||
snapshot = {
|
||||
"request_id": request_id,
|
||||
"source_executor_id": "review_gateway",
|
||||
"details": {
|
||||
"request_id": request_id,
|
||||
"prompt": "Review",
|
||||
"draft": "Draft",
|
||||
},
|
||||
}
|
||||
|
||||
shared_state = SharedState()
|
||||
async with shared_state.hold():
|
||||
await shared_state.set_within_hold(
|
||||
PENDING_STATE_KEY,
|
||||
{request_id: snapshot},
|
||||
)
|
||||
|
||||
runner_ctx = _StubRunnerContext({"pending_requests": {request_id: snapshot}})
|
||||
ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx)
|
||||
|
||||
executor = RequestInfoExecutor(id="request_info")
|
||||
|
||||
assert await executor.has_pending_request(request_id, ctx)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_pending_request_false_when_snapshot_absent() -> None:
|
||||
shared_state = SharedState()
|
||||
runner_ctx = _StubRunnerContext({"pending_requests": {}})
|
||||
ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx)
|
||||
|
||||
executor = RequestInfoExecutor(id="request_info")
|
||||
|
||||
assert not await executor.has_pending_request("missing", ctx)
|
||||
|
||||
|
||||
def test_pending_requests_from_checkpoint_and_summary() -> None:
|
||||
request = SimpleApproval(prompt="Review draft", draft="Draft text", iteration=3)
|
||||
request.request_id = "req-42"
|
||||
|
||||
response = RequestResponse[SimpleApproval, str].handled("approve")
|
||||
response = RequestResponse[SimpleApproval, str].with_correlation(
|
||||
response,
|
||||
request,
|
||||
request.request_id,
|
||||
)
|
||||
|
||||
encoded_response = _encode_checkpoint_value(response)
|
||||
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
checkpoint_id="cp-1",
|
||||
workflow_id="wf",
|
||||
messages={
|
||||
"request_info": [
|
||||
{
|
||||
"data": encoded_response,
|
||||
"source_id": "request_info",
|
||||
"target_id": "review_gateway",
|
||||
}
|
||||
]
|
||||
},
|
||||
shared_state={
|
||||
PENDING_STATE_KEY: {
|
||||
request.request_id: {
|
||||
"request_id": request.request_id,
|
||||
"prompt": request.prompt,
|
||||
"draft": request.draft,
|
||||
"iteration": request.iteration,
|
||||
"source_executor_id": "review_gateway",
|
||||
}
|
||||
}
|
||||
},
|
||||
executor_states={},
|
||||
iteration_count=1,
|
||||
)
|
||||
|
||||
pending = RequestInfoExecutor.pending_requests_from_checkpoint(checkpoint)
|
||||
assert len(pending) == 1
|
||||
entry = pending[0]
|
||||
assert isinstance(entry, PendingRequestDetails)
|
||||
assert entry.request_id == "req-42"
|
||||
assert entry.prompt == "Review draft"
|
||||
assert entry.draft == "Draft text"
|
||||
assert entry.iteration == 3
|
||||
assert entry.original_request is not None
|
||||
|
||||
summary = RequestInfoExecutor.checkpoint_summary(checkpoint)
|
||||
assert summary.checkpoint_id == "cp-1"
|
||||
assert summary.status == "awaiting human response"
|
||||
assert summary.pending_requests[0].request_id == "req-42"
|
||||
@@ -605,10 +605,7 @@ class TestSerializationWorkflowClasses:
|
||||
executor = SampleExecutor(id="valid-id")
|
||||
assert executor.id == "valid-id"
|
||||
|
||||
# Test validation failure for empty id - pydantic automatically validates min_length=1
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(ValueError):
|
||||
SampleExecutor(id="")
|
||||
|
||||
def test_edge_field_validation(self) -> None:
|
||||
|
||||
@@ -200,7 +200,7 @@ async def test_sub_workflow_with_interception():
|
||||
# Create parent workflow with interception
|
||||
parent = ParentOrchestrator(approved_domains={"example.com", "internal.org"})
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
|
||||
parent_request_info = RequestInfoExecutor()
|
||||
parent_request_info = RequestInfoExecutor(id="request_info")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
@@ -280,7 +280,7 @@ async def test_conditional_forwarding() -> None:
|
||||
|
||||
# Setup workflows
|
||||
email_validator = EmailValidator()
|
||||
request_info = RequestInfoExecutor()
|
||||
request_info = RequestInfoExecutor(id="request_info")
|
||||
|
||||
validation_workflow = (
|
||||
WorkflowBuilder()
|
||||
@@ -292,7 +292,7 @@ async def test_conditional_forwarding() -> None:
|
||||
|
||||
parent = ConditionalParent()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
|
||||
parent_request_info = RequestInfoExecutor()
|
||||
parent_request_info = RequestInfoExecutor(id="request_info")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
@@ -364,7 +364,7 @@ async def test_workflow_scoped_interception() -> None:
|
||||
# Create two identical sub-workflows
|
||||
def create_validation_workflow():
|
||||
validator = EmailValidator()
|
||||
request_info = RequestInfoExecutor()
|
||||
request_info = RequestInfoExecutor(id="request_info")
|
||||
return (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(validator)
|
||||
@@ -379,7 +379,7 @@ async def test_workflow_scoped_interception() -> None:
|
||||
parent = MultiWorkflowParent()
|
||||
executor_a = WorkflowExecutor(workflow_a, id="workflow_a")
|
||||
executor_b = WorkflowExecutor(workflow_b, id="workflow_b")
|
||||
parent_request_info = RequestInfoExecutor()
|
||||
parent_request_info = RequestInfoExecutor(id="request_info")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
from agent_framework import (
|
||||
EdgeDuplicationError,
|
||||
Executor,
|
||||
ExecutorDuplicationError,
|
||||
GraphConnectivityError,
|
||||
TypeCompatibilityError,
|
||||
ValidationTypeEnum,
|
||||
@@ -79,6 +80,17 @@ def test_valid_workflow_passes_validation():
|
||||
assert workflow is not None
|
||||
|
||||
|
||||
def test_duplicate_executor_ids_fail_validation():
|
||||
executor1 = StringExecutor(id="dup")
|
||||
executor2 = IntExecutor(id="dup")
|
||||
|
||||
with pytest.raises(ExecutorDuplicationError) as exc_info:
|
||||
(WorkflowBuilder().add_edge(executor1, executor2).set_start_executor(executor1).build())
|
||||
|
||||
assert exc_info.value.executor_id == "dup"
|
||||
assert exc_info.value.validation_type == ValidationTypeEnum.EXECUTOR_DUPLICATION
|
||||
|
||||
|
||||
def test_edge_duplication_validation_fails():
|
||||
executor1 = StringExecutor(id="executor1")
|
||||
executor2 = StringExecutor(id="executor2")
|
||||
|
||||
@@ -163,7 +163,7 @@ async def test_workflow_send_responses_streaming():
|
||||
"""Test the workflow run with approval."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = MockExecutorRequestApproval(id="executor_b")
|
||||
request_info_executor = RequestInfoExecutor()
|
||||
request_info_executor = RequestInfoExecutor(id="request_info")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
@@ -195,7 +195,7 @@ async def test_workflow_send_responses():
|
||||
"""Test the workflow run with approval."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = MockExecutorRequestApproval(id="executor_b")
|
||||
request_info_executor = RequestInfoExecutor()
|
||||
request_info_executor = RequestInfoExecutor(id="request_info")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
|
||||
@@ -11,6 +11,7 @@ from agent_framework import (
|
||||
AgentRunUpdateEvent,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
@@ -94,8 +95,8 @@ class TestWorkflowAgent:
|
||||
assert len(result.messages) >= 2, f"Expected at least 2 messages, got {len(result.messages)}"
|
||||
|
||||
# Find messages from each executor
|
||||
step1_messages = []
|
||||
step2_messages = []
|
||||
step1_messages: list[ChatMessage] = []
|
||||
step2_messages: list[ChatMessage] = []
|
||||
|
||||
for message in result.messages:
|
||||
first_content = message.contents[0]
|
||||
@@ -111,8 +112,8 @@ class TestWorkflowAgent:
|
||||
assert len(step2_messages) >= 1, "Should have received message from Step2 executor"
|
||||
|
||||
# Verify the processing worked for both
|
||||
step1_text = step1_messages[0].contents[0].text
|
||||
step2_text = step2_messages[0].contents[0].text
|
||||
step1_text: str = step1_messages[0].contents[0].text # type: ignore[attr-defined]
|
||||
step2_text: str = step2_messages[0].contents[0].text # type: ignore[attr-defined]
|
||||
assert "Step1: Hello World" in step1_text
|
||||
assert "Step2: Step1: Hello World" in step2_text
|
||||
|
||||
@@ -128,7 +129,7 @@ class TestWorkflowAgent:
|
||||
agent = WorkflowAgent(workflow=workflow, name="Streaming Test Agent")
|
||||
|
||||
# Execute workflow streaming to capture streaming events
|
||||
updates = []
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream("Test input"):
|
||||
updates.append(update)
|
||||
|
||||
@@ -137,8 +138,8 @@ class TestWorkflowAgent:
|
||||
|
||||
# Verify we got a streaming update
|
||||
assert updates[0].contents is not None
|
||||
first_content = updates[0].contents[0]
|
||||
second_content = updates[1].contents[0]
|
||||
first_content: TextContent = updates[0].contents[0] # type: ignore[assignment]
|
||||
second_content: TextContent = updates[1].contents[0] # type: ignore[assignment]
|
||||
assert isinstance(first_content, TextContent)
|
||||
assert "Streaming1: Test input" in first_content.text
|
||||
assert isinstance(second_content, TextContent)
|
||||
@@ -148,7 +149,7 @@ class TestWorkflowAgent:
|
||||
"""Test end-to-end workflow with RequestInfoEvent handling."""
|
||||
# Create workflow with requesting executor -> request info executor (no cycle)
|
||||
requesting_executor = RequestingExecutor(id="requester")
|
||||
request_info_executor = RequestInfoExecutor()
|
||||
request_info_executor = RequestInfoExecutor(id="request_info")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
@@ -160,21 +161,21 @@ class TestWorkflowAgent:
|
||||
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
|
||||
|
||||
# Execute workflow streaming to get request info event
|
||||
updates = []
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream("Start request"):
|
||||
updates.append(update)
|
||||
# Should have received a function call for the request info
|
||||
assert len(updates) > 0
|
||||
|
||||
# Find the function call update (RequestInfoEvent converted to function call)
|
||||
function_call_update = None
|
||||
function_call_update: AgentRunResponseUpdate | None = None
|
||||
for update in updates:
|
||||
if update.contents and hasattr(update.contents[0], "name") and update.contents[0].name == "request_info":
|
||||
if update.contents and hasattr(update.contents[0], "name") and update.contents[0].name == "request_info": # type: ignore[attr-defined]
|
||||
function_call_update = update
|
||||
break
|
||||
|
||||
assert function_call_update is not None, "Should have received a request_info function call"
|
||||
function_call = function_call_update.contents[0]
|
||||
function_call: FunctionCallContent = function_call_update.contents[0] # type: ignore[assignment]
|
||||
|
||||
# Verify the function call has expected structure
|
||||
assert function_call.call_id is not None
|
||||
@@ -230,7 +231,7 @@ class TestWorkflowAgent:
|
||||
raise ValueError("Unsupported message type")
|
||||
|
||||
# Create a simple workflow
|
||||
executor = _Executor()
|
||||
executor = _Executor(id="test")
|
||||
workflow = WorkflowBuilder().set_start_executor(executor).build()
|
||||
|
||||
# Try to create an agent with unsupported input types
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
@@ -155,3 +158,59 @@ async def test_run_includes_status_events_idle_with_requests():
|
||||
assert len(timeline) >= 3
|
||||
assert timeline[-2].state == WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS
|
||||
assert timeline[-1].state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotRequest(RequestInfoMessage):
|
||||
prompt: str = ""
|
||||
draft: str = ""
|
||||
iteration: int = 0
|
||||
|
||||
|
||||
class SnapshotRequester(Executor):
|
||||
"""Executor that emits a rich RequestInfoMessage for persistence tests."""
|
||||
|
||||
def __init__(self, id: str, prompt: str, draft: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._prompt = prompt
|
||||
self._draft = draft
|
||||
|
||||
@handler
|
||||
async def ask(self, _: str, ctx: WorkflowContext[SnapshotRequest]) -> None: # pragma: no cover - simple helper
|
||||
await ctx.send_message(SnapshotRequest(prompt=self._prompt, draft=self._draft, iteration=1))
|
||||
|
||||
|
||||
async def test_request_info_executor_tracks_pending_requests_via_shared_state():
|
||||
prompt = "Review the launch copy"
|
||||
draft = "Limited edition grinder now $249"
|
||||
requester = SnapshotRequester(id="snapshot_req", prompt=prompt, draft=draft)
|
||||
request_info = RequestInfoExecutor(id="request_info")
|
||||
|
||||
wf = WorkflowBuilder().set_start_executor(requester).add_edge(requester, request_info).build()
|
||||
|
||||
events = [event async for event in wf.run_stream("start")]
|
||||
assert any(isinstance(event, RequestInfoEvent) for event in events)
|
||||
|
||||
pending_map: dict[str, Any] = await wf._shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) # type: ignore[reportPrivateUsage]
|
||||
assert isinstance(pending_map, dict)
|
||||
assert len(pending_map) == 1
|
||||
snapshot: dict[str, Any] = next(iter(pending_map.values()))
|
||||
assert snapshot["prompt"] == prompt
|
||||
assert snapshot["draft"] == draft
|
||||
assert snapshot.get("iteration") == 1
|
||||
|
||||
request_id: str = snapshot["request_id"]
|
||||
|
||||
request_info_resume = RequestInfoExecutor(id="request_info_resume")
|
||||
resume_context: WFContext[Any] = WFContext(
|
||||
executor_id=request_info_resume.id,
|
||||
source_executor_ids=[wf.__class__.__name__],
|
||||
shared_state=wf._shared_state, # type: ignore[reportPrivateUsage]
|
||||
runner_context=wf._runner_context, # type: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
await request_info_resume.handle_response("approve", request_id, resume_context)
|
||||
|
||||
updated_pending: dict[str, Any] = await wf._shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) # type: ignore[reportPrivateUsage]
|
||||
assert isinstance(updated_pending, dict)
|
||||
assert request_id not in updated_pending
|
||||
|
||||
@@ -43,6 +43,7 @@ Once comfortable with these, explore the rest of the samples below.
|
||||
| Sample | File | Concepts |
|
||||
|---|---|---|
|
||||
| Checkpoint & Resume | [checkpoint/checkpoint_with_resume.py](./checkpoint/checkpoint_with_resume.py) | Create checkpoints, inspect them, and resume execution |
|
||||
| Checkpoint & HITL Resume | [checkpoint/checkpoint_with_human_in_the_loop.py](./checkpoint/checkpoint_with_human_in_the_loop.py) | Combine checkpointing with human approvals and resume pending HITL requests |
|
||||
|
||||
### composition
|
||||
| Sample | File | Concepts |
|
||||
|
||||
@@ -50,7 +50,7 @@ Prerequisites
|
||||
# - Compute a result
|
||||
# - Forward that result to downstream node(s) using ctx.send_message(result)
|
||||
class UpperCase(Executor):
|
||||
def __init__(self, id: str | None = None):
|
||||
def __init__(self, id: str):
|
||||
super().__init__(id=id)
|
||||
|
||||
@handler
|
||||
|
||||
@@ -1,108 +1,34 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# import asyncio
|
||||
|
||||
# from agent_framework.foundry import FoundryChatClient
|
||||
# from agent_framework import AgentRunUpdateEvent, WorkflowBuilder, WorkflowCompletedEvent
|
||||
# from azure.identity.aio import AzureCliCredential
|
||||
|
||||
# """
|
||||
# Sample: Agents in a workflow with streaming
|
||||
|
||||
# A Writer agent generates content, then a Reviewer agent critiques it.
|
||||
# The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens.
|
||||
|
||||
# Purpose:
|
||||
# Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors.
|
||||
|
||||
# Demonstrate:
|
||||
# - Automatic streaming of agent deltas via AgentRunUpdateEvent.
|
||||
# - A simple console aggregator that groups updates by executor id and prints them as they arrive.
|
||||
# - A final WorkflowCompletedEvent that contains the reviewer outcome after both agents finish.
|
||||
|
||||
# Prerequisites:
|
||||
# - Foundry Agent Service configured, along with the required environment variables.
|
||||
# - Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample.
|
||||
# - Basic familiarity with WorkflowBuilder, edges, events, and streaming runs.
|
||||
# """
|
||||
|
||||
|
||||
# async def main():
|
||||
# """Build and run a simple two node agent workflow: Writer then Reviewer."""
|
||||
# # Create the Foundry chat client.
|
||||
# async with (
|
||||
# AzureCliCredential() as credential,
|
||||
# FoundryChatClient(async_credential=credential).create_agent(
|
||||
# name="Writer",
|
||||
# instructions=(
|
||||
# "You are an excellent content writer.You create new content and edit contents based on the feedback."
|
||||
# ),
|
||||
# ) as writer_agent,
|
||||
# FoundryChatClient(async_credential=credential).create_agent(
|
||||
# name="Reviewer",
|
||||
# instructions=(
|
||||
# "You are an excellent content reviewer."
|
||||
# "Provide actionable feedback to the writer about the provided content."
|
||||
# "Provide the feedback in the most concise manner possible."
|
||||
# ),
|
||||
# ) as reviewer_agent,
|
||||
# ):
|
||||
# # Build the workflow using the fluent builder.
|
||||
# # Set the start node and connect an edge from writer to reviewer.
|
||||
# workflow = WorkflowBuilder().set_start_executor(writer_agent).add_edge(writer_agent, reviewer_agent).build()
|
||||
|
||||
# # Stream events from the workflow. We aggregate partial token updates per executor for readable output.
|
||||
# completed_event: WorkflowCompletedEvent | None = None
|
||||
# last_executor_id = None
|
||||
|
||||
# async for event in workflow.run_stream(
|
||||
# "Create a slogan for a new electric SUV that is affordable and fun to drive."
|
||||
# ):
|
||||
# if isinstance(event, AgentRunUpdateEvent):
|
||||
# # AgentRunUpdateEvent contains incremental text deltas from the underlying agent.
|
||||
# # Print a prefix when the executor changes, then append updates on the same line.
|
||||
# eid = event.executor_id
|
||||
# if eid != last_executor_id:
|
||||
# if last_executor_id is not None:
|
||||
# print()
|
||||
# print(f"{eid}:", end=" ", flush=True)
|
||||
# last_executor_id = eid
|
||||
# print(event.data, end="", flush=True)
|
||||
# elif isinstance(event, WorkflowCompletedEvent):
|
||||
# # Terminal event with the final reviewer output.
|
||||
# completed_event = event
|
||||
|
||||
# # Print the final consolidated reviewer result.
|
||||
# if completed_event:
|
||||
# print("\n===== Final Output =====")
|
||||
# print(completed_event.data)
|
||||
|
||||
# """
|
||||
# Sample Output:
|
||||
|
||||
# writer_agent: Charge Up Your Journey. Fun, Affordable, Electric.
|
||||
# reviewer_agent: Clear message, but consider highlighting SUV specific benefits
|
||||
# (space, versatility) for stronger impact. Try more vivid language to evoke
|
||||
# excitement. Example: "Big on Space. Big on Fun. Electric for Everyone."
|
||||
# ===== Final Output =====
|
||||
# Clear message, but consider highlighting SUV specific benefits (space, versatility)
|
||||
# for stronger impact. Try more vivid language to evoke excitement. Example:
|
||||
# "Big on Space. Big on Fun. Electric for Everyone."
|
||||
# """
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(main())
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework import AgentRunUpdateEvent, WorkflowBuilder, WorkflowCompletedEvent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
"""
|
||||
Sample: Agents in a workflow with streaming
|
||||
|
||||
A Writer agent generates content, then a Reviewer agent critiques it.
|
||||
The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens.
|
||||
|
||||
Purpose:
|
||||
Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors.
|
||||
|
||||
Demonstrate:
|
||||
- Automatic streaming of agent deltas via AgentRunUpdateEvent.
|
||||
- A simple console aggregator that groups updates by executor id and prints them as they arrive.
|
||||
- A final WorkflowCompletedEvent that contains the reviewer outcome after both agents finish.
|
||||
|
||||
Prerequisites:
|
||||
- Foundry Agent Service configured, along with the required environment variables.
|
||||
- Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample.
|
||||
- Basic familiarity with WorkflowBuilder, edges, events, and streaming runs.
|
||||
"""
|
||||
|
||||
|
||||
async def create_foundry_agent() -> tuple[Callable[..., Awaitable[Any]], Callable[[], Awaitable[None]]]:
|
||||
"""Helper method to create a Foundry agent factory and a close function.
|
||||
|
||||
+17
-12
@@ -1,27 +1,31 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from agent_framework import (
|
||||
# Ensure local getting_started package can be imported when running as a script.
|
||||
_SAMPLES_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(_SAMPLES_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_SAMPLES_ROOT))
|
||||
|
||||
from agent_framework import ( # noqa: E402
|
||||
ChatMessage,
|
||||
Executor,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
Role,
|
||||
)
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
Role,
|
||||
WorkflowAgent,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
|
||||
from samples.getting_started.workflow.agents.workflow_as_agent_reflection_pattern import (
|
||||
from agent_framework.openai import OpenAIChatClient # noqa: E402
|
||||
from getting_started.workflow.agents.workflow_as_agent_reflection_pattern import ( # noqa: E402
|
||||
ReviewRequest,
|
||||
ReviewResponse,
|
||||
Worker,
|
||||
@@ -56,8 +60,9 @@ class HumanReviewRequest(RequestInfoMessage):
|
||||
class ReviewerWithHumanInTheLoop(Executor):
|
||||
"""Executor that always escalates reviews to a human manager."""
|
||||
|
||||
def __init__(self, worker_id: str, request_info_id: str) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, worker_id: str, request_info_id: str, reviewer_id: str | None = None) -> None:
|
||||
unique_id = reviewer_id or f"{worker_id}-reviewer"
|
||||
super().__init__(id=unique_id)
|
||||
self._worker_id = worker_id
|
||||
self._request_info_id = request_info_id
|
||||
|
||||
@@ -96,8 +101,8 @@ async def main() -> None:
|
||||
# Create executors for the workflow.
|
||||
print("Creating chat client and executors...")
|
||||
mini_chat_client = OpenAIChatClient(ai_model_id="gpt-4.1-nano")
|
||||
worker = Worker(chat_client=mini_chat_client)
|
||||
request_info_executor = RequestInfoExecutor()
|
||||
worker = Worker(id="sub-worker", chat_client=mini_chat_client)
|
||||
request_info_executor = RequestInfoExecutor(id="request_info")
|
||||
reviewer = ReviewerWithHumanInTheLoop(worker_id=worker.id, request_info_id=request_info_executor.id)
|
||||
|
||||
print("Building workflow with Worker ↔ Reviewer cycle...")
|
||||
|
||||
+18
-8
@@ -4,9 +4,19 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework import AgentRunResponseUpdate, ChatClientProtocol, ChatMessage, Contents, Role
|
||||
from agent_framework import (
|
||||
AgentRunResponseUpdate,
|
||||
AgentRunUpdateEvent,
|
||||
ChatClientProtocol,
|
||||
ChatMessage,
|
||||
Contents,
|
||||
Executor,
|
||||
Role,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agent_framework import AgentRunUpdateEvent, Executor, WorkflowBuilder, WorkflowContext, handler
|
||||
from pydantic import BaseModel
|
||||
|
||||
"""
|
||||
@@ -54,8 +64,8 @@ class ReviewResponse:
|
||||
class Reviewer(Executor):
|
||||
"""Executor that reviews agent responses and provides structured feedback."""
|
||||
|
||||
def __init__(self, chat_client: ChatClientProtocol) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, id: str, chat_client: ChatClientProtocol) -> None:
|
||||
super().__init__(id=id)
|
||||
self._chat_client = chat_client
|
||||
|
||||
@handler
|
||||
@@ -106,8 +116,8 @@ class Reviewer(Executor):
|
||||
class Worker(Executor):
|
||||
"""Executor that generates responses and incorporates feedback when necessary."""
|
||||
|
||||
def __init__(self, chat_client: ChatClientProtocol) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, id: str, chat_client: ChatClientProtocol) -> None:
|
||||
super().__init__(id=id)
|
||||
self._chat_client = chat_client
|
||||
self._pending_requests: dict[str, tuple[ReviewRequest, list[ChatMessage]]] = {}
|
||||
|
||||
@@ -189,8 +199,8 @@ async def main() -> None:
|
||||
print("Creating chat client and executors...")
|
||||
mini_chat_client = OpenAIChatClient(ai_model_id="gpt-4.1-nano")
|
||||
chat_client = OpenAIChatClient(ai_model_id="gpt-4.1")
|
||||
reviewer = Reviewer(chat_client=chat_client)
|
||||
worker = Worker(chat_client=mini_chat_client)
|
||||
reviewer = Reviewer(id="reviewer", chat_client=chat_client)
|
||||
worker = Worker(id="worker", chat_client=mini_chat_client)
|
||||
|
||||
print("Building workflow with Worker ↔ Reviewer cycle...")
|
||||
agent = (
|
||||
|
||||
+483
@@ -0,0 +1,483 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
ChatMessage,
|
||||
Executor,
|
||||
FileCheckpointStorage,
|
||||
RequestInfoEvent,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
Role,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
WorkflowRunState,
|
||||
WorkflowStatusEvent,
|
||||
handler,
|
||||
)
|
||||
from agent_framework.azure import AzureChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
# NOTE: the Azure client imports above are real dependencies. When running this
|
||||
# sample outside of Azure-enabled environments you may wish to swap in the
|
||||
# `agent_framework.builtin` chat client or mock the writer executor. We keep the
|
||||
# concrete import here so readers can see an end-to-end configuration.
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import Workflow
|
||||
from agent_framework._workflow._checkpoint import WorkflowCheckpoint
|
||||
|
||||
"""
|
||||
Sample: Checkpoint + human-in-the-loop quickstart.
|
||||
|
||||
This getting-started sample keeps the moving pieces to a minimum:
|
||||
|
||||
1. A brief is turned into a consistent prompt for an AI copywriter.
|
||||
2. The copywriter (an `AgentExecutor`) drafts release notes.
|
||||
3. A reviewer gateway routes every draft through `RequestInfoExecutor` so a human
|
||||
can approve or request tweaks.
|
||||
4. The workflow records checkpoints between each superstep so you can stop the
|
||||
program, restart later, and optionally pre-supply human answers on resume.
|
||||
|
||||
Key concepts demonstrated
|
||||
-------------------------
|
||||
- Minimal executor pipeline with checkpoint persistence.
|
||||
- Human-in-the-loop pause/resume by pairing `RequestInfoExecutor` with
|
||||
checkpoint restoration.
|
||||
- Supplying responses at restore time (`run_stream_from_checkpoint(..., responses=...)`).
|
||||
|
||||
Typical pause/resume flow
|
||||
-------------------------
|
||||
1. Run the workflow until a human approval request is emitted.
|
||||
2. If the human is offline, exit the program. A checkpoint with
|
||||
``status=awaiting human response`` now exists.
|
||||
3. Later, restart the script, select that checkpoint, and provide the stored
|
||||
human decision when prompted to pre-supply responses.
|
||||
Doing so applies the answer immediately on resume, so the system does **not**
|
||||
re-emit the same `RequestInfoEvent`.
|
||||
"""
|
||||
|
||||
# Directory used for the sample's temporary checkpoint files. We isolate the
|
||||
# demo artefacts so that repeated runs do not collide with other samples and so
|
||||
# the clean-up step at the end of the script can simply delete the directory.
|
||||
TEMP_DIR = Path(__file__).with_suffix("").parent / "tmp" / "checkpoints_hitl"
|
||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class BriefPreparer(Executor):
|
||||
"""Normalises the user brief and sends a single AgentExecutorRequest."""
|
||||
|
||||
# The first executor in the workflow. By keeping it tiny we make it easier
|
||||
# to reason about the state that will later be captured in the checkpoint.
|
||||
# It is responsible for tidying the human-provided brief and kicking off the
|
||||
# agent run with a deterministic prompt structure.
|
||||
|
||||
def __init__(self, id: str, agent_id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._agent_id = agent_id
|
||||
|
||||
@handler
|
||||
async def prepare(self, brief: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
# Collapse errant whitespace so the prompt is stable between runs.
|
||||
normalized = " ".join(brief.split()).strip()
|
||||
if not normalized.endswith("."):
|
||||
normalized += "."
|
||||
# Persist the cleaned brief in shared state so downstream executors and
|
||||
# future checkpoints can recover the original intent.
|
||||
await ctx.set_shared_state("brief", normalized)
|
||||
prompt = (
|
||||
"You are drafting product release notes. Summarise the brief below in two sentences. "
|
||||
"Keep it positive and end with a call to action.\n\n"
|
||||
f"BRIEF: {normalized}"
|
||||
)
|
||||
# Hand the prompt to the writer agent. We always route through the
|
||||
# workflow context so the runtime can capture messages for checkpointing.
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
|
||||
target_id=self._agent_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanApprovalRequest(RequestInfoMessage):
|
||||
"""Message sent to the human reviewer via RequestInfoExecutor."""
|
||||
|
||||
# These fields are intentionally simple because they are serialised into
|
||||
# checkpoints. Keeping them primitive types guarantees the new
|
||||
# `pending_requests_from_checkpoint` helper can reconstruct them on resume.
|
||||
prompt: str = ""
|
||||
draft: str = ""
|
||||
iteration: int = 0
|
||||
|
||||
|
||||
class ReviewGateway(Executor):
|
||||
"""Routes agent drafts to humans and optionally back for revisions."""
|
||||
|
||||
def __init__(self, id: str, reviewer_id: str, writer_id: str, finalize_id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._reviewer_id = reviewer_id
|
||||
self._writer_id = writer_id
|
||||
self._finalize_id = finalize_id
|
||||
|
||||
@handler
|
||||
async def on_agent_response(
|
||||
self,
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[HumanApprovalRequest],
|
||||
) -> None:
|
||||
# Capture the agent output so we can surface it to the reviewer and
|
||||
# persist iterations. The `RequestInfoExecutor` relies on this state to
|
||||
# rehydrate when checkpoints are restored.
|
||||
draft = response.agent_run_response.text or ""
|
||||
iteration = int((await ctx.get_state() or {}).get("iteration", 0)) + 1
|
||||
await ctx.set_state({"iteration": iteration, "last_draft": draft})
|
||||
# Emit a human approval request. Because this flows through
|
||||
# RequestInfoExecutor it will pause the workflow until an answer is
|
||||
# supplied either interactively or via pre-supplied responses.
|
||||
await ctx.send_message(
|
||||
HumanApprovalRequest(
|
||||
prompt="Review the draft. Reply 'approve' or provide edit instructions.",
|
||||
draft=draft,
|
||||
iteration=iteration,
|
||||
),
|
||||
target_id=self._reviewer_id,
|
||||
)
|
||||
|
||||
@handler
|
||||
async def on_human_feedback(
|
||||
self,
|
||||
feedback: RequestResponse[HumanApprovalRequest, str],
|
||||
ctx: WorkflowContext[AgentExecutorRequest | str],
|
||||
) -> None:
|
||||
# The RequestResponse wrapper gives us both the human data and the
|
||||
# original request message, even when resuming from checkpoints.
|
||||
reply = (feedback.data or "").strip()
|
||||
state = await ctx.get_state() or {}
|
||||
draft = state.get("last_draft") or (feedback.original_request.draft if feedback.original_request else "")
|
||||
|
||||
if reply.lower() == "approve":
|
||||
# When the human signs off we can short-circuit the workflow and
|
||||
# send the approved draft to the final executor.
|
||||
await ctx.send_message(draft, target_id=self._finalize_id)
|
||||
return
|
||||
|
||||
# Any other response loops us back to the writer with fresh guidance.
|
||||
guidance = reply or "Tighten the copy and emphasise customer benefit."
|
||||
iteration = int(state.get("iteration", 1)) + 1
|
||||
await ctx.set_state({"iteration": iteration, "last_draft": draft})
|
||||
prompt = (
|
||||
"Revise the launch note. Respond with the new copy only.\n\n"
|
||||
f"Previous draft:\n{draft}\n\n"
|
||||
f"Human guidance: {guidance}"
|
||||
)
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True),
|
||||
target_id=self._writer_id,
|
||||
)
|
||||
|
||||
|
||||
class FinaliseExecutor(Executor):
|
||||
"""Publishes the approved text."""
|
||||
|
||||
@handler
|
||||
async def publish(self, text: str, ctx: WorkflowContext[Any]) -> None:
|
||||
# Store the output so diagnostics or a UI could fetch the final copy.
|
||||
await ctx.set_state({"published_text": text})
|
||||
# Emit a workflow completion event so the runner stops cleanly.
|
||||
await ctx.add_event(WorkflowCompletedEvent(text))
|
||||
|
||||
|
||||
def create_workflow(*, checkpoint_storage: FileCheckpointStorage | None = None) -> "Workflow":
|
||||
"""Assemble the workflow graph used by both the initial run and resume."""
|
||||
|
||||
# The Azure client is created once so our agent executor can issue calls to
|
||||
# the hosted model. The agent id is stable across runs which keeps
|
||||
# checkpoints deterministic.
|
||||
chat_client = AzureChatClient(credential=AzureCliCredential())
|
||||
writer = AgentExecutor(
|
||||
chat_client.create_agent(
|
||||
instructions="Write concise, warm release notes that sound human and helpful.",
|
||||
),
|
||||
id="writer",
|
||||
)
|
||||
# RequestInfoExecutor is the lynchpin for human-in-the-loop: every draft is
|
||||
# routed through it so checkpoints can pause while waiting for responses.
|
||||
review = RequestInfoExecutor(id="request_info")
|
||||
finalise = FinaliseExecutor(id="finalise")
|
||||
gateway = ReviewGateway(
|
||||
id="review_gateway",
|
||||
reviewer_id=review.id,
|
||||
writer_id=writer.id,
|
||||
finalize_id=finalise.id,
|
||||
)
|
||||
prepare = BriefPreparer(id="prepare_brief", agent_id=writer.id)
|
||||
|
||||
# Wire the workflow DAG. Edges mirror the numbered steps described in the
|
||||
# module docstring. Because `WorkflowBuilder` is declarative, reading these
|
||||
# edges is often the quickest way to understand execution order.
|
||||
builder = (
|
||||
WorkflowBuilder(max_iterations=6)
|
||||
.set_start_executor(prepare)
|
||||
.add_edge(prepare, writer)
|
||||
.add_edge(writer, gateway)
|
||||
.add_edge(gateway, review)
|
||||
.add_edge(review, gateway) # human resumes loop
|
||||
.add_edge(gateway, writer) # revisions
|
||||
.add_edge(gateway, finalise)
|
||||
)
|
||||
# Opt-in to persistence when the caller provides storage. The workflow
|
||||
# object itself is identical whether or not checkpointing is enabled.
|
||||
if checkpoint_storage:
|
||||
builder = builder.with_checkpointing(checkpoint_storage=checkpoint_storage)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
|
||||
"""Pretty-print saved checkpoints with the new framework summaries."""
|
||||
|
||||
print("\nCheckpoint summary:")
|
||||
for summary in [
|
||||
RequestInfoExecutor.checkpoint_summary(cp) for cp in sorted(checkpoints, key=lambda c: c.timestamp)
|
||||
]:
|
||||
# Compose a single line per checkpoint so the user can scan the output
|
||||
# and pick the resume point that still has outstanding human work.
|
||||
line = (
|
||||
f"- {summary.checkpoint_id} | iter={summary.iteration_count} "
|
||||
f"| targets={summary.targets} | states={summary.executor_states}"
|
||||
)
|
||||
if summary.status:
|
||||
line += f" | status={summary.status}"
|
||||
if summary.draft_preview:
|
||||
line += f" | draft_preview={summary.draft_preview}"
|
||||
if summary.pending_requests:
|
||||
line += f" | pending_request_id={summary.pending_requests[0].request_id}"
|
||||
print(line)
|
||||
|
||||
|
||||
def _print_events(events: list[Any]) -> tuple[WorkflowCompletedEvent | None, list[tuple[str, HumanApprovalRequest]]]:
|
||||
"""Echo workflow events to the console and collect outstanding requests."""
|
||||
|
||||
completed: WorkflowCompletedEvent | None = None
|
||||
requests: list[tuple[str, HumanApprovalRequest]] = []
|
||||
|
||||
for event in events:
|
||||
print(f"Event: {event}")
|
||||
if isinstance(event, WorkflowCompletedEvent):
|
||||
completed = event
|
||||
elif isinstance(event, RequestInfoEvent) and isinstance(event.data, HumanApprovalRequest):
|
||||
# Capture pending human approvals so the caller can ask the user for
|
||||
# input after the current batch of events is processed.
|
||||
requests.append((event.request_id, event.data))
|
||||
elif isinstance(event, WorkflowStatusEvent) and event.state in {
|
||||
WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS,
|
||||
WorkflowRunState.IDLE_WITH_PENDING_REQUESTS,
|
||||
}:
|
||||
print(f"Workflow state: {event.state.name}")
|
||||
|
||||
return completed, requests
|
||||
|
||||
|
||||
def _prompt_for_responses(requests: list[tuple[str, HumanApprovalRequest]]) -> dict[str, str] | None:
|
||||
"""Interactive CLI prompt for any live RequestInfo requests."""
|
||||
|
||||
if not requests:
|
||||
return None
|
||||
answers: dict[str, str] = {}
|
||||
for request_id, request in requests:
|
||||
# Keep the prompt conversational so testers can use the script without
|
||||
# memorising the workflow APIs.
|
||||
print("\n=== Human approval needed ===")
|
||||
print(f"request_id: {request_id}")
|
||||
if request.iteration:
|
||||
print(f"Iteration: {request.iteration}")
|
||||
print(request.prompt)
|
||||
print("Draft: \n---\n" + request.draft + "\n---")
|
||||
answer = input("Type 'approve' or enter revision guidance (or 'exit' to quit): ").strip() # noqa: ASYNC250
|
||||
if answer.lower() == "exit":
|
||||
raise SystemExit("Stopped by user.")
|
||||
answers[request_id] = answer
|
||||
return answers
|
||||
|
||||
|
||||
def _maybe_pre_supply_responses(cp: "WorkflowCheckpoint") -> dict[str, str] | None:
|
||||
"""Offer to collect responses before resuming a checkpoint."""
|
||||
|
||||
pending = RequestInfoExecutor.pending_requests_from_checkpoint(cp)
|
||||
if not pending:
|
||||
return None
|
||||
|
||||
print(
|
||||
"This checkpoint still has pending human input. Provide the responses now so the resume step "
|
||||
"applies them immediately and does not re-emit the original RequestInfo event."
|
||||
)
|
||||
choice = input("Pre-supply responses for this checkpoint? [y/N]: ").strip().lower() # noqa: ASYNC250
|
||||
if choice not in {"y", "yes"}:
|
||||
return None
|
||||
|
||||
answers: dict[str, str] = {}
|
||||
for item in pending:
|
||||
iteration = item.iteration or 0
|
||||
print(f"\nPending draft (iteration {iteration} | request_id={item.request_id}):")
|
||||
draft_text = (item.draft or "").strip()
|
||||
if draft_text:
|
||||
# The shortened preview in the summary may truncate text; here we
|
||||
# show the full draft so the reviewer can make an informed choice.
|
||||
print("Draft:\n---\n" + draft_text + "\n---")
|
||||
else:
|
||||
print("Draft: [not captured in checkpoint payload - refer to your notes/log]")
|
||||
prompt_text = (item.prompt or "Review the draft").strip()
|
||||
print(prompt_text)
|
||||
answer = input("Response ('approve' or guidance, 'exit' to abort): ").strip() # noqa: ASYNC250
|
||||
if answer.lower() == "exit":
|
||||
raise SystemExit("Resume aborted by user.")
|
||||
answers[item.request_id] = answer
|
||||
return answers
|
||||
|
||||
|
||||
async def _consume(stream: AsyncIterable[Any]) -> list[Any]:
|
||||
"""Materialise an async event stream into a list."""
|
||||
|
||||
return [event async for event in stream]
|
||||
|
||||
|
||||
async def run_interactive_session(workflow: "Workflow", initial_message: str) -> WorkflowCompletedEvent | None:
|
||||
"""Run the workflow until it either finishes or pauses for human input."""
|
||||
|
||||
pending_responses: dict[str, str] | None = None
|
||||
completed: WorkflowCompletedEvent | None = None
|
||||
first = True
|
||||
|
||||
while completed is None:
|
||||
if first:
|
||||
# Kick off the workflow with the initial brief. The returned events
|
||||
# include RequestInfo events when the agent produces a draft.
|
||||
events = await _consume(workflow.run_stream(initial_message))
|
||||
first = False
|
||||
elif pending_responses:
|
||||
# Feed any answers the user just typed back into the workflow.
|
||||
events = await _consume(workflow.send_responses_streaming(pending_responses))
|
||||
else:
|
||||
break
|
||||
|
||||
completed, requests = _print_events(events)
|
||||
pending_responses = _prompt_for_responses(requests)
|
||||
|
||||
return completed
|
||||
|
||||
|
||||
async def resume_from_checkpoint(
|
||||
workflow: "Workflow",
|
||||
checkpoint_id: str,
|
||||
storage: FileCheckpointStorage,
|
||||
pre_supplied: dict[str, str] | None,
|
||||
) -> None:
|
||||
"""Resume a stored checkpoint and continue until completion or another pause."""
|
||||
|
||||
print(f"\nResuming from checkpoint: {checkpoint_id}")
|
||||
events = await _consume(
|
||||
workflow.run_stream_from_checkpoint(
|
||||
checkpoint_id,
|
||||
checkpoint_storage=storage,
|
||||
responses=pre_supplied,
|
||||
)
|
||||
)
|
||||
completed, requests = _print_events(events)
|
||||
if pre_supplied and not requests and completed is None:
|
||||
# When the checkpoint only needed the provided answers we let the user
|
||||
# know the workflow is waiting for the next superstep (usually another
|
||||
# agent response).
|
||||
print("Pre-supplied responses applied automatically; workflow is now waiting for the next step.")
|
||||
|
||||
pending = _prompt_for_responses(requests)
|
||||
while completed is None and pending:
|
||||
events = await _consume(workflow.send_responses_streaming(pending))
|
||||
completed, requests = _print_events(events)
|
||||
pending = _prompt_for_responses(requests)
|
||||
|
||||
if completed:
|
||||
print(f"Workflow completed with: {completed.data}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Entry point used by both the initial run and subsequent resumes."""
|
||||
|
||||
for file in TEMP_DIR.glob("*.json"):
|
||||
# Start each execution with a clean slate so the demonstration is
|
||||
# deterministic even if the directory had stale checkpoints.
|
||||
file.unlink()
|
||||
|
||||
storage = FileCheckpointStorage(storage_path=TEMP_DIR)
|
||||
workflow = create_workflow(checkpoint_storage=storage)
|
||||
|
||||
brief = (
|
||||
"Introduce our limited edition smart coffee grinder. Mention the $249 price, highlight the "
|
||||
"sensor that auto-adjusts the grind, and invite customers to pre-order on the website."
|
||||
)
|
||||
|
||||
print("Running workflow (human approval required)...")
|
||||
completed = await run_interactive_session(workflow, initial_message=brief)
|
||||
if completed:
|
||||
print(f"Initial run completed with final copy: {completed.data}")
|
||||
else:
|
||||
print("Initial run paused for human input.")
|
||||
|
||||
checkpoints = await storage.list_checkpoints()
|
||||
if not checkpoints:
|
||||
print("No checkpoints recorded.")
|
||||
return
|
||||
|
||||
# Show the user what is available before we prompt for the index. The
|
||||
# summary helper keeps this output consistent with other tooling.
|
||||
_render_checkpoint_summary(checkpoints)
|
||||
|
||||
sorted_cps = sorted(checkpoints, key=lambda c: c.timestamp)
|
||||
print("\nAvailable checkpoints:")
|
||||
for idx, cp in enumerate(sorted_cps):
|
||||
print(f" [{idx}] id={cp.checkpoint_id} iter={cp.iteration_count}")
|
||||
|
||||
# For the pause/resume demo we typically pick the latest checkpoint whose summary
|
||||
# status reads "awaiting human response" - that is the saved state that proves the
|
||||
# workflow can rehydrate, collect the pending answer, and continue after a break.
|
||||
selection = input("\nResume from which checkpoint? (press Enter to skip): ").strip() # noqa: ASYNC250
|
||||
if not selection:
|
||||
print("No resume selected. Exiting.")
|
||||
return
|
||||
|
||||
try:
|
||||
idx = int(selection)
|
||||
except ValueError:
|
||||
print("Invalid input; exiting.")
|
||||
return
|
||||
|
||||
if not 0 <= idx < len(sorted_cps):
|
||||
print("Index out of range; exiting.")
|
||||
return
|
||||
|
||||
chosen = sorted_cps[idx]
|
||||
summary = RequestInfoExecutor.checkpoint_summary(chosen)
|
||||
if summary.status == "completed":
|
||||
print("Selected checkpoint already reflects a completed workflow; nothing to resume.")
|
||||
return
|
||||
|
||||
# If the user wants, capture their decisions now so the resume call can
|
||||
# push them into the workflow and avoid re-prompting.
|
||||
pre_responses = _maybe_pre_supply_responses(chosen)
|
||||
|
||||
resumed_workflow = create_workflow()
|
||||
# Resume with a fresh workflow instance. The checkpoint carries the
|
||||
# persistent state while this object holds the runtime wiring.
|
||||
await resume_from_checkpoint(resumed_workflow, chosen.checkpoint_id, storage, pre_responses)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
@@ -12,6 +12,7 @@ from agent_framework import (
|
||||
ChatMessage,
|
||||
Executor,
|
||||
FileCheckpointStorage,
|
||||
RequestInfoExecutor,
|
||||
Role,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
@@ -21,6 +22,10 @@ from agent_framework import (
|
||||
from agent_framework.azure import AzureChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import Workflow
|
||||
from agent_framework._workflow._checkpoint import WorkflowCheckpoint
|
||||
|
||||
"""
|
||||
Sample: Checkpointing and Resuming a Workflow (with an Agent stage)
|
||||
|
||||
@@ -87,7 +92,7 @@ class UpperCaseExecutor(Executor):
|
||||
class SubmitToLowerAgent(Executor):
|
||||
"""Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility."""
|
||||
|
||||
def __init__(self, agent_id: str, id: str | None = None):
|
||||
def __init__(self, id: str, agent_id: str):
|
||||
super().__init__(id=id)
|
||||
self._agent_id = agent_id
|
||||
|
||||
@@ -132,10 +137,6 @@ class FinalizeFromAgent(Executor):
|
||||
class ReverseTextExecutor(Executor):
|
||||
"""Reverses the input text and persists local state."""
|
||||
|
||||
def __init__(self, id: str):
|
||||
"""Initialize the executor with an ID."""
|
||||
super().__init__(id=id)
|
||||
|
||||
@handler
|
||||
async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None:
|
||||
result = text[::-1]
|
||||
@@ -154,15 +155,10 @@ class ReverseTextExecutor(Executor):
|
||||
await ctx.send_message(result)
|
||||
|
||||
|
||||
async def main():
|
||||
# Clear existing checkpoints in this sample directory for a clean run.
|
||||
checkpoint_dir = Path(TEMP_DIR)
|
||||
for file in checkpoint_dir.glob("*.json"):
|
||||
file.unlink()
|
||||
|
||||
def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow":
|
||||
# Instantiate the pipeline executors.
|
||||
upper_case_executor = UpperCaseExecutor(id="upper_case_executor")
|
||||
reverse_text_executor = ReverseTextExecutor(id="reverse_text_executor")
|
||||
upper_case_executor = UpperCaseExecutor(id="upper-case")
|
||||
reverse_text_executor = ReverseTextExecutor(id="reverse-text")
|
||||
|
||||
# Configure the agent stage that lowercases the text.
|
||||
chat_client = AzureChatClient(credential=AzureCliCredential())
|
||||
@@ -174,14 +170,11 @@ async def main():
|
||||
)
|
||||
|
||||
# Bridge to the agent and terminalization stage.
|
||||
submit_lower = SubmitToLowerAgent(agent_id=lower_agent.id, id="submit_lower")
|
||||
submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id)
|
||||
finalize = FinalizeFromAgent(id="finalize")
|
||||
|
||||
# Backing store for checkpoints written by with_checkpointing.
|
||||
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
|
||||
|
||||
# Build the workflow with checkpointing enabled.
|
||||
workflow = (
|
||||
return (
|
||||
WorkflowBuilder(max_iterations=5)
|
||||
.add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse
|
||||
.add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request
|
||||
@@ -192,6 +185,40 @@ async def main():
|
||||
.build()
|
||||
)
|
||||
|
||||
def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None:
|
||||
"""Display human-friendly checkpoint metadata using framework summaries."""
|
||||
|
||||
if not checkpoints:
|
||||
return
|
||||
|
||||
print("\nCheckpoint summary:")
|
||||
for cp in sorted(checkpoints, key=lambda c: c.timestamp):
|
||||
summary = RequestInfoExecutor.checkpoint_summary(cp)
|
||||
msg_count = sum(len(v) for v in cp.messages.values())
|
||||
state_keys = sorted(cp.executor_states.keys())
|
||||
orig = cp.shared_state.get("original_input")
|
||||
upper = cp.shared_state.get("upper_output")
|
||||
|
||||
line = (
|
||||
f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}"
|
||||
)
|
||||
if summary.status:
|
||||
line += f" | status={summary.status}"
|
||||
line += f" | shared_state: original_input='{orig}', upper_output='{upper}'"
|
||||
print(line)
|
||||
|
||||
|
||||
async def main():
|
||||
# Clear existing checkpoints in this sample directory for a clean run.
|
||||
checkpoint_dir = Path(TEMP_DIR)
|
||||
for file in checkpoint_dir.glob("*.json"):
|
||||
file.unlink()
|
||||
|
||||
# Backing store for checkpoints written by with_checkpointing.
|
||||
checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR)
|
||||
|
||||
workflow = create_workflow(checkpoint_storage=checkpoint_storage)
|
||||
|
||||
# Run the full workflow once and observe events as they stream.
|
||||
print("Running workflow with initial message...")
|
||||
async for event in workflow.run_stream(message="hello world"):
|
||||
@@ -206,26 +233,20 @@ async def main():
|
||||
# All checkpoints created by this run share the same workflow_id.
|
||||
workflow_id = all_checkpoints[0].workflow_id
|
||||
|
||||
# Dump a quick summary including shared_state keys to illustrate what persisted.
|
||||
print("\nCheckpoint summary:")
|
||||
for cp in sorted(all_checkpoints, key=lambda c: c.timestamp):
|
||||
msg_count = sum(len(v) for v in cp.messages.values())
|
||||
state_keys = sorted(list(cp.executor_states.keys())) if hasattr(cp, "executor_states") else []
|
||||
orig = cp.shared_state.get("original_input") if hasattr(cp, "shared_state") else None
|
||||
upper = cp.shared_state.get("upper_output") if hasattr(cp, "shared_state") else None
|
||||
print(
|
||||
f"- {cp.checkpoint_id} | "
|
||||
f"iter={cp.iteration_count} | messages={msg_count} | states={state_keys} | "
|
||||
f"shared_state: original_input='{orig}', upper_output='{upper}'"
|
||||
)
|
||||
_render_checkpoint_summary(all_checkpoints)
|
||||
|
||||
# Offer an interactive selection of checkpoints to resume from.
|
||||
sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp)
|
||||
|
||||
print("\nAvailable checkpoints to resume from:")
|
||||
for idx, cp in enumerate(sorted_cps):
|
||||
summary = RequestInfoExecutor.checkpoint_summary(cp)
|
||||
line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}"
|
||||
if summary.status:
|
||||
line += f" status={summary.status}"
|
||||
msg_count = sum(len(v) for v in cp.messages.values())
|
||||
print(f" [{idx}] id={cp.checkpoint_id} iter={cp.iteration_count} messages={msg_count}")
|
||||
line += f" messages={msg_count}"
|
||||
print(line)
|
||||
|
||||
user_input = input(
|
||||
"\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: "
|
||||
@@ -256,15 +277,7 @@ async def main():
|
||||
# You can reuse the same workflow graph definition and resume from a prior checkpoint.
|
||||
# This second workflow instance does not enable checkpointing to show that resumption
|
||||
# reads from stored state but need not write new checkpoints.
|
||||
new_workflow = (
|
||||
WorkflowBuilder(max_iterations=5)
|
||||
.add_edge(upper_case_executor, reverse_text_executor)
|
||||
.add_edge(reverse_text_executor, submit_lower)
|
||||
.add_edge(submit_lower, lower_agent)
|
||||
.add_edge(lower_agent, finalize)
|
||||
.set_start_executor(upper_case_executor)
|
||||
.build()
|
||||
)
|
||||
new_workflow = create_workflow(checkpoint_storage=checkpoint_storage)
|
||||
|
||||
print(f"\nResuming from checkpoint: {chosen_cp_id}")
|
||||
async for event in new_workflow.run_stream_from_checkpoint(chosen_cp_id, checkpoint_storage=checkpoint_storage):
|
||||
@@ -275,15 +288,15 @@ async def main():
|
||||
|
||||
Running workflow with initial message...
|
||||
UpperCaseExecutor: 'hello world' -> 'HELLO WORLD'
|
||||
Event: ExecutorInvokedEvent(executor_id=upper_case_executor)
|
||||
Event: ExecutorInvokeEvent(executor_id=upper_case_executor)
|
||||
Event: ExecutorCompletedEvent(executor_id=upper_case_executor)
|
||||
ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH'
|
||||
Event: ExecutorInvokedEvent(executor_id=reverse_text_executor)
|
||||
Event: ExecutorInvokeEvent(executor_id=reverse_text_executor)
|
||||
Event: ExecutorCompletedEvent(executor_id=reverse_text_executor)
|
||||
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
|
||||
Event: ExecutorInvokedEvent(executor_id=submit_lower)
|
||||
Event: ExecutorInvokedEvent(executor_id=lower_agent)
|
||||
Event: ExecutorInvokedEvent(executor_id=finalize)
|
||||
Event: ExecutorInvokeEvent(executor_id=submit_lower)
|
||||
Event: ExecutorInvokeEvent(executor_id=lower_agent)
|
||||
Event: ExecutorInvokeEvent(executor_id=finalize)
|
||||
Event: WorkflowCompletedEvent(data=dlrow olleh)
|
||||
|
||||
Checkpoint summary:
|
||||
@@ -300,9 +313,9 @@ async def main():
|
||||
|
||||
Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b
|
||||
LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD'
|
||||
Resumed Event: ExecutorInvokedEvent(executor_id=submit_lower)
|
||||
Resumed Event: ExecutorInvokedEvent(executor_id=lower_agent)
|
||||
Resumed Event: ExecutorInvokedEvent(executor_id=finalize)
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower)
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent)
|
||||
Resumed Event: ExecutorInvokeEvent(executor_id=finalize)
|
||||
Resumed Event: WorkflowCompletedEvent(data=dlrow olleh)
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class GuessNumberExecutor(Executor):
|
||||
|
||||
def __init__(self, bound: tuple[int, int], id: str | None = None):
|
||||
"""Initialize the executor with a target number."""
|
||||
super().__init__(id=id)
|
||||
super().__init__(id=id or "guess_number")
|
||||
self._lower = bound[0]
|
||||
self._upper = bound[1]
|
||||
|
||||
@@ -83,7 +83,7 @@ class SubmitToJudgeAgent(Executor):
|
||||
"""Send the numeric guess to a judge agent which replies ABOVE/BELOW/MATCHED."""
|
||||
|
||||
def __init__(self, judge_agent_id: str, target: int, id: str | None = None):
|
||||
super().__init__(id=id)
|
||||
super().__init__(id=id or "submit_to_judge")
|
||||
self._judge_agent_id = judge_agent_id
|
||||
self._target = target
|
||||
|
||||
|
||||
+1
-1
@@ -87,7 +87,7 @@ class TurnManager(Executor):
|
||||
"""
|
||||
|
||||
def __init__(self, id: str | None = None):
|
||||
super().__init__(id=id)
|
||||
super().__init__(id=id or "turn_manager")
|
||||
|
||||
@handler
|
||||
async def start(self, _: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
|
||||
@@ -43,7 +43,7 @@ class DispatchToExperts(Executor):
|
||||
"""Dispatches the incoming prompt to all expert agent executors for parallel processing (fan out)."""
|
||||
|
||||
def __init__(self, expert_ids: list[str], id: str | None = None):
|
||||
super().__init__(id)
|
||||
super().__init__(id=id or "dispatch_to_experts")
|
||||
self._expert_ids = expert_ids
|
||||
|
||||
@handler
|
||||
@@ -71,7 +71,7 @@ class AggregateInsights(Executor):
|
||||
"""Aggregates expert agent responses into a single consolidated result (fan in)."""
|
||||
|
||||
def __init__(self, expert_ids: list[str], id: str | None = None):
|
||||
super().__init__(id)
|
||||
super().__init__(id=id or "aggregate_insights")
|
||||
self._expert_ids = expert_ids
|
||||
|
||||
@handler
|
||||
|
||||
@@ -61,7 +61,7 @@ class Split(Executor):
|
||||
|
||||
def __init__(self, map_executor_ids: list[str], id: str | None = None):
|
||||
"""Store mapper ids so we can assign non overlapping ranges per mapper."""
|
||||
super().__init__(id)
|
||||
super().__init__(id=id or "split")
|
||||
self._map_executor_ids = map_executor_ids
|
||||
|
||||
@handler
|
||||
@@ -145,7 +145,7 @@ class Shuffle(Executor):
|
||||
|
||||
def __init__(self, reducer_ids: list[str], id: str | None = None):
|
||||
"""Remember reducer ids so we can partition work deterministically."""
|
||||
super().__init__(id)
|
||||
super().__init__(id=id or "shuffle")
|
||||
self._reducer_ids = reducer_ids
|
||||
|
||||
@handler
|
||||
|
||||
+2
-2
@@ -40,7 +40,7 @@ class DispatchToExperts(Executor):
|
||||
"""Dispatches the incoming prompt to all expert agent executors (fan-out)."""
|
||||
|
||||
def __init__(self, expert_ids: list[str], id: str | None = None):
|
||||
super().__init__(id)
|
||||
super().__init__(id=id or "dispatch_to_experts")
|
||||
self._expert_ids = expert_ids
|
||||
|
||||
@handler
|
||||
@@ -67,7 +67,7 @@ class AggregateInsights(Executor):
|
||||
"""Aggregates expert agent responses into a single consolidated result (fan-in)."""
|
||||
|
||||
def __init__(self, expert_ids: list[str], id: str | None = None):
|
||||
super().__init__(id)
|
||||
super().__init__(id=id or "aggregate_insights")
|
||||
self._expert_ids = expert_ids
|
||||
|
||||
@handler
|
||||
|
||||
Reference in New Issue
Block a user