From c2c8ec3d4e6602bfef15d7a2849a565a7364ba2b Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 8 Oct 2025 10:04:06 -0700 Subject: [PATCH] Python: Reorganize workflows modules (#1282) * Reorganize modules * Fix unit tests * Remove submodules --- .../agent_framework/_workflows/__init__.py | 8 +- .../agent_framework/_workflows/__init__.pyi | 8 +- .../agent_framework/_workflows/_events.py | 2 +- .../agent_framework/_workflows/_executor.py | 841 +----------------- .../agent_framework/_workflows/_magentic.py | 5 +- .../_workflows/_request_info_executor.py | 841 ++++++++++++++++++ .../agent_framework/_workflows/_runner.py | 6 +- .../agent_framework/_workflows/_validation.py | 3 +- .../agent_framework/_workflows/_workflow.py | 5 +- .../_workflows/_workflow_executor.py | 4 +- .../tests/workflow/test_checkpoint_decode.py | 2 +- .../test_request_info_executor_rehydrate.py | 2 +- 12 files changed, 873 insertions(+), 854 deletions(-) create mode 100644 python/packages/core/agent_framework/_workflows/_request_info_executor.py diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index f9e292465a..5af4049128 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -48,9 +48,6 @@ from ._events import ( ) from ._executor import ( Executor, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, handler, ) from ._function_executor import FunctionExecutor, executor @@ -76,6 +73,11 @@ from ._magentic import ( MagenticStartMessage, StandardMagenticManager, ) +from ._request_info_executor import ( + RequestInfoExecutor, + RequestInfoMessage, + RequestResponse, +) from ._runner import Runner from ._runner_context import ( InProcRunnerContext, diff --git a/python/packages/core/agent_framework/_workflows/__init__.pyi b/python/packages/core/agent_framework/_workflows/__init__.pyi index 5b0dcc799d..db8f87eb4e 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.pyi +++ b/python/packages/core/agent_framework/_workflows/__init__.pyi @@ -46,9 +46,6 @@ from ._events import ( ) from ._executor import ( Executor, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, handler, ) from ._function_executor import FunctionExecutor, executor @@ -74,6 +71,11 @@ from ._magentic import ( MagenticStartMessage, StandardMagenticManager, ) +from ._request_info_executor import ( + RequestInfoExecutor, + RequestInfoMessage, + RequestResponse, +) from ._runner import Runner from ._runner_context import ( InProcRunnerContext, diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index 90a2f1f174..58e699e2b4 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, TypeAlias from agent_framework import AgentRunResponse, AgentRunResponseUpdate if TYPE_CHECKING: - from ._executor import RequestInfoMessage + from ._request_info_executor import RequestInfoMessage class WorkflowEventSource(str, Enum): diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 80dd30e402..1f822e870a 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -2,61 +2,29 @@ import contextlib import functools -import importlib import inspect -import json import logging -import uuid -from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass, field, fields, is_dataclass -from textwrap import shorten -from typing import Any, ClassVar, Generic, TypeVar, cast +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar from ..observability import create_processing_span -from ._checkpoint import WorkflowCheckpoint from ._events import ( ExecutorCompletedEvent, ExecutorFailedEvent, ExecutorInvokedEvent, - RequestInfoEvent, WorkflowErrorDetails, _framework_event_origin, # type: ignore[reportPrivateUsage] ) from ._model_utils import DictConvertible -from ._runner_context import Message, RunnerContext, _decode_checkpoint_value # type: ignore +from ._runner_context import Message, RunnerContext # type: ignore from ._shared_state import SharedState from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext, validate_function_signature logger = logging.getLogger(__name__) + + # region Executor - - -@dataclass -class PendingRequestDetails: - """Lightweight information about a pending request captured in a checkpoint.""" - - request_id: str - prompt: str | None = None - draft: str | None = None - iteration: int | None = None - source_executor_id: str | None = None - original_request: "RequestInfoMessage | dict[str, Any] | None" = None - - -@dataclass -class WorkflowCheckpointSummary: - """Human-readable summary of a workflow checkpoint.""" - - checkpoint_id: str - iteration_count: int - targets: list[str] - executor_states: list[str] - status: str - draft_preview: str | None - pending_requests: list[PendingRequestDetails] - - class Executor(DictConvertible): """Base class for all workflow executors that process messages and perform computations. @@ -542,802 +510,3 @@ def handler( # endregion: Handler Decorator - - -# region Request/Response Types -@dataclass -class RequestInfoMessage: - """Base class for all request messages in workflows. - - Any message that should be routed to the RequestInfoExecutor for external - handling must inherit from this class. This ensures type safety and makes - the request/response pattern explicit. - """ - - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - """Unique identifier for correlating requests and responses.""" - - source_executor_id: str | None = None - """ID of the executor expecting a response to this request. - May differ from the executor that sent the request if intercepted and forwarded.""" - - -TRequest = TypeVar("TRequest", bound="RequestInfoMessage") -TResponse = TypeVar("TResponse") - - -@dataclass -class RequestResponse(Generic[TRequest, TResponse]): - """Response type for request/response correlation in workflows. - - This type is used by RequestInfoExecutor to create correlated responses - that include the original request context for proper message routing. - """ - - data: TResponse - """The response data returned from handling the request.""" - - original_request: TRequest - """The original request that this response corresponds to.""" - - request_id: str - """The ID of the original request.""" - - -# endregion: Request/Response Types - - -# region Request Info Executor -class RequestInfoExecutor(Executor): - """Built-in executor that handles request/response patterns in workflows. - - This executor acts as a gateway for external information requests. When it receives - a request message, it saves the request details and emits a RequestInfoEvent. When - a response is provided externally, it emits the response as a message. - """ - - _PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info" - - def __init__(self, id: str): - """Initialize the RequestInfoExecutor with a unique ID. - - Args: - id: Unique ID for this RequestInfoExecutor. - """ - super().__init__(id=id) - self._request_events: dict[str, RequestInfoEvent] = {} - - @handler - async def run(self, message: RequestInfoMessage, ctx: WorkflowContext) -> None: - """Run the RequestInfoExecutor with the given message.""" - # Use source_executor_id from message if available, otherwise fall back to context - source_executor_id = message.source_executor_id or ctx.get_source_executor_id() - - event = RequestInfoEvent( - request_id=message.request_id, - source_executor_id=source_executor_id, - request_type=type(message), - request_data=message, - ) - self._request_events[message.request_id] = event - await self._record_pending_request_snapshot(message, source_executor_id, ctx) - await ctx.add_event(event) - - async def handle_response( - self, - response_data: Any, - request_id: str, - ctx: WorkflowContext[RequestResponse[RequestInfoMessage, Any]], - ) -> None: - """Handle a response to a request. - - Args: - request_id: The ID of the request to which this response corresponds. - response_data: The data returned in the response. - ctx: The workflow context for sending the response. - """ - event = self._request_events.get(request_id) - if event is None: - event = await self._rehydrate_request_event(request_id, ctx) - if event is None: - raise ValueError(f"No request found with ID: {request_id}") - - self._request_events.pop(request_id, None) - - # Create a correlated response that includes both the response data and original request - if not isinstance(event.data, RequestInfoMessage): - raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}") - correlated_response = RequestResponse(data=response_data, original_request=event.data, request_id=request_id) - await ctx.send_message(correlated_response, target_id=event.source_executor_id) - - await self._clear_pending_request_snapshot(request_id, ctx) - - async def _record_pending_request_snapshot( - self, - request: RequestInfoMessage, - source_executor_id: str, - ctx: WorkflowContext[Any], - ) -> None: - snapshot = self._build_request_snapshot(request, source_executor_id) - - pending = await self._load_pending_request_state(ctx) - pending[request.request_id] = snapshot - await self._persist_pending_request_state(pending, ctx) - await self._write_executor_state(ctx, pending) - - async def _clear_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> None: - pending = await self._load_pending_request_state(ctx) - if request_id in pending: - pending.pop(request_id, None) - await self._persist_pending_request_state(pending, ctx) - await self._write_executor_state(ctx, pending) - - async def _load_pending_request_state(self, ctx: WorkflowContext[Any]) -> dict[str, Any]: - try: - existing = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY) - except KeyError: - return {} - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"RequestInfoExecutor {self.id} failed to read pending request state: {exc}") - return {} - - if not isinstance(existing, dict): - if existing not in (None, {}): - logger.warning( - f"RequestInfoExecutor {self.id} encountered non-dict pending state " - f"({type(existing).__name__}); resetting." - ) - return {} - - return dict(existing) # type: ignore[arg-type] - - async def _persist_pending_request_state(self, pending: dict[str, Any], ctx: WorkflowContext[Any]) -> None: - await self._safe_set_shared_state(ctx, pending) - await self._safe_set_runner_state(ctx, pending) - - async def _safe_set_shared_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: - try: - await ctx.set_shared_state(self._PENDING_SHARED_STATE_KEY, pending) - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"RequestInfoExecutor {self.id} failed to update shared pending state: {exc}") - - async def _safe_set_runner_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: - try: - await ctx.set_state({"pending_requests": pending}) - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"RequestInfoExecutor {self.id} failed to update runner state with pending requests: {exc}") - - def snapshot_state(self) -> dict[str, Any]: - """Serialize pending requests so checkpoint restoration can resume seamlessly.""" - - def _encode_event(event: RequestInfoEvent) -> dict[str, Any]: - request_data = event.data - payload: dict[str, Any] - data_cls = request_data.__class__ if request_data is not None else type(None) - - payload = self._encode_request_payload(request_data, data_cls) - - return { - "source_executor_id": event.source_executor_id, - "request_type": f"{event.request_type.__module__}:{event.request_type.__qualname__}", - "request_data": payload, - } - - return { - "request_events": {rid: _encode_event(event) for rid, event in self._request_events.items()}, - } - - def _encode_request_payload(self, request_data: RequestInfoMessage | None, data_cls: type[Any]) -> dict[str, Any]: - if request_data is None or isinstance(request_data, (str, int, float, bool)): - return { - "kind": "raw", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": request_data, - } - - if is_dataclass(request_data) and not isinstance(request_data, type): - dataclass_instance = cast(Any, request_data) - safe_value = self._make_json_safe(asdict(dataclass_instance)) - return { - "kind": "dataclass", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - to_dict_fn = getattr(request_data, "to_dict", None) - if callable(to_dict_fn): - try: - dumped = to_dict_fn() - except TypeError: - dumped = to_dict_fn() - safe_value = self._make_json_safe(dumped) - return { - "kind": "dict", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - to_json_fn = getattr(request_data, "to_json", None) - if callable(to_json_fn): - try: - dumped = to_json_fn() - except TypeError: - dumped = to_json_fn() - converted = dumped - if isinstance(dumped, (str, bytes, bytearray)): - decoded: str | bytes | bytearray - if isinstance(dumped, (bytes, bytearray)): - try: - decoded = dumped.decode() - except Exception: - decoded = dumped - else: - decoded = dumped - try: - converted = json.loads(decoded) - except Exception: - converted = decoded - safe_value = self._make_json_safe(converted) - return { - "kind": "dict" if isinstance(converted, dict) else "json", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - details = self._serialise_request_details(request_data) - if details is not None: - safe_value = self._make_json_safe(details) - return { - "kind": "raw", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - safe_value = self._make_json_safe(request_data) - return { - "kind": "raw", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - def restore_state(self, state: dict[str, Any]) -> None: - """Restore pending request bookkeeping from checkpoint state.""" - self._request_events.clear() - stored_events = state.get("request_events", {}) - - for request_id, payload in stored_events.items(): - request_type_qual = payload.get("request_type", "") - try: - request_type = self._import_qualname(request_type_qual) - except Exception as exc: # pragma: no cover - defensive fallback - logger.debug( - "RequestInfoExecutor %s failed to import %s during restore: %s", - self.id, - request_type_qual, - exc, - ) - request_type = RequestInfoMessage - request_data_meta = payload.get("request_data", {}) - request_data = self._decode_request_data(request_data_meta) - event = RequestInfoEvent( - request_id=request_id, - source_executor_id=payload.get("source_executor_id", ""), - request_type=request_type, - request_data=request_data, - ) - self._request_events[request_id] = event - - @staticmethod - def _import_qualname(qualname: str) -> type[Any]: - module_name, _, type_name = qualname.partition(":") - if not module_name or not type_name: - raise ValueError(f"Invalid qualified name: {qualname}") - module = importlib.import_module(module_name) - attr: Any = module - for part in type_name.split("."): - attr = getattr(attr, part) - if not isinstance(attr, type): - raise TypeError(f"Resolved object is not a type: {qualname}") - return attr - - def _decode_request_data(self, metadata: dict[str, Any]) -> RequestInfoMessage: - kind = metadata.get("kind") - type_name = metadata.get("type", "") - value: Any = metadata.get("value", {}) - if type_name: - try: - imported = self._import_qualname(type_name) - except Exception as exc: # pragma: no cover - defensive fallback - logger.debug( - "RequestInfoExecutor %s failed to import %s during decode: %s", - self.id, - type_name, - exc, - ) - imported = RequestInfoMessage - else: - imported = RequestInfoMessage - target_cls: type[RequestInfoMessage] - if isinstance(imported, type) and issubclass(imported, RequestInfoMessage): - target_cls = imported - else: - target_cls = RequestInfoMessage - - if kind == "dataclass" and isinstance(value, dict): - with contextlib.suppress(TypeError): - return target_cls(**value) # type: ignore[arg-type] - - # Backwards-compat handling for checkpoints that used to store pydantic as "dict" - if kind in {"dict", "pydantic", "json"} and isinstance(value, dict): - from_dict = getattr(target_cls, "from_dict", None) - if callable(from_dict): - with contextlib.suppress(Exception): - return cast(RequestInfoMessage, from_dict(value)) - - if kind == "json" and isinstance(value, str): - from_json = getattr(target_cls, "from_json", None) - if callable(from_json): - with contextlib.suppress(Exception): - return cast(RequestInfoMessage, from_json(value)) - with contextlib.suppress(Exception): - parsed = json.loads(value) - if isinstance(parsed, dict): - return self._decode_request_data({"kind": "dict", "type": type_name, "value": parsed}) - - if isinstance(value, dict): - with contextlib.suppress(TypeError): - return target_cls(**value) # type: ignore[arg-type] - instance = object.__new__(target_cls) - instance.__dict__.update(value) # type: ignore[arg-type] - return instance - - with contextlib.suppress(Exception): - return target_cls() - return RequestInfoMessage() - - async def _write_executor_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: - state = self.snapshot_state() - state["pending_requests"] = pending - try: - await ctx.set_state(state) - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"RequestInfoExecutor {self.id} failed to persist executor state: {exc}") - - def _build_request_snapshot( - self, - request: RequestInfoMessage, - source_executor_id: str, - ) -> dict[str, Any]: - snapshot: dict[str, Any] = { - "request_id": request.request_id, - "source_executor_id": source_executor_id, - "request_type": f"{type(request).__module__}:{type(request).__name__}", - "summary": repr(request), - } - - details = self._serialise_request_details(request) - if details: - snapshot["details"] = details - for key in ("prompt", "draft", "iteration"): - if key in details and key not in snapshot: - snapshot[key] = details[key] - - return snapshot - - def _serialise_request_details(self, request: RequestInfoMessage) -> dict[str, Any] | None: - if is_dataclass(request): - data = self._make_json_safe(asdict(request)) - if isinstance(data, dict): - return cast(dict[str, Any], data) - return None - - to_dict = getattr(request, "to_dict", None) - if callable(to_dict): - try: - dump = self._make_json_safe(to_dict()) - except TypeError: - dump = self._make_json_safe(to_dict()) - if isinstance(dump, dict): - return cast(dict[str, Any], dump) - return None - - to_json = getattr(request, "to_json", None) - if callable(to_json): - try: - raw = to_json() - except TypeError: - raw = to_json() - converted = raw - if isinstance(raw, (str, bytes, bytearray)): - decoded: str | bytes | bytearray - if isinstance(raw, (bytes, bytearray)): - try: - decoded = raw.decode() - except Exception: - decoded = raw - else: - decoded = raw - try: - converted = json.loads(decoded) - except Exception: - converted = decoded - dump = self._make_json_safe(converted) - if isinstance(dump, dict): - return cast(dict[str, Any], dump) - return None - - attrs = getattr(request, "__dict__", None) - if isinstance(attrs, dict): - cleaned = self._make_json_safe(attrs) - if isinstance(cleaned, dict): - return cast(dict[str, Any], cleaned) - - return None - - def _make_json_safe(self, value: Any) -> Any: - if value is None or isinstance(value, (str, int, float, bool)): - return value - if isinstance(value, Mapping): - safe_dict: dict[str, Any] = {} - for key, val in value.items(): # type: ignore[attr-defined] - safe_dict[str(key)] = self._make_json_safe(val) # type: ignore[arg-type] - return safe_dict - if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): - return [self._make_json_safe(item) for item in value] # type: ignore[misc] - return repr(value) - - async def has_pending_request(self, request_id: str, ctx: WorkflowContext[Any]) -> bool: - if request_id in self._request_events: - return True - snapshot = await self._get_pending_request_snapshot(request_id, ctx) - return snapshot is not None - - async def _rehydrate_request_event( - self, - request_id: str, - ctx: WorkflowContext[Any], - ) -> RequestInfoEvent | None: - snapshot = await self._get_pending_request_snapshot(request_id, ctx) - if snapshot is None: - return None - - source_executor_id = snapshot.get("source_executor_id") - if not isinstance(source_executor_id, str) or not source_executor_id: - return None - - request = self._construct_request_from_snapshot(snapshot) - if request is None: - return None - - event = RequestInfoEvent( - request_id=request_id, - source_executor_id=source_executor_id, - request_type=type(request), - request_data=request, - ) - self._request_events[request_id] = event - return event - - async def _get_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> dict[str, Any] | None: - pending = await self._collect_pending_request_snapshots(ctx) - snapshot = pending.get(request_id) - if snapshot is None: - return None - return snapshot - - async def _collect_pending_request_snapshots(self, ctx: WorkflowContext[Any]) -> dict[str, dict[str, Any]]: - combined: dict[str, dict[str, Any]] = {} - - try: - shared_pending = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY) - except KeyError: - shared_pending = None - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"RequestInfoExecutor {self.id} failed to read shared pending state during rehydrate: {exc}") - shared_pending = None - - if isinstance(shared_pending, dict): - for key, value in shared_pending.items(): # type: ignore[attr-defined] - if isinstance(key, str) and isinstance(value, dict): - combined[key] = cast(dict[str, Any], value) - - try: - state = await ctx.get_state() - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"RequestInfoExecutor {self.id} failed to read runner state during rehydrate: {exc}") - state = None - - if isinstance(state, dict): - state_pending = state.get("pending_requests") - if isinstance(state_pending, dict): - for key, value in state_pending.items(): # type: ignore[attr-defined] - if isinstance(key, str) and isinstance(value, dict) and key not in combined: - combined[key] = cast(dict[str, Any], value) - - return combined - - def _construct_request_from_snapshot(self, snapshot: dict[str, Any]) -> RequestInfoMessage | None: - details_raw = snapshot.get("details") - details: dict[str, Any] = cast(dict[str, Any], details_raw) if isinstance(details_raw, dict) else {} - - request_cls: type[RequestInfoMessage] = RequestInfoMessage - request_type_str = snapshot.get("request_type") - if isinstance(request_type_str, str) and ":" in request_type_str: - module_name, class_name = request_type_str.split(":", 1) - try: - module = importlib.import_module(module_name) - candidate = getattr(module, class_name) - if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage): - request_cls = candidate - except Exception as exc: - logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}") - request_cls = RequestInfoMessage - - request: RequestInfoMessage | None = self._instantiate_request(request_cls, details) - - if request is None and request_cls is not RequestInfoMessage: - request = self._instantiate_request(RequestInfoMessage, details) - - if request is None: - logger.warning( - f"RequestInfoExecutor {self.id} could not reconstruct request " - f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(details.keys())}" - ) - return None - - for key, value in details.items(): - if key == "request_id": - continue - try: - setattr(request, key, value) - except Exception as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}" - ) - continue - - snapshot_request_id = snapshot.get("request_id") - if isinstance(snapshot_request_id, str) and snapshot_request_id: - try: - request.request_id = snapshot_request_id - except Exception as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not apply snapshot " - f"request_id to {type(request).__name__}: {exc}" - ) - - return request - - def _instantiate_request( - self, - request_cls: type[RequestInfoMessage], - details: dict[str, Any], - ) -> RequestInfoMessage | None: - try: - from_dict = getattr(request_cls, "from_dict", None) - if callable(from_dict): - return cast(RequestInfoMessage, from_dict(details)) - except (TypeError, ValueError) as exc: - logger.debug(f"RequestInfoExecutor {self.id} failed to hydrate {request_cls.__name__} via from_dict: {exc}") - except Exception as exc: - logger.warning( - f"RequestInfoExecutor {self.id} encountered unexpected error during " - f"{request_cls.__name__}.from_dict: {exc}" - ) - - if is_dataclass(request_cls): - try: - field_names = {f.name for f in fields(request_cls)} - ctor_kwargs = {name: details[name] for name in field_names if name in details} - return request_cls(**ctor_kwargs) - except (TypeError, ValueError) as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not instantiate dataclass " - f"{request_cls.__name__} with snapshot data: {exc}" - ) - except Exception as exc: - logger.warning( - f"RequestInfoExecutor {self.id} encountered unexpected error " - f"constructing dataclass {request_cls.__name__}: {exc}" - ) - - try: - instance = request_cls() - except Exception as exc: - logger.warning( - f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}" - ) - return None - - for key, value in details.items(): - if key == "request_id": - continue - try: - setattr(instance, key, value) - except Exception as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not set attribute {key} on " - f"{request_cls.__name__} during instantiation: {exc}" - ) - continue - - return instance - - @staticmethod - def pending_requests_from_checkpoint( - checkpoint: WorkflowCheckpoint, - *, - request_executor_ids: Iterable[str] | None = None, - ) -> list[PendingRequestDetails]: - executor_filter: set[str] | None = None - if request_executor_ids is not None: - executor_filter = {str(value) for value in request_executor_ids} - - pending: dict[str, PendingRequestDetails] = {} - - shared_map = checkpoint.shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) - if isinstance(shared_map, Mapping): - for request_id, snapshot in shared_map.items(): # type: ignore[attr-defined] - RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type] - - for state in checkpoint.executor_states.values(): - if not isinstance(state, Mapping): - continue - inner = state.get("pending_requests") - if isinstance(inner, Mapping): - for request_id, snapshot in inner.items(): # type: ignore[attr-defined] - RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type] - - for source_id, message_list in checkpoint.messages.items(): - if executor_filter is not None and source_id not in executor_filter: - continue - if not isinstance(message_list, list): - continue - for message in message_list: - if not isinstance(message, Mapping): - continue - payload = _decode_checkpoint_value(message.get("data")) - RequestInfoExecutor._merge_message_payload(pending, payload, message) - - return list(pending.values()) - - @staticmethod - def checkpoint_summary( - checkpoint: WorkflowCheckpoint, - *, - request_executor_ids: Iterable[str] | None = None, - preview_width: int = 70, - ) -> WorkflowCheckpointSummary: - targets = sorted(checkpoint.messages.keys()) - executor_states = sorted(checkpoint.executor_states.keys()) - pending = RequestInfoExecutor.pending_requests_from_checkpoint( - checkpoint, request_executor_ids=request_executor_ids - ) - - draft_preview: str | None = None - for entry in pending: - if entry.draft: - draft_preview = shorten(entry.draft, width=preview_width, placeholder="…") - break - - status = "idle" - if pending: - status = "awaiting human response" - elif not checkpoint.messages and "finalise" in executor_states: - status = "completed" - elif checkpoint.messages: - status = "awaiting next superstep" - elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids): - status = "awaiting request delivery" - - return WorkflowCheckpointSummary( - checkpoint_id=checkpoint.checkpoint_id, - iteration_count=checkpoint.iteration_count, - targets=targets, - executor_states=executor_states, - status=status, - draft_preview=draft_preview, - pending_requests=pending, - ) - - @staticmethod - def _merge_snapshot( - pending: dict[str, PendingRequestDetails], - request_id: str, - snapshot: Any, - ) -> None: - if not request_id or not isinstance(snapshot, Mapping): - return - - details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) - - RequestInfoExecutor._apply_update( - details, - prompt=snapshot.get("prompt"), # type: ignore[attr-defined] - draft=snapshot.get("draft"), # type: ignore[attr-defined] - iteration=snapshot.get("iteration"), # type: ignore[attr-defined] - source_executor_id=snapshot.get("source_executor_id"), # type: ignore[attr-defined] - ) - - extra = snapshot.get("details") # type: ignore[attr-defined] - if isinstance(extra, Mapping): - RequestInfoExecutor._apply_update( - details, - prompt=extra.get("prompt"), # type: ignore[attr-defined] - draft=extra.get("draft"), # type: ignore[attr-defined] - iteration=extra.get("iteration"), # type: ignore[attr-defined] - ) - - @staticmethod - def _merge_message_payload( - pending: dict[str, PendingRequestDetails], - payload: Any, - raw_message: Mapping[str, Any], - ) -> None: - if isinstance(payload, RequestResponse): - request_id = payload.request_id or RequestInfoExecutor._get_field(payload.original_request, "request_id") # type: ignore[arg-type] - if not request_id: - return - details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) - RequestInfoExecutor._apply_update( - details, - prompt=RequestInfoExecutor._get_field(payload.original_request, "prompt"), # type: ignore[arg-type] - draft=RequestInfoExecutor._get_field(payload.original_request, "draft"), # type: ignore[arg-type] - iteration=RequestInfoExecutor._get_field(payload.original_request, "iteration"), # type: ignore[arg-type] - source_executor_id=raw_message.get("source_id"), - original_request=payload.original_request, # type: ignore[arg-type] - ) - elif isinstance(payload, RequestInfoMessage): - request_id = getattr(payload, "request_id", None) - if not request_id: - return - details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) - RequestInfoExecutor._apply_update( - details, - prompt=getattr(payload, "prompt", None), - draft=getattr(payload, "draft", None), - iteration=getattr(payload, "iteration", None), - source_executor_id=raw_message.get("source_id"), - original_request=payload, - ) - - @staticmethod - def _apply_update( - details: PendingRequestDetails, - *, - prompt: Any = None, - draft: Any = None, - iteration: Any = None, - source_executor_id: Any = None, - original_request: Any = None, - ) -> None: - if prompt and not details.prompt: - details.prompt = str(prompt) - if draft and not details.draft: - details.draft = str(draft) - if iteration is not None and details.iteration is None: - coerced = RequestInfoExecutor._coerce_int(iteration) - if coerced is not None: - details.iteration = coerced - if source_executor_id and not details.source_executor_id: - details.source_executor_id = str(source_executor_id) - if original_request is not None and details.original_request is None: - details.original_request = original_request - - @staticmethod - def _get_field(obj: Any, key: str) -> Any: - if obj is None: - return None - if isinstance(obj, Mapping): - return obj.get(key) # type: ignore[attr-defined,return-value] - return getattr(obj, key, None) - - @staticmethod - def _coerce_int(value: Any) -> int | None: - try: - return int(value) - except (TypeError, ValueError): - return None - - -# endregion: Request Info Executor diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 0f61f1ca7e..457116c65c 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -27,8 +27,9 @@ from agent_framework._agents import BaseAgent from ._checkpoint import CheckpointStorage, WorkflowCheckpoint from ._events import WorkflowEvent -from ._executor import Executor, RequestInfoMessage, RequestResponse, handler +from ._executor import Executor, handler from ._model_utils import DictConvertible, encode_value +from ._request_info_executor import RequestInfoMessage, RequestResponse from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult from ._workflow_context import WorkflowContext @@ -1939,7 +1940,7 @@ class MagenticBuilder: workflow_builder = WorkflowBuilder().set_start_executor(orchestrator_executor) if self._enable_plan_review: - from ._executor import RequestInfoExecutor + from ._request_info_executor import RequestInfoExecutor request_info = RequestInfoExecutor(id="magentic_plan_review") workflow_builder = ( diff --git a/python/packages/core/agent_framework/_workflows/_request_info_executor.py b/python/packages/core/agent_framework/_workflows/_request_info_executor.py new file mode 100644 index 0000000000..ef6eb3fc22 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_request_info_executor.py @@ -0,0 +1,841 @@ +# Copyright (c) Microsoft. All rights reserved. + +import contextlib +import importlib +import json +import logging +import uuid +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from textwrap import shorten +from typing import Any, ClassVar, Generic, TypeVar, cast + +from ._checkpoint import WorkflowCheckpoint +from ._events import ( + RequestInfoEvent, # type: ignore[reportPrivateUsage] +) +from ._executor import Executor, handler +from ._runner_context import _decode_checkpoint_value # type: ignore +from ._workflow_context import WorkflowContext + +logger = logging.getLogger(__name__) + + +@dataclass +class PendingRequestDetails: + """Lightweight information about a pending request captured in a checkpoint.""" + + request_id: str + prompt: str | None = None + draft: str | None = None + iteration: int | None = None + source_executor_id: str | None = None + original_request: "RequestInfoMessage | dict[str, Any] | None" = None + + +@dataclass +class WorkflowCheckpointSummary: + """Human-readable summary of a workflow checkpoint.""" + + checkpoint_id: str + iteration_count: int + targets: list[str] + executor_states: list[str] + status: str + draft_preview: str | None + pending_requests: list[PendingRequestDetails] + + +@dataclass +class RequestInfoMessage: + """Base class for all request messages in workflows. + + Any message that should be routed to the RequestInfoExecutor for external + handling must inherit from this class. This ensures type safety and makes + the request/response pattern explicit. + """ + + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + """Unique identifier for correlating requests and responses.""" + + source_executor_id: str | None = None + """ID of the executor expecting a response to this request. + May differ from the executor that sent the request if intercepted and forwarded.""" + + +TRequest = TypeVar("TRequest", bound="RequestInfoMessage") +TResponse = TypeVar("TResponse") + + +@dataclass +class RequestResponse(Generic[TRequest, TResponse]): + """Response type for request/response correlation in workflows. + + This type is used by RequestInfoExecutor to create correlated responses + that include the original request context for proper message routing. + """ + + data: TResponse + """The response data returned from handling the request.""" + + original_request: TRequest + """The original request that this response corresponds to.""" + + request_id: str + """The ID of the original request.""" + + +# endregion: Request/Response Types + + +# region Request Info Executor +class RequestInfoExecutor(Executor): + """Built-in executor that handles request/response patterns in workflows. + + This executor acts as a gateway for external information requests. When it receives + a request message, it saves the request details and emits a RequestInfoEvent. When + a response is provided externally, it emits the response as a message. + """ + + _PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info" + + def __init__(self, id: str): + """Initialize the RequestInfoExecutor with a unique ID. + + Args: + id: Unique ID for this RequestInfoExecutor. + """ + super().__init__(id=id) + self._request_events: dict[str, RequestInfoEvent] = {} + + @handler + async def run(self, message: RequestInfoMessage, ctx: WorkflowContext) -> None: + """Run the RequestInfoExecutor with the given message.""" + # Use source_executor_id from message if available, otherwise fall back to context + source_executor_id = message.source_executor_id or ctx.get_source_executor_id() + + event = RequestInfoEvent( + request_id=message.request_id, + source_executor_id=source_executor_id, + request_type=type(message), + request_data=message, + ) + self._request_events[message.request_id] = event + await self._record_pending_request_snapshot(message, source_executor_id, ctx) + await ctx.add_event(event) + + async def handle_response( + self, + response_data: Any, + request_id: str, + ctx: WorkflowContext[RequestResponse[RequestInfoMessage, Any]], + ) -> None: + """Handle a response to a request. + + Args: + request_id: The ID of the request to which this response corresponds. + response_data: The data returned in the response. + ctx: The workflow context for sending the response. + """ + event = self._request_events.get(request_id) + if event is None: + event = await self._rehydrate_request_event(request_id, ctx) + if event is None: + raise ValueError(f"No request found with ID: {request_id}") + + self._request_events.pop(request_id, None) + + # Create a correlated response that includes both the response data and original request + if not isinstance(event.data, RequestInfoMessage): + raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}") + correlated_response = RequestResponse(data=response_data, original_request=event.data, request_id=request_id) + await ctx.send_message(correlated_response, target_id=event.source_executor_id) + + await self._clear_pending_request_snapshot(request_id, ctx) + + async def _record_pending_request_snapshot( + self, + request: RequestInfoMessage, + source_executor_id: str, + ctx: WorkflowContext[Any], + ) -> None: + snapshot = self._build_request_snapshot(request, source_executor_id) + + pending = await self._load_pending_request_state(ctx) + pending[request.request_id] = snapshot + await self._persist_pending_request_state(pending, ctx) + await self._write_executor_state(ctx, pending) + + async def _clear_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> None: + pending = await self._load_pending_request_state(ctx) + if request_id in pending: + pending.pop(request_id, None) + await self._persist_pending_request_state(pending, ctx) + await self._write_executor_state(ctx, pending) + + async def _load_pending_request_state(self, ctx: WorkflowContext[Any]) -> dict[str, Any]: + try: + existing = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY) + except KeyError: + return {} + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to read pending request state: {exc}") + return {} + + if not isinstance(existing, dict): + if existing not in (None, {}): + logger.warning( + f"RequestInfoExecutor {self.id} encountered non-dict pending state " + f"({type(existing).__name__}); resetting." + ) + return {} + + return dict(existing) # type: ignore[arg-type] + + async def _persist_pending_request_state(self, pending: dict[str, Any], ctx: WorkflowContext[Any]) -> None: + await self._safe_set_shared_state(ctx, pending) + await self._safe_set_runner_state(ctx, pending) + + async def _safe_set_shared_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: + try: + await ctx.set_shared_state(self._PENDING_SHARED_STATE_KEY, pending) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to update shared pending state: {exc}") + + async def _safe_set_runner_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: + try: + await ctx.set_state({"pending_requests": pending}) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to update runner state with pending requests: {exc}") + + def snapshot_state(self) -> dict[str, Any]: + """Serialize pending requests so checkpoint restoration can resume seamlessly.""" + + def _encode_event(event: RequestInfoEvent) -> dict[str, Any]: + request_data = event.data + payload: dict[str, Any] + data_cls = request_data.__class__ if request_data is not None else type(None) + + payload = self._encode_request_payload(request_data, data_cls) + + return { + "source_executor_id": event.source_executor_id, + "request_type": f"{event.request_type.__module__}:{event.request_type.__qualname__}", + "request_data": payload, + } + + return { + "request_events": {rid: _encode_event(event) for rid, event in self._request_events.items()}, + } + + def _encode_request_payload(self, request_data: RequestInfoMessage | None, data_cls: type[Any]) -> dict[str, Any]: + if request_data is None or isinstance(request_data, (str, int, float, bool)): + return { + "kind": "raw", + "type": f"{data_cls.__module__}:{data_cls.__qualname__}", + "value": request_data, + } + + if is_dataclass(request_data) and not isinstance(request_data, type): + dataclass_instance = cast(Any, request_data) + safe_value = self._make_json_safe(asdict(dataclass_instance)) + return { + "kind": "dataclass", + "type": f"{data_cls.__module__}:{data_cls.__qualname__}", + "value": safe_value, + } + + to_dict_fn = getattr(request_data, "to_dict", None) + if callable(to_dict_fn): + try: + dumped = to_dict_fn() + except TypeError: + dumped = to_dict_fn() + safe_value = self._make_json_safe(dumped) + return { + "kind": "dict", + "type": f"{data_cls.__module__}:{data_cls.__qualname__}", + "value": safe_value, + } + + to_json_fn = getattr(request_data, "to_json", None) + if callable(to_json_fn): + try: + dumped = to_json_fn() + except TypeError: + dumped = to_json_fn() + converted = dumped + if isinstance(dumped, (str, bytes, bytearray)): + decoded: str | bytes | bytearray + if isinstance(dumped, (bytes, bytearray)): + try: + decoded = dumped.decode() + except Exception: + decoded = dumped + else: + decoded = dumped + try: + converted = json.loads(decoded) + except Exception: + converted = decoded + safe_value = self._make_json_safe(converted) + return { + "kind": "dict" if isinstance(converted, dict) else "json", + "type": f"{data_cls.__module__}:{data_cls.__qualname__}", + "value": safe_value, + } + + details = self._serialise_request_details(request_data) + if details is not None: + safe_value = self._make_json_safe(details) + return { + "kind": "raw", + "type": f"{data_cls.__module__}:{data_cls.__qualname__}", + "value": safe_value, + } + + safe_value = self._make_json_safe(request_data) + return { + "kind": "raw", + "type": f"{data_cls.__module__}:{data_cls.__qualname__}", + "value": safe_value, + } + + def restore_state(self, state: dict[str, Any]) -> None: + """Restore pending request bookkeeping from checkpoint state.""" + self._request_events.clear() + stored_events = state.get("request_events", {}) + + for request_id, payload in stored_events.items(): + request_type_qual = payload.get("request_type", "") + try: + request_type = self._import_qualname(request_type_qual) + except Exception as exc: # pragma: no cover - defensive fallback + logger.debug( + "RequestInfoExecutor %s failed to import %s during restore: %s", + self.id, + request_type_qual, + exc, + ) + request_type = RequestInfoMessage + request_data_meta = payload.get("request_data", {}) + request_data = self._decode_request_data(request_data_meta) + event = RequestInfoEvent( + request_id=request_id, + source_executor_id=payload.get("source_executor_id", ""), + request_type=request_type, + request_data=request_data, + ) + self._request_events[request_id] = event + + @staticmethod + def _import_qualname(qualname: str) -> type[Any]: + module_name, _, type_name = qualname.partition(":") + if not module_name or not type_name: + raise ValueError(f"Invalid qualified name: {qualname}") + module = importlib.import_module(module_name) + attr: Any = module + for part in type_name.split("."): + attr = getattr(attr, part) + if not isinstance(attr, type): + raise TypeError(f"Resolved object is not a type: {qualname}") + return attr + + def _decode_request_data(self, metadata: dict[str, Any]) -> RequestInfoMessage: + kind = metadata.get("kind") + type_name = metadata.get("type", "") + value: Any = metadata.get("value", {}) + if type_name: + try: + imported = self._import_qualname(type_name) + except Exception as exc: # pragma: no cover - defensive fallback + logger.debug( + "RequestInfoExecutor %s failed to import %s during decode: %s", + self.id, + type_name, + exc, + ) + imported = RequestInfoMessage + else: + imported = RequestInfoMessage + target_cls: type[RequestInfoMessage] + if isinstance(imported, type) and issubclass(imported, RequestInfoMessage): + target_cls = imported + else: + target_cls = RequestInfoMessage + + if kind == "dataclass" and isinstance(value, dict): + with contextlib.suppress(TypeError): + return target_cls(**value) # type: ignore[arg-type] + + # Backwards-compat handling for checkpoints that used to store pydantic as "dict" + if kind in {"dict", "pydantic", "json"} and isinstance(value, dict): + from_dict = getattr(target_cls, "from_dict", None) + if callable(from_dict): + with contextlib.suppress(Exception): + return cast(RequestInfoMessage, from_dict(value)) + + if kind == "json" and isinstance(value, str): + from_json = getattr(target_cls, "from_json", None) + if callable(from_json): + with contextlib.suppress(Exception): + return cast(RequestInfoMessage, from_json(value)) + with contextlib.suppress(Exception): + parsed = json.loads(value) + if isinstance(parsed, dict): + return self._decode_request_data({"kind": "dict", "type": type_name, "value": parsed}) + + if isinstance(value, dict): + with contextlib.suppress(TypeError): + return target_cls(**value) # type: ignore[arg-type] + instance = object.__new__(target_cls) + instance.__dict__.update(value) # type: ignore[arg-type] + return instance + + with contextlib.suppress(Exception): + return target_cls() + return RequestInfoMessage() + + async def _write_executor_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None: + state = self.snapshot_state() + state["pending_requests"] = pending + try: + await ctx.set_state(state) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to persist executor state: {exc}") + + def _build_request_snapshot( + self, + request: RequestInfoMessage, + source_executor_id: str, + ) -> dict[str, Any]: + snapshot: dict[str, Any] = { + "request_id": request.request_id, + "source_executor_id": source_executor_id, + "request_type": f"{type(request).__module__}:{type(request).__name__}", + "summary": repr(request), + } + + details = self._serialise_request_details(request) + if details: + snapshot["details"] = details + for key in ("prompt", "draft", "iteration"): + if key in details and key not in snapshot: + snapshot[key] = details[key] + + return snapshot + + def _serialise_request_details(self, request: RequestInfoMessage) -> dict[str, Any] | None: + if is_dataclass(request): + data = self._make_json_safe(asdict(request)) + if isinstance(data, dict): + return cast(dict[str, Any], data) + return None + + to_dict = getattr(request, "to_dict", None) + if callable(to_dict): + try: + dump = self._make_json_safe(to_dict()) + except TypeError: + dump = self._make_json_safe(to_dict()) + if isinstance(dump, dict): + return cast(dict[str, Any], dump) + return None + + to_json = getattr(request, "to_json", None) + if callable(to_json): + try: + raw = to_json() + except TypeError: + raw = to_json() + converted = raw + if isinstance(raw, (str, bytes, bytearray)): + decoded: str | bytes | bytearray + if isinstance(raw, (bytes, bytearray)): + try: + decoded = raw.decode() + except Exception: + decoded = raw + else: + decoded = raw + try: + converted = json.loads(decoded) + except Exception: + converted = decoded + dump = self._make_json_safe(converted) + if isinstance(dump, dict): + return cast(dict[str, Any], dump) + return None + + attrs = getattr(request, "__dict__", None) + if isinstance(attrs, dict): + cleaned = self._make_json_safe(attrs) + if isinstance(cleaned, dict): + return cast(dict[str, Any], cleaned) + + return None + + def _make_json_safe(self, value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, Mapping): + safe_dict: dict[str, Any] = {} + for key, val in value.items(): # type: ignore[attr-defined] + safe_dict[str(key)] = self._make_json_safe(val) # type: ignore[arg-type] + return safe_dict + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [self._make_json_safe(item) for item in value] # type: ignore[misc] + return repr(value) + + async def has_pending_request(self, request_id: str, ctx: WorkflowContext[Any]) -> bool: + if request_id in self._request_events: + return True + snapshot = await self._get_pending_request_snapshot(request_id, ctx) + return snapshot is not None + + async def _rehydrate_request_event( + self, + request_id: str, + ctx: WorkflowContext[Any], + ) -> RequestInfoEvent | None: + snapshot = await self._get_pending_request_snapshot(request_id, ctx) + if snapshot is None: + return None + + source_executor_id = snapshot.get("source_executor_id") + if not isinstance(source_executor_id, str) or not source_executor_id: + return None + + request = self._construct_request_from_snapshot(snapshot) + if request is None: + return None + + event = RequestInfoEvent( + request_id=request_id, + source_executor_id=source_executor_id, + request_type=type(request), + request_data=request, + ) + self._request_events[request_id] = event + return event + + async def _get_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> dict[str, Any] | None: + pending = await self._collect_pending_request_snapshots(ctx) + snapshot = pending.get(request_id) + if snapshot is None: + return None + return snapshot + + async def _collect_pending_request_snapshots(self, ctx: WorkflowContext[Any]) -> dict[str, dict[str, Any]]: + combined: dict[str, dict[str, Any]] = {} + + try: + shared_pending = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY) + except KeyError: + shared_pending = None + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to read shared pending state during rehydrate: {exc}") + shared_pending = None + + if isinstance(shared_pending, dict): + for key, value in shared_pending.items(): # type: ignore[attr-defined] + if isinstance(key, str) and isinstance(value, dict): + combined[key] = cast(dict[str, Any], value) + + try: + state = await ctx.get_state() + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"RequestInfoExecutor {self.id} failed to read runner state during rehydrate: {exc}") + state = None + + if isinstance(state, dict): + state_pending = state.get("pending_requests") + if isinstance(state_pending, dict): + for key, value in state_pending.items(): # type: ignore[attr-defined] + if isinstance(key, str) and isinstance(value, dict) and key not in combined: + combined[key] = cast(dict[str, Any], value) + + return combined + + def _construct_request_from_snapshot(self, snapshot: dict[str, Any]) -> RequestInfoMessage | None: + details_raw = snapshot.get("details") + details: dict[str, Any] = cast(dict[str, Any], details_raw) if isinstance(details_raw, dict) else {} + + request_cls: type[RequestInfoMessage] = RequestInfoMessage + request_type_str = snapshot.get("request_type") + if isinstance(request_type_str, str) and ":" in request_type_str: + module_name, class_name = request_type_str.split(":", 1) + try: + module = importlib.import_module(module_name) + candidate = getattr(module, class_name) + if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage): + request_cls = candidate + except Exception as exc: + logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}") + request_cls = RequestInfoMessage + + request: RequestInfoMessage | None = self._instantiate_request(request_cls, details) + + if request is None and request_cls is not RequestInfoMessage: + request = self._instantiate_request(RequestInfoMessage, details) + + if request is None: + logger.warning( + f"RequestInfoExecutor {self.id} could not reconstruct request " + f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(details.keys())}" + ) + return None + + for key, value in details.items(): + if key == "request_id": + continue + try: + setattr(request, key, value) + except Exception as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}" + ) + continue + + snapshot_request_id = snapshot.get("request_id") + if isinstance(snapshot_request_id, str) and snapshot_request_id: + try: + request.request_id = snapshot_request_id + except Exception as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not apply snapshot " + f"request_id to {type(request).__name__}: {exc}" + ) + + return request + + def _instantiate_request( + self, + request_cls: type[RequestInfoMessage], + details: dict[str, Any], + ) -> RequestInfoMessage | None: + try: + from_dict = getattr(request_cls, "from_dict", None) + if callable(from_dict): + return cast(RequestInfoMessage, from_dict(details)) + except (TypeError, ValueError) as exc: + logger.debug(f"RequestInfoExecutor {self.id} failed to hydrate {request_cls.__name__} via from_dict: {exc}") + except Exception as exc: + logger.warning( + f"RequestInfoExecutor {self.id} encountered unexpected error during " + f"{request_cls.__name__}.from_dict: {exc}" + ) + + if is_dataclass(request_cls): + try: + field_names = {f.name for f in fields(request_cls)} + ctor_kwargs = {name: details[name] for name in field_names if name in details} + return request_cls(**ctor_kwargs) + except (TypeError, ValueError) as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not instantiate dataclass " + f"{request_cls.__name__} with snapshot data: {exc}" + ) + except Exception as exc: + logger.warning( + f"RequestInfoExecutor {self.id} encountered unexpected error " + f"constructing dataclass {request_cls.__name__}: {exc}" + ) + + try: + instance = request_cls() + except Exception as exc: + logger.warning( + f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}" + ) + return None + + for key, value in details.items(): + if key == "request_id": + continue + try: + setattr(instance, key, value) + except Exception as exc: + logger.debug( + f"RequestInfoExecutor {self.id} could not set attribute {key} on " + f"{request_cls.__name__} during instantiation: {exc}" + ) + continue + + return instance + + @staticmethod + def pending_requests_from_checkpoint( + checkpoint: WorkflowCheckpoint, + *, + request_executor_ids: Iterable[str] | None = None, + ) -> list[PendingRequestDetails]: + executor_filter: set[str] | None = None + if request_executor_ids is not None: + executor_filter = {str(value) for value in request_executor_ids} + + pending: dict[str, PendingRequestDetails] = {} + + shared_map = checkpoint.shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY) + if isinstance(shared_map, Mapping): + for request_id, snapshot in shared_map.items(): # type: ignore[attr-defined] + RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type] + + for state in checkpoint.executor_states.values(): + if not isinstance(state, Mapping): + continue + inner = state.get("pending_requests") + if isinstance(inner, Mapping): + for request_id, snapshot in inner.items(): # type: ignore[attr-defined] + RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type] + + for source_id, message_list in checkpoint.messages.items(): + if executor_filter is not None and source_id not in executor_filter: + continue + if not isinstance(message_list, list): + continue + for message in message_list: + if not isinstance(message, Mapping): + continue + payload = _decode_checkpoint_value(message.get("data")) + RequestInfoExecutor._merge_message_payload(pending, payload, message) + + return list(pending.values()) + + @staticmethod + def checkpoint_summary( + checkpoint: WorkflowCheckpoint, + *, + request_executor_ids: Iterable[str] | None = None, + preview_width: int = 70, + ) -> WorkflowCheckpointSummary: + targets = sorted(checkpoint.messages.keys()) + executor_states = sorted(checkpoint.executor_states.keys()) + pending = RequestInfoExecutor.pending_requests_from_checkpoint( + checkpoint, request_executor_ids=request_executor_ids + ) + + draft_preview: str | None = None + for entry in pending: + if entry.draft: + draft_preview = shorten(entry.draft, width=preview_width, placeholder="…") + break + + status = "idle" + if pending: + status = "awaiting human response" + elif not checkpoint.messages and "finalise" in executor_states: + status = "completed" + elif checkpoint.messages: + status = "awaiting next superstep" + elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids): + status = "awaiting request delivery" + + return WorkflowCheckpointSummary( + checkpoint_id=checkpoint.checkpoint_id, + iteration_count=checkpoint.iteration_count, + targets=targets, + executor_states=executor_states, + status=status, + draft_preview=draft_preview, + pending_requests=pending, + ) + + @staticmethod + def _merge_snapshot( + pending: dict[str, PendingRequestDetails], + request_id: str, + snapshot: Any, + ) -> None: + if not request_id or not isinstance(snapshot, Mapping): + return + + details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) + + RequestInfoExecutor._apply_update( + details, + prompt=snapshot.get("prompt"), # type: ignore[attr-defined] + draft=snapshot.get("draft"), # type: ignore[attr-defined] + iteration=snapshot.get("iteration"), # type: ignore[attr-defined] + source_executor_id=snapshot.get("source_executor_id"), # type: ignore[attr-defined] + ) + + extra = snapshot.get("details") # type: ignore[attr-defined] + if isinstance(extra, Mapping): + RequestInfoExecutor._apply_update( + details, + prompt=extra.get("prompt"), # type: ignore[attr-defined] + draft=extra.get("draft"), # type: ignore[attr-defined] + iteration=extra.get("iteration"), # type: ignore[attr-defined] + ) + + @staticmethod + def _merge_message_payload( + pending: dict[str, PendingRequestDetails], + payload: Any, + raw_message: Mapping[str, Any], + ) -> None: + if isinstance(payload, RequestResponse): + request_id = payload.request_id or RequestInfoExecutor._get_field(payload.original_request, "request_id") # type: ignore[arg-type] + if not request_id: + return + details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) + RequestInfoExecutor._apply_update( + details, + prompt=RequestInfoExecutor._get_field(payload.original_request, "prompt"), # type: ignore[arg-type] + draft=RequestInfoExecutor._get_field(payload.original_request, "draft"), # type: ignore[arg-type] + iteration=RequestInfoExecutor._get_field(payload.original_request, "iteration"), # type: ignore[arg-type] + source_executor_id=raw_message.get("source_id"), + original_request=payload.original_request, # type: ignore[arg-type] + ) + elif isinstance(payload, RequestInfoMessage): + request_id = getattr(payload, "request_id", None) + if not request_id: + return + details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) + RequestInfoExecutor._apply_update( + details, + prompt=getattr(payload, "prompt", None), + draft=getattr(payload, "draft", None), + iteration=getattr(payload, "iteration", None), + source_executor_id=raw_message.get("source_id"), + original_request=payload, + ) + + @staticmethod + def _apply_update( + details: PendingRequestDetails, + *, + prompt: Any = None, + draft: Any = None, + iteration: Any = None, + source_executor_id: Any = None, + original_request: Any = None, + ) -> None: + if prompt and not details.prompt: + details.prompt = str(prompt) + if draft and not details.draft: + details.draft = str(draft) + if iteration is not None and details.iteration is None: + coerced = RequestInfoExecutor._coerce_int(iteration) + if coerced is not None: + details.iteration = coerced + if source_executor_id and not details.source_executor_id: + details.source_executor_id = str(source_executor_id) + if original_request is not None and details.original_request is None: + details.original_request = original_request + + @staticmethod + def _get_field(obj: Any, key: str) -> Any: + if obj is None: + return None + if isinstance(obj, Mapping): + return obj.get(key) # type: ignore[attr-defined,return-value] + return getattr(obj, key, None) + + @staticmethod + def _coerce_int(value: Any) -> int | None: + try: + return int(value) + except (TypeError, ValueError): + return None diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index a8f845637e..11a0499acf 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -22,7 +22,7 @@ from ._runner_context import ( from ._shared_state import SharedState if TYPE_CHECKING: - from ._executor import RequestInfoExecutor + from ._request_info_executor import RequestInfoExecutor logger = logging.getLogger(__name__) @@ -372,7 +372,7 @@ class Runner: Returns: The RequestInfoExecutor instance if found, None otherwise. """ - from ._executor import RequestInfoExecutor + from ._request_info_executor import RequestInfoExecutor for executor in self._executors.values(): if isinstance(executor, RequestInfoExecutor): @@ -388,7 +388,7 @@ class Runner: Returns: True if the message targets a RequestInfoExecutor, False otherwise. """ - from ._executor import RequestInfoExecutor + from ._request_info_executor import RequestInfoExecutor if not msg.target_id: return False diff --git a/python/packages/core/agent_framework/_workflows/_validation.py b/python/packages/core/agent_framework/_workflows/_validation.py index 95f522c225..5cd7940ff3 100644 --- a/python/packages/core/agent_framework/_workflows/_validation.py +++ b/python/packages/core/agent_framework/_workflows/_validation.py @@ -9,7 +9,8 @@ from types import UnionType from typing import Any, Union, get_args, get_origin from ._edge import Edge, EdgeGroup, FanInEdgeGroup -from ._executor import Executor, RequestInfoExecutor +from ._executor import Executor +from ._request_info_executor import RequestInfoExecutor logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 0782dbd1cd..d9270bfe02 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -37,8 +37,9 @@ from ._events import ( WorkflowStatusEvent, _framework_event_origin, # type: ignore ) -from ._executor import Executor, RequestInfoExecutor +from ._executor import Executor from ._model_utils import DictConvertible +from ._request_info_executor import RequestInfoExecutor from ._runner import Runner from ._runner_context import InProcRunnerContext, RunnerContext from ._shared_state import SharedState @@ -742,7 +743,7 @@ class Workflow(DictConvertible): Returns: The RequestInfoExecutor instance if found, None otherwise. """ - from ._executor import RequestInfoExecutor + from ._request_info_executor import RequestInfoExecutor for executor in self.executors.values(): if isinstance(executor, RequestInfoExecutor): diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 55f1fab264..501ce0d8f1 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -19,10 +19,12 @@ from ._events import ( ) from ._executor import ( Executor, + handler, +) +from ._request_info_executor import ( RequestInfoExecutor, RequestInfoMessage, RequestResponse, - handler, ) from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext diff --git a/python/packages/core/tests/workflow/test_checkpoint_decode.py b/python/packages/core/tests/workflow/test_checkpoint_decode.py index 1947b3fb41..08c10aa9a9 100644 --- a/python/packages/core/tests/workflow/test_checkpoint_decode.py +++ b/python/packages/core/tests/workflow/test_checkpoint_decode.py @@ -3,7 +3,7 @@ from dataclasses import dataclass # noqa: I001 from typing import Any, cast -from agent_framework._workflows._executor import RequestInfoMessage, RequestResponse +from agent_framework._workflows._request_info_executor import RequestInfoMessage, RequestResponse from agent_framework._workflows._runner_context import ( # type: ignore _decode_checkpoint_value, # type: ignore _encode_checkpoint_value, # type: ignore diff --git a/python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py b/python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py index dee1a30c25..2b52a04f61 100644 --- a/python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py +++ b/python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py @@ -7,7 +7,7 @@ from typing import Any from agent_framework._workflows._checkpoint import CheckpointStorage, WorkflowCheckpoint from agent_framework._workflows._events import RequestInfoEvent, WorkflowEvent -from agent_framework._workflows._executor import ( +from agent_framework._workflows._request_info_executor import ( PendingRequestDetails, RequestInfoExecutor, RequestInfoMessage,