From aba094b5cf145fdc2365d7b5362027c0f93ad44f Mon Sep 17 00:00:00 2001 From: Evan Mattson <35585003+moonbox3@users.noreply.github.com> Date: Sat, 20 Sep 2025 03:57:09 +0900 Subject: [PATCH] Python: [BREAKING] Python: Make executor ID required, improvements around handling rehydrating checkpoints (#832) * Make executor ID required, improvements around handling rehydrating checkpoints. * Duplicate executor validation added * fix remaining issues --------- Co-authored-by: Eric Zhu --- .../agent_framework/_workflow/__init__.py | 2 + .../agent_framework/_workflow/__init__.pyi | 2 + .../agent_framework/_workflow/_concurrent.py | 5 +- .../agent_framework/_workflow/_executor.py | 548 +++++++++++++++++- .../agent_framework/_workflow/_magentic.py | 2 +- .../main/agent_framework/_workflow/_runner.py | 63 +- .../_workflow/_runner_context.py | 77 ++- .../_workflow/_typing_utils.py | 56 +- .../agent_framework/_workflow/_validation.py | 62 +- .../agent_framework/_workflow/_workflow.py | 146 ++++- .../tests/workflow/test_checkpoint_decode.py | 52 ++ .../workflow/test_checkpoint_validation.py | 73 +++ .../main/tests/workflow/test_concurrent.py | 14 + .../main/tests/workflow/test_executor.py | 12 +- .../test_request_info_executor_rehydrate.py | 223 +++++++ .../main/tests/workflow/test_serialization.py | 5 +- .../main/tests/workflow/test_sub_workflow.py | 10 +- .../main/tests/workflow/test_validation.py | 12 + .../main/tests/workflow/test_workflow.py | 4 +- .../tests/workflow/test_workflow_agent.py | 27 +- .../tests/workflow/test_workflow_states.py | 59 ++ .../getting_started/workflow/README.md | 1 + .../_start-here/step1_executors_and_edges.py | 2 +- .../agents/foundry_chat_agents_streaming.py | 118 +--- .../workflow_as_agent_human_in_the_loop.py | 29 +- .../workflow_as_agent_reflection_pattern.py | 26 +- .../checkpoint_with_human_in_the_loop.py | 483 +++++++++++++++ .../checkpoint/checkpoint_with_resume.py | 111 ++-- .../workflow/control-flow/simple_loop.py | 4 +- .../guessing_game_with_human_input.py | 2 +- .../parallelism/fan_out_fan_in_edges.py | 4 +- .../map_reduce_and_visualization.py | 4 +- .../concurrent_with_visualization.py | 4 +- 33 files changed, 1967 insertions(+), 275 deletions(-) create mode 100644 python/packages/main/tests/workflow/test_checkpoint_decode.py create mode 100644 python/packages/main/tests/workflow/test_checkpoint_validation.py create mode 100644 python/packages/main/tests/workflow/test_request_info_executor_rehydrate.py create mode 100644 python/samples/getting_started/workflow/checkpoint/checkpoint_with_human_in_the_loop.py diff --git a/python/packages/main/agent_framework/_workflow/__init__.py b/python/packages/main/agent_framework/_workflow/__init__.py index 8848794a74..995d379d64 100644 --- a/python/packages/main/agent_framework/_workflow/__init__.py +++ b/python/packages/main/agent_framework/_workflow/__init__.py @@ -91,6 +91,7 @@ from ._shared_state import SharedState from ._telemetry import EdgeGroupDeliveryStatus, WorkflowTracer, workflow_tracer from ._validation import ( EdgeDuplicationError, + ExecutorDuplicationError, GraphConnectivityError, HandlerOutputAnnotationError, TypeCompatibilityError, @@ -118,6 +119,7 @@ __all__ = [ "EdgeGroupDeliveryStatus", "Executor", "ExecutorCompletedEvent", + "ExecutorDuplicationError", "ExecutorEvent", "ExecutorFailedEvent", "ExecutorInvokedEvent", diff --git a/python/packages/main/agent_framework/_workflow/__init__.pyi b/python/packages/main/agent_framework/_workflow/__init__.pyi index 90d2a4322e..7bc212c73e 100644 --- a/python/packages/main/agent_framework/_workflow/__init__.pyi +++ b/python/packages/main/agent_framework/_workflow/__init__.pyi @@ -87,6 +87,7 @@ from ._shared_state import SharedState from ._telemetry import EdgeGroupDeliveryStatus, WorkflowTracer, workflow_tracer from ._validation import ( EdgeDuplicationError, + ExecutorDuplicationError, GraphConnectivityError, HandlerOutputAnnotationError, TypeCompatibilityError, @@ -114,6 +115,7 @@ __all__ = [ "EdgeGroupDeliveryStatus", "Executor", "ExecutorCompletedEvent", + "ExecutorDuplicationError", "ExecutorEvent", "ExecutorFailedEvent", "ExecutorInvokedEvent", diff --git a/python/packages/main/agent_framework/_workflow/_concurrent.py b/python/packages/main/agent_framework/_workflow/_concurrent.py index 670a555f50..42b567aaab 100644 --- a/python/packages/main/agent_framework/_workflow/_concurrent.py +++ b/python/packages/main/agent_framework/_workflow/_concurrent.py @@ -145,7 +145,10 @@ class _CallbackAggregator(Executor): """ def __init__(self, callback: Callable[..., Any], id: str | None = None) -> None: - super().__init__(id) + derived_id = getattr(callback, "__name__", "") or "" + if not derived_id or derived_id == "": + derived_id = f"{type(self).__name__}_unnamed" + super().__init__(id or derived_id) self._callback = callback self._param_count = len(inspect.signature(callback).parameters) diff --git a/python/packages/main/agent_framework/_workflow/_executor.py b/python/packages/main/agent_framework/_workflow/_executor.py index 8a563e21f9..cb066b383d 100644 --- a/python/packages/main/agent_framework/_workflow/_executor.py +++ b/python/packages/main/agent_framework/_workflow/_executor.py @@ -2,12 +2,15 @@ import contextlib import functools +import importlib import inspect +import logging import uuid -from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field +from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from textwrap import shorten from types import UnionType -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, get_args, get_origin, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload if TYPE_CHECKING: from ._workflow import Workflow @@ -17,6 +20,7 @@ from pydantic import Field from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, AgentThread, ChatMessage from agent_framework._pydantic import AFBaseModel +from ._checkpoint import WorkflowCheckpoint from ._events import ( AgentRunEvent, AgentRunUpdateEvent, @@ -25,35 +29,63 @@ from ._events import ( RequestInfoEvent, _framework_event_origin, # pyright: ignore[reportPrivateUsage] ) +from ._runner_context import _decode_checkpoint_value from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext +logger = logging.getLogger(__name__) + # region Executor +@dataclass +class PendingRequestDetails: + """Lightweight information about a pending request captured in a checkpoint.""" + + request_id: str + prompt: str | None = None + draft: str | None = None + iteration: int | None = None + source_executor_id: str | None = None + original_request: "RequestInfoMessage | dict[str, Any] | None" = None + + +@dataclass +class WorkflowCheckpointSummary: + """Human-readable summary of a workflow checkpoint.""" + + checkpoint_id: str + iteration_count: int + targets: list[str] + executor_states: list[str] + status: str + draft_preview: str | None + pending_requests: list[PendingRequestDetails] + + class Executor(AFBaseModel): """An executor is a component that processes messages in a workflow.""" # Provide a default so static analyzers (e.g., pyright) don't require passing `id`. # Runtime still sets a concrete value in __init__. id: str = Field( - default_factory=lambda: str(uuid.uuid4()), + ..., min_length=1, description="Unique identifier for the executor", ) type_: str = Field(default="", alias="type", description="The type of executor, corresponding to the class name") - def __init__(self, id: str | None = None, **kwargs: Any) -> None: + def __init__(self, id: str, **kwargs: Any) -> None: """Initialize the executor with a unique identifier. Args: - id: A unique identifier for the executor. If None, a new ID will be generated - following the format /. + id: A unique identifier for the executor. kwargs: Additional keyword arguments. Unused in this implementation. """ - executor_id = f"{self.__class__.__name__}/{uuid.uuid4()}" if id is None else id + if not id: + raise ValueError("Executor ID must be a non-empty string.") - kwargs.update({"id": executor_id}) + kwargs.update({"id": id}) if "type" not in kwargs and "type_" not in kwargs: kwargs["type_"] = self.__class__.__name__ @@ -677,17 +709,15 @@ class RequestInfoExecutor(Executor): a response is provided externally, it emits the response as a message. """ - def __init__(self, id: str | None = None): - """Initialize the RequestInfoExecutor with an optional custom ID. + _PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info" + + def __init__(self, id: str): + """Initialize the RequestInfoExecutor with a unique ID. Args: - id: Optional custom ID for this RequestInfoExecutor. If not provided, - a unique ID will be generated. + id: Unique ID for this RequestInfoExecutor. """ - import uuid - - executor_id = id or f"request_info_{uuid.uuid4().hex[:8]}" - super().__init__(id=executor_id) + super().__init__(id=id) self._request_events: dict[str, RequestInfoEvent] = {} self._sub_workflow_contexts: dict[str, dict[str, str]] = {} @@ -703,6 +733,7 @@ class RequestInfoExecutor(Executor): request_data=message, ) self._request_events[message.request_id] = event + await self._record_pending_request_snapshot(message, source_executor_id, ctx) await ctx.add_event(event) @handler @@ -748,10 +779,13 @@ class RequestInfoExecutor(Executor): response_data: The data returned in the response. ctx: The workflow context for sending the response. """ - if request_id not in self._request_events: + event = self._request_events.get(request_id) + if event is None: + event = await self._rehydrate_request_event(request_id, ctx) + if event is None: raise ValueError(f"No request found with ID: {request_id}") - event = self._request_events.pop(request_id) + self._request_events.pop(request_id, None) # Check if this was a forwarded sub-workflow request if request_id in self._sub_workflow_contexts: @@ -779,6 +813,472 @@ class RequestInfoExecutor(Executor): await ctx.send_message(correlated_response, target_id=event.source_executor_id) + await self._clear_pending_request_snapshot(request_id, ctx) + + async def _record_pending_request_snapshot( + self, + request: RequestInfoMessage, + source_executor_id: str, + ctx: WorkflowContext[Any], + ) -> None: + snapshot = self._build_request_snapshot(request, source_executor_id) + + pending = await self._load_pending_request_state(ctx) + pending[request.request_id] = snapshot + await self._persist_pending_request_state(pending, ctx) + + async def _clear_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> None: + pending = await self._load_pending_request_state(ctx) + if request_id not in pending: + return + + pending.pop(request_id, None) + await self._persist_pending_request_state(pending, ctx) + + async def _load_pending_request_state(self, ctx: WorkflowContext[Any]) -> dict[str, Any]: + try: + existing = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to read pending request state: {exc}") + return {} + + if not isinstance(existing, dict) or existing is None: + if existing not in (None, {}): + logger.warning( + f"RequestInfoExecutor {self.id} encountered non-dict pending state " + f"({type(existing).__name__}); resetting." + ) + return {} + + return dict(existing) + + async def _persist_pending_request_state(self, pending: dict[str, Any], ctx: WorkflowContext[Any]) -> None: + await self._safe_set_shared_state(ctx, pending) + await self._safe_set_runner_state(ctx, pending) + + async def _safe_set_shared_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: + try: + await ctx.set_shared_state(self._PENDING_SHARED_STATE_KEY, pending) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to update shared pending state: {exc}") + + async def _safe_set_runner_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: + try: + await ctx.set_state({"pending_requests": pending}) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to update runner state with pending requests: {exc}") + + def _build_request_snapshot( + self, + request: RequestInfoMessage, + source_executor_id: str, + ) -> dict[str, Any]: + snapshot: dict[str, Any] = { + "request_id": request.request_id, + "source_executor_id": source_executor_id, + "request_type": f"{type(request).__module__}:{type(request).__name__}", + "summary": repr(request), + } + + details = self._serialise_request_details(request) + if details: + snapshot["details"] = details + for key in ("prompt", "draft", "iteration"): + if key in details and key not in snapshot: + snapshot[key] = details[key] + + return snapshot + + def _serialise_request_details(self, request: RequestInfoMessage) -> dict[str, Any] | None: + if is_dataclass(request): + data = self._make_json_safe(asdict(request)) + if isinstance(data, dict): + return cast(dict[str, Any], data) + return None + + model_dump = getattr(request, "model_dump", None) + if callable(model_dump): + try: + dump = self._make_json_safe(model_dump(mode="json")) + except TypeError: + dump = self._make_json_safe(model_dump()) + if isinstance(dump, dict): + return cast(dict[str, Any], dump) + return None + + attrs = getattr(request, "__dict__", None) + if isinstance(attrs, dict): + cleaned = self._make_json_safe(attrs) + if isinstance(cleaned, dict): + return cast(dict[str, Any], cleaned) + + return None + + def _make_json_safe(self, value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, Mapping): + safe_dict: dict[str, Any] = {} + for key, val in value.items(): + safe_dict[str(key)] = self._make_json_safe(val) + return safe_dict + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [self._make_json_safe(item) for item in value] + return repr(value) + + async def has_pending_request(self, request_id: str, ctx: WorkflowContext[Any]) -> bool: + if request_id in self._request_events or request_id in self._sub_workflow_contexts: + return True + snapshot = await self._get_pending_request_snapshot(request_id, ctx) + return snapshot is not None + + async def _rehydrate_request_event( + self, + request_id: str, + ctx: WorkflowContext[Any], + ) -> RequestInfoEvent | None: + snapshot = await self._get_pending_request_snapshot(request_id, ctx) + if snapshot is None: + return None + + source_executor_id = snapshot.get("source_executor_id") + if not isinstance(source_executor_id, str) or not source_executor_id: + return None + + request = self._construct_request_from_snapshot(snapshot) + if request is None: + return None + + event = RequestInfoEvent( + request_id=request_id, + source_executor_id=source_executor_id, + request_type=type(request), + request_data=request, + ) + self._request_events[request_id] = event + return event + + async def _get_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> dict[str, Any] | None: + pending = await self._collect_pending_request_snapshots(ctx) + snapshot = pending.get(request_id) + if snapshot is None: + return None + return snapshot + + async def _collect_pending_request_snapshots(self, ctx: WorkflowContext[Any]) -> dict[str, dict[str, Any]]: + combined: dict[str, dict[str, Any]] = {} + + try: + shared_pending = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to read shared pending state during rehydrate: {exc}") + shared_pending = None + + if isinstance(shared_pending, dict): + for key, value in shared_pending.items(): + if isinstance(key, str) and isinstance(value, dict): + combined[key] = cast(dict[str, Any], value) + + try: + state = await ctx.get_state() + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to read runner state during rehydrate: {exc}") + state = None + + if isinstance(state, dict): + state_pending = state.get("pending_requests") + if isinstance(state_pending, dict): + for key, value in state_pending.items(): + if isinstance(key, str) and isinstance(value, dict) and key not in combined: + combined[key] = cast(dict[str, Any], value) + + return combined + + def _construct_request_from_snapshot(self, snapshot: dict[str, Any]) -> RequestInfoMessage | None: + details_raw = snapshot.get("details") + details: dict[str, Any] = cast(dict[str, Any], details_raw) if isinstance(details_raw, dict) else {} + + request_cls: type[RequestInfoMessage] = RequestInfoMessage + request_type_str = snapshot.get("request_type") + if isinstance(request_type_str, str) and ":" in request_type_str: + module_name, class_name = request_type_str.split(":", 1) + try: + module = importlib.import_module(module_name) + candidate = getattr(module, class_name) + if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage): + request_cls = candidate + except Exception as exc: + logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}") + request_cls = RequestInfoMessage + + request: RequestInfoMessage | None = self._instantiate_request(request_cls, details) + + if request is None and request_cls is not RequestInfoMessage: + request = self._instantiate_request(RequestInfoMessage, details) + + if request is None: + logger.warning( + f"RequestInfoExecutor {self.id} could not reconstruct request " + f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(details.keys())}" + ) + return None + + for key, value in details.items(): + if key == "request_id": + continue + try: + setattr(request, key, value) + except Exception as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}" + ) + continue + + snapshot_request_id = snapshot.get("request_id") + if isinstance(snapshot_request_id, str) and snapshot_request_id: + try: + request.request_id = snapshot_request_id + except Exception as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not apply snapshot " + f"request_id to {type(request).__name__}: {exc}" + ) + + return request + + def _instantiate_request( + self, + request_cls: type[RequestInfoMessage], + details: dict[str, Any], + ) -> RequestInfoMessage | None: + try: + model_validate = getattr(request_cls, "model_validate", None) + if callable(model_validate): + return cast(RequestInfoMessage, model_validate(details)) + except (TypeError, ValueError) as exc: + logger.debug( + f"RequestInfoExecutor {self.id} validation failed for {request_cls.__name__} via model_validate: {exc}" + ) + except Exception as exc: + logger.warning( + f"RequestInfoExecutor {self.id} encountered unexpected error during " + f"{request_cls.__name__}.model_validate: {exc}" + ) + + if is_dataclass(request_cls): + try: + field_names = {f.name for f in fields(request_cls)} + ctor_kwargs = {name: details[name] for name in field_names if name in details} + return request_cls(**ctor_kwargs) # type: ignore[call-arg] + except (TypeError, ValueError) as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not instantiate dataclass " + f"{request_cls.__name__} with snapshot data: {exc}" + ) + except Exception as exc: + logger.warning( + f"RequestInfoExecutor {self.id} encountered unexpected error " + f"constructing dataclass {request_cls.__name__}: {exc}" + ) + + try: + instance = request_cls() # type: ignore[call-arg] + except Exception as exc: + logger.warning( + f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}" + ) + return None + + for key, value in details.items(): + if key == "request_id": + continue + try: + setattr(instance, key, value) + except Exception as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not set attribute {key} on " + f"{request_cls.__name__} during instantiation: {exc}" + ) + continue + + return instance + + @staticmethod + def pending_requests_from_checkpoint( + checkpoint: WorkflowCheckpoint, + *, + request_executor_ids: Iterable[str] | None = None, + ) -> list[PendingRequestDetails]: + executor_filter: set[str] | None = None + if request_executor_ids is not None: + executor_filter = {str(value) for value in request_executor_ids} + + pending: dict[str, PendingRequestDetails] = {} + + shared_map = checkpoint.shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) + if isinstance(shared_map, Mapping): + for request_id, snapshot in shared_map.items(): + RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) + + for state in checkpoint.executor_states.values(): + if not isinstance(state, Mapping): + continue + inner = state.get("pending_requests") + if isinstance(inner, Mapping): + for request_id, snapshot in inner.items(): + RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) + + for source_id, message_list in checkpoint.messages.items(): + if executor_filter is not None and source_id not in executor_filter: + continue + if not isinstance(message_list, list): + continue + for message in message_list: + if not isinstance(message, Mapping): + continue + payload = _decode_checkpoint_value(message.get("data")) + RequestInfoExecutor._merge_message_payload(pending, payload, message) + + return list(pending.values()) + + @staticmethod + def checkpoint_summary( + checkpoint: WorkflowCheckpoint, + *, + request_executor_ids: Iterable[str] | None = None, + preview_width: int = 70, + ) -> WorkflowCheckpointSummary: + targets = sorted(checkpoint.messages.keys()) + executor_states = sorted(checkpoint.executor_states.keys()) + pending = RequestInfoExecutor.pending_requests_from_checkpoint( + checkpoint, request_executor_ids=request_executor_ids + ) + + draft_preview: str | None = None + for entry in pending: + if entry.draft: + draft_preview = shorten(entry.draft, width=preview_width, placeholder="…") + break + + status = "idle" + if pending: + status = "awaiting human response" + elif not checkpoint.messages and "finalise" in executor_states: + status = "completed" + elif checkpoint.messages: + status = "awaiting next superstep" + elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids): + status = "awaiting request delivery" + + return WorkflowCheckpointSummary( + checkpoint_id=checkpoint.checkpoint_id, + iteration_count=checkpoint.iteration_count, + targets=targets, + executor_states=executor_states, + status=status, + draft_preview=draft_preview, + pending_requests=pending, + ) + + @staticmethod + def _merge_snapshot( + pending: dict[str, PendingRequestDetails], + request_id: str, + snapshot: Any, + ) -> None: + if not request_id or not isinstance(snapshot, Mapping): + return + + details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) + + RequestInfoExecutor._apply_update( + details, + prompt=snapshot.get("prompt"), + draft=snapshot.get("draft"), + iteration=snapshot.get("iteration"), + source_executor_id=snapshot.get("source_executor_id"), + ) + + extra = snapshot.get("details") + if isinstance(extra, Mapping): + RequestInfoExecutor._apply_update( + details, + prompt=extra.get("prompt"), + draft=extra.get("draft"), + iteration=extra.get("iteration"), + ) + + @staticmethod + def _merge_message_payload( + pending: dict[str, PendingRequestDetails], + payload: Any, + raw_message: Mapping[str, Any], + ) -> None: + if isinstance(payload, RequestResponse): + request_id = payload.request_id or RequestInfoExecutor._get_field(payload.original_request, "request_id") + if not request_id: + return + details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) + RequestInfoExecutor._apply_update( + details, + prompt=RequestInfoExecutor._get_field(payload.original_request, "prompt"), + draft=RequestInfoExecutor._get_field(payload.original_request, "draft"), + iteration=RequestInfoExecutor._get_field(payload.original_request, "iteration"), + source_executor_id=raw_message.get("source_id"), + original_request=payload.original_request, + ) + elif isinstance(payload, RequestInfoMessage): + request_id = getattr(payload, "request_id", None) + if not request_id: + return + details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) + RequestInfoExecutor._apply_update( + details, + prompt=getattr(payload, "prompt", None), + draft=getattr(payload, "draft", None), + iteration=getattr(payload, "iteration", None), + source_executor_id=raw_message.get("source_id"), + original_request=payload, + ) + + @staticmethod + def _apply_update( + details: PendingRequestDetails, + *, + prompt: Any = None, + draft: Any = None, + iteration: Any = None, + source_executor_id: Any = None, + original_request: Any = None, + ) -> None: + if prompt and not details.prompt: + details.prompt = str(prompt) + if draft and not details.draft: + details.draft = str(draft) + if iteration is not None and details.iteration is None: + coerced = RequestInfoExecutor._coerce_int(iteration) + if coerced is not None: + details.iteration = coerced + if source_executor_id and not details.source_executor_id: + details.source_executor_id = str(source_executor_id) + if original_request is not None and details.original_request is None: + details.original_request = original_request + + @staticmethod + def _get_field(obj: Any, key: str) -> Any: + if obj is None: + return None + if isinstance(obj, Mapping): + return obj.get(key) + return getattr(obj, key, None) + + @staticmethod + def _coerce_int(value: Any) -> int | None: + try: + return int(value) + except (TypeError, ValueError): + return None + # endregion: Request Info Executor @@ -840,7 +1340,11 @@ class AgentExecutor(Executor): exec_id = id else: agent_name = agent.name - exec_id = str(agent_name) if agent_name else f"executor_{uuid.uuid4()}" + if agent_name: + exec_id = str(agent_name) + else: + logger.warning("Agent has no name, using fallback ID 'executor_unnamed'") + exec_id = "executor_unnamed" super().__init__(exec_id) self._agent = agent self._agent_thread = agent_thread or self._agent.get_new_thread() @@ -952,12 +1456,12 @@ class WorkflowExecutor(Executor): workflow: "Workflow" = Field(description="The workflow to execute as a sub-workflow") - def __init__(self, workflow: "Workflow", id: str | None = None, **kwargs: Any): + def __init__(self, workflow: "Workflow", id: str, **kwargs: Any): """Initialize the WorkflowExecutor. Args: workflow: The workflow to execute as a sub-workflow. - id: Optional unique identifier for this executor. + id: Unique identifier for this executor. **kwargs: Additional keyword arguments passed to the parent constructor. """ kwargs.update({"workflow": workflow}) diff --git a/python/packages/main/agent_framework/_workflow/_magentic.py b/python/packages/main/agent_framework/_workflow/_magentic.py index 204565c320..44a2adc9dd 100644 --- a/python/packages/main/agent_framework/_workflow/_magentic.py +++ b/python/packages/main/agent_framework/_workflow/_magentic.py @@ -1635,7 +1635,7 @@ class MagenticBuilder: if self._enable_plan_review: from ._executor import RequestInfoExecutor - request_info = RequestInfoExecutor() + request_info = RequestInfoExecutor(id="request_info") workflow_builder = ( workflow_builder # Only route plan review asks to request_info diff --git a/python/packages/main/agent_framework/_workflow/_runner.py b/python/packages/main/agent_framework/_workflow/_runner.py index b501ff205f..bfbbce5e9b 100644 --- a/python/packages/main/agent_framework/_workflow/_runner.py +++ b/python/packages/main/agent_framework/_workflow/_runner.py @@ -13,7 +13,14 @@ from ._edge import EdgeGroup from ._edge_runner import EdgeRunner, create_edge_runner from ._events import WorkflowCompletedEvent, WorkflowEvent, _framework_event_origin from ._executor import Executor -from ._runner_context import Message, RunnerContext +from ._runner_context import ( + _DATACLASS_MARKER, + _PYDANTIC_MARKER, + CheckpointState, + Message, + RunnerContext, + _decode_checkpoint_value, +) from ._shared_state import SharedState from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext @@ -53,6 +60,7 @@ class Runner: self._workflow_id = workflow_id self._running = False self._resumed_from_checkpoint = False # Track whether we resumed + self.graph_signature_hash: str | None = None # Set workflow ID in context if provided if workflow_id: @@ -244,6 +252,19 @@ class Runner: """Inner loop to deliver a single message through an edge runner.""" return await edge_runner.send_message(message, self._shared_state, self._ctx) + def _normalize_message_payload(message: Message) -> None: + data = message.data + if not isinstance(data, dict): + return + if _PYDANTIC_MARKER not in data and _DATACLASS_MARKER not in data: + return + try: + decoded = _decode_checkpoint_value(data) + except Exception as exc: # pragma: no cover - defensive + logger.debug("Failed to decode checkpoint payload during delivery: %s", exc) + return + message.data = decoded + # Handle SubWorkflowRequestInfo messages specially await _deliver_sub_workflow_requests(messages) @@ -266,6 +287,7 @@ class Runner: associated_edge_runners = self._edge_runner_map.get(source_executor_id, []) for message in non_sub_workflow_messages: + _normalize_message_payload(message) # Deliver a message through all edge runners associated with the source executor concurrently. tasks = [_deliver_message_inner(edge_runner, message) for edge_runner in associated_edge_runners] if not tasks: @@ -332,6 +354,8 @@ class Runner: "superstep": self._iteration, "checkpoint_type": checkpoint_category, } + if self.graph_signature_hash: + metadata["graph_signature"] = self.graph_signature_hash checkpoint_id = await self._ctx.create_checkpoint(metadata=metadata) logger.info(f"Created {checkpoint_type} checkpoint: {checkpoint_id}") return checkpoint_id @@ -403,14 +427,45 @@ class Runner: return False try: - success = await self._ctx.restore_from_checkpoint(checkpoint_id) - if not success: + checkpoint = await self._ctx.load_checkpoint(checkpoint_id) + if not checkpoint: + logger.error(f"Checkpoint {checkpoint_id} not found") return False + graph_hash = getattr(self, "graph_signature_hash", None) + checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature") + if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash: + raise ValueError( + "Workflow graph has changed since the checkpoint was created. " + "Please rebuild the original workflow before resuming." + ) + if graph_hash and not checkpoint_hash: + logger.warning( + "Checkpoint %s does not include graph signature metadata; skipping topology validation.", + checkpoint_id, + ) + + state: CheckpointState = { + "messages": checkpoint.messages, + "shared_state": checkpoint.shared_state, + "executor_states": checkpoint.executor_states, + "iteration_count": checkpoint.iteration_count, + "max_iterations": checkpoint.max_iterations, + } + await self._ctx.set_checkpoint_state(state) + if checkpoint.workflow_id: + self._ctx.set_workflow_id(checkpoint.workflow_id) + self._workflow_id = checkpoint.workflow_id + await self._restore_shared_state_from_context() - self.mark_resumed() # mark resumed; iteration/max already restored from context + self.mark_resumed( + iteration=checkpoint.iteration_count, + max_iterations=checkpoint.max_iterations, + ) logger.info(f"Successfully restored workflow from checkpoint: {checkpoint_id}") return True + except ValueError: + raise except Exception as e: logger.error(f"Failed to restore from checkpoint {checkpoint_id}: {e}") return False diff --git a/python/packages/main/agent_framework/_workflow/_runner_context.py b/python/packages/main/agent_framework/_workflow/_runner_context.py index c78ff62f87..b99f1b5c7a 100644 --- a/python/packages/main/agent_framework/_workflow/_runner_context.py +++ b/python/packages/main/agent_framework/_workflow/_runner_context.py @@ -3,6 +3,7 @@ import asyncio import importlib import logging +import sys import uuid from collections import defaultdict from dataclasses import dataclass, fields, is_dataclass @@ -60,6 +61,39 @@ _MAX_ENCODE_DEPTH = 100 _CYCLE_SENTINEL = "" +def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None: + if not isinstance(cls, type): + logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}") + return None + + if isinstance(payload, dict): + try: + return cls(**payload) # type: ignore[arg-type] + except TypeError as exc: + logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}") + except Exception as exc: + logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}") + try: + instance = object.__new__(cls) + except Exception as exc: + logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}") + return None + for key, val in payload.items(): + try: + setattr(instance, key, val) + except Exception as exc: + logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}") + return instance + + try: + return cls(payload) # type: ignore[call-arg] + except TypeError as exc: + logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}") + except Exception as exc: + logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}") + return None + + def _is_pydantic_model(obj: object) -> bool: """Best-effort check for Pydantic models (e.g., AFBaseModel). @@ -99,7 +133,7 @@ def _encode_checkpoint_value(value: Any) -> Any: "value": v.model_dump(mode="json"), } except Exception as exc: # best-effort fallback - logger.debug("Pydantic model_dump failed for %s: %s", cls, exc) + logger.debug(f"Pydantic model_dump failed for {cls}: {exc}") return str(v) # Dataclasses (instances only) @@ -178,36 +212,32 @@ def _decode_checkpoint_value(value: Any) -> Any: if isinstance(type_key, str): try: module_name, class_name = type_key.split(":", 1) - module = importlib.import_module(module_name) + module = sys.modules.get(module_name) + if module is None: + module = importlib.import_module(module_name) cls: Any = getattr(module, class_name) if hasattr(cls, "model_validate"): return cls.model_validate(raw) except Exception as exc: - logger.debug( - "Failed to decode pydantic model %s: %s; returning raw value", - type_key, - exc, - ) + logger.debug(f"Failed to decode pydantic model {type_key}: {exc}; returning raw value") # Dataclass marker handling if _DATACLASS_MARKER in value_dict and "value" in value_dict: type_key_dc: str | None = value_dict.get(_DATACLASS_MARKER) # type: ignore[assignment] raw_dc: Any = value_dict.get("value") + decoded_raw = _decode_checkpoint_value(raw_dc) if isinstance(type_key_dc, str): try: module_name, class_name = type_key_dc.split(":", 1) - module = importlib.import_module(module_name) + module = sys.modules.get(module_name) + if module is None: + module = importlib.import_module(module_name) cls_dc: Any = getattr(module, class_name) - decoded_raw = _decode_checkpoint_value(raw_dc) - if isinstance(decoded_raw, dict): - return cls_dc(**decoded_raw) + constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw) + if constructed is not None: + return constructed except Exception as exc: - logger.debug( - "Failed to decode dataclass %s: %s; returning raw value", - type_key_dc, - exc, - ) - # Fallback to decoded raw value - return _decode_checkpoint_value(raw_dc) + logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value") + return decoded_raw # Regular dict: decode recursively decoded: dict[str, Any] = {} @@ -338,6 +368,10 @@ class RunnerContext(Protocol): """ ... + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: + """Load a checkpoint without mutating the current context state.""" + ... + async def get_checkpoint_state(self) -> CheckpointState: """Get the current state of the context suitable for checkpointing.""" ... @@ -409,7 +443,7 @@ class InProcRunnerContext: return except Exception as exc: # pragma: no cover - defensive logging path # Best-effort filtering only; never block event delivery on filtering errors - logger.debug("Error while filtering event %r: %s", event, exc, exc_info=True) + logger.debug(f"Error while filtering event {event!r}: {exc}", exc_info=True) await self._event_queue.put(event) @@ -497,6 +531,11 @@ class InProcRunnerContext: logger.info(f"Restored state from checkpoint {checkpoint_id}'") return True + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: + if not self._checkpoint_storage: + raise ValueError("Checkpoint storage not configured") + return await self._checkpoint_storage.load_checkpoint(checkpoint_id) + async def get_checkpoint_state(self) -> CheckpointState: serializable_messages: dict[str, list[dict[str, Any]]] = {} for source_id, message_list in self._messages.items(): diff --git a/python/packages/main/agent_framework/_workflow/_typing_utils.py b/python/packages/main/agent_framework/_workflow/_typing_utils.py index d8f54de786..586178d8d6 100644 --- a/python/packages/main/agent_framework/_workflow/_typing_utils.py +++ b/python/packages/main/agent_framework/_workflow/_typing_utils.py @@ -1,7 +1,58 @@ # Copyright (c) Microsoft. All rights reserved. +import logging +from dataclasses import fields, is_dataclass from typing import Any, Union, get_args, get_origin +logger = logging.getLogger(__name__) + + +def _coerce_to_type(value: Any, target_type: type) -> Any | None: + """Best-effort conversion of value into target_type.""" + if isinstance(value, target_type): + return value + + # Convert dataclass instances or objects with __dict__ into dict first + if not isinstance(value, dict): + if is_dataclass(value): + value = {f.name: getattr(value, f.name) for f in fields(value)} + else: + value_dict = getattr(value, "__dict__", None) + if isinstance(value_dict, dict): + value = dict(value_dict) + + if isinstance(value, dict): + ctor_kwargs: dict[str, Any] = dict(value) + + if is_dataclass(target_type): + field_names = {f.name for f in fields(target_type)} + ctor_kwargs = {k: v for k, v in value.items() if k in field_names} + + try: + return target_type(**ctor_kwargs) # type: ignore[arg-type] + except TypeError as exc: + logger.debug(f"_coerce_to_type could not call {target_type.__name__}(**..): {exc}") + except Exception as exc: # pragma: no cover - unexpected constructor failure + logger.warning( + f"_coerce_to_type encountered unexpected error calling {target_type.__name__} constructor: {exc}" + ) + try: + instance: Any = object.__new__(target_type) + except Exception as exc: # pragma: no cover - pathological type + logger.debug(f"_coerce_to_type could not allocate {target_type.__name__} without __init__: {exc}") + return None + for key, val in value.items(): + try: + setattr(instance, key, val) + except Exception as exc: + logger.debug( + f"_coerce_to_type could not set {target_type.__name__}.{key} during fallback assignment: {exc}" + ) + continue + return instance + + return None + def is_instance_of(data: Any, target_type: type) -> bool: """Check if the data is an instance of the target type. @@ -63,7 +114,10 @@ def is_instance_of(data: Any, target_type: type) -> bool: and data.original_request is not None and not is_instance_of(data.original_request, request_type) ): - return False + coerced = _coerce_to_type(data.original_request, request_type) + if coerced is None: + return False + data.original_request = coerced if hasattr(data, "data") and data.data is not None and not is_instance_of(data.data, response_type): return False return True diff --git a/python/packages/main/agent_framework/_workflow/_validation.py b/python/packages/main/agent_framework/_workflow/_validation.py index 94d4426c28..f7da01b969 100644 --- a/python/packages/main/agent_framework/_workflow/_validation.py +++ b/python/packages/main/agent_framework/_workflow/_validation.py @@ -34,6 +34,7 @@ class ValidationTypeEnum(Enum): """Enumeration of workflow validation types.""" EDGE_DUPLICATION = "EDGE_DUPLICATION" + EXECUTOR_DUPLICATION = "EXECUTOR_DUPLICATION" TYPE_COMPATIBILITY = "TYPE_COMPATIBILITY" GRAPH_CONNECTIVITY = "GRAPH_CONNECTIVITY" HANDLER_OUTPUT_ANNOTATION = "HANDLER_OUTPUT_ANNOTATION" @@ -63,6 +64,20 @@ class EdgeDuplicationError(WorkflowValidationError): self.edge_id = edge_id +class ExecutorDuplicationError(WorkflowValidationError): + """Exception raised when duplicate executor identifiers are detected.""" + + def __init__(self, executor_id: str): + super().__init__( + message=( + f"Duplicate executor id detected: '{executor_id}'. Executor ids must be globally unique within a " + "workflow." + ), + validation_type=ValidationTypeEnum.EXECUTOR_DUPLICATION, + ) + self.executor_id = executor_id + + class TypeCompatibilityError(WorkflowValidationError): """Exception raised when type incompatibility is detected between connected executors.""" @@ -133,10 +148,17 @@ class WorkflowGraphValidator: def __init__(self) -> None: self._edges: list[Edge] = [] self._executors: dict[str, Executor] = {} + self._duplicate_executor_ids: set[str] = set() + self._start_executor_ref: Executor | str | None = None # region Core Validation Methods def validate_workflow( - self, edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str + self, + edge_groups: Sequence[EdgeGroup], + executors: dict[str, Executor], + start_executor: Executor | str, + *, + duplicate_executor_ids: Sequence[str] | None = None, ) -> None: """Validate the entire workflow graph. @@ -144,6 +166,7 @@ class WorkflowGraphValidator: edge_groups: list of edge groups in the workflow executors: Map of executor IDs to executor instances start_executor: The starting executor (can be instance or ID) + duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate Raises: WorkflowValidationError: If any validation fails @@ -151,6 +174,8 @@ class WorkflowGraphValidator: self._executors = executors self._edges = [edge for group in edge_groups for edge in group.edges] self._edge_groups = edge_groups + self._duplicate_executor_ids = set(duplicate_executor_ids or []) + self._start_executor_ref = start_executor # If only the start executor exists, add it to the executor map # Handle the special case where the workflow consists of only a single executor and no edges. @@ -185,6 +210,7 @@ class WorkflowGraphValidator: ) # Run all checks + self._validate_executor_id_uniqueness(start_executor_id) self._validate_edge_duplication() self._validate_handler_output_annotations() self._validate_type_compatibility() @@ -353,6 +379,26 @@ class WorkflowGraphValidator: # endregion + def _validate_executor_id_uniqueness(self, start_executor_id: str) -> None: + """Ensure executor identifiers are unique throughout the workflow graph.""" + duplicates: set[str] = set(self._duplicate_executor_ids) + + id_counts: defaultdict[str, int] = defaultdict(int) + for key, executor in self._executors.items(): + id_counts[executor.id] += 1 + if key != executor.id: + duplicates.add(executor.id) + + duplicates.update({executor_id for executor_id, count in id_counts.items() if count > 1}) + + if isinstance(self._start_executor_ref, Executor): + mapped = self._executors.get(start_executor_id) + if mapped is not None and mapped is not self._start_executor_ref: + duplicates.add(start_executor_id) + + if duplicates: + raise ExecutorDuplicationError(sorted(duplicates)[0]) + # region Edge and Type Validation def _validate_edge_duplication(self) -> None: """Validate that there are no duplicate edges in the workflow. @@ -793,7 +839,11 @@ class WorkflowGraphValidator: def validate_workflow_graph( - edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str + edge_groups: Sequence[EdgeGroup], + executors: dict[str, Executor], + start_executor: Executor | str, + *, + duplicate_executor_ids: Sequence[str] | None = None, ) -> None: """Convenience function to validate a workflow graph. @@ -801,9 +851,15 @@ def validate_workflow_graph( edge_groups: list of edge groups in the workflow executors: Map of executor IDs to executor instances start_executor: The starting executor (can be instance or ID) + duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate Raises: WorkflowValidationError: If any validation fails """ validator = WorkflowGraphValidator() - validator.validate_workflow(edge_groups, executors, start_executor) + validator.validate_workflow( + edge_groups, + executors, + start_executor, + duplicate_executor_ids=duplicate_executor_ids, + ) diff --git a/python/packages/main/agent_framework/_workflow/_workflow.py b/python/packages/main/agent_framework/_workflow/_workflow.py index 2fe61c2209..4b3cf3646d 100644 --- a/python/packages/main/agent_framework/_workflow/_workflow.py +++ b/python/packages/main/agent_framework/_workflow/_workflow.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import hashlib +import json import logging import sys import uuid @@ -177,6 +179,12 @@ class Workflow(AFBaseModel): workflow_id=id, ) + # Capture a canonical fingerprint of the workflow graph so checkpoints + # can assert they are resumed with an equivalent topology. + self._graph_signature = self._compute_graph_signature() + self._graph_signature_hash = self._hash_graph_signature(self._graph_signature) + self._runner.graph_signature_hash = self._graph_signature_hash + def model_dump(self, **kwargs: Any) -> dict[str, Any]: """Custom serialization that properly handles WorkflowExecutor nested workflows.""" data = super().model_dump(**kwargs) @@ -377,17 +385,26 @@ class Workflow(AFBaseModel): request_info_executor = self._find_request_info_executor() if request_info_executor: for request_id, response_data in responses.items(): + ctx: WorkflowContext[Any] = WorkflowContext( + request_info_executor.id, + [self.__class__.__name__], + self._shared_state, + self._runner.context, + trace_contexts=None, # No parent trace context for new workflow span + source_span_ids=None, # No source span for response handling + ) + + if not await request_info_executor.has_pending_request(request_id, ctx): + logger.debug( + f"Skipping pre-supplied response for request {request_id}; no pending request found " + f"after checkpoint restoration." + ) + continue + await request_info_executor.handle_response( response_data, request_id, - WorkflowContext( - request_info_executor.id, - [self.__class__.__name__], - self._shared_state, - self._runner.context, - trace_contexts=None, # No parent trace context for new workflow span - source_span_ids=None, # No source span for response handling - ), + ctx, ) async for event in self._run_workflow_with_tracing( @@ -590,6 +607,19 @@ class Workflow(AFBaseModel): if not checkpoint: return False + graph_hash = getattr(self._runner, "graph_signature_hash", None) + checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature") + if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash: + raise ValueError( + "Workflow graph has changed since the checkpoint was created. " + "Please rebuild the original workflow before resuming." + ) + if graph_hash and not checkpoint_hash: + logger.warning( + f"Checkpoint {checkpoint_id} does not include graph signature metadata; " + f"skipping topology validation." + ) + temp_context = InProcRunnerContext(checkpoint_storage) state: CheckpointState = { "messages": checkpoint.messages, @@ -608,10 +638,9 @@ class Workflow(AFBaseModel): return True + except ValueError: + raise except Exception as e: - import logging - - logger = logging.getLogger(__name__) logger.error(f"Failed to restore from external checkpoint {checkpoint_id}: {e}") return False @@ -632,7 +661,7 @@ class Workflow(AFBaseModel): self._shared_state._state.clear() # type: ignore[attr-defined] self._shared_state._state.update(shared_state_data) # type: ignore[attr-defined] except Exception as exc: # pragma: no cover - logger.debug("Failed to restore shared_state during external restore: %s", exc) + logger.debug(f"Failed to restore shared_state during external restore: {exc}") # Restore executor states into the context so ctx.get_state() calls after resume succeed try: @@ -641,9 +670,9 @@ class Workflow(AFBaseModel): try: await self._runner.context.set_state(exec_id, state) except Exception as exc: # pragma: no cover - ignore per-executor failures - logger.debug("Failed to restore executor state for %s during external restore: %s", exec_id, exc) + logger.debug(f"Failed to restore executor state for {exec_id} during external restore: {exc}") except Exception as exc: # pragma: no cover - logger.debug("Failed to iterate executor_states during external restore: %s", exc) + logger.debug(f"Failed to iterate executor_states during external restore: {exc}") # Transfer pending messages into the context for delivery in the next superstep messages_data = restored_state["messages"] @@ -671,6 +700,71 @@ class Workflow(AFBaseModel): ) ) + # Graph signature helpers + + def _compute_graph_signature(self) -> dict[str, Any]: + """Build a canonical fingerprint of the workflow graph topology for checkpoint validation. + + This creates a minimal, stable representation that captures only the structural + elements of the workflow (executor types, edge relationships, topology) while + ignoring data/state changes. Used to verify that a workflow's structure hasn't + changed when resuming from checkpoints. + """ + executors_signature = { + executor_id: f"{executor.__class__.__module__}.{executor.__class__.__name__}" + for executor_id, executor in self.executors.items() + } + + edge_groups_signature: list[dict[str, Any]] = [] + for group in self.edge_groups: + edges = [ + { + "source": edge.source_id, + "target": edge.target_id, + "condition": getattr(edge, "condition_name", None), + } + for edge in group.edges + ] + edges.sort(key=lambda e: (e["source"], e["target"], e["condition"] or "")) + + group_info: dict[str, Any] = { + "group_type": group.__class__.__name__, + "sources": sorted(group.source_executor_ids), + "targets": sorted(group.target_executor_ids), + "edges": edges, + } + + if isinstance(group, FanOutEdgeGroup): + group_info["selection_func"] = group.selection_func_name + + edge_groups_signature.append(group_info) + + edge_groups_signature.sort( + key=lambda info: ( + info["group_type"], + tuple(info["sources"]), + tuple(info["targets"]), + json.dumps(info["edges"], sort_keys=True), + json.dumps(info.get("selection_func")), + ) + ) + + return { + "start_executor": self.start_executor_id, + "executors": executors_signature, + "edge_groups": edge_groups_signature, + "max_iterations": self.max_iterations, + } + + @staticmethod + def _hash_graph_signature(signature: dict[str, Any]) -> str: + canonical = json.dumps(signature, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + + @property + def graph_signature_hash(self) -> str: + return self._graph_signature_hash + def as_agent(self, name: str | None = None) -> WorkflowAgent: """Create a WorkflowAgent that wraps this workflow. @@ -699,6 +793,7 @@ class WorkflowBuilder: """Initialize the WorkflowBuilder with an empty list of edges and no starting executor.""" self._edge_groups: list[EdgeGroup] = [] self._executors: dict[str, Executor] = {} + self._duplicate_executor_ids: set[str] = set() self._start_executor: Executor | str | None = None self._checkpoint_storage: CheckpointStorage | None = None self._max_iterations: int = max_iterations @@ -712,7 +807,11 @@ class WorkflowBuilder: def _add_executor(self, executor: Executor) -> str: """Add an executor to the map and return its ID.""" - self._executors[executor.id] = executor + existing = self._executors.get(executor.id) + if existing is not None and existing is not executor: + self._duplicate_executor_ids.add(executor.id) + else: + self._executors[executor.id] = executor return executor.id def _maybe_wrap_agent(self, candidate: Executor | AgentProtocol) -> Executor: @@ -739,7 +838,10 @@ class WorkflowBuilder: if name: proposed_id = str(name) if proposed_id in self._executors: - proposed_id = f"{proposed_id}-{uuid.uuid4().hex[:8]}" + raise ValueError( + f"Duplicate executor ID '{proposed_id}' from agent name. " + "Agent names must be unique within a workflow." + ) wrapper = AgentExecutor(candidate, id=proposed_id, streaming=True) self._agent_wrappers[id(candidate)] = wrapper return wrapper @@ -934,8 +1036,9 @@ class WorkflowBuilder: self._start_executor = wrapped # Ensure the start executor is present in the executor map so validation succeeds # even if no edges are added yet, or before edges wrap the same agent again. - if wrapped.id not in self._executors: - self._executors[wrapped.id] = wrapped + existing = self._executors.get(wrapped.id) + if existing is not wrapped: + self._add_executor(wrapped) return self def set_max_iterations(self, max_iterations: int) -> Self: @@ -986,7 +1089,12 @@ class WorkflowBuilder: ) # Perform validation before creating the workflow - validate_workflow_graph(self._edge_groups, self._executors, self._start_executor) + validate_workflow_graph( + self._edge_groups, + self._executors, + self._start_executor, + duplicate_executor_ids=tuple(self._duplicate_executor_ids), + ) # Add validation completed event workflow_tracer.add_build_event("build.validation_completed") diff --git a/python/packages/main/tests/workflow/test_checkpoint_decode.py b/python/packages/main/tests/workflow/test_checkpoint_decode.py new file mode 100644 index 0000000000..7c528a5149 --- /dev/null +++ b/python/packages/main/tests/workflow/test_checkpoint_decode.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft. All rights reserved. + +from dataclasses import dataclass +from typing import Any, cast + +from agent_framework._workflow._executor import RequestInfoMessage, RequestResponse +from agent_framework._workflow._runner_context import _decode_checkpoint_value, _encode_checkpoint_value # type: ignore +from agent_framework._workflow._typing_utils import is_instance_of + + +@dataclass(kw_only=True) +class SampleRequest(RequestInfoMessage): + prompt: str + + +def test_decode_dataclass_with_nested_request() -> None: + original = RequestResponse[SampleRequest, str].handled("approve") + original = RequestResponse[SampleRequest, str].with_correlation( + original, + SampleRequest(request_id="abc", prompt="prompt"), + "abc", + ) + + encoded = _encode_checkpoint_value(original) + decoded = cast(RequestResponse[SampleRequest, str], _decode_checkpoint_value(encoded)) + + assert isinstance(decoded, RequestResponse) + assert decoded.data == "approve" + assert decoded.request_id == "abc" + assert isinstance(decoded.original_request, SampleRequest) + assert decoded.original_request.prompt == "prompt" + + +def test_is_instance_of_coerces_request_response_original_request_dict() -> None: + response = RequestResponse[SampleRequest, str].handled("approve") + response = RequestResponse[SampleRequest, str].with_correlation( + response, + SampleRequest(request_id="req-1", prompt="prompt"), + "req-1", + ) + + # Simulate checkpoint decode fallback leaving a dict + response.original_request = cast( + Any, + { + "request_id": "req-1", + "prompt": "prompt", + }, + ) + + assert is_instance_of(response, RequestResponse[SampleRequest, str]) + assert isinstance(response.original_request, SampleRequest) diff --git a/python/packages/main/tests/workflow/test_checkpoint_validation.py b/python/packages/main/tests/workflow/test_checkpoint_validation.py new file mode 100644 index 0000000000..defdbb1717 --- /dev/null +++ b/python/packages/main/tests/workflow/test_checkpoint_validation.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft. All rights reserved. + +import pytest + +from agent_framework import WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, handler +from agent_framework._workflow._checkpoint import InMemoryCheckpointStorage +from agent_framework._workflow._executor import Executor + + +class StartExecutor(Executor): + @handler + async def run(self, message: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(message, target_id="finish") + + +class FinishExecutor(Executor): + @handler + async def finish(self, message: str, ctx: WorkflowContext[None]) -> None: + await ctx.add_event(WorkflowCompletedEvent(message)) + + +def build_workflow(storage: InMemoryCheckpointStorage, finish_id: str = "finish"): + start = StartExecutor(id="start") + finish = FinishExecutor(id=finish_id) + + builder = WorkflowBuilder(max_iterations=3).set_start_executor(start).add_edge(start, finish) + builder = builder.with_checkpointing(checkpoint_storage=storage) + return builder.build() + + +async def test_resume_fails_when_graph_mismatch() -> None: + storage = InMemoryCheckpointStorage() + workflow = build_workflow(storage, finish_id="finish") + + # Run once to create checkpoints + _ = [event async for event in workflow.run_stream("hello")] # noqa: F841 + + checkpoints = await storage.list_checkpoints() + assert checkpoints, "expected at least one checkpoint to be created" + target_checkpoint = checkpoints[-1] + + # Build a structurally different workflow (different finish executor id) + mismatched_workflow = build_workflow(storage, finish_id="finish_alt") + + with pytest.raises(ValueError, match="Workflow graph has changed"): + _ = [ + event + async for event in mismatched_workflow.run_stream_from_checkpoint( + target_checkpoint.checkpoint_id, + checkpoint_storage=storage, + ) + ] + + +async def test_resume_succeeds_when_graph_matches() -> None: + storage = InMemoryCheckpointStorage() + workflow = build_workflow(storage, finish_id="finish") + _ = [event async for event in workflow.run_stream("hello")] # noqa: F841 + + checkpoints = sorted(await storage.list_checkpoints(), key=lambda c: c.timestamp) + target_checkpoint = checkpoints[0] + + resumed_workflow = build_workflow(storage, finish_id="finish") + + events = [ + event + async for event in resumed_workflow.run_stream_from_checkpoint( + target_checkpoint.checkpoint_id, + checkpoint_storage=storage, + ) + ] + + assert any(isinstance(event, WorkflowCompletedEvent) for event in events) diff --git a/python/packages/main/tests/workflow/test_concurrent.py b/python/packages/main/tests/workflow/test_concurrent.py index 1974b0b2e9..bfcf8802da 100644 --- a/python/packages/main/tests/workflow/test_concurrent.py +++ b/python/packages/main/tests/workflow/test_concurrent.py @@ -126,3 +126,17 @@ async def test_concurrent_custom_aggregator_sync_callback_is_used() -> None: assert completed is not None assert isinstance(completed.data, str) assert completed.data == "One | Two" + + +def test_concurrent_custom_aggregator_uses_callback_name_for_id() -> None: + e1 = _FakeAgentExec("agentA", "One") + e2 = _FakeAgentExec("agentB", "Two") + + def summarize(results: list[AgentExecutorResponse]) -> str: # type: ignore[override] + return str(len(results)) + + wf = ConcurrentBuilder().participants([e1, e2]).with_aggregator(summarize).build() + + assert "summarize" in wf.executors + aggregator = wf.executors["summarize"] + assert aggregator.id == "summarize" diff --git a/python/packages/main/tests/workflow/test_executor.py b/python/packages/main/tests/workflow/test_executor.py index c448212357..7f3d5bfc3e 100644 --- a/python/packages/main/tests/workflow/test_executor.py +++ b/python/packages/main/tests/workflow/test_executor.py @@ -5,16 +5,16 @@ import pytest from agent_framework import Executor, WorkflowContext, handler -def test_executor_without_handlers(): - """Test that an executor without handlers raises an error when trying to run.""" +def test_executor_without_id(): + """Test that an executor without an ID raises an error when trying to run.""" - class MockExecutorWithoutHandlers(Executor): + class MockExecutorWithoutID(Executor): """A mock executor that does not implement any handlers.""" pass with pytest.raises(ValueError): - MockExecutorWithoutHandlers() + MockExecutorWithoutID(id="") def test_executor_handler_without_annotations(): @@ -61,7 +61,7 @@ def test_executor_with_valid_handlers(): """Another mock handler with a valid signature.""" pass - executor = MockExecutorWithValidHandlers() + executor = MockExecutorWithValidHandlers(id="test") assert executor.id is not None assert len(executor._handlers) == 2 # type: ignore assert executor.can_handle("text") is True @@ -85,7 +85,7 @@ def test_executor_handlers_with_output_types(): """A mock handler that outputs an integer.""" pass - executor = MockExecutorWithOutputTypes() + executor = MockExecutorWithOutputTypes(id="test") assert len(executor._handlers) == 2 # type: ignore string_handler = executor._handlers[str] # type: ignore diff --git a/python/packages/main/tests/workflow/test_request_info_executor_rehydrate.py b/python/packages/main/tests/workflow/test_request_info_executor_rehydrate.py new file mode 100644 index 0000000000..b43d427477 --- /dev/null +++ b/python/packages/main/tests/workflow/test_request_info_executor_rehydrate.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft. All rights reserved. + +from dataclasses import dataclass +from typing import Any + +import pytest + +from agent_framework._workflow._checkpoint import WorkflowCheckpoint +from agent_framework._workflow._events import WorkflowEvent +from agent_framework._workflow._executor import ( + PendingRequestDetails, + RequestInfoExecutor, + RequestInfoMessage, + RequestResponse, +) +from agent_framework._workflow._runner_context import CheckpointState, Message, _encode_checkpoint_value # type: ignore +from agent_framework._workflow._shared_state import SharedState +from agent_framework._workflow._workflow_context import WorkflowContext + +PENDING_STATE_KEY = RequestInfoExecutor._PENDING_SHARED_STATE_KEY # pyright: ignore[reportPrivateUsage] + + +class _StubRunnerContext: + """Minimal runner context stub for exercising WorkflowContext helpers.""" + + def __init__(self, stored_state: dict[str, Any] | None = None) -> None: + self._state = stored_state or {} + + async def send_message(self, message: Message) -> None: # pragma: no cover - unused in tests + return None + + async def drain_messages(self) -> dict[str, list[Message]]: # pragma: no cover - unused + return {} + + async def has_messages(self) -> bool: # pragma: no cover - unused + return False + + async def add_event(self, event: WorkflowEvent) -> None: # pragma: no cover - unused + return None + + async def drain_events(self) -> list[WorkflowEvent]: # pragma: no cover - unused + return [] + + async def has_events(self) -> bool: # pragma: no cover - unused + return False + + async def next_event(self) -> WorkflowEvent: # pragma: no cover - unused + raise RuntimeError("Not implemented in stub context") + + async def get_state(self, executor_id: str) -> dict[str, Any] | None: # pragma: no cover - trivial + return self._state + + async def set_state(self, executor_id: str, state: dict[str, Any]) -> None: # pragma: no cover - unused + self._state = state + + def has_checkpointing(self) -> bool: # pragma: no cover - unused + return False + + def set_workflow_id(self, workflow_id: str) -> None: # pragma: no cover - unused + pass + + def reset_for_new_run(self, workflow_shared_state: SharedState | None = None) -> None: # pragma: no cover - unused + pass + + async def create_checkpoint(self, metadata: dict[str, Any] | None = None) -> str: # pragma: no cover - unused + raise RuntimeError("Checkpointing not supported in stub context") + + async def restore_from_checkpoint(self, checkpoint_id: str) -> bool: # pragma: no cover - unused + return False + + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pragma: no cover - unused + return None + + async def get_checkpoint_state(self) -> CheckpointState: # pragma: no cover - unused + return {} # type: ignore[return-value] + + async def set_checkpoint_state(self, state: CheckpointState) -> None: # pragma: no cover - unused + pass + + +@dataclass(kw_only=True) +class SimpleApproval(RequestInfoMessage): + prompt: str = "" + draft: str = "" + iteration: int = 0 + + +@pytest.mark.asyncio +async def test_rehydrate_falls_back_when_request_type_missing() -> None: + """Rehydration should succeed even if the original request type cannot be imported. + + This simulates resuming a workflow where the HumanApprovalRequest class is unavailable + in the current process (e.g., defined in __main__ during the original run). + """ + + request_id = "request-123" + snapshot = { + "request_id": request_id, + "source_executor_id": "review_gateway", + "request_type": "nonexistent.module:MissingRequest", + "summary": "...", + "details": { + "request_id": request_id, + "prompt": "Review draft", + "draft": "Draft text", + "iteration": 2, + }, + } + + shared_state = SharedState() + async with shared_state.hold(): + await shared_state.set_within_hold( + PENDING_STATE_KEY, + {request_id: snapshot}, + ) + + runner_ctx = _StubRunnerContext({"pending_requests": {request_id: snapshot}}) + ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx) + + executor = RequestInfoExecutor(id="request_info") + + event = await executor._rehydrate_request_event(request_id, ctx) # pyright: ignore[reportPrivateUsage] + + assert event is not None + assert event.request_id == request_id + assert isinstance(event.data, RequestInfoMessage) + assert getattr(event.data, "prompt", None) == "Review draft" + assert getattr(event.data, "iteration", None) == 2 + + +@pytest.mark.asyncio +async def test_has_pending_request_detects_snapshot() -> None: + request_id = "req-pending" + snapshot = { + "request_id": request_id, + "source_executor_id": "review_gateway", + "details": { + "request_id": request_id, + "prompt": "Review", + "draft": "Draft", + }, + } + + shared_state = SharedState() + async with shared_state.hold(): + await shared_state.set_within_hold( + PENDING_STATE_KEY, + {request_id: snapshot}, + ) + + runner_ctx = _StubRunnerContext({"pending_requests": {request_id: snapshot}}) + ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx) + + executor = RequestInfoExecutor(id="request_info") + + assert await executor.has_pending_request(request_id, ctx) + + +@pytest.mark.asyncio +async def test_has_pending_request_false_when_snapshot_absent() -> None: + shared_state = SharedState() + runner_ctx = _StubRunnerContext({"pending_requests": {}}) + ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], shared_state, runner_ctx) + + executor = RequestInfoExecutor(id="request_info") + + assert not await executor.has_pending_request("missing", ctx) + + +def test_pending_requests_from_checkpoint_and_summary() -> None: + request = SimpleApproval(prompt="Review draft", draft="Draft text", iteration=3) + request.request_id = "req-42" + + response = RequestResponse[SimpleApproval, str].handled("approve") + response = RequestResponse[SimpleApproval, str].with_correlation( + response, + request, + request.request_id, + ) + + encoded_response = _encode_checkpoint_value(response) + + checkpoint = WorkflowCheckpoint( + checkpoint_id="cp-1", + workflow_id="wf", + messages={ + "request_info": [ + { + "data": encoded_response, + "source_id": "request_info", + "target_id": "review_gateway", + } + ] + }, + shared_state={ + PENDING_STATE_KEY: { + request.request_id: { + "request_id": request.request_id, + "prompt": request.prompt, + "draft": request.draft, + "iteration": request.iteration, + "source_executor_id": "review_gateway", + } + } + }, + executor_states={}, + iteration_count=1, + ) + + pending = RequestInfoExecutor.pending_requests_from_checkpoint(checkpoint) + assert len(pending) == 1 + entry = pending[0] + assert isinstance(entry, PendingRequestDetails) + assert entry.request_id == "req-42" + assert entry.prompt == "Review draft" + assert entry.draft == "Draft text" + assert entry.iteration == 3 + assert entry.original_request is not None + + summary = RequestInfoExecutor.checkpoint_summary(checkpoint) + assert summary.checkpoint_id == "cp-1" + assert summary.status == "awaiting human response" + assert summary.pending_requests[0].request_id == "req-42" diff --git a/python/packages/main/tests/workflow/test_serialization.py b/python/packages/main/tests/workflow/test_serialization.py index 2a515b2c57..59a754c9ac 100644 --- a/python/packages/main/tests/workflow/test_serialization.py +++ b/python/packages/main/tests/workflow/test_serialization.py @@ -605,10 +605,7 @@ class TestSerializationWorkflowClasses: executor = SampleExecutor(id="valid-id") assert executor.id == "valid-id" - # Test validation failure for empty id - pydantic automatically validates min_length=1 - from pydantic import ValidationError - - with pytest.raises(ValidationError): + with pytest.raises(ValueError): SampleExecutor(id="") def test_edge_field_validation(self) -> None: diff --git a/python/packages/main/tests/workflow/test_sub_workflow.py b/python/packages/main/tests/workflow/test_sub_workflow.py index 3fe8109cb8..5cc1781a2c 100644 --- a/python/packages/main/tests/workflow/test_sub_workflow.py +++ b/python/packages/main/tests/workflow/test_sub_workflow.py @@ -200,7 +200,7 @@ async def test_sub_workflow_with_interception(): # Create parent workflow with interception parent = ParentOrchestrator(approved_domains={"example.com", "internal.org"}) workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow") - parent_request_info = RequestInfoExecutor() + parent_request_info = RequestInfoExecutor(id="request_info") main_workflow = ( WorkflowBuilder() @@ -280,7 +280,7 @@ async def test_conditional_forwarding() -> None: # Setup workflows email_validator = EmailValidator() - request_info = RequestInfoExecutor() + request_info = RequestInfoExecutor(id="request_info") validation_workflow = ( WorkflowBuilder() @@ -292,7 +292,7 @@ async def test_conditional_forwarding() -> None: parent = ConditionalParent() workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow") - parent_request_info = RequestInfoExecutor() + parent_request_info = RequestInfoExecutor(id="request_info") main_workflow = ( WorkflowBuilder() @@ -364,7 +364,7 @@ async def test_workflow_scoped_interception() -> None: # Create two identical sub-workflows def create_validation_workflow(): validator = EmailValidator() - request_info = RequestInfoExecutor() + request_info = RequestInfoExecutor(id="request_info") return ( WorkflowBuilder() .set_start_executor(validator) @@ -379,7 +379,7 @@ async def test_workflow_scoped_interception() -> None: parent = MultiWorkflowParent() executor_a = WorkflowExecutor(workflow_a, id="workflow_a") executor_b = WorkflowExecutor(workflow_b, id="workflow_b") - parent_request_info = RequestInfoExecutor() + parent_request_info = RequestInfoExecutor(id="request_info") main_workflow = ( WorkflowBuilder() diff --git a/python/packages/main/tests/workflow/test_validation.py b/python/packages/main/tests/workflow/test_validation.py index 6057f4486a..da41ff4858 100644 --- a/python/packages/main/tests/workflow/test_validation.py +++ b/python/packages/main/tests/workflow/test_validation.py @@ -8,6 +8,7 @@ import pytest from agent_framework import ( EdgeDuplicationError, Executor, + ExecutorDuplicationError, GraphConnectivityError, TypeCompatibilityError, ValidationTypeEnum, @@ -79,6 +80,17 @@ def test_valid_workflow_passes_validation(): assert workflow is not None +def test_duplicate_executor_ids_fail_validation(): + executor1 = StringExecutor(id="dup") + executor2 = IntExecutor(id="dup") + + with pytest.raises(ExecutorDuplicationError) as exc_info: + (WorkflowBuilder().add_edge(executor1, executor2).set_start_executor(executor1).build()) + + assert exc_info.value.executor_id == "dup" + assert exc_info.value.validation_type == ValidationTypeEnum.EXECUTOR_DUPLICATION + + def test_edge_duplication_validation_fails(): executor1 = StringExecutor(id="executor1") executor2 = StringExecutor(id="executor2") diff --git a/python/packages/main/tests/workflow/test_workflow.py b/python/packages/main/tests/workflow/test_workflow.py index b644ba2cb4..5baf05ed80 100644 --- a/python/packages/main/tests/workflow/test_workflow.py +++ b/python/packages/main/tests/workflow/test_workflow.py @@ -163,7 +163,7 @@ async def test_workflow_send_responses_streaming(): """Test the workflow run with approval.""" executor_a = IncrementExecutor(id="executor_a") executor_b = MockExecutorRequestApproval(id="executor_b") - request_info_executor = RequestInfoExecutor() + request_info_executor = RequestInfoExecutor(id="request_info") workflow = ( WorkflowBuilder() @@ -195,7 +195,7 @@ async def test_workflow_send_responses(): """Test the workflow run with approval.""" executor_a = IncrementExecutor(id="executor_a") executor_b = MockExecutorRequestApproval(id="executor_b") - request_info_executor = RequestInfoExecutor() + request_info_executor = RequestInfoExecutor(id="request_info") workflow = ( WorkflowBuilder() diff --git a/python/packages/main/tests/workflow/test_workflow_agent.py b/python/packages/main/tests/workflow/test_workflow_agent.py index 145ac1c06a..a48b876a67 100644 --- a/python/packages/main/tests/workflow/test_workflow_agent.py +++ b/python/packages/main/tests/workflow/test_workflow_agent.py @@ -11,6 +11,7 @@ from agent_framework import ( AgentRunUpdateEvent, ChatMessage, Executor, + FunctionCallContent, FunctionResultContent, RequestInfoExecutor, RequestInfoMessage, @@ -94,8 +95,8 @@ class TestWorkflowAgent: assert len(result.messages) >= 2, f"Expected at least 2 messages, got {len(result.messages)}" # Find messages from each executor - step1_messages = [] - step2_messages = [] + step1_messages: list[ChatMessage] = [] + step2_messages: list[ChatMessage] = [] for message in result.messages: first_content = message.contents[0] @@ -111,8 +112,8 @@ class TestWorkflowAgent: assert len(step2_messages) >= 1, "Should have received message from Step2 executor" # Verify the processing worked for both - step1_text = step1_messages[0].contents[0].text - step2_text = step2_messages[0].contents[0].text + step1_text: str = step1_messages[0].contents[0].text # type: ignore[attr-defined] + step2_text: str = step2_messages[0].contents[0].text # type: ignore[attr-defined] assert "Step1: Hello World" in step1_text assert "Step2: Step1: Hello World" in step2_text @@ -128,7 +129,7 @@ class TestWorkflowAgent: agent = WorkflowAgent(workflow=workflow, name="Streaming Test Agent") # Execute workflow streaming to capture streaming events - updates = [] + updates: list[AgentRunResponseUpdate] = [] async for update in agent.run_stream("Test input"): updates.append(update) @@ -137,8 +138,8 @@ class TestWorkflowAgent: # Verify we got a streaming update assert updates[0].contents is not None - first_content = updates[0].contents[0] - second_content = updates[1].contents[0] + first_content: TextContent = updates[0].contents[0] # type: ignore[assignment] + second_content: TextContent = updates[1].contents[0] # type: ignore[assignment] assert isinstance(first_content, TextContent) assert "Streaming1: Test input" in first_content.text assert isinstance(second_content, TextContent) @@ -148,7 +149,7 @@ class TestWorkflowAgent: """Test end-to-end workflow with RequestInfoEvent handling.""" # Create workflow with requesting executor -> request info executor (no cycle) requesting_executor = RequestingExecutor(id="requester") - request_info_executor = RequestInfoExecutor() + request_info_executor = RequestInfoExecutor(id="request_info") workflow = ( WorkflowBuilder() @@ -160,21 +161,21 @@ class TestWorkflowAgent: agent = WorkflowAgent(workflow=workflow, name="Request Test Agent") # Execute workflow streaming to get request info event - updates = [] + updates: list[AgentRunResponseUpdate] = [] async for update in agent.run_stream("Start request"): updates.append(update) # Should have received a function call for the request info assert len(updates) > 0 # Find the function call update (RequestInfoEvent converted to function call) - function_call_update = None + function_call_update: AgentRunResponseUpdate | None = None for update in updates: - if update.contents and hasattr(update.contents[0], "name") and update.contents[0].name == "request_info": + if update.contents and hasattr(update.contents[0], "name") and update.contents[0].name == "request_info": # type: ignore[attr-defined] function_call_update = update break assert function_call_update is not None, "Should have received a request_info function call" - function_call = function_call_update.contents[0] + function_call: FunctionCallContent = function_call_update.contents[0] # type: ignore[assignment] # Verify the function call has expected structure assert function_call.call_id is not None @@ -230,7 +231,7 @@ class TestWorkflowAgent: raise ValueError("Unsupported message type") # Create a simple workflow - executor = _Executor() + executor = _Executor(id="test") workflow = WorkflowBuilder().set_start_executor(executor).build() # Try to create an agent with unsupported input types diff --git a/python/packages/main/tests/workflow/test_workflow_states.py b/python/packages/main/tests/workflow/test_workflow_states.py index c419682c7c..bef26a4c98 100644 --- a/python/packages/main/tests/workflow/test_workflow_states.py +++ b/python/packages/main/tests/workflow/test_workflow_states.py @@ -1,5 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. +from dataclasses import dataclass +from typing import Any + import pytest from agent_framework import ( @@ -155,3 +158,59 @@ async def test_run_includes_status_events_idle_with_requests(): assert len(timeline) >= 3 assert timeline[-2].state == WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS assert timeline[-1].state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + +@dataclass +class SnapshotRequest(RequestInfoMessage): + prompt: str = "" + draft: str = "" + iteration: int = 0 + + +class SnapshotRequester(Executor): + """Executor that emits a rich RequestInfoMessage for persistence tests.""" + + def __init__(self, id: str, prompt: str, draft: str) -> None: + super().__init__(id=id) + self._prompt = prompt + self._draft = draft + + @handler + async def ask(self, _: str, ctx: WorkflowContext[SnapshotRequest]) -> None: # pragma: no cover - simple helper + await ctx.send_message(SnapshotRequest(prompt=self._prompt, draft=self._draft, iteration=1)) + + +async def test_request_info_executor_tracks_pending_requests_via_shared_state(): + prompt = "Review the launch copy" + draft = "Limited edition grinder now $249" + requester = SnapshotRequester(id="snapshot_req", prompt=prompt, draft=draft) + request_info = RequestInfoExecutor(id="request_info") + + wf = WorkflowBuilder().set_start_executor(requester).add_edge(requester, request_info).build() + + events = [event async for event in wf.run_stream("start")] + assert any(isinstance(event, RequestInfoEvent) for event in events) + + pending_map: dict[str, Any] = await wf._shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) # type: ignore[reportPrivateUsage] + assert isinstance(pending_map, dict) + assert len(pending_map) == 1 + snapshot: dict[str, Any] = next(iter(pending_map.values())) + assert snapshot["prompt"] == prompt + assert snapshot["draft"] == draft + assert snapshot.get("iteration") == 1 + + request_id: str = snapshot["request_id"] + + request_info_resume = RequestInfoExecutor(id="request_info_resume") + resume_context: WFContext[Any] = WFContext( + executor_id=request_info_resume.id, + source_executor_ids=[wf.__class__.__name__], + shared_state=wf._shared_state, # type: ignore[reportPrivateUsage] + runner_context=wf._runner_context, # type: ignore[reportPrivateUsage] + ) + + await request_info_resume.handle_response("approve", request_id, resume_context) + + updated_pending: dict[str, Any] = await wf._shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) # type: ignore[reportPrivateUsage] + assert isinstance(updated_pending, dict) + assert request_id not in updated_pending diff --git a/python/samples/getting_started/workflow/README.md b/python/samples/getting_started/workflow/README.md index 4b879e10dd..8758c53b6d 100644 --- a/python/samples/getting_started/workflow/README.md +++ b/python/samples/getting_started/workflow/README.md @@ -43,6 +43,7 @@ Once comfortable with these, explore the rest of the samples below. | Sample | File | Concepts | |---|---|---| | Checkpoint & Resume | [checkpoint/checkpoint_with_resume.py](./checkpoint/checkpoint_with_resume.py) | Create checkpoints, inspect them, and resume execution | +| Checkpoint & HITL Resume | [checkpoint/checkpoint_with_human_in_the_loop.py](./checkpoint/checkpoint_with_human_in_the_loop.py) | Combine checkpointing with human approvals and resume pending HITL requests | ### composition | Sample | File | Concepts | diff --git a/python/samples/getting_started/workflow/_start-here/step1_executors_and_edges.py b/python/samples/getting_started/workflow/_start-here/step1_executors_and_edges.py index 173c578a51..7744018c75 100644 --- a/python/samples/getting_started/workflow/_start-here/step1_executors_and_edges.py +++ b/python/samples/getting_started/workflow/_start-here/step1_executors_and_edges.py @@ -50,7 +50,7 @@ Prerequisites # - Compute a result # - Forward that result to downstream node(s) using ctx.send_message(result) class UpperCase(Executor): - def __init__(self, id: str | None = None): + def __init__(self, id: str): super().__init__(id=id) @handler diff --git a/python/samples/getting_started/workflow/agents/foundry_chat_agents_streaming.py b/python/samples/getting_started/workflow/agents/foundry_chat_agents_streaming.py index 20f04c2a00..e7bf4b0ffa 100644 --- a/python/samples/getting_started/workflow/agents/foundry_chat_agents_streaming.py +++ b/python/samples/getting_started/workflow/agents/foundry_chat_agents_streaming.py @@ -1,108 +1,34 @@ # Copyright (c) Microsoft. All rights reserved. -# import asyncio - -# from agent_framework.foundry import FoundryChatClient -# from agent_framework import AgentRunUpdateEvent, WorkflowBuilder, WorkflowCompletedEvent -# from azure.identity.aio import AzureCliCredential - -# """ -# Sample: Agents in a workflow with streaming - -# A Writer agent generates content, then a Reviewer agent critiques it. -# The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens. - -# Purpose: -# Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors. - -# Demonstrate: -# - Automatic streaming of agent deltas via AgentRunUpdateEvent. -# - A simple console aggregator that groups updates by executor id and prints them as they arrive. -# - A final WorkflowCompletedEvent that contains the reviewer outcome after both agents finish. - -# Prerequisites: -# - Foundry Agent Service configured, along with the required environment variables. -# - Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample. -# - Basic familiarity with WorkflowBuilder, edges, events, and streaming runs. -# """ - - -# async def main(): -# """Build and run a simple two node agent workflow: Writer then Reviewer.""" -# # Create the Foundry chat client. -# async with ( -# AzureCliCredential() as credential, -# FoundryChatClient(async_credential=credential).create_agent( -# name="Writer", -# instructions=( -# "You are an excellent content writer.You create new content and edit contents based on the feedback." -# ), -# ) as writer_agent, -# FoundryChatClient(async_credential=credential).create_agent( -# name="Reviewer", -# instructions=( -# "You are an excellent content reviewer." -# "Provide actionable feedback to the writer about the provided content." -# "Provide the feedback in the most concise manner possible." -# ), -# ) as reviewer_agent, -# ): -# # Build the workflow using the fluent builder. -# # Set the start node and connect an edge from writer to reviewer. -# workflow = WorkflowBuilder().set_start_executor(writer_agent).add_edge(writer_agent, reviewer_agent).build() - -# # Stream events from the workflow. We aggregate partial token updates per executor for readable output. -# completed_event: WorkflowCompletedEvent | None = None -# last_executor_id = None - -# async for event in workflow.run_stream( -# "Create a slogan for a new electric SUV that is affordable and fun to drive." -# ): -# if isinstance(event, AgentRunUpdateEvent): -# # AgentRunUpdateEvent contains incremental text deltas from the underlying agent. -# # Print a prefix when the executor changes, then append updates on the same line. -# eid = event.executor_id -# if eid != last_executor_id: -# if last_executor_id is not None: -# print() -# print(f"{eid}:", end=" ", flush=True) -# last_executor_id = eid -# print(event.data, end="", flush=True) -# elif isinstance(event, WorkflowCompletedEvent): -# # Terminal event with the final reviewer output. -# completed_event = event - -# # Print the final consolidated reviewer result. -# if completed_event: -# print("\n===== Final Output =====") -# print(completed_event.data) - -# """ -# Sample Output: - -# writer_agent: Charge Up Your Journey. Fun, Affordable, Electric. -# reviewer_agent: Clear message, but consider highlighting SUV specific benefits -# (space, versatility) for stronger impact. Try more vivid language to evoke -# excitement. Example: "Big on Space. Big on Fun. Electric for Everyone." -# ===== Final Output ===== -# Clear message, but consider highlighting SUV specific benefits (space, versatility) -# for stronger impact. Try more vivid language to evoke excitement. Example: -# "Big on Space. Big on Fun. Electric for Everyone." -# """ - - -# if __name__ == "__main__": -# asyncio.run(main()) - import asyncio +from collections.abc import Awaitable, Callable from contextlib import AsyncExitStack from typing import Any -from collections.abc import Awaitable, Callable -from agent_framework.foundry import FoundryChatClient from agent_framework import AgentRunUpdateEvent, WorkflowBuilder, WorkflowCompletedEvent +from agent_framework.foundry import FoundryChatClient from azure.identity.aio import AzureCliCredential +""" +Sample: Agents in a workflow with streaming + +A Writer agent generates content, then a Reviewer agent critiques it. +The workflow uses streaming so you can observe incremental AgentRunUpdateEvent chunks as each agent produces tokens. + +Purpose: +Show how to wire chat agents directly into a WorkflowBuilder pipeline where agents are auto wrapped as executors. + +Demonstrate: +- Automatic streaming of agent deltas via AgentRunUpdateEvent. +- A simple console aggregator that groups updates by executor id and prints them as they arrive. +- A final WorkflowCompletedEvent that contains the reviewer outcome after both agents finish. + +Prerequisites: +- Foundry Agent Service configured, along with the required environment variables. +- Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample. +- Basic familiarity with WorkflowBuilder, edges, events, and streaming runs. +""" + async def create_foundry_agent() -> tuple[Callable[..., Awaitable[Any]], Callable[[], Awaitable[None]]]: """Helper method to create a Foundry agent factory and a close function. diff --git a/python/samples/getting_started/workflow/agents/workflow_as_agent_human_in_the_loop.py b/python/samples/getting_started/workflow/agents/workflow_as_agent_human_in_the_loop.py index c1af5d63cc..e80c181128 100644 --- a/python/samples/getting_started/workflow/agents/workflow_as_agent_human_in_the_loop.py +++ b/python/samples/getting_started/workflow/agents/workflow_as_agent_human_in_the_loop.py @@ -1,27 +1,31 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import sys from dataclasses import dataclass +from pathlib import Path -from agent_framework import ( +# Ensure local getting_started package can be imported when running as a script. +_SAMPLES_ROOT = Path(__file__).resolve().parents[3] +if str(_SAMPLES_ROOT) not in sys.path: + sys.path.insert(0, str(_SAMPLES_ROOT)) + +from agent_framework import ( # noqa: E402 ChatMessage, + Executor, FunctionCallContent, FunctionResultContent, - Role, -) -from agent_framework.openai import OpenAIChatClient -from agent_framework import ( - Executor, RequestInfoExecutor, RequestInfoMessage, RequestResponse, + Role, WorkflowAgent, WorkflowBuilder, WorkflowContext, handler, ) - -from samples.getting_started.workflow.agents.workflow_as_agent_reflection_pattern import ( +from agent_framework.openai import OpenAIChatClient # noqa: E402 +from getting_started.workflow.agents.workflow_as_agent_reflection_pattern import ( # noqa: E402 ReviewRequest, ReviewResponse, Worker, @@ -56,8 +60,9 @@ class HumanReviewRequest(RequestInfoMessage): class ReviewerWithHumanInTheLoop(Executor): """Executor that always escalates reviews to a human manager.""" - def __init__(self, worker_id: str, request_info_id: str) -> None: - super().__init__() + def __init__(self, worker_id: str, request_info_id: str, reviewer_id: str | None = None) -> None: + unique_id = reviewer_id or f"{worker_id}-reviewer" + super().__init__(id=unique_id) self._worker_id = worker_id self._request_info_id = request_info_id @@ -96,8 +101,8 @@ async def main() -> None: # Create executors for the workflow. print("Creating chat client and executors...") mini_chat_client = OpenAIChatClient(ai_model_id="gpt-4.1-nano") - worker = Worker(chat_client=mini_chat_client) - request_info_executor = RequestInfoExecutor() + worker = Worker(id="sub-worker", chat_client=mini_chat_client) + request_info_executor = RequestInfoExecutor(id="request_info") reviewer = ReviewerWithHumanInTheLoop(worker_id=worker.id, request_info_id=request_info_executor.id) print("Building workflow with Worker ↔ Reviewer cycle...") diff --git a/python/samples/getting_started/workflow/agents/workflow_as_agent_reflection_pattern.py b/python/samples/getting_started/workflow/agents/workflow_as_agent_reflection_pattern.py index 1dd66da437..f9d9e9792c 100644 --- a/python/samples/getting_started/workflow/agents/workflow_as_agent_reflection_pattern.py +++ b/python/samples/getting_started/workflow/agents/workflow_as_agent_reflection_pattern.py @@ -4,9 +4,19 @@ import asyncio from dataclasses import dataclass from uuid import uuid4 -from agent_framework import AgentRunResponseUpdate, ChatClientProtocol, ChatMessage, Contents, Role +from agent_framework import ( + AgentRunResponseUpdate, + AgentRunUpdateEvent, + ChatClientProtocol, + ChatMessage, + Contents, + Executor, + Role, + WorkflowBuilder, + WorkflowContext, + handler, +) from agent_framework.openai import OpenAIChatClient -from agent_framework import AgentRunUpdateEvent, Executor, WorkflowBuilder, WorkflowContext, handler from pydantic import BaseModel """ @@ -54,8 +64,8 @@ class ReviewResponse: class Reviewer(Executor): """Executor that reviews agent responses and provides structured feedback.""" - def __init__(self, chat_client: ChatClientProtocol) -> None: - super().__init__() + def __init__(self, id: str, chat_client: ChatClientProtocol) -> None: + super().__init__(id=id) self._chat_client = chat_client @handler @@ -106,8 +116,8 @@ class Reviewer(Executor): class Worker(Executor): """Executor that generates responses and incorporates feedback when necessary.""" - def __init__(self, chat_client: ChatClientProtocol) -> None: - super().__init__() + def __init__(self, id: str, chat_client: ChatClientProtocol) -> None: + super().__init__(id=id) self._chat_client = chat_client self._pending_requests: dict[str, tuple[ReviewRequest, list[ChatMessage]]] = {} @@ -189,8 +199,8 @@ async def main() -> None: print("Creating chat client and executors...") mini_chat_client = OpenAIChatClient(ai_model_id="gpt-4.1-nano") chat_client = OpenAIChatClient(ai_model_id="gpt-4.1") - reviewer = Reviewer(chat_client=chat_client) - worker = Worker(chat_client=mini_chat_client) + reviewer = Reviewer(id="reviewer", chat_client=chat_client) + worker = Worker(id="worker", chat_client=mini_chat_client) print("Building workflow with Worker ↔ Reviewer cycle...") agent = ( diff --git a/python/samples/getting_started/workflow/checkpoint/checkpoint_with_human_in_the_loop.py b/python/samples/getting_started/workflow/checkpoint/checkpoint_with_human_in_the_loop.py new file mode 100644 index 0000000000..06b75f8148 --- /dev/null +++ b/python/samples/getting_started/workflow/checkpoint/checkpoint_with_human_in_the_loop.py @@ -0,0 +1,483 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import AsyncIterable +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from agent_framework import ( + AgentExecutor, + AgentExecutorRequest, + AgentExecutorResponse, + ChatMessage, + Executor, + FileCheckpointStorage, + RequestInfoEvent, + RequestInfoExecutor, + RequestInfoMessage, + RequestResponse, + Role, + WorkflowBuilder, + WorkflowCompletedEvent, + WorkflowContext, + WorkflowRunState, + WorkflowStatusEvent, + handler, +) +from agent_framework.azure import AzureChatClient +from azure.identity import AzureCliCredential + +# NOTE: the Azure client imports above are real dependencies. When running this +# sample outside of Azure-enabled environments you may wish to swap in the +# `agent_framework.builtin` chat client or mock the writer executor. We keep the +# concrete import here so readers can see an end-to-end configuration. + +if TYPE_CHECKING: + from agent_framework import Workflow + from agent_framework._workflow._checkpoint import WorkflowCheckpoint + +""" +Sample: Checkpoint + human-in-the-loop quickstart. + +This getting-started sample keeps the moving pieces to a minimum: + +1. A brief is turned into a consistent prompt for an AI copywriter. +2. The copywriter (an `AgentExecutor`) drafts release notes. +3. A reviewer gateway routes every draft through `RequestInfoExecutor` so a human + can approve or request tweaks. +4. The workflow records checkpoints between each superstep so you can stop the + program, restart later, and optionally pre-supply human answers on resume. + +Key concepts demonstrated +------------------------- +- Minimal executor pipeline with checkpoint persistence. +- Human-in-the-loop pause/resume by pairing `RequestInfoExecutor` with + checkpoint restoration. +- Supplying responses at restore time (`run_stream_from_checkpoint(..., responses=...)`). + +Typical pause/resume flow +------------------------- +1. Run the workflow until a human approval request is emitted. +2. If the human is offline, exit the program. A checkpoint with + ``status=awaiting human response`` now exists. +3. Later, restart the script, select that checkpoint, and provide the stored + human decision when prompted to pre-supply responses. + Doing so applies the answer immediately on resume, so the system does **not** + re-emit the same `RequestInfoEvent`. +""" + +# Directory used for the sample's temporary checkpoint files. We isolate the +# demo artefacts so that repeated runs do not collide with other samples and so +# the clean-up step at the end of the script can simply delete the directory. +TEMP_DIR = Path(__file__).with_suffix("").parent / "tmp" / "checkpoints_hitl" +TEMP_DIR.mkdir(parents=True, exist_ok=True) + + +class BriefPreparer(Executor): + """Normalises the user brief and sends a single AgentExecutorRequest.""" + + # The first executor in the workflow. By keeping it tiny we make it easier + # to reason about the state that will later be captured in the checkpoint. + # It is responsible for tidying the human-provided brief and kicking off the + # agent run with a deterministic prompt structure. + + def __init__(self, id: str, agent_id: str) -> None: + super().__init__(id=id) + self._agent_id = agent_id + + @handler + async def prepare(self, brief: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: + # Collapse errant whitespace so the prompt is stable between runs. + normalized = " ".join(brief.split()).strip() + if not normalized.endswith("."): + normalized += "." + # Persist the cleaned brief in shared state so downstream executors and + # future checkpoints can recover the original intent. + await ctx.set_shared_state("brief", normalized) + prompt = ( + "You are drafting product release notes. Summarise the brief below in two sentences. " + "Keep it positive and end with a call to action.\n\n" + f"BRIEF: {normalized}" + ) + # Hand the prompt to the writer agent. We always route through the + # workflow context so the runtime can capture messages for checkpointing. + await ctx.send_message( + AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True), + target_id=self._agent_id, + ) + + +@dataclass +class HumanApprovalRequest(RequestInfoMessage): + """Message sent to the human reviewer via RequestInfoExecutor.""" + + # These fields are intentionally simple because they are serialised into + # checkpoints. Keeping them primitive types guarantees the new + # `pending_requests_from_checkpoint` helper can reconstruct them on resume. + prompt: str = "" + draft: str = "" + iteration: int = 0 + + +class ReviewGateway(Executor): + """Routes agent drafts to humans and optionally back for revisions.""" + + def __init__(self, id: str, reviewer_id: str, writer_id: str, finalize_id: str) -> None: + super().__init__(id=id) + self._reviewer_id = reviewer_id + self._writer_id = writer_id + self._finalize_id = finalize_id + + @handler + async def on_agent_response( + self, + response: AgentExecutorResponse, + ctx: WorkflowContext[HumanApprovalRequest], + ) -> None: + # Capture the agent output so we can surface it to the reviewer and + # persist iterations. The `RequestInfoExecutor` relies on this state to + # rehydrate when checkpoints are restored. + draft = response.agent_run_response.text or "" + iteration = int((await ctx.get_state() or {}).get("iteration", 0)) + 1 + await ctx.set_state({"iteration": iteration, "last_draft": draft}) + # Emit a human approval request. Because this flows through + # RequestInfoExecutor it will pause the workflow until an answer is + # supplied either interactively or via pre-supplied responses. + await ctx.send_message( + HumanApprovalRequest( + prompt="Review the draft. Reply 'approve' or provide edit instructions.", + draft=draft, + iteration=iteration, + ), + target_id=self._reviewer_id, + ) + + @handler + async def on_human_feedback( + self, + feedback: RequestResponse[HumanApprovalRequest, str], + ctx: WorkflowContext[AgentExecutorRequest | str], + ) -> None: + # The RequestResponse wrapper gives us both the human data and the + # original request message, even when resuming from checkpoints. + reply = (feedback.data or "").strip() + state = await ctx.get_state() or {} + draft = state.get("last_draft") or (feedback.original_request.draft if feedback.original_request else "") + + if reply.lower() == "approve": + # When the human signs off we can short-circuit the workflow and + # send the approved draft to the final executor. + await ctx.send_message(draft, target_id=self._finalize_id) + return + + # Any other response loops us back to the writer with fresh guidance. + guidance = reply or "Tighten the copy and emphasise customer benefit." + iteration = int(state.get("iteration", 1)) + 1 + await ctx.set_state({"iteration": iteration, "last_draft": draft}) + prompt = ( + "Revise the launch note. Respond with the new copy only.\n\n" + f"Previous draft:\n{draft}\n\n" + f"Human guidance: {guidance}" + ) + await ctx.send_message( + AgentExecutorRequest(messages=[ChatMessage(Role.USER, text=prompt)], should_respond=True), + target_id=self._writer_id, + ) + + +class FinaliseExecutor(Executor): + """Publishes the approved text.""" + + @handler + async def publish(self, text: str, ctx: WorkflowContext[Any]) -> None: + # Store the output so diagnostics or a UI could fetch the final copy. + await ctx.set_state({"published_text": text}) + # Emit a workflow completion event so the runner stops cleanly. + await ctx.add_event(WorkflowCompletedEvent(text)) + + +def create_workflow(*, checkpoint_storage: FileCheckpointStorage | None = None) -> "Workflow": + """Assemble the workflow graph used by both the initial run and resume.""" + + # The Azure client is created once so our agent executor can issue calls to + # the hosted model. The agent id is stable across runs which keeps + # checkpoints deterministic. + chat_client = AzureChatClient(credential=AzureCliCredential()) + writer = AgentExecutor( + chat_client.create_agent( + instructions="Write concise, warm release notes that sound human and helpful.", + ), + id="writer", + ) + # RequestInfoExecutor is the lynchpin for human-in-the-loop: every draft is + # routed through it so checkpoints can pause while waiting for responses. + review = RequestInfoExecutor(id="request_info") + finalise = FinaliseExecutor(id="finalise") + gateway = ReviewGateway( + id="review_gateway", + reviewer_id=review.id, + writer_id=writer.id, + finalize_id=finalise.id, + ) + prepare = BriefPreparer(id="prepare_brief", agent_id=writer.id) + + # Wire the workflow DAG. Edges mirror the numbered steps described in the + # module docstring. Because `WorkflowBuilder` is declarative, reading these + # edges is often the quickest way to understand execution order. + builder = ( + WorkflowBuilder(max_iterations=6) + .set_start_executor(prepare) + .add_edge(prepare, writer) + .add_edge(writer, gateway) + .add_edge(gateway, review) + .add_edge(review, gateway) # human resumes loop + .add_edge(gateway, writer) # revisions + .add_edge(gateway, finalise) + ) + # Opt-in to persistence when the caller provides storage. The workflow + # object itself is identical whether or not checkpointing is enabled. + if checkpoint_storage: + builder = builder.with_checkpointing(checkpoint_storage=checkpoint_storage) + return builder.build() + + +def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: + """Pretty-print saved checkpoints with the new framework summaries.""" + + print("\nCheckpoint summary:") + for summary in [ + RequestInfoExecutor.checkpoint_summary(cp) for cp in sorted(checkpoints, key=lambda c: c.timestamp) + ]: + # Compose a single line per checkpoint so the user can scan the output + # and pick the resume point that still has outstanding human work. + line = ( + f"- {summary.checkpoint_id} | iter={summary.iteration_count} " + f"| targets={summary.targets} | states={summary.executor_states}" + ) + if summary.status: + line += f" | status={summary.status}" + if summary.draft_preview: + line += f" | draft_preview={summary.draft_preview}" + if summary.pending_requests: + line += f" | pending_request_id={summary.pending_requests[0].request_id}" + print(line) + + +def _print_events(events: list[Any]) -> tuple[WorkflowCompletedEvent | None, list[tuple[str, HumanApprovalRequest]]]: + """Echo workflow events to the console and collect outstanding requests.""" + + completed: WorkflowCompletedEvent | None = None + requests: list[tuple[str, HumanApprovalRequest]] = [] + + for event in events: + print(f"Event: {event}") + if isinstance(event, WorkflowCompletedEvent): + completed = event + elif isinstance(event, RequestInfoEvent) and isinstance(event.data, HumanApprovalRequest): + # Capture pending human approvals so the caller can ask the user for + # input after the current batch of events is processed. + requests.append((event.request_id, event.data)) + elif isinstance(event, WorkflowStatusEvent) and event.state in { + WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + }: + print(f"Workflow state: {event.state.name}") + + return completed, requests + + +def _prompt_for_responses(requests: list[tuple[str, HumanApprovalRequest]]) -> dict[str, str] | None: + """Interactive CLI prompt for any live RequestInfo requests.""" + + if not requests: + return None + answers: dict[str, str] = {} + for request_id, request in requests: + # Keep the prompt conversational so testers can use the script without + # memorising the workflow APIs. + print("\n=== Human approval needed ===") + print(f"request_id: {request_id}") + if request.iteration: + print(f"Iteration: {request.iteration}") + print(request.prompt) + print("Draft: \n---\n" + request.draft + "\n---") + answer = input("Type 'approve' or enter revision guidance (or 'exit' to quit): ").strip() # noqa: ASYNC250 + if answer.lower() == "exit": + raise SystemExit("Stopped by user.") + answers[request_id] = answer + return answers + + +def _maybe_pre_supply_responses(cp: "WorkflowCheckpoint") -> dict[str, str] | None: + """Offer to collect responses before resuming a checkpoint.""" + + pending = RequestInfoExecutor.pending_requests_from_checkpoint(cp) + if not pending: + return None + + print( + "This checkpoint still has pending human input. Provide the responses now so the resume step " + "applies them immediately and does not re-emit the original RequestInfo event." + ) + choice = input("Pre-supply responses for this checkpoint? [y/N]: ").strip().lower() # noqa: ASYNC250 + if choice not in {"y", "yes"}: + return None + + answers: dict[str, str] = {} + for item in pending: + iteration = item.iteration or 0 + print(f"\nPending draft (iteration {iteration} | request_id={item.request_id}):") + draft_text = (item.draft or "").strip() + if draft_text: + # The shortened preview in the summary may truncate text; here we + # show the full draft so the reviewer can make an informed choice. + print("Draft:\n---\n" + draft_text + "\n---") + else: + print("Draft: [not captured in checkpoint payload - refer to your notes/log]") + prompt_text = (item.prompt or "Review the draft").strip() + print(prompt_text) + answer = input("Response ('approve' or guidance, 'exit' to abort): ").strip() # noqa: ASYNC250 + if answer.lower() == "exit": + raise SystemExit("Resume aborted by user.") + answers[item.request_id] = answer + return answers + + +async def _consume(stream: AsyncIterable[Any]) -> list[Any]: + """Materialise an async event stream into a list.""" + + return [event async for event in stream] + + +async def run_interactive_session(workflow: "Workflow", initial_message: str) -> WorkflowCompletedEvent | None: + """Run the workflow until it either finishes or pauses for human input.""" + + pending_responses: dict[str, str] | None = None + completed: WorkflowCompletedEvent | None = None + first = True + + while completed is None: + if first: + # Kick off the workflow with the initial brief. The returned events + # include RequestInfo events when the agent produces a draft. + events = await _consume(workflow.run_stream(initial_message)) + first = False + elif pending_responses: + # Feed any answers the user just typed back into the workflow. + events = await _consume(workflow.send_responses_streaming(pending_responses)) + else: + break + + completed, requests = _print_events(events) + pending_responses = _prompt_for_responses(requests) + + return completed + + +async def resume_from_checkpoint( + workflow: "Workflow", + checkpoint_id: str, + storage: FileCheckpointStorage, + pre_supplied: dict[str, str] | None, +) -> None: + """Resume a stored checkpoint and continue until completion or another pause.""" + + print(f"\nResuming from checkpoint: {checkpoint_id}") + events = await _consume( + workflow.run_stream_from_checkpoint( + checkpoint_id, + checkpoint_storage=storage, + responses=pre_supplied, + ) + ) + completed, requests = _print_events(events) + if pre_supplied and not requests and completed is None: + # When the checkpoint only needed the provided answers we let the user + # know the workflow is waiting for the next superstep (usually another + # agent response). + print("Pre-supplied responses applied automatically; workflow is now waiting for the next step.") + + pending = _prompt_for_responses(requests) + while completed is None and pending: + events = await _consume(workflow.send_responses_streaming(pending)) + completed, requests = _print_events(events) + pending = _prompt_for_responses(requests) + + if completed: + print(f"Workflow completed with: {completed.data}") + + +async def main() -> None: + """Entry point used by both the initial run and subsequent resumes.""" + + for file in TEMP_DIR.glob("*.json"): + # Start each execution with a clean slate so the demonstration is + # deterministic even if the directory had stale checkpoints. + file.unlink() + + storage = FileCheckpointStorage(storage_path=TEMP_DIR) + workflow = create_workflow(checkpoint_storage=storage) + + brief = ( + "Introduce our limited edition smart coffee grinder. Mention the $249 price, highlight the " + "sensor that auto-adjusts the grind, and invite customers to pre-order on the website." + ) + + print("Running workflow (human approval required)...") + completed = await run_interactive_session(workflow, initial_message=brief) + if completed: + print(f"Initial run completed with final copy: {completed.data}") + else: + print("Initial run paused for human input.") + + checkpoints = await storage.list_checkpoints() + if not checkpoints: + print("No checkpoints recorded.") + return + + # Show the user what is available before we prompt for the index. The + # summary helper keeps this output consistent with other tooling. + _render_checkpoint_summary(checkpoints) + + sorted_cps = sorted(checkpoints, key=lambda c: c.timestamp) + print("\nAvailable checkpoints:") + for idx, cp in enumerate(sorted_cps): + print(f" [{idx}] id={cp.checkpoint_id} iter={cp.iteration_count}") + + # For the pause/resume demo we typically pick the latest checkpoint whose summary + # status reads "awaiting human response" - that is the saved state that proves the + # workflow can rehydrate, collect the pending answer, and continue after a break. + selection = input("\nResume from which checkpoint? (press Enter to skip): ").strip() # noqa: ASYNC250 + if not selection: + print("No resume selected. Exiting.") + return + + try: + idx = int(selection) + except ValueError: + print("Invalid input; exiting.") + return + + if not 0 <= idx < len(sorted_cps): + print("Index out of range; exiting.") + return + + chosen = sorted_cps[idx] + summary = RequestInfoExecutor.checkpoint_summary(chosen) + if summary.status == "completed": + print("Selected checkpoint already reflects a completed workflow; nothing to resume.") + return + + # If the user wants, capture their decisions now so the resume call can + # push them into the workflow and avoid re-prompting. + pre_responses = _maybe_pre_supply_responses(chosen) + + resumed_workflow = create_workflow() + # Resume with a fresh workflow instance. The checkpoint carries the + # persistent state while this object holds the runtime wiring. + await resume_from_checkpoint(resumed_workflow, chosen.checkpoint_id, storage, pre_responses) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflow/checkpoint/checkpoint_with_resume.py b/python/samples/getting_started/workflow/checkpoint/checkpoint_with_resume.py index 42dbb789fd..c8f675147c 100644 --- a/python/samples/getting_started/workflow/checkpoint/checkpoint_with_resume.py +++ b/python/samples/getting_started/workflow/checkpoint/checkpoint_with_resume.py @@ -3,7 +3,7 @@ import asyncio import os from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from agent_framework import ( AgentExecutor, @@ -12,6 +12,7 @@ from agent_framework import ( ChatMessage, Executor, FileCheckpointStorage, + RequestInfoExecutor, Role, WorkflowBuilder, WorkflowCompletedEvent, @@ -21,6 +22,10 @@ from agent_framework import ( from agent_framework.azure import AzureChatClient from azure.identity import AzureCliCredential +if TYPE_CHECKING: + from agent_framework import Workflow + from agent_framework._workflow._checkpoint import WorkflowCheckpoint + """ Sample: Checkpointing and Resuming a Workflow (with an Agent stage) @@ -87,7 +92,7 @@ class UpperCaseExecutor(Executor): class SubmitToLowerAgent(Executor): """Builds an AgentExecutorRequest to send to the lowercasing agent while keeping shared-state visibility.""" - def __init__(self, agent_id: str, id: str | None = None): + def __init__(self, id: str, agent_id: str): super().__init__(id=id) self._agent_id = agent_id @@ -132,10 +137,6 @@ class FinalizeFromAgent(Executor): class ReverseTextExecutor(Executor): """Reverses the input text and persists local state.""" - def __init__(self, id: str): - """Initialize the executor with an ID.""" - super().__init__(id=id) - @handler async def reverse_text(self, text: str, ctx: WorkflowContext[str]) -> None: result = text[::-1] @@ -154,15 +155,10 @@ class ReverseTextExecutor(Executor): await ctx.send_message(result) -async def main(): - # Clear existing checkpoints in this sample directory for a clean run. - checkpoint_dir = Path(TEMP_DIR) - for file in checkpoint_dir.glob("*.json"): - file.unlink() - +def create_workflow(checkpoint_storage: FileCheckpointStorage) -> "Workflow": # Instantiate the pipeline executors. - upper_case_executor = UpperCaseExecutor(id="upper_case_executor") - reverse_text_executor = ReverseTextExecutor(id="reverse_text_executor") + upper_case_executor = UpperCaseExecutor(id="upper-case") + reverse_text_executor = ReverseTextExecutor(id="reverse-text") # Configure the agent stage that lowercases the text. chat_client = AzureChatClient(credential=AzureCliCredential()) @@ -174,14 +170,11 @@ async def main(): ) # Bridge to the agent and terminalization stage. - submit_lower = SubmitToLowerAgent(agent_id=lower_agent.id, id="submit_lower") + submit_lower = SubmitToLowerAgent(id="submit_lower", agent_id=lower_agent.id) finalize = FinalizeFromAgent(id="finalize") - # Backing store for checkpoints written by with_checkpointing. - checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR) - # Build the workflow with checkpointing enabled. - workflow = ( + return ( WorkflowBuilder(max_iterations=5) .add_edge(upper_case_executor, reverse_text_executor) # Uppercase -> Reverse .add_edge(reverse_text_executor, submit_lower) # Reverse -> Build Agent request @@ -192,6 +185,40 @@ async def main(): .build() ) +def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: + """Display human-friendly checkpoint metadata using framework summaries.""" + + if not checkpoints: + return + + print("\nCheckpoint summary:") + for cp in sorted(checkpoints, key=lambda c: c.timestamp): + summary = RequestInfoExecutor.checkpoint_summary(cp) + msg_count = sum(len(v) for v in cp.messages.values()) + state_keys = sorted(cp.executor_states.keys()) + orig = cp.shared_state.get("original_input") + upper = cp.shared_state.get("upper_output") + + line = ( + f"- {summary.checkpoint_id} | iter={summary.iteration_count} | messages={msg_count} | states={state_keys}" + ) + if summary.status: + line += f" | status={summary.status}" + line += f" | shared_state: original_input='{orig}', upper_output='{upper}'" + print(line) + + +async def main(): + # Clear existing checkpoints in this sample directory for a clean run. + checkpoint_dir = Path(TEMP_DIR) + for file in checkpoint_dir.glob("*.json"): + file.unlink() + + # Backing store for checkpoints written by with_checkpointing. + checkpoint_storage = FileCheckpointStorage(storage_path=TEMP_DIR) + + workflow = create_workflow(checkpoint_storage=checkpoint_storage) + # Run the full workflow once and observe events as they stream. print("Running workflow with initial message...") async for event in workflow.run_stream(message="hello world"): @@ -206,26 +233,20 @@ async def main(): # All checkpoints created by this run share the same workflow_id. workflow_id = all_checkpoints[0].workflow_id - # Dump a quick summary including shared_state keys to illustrate what persisted. - print("\nCheckpoint summary:") - for cp in sorted(all_checkpoints, key=lambda c: c.timestamp): - msg_count = sum(len(v) for v in cp.messages.values()) - state_keys = sorted(list(cp.executor_states.keys())) if hasattr(cp, "executor_states") else [] - orig = cp.shared_state.get("original_input") if hasattr(cp, "shared_state") else None - upper = cp.shared_state.get("upper_output") if hasattr(cp, "shared_state") else None - print( - f"- {cp.checkpoint_id} | " - f"iter={cp.iteration_count} | messages={msg_count} | states={state_keys} | " - f"shared_state: original_input='{orig}', upper_output='{upper}'" - ) + _render_checkpoint_summary(all_checkpoints) # Offer an interactive selection of checkpoints to resume from. sorted_cps = sorted([cp for cp in all_checkpoints if cp.workflow_id == workflow_id], key=lambda c: c.timestamp) print("\nAvailable checkpoints to resume from:") for idx, cp in enumerate(sorted_cps): + summary = RequestInfoExecutor.checkpoint_summary(cp) + line = f" [{idx}] id={summary.checkpoint_id} iter={summary.iteration_count}" + if summary.status: + line += f" status={summary.status}" msg_count = sum(len(v) for v in cp.messages.values()) - print(f" [{idx}] id={cp.checkpoint_id} iter={cp.iteration_count} messages={msg_count}") + line += f" messages={msg_count}" + print(line) user_input = input( "\nEnter checkpoint index (or paste checkpoint id) to resume from, or press Enter to skip resume: " @@ -256,15 +277,7 @@ async def main(): # You can reuse the same workflow graph definition and resume from a prior checkpoint. # This second workflow instance does not enable checkpointing to show that resumption # reads from stored state but need not write new checkpoints. - new_workflow = ( - WorkflowBuilder(max_iterations=5) - .add_edge(upper_case_executor, reverse_text_executor) - .add_edge(reverse_text_executor, submit_lower) - .add_edge(submit_lower, lower_agent) - .add_edge(lower_agent, finalize) - .set_start_executor(upper_case_executor) - .build() - ) + new_workflow = create_workflow(checkpoint_storage=checkpoint_storage) print(f"\nResuming from checkpoint: {chosen_cp_id}") async for event in new_workflow.run_stream_from_checkpoint(chosen_cp_id, checkpoint_storage=checkpoint_storage): @@ -275,15 +288,15 @@ async def main(): Running workflow with initial message... UpperCaseExecutor: 'hello world' -> 'HELLO WORLD' - Event: ExecutorInvokedEvent(executor_id=upper_case_executor) + Event: ExecutorInvokeEvent(executor_id=upper_case_executor) Event: ExecutorCompletedEvent(executor_id=upper_case_executor) ReverseTextExecutor: 'HELLO WORLD' -> 'DLROW OLLEH' - Event: ExecutorInvokedEvent(executor_id=reverse_text_executor) + Event: ExecutorInvokeEvent(executor_id=reverse_text_executor) Event: ExecutorCompletedEvent(executor_id=reverse_text_executor) LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD' - Event: ExecutorInvokedEvent(executor_id=submit_lower) - Event: ExecutorInvokedEvent(executor_id=lower_agent) - Event: ExecutorInvokedEvent(executor_id=finalize) + Event: ExecutorInvokeEvent(executor_id=submit_lower) + Event: ExecutorInvokeEvent(executor_id=lower_agent) + Event: ExecutorInvokeEvent(executor_id=finalize) Event: WorkflowCompletedEvent(data=dlrow olleh) Checkpoint summary: @@ -300,9 +313,9 @@ async def main(): Resuming from checkpoint: a78c345a-e5d9-45ba-82c0-cb725452d91b LowerAgent (shared_state): original_input='hello world', upper_output='HELLO WORLD' - Resumed Event: ExecutorInvokedEvent(executor_id=submit_lower) - Resumed Event: ExecutorInvokedEvent(executor_id=lower_agent) - Resumed Event: ExecutorInvokedEvent(executor_id=finalize) + Resumed Event: ExecutorInvokeEvent(executor_id=submit_lower) + Resumed Event: ExecutorInvokeEvent(executor_id=lower_agent) + Resumed Event: ExecutorInvokeEvent(executor_id=finalize) Resumed Event: WorkflowCompletedEvent(data=dlrow olleh) """ # noqa: E501 diff --git a/python/samples/getting_started/workflow/control-flow/simple_loop.py b/python/samples/getting_started/workflow/control-flow/simple_loop.py index d38c314b02..6a16a73ed9 100644 --- a/python/samples/getting_started/workflow/control-flow/simple_loop.py +++ b/python/samples/getting_started/workflow/control-flow/simple_loop.py @@ -50,7 +50,7 @@ class GuessNumberExecutor(Executor): def __init__(self, bound: tuple[int, int], id: str | None = None): """Initialize the executor with a target number.""" - super().__init__(id=id) + super().__init__(id=id or "guess_number") self._lower = bound[0] self._upper = bound[1] @@ -83,7 +83,7 @@ class SubmitToJudgeAgent(Executor): """Send the numeric guess to a judge agent which replies ABOVE/BELOW/MATCHED.""" def __init__(self, judge_agent_id: str, target: int, id: str | None = None): - super().__init__(id=id) + super().__init__(id=id or "submit_to_judge") self._judge_agent_id = judge_agent_id self._target = target diff --git a/python/samples/getting_started/workflow/human-in-the-loop/guessing_game_with_human_input.py b/python/samples/getting_started/workflow/human-in-the-loop/guessing_game_with_human_input.py index 1be8abf83b..c9a9880b68 100644 --- a/python/samples/getting_started/workflow/human-in-the-loop/guessing_game_with_human_input.py +++ b/python/samples/getting_started/workflow/human-in-the-loop/guessing_game_with_human_input.py @@ -87,7 +87,7 @@ class TurnManager(Executor): """ def __init__(self, id: str | None = None): - super().__init__(id=id) + super().__init__(id=id or "turn_manager") @handler async def start(self, _: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None: diff --git a/python/samples/getting_started/workflow/parallelism/fan_out_fan_in_edges.py b/python/samples/getting_started/workflow/parallelism/fan_out_fan_in_edges.py index d28a2ca6a0..c2c95392c4 100644 --- a/python/samples/getting_started/workflow/parallelism/fan_out_fan_in_edges.py +++ b/python/samples/getting_started/workflow/parallelism/fan_out_fan_in_edges.py @@ -43,7 +43,7 @@ class DispatchToExperts(Executor): """Dispatches the incoming prompt to all expert agent executors for parallel processing (fan out).""" def __init__(self, expert_ids: list[str], id: str | None = None): - super().__init__(id) + super().__init__(id=id or "dispatch_to_experts") self._expert_ids = expert_ids @handler @@ -71,7 +71,7 @@ class AggregateInsights(Executor): """Aggregates expert agent responses into a single consolidated result (fan in).""" def __init__(self, expert_ids: list[str], id: str | None = None): - super().__init__(id) + super().__init__(id=id or "aggregate_insights") self._expert_ids = expert_ids @handler diff --git a/python/samples/getting_started/workflow/parallelism/map_reduce_and_visualization.py b/python/samples/getting_started/workflow/parallelism/map_reduce_and_visualization.py index 160b62ec6b..6098da7303 100644 --- a/python/samples/getting_started/workflow/parallelism/map_reduce_and_visualization.py +++ b/python/samples/getting_started/workflow/parallelism/map_reduce_and_visualization.py @@ -61,7 +61,7 @@ class Split(Executor): def __init__(self, map_executor_ids: list[str], id: str | None = None): """Store mapper ids so we can assign non overlapping ranges per mapper.""" - super().__init__(id) + super().__init__(id=id or "split") self._map_executor_ids = map_executor_ids @handler @@ -145,7 +145,7 @@ class Shuffle(Executor): def __init__(self, reducer_ids: list[str], id: str | None = None): """Remember reducer ids so we can partition work deterministically.""" - super().__init__(id) + super().__init__(id=id or "shuffle") self._reducer_ids = reducer_ids @handler diff --git a/python/samples/getting_started/workflow/visualization/concurrent_with_visualization.py b/python/samples/getting_started/workflow/visualization/concurrent_with_visualization.py index dfc4c3b0fd..5132299f43 100644 --- a/python/samples/getting_started/workflow/visualization/concurrent_with_visualization.py +++ b/python/samples/getting_started/workflow/visualization/concurrent_with_visualization.py @@ -40,7 +40,7 @@ class DispatchToExperts(Executor): """Dispatches the incoming prompt to all expert agent executors (fan-out).""" def __init__(self, expert_ids: list[str], id: str | None = None): - super().__init__(id) + super().__init__(id=id or "dispatch_to_experts") self._expert_ids = expert_ids @handler @@ -67,7 +67,7 @@ class AggregateInsights(Executor): """Aggregates expert agent responses into a single consolidated result (fan-in).""" def __init__(self, expert_ids: list[str], id: str | None = None): - super().__init__(id) + super().__init__(id=id or "aggregate_insights") self._expert_ids = expert_ids @handler