Python: Reorganize workflows modules (#1282)

* Reorganize modules

* Fix unit tests

* Remove submodules
This commit is contained in:
Tao Chen
2025-10-08 10:04:06 -07:00
committed by GitHub
Unverified
parent 1c5e607a1f
commit c2c8ec3d4e
12 changed files with 873 additions and 854 deletions
@@ -48,9 +48,6 @@ from ._events import (
)
from ._executor import (
Executor,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
handler,
)
from ._function_executor import FunctionExecutor, executor
@@ -76,6 +73,11 @@ from ._magentic import (
MagenticStartMessage,
StandardMagenticManager,
)
from ._request_info_executor import (
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
)
from ._runner import Runner
from ._runner_context import (
InProcRunnerContext,
@@ -46,9 +46,6 @@ from ._events import (
)
from ._executor import (
Executor,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
handler,
)
from ._function_executor import FunctionExecutor, executor
@@ -74,6 +71,11 @@ from ._magentic import (
MagenticStartMessage,
StandardMagenticManager,
)
from ._request_info_executor import (
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
)
from ._runner import Runner
from ._runner_context import (
InProcRunnerContext,
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, TypeAlias
from agent_framework import AgentRunResponse, AgentRunResponseUpdate
if TYPE_CHECKING:
from ._executor import RequestInfoMessage
from ._request_info_executor import RequestInfoMessage
class WorkflowEventSource(str, Enum):
@@ -2,61 +2,29 @@
import contextlib
import functools
import importlib
import inspect
import json
import logging
import uuid
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from textwrap import shorten
from typing import Any, ClassVar, Generic, TypeVar, cast
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from ..observability import create_processing_span
from ._checkpoint import WorkflowCheckpoint
from ._events import (
ExecutorCompletedEvent,
ExecutorFailedEvent,
ExecutorInvokedEvent,
RequestInfoEvent,
WorkflowErrorDetails,
_framework_event_origin, # type: ignore[reportPrivateUsage]
)
from ._model_utils import DictConvertible
from ._runner_context import Message, RunnerContext, _decode_checkpoint_value # type: ignore
from ._runner_context import Message, RunnerContext # type: ignore
from ._shared_state import SharedState
from ._typing_utils import is_instance_of
from ._workflow_context import WorkflowContext, validate_function_signature
logger = logging.getLogger(__name__)
# region Executor
@dataclass
class PendingRequestDetails:
"""Lightweight information about a pending request captured in a checkpoint."""
request_id: str
prompt: str | None = None
draft: str | None = None
iteration: int | None = None
source_executor_id: str | None = None
original_request: "RequestInfoMessage | dict[str, Any] | None" = None
@dataclass
class WorkflowCheckpointSummary:
"""Human-readable summary of a workflow checkpoint."""
checkpoint_id: str
iteration_count: int
targets: list[str]
executor_states: list[str]
status: str
draft_preview: str | None
pending_requests: list[PendingRequestDetails]
class Executor(DictConvertible):
"""Base class for all workflow executors that process messages and perform computations.
@@ -542,802 +510,3 @@ def handler(
# endregion: Handler Decorator
# region Request/Response Types
@dataclass
class RequestInfoMessage:
"""Base class for all request messages in workflows.
Any message that should be routed to the RequestInfoExecutor for external
handling must inherit from this class. This ensures type safety and makes
the request/response pattern explicit.
"""
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
"""Unique identifier for correlating requests and responses."""
source_executor_id: str | None = None
"""ID of the executor expecting a response to this request.
May differ from the executor that sent the request if intercepted and forwarded."""
TRequest = TypeVar("TRequest", bound="RequestInfoMessage")
TResponse = TypeVar("TResponse")
@dataclass
class RequestResponse(Generic[TRequest, TResponse]):
"""Response type for request/response correlation in workflows.
This type is used by RequestInfoExecutor to create correlated responses
that include the original request context for proper message routing.
"""
data: TResponse
"""The response data returned from handling the request."""
original_request: TRequest
"""The original request that this response corresponds to."""
request_id: str
"""The ID of the original request."""
# endregion: Request/Response Types
# region Request Info Executor
class RequestInfoExecutor(Executor):
"""Built-in executor that handles request/response patterns in workflows.
This executor acts as a gateway for external information requests. When it receives
a request message, it saves the request details and emits a RequestInfoEvent. When
a response is provided externally, it emits the response as a message.
"""
_PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info"
def __init__(self, id: str):
"""Initialize the RequestInfoExecutor with a unique ID.
Args:
id: Unique ID for this RequestInfoExecutor.
"""
super().__init__(id=id)
self._request_events: dict[str, RequestInfoEvent] = {}
@handler
async def run(self, message: RequestInfoMessage, ctx: WorkflowContext) -> None:
"""Run the RequestInfoExecutor with the given message."""
# Use source_executor_id from message if available, otherwise fall back to context
source_executor_id = message.source_executor_id or ctx.get_source_executor_id()
event = RequestInfoEvent(
request_id=message.request_id,
source_executor_id=source_executor_id,
request_type=type(message),
request_data=message,
)
self._request_events[message.request_id] = event
await self._record_pending_request_snapshot(message, source_executor_id, ctx)
await ctx.add_event(event)
async def handle_response(
self,
response_data: Any,
request_id: str,
ctx: WorkflowContext[RequestResponse[RequestInfoMessage, Any]],
) -> None:
"""Handle a response to a request.
Args:
request_id: The ID of the request to which this response corresponds.
response_data: The data returned in the response.
ctx: The workflow context for sending the response.
"""
event = self._request_events.get(request_id)
if event is None:
event = await self._rehydrate_request_event(request_id, ctx)
if event is None:
raise ValueError(f"No request found with ID: {request_id}")
self._request_events.pop(request_id, None)
# Create a correlated response that includes both the response data and original request
if not isinstance(event.data, RequestInfoMessage):
raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}")
correlated_response = RequestResponse(data=response_data, original_request=event.data, request_id=request_id)
await ctx.send_message(correlated_response, target_id=event.source_executor_id)
await self._clear_pending_request_snapshot(request_id, ctx)
async def _record_pending_request_snapshot(
self,
request: RequestInfoMessage,
source_executor_id: str,
ctx: WorkflowContext[Any],
) -> None:
snapshot = self._build_request_snapshot(request, source_executor_id)
pending = await self._load_pending_request_state(ctx)
pending[request.request_id] = snapshot
await self._persist_pending_request_state(pending, ctx)
await self._write_executor_state(ctx, pending)
async def _clear_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> None:
pending = await self._load_pending_request_state(ctx)
if request_id in pending:
pending.pop(request_id, None)
await self._persist_pending_request_state(pending, ctx)
await self._write_executor_state(ctx, pending)
async def _load_pending_request_state(self, ctx: WorkflowContext[Any]) -> dict[str, Any]:
try:
existing = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
except KeyError:
return {}
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read pending request state: {exc}")
return {}
if not isinstance(existing, dict):
if existing not in (None, {}):
logger.warning(
f"RequestInfoExecutor {self.id} encountered non-dict pending state "
f"({type(existing).__name__}); resetting."
)
return {}
return dict(existing) # type: ignore[arg-type]
async def _persist_pending_request_state(self, pending: dict[str, Any], ctx: WorkflowContext[Any]) -> None:
await self._safe_set_shared_state(ctx, pending)
await self._safe_set_runner_state(ctx, pending)
async def _safe_set_shared_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
try:
await ctx.set_shared_state(self._PENDING_SHARED_STATE_KEY, pending)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to update shared pending state: {exc}")
async def _safe_set_runner_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
try:
await ctx.set_state({"pending_requests": pending})
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to update runner state with pending requests: {exc}")
def snapshot_state(self) -> dict[str, Any]:
"""Serialize pending requests so checkpoint restoration can resume seamlessly."""
def _encode_event(event: RequestInfoEvent) -> dict[str, Any]:
request_data = event.data
payload: dict[str, Any]
data_cls = request_data.__class__ if request_data is not None else type(None)
payload = self._encode_request_payload(request_data, data_cls)
return {
"source_executor_id": event.source_executor_id,
"request_type": f"{event.request_type.__module__}:{event.request_type.__qualname__}",
"request_data": payload,
}
return {
"request_events": {rid: _encode_event(event) for rid, event in self._request_events.items()},
}
def _encode_request_payload(self, request_data: RequestInfoMessage | None, data_cls: type[Any]) -> dict[str, Any]:
if request_data is None or isinstance(request_data, (str, int, float, bool)):
return {
"kind": "raw",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": request_data,
}
if is_dataclass(request_data) and not isinstance(request_data, type):
dataclass_instance = cast(Any, request_data)
safe_value = self._make_json_safe(asdict(dataclass_instance))
return {
"kind": "dataclass",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
to_dict_fn = getattr(request_data, "to_dict", None)
if callable(to_dict_fn):
try:
dumped = to_dict_fn()
except TypeError:
dumped = to_dict_fn()
safe_value = self._make_json_safe(dumped)
return {
"kind": "dict",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
to_json_fn = getattr(request_data, "to_json", None)
if callable(to_json_fn):
try:
dumped = to_json_fn()
except TypeError:
dumped = to_json_fn()
converted = dumped
if isinstance(dumped, (str, bytes, bytearray)):
decoded: str | bytes | bytearray
if isinstance(dumped, (bytes, bytearray)):
try:
decoded = dumped.decode()
except Exception:
decoded = dumped
else:
decoded = dumped
try:
converted = json.loads(decoded)
except Exception:
converted = decoded
safe_value = self._make_json_safe(converted)
return {
"kind": "dict" if isinstance(converted, dict) else "json",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
details = self._serialise_request_details(request_data)
if details is not None:
safe_value = self._make_json_safe(details)
return {
"kind": "raw",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
safe_value = self._make_json_safe(request_data)
return {
"kind": "raw",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
def restore_state(self, state: dict[str, Any]) -> None:
"""Restore pending request bookkeeping from checkpoint state."""
self._request_events.clear()
stored_events = state.get("request_events", {})
for request_id, payload in stored_events.items():
request_type_qual = payload.get("request_type", "")
try:
request_type = self._import_qualname(request_type_qual)
except Exception as exc: # pragma: no cover - defensive fallback
logger.debug(
"RequestInfoExecutor %s failed to import %s during restore: %s",
self.id,
request_type_qual,
exc,
)
request_type = RequestInfoMessage
request_data_meta = payload.get("request_data", {})
request_data = self._decode_request_data(request_data_meta)
event = RequestInfoEvent(
request_id=request_id,
source_executor_id=payload.get("source_executor_id", ""),
request_type=request_type,
request_data=request_data,
)
self._request_events[request_id] = event
@staticmethod
def _import_qualname(qualname: str) -> type[Any]:
module_name, _, type_name = qualname.partition(":")
if not module_name or not type_name:
raise ValueError(f"Invalid qualified name: {qualname}")
module = importlib.import_module(module_name)
attr: Any = module
for part in type_name.split("."):
attr = getattr(attr, part)
if not isinstance(attr, type):
raise TypeError(f"Resolved object is not a type: {qualname}")
return attr
def _decode_request_data(self, metadata: dict[str, Any]) -> RequestInfoMessage:
kind = metadata.get("kind")
type_name = metadata.get("type", "")
value: Any = metadata.get("value", {})
if type_name:
try:
imported = self._import_qualname(type_name)
except Exception as exc: # pragma: no cover - defensive fallback
logger.debug(
"RequestInfoExecutor %s failed to import %s during decode: %s",
self.id,
type_name,
exc,
)
imported = RequestInfoMessage
else:
imported = RequestInfoMessage
target_cls: type[RequestInfoMessage]
if isinstance(imported, type) and issubclass(imported, RequestInfoMessage):
target_cls = imported
else:
target_cls = RequestInfoMessage
if kind == "dataclass" and isinstance(value, dict):
with contextlib.suppress(TypeError):
return target_cls(**value) # type: ignore[arg-type]
# Backwards-compat handling for checkpoints that used to store pydantic as "dict"
if kind in {"dict", "pydantic", "json"} and isinstance(value, dict):
from_dict = getattr(target_cls, "from_dict", None)
if callable(from_dict):
with contextlib.suppress(Exception):
return cast(RequestInfoMessage, from_dict(value))
if kind == "json" and isinstance(value, str):
from_json = getattr(target_cls, "from_json", None)
if callable(from_json):
with contextlib.suppress(Exception):
return cast(RequestInfoMessage, from_json(value))
with contextlib.suppress(Exception):
parsed = json.loads(value)
if isinstance(parsed, dict):
return self._decode_request_data({"kind": "dict", "type": type_name, "value": parsed})
if isinstance(value, dict):
with contextlib.suppress(TypeError):
return target_cls(**value) # type: ignore[arg-type]
instance = object.__new__(target_cls)
instance.__dict__.update(value) # type: ignore[arg-type]
return instance
with contextlib.suppress(Exception):
return target_cls()
return RequestInfoMessage()
async def _write_executor_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
state = self.snapshot_state()
state["pending_requests"] = pending
try:
await ctx.set_state(state)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to persist executor state: {exc}")
def _build_request_snapshot(
self,
request: RequestInfoMessage,
source_executor_id: str,
) -> dict[str, Any]:
snapshot: dict[str, Any] = {
"request_id": request.request_id,
"source_executor_id": source_executor_id,
"request_type": f"{type(request).__module__}:{type(request).__name__}",
"summary": repr(request),
}
details = self._serialise_request_details(request)
if details:
snapshot["details"] = details
for key in ("prompt", "draft", "iteration"):
if key in details and key not in snapshot:
snapshot[key] = details[key]
return snapshot
def _serialise_request_details(self, request: RequestInfoMessage) -> dict[str, Any] | None:
if is_dataclass(request):
data = self._make_json_safe(asdict(request))
if isinstance(data, dict):
return cast(dict[str, Any], data)
return None
to_dict = getattr(request, "to_dict", None)
if callable(to_dict):
try:
dump = self._make_json_safe(to_dict())
except TypeError:
dump = self._make_json_safe(to_dict())
if isinstance(dump, dict):
return cast(dict[str, Any], dump)
return None
to_json = getattr(request, "to_json", None)
if callable(to_json):
try:
raw = to_json()
except TypeError:
raw = to_json()
converted = raw
if isinstance(raw, (str, bytes, bytearray)):
decoded: str | bytes | bytearray
if isinstance(raw, (bytes, bytearray)):
try:
decoded = raw.decode()
except Exception:
decoded = raw
else:
decoded = raw
try:
converted = json.loads(decoded)
except Exception:
converted = decoded
dump = self._make_json_safe(converted)
if isinstance(dump, dict):
return cast(dict[str, Any], dump)
return None
attrs = getattr(request, "__dict__", None)
if isinstance(attrs, dict):
cleaned = self._make_json_safe(attrs)
if isinstance(cleaned, dict):
return cast(dict[str, Any], cleaned)
return None
def _make_json_safe(self, value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Mapping):
safe_dict: dict[str, Any] = {}
for key, val in value.items(): # type: ignore[attr-defined]
safe_dict[str(key)] = self._make_json_safe(val) # type: ignore[arg-type]
return safe_dict
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return [self._make_json_safe(item) for item in value] # type: ignore[misc]
return repr(value)
async def has_pending_request(self, request_id: str, ctx: WorkflowContext[Any]) -> bool:
if request_id in self._request_events:
return True
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
return snapshot is not None
async def _rehydrate_request_event(
self,
request_id: str,
ctx: WorkflowContext[Any],
) -> RequestInfoEvent | None:
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
if snapshot is None:
return None
source_executor_id = snapshot.get("source_executor_id")
if not isinstance(source_executor_id, str) or not source_executor_id:
return None
request = self._construct_request_from_snapshot(snapshot)
if request is None:
return None
event = RequestInfoEvent(
request_id=request_id,
source_executor_id=source_executor_id,
request_type=type(request),
request_data=request,
)
self._request_events[request_id] = event
return event
async def _get_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> dict[str, Any] | None:
pending = await self._collect_pending_request_snapshots(ctx)
snapshot = pending.get(request_id)
if snapshot is None:
return None
return snapshot
async def _collect_pending_request_snapshots(self, ctx: WorkflowContext[Any]) -> dict[str, dict[str, Any]]:
combined: dict[str, dict[str, Any]] = {}
try:
shared_pending = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
except KeyError:
shared_pending = None
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read shared pending state during rehydrate: {exc}")
shared_pending = None
if isinstance(shared_pending, dict):
for key, value in shared_pending.items(): # type: ignore[attr-defined]
if isinstance(key, str) and isinstance(value, dict):
combined[key] = cast(dict[str, Any], value)
try:
state = await ctx.get_state()
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read runner state during rehydrate: {exc}")
state = None
if isinstance(state, dict):
state_pending = state.get("pending_requests")
if isinstance(state_pending, dict):
for key, value in state_pending.items(): # type: ignore[attr-defined]
if isinstance(key, str) and isinstance(value, dict) and key not in combined:
combined[key] = cast(dict[str, Any], value)
return combined
def _construct_request_from_snapshot(self, snapshot: dict[str, Any]) -> RequestInfoMessage | None:
details_raw = snapshot.get("details")
details: dict[str, Any] = cast(dict[str, Any], details_raw) if isinstance(details_raw, dict) else {}
request_cls: type[RequestInfoMessage] = RequestInfoMessage
request_type_str = snapshot.get("request_type")
if isinstance(request_type_str, str) and ":" in request_type_str:
module_name, class_name = request_type_str.split(":", 1)
try:
module = importlib.import_module(module_name)
candidate = getattr(module, class_name)
if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage):
request_cls = candidate
except Exception as exc:
logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}")
request_cls = RequestInfoMessage
request: RequestInfoMessage | None = self._instantiate_request(request_cls, details)
if request is None and request_cls is not RequestInfoMessage:
request = self._instantiate_request(RequestInfoMessage, details)
if request is None:
logger.warning(
f"RequestInfoExecutor {self.id} could not reconstruct request "
f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(details.keys())}"
)
return None
for key, value in details.items():
if key == "request_id":
continue
try:
setattr(request, key, value)
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}"
)
continue
snapshot_request_id = snapshot.get("request_id")
if isinstance(snapshot_request_id, str) and snapshot_request_id:
try:
request.request_id = snapshot_request_id
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not apply snapshot "
f"request_id to {type(request).__name__}: {exc}"
)
return request
def _instantiate_request(
self,
request_cls: type[RequestInfoMessage],
details: dict[str, Any],
) -> RequestInfoMessage | None:
try:
from_dict = getattr(request_cls, "from_dict", None)
if callable(from_dict):
return cast(RequestInfoMessage, from_dict(details))
except (TypeError, ValueError) as exc:
logger.debug(f"RequestInfoExecutor {self.id} failed to hydrate {request_cls.__name__} via from_dict: {exc}")
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} encountered unexpected error during "
f"{request_cls.__name__}.from_dict: {exc}"
)
if is_dataclass(request_cls):
try:
field_names = {f.name for f in fields(request_cls)}
ctor_kwargs = {name: details[name] for name in field_names if name in details}
return request_cls(**ctor_kwargs)
except (TypeError, ValueError) as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not instantiate dataclass "
f"{request_cls.__name__} with snapshot data: {exc}"
)
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} encountered unexpected error "
f"constructing dataclass {request_cls.__name__}: {exc}"
)
try:
instance = request_cls()
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}"
)
return None
for key, value in details.items():
if key == "request_id":
continue
try:
setattr(instance, key, value)
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not set attribute {key} on "
f"{request_cls.__name__} during instantiation: {exc}"
)
continue
return instance
@staticmethod
def pending_requests_from_checkpoint(
checkpoint: WorkflowCheckpoint,
*,
request_executor_ids: Iterable[str] | None = None,
) -> list[PendingRequestDetails]:
executor_filter: set[str] | None = None
if request_executor_ids is not None:
executor_filter = {str(value) for value in request_executor_ids}
pending: dict[str, PendingRequestDetails] = {}
shared_map = checkpoint.shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY)
if isinstance(shared_map, Mapping):
for request_id, snapshot in shared_map.items(): # type: ignore[attr-defined]
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type]
for state in checkpoint.executor_states.values():
if not isinstance(state, Mapping):
continue
inner = state.get("pending_requests")
if isinstance(inner, Mapping):
for request_id, snapshot in inner.items(): # type: ignore[attr-defined]
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type]
for source_id, message_list in checkpoint.messages.items():
if executor_filter is not None and source_id not in executor_filter:
continue
if not isinstance(message_list, list):
continue
for message in message_list:
if not isinstance(message, Mapping):
continue
payload = _decode_checkpoint_value(message.get("data"))
RequestInfoExecutor._merge_message_payload(pending, payload, message)
return list(pending.values())
@staticmethod
def checkpoint_summary(
checkpoint: WorkflowCheckpoint,
*,
request_executor_ids: Iterable[str] | None = None,
preview_width: int = 70,
) -> WorkflowCheckpointSummary:
targets = sorted(checkpoint.messages.keys())
executor_states = sorted(checkpoint.executor_states.keys())
pending = RequestInfoExecutor.pending_requests_from_checkpoint(
checkpoint, request_executor_ids=request_executor_ids
)
draft_preview: str | None = None
for entry in pending:
if entry.draft:
draft_preview = shorten(entry.draft, width=preview_width, placeholder="")
break
status = "idle"
if pending:
status = "awaiting human response"
elif not checkpoint.messages and "finalise" in executor_states:
status = "completed"
elif checkpoint.messages:
status = "awaiting next superstep"
elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids):
status = "awaiting request delivery"
return WorkflowCheckpointSummary(
checkpoint_id=checkpoint.checkpoint_id,
iteration_count=checkpoint.iteration_count,
targets=targets,
executor_states=executor_states,
status=status,
draft_preview=draft_preview,
pending_requests=pending,
)
@staticmethod
def _merge_snapshot(
pending: dict[str, PendingRequestDetails],
request_id: str,
snapshot: Any,
) -> None:
if not request_id or not isinstance(snapshot, Mapping):
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=snapshot.get("prompt"), # type: ignore[attr-defined]
draft=snapshot.get("draft"), # type: ignore[attr-defined]
iteration=snapshot.get("iteration"), # type: ignore[attr-defined]
source_executor_id=snapshot.get("source_executor_id"), # type: ignore[attr-defined]
)
extra = snapshot.get("details") # type: ignore[attr-defined]
if isinstance(extra, Mapping):
RequestInfoExecutor._apply_update(
details,
prompt=extra.get("prompt"), # type: ignore[attr-defined]
draft=extra.get("draft"), # type: ignore[attr-defined]
iteration=extra.get("iteration"), # type: ignore[attr-defined]
)
@staticmethod
def _merge_message_payload(
pending: dict[str, PendingRequestDetails],
payload: Any,
raw_message: Mapping[str, Any],
) -> None:
if isinstance(payload, RequestResponse):
request_id = payload.request_id or RequestInfoExecutor._get_field(payload.original_request, "request_id") # type: ignore[arg-type]
if not request_id:
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=RequestInfoExecutor._get_field(payload.original_request, "prompt"), # type: ignore[arg-type]
draft=RequestInfoExecutor._get_field(payload.original_request, "draft"), # type: ignore[arg-type]
iteration=RequestInfoExecutor._get_field(payload.original_request, "iteration"), # type: ignore[arg-type]
source_executor_id=raw_message.get("source_id"),
original_request=payload.original_request, # type: ignore[arg-type]
)
elif isinstance(payload, RequestInfoMessage):
request_id = getattr(payload, "request_id", None)
if not request_id:
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=getattr(payload, "prompt", None),
draft=getattr(payload, "draft", None),
iteration=getattr(payload, "iteration", None),
source_executor_id=raw_message.get("source_id"),
original_request=payload,
)
@staticmethod
def _apply_update(
details: PendingRequestDetails,
*,
prompt: Any = None,
draft: Any = None,
iteration: Any = None,
source_executor_id: Any = None,
original_request: Any = None,
) -> None:
if prompt and not details.prompt:
details.prompt = str(prompt)
if draft and not details.draft:
details.draft = str(draft)
if iteration is not None and details.iteration is None:
coerced = RequestInfoExecutor._coerce_int(iteration)
if coerced is not None:
details.iteration = coerced
if source_executor_id and not details.source_executor_id:
details.source_executor_id = str(source_executor_id)
if original_request is not None and details.original_request is None:
details.original_request = original_request
@staticmethod
def _get_field(obj: Any, key: str) -> Any:
if obj is None:
return None
if isinstance(obj, Mapping):
return obj.get(key) # type: ignore[attr-defined,return-value]
return getattr(obj, key, None)
@staticmethod
def _coerce_int(value: Any) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None
# endregion: Request Info Executor
@@ -27,8 +27,9 @@ from agent_framework._agents import BaseAgent
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
from ._events import WorkflowEvent
from ._executor import Executor, RequestInfoMessage, RequestResponse, handler
from ._executor import Executor, handler
from ._model_utils import DictConvertible, encode_value
from ._request_info_executor import RequestInfoMessage, RequestResponse
from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult
from ._workflow_context import WorkflowContext
@@ -1939,7 +1940,7 @@ class MagenticBuilder:
workflow_builder = WorkflowBuilder().set_start_executor(orchestrator_executor)
if self._enable_plan_review:
from ._executor import RequestInfoExecutor
from ._request_info_executor import RequestInfoExecutor
request_info = RequestInfoExecutor(id="magentic_plan_review")
workflow_builder = (
@@ -0,0 +1,841 @@
# Copyright (c) Microsoft. All rights reserved.
import contextlib
import importlib
import json
import logging
import uuid
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from textwrap import shorten
from typing import Any, ClassVar, Generic, TypeVar, cast
from ._checkpoint import WorkflowCheckpoint
from ._events import (
RequestInfoEvent, # type: ignore[reportPrivateUsage]
)
from ._executor import Executor, handler
from ._runner_context import _decode_checkpoint_value # type: ignore
from ._workflow_context import WorkflowContext
logger = logging.getLogger(__name__)
@dataclass
class PendingRequestDetails:
"""Lightweight information about a pending request captured in a checkpoint."""
request_id: str
prompt: str | None = None
draft: str | None = None
iteration: int | None = None
source_executor_id: str | None = None
original_request: "RequestInfoMessage | dict[str, Any] | None" = None
@dataclass
class WorkflowCheckpointSummary:
"""Human-readable summary of a workflow checkpoint."""
checkpoint_id: str
iteration_count: int
targets: list[str]
executor_states: list[str]
status: str
draft_preview: str | None
pending_requests: list[PendingRequestDetails]
@dataclass
class RequestInfoMessage:
"""Base class for all request messages in workflows.
Any message that should be routed to the RequestInfoExecutor for external
handling must inherit from this class. This ensures type safety and makes
the request/response pattern explicit.
"""
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
"""Unique identifier for correlating requests and responses."""
source_executor_id: str | None = None
"""ID of the executor expecting a response to this request.
May differ from the executor that sent the request if intercepted and forwarded."""
TRequest = TypeVar("TRequest", bound="RequestInfoMessage")
TResponse = TypeVar("TResponse")
@dataclass
class RequestResponse(Generic[TRequest, TResponse]):
"""Response type for request/response correlation in workflows.
This type is used by RequestInfoExecutor to create correlated responses
that include the original request context for proper message routing.
"""
data: TResponse
"""The response data returned from handling the request."""
original_request: TRequest
"""The original request that this response corresponds to."""
request_id: str
"""The ID of the original request."""
# endregion: Request/Response Types
# region Request Info Executor
class RequestInfoExecutor(Executor):
"""Built-in executor that handles request/response patterns in workflows.
This executor acts as a gateway for external information requests. When it receives
a request message, it saves the request details and emits a RequestInfoEvent. When
a response is provided externally, it emits the response as a message.
"""
_PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info"
def __init__(self, id: str):
"""Initialize the RequestInfoExecutor with a unique ID.
Args:
id: Unique ID for this RequestInfoExecutor.
"""
super().__init__(id=id)
self._request_events: dict[str, RequestInfoEvent] = {}
@handler
async def run(self, message: RequestInfoMessage, ctx: WorkflowContext) -> None:
"""Run the RequestInfoExecutor with the given message."""
# Use source_executor_id from message if available, otherwise fall back to context
source_executor_id = message.source_executor_id or ctx.get_source_executor_id()
event = RequestInfoEvent(
request_id=message.request_id,
source_executor_id=source_executor_id,
request_type=type(message),
request_data=message,
)
self._request_events[message.request_id] = event
await self._record_pending_request_snapshot(message, source_executor_id, ctx)
await ctx.add_event(event)
async def handle_response(
self,
response_data: Any,
request_id: str,
ctx: WorkflowContext[RequestResponse[RequestInfoMessage, Any]],
) -> None:
"""Handle a response to a request.
Args:
request_id: The ID of the request to which this response corresponds.
response_data: The data returned in the response.
ctx: The workflow context for sending the response.
"""
event = self._request_events.get(request_id)
if event is None:
event = await self._rehydrate_request_event(request_id, ctx)
if event is None:
raise ValueError(f"No request found with ID: {request_id}")
self._request_events.pop(request_id, None)
# Create a correlated response that includes both the response data and original request
if not isinstance(event.data, RequestInfoMessage):
raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}")
correlated_response = RequestResponse(data=response_data, original_request=event.data, request_id=request_id)
await ctx.send_message(correlated_response, target_id=event.source_executor_id)
await self._clear_pending_request_snapshot(request_id, ctx)
async def _record_pending_request_snapshot(
self,
request: RequestInfoMessage,
source_executor_id: str,
ctx: WorkflowContext[Any],
) -> None:
snapshot = self._build_request_snapshot(request, source_executor_id)
pending = await self._load_pending_request_state(ctx)
pending[request.request_id] = snapshot
await self._persist_pending_request_state(pending, ctx)
await self._write_executor_state(ctx, pending)
async def _clear_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> None:
pending = await self._load_pending_request_state(ctx)
if request_id in pending:
pending.pop(request_id, None)
await self._persist_pending_request_state(pending, ctx)
await self._write_executor_state(ctx, pending)
async def _load_pending_request_state(self, ctx: WorkflowContext[Any]) -> dict[str, Any]:
try:
existing = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
except KeyError:
return {}
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read pending request state: {exc}")
return {}
if not isinstance(existing, dict):
if existing not in (None, {}):
logger.warning(
f"RequestInfoExecutor {self.id} encountered non-dict pending state "
f"({type(existing).__name__}); resetting."
)
return {}
return dict(existing) # type: ignore[arg-type]
async def _persist_pending_request_state(self, pending: dict[str, Any], ctx: WorkflowContext[Any]) -> None:
await self._safe_set_shared_state(ctx, pending)
await self._safe_set_runner_state(ctx, pending)
async def _safe_set_shared_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
try:
await ctx.set_shared_state(self._PENDING_SHARED_STATE_KEY, pending)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to update shared pending state: {exc}")
async def _safe_set_runner_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
try:
await ctx.set_state({"pending_requests": pending})
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to update runner state with pending requests: {exc}")
def snapshot_state(self) -> dict[str, Any]:
"""Serialize pending requests so checkpoint restoration can resume seamlessly."""
def _encode_event(event: RequestInfoEvent) -> dict[str, Any]:
request_data = event.data
payload: dict[str, Any]
data_cls = request_data.__class__ if request_data is not None else type(None)
payload = self._encode_request_payload(request_data, data_cls)
return {
"source_executor_id": event.source_executor_id,
"request_type": f"{event.request_type.__module__}:{event.request_type.__qualname__}",
"request_data": payload,
}
return {
"request_events": {rid: _encode_event(event) for rid, event in self._request_events.items()},
}
def _encode_request_payload(self, request_data: RequestInfoMessage | None, data_cls: type[Any]) -> dict[str, Any]:
if request_data is None or isinstance(request_data, (str, int, float, bool)):
return {
"kind": "raw",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": request_data,
}
if is_dataclass(request_data) and not isinstance(request_data, type):
dataclass_instance = cast(Any, request_data)
safe_value = self._make_json_safe(asdict(dataclass_instance))
return {
"kind": "dataclass",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
to_dict_fn = getattr(request_data, "to_dict", None)
if callable(to_dict_fn):
try:
dumped = to_dict_fn()
except TypeError:
dumped = to_dict_fn()
safe_value = self._make_json_safe(dumped)
return {
"kind": "dict",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
to_json_fn = getattr(request_data, "to_json", None)
if callable(to_json_fn):
try:
dumped = to_json_fn()
except TypeError:
dumped = to_json_fn()
converted = dumped
if isinstance(dumped, (str, bytes, bytearray)):
decoded: str | bytes | bytearray
if isinstance(dumped, (bytes, bytearray)):
try:
decoded = dumped.decode()
except Exception:
decoded = dumped
else:
decoded = dumped
try:
converted = json.loads(decoded)
except Exception:
converted = decoded
safe_value = self._make_json_safe(converted)
return {
"kind": "dict" if isinstance(converted, dict) else "json",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
details = self._serialise_request_details(request_data)
if details is not None:
safe_value = self._make_json_safe(details)
return {
"kind": "raw",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
safe_value = self._make_json_safe(request_data)
return {
"kind": "raw",
"type": f"{data_cls.__module__}:{data_cls.__qualname__}",
"value": safe_value,
}
def restore_state(self, state: dict[str, Any]) -> None:
"""Restore pending request bookkeeping from checkpoint state."""
self._request_events.clear()
stored_events = state.get("request_events", {})
for request_id, payload in stored_events.items():
request_type_qual = payload.get("request_type", "")
try:
request_type = self._import_qualname(request_type_qual)
except Exception as exc: # pragma: no cover - defensive fallback
logger.debug(
"RequestInfoExecutor %s failed to import %s during restore: %s",
self.id,
request_type_qual,
exc,
)
request_type = RequestInfoMessage
request_data_meta = payload.get("request_data", {})
request_data = self._decode_request_data(request_data_meta)
event = RequestInfoEvent(
request_id=request_id,
source_executor_id=payload.get("source_executor_id", ""),
request_type=request_type,
request_data=request_data,
)
self._request_events[request_id] = event
@staticmethod
def _import_qualname(qualname: str) -> type[Any]:
module_name, _, type_name = qualname.partition(":")
if not module_name or not type_name:
raise ValueError(f"Invalid qualified name: {qualname}")
module = importlib.import_module(module_name)
attr: Any = module
for part in type_name.split("."):
attr = getattr(attr, part)
if not isinstance(attr, type):
raise TypeError(f"Resolved object is not a type: {qualname}")
return attr
def _decode_request_data(self, metadata: dict[str, Any]) -> RequestInfoMessage:
kind = metadata.get("kind")
type_name = metadata.get("type", "")
value: Any = metadata.get("value", {})
if type_name:
try:
imported = self._import_qualname(type_name)
except Exception as exc: # pragma: no cover - defensive fallback
logger.debug(
"RequestInfoExecutor %s failed to import %s during decode: %s",
self.id,
type_name,
exc,
)
imported = RequestInfoMessage
else:
imported = RequestInfoMessage
target_cls: type[RequestInfoMessage]
if isinstance(imported, type) and issubclass(imported, RequestInfoMessage):
target_cls = imported
else:
target_cls = RequestInfoMessage
if kind == "dataclass" and isinstance(value, dict):
with contextlib.suppress(TypeError):
return target_cls(**value) # type: ignore[arg-type]
# Backwards-compat handling for checkpoints that used to store pydantic as "dict"
if kind in {"dict", "pydantic", "json"} and isinstance(value, dict):
from_dict = getattr(target_cls, "from_dict", None)
if callable(from_dict):
with contextlib.suppress(Exception):
return cast(RequestInfoMessage, from_dict(value))
if kind == "json" and isinstance(value, str):
from_json = getattr(target_cls, "from_json", None)
if callable(from_json):
with contextlib.suppress(Exception):
return cast(RequestInfoMessage, from_json(value))
with contextlib.suppress(Exception):
parsed = json.loads(value)
if isinstance(parsed, dict):
return self._decode_request_data({"kind": "dict", "type": type_name, "value": parsed})
if isinstance(value, dict):
with contextlib.suppress(TypeError):
return target_cls(**value) # type: ignore[arg-type]
instance = object.__new__(target_cls)
instance.__dict__.update(value) # type: ignore[arg-type]
return instance
with contextlib.suppress(Exception):
return target_cls()
return RequestInfoMessage()
async def _write_executor_state(self, ctx: WorkflowContext[Any], pending: dict[str, Any]) -> None:
state = self.snapshot_state()
state["pending_requests"] = pending
try:
await ctx.set_state(state)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to persist executor state: {exc}")
def _build_request_snapshot(
self,
request: RequestInfoMessage,
source_executor_id: str,
) -> dict[str, Any]:
snapshot: dict[str, Any] = {
"request_id": request.request_id,
"source_executor_id": source_executor_id,
"request_type": f"{type(request).__module__}:{type(request).__name__}",
"summary": repr(request),
}
details = self._serialise_request_details(request)
if details:
snapshot["details"] = details
for key in ("prompt", "draft", "iteration"):
if key in details and key not in snapshot:
snapshot[key] = details[key]
return snapshot
def _serialise_request_details(self, request: RequestInfoMessage) -> dict[str, Any] | None:
if is_dataclass(request):
data = self._make_json_safe(asdict(request))
if isinstance(data, dict):
return cast(dict[str, Any], data)
return None
to_dict = getattr(request, "to_dict", None)
if callable(to_dict):
try:
dump = self._make_json_safe(to_dict())
except TypeError:
dump = self._make_json_safe(to_dict())
if isinstance(dump, dict):
return cast(dict[str, Any], dump)
return None
to_json = getattr(request, "to_json", None)
if callable(to_json):
try:
raw = to_json()
except TypeError:
raw = to_json()
converted = raw
if isinstance(raw, (str, bytes, bytearray)):
decoded: str | bytes | bytearray
if isinstance(raw, (bytes, bytearray)):
try:
decoded = raw.decode()
except Exception:
decoded = raw
else:
decoded = raw
try:
converted = json.loads(decoded)
except Exception:
converted = decoded
dump = self._make_json_safe(converted)
if isinstance(dump, dict):
return cast(dict[str, Any], dump)
return None
attrs = getattr(request, "__dict__", None)
if isinstance(attrs, dict):
cleaned = self._make_json_safe(attrs)
if isinstance(cleaned, dict):
return cast(dict[str, Any], cleaned)
return None
def _make_json_safe(self, value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Mapping):
safe_dict: dict[str, Any] = {}
for key, val in value.items(): # type: ignore[attr-defined]
safe_dict[str(key)] = self._make_json_safe(val) # type: ignore[arg-type]
return safe_dict
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return [self._make_json_safe(item) for item in value] # type: ignore[misc]
return repr(value)
async def has_pending_request(self, request_id: str, ctx: WorkflowContext[Any]) -> bool:
if request_id in self._request_events:
return True
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
return snapshot is not None
async def _rehydrate_request_event(
self,
request_id: str,
ctx: WorkflowContext[Any],
) -> RequestInfoEvent | None:
snapshot = await self._get_pending_request_snapshot(request_id, ctx)
if snapshot is None:
return None
source_executor_id = snapshot.get("source_executor_id")
if not isinstance(source_executor_id, str) or not source_executor_id:
return None
request = self._construct_request_from_snapshot(snapshot)
if request is None:
return None
event = RequestInfoEvent(
request_id=request_id,
source_executor_id=source_executor_id,
request_type=type(request),
request_data=request,
)
self._request_events[request_id] = event
return event
async def _get_pending_request_snapshot(self, request_id: str, ctx: WorkflowContext[Any]) -> dict[str, Any] | None:
pending = await self._collect_pending_request_snapshots(ctx)
snapshot = pending.get(request_id)
if snapshot is None:
return None
return snapshot
async def _collect_pending_request_snapshots(self, ctx: WorkflowContext[Any]) -> dict[str, dict[str, Any]]:
combined: dict[str, dict[str, Any]] = {}
try:
shared_pending = await ctx.get_shared_state(self._PENDING_SHARED_STATE_KEY)
except KeyError:
shared_pending = None
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read shared pending state during rehydrate: {exc}")
shared_pending = None
if isinstance(shared_pending, dict):
for key, value in shared_pending.items(): # type: ignore[attr-defined]
if isinstance(key, str) and isinstance(value, dict):
combined[key] = cast(dict[str, Any], value)
try:
state = await ctx.get_state()
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"RequestInfoExecutor {self.id} failed to read runner state during rehydrate: {exc}")
state = None
if isinstance(state, dict):
state_pending = state.get("pending_requests")
if isinstance(state_pending, dict):
for key, value in state_pending.items(): # type: ignore[attr-defined]
if isinstance(key, str) and isinstance(value, dict) and key not in combined:
combined[key] = cast(dict[str, Any], value)
return combined
def _construct_request_from_snapshot(self, snapshot: dict[str, Any]) -> RequestInfoMessage | None:
details_raw = snapshot.get("details")
details: dict[str, Any] = cast(dict[str, Any], details_raw) if isinstance(details_raw, dict) else {}
request_cls: type[RequestInfoMessage] = RequestInfoMessage
request_type_str = snapshot.get("request_type")
if isinstance(request_type_str, str) and ":" in request_type_str:
module_name, class_name = request_type_str.split(":", 1)
try:
module = importlib.import_module(module_name)
candidate = getattr(module, class_name)
if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage):
request_cls = candidate
except Exception as exc:
logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}")
request_cls = RequestInfoMessage
request: RequestInfoMessage | None = self._instantiate_request(request_cls, details)
if request is None and request_cls is not RequestInfoMessage:
request = self._instantiate_request(RequestInfoMessage, details)
if request is None:
logger.warning(
f"RequestInfoExecutor {self.id} could not reconstruct request "
f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(details.keys())}"
)
return None
for key, value in details.items():
if key == "request_id":
continue
try:
setattr(request, key, value)
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}"
)
continue
snapshot_request_id = snapshot.get("request_id")
if isinstance(snapshot_request_id, str) and snapshot_request_id:
try:
request.request_id = snapshot_request_id
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not apply snapshot "
f"request_id to {type(request).__name__}: {exc}"
)
return request
def _instantiate_request(
self,
request_cls: type[RequestInfoMessage],
details: dict[str, Any],
) -> RequestInfoMessage | None:
try:
from_dict = getattr(request_cls, "from_dict", None)
if callable(from_dict):
return cast(RequestInfoMessage, from_dict(details))
except (TypeError, ValueError) as exc:
logger.debug(f"RequestInfoExecutor {self.id} failed to hydrate {request_cls.__name__} via from_dict: {exc}")
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} encountered unexpected error during "
f"{request_cls.__name__}.from_dict: {exc}"
)
if is_dataclass(request_cls):
try:
field_names = {f.name for f in fields(request_cls)}
ctor_kwargs = {name: details[name] for name in field_names if name in details}
return request_cls(**ctor_kwargs)
except (TypeError, ValueError) as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not instantiate dataclass "
f"{request_cls.__name__} with snapshot data: {exc}"
)
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} encountered unexpected error "
f"constructing dataclass {request_cls.__name__}: {exc}"
)
try:
instance = request_cls()
except Exception as exc:
logger.warning(
f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}"
)
return None
for key, value in details.items():
if key == "request_id":
continue
try:
setattr(instance, key, value)
except Exception as exc:
logger.debug(
f"RequestInfoExecutor {self.id} could not set attribute {key} on "
f"{request_cls.__name__} during instantiation: {exc}"
)
continue
return instance
@staticmethod
def pending_requests_from_checkpoint(
checkpoint: WorkflowCheckpoint,
*,
request_executor_ids: Iterable[str] | None = None,
) -> list[PendingRequestDetails]:
executor_filter: set[str] | None = None
if request_executor_ids is not None:
executor_filter = {str(value) for value in request_executor_ids}
pending: dict[str, PendingRequestDetails] = {}
shared_map = checkpoint.shared_state.get(RequestInfoExecutor._PENDING_SHARED_STATE_KEY)
if isinstance(shared_map, Mapping):
for request_id, snapshot in shared_map.items(): # type: ignore[attr-defined]
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type]
for state in checkpoint.executor_states.values():
if not isinstance(state, Mapping):
continue
inner = state.get("pending_requests")
if isinstance(inner, Mapping):
for request_id, snapshot in inner.items(): # type: ignore[attr-defined]
RequestInfoExecutor._merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type]
for source_id, message_list in checkpoint.messages.items():
if executor_filter is not None and source_id not in executor_filter:
continue
if not isinstance(message_list, list):
continue
for message in message_list:
if not isinstance(message, Mapping):
continue
payload = _decode_checkpoint_value(message.get("data"))
RequestInfoExecutor._merge_message_payload(pending, payload, message)
return list(pending.values())
@staticmethod
def checkpoint_summary(
checkpoint: WorkflowCheckpoint,
*,
request_executor_ids: Iterable[str] | None = None,
preview_width: int = 70,
) -> WorkflowCheckpointSummary:
targets = sorted(checkpoint.messages.keys())
executor_states = sorted(checkpoint.executor_states.keys())
pending = RequestInfoExecutor.pending_requests_from_checkpoint(
checkpoint, request_executor_ids=request_executor_ids
)
draft_preview: str | None = None
for entry in pending:
if entry.draft:
draft_preview = shorten(entry.draft, width=preview_width, placeholder="")
break
status = "idle"
if pending:
status = "awaiting human response"
elif not checkpoint.messages and "finalise" in executor_states:
status = "completed"
elif checkpoint.messages:
status = "awaiting next superstep"
elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids):
status = "awaiting request delivery"
return WorkflowCheckpointSummary(
checkpoint_id=checkpoint.checkpoint_id,
iteration_count=checkpoint.iteration_count,
targets=targets,
executor_states=executor_states,
status=status,
draft_preview=draft_preview,
pending_requests=pending,
)
@staticmethod
def _merge_snapshot(
pending: dict[str, PendingRequestDetails],
request_id: str,
snapshot: Any,
) -> None:
if not request_id or not isinstance(snapshot, Mapping):
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=snapshot.get("prompt"), # type: ignore[attr-defined]
draft=snapshot.get("draft"), # type: ignore[attr-defined]
iteration=snapshot.get("iteration"), # type: ignore[attr-defined]
source_executor_id=snapshot.get("source_executor_id"), # type: ignore[attr-defined]
)
extra = snapshot.get("details") # type: ignore[attr-defined]
if isinstance(extra, Mapping):
RequestInfoExecutor._apply_update(
details,
prompt=extra.get("prompt"), # type: ignore[attr-defined]
draft=extra.get("draft"), # type: ignore[attr-defined]
iteration=extra.get("iteration"), # type: ignore[attr-defined]
)
@staticmethod
def _merge_message_payload(
pending: dict[str, PendingRequestDetails],
payload: Any,
raw_message: Mapping[str, Any],
) -> None:
if isinstance(payload, RequestResponse):
request_id = payload.request_id or RequestInfoExecutor._get_field(payload.original_request, "request_id") # type: ignore[arg-type]
if not request_id:
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=RequestInfoExecutor._get_field(payload.original_request, "prompt"), # type: ignore[arg-type]
draft=RequestInfoExecutor._get_field(payload.original_request, "draft"), # type: ignore[arg-type]
iteration=RequestInfoExecutor._get_field(payload.original_request, "iteration"), # type: ignore[arg-type]
source_executor_id=raw_message.get("source_id"),
original_request=payload.original_request, # type: ignore[arg-type]
)
elif isinstance(payload, RequestInfoMessage):
request_id = getattr(payload, "request_id", None)
if not request_id:
return
details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id))
RequestInfoExecutor._apply_update(
details,
prompt=getattr(payload, "prompt", None),
draft=getattr(payload, "draft", None),
iteration=getattr(payload, "iteration", None),
source_executor_id=raw_message.get("source_id"),
original_request=payload,
)
@staticmethod
def _apply_update(
details: PendingRequestDetails,
*,
prompt: Any = None,
draft: Any = None,
iteration: Any = None,
source_executor_id: Any = None,
original_request: Any = None,
) -> None:
if prompt and not details.prompt:
details.prompt = str(prompt)
if draft and not details.draft:
details.draft = str(draft)
if iteration is not None and details.iteration is None:
coerced = RequestInfoExecutor._coerce_int(iteration)
if coerced is not None:
details.iteration = coerced
if source_executor_id and not details.source_executor_id:
details.source_executor_id = str(source_executor_id)
if original_request is not None and details.original_request is None:
details.original_request = original_request
@staticmethod
def _get_field(obj: Any, key: str) -> Any:
if obj is None:
return None
if isinstance(obj, Mapping):
return obj.get(key) # type: ignore[attr-defined,return-value]
return getattr(obj, key, None)
@staticmethod
def _coerce_int(value: Any) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None
@@ -22,7 +22,7 @@ from ._runner_context import (
from ._shared_state import SharedState
if TYPE_CHECKING:
from ._executor import RequestInfoExecutor
from ._request_info_executor import RequestInfoExecutor
logger = logging.getLogger(__name__)
@@ -372,7 +372,7 @@ class Runner:
Returns:
The RequestInfoExecutor instance if found, None otherwise.
"""
from ._executor import RequestInfoExecutor
from ._request_info_executor import RequestInfoExecutor
for executor in self._executors.values():
if isinstance(executor, RequestInfoExecutor):
@@ -388,7 +388,7 @@ class Runner:
Returns:
True if the message targets a RequestInfoExecutor, False otherwise.
"""
from ._executor import RequestInfoExecutor
from ._request_info_executor import RequestInfoExecutor
if not msg.target_id:
return False
@@ -9,7 +9,8 @@ from types import UnionType
from typing import Any, Union, get_args, get_origin
from ._edge import Edge, EdgeGroup, FanInEdgeGroup
from ._executor import Executor, RequestInfoExecutor
from ._executor import Executor
from ._request_info_executor import RequestInfoExecutor
logger = logging.getLogger(__name__)
@@ -37,8 +37,9 @@ from ._events import (
WorkflowStatusEvent,
_framework_event_origin, # type: ignore
)
from ._executor import Executor, RequestInfoExecutor
from ._executor import Executor
from ._model_utils import DictConvertible
from ._request_info_executor import RequestInfoExecutor
from ._runner import Runner
from ._runner_context import InProcRunnerContext, RunnerContext
from ._shared_state import SharedState
@@ -742,7 +743,7 @@ class Workflow(DictConvertible):
Returns:
The RequestInfoExecutor instance if found, None otherwise.
"""
from ._executor import RequestInfoExecutor
from ._request_info_executor import RequestInfoExecutor
for executor in self.executors.values():
if isinstance(executor, RequestInfoExecutor):
@@ -19,10 +19,12 @@ from ._events import (
)
from ._executor import (
Executor,
handler,
)
from ._request_info_executor import (
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
handler,
)
from ._typing_utils import is_instance_of
from ._workflow_context import WorkflowContext
@@ -3,7 +3,7 @@
from dataclasses import dataclass # noqa: I001
from typing import Any, cast
from agent_framework._workflows._executor import RequestInfoMessage, RequestResponse
from agent_framework._workflows._request_info_executor import RequestInfoMessage, RequestResponse
from agent_framework._workflows._runner_context import ( # type: ignore
_decode_checkpoint_value, # type: ignore
_encode_checkpoint_value, # type: ignore
@@ -7,7 +7,7 @@ from typing import Any
from agent_framework._workflows._checkpoint import CheckpointStorage, WorkflowCheckpoint
from agent_framework._workflows._events import RequestInfoEvent, WorkflowEvent
from agent_framework._workflows._executor import (
from agent_framework._workflows._request_info_executor import (
PendingRequestDetails,
RequestInfoExecutor,
RequestInfoMessage,