mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Code clean up: Checkpoint and WorkflowBuilder (#1557)
Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
9c3f52566f
commit
083d0de3f3
@@ -98,7 +98,8 @@ from ._validation import (
|
||||
validate_workflow_graph,
|
||||
)
|
||||
from ._viz import WorkflowViz
|
||||
from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult
|
||||
from ._workflow import Workflow, WorkflowRunResult
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
from ._workflow_executor import WorkflowExecutor
|
||||
|
||||
|
||||
@@ -96,7 +96,8 @@ from ._validation import (
|
||||
validate_workflow_graph,
|
||||
)
|
||||
from ._viz import WorkflowViz
|
||||
from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult
|
||||
from ._workflow import Workflow, WorkflowRunResult
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
from ._workflow_executor import WorkflowExecutor
|
||||
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
# Checkpoint serialization helpers
|
||||
MODEL_MARKER = "__af_model__"
|
||||
DATACLASS_MARKER = "__af_dataclass__"
|
||||
|
||||
# Guards to prevent runaway recursion while encoding arbitrary user data
|
||||
_MAX_ENCODE_DEPTH = 100
|
||||
_CYCLE_SENTINEL = "<cycle>"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def encode_checkpoint_value(value: Any) -> Any:
|
||||
"""Recursively encode values into JSON-serializable structures.
|
||||
|
||||
- Objects exposing to_dict/to_json -> { MODEL_MARKER: "module:Class", value: encoded }
|
||||
- dataclass instances -> { DATACLASS_MARKER: "module:Class", value: {field: encoded} }
|
||||
- dict -> encode keys as str and values recursively
|
||||
- list/tuple/set -> list of encoded items
|
||||
- other -> returned as-is if already JSON-serializable
|
||||
|
||||
Includes cycle and depth protection to avoid infinite recursion.
|
||||
"""
|
||||
|
||||
def _enc(v: Any, stack: set[int], depth: int) -> Any:
|
||||
# Depth guard
|
||||
if depth > _MAX_ENCODE_DEPTH:
|
||||
logger.debug(f"Max encode depth reached at depth={depth} for type={type(v)}")
|
||||
return "<max_depth>"
|
||||
|
||||
# Structured model handling (objects exposing to_dict/to_json)
|
||||
if _supports_model_protocol(v):
|
||||
cls = cast(type[Any], type(v)) # type: ignore
|
||||
try:
|
||||
if hasattr(v, "to_dict") and callable(getattr(v, "to_dict", None)):
|
||||
raw = v.to_dict() # type: ignore[attr-defined]
|
||||
strategy = "to_dict"
|
||||
elif hasattr(v, "to_json") and callable(getattr(v, "to_json", None)):
|
||||
serialized = v.to_json() # type: ignore[attr-defined]
|
||||
if isinstance(serialized, (bytes, bytearray)):
|
||||
try:
|
||||
serialized = serialized.decode()
|
||||
except Exception:
|
||||
serialized = serialized.decode(errors="replace")
|
||||
raw = serialized
|
||||
strategy = "to_json"
|
||||
else:
|
||||
raise AttributeError("Structured model lacks serialization hooks")
|
||||
return {
|
||||
MODEL_MARKER: f"{cls.__module__}:{cls.__name__}",
|
||||
"strategy": strategy,
|
||||
"value": _enc(raw, stack, depth + 1),
|
||||
}
|
||||
except Exception as exc: # best-effort fallback
|
||||
logger.debug(f"Structured model serialization failed for {cls}: {exc}")
|
||||
return str(v)
|
||||
|
||||
# Dataclasses (instances only)
|
||||
if is_dataclass(v) and not isinstance(v, type):
|
||||
oid = id(v)
|
||||
if oid in stack:
|
||||
logger.debug("Cycle detected while encoding dataclass instance")
|
||||
return _CYCLE_SENTINEL
|
||||
stack.add(oid)
|
||||
try:
|
||||
# type(v) already narrows sufficiently; cast was redundant
|
||||
dc_cls: type[Any] = type(v)
|
||||
field_values: dict[str, Any] = {}
|
||||
for f in fields(v): # type: ignore[arg-type]
|
||||
field_values[f.name] = _enc(getattr(v, f.name), stack, depth + 1)
|
||||
return {
|
||||
DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}",
|
||||
"value": field_values,
|
||||
}
|
||||
finally:
|
||||
stack.remove(oid)
|
||||
|
||||
# Collections
|
||||
if isinstance(v, dict):
|
||||
v_dict = cast("dict[object, object]", v)
|
||||
oid = id(v_dict)
|
||||
if oid in stack:
|
||||
logger.debug("Cycle detected while encoding dict")
|
||||
return _CYCLE_SENTINEL
|
||||
stack.add(oid)
|
||||
try:
|
||||
json_dict: dict[str, Any] = {}
|
||||
for k_any, val_any in v_dict.items(): # type: ignore[assignment]
|
||||
k_str: str = str(k_any)
|
||||
json_dict[k_str] = _enc(val_any, stack, depth + 1)
|
||||
return json_dict
|
||||
finally:
|
||||
stack.remove(oid)
|
||||
|
||||
if isinstance(v, (list, tuple, set)):
|
||||
iterable_v = cast("list[object] | tuple[object, ...] | set[object]", v)
|
||||
oid = id(iterable_v)
|
||||
if oid in stack:
|
||||
logger.debug("Cycle detected while encoding iterable")
|
||||
return _CYCLE_SENTINEL
|
||||
stack.add(oid)
|
||||
try:
|
||||
seq: list[object] = list(iterable_v)
|
||||
encoded_list: list[Any] = []
|
||||
for item in seq:
|
||||
encoded_list.append(_enc(item, stack, depth + 1))
|
||||
return encoded_list
|
||||
finally:
|
||||
stack.remove(oid)
|
||||
|
||||
# Primitives (or unknown objects): ensure JSON-serializable
|
||||
if isinstance(v, (str, int, float, bool)) or v is None:
|
||||
return v
|
||||
# Fallback: stringify unknown objects to avoid JSON serialization errors
|
||||
try:
|
||||
return str(v)
|
||||
except Exception:
|
||||
return f"<{type(v).__name__}>"
|
||||
|
||||
return _enc(value, set(), 0)
|
||||
|
||||
|
||||
def decode_checkpoint_value(value: Any) -> Any:
|
||||
"""Recursively decode values previously encoded by encode_checkpoint_value."""
|
||||
if isinstance(value, dict):
|
||||
value_dict = cast(dict[str, Any], value) # encoded form always uses string keys
|
||||
# Structured model marker handling
|
||||
if MODEL_MARKER in value_dict and "value" in value_dict:
|
||||
type_key: str | None = value_dict.get(MODEL_MARKER) # type: ignore[assignment]
|
||||
strategy: str | None = value_dict.get("strategy") # type: ignore[assignment]
|
||||
raw_encoded: Any = value_dict.get("value")
|
||||
decoded_payload = decode_checkpoint_value(raw_encoded)
|
||||
if isinstance(type_key, str):
|
||||
try:
|
||||
cls = _import_qualified_name(type_key)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Failed to import structured model {type_key}: {exc}")
|
||||
cls = None
|
||||
|
||||
if cls is not None:
|
||||
if strategy == "to_dict" and hasattr(cls, "from_dict"):
|
||||
with contextlib.suppress(Exception):
|
||||
return cls.from_dict(decoded_payload)
|
||||
if strategy == "to_json" and hasattr(cls, "from_json"):
|
||||
if isinstance(decoded_payload, (str, bytes, bytearray)):
|
||||
with contextlib.suppress(Exception):
|
||||
return cls.from_json(decoded_payload)
|
||||
if isinstance(decoded_payload, dict) and hasattr(cls, "from_dict"):
|
||||
with contextlib.suppress(Exception):
|
||||
return cls.from_dict(decoded_payload)
|
||||
return decoded_payload
|
||||
# Dataclass marker handling
|
||||
if DATACLASS_MARKER in value_dict and "value" in value_dict:
|
||||
type_key_dc: str | None = value_dict.get(DATACLASS_MARKER) # type: ignore[assignment]
|
||||
raw_dc: Any = value_dict.get("value")
|
||||
decoded_raw = decode_checkpoint_value(raw_dc)
|
||||
if isinstance(type_key_dc, str):
|
||||
try:
|
||||
module_name, class_name = type_key_dc.split(":", 1)
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
cls_dc: Any = getattr(module, class_name)
|
||||
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
|
||||
if constructed is not None:
|
||||
return constructed
|
||||
except Exception as exc:
|
||||
logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value")
|
||||
return decoded_raw
|
||||
|
||||
# Regular dict: decode recursively
|
||||
decoded: dict[str, Any] = {}
|
||||
for k_any, v_any in value_dict.items():
|
||||
decoded[k_any] = decode_checkpoint_value(v_any)
|
||||
return decoded
|
||||
if isinstance(value, list):
|
||||
# After isinstance check, treat value as list[Any] for decoding
|
||||
value_list: list[Any] = value # type: ignore[assignment]
|
||||
return [decode_checkpoint_value(v_any) for v_any in value_list]
|
||||
return value
|
||||
|
||||
|
||||
def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None:
|
||||
if not isinstance(cls, type):
|
||||
logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}")
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
try:
|
||||
return cls(**payload) # type: ignore[arg-type]
|
||||
except TypeError as exc:
|
||||
logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}")
|
||||
try:
|
||||
instance = object.__new__(cls)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}")
|
||||
return None
|
||||
for key, val in payload.items(): # type: ignore[attr-defined]
|
||||
try:
|
||||
setattr(instance, key, val) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}")
|
||||
return instance
|
||||
|
||||
try:
|
||||
return cls(payload) # type: ignore[call-arg]
|
||||
except TypeError as exc:
|
||||
logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def _supports_model_protocol(obj: object) -> bool:
|
||||
"""Detect objects that expose dictionary serialization hooks."""
|
||||
try:
|
||||
obj_type: type[Any] = type(obj)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type]
|
||||
has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None))
|
||||
|
||||
has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type]
|
||||
has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None))
|
||||
|
||||
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
|
||||
|
||||
|
||||
def _import_qualified_name(qualname: str) -> type[Any] | None:
|
||||
if ":" not in qualname:
|
||||
return None
|
||||
module_name, class_name = qualname.split(":", 1)
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
attr: Any = module
|
||||
for part in class_name.split("."):
|
||||
attr = getattr(attr, part)
|
||||
return attr if isinstance(attr, type) else None
|
||||
@@ -7,8 +7,8 @@ from textwrap import shorten
|
||||
from typing import Any
|
||||
|
||||
from ._checkpoint import WorkflowCheckpoint
|
||||
from ._checkpoint_encoding import decode_checkpoint_value
|
||||
from ._request_info_executor import PendingRequestDetails, RequestInfoMessage, RequestResponse
|
||||
from ._runner_context import _decode_checkpoint_value # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,7 +90,7 @@ def _pending_requests_from_checkpoint(
|
||||
for message in message_list:
|
||||
if not isinstance(message, Mapping):
|
||||
continue
|
||||
payload = _decode_checkpoint_value(message.get("data"))
|
||||
payload = decode_checkpoint_value(message.get("data"))
|
||||
_merge_message_payload(pending, payload, message)
|
||||
|
||||
return list(pending.values())
|
||||
|
||||
@@ -13,7 +13,8 @@ from agent_framework import AgentProtocol, ChatMessage, Role
|
||||
from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._executor import Executor, handler
|
||||
from ._workflow import Workflow, WorkflowBuilder
|
||||
from ._workflow import Workflow
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,7 +30,8 @@ from ._events import WorkflowEvent
|
||||
from ._executor import Executor, handler
|
||||
from ._model_utils import DictConvertible, encode_value
|
||||
from ._request_info_executor import RequestInfoMessage, RequestResponse
|
||||
from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult
|
||||
from ._workflow import Workflow, WorkflowRunResult
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@@ -1071,7 +1072,7 @@ class MagenticOrchestratorExecutor(Executor):
|
||||
) -> None:
|
||||
if self._state_restored and self._context is not None:
|
||||
return
|
||||
state = await context.get_state()
|
||||
state = await context.get_executor_state()
|
||||
if not state:
|
||||
self._state_restored = True
|
||||
return
|
||||
@@ -1552,7 +1553,7 @@ class MagenticAgentExecutor(Executor):
|
||||
async def _ensure_state_restored(self, context: WorkflowContext[Any, Any]) -> None:
|
||||
if self._state_restored and self._chat_history:
|
||||
return
|
||||
state = await context.get_state()
|
||||
state = await context.get_executor_state()
|
||||
if not state:
|
||||
self._state_restored = True
|
||||
return
|
||||
|
||||
@@ -249,7 +249,7 @@ class RequestInfoExecutor(Executor):
|
||||
|
||||
async def _retrieve_existing_pending_requests(self, ctx: WorkflowContext) -> dict[str, PendingRequestSnapshot]:
|
||||
"""Retrieve existing pending requests from executor state."""
|
||||
executor_state = await ctx.get_state()
|
||||
executor_state = await ctx.get_executor_state()
|
||||
if executor_state is None:
|
||||
return {}
|
||||
|
||||
@@ -271,9 +271,9 @@ class RequestInfoExecutor(Executor):
|
||||
self, pending: dict[str, PendingRequestSnapshot], ctx: WorkflowContext
|
||||
) -> None:
|
||||
"""Persist the current pending requests to the executor's state."""
|
||||
executor_state = await ctx.get_state() or {}
|
||||
executor_state = await ctx.get_executor_state() or {}
|
||||
executor_state[self._PENDING_SHARED_STATE_KEY] = pending
|
||||
await ctx.set_state(executor_state)
|
||||
await ctx.set_executor_state(executor_state)
|
||||
|
||||
def _build_pending_request_snapshot(
|
||||
self, request: RequestInfoMessage, source_executor_id: str
|
||||
|
||||
@@ -7,17 +7,15 @@ from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpoint_value
|
||||
from ._edge import EdgeGroup
|
||||
from ._edge_runner import EdgeRunner, create_edge_runner
|
||||
from ._events import WorkflowEvent
|
||||
from ._executor import Executor
|
||||
from ._runner_context import (
|
||||
_DATACLASS_MARKER, # type: ignore
|
||||
_MODEL_MARKER, # type: ignore
|
||||
CheckpointState,
|
||||
Message,
|
||||
RunnerContext,
|
||||
_decode_checkpoint_value, # type: ignore
|
||||
WorkflowState,
|
||||
)
|
||||
from ._shared_state import SharedState
|
||||
|
||||
@@ -168,10 +166,10 @@ class Runner:
|
||||
data = message.data
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
if _MODEL_MARKER not in data and _DATACLASS_MARKER not in data:
|
||||
if MODEL_MARKER not in data and DATACLASS_MARKER not in data:
|
||||
return
|
||||
try:
|
||||
decoded = _decode_checkpoint_value(data)
|
||||
decoded = decode_checkpoint_value(data)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug("Failed to decode checkpoint payload during delivery: %s", exc)
|
||||
return
|
||||
@@ -238,7 +236,7 @@ class Runner:
|
||||
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
|
||||
if state_dict is not None:
|
||||
try:
|
||||
await self._ctx.set_state(exec_id, state_dict)
|
||||
await self._ctx.set_executor_state(exec_id, state_dict)
|
||||
except Exception as ex: # pragma: no cover
|
||||
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
|
||||
|
||||
@@ -247,7 +245,7 @@ class Runner:
|
||||
return
|
||||
|
||||
try:
|
||||
current_state = await self._ctx.get_checkpoint_state()
|
||||
current_state = await self._ctx.get_workflow_state()
|
||||
|
||||
shared_state_data = {}
|
||||
async with self._shared_state.hold():
|
||||
@@ -258,7 +256,7 @@ class Runner:
|
||||
current_state["iteration_count"] = self._iteration
|
||||
current_state["max_iterations"] = self._max_iterations
|
||||
|
||||
await self._ctx.set_checkpoint_state(current_state)
|
||||
await self._ctx.set_workflow_state(current_state)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update context with shared state: {e}")
|
||||
|
||||
@@ -278,6 +276,7 @@ class Runner:
|
||||
True if restoration was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Load the checkpoint
|
||||
checkpoint: WorkflowCheckpoint | None
|
||||
if self._ctx.has_checkpointing():
|
||||
checkpoint = await self._ctx.load_checkpoint(checkpoint_id)
|
||||
@@ -291,6 +290,7 @@ class Runner:
|
||||
logger.error(f"Checkpoint {checkpoint_id} not found")
|
||||
return False
|
||||
|
||||
# Validate the loaded checkpoint against the workflow
|
||||
graph_hash = getattr(self, "graph_signature_hash", None)
|
||||
checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature")
|
||||
if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash:
|
||||
@@ -306,8 +306,9 @@ class Runner:
|
||||
|
||||
await self._restore_executor_states(checkpoint.executor_states)
|
||||
|
||||
state = self._checkpoint_to_state(checkpoint)
|
||||
await self._ctx.set_checkpoint_state(state)
|
||||
state = _convert_checkpoint_to_workflow_state(checkpoint)
|
||||
await self._ctx.set_workflow_state(state)
|
||||
|
||||
if checkpoint.workflow_id:
|
||||
self._ctx.set_workflow_id(checkpoint.workflow_id)
|
||||
self._workflow_id = checkpoint.workflow_id
|
||||
@@ -348,7 +349,7 @@ class Runner:
|
||||
|
||||
async def _restore_shared_state_from_context(self) -> None:
|
||||
try:
|
||||
restored_state = await self._ctx.get_checkpoint_state()
|
||||
restored_state = await self._ctx.get_workflow_state()
|
||||
|
||||
shared_state_data = restored_state.get("shared_state", {})
|
||||
if shared_state_data and hasattr(self._shared_state, "_state"):
|
||||
@@ -362,16 +363,6 @@ class Runner:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to restore shared state from context: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_to_state(checkpoint: WorkflowCheckpoint) -> CheckpointState:
|
||||
return {
|
||||
"messages": checkpoint.messages,
|
||||
"shared_state": checkpoint.shared_state,
|
||||
"executor_states": checkpoint.executor_states,
|
||||
"iteration_count": checkpoint.iteration_count,
|
||||
"max_iterations": checkpoint.max_iterations,
|
||||
}
|
||||
|
||||
def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[EdgeRunner]]:
|
||||
"""Parse the edge runners of the workflow into a mapping where each source executor ID maps to its edge runners.
|
||||
|
||||
@@ -421,3 +412,14 @@ class Runner:
|
||||
if executor.id == msg.target_id and isinstance(executor, RequestInfoExecutor):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _convert_checkpoint_to_workflow_state(checkpoint: WorkflowCheckpoint) -> WorkflowState:
|
||||
"""Helper function to convert a WorkflowCheckpoint to a WorkflowState."""
|
||||
return {
|
||||
"messages": checkpoint.messages,
|
||||
"shared_state": checkpoint.shared_state,
|
||||
"executor_states": checkpoint.executor_states,
|
||||
"iteration_count": checkpoint.iteration_count,
|
||||
"max_iterations": checkpoint.max_iterations,
|
||||
}
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, TypedDict, TypeVar, cast, runtime_checkable
|
||||
|
||||
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
|
||||
from ._const import DEFAULT_MAX_ITERATIONS
|
||||
from ._events import WorkflowEvent
|
||||
from ._shared_state import SharedState
|
||||
@@ -45,7 +43,12 @@ class Message:
|
||||
return self.source_span_ids[0] if self.source_span_ids else None
|
||||
|
||||
|
||||
class CheckpointState(TypedDict):
|
||||
class WorkflowState(TypedDict):
|
||||
"""TypedDict representing the serializable state of a workflow execution.
|
||||
|
||||
This includes all state data needed for checkpointing and restoration.
|
||||
"""
|
||||
|
||||
messages: dict[str, list[dict[str, Any]]]
|
||||
shared_state: dict[str, Any]
|
||||
executor_states: dict[str, dict[str, Any]]
|
||||
@@ -53,248 +56,6 @@ class CheckpointState(TypedDict):
|
||||
max_iterations: int
|
||||
|
||||
|
||||
# Checkpoint serialization helpers
|
||||
_MODEL_MARKER = "__af_model__"
|
||||
_DATACLASS_MARKER = "__af_dataclass__"
|
||||
_AF_MARKER = "__af__"
|
||||
|
||||
# Guards to prevent runaway recursion while encoding arbitrary user data
|
||||
_MAX_ENCODE_DEPTH = 100
|
||||
_CYCLE_SENTINEL = "<cycle>"
|
||||
|
||||
|
||||
def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None:
|
||||
if not isinstance(cls, type):
|
||||
logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}")
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
try:
|
||||
return cls(**payload) # type: ignore[arg-type]
|
||||
except TypeError as exc:
|
||||
logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}")
|
||||
try:
|
||||
instance = object.__new__(cls)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}")
|
||||
return None
|
||||
for key, val in payload.items(): # type: ignore[attr-defined]
|
||||
try:
|
||||
setattr(instance, key, val) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}")
|
||||
return instance
|
||||
|
||||
try:
|
||||
return cls(payload) # type: ignore[call-arg]
|
||||
except TypeError as exc:
|
||||
logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def _supports_model_protocol(obj: object) -> bool:
|
||||
"""Detect objects that expose dictionary serialization hooks."""
|
||||
try:
|
||||
obj_type: type[Any] = type(obj)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type]
|
||||
has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None))
|
||||
|
||||
has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type]
|
||||
has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None))
|
||||
|
||||
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
|
||||
|
||||
|
||||
def _import_qualified_name(qualname: str) -> type[Any] | None:
|
||||
if ":" not in qualname:
|
||||
return None
|
||||
module_name, class_name = qualname.split(":", 1)
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
attr: Any = module
|
||||
for part in class_name.split("."):
|
||||
attr = getattr(attr, part)
|
||||
return attr if isinstance(attr, type) else None
|
||||
|
||||
|
||||
def _encode_checkpoint_value(value: Any) -> Any:
|
||||
"""Recursively encode values into JSON-serializable structures.
|
||||
|
||||
- Objects exposing to_dict/to_json -> { _MODEL_MARKER: "module:Class", value: encoded }
|
||||
- dataclass instances -> { _DATACLASS_MARKER: "module:Class", value: {field: encoded} }
|
||||
- dict -> encode keys as str and values recursively
|
||||
- list/tuple/set -> list of encoded items
|
||||
- other -> returned as-is if already JSON-serializable
|
||||
|
||||
Includes cycle and depth protection to avoid infinite recursion.
|
||||
"""
|
||||
|
||||
def _enc(v: Any, stack: set[int], depth: int) -> Any:
|
||||
# Depth guard
|
||||
if depth > _MAX_ENCODE_DEPTH:
|
||||
logger.debug(f"Max encode depth reached at depth={depth} for type={type(v)}")
|
||||
return "<max_depth>"
|
||||
|
||||
# Structured model handling (objects exposing to_dict/to_json)
|
||||
if _supports_model_protocol(v):
|
||||
cls = cast(type[Any], type(v)) # type: ignore
|
||||
try:
|
||||
if hasattr(v, "to_dict") and callable(getattr(v, "to_dict", None)):
|
||||
raw = v.to_dict() # type: ignore[attr-defined]
|
||||
strategy = "to_dict"
|
||||
elif hasattr(v, "to_json") and callable(getattr(v, "to_json", None)):
|
||||
serialized = v.to_json() # type: ignore[attr-defined]
|
||||
if isinstance(serialized, (bytes, bytearray)):
|
||||
try:
|
||||
serialized = serialized.decode()
|
||||
except Exception:
|
||||
serialized = serialized.decode(errors="replace")
|
||||
raw = serialized
|
||||
strategy = "to_json"
|
||||
else:
|
||||
raise AttributeError("Structured model lacks serialization hooks")
|
||||
return {
|
||||
_MODEL_MARKER: f"{cls.__module__}:{cls.__name__}",
|
||||
"strategy": strategy,
|
||||
"value": _enc(raw, stack, depth + 1),
|
||||
}
|
||||
except Exception as exc: # best-effort fallback
|
||||
logger.debug(f"Structured model serialization failed for {cls}: {exc}")
|
||||
return str(v)
|
||||
|
||||
# Dataclasses (instances only)
|
||||
if is_dataclass(v) and not isinstance(v, type):
|
||||
oid = id(v)
|
||||
if oid in stack:
|
||||
logger.debug("Cycle detected while encoding dataclass instance")
|
||||
return _CYCLE_SENTINEL
|
||||
stack.add(oid)
|
||||
try:
|
||||
# type(v) already narrows sufficiently; cast was redundant
|
||||
dc_cls: type[Any] = type(v)
|
||||
field_values: dict[str, Any] = {}
|
||||
for f in fields(v): # type: ignore[arg-type]
|
||||
field_values[f.name] = _enc(getattr(v, f.name), stack, depth + 1)
|
||||
return {
|
||||
_DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}",
|
||||
"value": field_values,
|
||||
}
|
||||
finally:
|
||||
stack.remove(oid)
|
||||
|
||||
# Collections
|
||||
if isinstance(v, dict):
|
||||
v_dict = cast("dict[object, object]", v)
|
||||
oid = id(v_dict)
|
||||
if oid in stack:
|
||||
logger.debug("Cycle detected while encoding dict")
|
||||
return _CYCLE_SENTINEL
|
||||
stack.add(oid)
|
||||
try:
|
||||
json_dict: dict[str, Any] = {}
|
||||
for k_any, val_any in v_dict.items(): # type: ignore[assignment]
|
||||
k_str: str = str(k_any)
|
||||
json_dict[k_str] = _enc(val_any, stack, depth + 1)
|
||||
return json_dict
|
||||
finally:
|
||||
stack.remove(oid)
|
||||
|
||||
if isinstance(v, (list, tuple, set)):
|
||||
iterable_v = cast("list[object] | tuple[object, ...] | set[object]", v)
|
||||
oid = id(iterable_v)
|
||||
if oid in stack:
|
||||
logger.debug("Cycle detected while encoding iterable")
|
||||
return _CYCLE_SENTINEL
|
||||
stack.add(oid)
|
||||
try:
|
||||
seq: list[object] = list(iterable_v)
|
||||
encoded_list: list[Any] = []
|
||||
for item in seq:
|
||||
encoded_list.append(_enc(item, stack, depth + 1))
|
||||
return encoded_list
|
||||
finally:
|
||||
stack.remove(oid)
|
||||
|
||||
# Primitives (or unknown objects): ensure JSON-serializable
|
||||
if isinstance(v, (str, int, float, bool)) or v is None:
|
||||
return v
|
||||
# Fallback: stringify unknown objects to avoid JSON serialization errors
|
||||
try:
|
||||
return str(v)
|
||||
except Exception:
|
||||
return f"<{type(v).__name__}>"
|
||||
|
||||
return _enc(value, set(), 0)
|
||||
|
||||
|
||||
def _decode_checkpoint_value(value: Any) -> Any:
|
||||
"""Recursively decode values previously encoded by _encode_checkpoint_value."""
|
||||
if isinstance(value, dict):
|
||||
value_dict = cast(dict[str, Any], value) # encoded form always uses string keys
|
||||
# Structured model marker handling
|
||||
if _MODEL_MARKER in value_dict and "value" in value_dict:
|
||||
type_key: str | None = value_dict.get(_MODEL_MARKER) # type: ignore[assignment]
|
||||
strategy: str | None = value_dict.get("strategy") # type: ignore[assignment]
|
||||
raw_encoded: Any = value_dict.get("value")
|
||||
decoded_payload = _decode_checkpoint_value(raw_encoded)
|
||||
if isinstance(type_key, str):
|
||||
try:
|
||||
cls = _import_qualified_name(type_key)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Failed to import structured model {type_key}: {exc}")
|
||||
cls = None
|
||||
|
||||
if cls is not None:
|
||||
if strategy == "to_dict" and hasattr(cls, "from_dict"):
|
||||
with contextlib.suppress(Exception):
|
||||
return cls.from_dict(decoded_payload)
|
||||
if strategy == "to_json" and hasattr(cls, "from_json"):
|
||||
if isinstance(decoded_payload, (str, bytes, bytearray)):
|
||||
with contextlib.suppress(Exception):
|
||||
return cls.from_json(decoded_payload)
|
||||
if isinstance(decoded_payload, dict) and hasattr(cls, "from_dict"):
|
||||
with contextlib.suppress(Exception):
|
||||
return cls.from_dict(decoded_payload)
|
||||
return decoded_payload
|
||||
# Dataclass marker handling
|
||||
if _DATACLASS_MARKER in value_dict and "value" in value_dict:
|
||||
type_key_dc: str | None = value_dict.get(_DATACLASS_MARKER) # type: ignore[assignment]
|
||||
raw_dc: Any = value_dict.get("value")
|
||||
decoded_raw = _decode_checkpoint_value(raw_dc)
|
||||
if isinstance(type_key_dc, str):
|
||||
try:
|
||||
module_name, class_name = type_key_dc.split(":", 1)
|
||||
module = sys.modules.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
cls_dc: Any = getattr(module, class_name)
|
||||
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
|
||||
if constructed is not None:
|
||||
return constructed
|
||||
except Exception as exc:
|
||||
logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value")
|
||||
return decoded_raw
|
||||
|
||||
# Regular dict: decode recursively
|
||||
decoded: dict[str, Any] = {}
|
||||
for k_any, v_any in value_dict.items():
|
||||
decoded[k_any] = _decode_checkpoint_value(v_any)
|
||||
return decoded
|
||||
if isinstance(value, list):
|
||||
# After isinstance check, treat value as list[Any] for decoding
|
||||
value_list: list[Any] = value # type: ignore[assignment]
|
||||
return [_decode_checkpoint_value(v_any) for v_any in value_list]
|
||||
return value
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class RunnerContext(Protocol):
|
||||
"""Protocol for the execution context used by the runner.
|
||||
@@ -355,7 +116,7 @@ class RunnerContext(Protocol):
|
||||
"""Wait for and return the next event emitted by the workflow run."""
|
||||
...
|
||||
|
||||
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None:
|
||||
async def set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
|
||||
"""Set the state for a specific executor.
|
||||
|
||||
Args:
|
||||
@@ -364,7 +125,7 @@ class RunnerContext(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_state(self, executor_id: str) -> dict[str, Any] | None:
|
||||
async def get_executor_state(self, executor_id: str) -> dict[str, Any] | None:
|
||||
"""Get the state for a specific executor.
|
||||
|
||||
Args:
|
||||
@@ -417,30 +178,19 @@ class RunnerContext(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def restore_from_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""Restore the context from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The ID of the checkpoint to restore from.
|
||||
|
||||
Returns:
|
||||
True if the restoration was successful, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
|
||||
"""Load a checkpoint without mutating the current context state."""
|
||||
...
|
||||
|
||||
async def get_checkpoint_state(self) -> CheckpointState:
|
||||
"""Get the current state of the context suitable for checkpointing."""
|
||||
async def get_workflow_state(self) -> WorkflowState:
|
||||
"""Get the current state of the workflow suitable for checkpointing."""
|
||||
...
|
||||
|
||||
async def set_checkpoint_state(self, state: CheckpointState) -> None:
|
||||
"""Set the state of the context from a checkpoint.
|
||||
async def set_workflow_state(self, state: WorkflowState) -> None:
|
||||
"""Set the state of the workflow from a checkpoint.
|
||||
|
||||
Args:
|
||||
state: The state data to set for the context.
|
||||
state: The state data to set for the workflow.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -509,10 +259,10 @@ class InProcRunnerContext:
|
||||
"""
|
||||
return await self._event_queue.get()
|
||||
|
||||
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None:
|
||||
async def set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
|
||||
self._executor_states[executor_id] = state
|
||||
|
||||
async def get_state(self, executor_id: str) -> dict[str, Any] | None:
|
||||
async def get_executor_state(self, executor_id: str) -> dict[str, Any] | None:
|
||||
return self._executor_states.get(executor_id)
|
||||
|
||||
def has_checkpointing(self) -> bool:
|
||||
@@ -554,7 +304,7 @@ class InProcRunnerContext:
|
||||
|
||||
wf_id = self._workflow_id or str(uuid.uuid4())
|
||||
self._workflow_id = wf_id
|
||||
state = await self.get_checkpoint_state()
|
||||
state = await self.get_workflow_state()
|
||||
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
workflow_id=wf_id,
|
||||
@@ -569,38 +319,17 @@ class InProcRunnerContext:
|
||||
logger.info(f"Created checkpoint {checkpoint_id} for workflow {wf_id}'")
|
||||
return checkpoint_id
|
||||
|
||||
async def restore_from_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
if not self._checkpoint_storage:
|
||||
raise ValueError("Checkpoint storage not configured")
|
||||
|
||||
checkpoint = await self._checkpoint_storage.load_checkpoint(checkpoint_id)
|
||||
if not checkpoint:
|
||||
logger.error(f"Checkpoint {checkpoint_id} not found")
|
||||
return False
|
||||
|
||||
state: CheckpointState = {
|
||||
"messages": checkpoint.messages,
|
||||
"shared_state": checkpoint.shared_state,
|
||||
"executor_states": checkpoint.executor_states,
|
||||
"iteration_count": checkpoint.iteration_count,
|
||||
"max_iterations": checkpoint.max_iterations,
|
||||
}
|
||||
await self.set_checkpoint_state(state)
|
||||
self._workflow_id = checkpoint.workflow_id
|
||||
logger.info(f"Restored state from checkpoint {checkpoint_id}'")
|
||||
return True
|
||||
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
|
||||
if not self._checkpoint_storage:
|
||||
raise ValueError("Checkpoint storage not configured")
|
||||
return await self._checkpoint_storage.load_checkpoint(checkpoint_id)
|
||||
|
||||
async def get_checkpoint_state(self) -> CheckpointState:
|
||||
async def get_workflow_state(self) -> WorkflowState:
|
||||
serializable_messages: dict[str, list[dict[str, Any]]] = {}
|
||||
for source_id, message_list in self._messages.items():
|
||||
serializable_messages[source_id] = [
|
||||
{
|
||||
"data": _encode_checkpoint_value(msg.data),
|
||||
"data": encode_checkpoint_value(msg.data),
|
||||
"source_id": msg.source_id,
|
||||
"target_id": msg.target_id,
|
||||
"trace_contexts": msg.trace_contexts,
|
||||
@@ -608,21 +337,22 @@ class InProcRunnerContext:
|
||||
}
|
||||
for msg in message_list
|
||||
]
|
||||
|
||||
return {
|
||||
"messages": serializable_messages,
|
||||
"shared_state": _encode_checkpoint_value(self._shared_state),
|
||||
"executor_states": _encode_checkpoint_value(self._executor_states),
|
||||
"shared_state": encode_checkpoint_value(self._shared_state),
|
||||
"executor_states": encode_checkpoint_value(self._executor_states),
|
||||
"iteration_count": self._iteration_count,
|
||||
"max_iterations": self._max_iterations,
|
||||
}
|
||||
|
||||
async def set_checkpoint_state(self, state: CheckpointState) -> None:
|
||||
async def set_workflow_state(self, state: WorkflowState) -> None:
|
||||
self._messages.clear()
|
||||
messages_data = state.get("messages", {})
|
||||
for source_id, message_list in messages_data.items():
|
||||
self._messages[source_id] = [
|
||||
Message(
|
||||
data=_decode_checkpoint_value(msg.get("data")),
|
||||
data=decode_checkpoint_value(msg.get("data")),
|
||||
source_id=msg.get("source_id", ""),
|
||||
target_id=msg.get("target_id"),
|
||||
trace_contexts=msg.get("trace_contexts"),
|
||||
@@ -631,14 +361,14 @@ class InProcRunnerContext:
|
||||
for msg in message_list
|
||||
]
|
||||
# Restore shared_state
|
||||
decoded_shared_raw = _decode_checkpoint_value(state.get("shared_state", {}))
|
||||
decoded_shared_raw = decode_checkpoint_value(state.get("shared_state", {}))
|
||||
if isinstance(decoded_shared_raw, dict):
|
||||
self._shared_state = cast(dict[str, Any], decoded_shared_raw)
|
||||
else: # fallback to empty dict if corrupted
|
||||
self._shared_state = {}
|
||||
|
||||
# Restore executor_states ensuring value types are dicts
|
||||
decoded_exec_raw = _decode_checkpoint_value(state.get("executor_states", {}))
|
||||
decoded_exec_raw = decode_checkpoint_value(state.get("executor_states", {}))
|
||||
if isinstance(decoded_exec_raw, dict):
|
||||
typed_exec: dict[str, dict[str, Any]] = {}
|
||||
for k_raw, v_raw in decoded_exec_raw.items(): # type: ignore[assignment]
|
||||
|
||||
@@ -51,7 +51,8 @@ from ._executor import (
|
||||
Executor,
|
||||
handler,
|
||||
)
|
||||
from ._workflow import Workflow, WorkflowBuilder
|
||||
from ._workflow import Workflow
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -6,25 +6,16 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from .._agents import AgentProtocol
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent import WorkflowAgent
|
||||
from ._agent_executor import AgentExecutor
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._const import DEFAULT_MAX_ITERATIONS
|
||||
from ._edge import (
|
||||
Case,
|
||||
Default,
|
||||
EdgeGroup,
|
||||
FanInEdgeGroup,
|
||||
FanOutEdgeGroup,
|
||||
SingleEdgeGroup,
|
||||
SwitchCaseEdgeGroup,
|
||||
SwitchCaseEdgeGroupCase,
|
||||
SwitchCaseEdgeGroupDefault,
|
||||
)
|
||||
from ._events import (
|
||||
RequestInfoEvent,
|
||||
@@ -41,15 +32,14 @@ from ._executor import Executor
|
||||
from ._model_utils import DictConvertible
|
||||
from ._request_info_executor import RequestInfoExecutor
|
||||
from ._runner import Runner
|
||||
from ._runner_context import InProcRunnerContext, RunnerContext
|
||||
from ._runner_context import RunnerContext
|
||||
from ._shared_state import SharedState
|
||||
from ._validation import validate_workflow_graph
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
pass # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -858,422 +848,3 @@ class Workflow(DictConvertible):
|
||||
from ._agent import WorkflowAgent
|
||||
|
||||
return WorkflowAgent(workflow=self, name=name)
|
||||
|
||||
|
||||
# region WorkflowBuilder
|
||||
|
||||
|
||||
class WorkflowBuilder:
|
||||
"""A builder class for constructing workflows.
|
||||
|
||||
This class provides methods to add edges and set the starting executor for the workflow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
):
|
||||
"""Initialize the WorkflowBuilder with an empty list of edges and no starting executor.
|
||||
|
||||
Args:
|
||||
max_iterations: Maximum number of iterations for workflow convergence.
|
||||
name: Optional human-readable name for the workflow.
|
||||
description: Optional description of what the workflow does.
|
||||
"""
|
||||
self._edge_groups: list[EdgeGroup] = []
|
||||
self._executors: dict[str, Executor] = {}
|
||||
self._duplicate_executor_ids: set[str] = set()
|
||||
self._start_executor: Executor | str | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = None
|
||||
self._max_iterations: int = max_iterations
|
||||
self._name: str | None = name
|
||||
self._description: str | None = description
|
||||
# Maps underlying AgentProtocol object id -> wrapped Executor so we reuse the same wrapper
|
||||
# across set_start_executor / add_edge calls. Without this, unnamed agents (which receive
|
||||
# random UUID based executor ids) end up wrapped multiple times, giving different ids for
|
||||
# the start node vs edge nodes and triggering a GraphConnectivityError during validation.
|
||||
self._agent_wrappers: dict[int, Executor] = {}
|
||||
|
||||
# Agents auto-wrapped by builder now always stream incremental updates.
|
||||
|
||||
def _add_executor(self, executor: Executor) -> str:
|
||||
"""Add an executor to the map and return its ID."""
|
||||
existing = self._executors.get(executor.id)
|
||||
if existing is not None and existing is not executor:
|
||||
self._duplicate_executor_ids.add(executor.id)
|
||||
else:
|
||||
self._executors[executor.id] = executor
|
||||
return executor.id
|
||||
|
||||
def _maybe_wrap_agent(
|
||||
self,
|
||||
candidate: Executor | AgentProtocol,
|
||||
agent_thread: Any | None = None,
|
||||
output_response: bool = False,
|
||||
executor_id: str | None = None,
|
||||
) -> Executor:
|
||||
"""If the provided object implements AgentProtocol, wrap it in an AgentExecutor.
|
||||
|
||||
This allows fluent builder APIs to directly accept agents instead of
|
||||
requiring callers to manually instantiate AgentExecutor.
|
||||
|
||||
Args:
|
||||
candidate: The executor or agent to wrap.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
|
||||
executor_id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
"""
|
||||
try: # Local import to avoid hard dependency at import time
|
||||
from agent_framework import AgentProtocol # type: ignore
|
||||
except Exception: # pragma: no cover - defensive
|
||||
AgentProtocol = object # type: ignore
|
||||
|
||||
if isinstance(candidate, Executor): # Already an executor
|
||||
return candidate
|
||||
if isinstance(candidate, AgentProtocol): # type: ignore[arg-type]
|
||||
# Reuse existing wrapper for the same agent instance if present
|
||||
agent_instance_id = id(candidate)
|
||||
existing = self._agent_wrappers.get(agent_instance_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
# Use agent name if available and unique among current executors
|
||||
name = getattr(candidate, "name", None)
|
||||
proposed_id: str | None = executor_id
|
||||
if proposed_id is None and name:
|
||||
proposed_id = str(name)
|
||||
if proposed_id in self._executors:
|
||||
raise ValueError(
|
||||
f"Duplicate executor ID '{proposed_id}' from agent name. "
|
||||
"Agent names must be unique within a workflow."
|
||||
)
|
||||
wrapper = AgentExecutor(
|
||||
candidate,
|
||||
agent_thread=agent_thread,
|
||||
output_response=output_response,
|
||||
id=proposed_id,
|
||||
)
|
||||
self._agent_wrappers[agent_instance_id] = wrapper
|
||||
return wrapper
|
||||
raise TypeError(
|
||||
f"WorkflowBuilder expected an Executor or AgentProtocol instance; got {type(candidate).__name__}."
|
||||
)
|
||||
|
||||
def add_agent(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent_thread: Any | None = None,
|
||||
output_response: bool = False,
|
||||
id: str | None = None,
|
||||
) -> Self:
|
||||
"""Add an agent to the workflow by wrapping it in an AgentExecutor.
|
||||
|
||||
This method creates an AgentExecutor that wraps the agent with the given parameters
|
||||
and ensures that subsequent uses of the same agent instance in other builder methods
|
||||
(like add_edge, set_start_executor, etc.) will reuse the same wrapped executor.
|
||||
|
||||
Note: Agents adapt their behavior based on how the workflow is executed:
|
||||
- run_stream(): Agents emit incremental AgentRunUpdateEvent events as tokens are produced
|
||||
- run(): Agents emit a single AgentRunEvent containing the complete response
|
||||
|
||||
Args:
|
||||
agent: The agent to add to the workflow.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
|
||||
id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
|
||||
Returns:
|
||||
The WorkflowBuilder instance (for method chaining).
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided id or agent name conflicts with an existing executor.
|
||||
"""
|
||||
executor = self._maybe_wrap_agent(
|
||||
agent, agent_thread=agent_thread, output_response=output_response, executor_id=id
|
||||
)
|
||||
self._add_executor(executor)
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
target: Executor | AgentProtocol,
|
||||
condition: Callable[[Any], bool] | None = None,
|
||||
) -> Self:
|
||||
"""Add a directed edge between two executors.
|
||||
|
||||
The output types of the source and the input types of the target must be compatible.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edge.
|
||||
target: The target executor of the edge.
|
||||
condition: An optional condition function that determines whether the edge
|
||||
should be traversed based on the message type.
|
||||
"""
|
||||
# TODO(@taochen): Support executor factories for lazy initialization
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_exec = self._maybe_wrap_agent(target)
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_id = self._add_executor(target_exec)
|
||||
self._edge_groups.append(SingleEdgeGroup(source_id, target_id, condition)) # type: ignore[call-arg]
|
||||
return self
|
||||
|
||||
def add_fan_out_edges(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
targets: Sequence[Executor | AgentProtocol],
|
||||
) -> Self:
|
||||
"""Add multiple edges to the workflow where messages from the source will be sent to all target.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
targets: A list of target executors for the edges.
|
||||
"""
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets]
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_ids = [self._add_executor(t) for t in target_execs]
|
||||
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_switch_case_edge_group(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
cases: Sequence[Case | Default],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a switch-case statement.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
Messages from the source executor will be sent to one of the target executors based on
|
||||
the provided conditions.
|
||||
|
||||
Think of this as a switch statement where each target executor corresponds to a case.
|
||||
Each condition function will be evaluated in order, and the first one that returns True
|
||||
will determine which target executor receives the message.
|
||||
|
||||
The last case (the default case) will receive messages that fall through all conditions
|
||||
(i.e., no condition matched).
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
cases: A list of case objects that determine the target executor for each message.
|
||||
"""
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
source_id = self._add_executor(source_exec)
|
||||
# Convert case data types to internal types that only uses target_id.
|
||||
internal_cases: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = []
|
||||
for case in cases:
|
||||
# Allow case targets to be agents
|
||||
case.target = self._maybe_wrap_agent(case.target) # type: ignore[attr-defined]
|
||||
self._add_executor(case.target)
|
||||
if isinstance(case, Default):
|
||||
internal_cases.append(SwitchCaseEdgeGroupDefault(target_id=case.target.id))
|
||||
else:
|
||||
internal_cases.append(SwitchCaseEdgeGroupCase(condition=case.condition, target_id=case.target.id))
|
||||
self._edge_groups.append(SwitchCaseEdgeGroup(source_id, internal_cases)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_multi_selection_edge_group(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
targets: Sequence[Executor | AgentProtocol],
|
||||
selection_func: Callable[[Any, list[str]], list[str]],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a multi-selection execution model.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
Messages from the source executor will be sent to multiple target executors based on
|
||||
the provided selection function.
|
||||
|
||||
The selection function should take a message and the name of the target executors,
|
||||
and return a list of indices indicating which target executors should receive the message.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
targets: A list of target executors for the edges.
|
||||
selection_func: A function that selects target executors for messages.
|
||||
"""
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets]
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_ids = [self._add_executor(t) for t in target_execs]
|
||||
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids, selection_func)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_fan_in_edges(
|
||||
self,
|
||||
sources: Sequence[Executor | AgentProtocol],
|
||||
target: Executor | AgentProtocol,
|
||||
) -> Self:
|
||||
"""Add multiple edges from sources to a single target executor.
|
||||
|
||||
The edges will be grouped together for synchronized processing, meaning
|
||||
the target executor will only be executed once all source executors have completed.
|
||||
|
||||
The target executor will receive a list of messages aggregated from all source executors.
|
||||
Thus the input types of the target executor must be compatible with a list of the output
|
||||
types of the source executors. For example:
|
||||
|
||||
class Target(Executor):
|
||||
@handler
|
||||
def handle_messages(self, messages: list[Message]) -> None:
|
||||
# Process the aggregated messages from all sources
|
||||
|
||||
class Source(Executor):
|
||||
@handler(output_type=[Message])
|
||||
def handle_message(self, message: Message) -> None:
|
||||
# Send a message to the target executor
|
||||
self.send_message(message)
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.add_fan_in_edges(
|
||||
[Source(id="source1"), Source(id="source2")],
|
||||
Target(id="target")
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
Args:
|
||||
sources: A list of source executors for the edges.
|
||||
target: The target executor for the edges.
|
||||
"""
|
||||
source_execs = [self._maybe_wrap_agent(s) for s in sources]
|
||||
target_exec = self._maybe_wrap_agent(target)
|
||||
source_ids = [self._add_executor(s) for s in source_execs]
|
||||
target_id = self._add_executor(target_exec)
|
||||
self._edge_groups.append(FanInEdgeGroup(source_ids, target_id)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_chain(self, executors: Sequence[Executor | AgentProtocol]) -> Self:
|
||||
"""Add a chain of executors to the workflow.
|
||||
|
||||
The output of each executor in the chain will be sent to the next executor in the chain.
|
||||
The input types of each executor must be compatible with the output types of the previous executor.
|
||||
|
||||
Circles in the chain are not allowed, meaning the chain cannot have two executors with the same ID.
|
||||
|
||||
Args:
|
||||
executors: A list of executors to be added to the chain.
|
||||
"""
|
||||
# Wrap each candidate first to ensure stable IDs before adding edges
|
||||
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors]
|
||||
for i in range(len(wrapped) - 1):
|
||||
self.add_edge(wrapped[i], wrapped[i + 1])
|
||||
return self
|
||||
|
||||
def set_start_executor(self, executor: Executor | AgentProtocol | str) -> Self:
|
||||
"""Set the starting executor for the workflow.
|
||||
|
||||
Args:
|
||||
executor: The starting executor, which can be an Executor instance or its ID.
|
||||
"""
|
||||
if isinstance(executor, str):
|
||||
self._start_executor = executor
|
||||
else:
|
||||
wrapped = self._maybe_wrap_agent(executor) # type: ignore[arg-type]
|
||||
self._start_executor = wrapped
|
||||
# Ensure the start executor is present in the executor map so validation succeeds
|
||||
# even if no edges are added yet, or before edges wrap the same agent again.
|
||||
existing = self._executors.get(wrapped.id)
|
||||
if existing is not wrapped:
|
||||
self._add_executor(wrapped)
|
||||
return self
|
||||
|
||||
def set_max_iterations(self, max_iterations: int) -> Self:
|
||||
"""Set the maximum number of iterations for the workflow.
|
||||
|
||||
Args:
|
||||
max_iterations: The maximum number of iterations the workflow will run for convergence.
|
||||
"""
|
||||
self._max_iterations = max_iterations
|
||||
return self
|
||||
|
||||
# Removed explicit set_agent_streaming() API; agents always stream updates.
|
||||
|
||||
def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> Self:
|
||||
"""Enable checkpointing with the specified storage.
|
||||
|
||||
Args:
|
||||
checkpoint_storage: The checkpoint storage to use.
|
||||
"""
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
return self
|
||||
|
||||
def build(self) -> Workflow:
|
||||
"""Build and return the constructed workflow.
|
||||
|
||||
This method performs validation before building the workflow.
|
||||
|
||||
Returns:
|
||||
A Workflow instance with the defined edges and starting executor.
|
||||
|
||||
Raises:
|
||||
ValueError: If starting executor is not set.
|
||||
WorkflowValidationError: If workflow validation fails (includes EdgeDuplicationError,
|
||||
TypeCompatibilityError, and GraphConnectivityError subclasses).
|
||||
"""
|
||||
# Create workflow build span that includes validation and workflow creation
|
||||
with create_workflow_span(OtelAttr.WORKFLOW_BUILD_SPAN) as span:
|
||||
try:
|
||||
# Add workflow build started event
|
||||
span.add_event(OtelAttr.BUILD_STARTED)
|
||||
|
||||
if not self._start_executor:
|
||||
raise ValueError(
|
||||
"Starting executor must be set using set_start_executor before building the workflow."
|
||||
)
|
||||
|
||||
# Perform validation before creating the workflow
|
||||
validate_workflow_graph(
|
||||
self._edge_groups,
|
||||
self._executors,
|
||||
self._start_executor,
|
||||
duplicate_executor_ids=tuple(self._duplicate_executor_ids),
|
||||
)
|
||||
|
||||
# Add validation completed event
|
||||
span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED)
|
||||
|
||||
context = InProcRunnerContext(self._checkpoint_storage)
|
||||
|
||||
# Create workflow instance after validation
|
||||
workflow = Workflow(
|
||||
self._edge_groups,
|
||||
self._executors,
|
||||
self._start_executor,
|
||||
context,
|
||||
self._max_iterations,
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
)
|
||||
build_attributes: dict[str, Any] = {
|
||||
OtelAttr.WORKFLOW_ID: workflow.id,
|
||||
OtelAttr.WORKFLOW_DEFINITION: workflow.to_json(),
|
||||
}
|
||||
if workflow.name:
|
||||
build_attributes[OtelAttr.WORKFLOW_NAME] = workflow.name
|
||||
if workflow.description:
|
||||
build_attributes[OtelAttr.WORKFLOW_DESCRIPTION] = workflow.description
|
||||
span.set_attributes(build_attributes)
|
||||
|
||||
# Add workflow build completed event
|
||||
span.add_event(OtelAttr.BUILD_COMPLETED)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as exc:
|
||||
attributes = {
|
||||
OtelAttr.BUILD_ERROR_MESSAGE: str(exc),
|
||||
OtelAttr.BUILD_ERROR_TYPE: type(exc).__name__,
|
||||
}
|
||||
span.add_event(OtelAttr.BUILD_ERROR, attributes) # type: ignore[reportArgumentType, arg-type]
|
||||
capture_exception(span, exc)
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,451 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from .._agents import AgentProtocol
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent_executor import AgentExecutor
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._const import DEFAULT_MAX_ITERATIONS
|
||||
from ._edge import (
|
||||
Case,
|
||||
Default,
|
||||
EdgeGroup,
|
||||
FanInEdgeGroup,
|
||||
FanOutEdgeGroup,
|
||||
SingleEdgeGroup,
|
||||
SwitchCaseEdgeGroup,
|
||||
SwitchCaseEdgeGroupCase,
|
||||
SwitchCaseEdgeGroupDefault,
|
||||
)
|
||||
from ._executor import Executor
|
||||
from ._runner_context import InProcRunnerContext
|
||||
from ._validation import validate_workflow_graph
|
||||
from ._workflow import Workflow
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowBuilder:
|
||||
"""A builder class for constructing workflows.
|
||||
|
||||
This class provides methods to add edges and set the starting executor for the workflow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
):
|
||||
"""Initialize the WorkflowBuilder with an empty list of edges and no starting executor.
|
||||
|
||||
Args:
|
||||
max_iterations: Maximum number of iterations for workflow convergence.
|
||||
name: Optional human-readable name for the workflow.
|
||||
description: Optional description of what the workflow does.
|
||||
"""
|
||||
self._edge_groups: list[EdgeGroup] = []
|
||||
self._executors: dict[str, Executor] = {}
|
||||
self._duplicate_executor_ids: set[str] = set()
|
||||
self._start_executor: Executor | str | None = None
|
||||
self._checkpoint_storage: CheckpointStorage | None = None
|
||||
self._max_iterations: int = max_iterations
|
||||
self._name: str | None = name
|
||||
self._description: str | None = description
|
||||
# Maps underlying AgentProtocol object id -> wrapped Executor so we reuse the same wrapper
|
||||
# across set_start_executor / add_edge calls. Without this, unnamed agents (which receive
|
||||
# random UUID based executor ids) end up wrapped multiple times, giving different ids for
|
||||
# the start node vs edge nodes and triggering a GraphConnectivityError during validation.
|
||||
self._agent_wrappers: dict[int, Executor] = {}
|
||||
|
||||
# Agents auto-wrapped by builder now always stream incremental updates.
|
||||
|
||||
def _add_executor(self, executor: Executor) -> str:
|
||||
"""Add an executor to the map and return its ID."""
|
||||
existing = self._executors.get(executor.id)
|
||||
if existing is not None and existing is not executor:
|
||||
self._duplicate_executor_ids.add(executor.id)
|
||||
else:
|
||||
self._executors[executor.id] = executor
|
||||
return executor.id
|
||||
|
||||
def _maybe_wrap_agent(
|
||||
self,
|
||||
candidate: Executor | AgentProtocol,
|
||||
agent_thread: Any | None = None,
|
||||
output_response: bool = False,
|
||||
executor_id: str | None = None,
|
||||
) -> Executor:
|
||||
"""If the provided object implements AgentProtocol, wrap it in an AgentExecutor.
|
||||
|
||||
This allows fluent builder APIs to directly accept agents instead of
|
||||
requiring callers to manually instantiate AgentExecutor.
|
||||
|
||||
Args:
|
||||
candidate: The executor or agent to wrap.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
|
||||
executor_id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
"""
|
||||
try: # Local import to avoid hard dependency at import time
|
||||
from agent_framework import AgentProtocol # type: ignore
|
||||
except Exception: # pragma: no cover - defensive
|
||||
AgentProtocol = object # type: ignore
|
||||
|
||||
if isinstance(candidate, Executor): # Already an executor
|
||||
return candidate
|
||||
if isinstance(candidate, AgentProtocol): # type: ignore[arg-type]
|
||||
# Reuse existing wrapper for the same agent instance if present
|
||||
agent_instance_id = id(candidate)
|
||||
existing = self._agent_wrappers.get(agent_instance_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
# Use agent name if available and unique among current executors
|
||||
name = getattr(candidate, "name", None)
|
||||
proposed_id: str | None = executor_id
|
||||
if proposed_id is None and name:
|
||||
proposed_id = str(name)
|
||||
if proposed_id in self._executors:
|
||||
raise ValueError(
|
||||
f"Duplicate executor ID '{proposed_id}' from agent name. "
|
||||
"Agent names must be unique within a workflow."
|
||||
)
|
||||
wrapper = AgentExecutor(
|
||||
candidate,
|
||||
agent_thread=agent_thread,
|
||||
output_response=output_response,
|
||||
id=proposed_id,
|
||||
)
|
||||
self._agent_wrappers[agent_instance_id] = wrapper
|
||||
return wrapper
|
||||
raise TypeError(
|
||||
f"WorkflowBuilder expected an Executor or AgentProtocol instance; got {type(candidate).__name__}."
|
||||
)
|
||||
|
||||
def add_agent(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent_thread: Any | None = None,
|
||||
output_response: bool = False,
|
||||
id: str | None = None,
|
||||
) -> Self:
|
||||
"""Add an agent to the workflow by wrapping it in an AgentExecutor.
|
||||
|
||||
This method creates an AgentExecutor that wraps the agent with the given parameters
|
||||
and ensures that subsequent uses of the same agent instance in other builder methods
|
||||
(like add_edge, set_start_executor, etc.) will reuse the same wrapped executor.
|
||||
|
||||
Note: Agents adapt their behavior based on how the workflow is executed:
|
||||
- run_stream(): Agents emit incremental AgentRunUpdateEvent events as tokens are produced
|
||||
- run(): Agents emit a single AgentRunEvent containing the complete response
|
||||
|
||||
Args:
|
||||
agent: The agent to add to the workflow.
|
||||
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
|
||||
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
|
||||
id: A unique identifier for the executor. If None, the agent's name will be used if available.
|
||||
|
||||
Returns:
|
||||
The WorkflowBuilder instance (for method chaining).
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided id or agent name conflicts with an existing executor.
|
||||
"""
|
||||
executor = self._maybe_wrap_agent(
|
||||
agent, agent_thread=agent_thread, output_response=output_response, executor_id=id
|
||||
)
|
||||
self._add_executor(executor)
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
target: Executor | AgentProtocol,
|
||||
condition: Callable[[Any], bool] | None = None,
|
||||
) -> Self:
|
||||
"""Add a directed edge between two executors.
|
||||
|
||||
The output types of the source and the input types of the target must be compatible.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edge.
|
||||
target: The target executor of the edge.
|
||||
condition: An optional condition function that determines whether the edge
|
||||
should be traversed based on the message type.
|
||||
"""
|
||||
# TODO(@taochen): Support executor factories for lazy initialization
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_exec = self._maybe_wrap_agent(target)
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_id = self._add_executor(target_exec)
|
||||
self._edge_groups.append(SingleEdgeGroup(source_id, target_id, condition)) # type: ignore[call-arg]
|
||||
return self
|
||||
|
||||
def add_fan_out_edges(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
targets: Sequence[Executor | AgentProtocol],
|
||||
) -> Self:
|
||||
"""Add multiple edges to the workflow where messages from the source will be sent to all target.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
targets: A list of target executors for the edges.
|
||||
"""
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets]
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_ids = [self._add_executor(t) for t in target_execs]
|
||||
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_switch_case_edge_group(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
cases: Sequence[Case | Default],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a switch-case statement.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
Messages from the source executor will be sent to one of the target executors based on
|
||||
the provided conditions.
|
||||
|
||||
Think of this as a switch statement where each target executor corresponds to a case.
|
||||
Each condition function will be evaluated in order, and the first one that returns True
|
||||
will determine which target executor receives the message.
|
||||
|
||||
The last case (the default case) will receive messages that fall through all conditions
|
||||
(i.e., no condition matched).
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
cases: A list of case objects that determine the target executor for each message.
|
||||
"""
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
source_id = self._add_executor(source_exec)
|
||||
# Convert case data types to internal types that only uses target_id.
|
||||
internal_cases: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = []
|
||||
for case in cases:
|
||||
# Allow case targets to be agents
|
||||
case.target = self._maybe_wrap_agent(case.target) # type: ignore[attr-defined]
|
||||
self._add_executor(case.target)
|
||||
if isinstance(case, Default):
|
||||
internal_cases.append(SwitchCaseEdgeGroupDefault(target_id=case.target.id))
|
||||
else:
|
||||
internal_cases.append(SwitchCaseEdgeGroupCase(condition=case.condition, target_id=case.target.id))
|
||||
self._edge_groups.append(SwitchCaseEdgeGroup(source_id, internal_cases)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_multi_selection_edge_group(
|
||||
self,
|
||||
source: Executor | AgentProtocol,
|
||||
targets: Sequence[Executor | AgentProtocol],
|
||||
selection_func: Callable[[Any, list[str]], list[str]],
|
||||
) -> Self:
|
||||
"""Add an edge group that represents a multi-selection execution model.
|
||||
|
||||
The output types of the source and the input types of the targets must be compatible.
|
||||
Messages from the source executor will be sent to multiple target executors based on
|
||||
the provided selection function.
|
||||
|
||||
The selection function should take a message and the name of the target executors,
|
||||
and return a list of indices indicating which target executors should receive the message.
|
||||
|
||||
Args:
|
||||
source: The source executor of the edges.
|
||||
targets: A list of target executors for the edges.
|
||||
selection_func: A function that selects target executors for messages.
|
||||
"""
|
||||
source_exec = self._maybe_wrap_agent(source)
|
||||
target_execs = [self._maybe_wrap_agent(t) for t in targets]
|
||||
source_id = self._add_executor(source_exec)
|
||||
target_ids = [self._add_executor(t) for t in target_execs]
|
||||
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids, selection_func)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_fan_in_edges(
|
||||
self,
|
||||
sources: Sequence[Executor | AgentProtocol],
|
||||
target: Executor | AgentProtocol,
|
||||
) -> Self:
|
||||
"""Add multiple edges from sources to a single target executor.
|
||||
|
||||
The edges will be grouped together for synchronized processing, meaning
|
||||
the target executor will only be executed once all source executors have completed.
|
||||
|
||||
The target executor will receive a list of messages aggregated from all source executors.
|
||||
Thus the input types of the target executor must be compatible with a list of the output
|
||||
types of the source executors. For example:
|
||||
|
||||
class Target(Executor):
|
||||
@handler
|
||||
def handle_messages(self, messages: list[Message]) -> None:
|
||||
# Process the aggregated messages from all sources
|
||||
|
||||
class Source(Executor):
|
||||
@handler(output_type=[Message])
|
||||
def handle_message(self, message: Message) -> None:
|
||||
# Send a message to the target executor
|
||||
self.send_message(message)
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder()
|
||||
.add_fan_in_edges(
|
||||
[Source(id="source1"), Source(id="source2")],
|
||||
Target(id="target")
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
Args:
|
||||
sources: A list of source executors for the edges.
|
||||
target: The target executor for the edges.
|
||||
"""
|
||||
source_execs = [self._maybe_wrap_agent(s) for s in sources]
|
||||
target_exec = self._maybe_wrap_agent(target)
|
||||
source_ids = [self._add_executor(s) for s in source_execs]
|
||||
target_id = self._add_executor(target_exec)
|
||||
self._edge_groups.append(FanInEdgeGroup(source_ids, target_id)) # type: ignore[call-arg]
|
||||
|
||||
return self
|
||||
|
||||
def add_chain(self, executors: Sequence[Executor | AgentProtocol]) -> Self:
|
||||
"""Add a chain of executors to the workflow.
|
||||
|
||||
The output of each executor in the chain will be sent to the next executor in the chain.
|
||||
The input types of each executor must be compatible with the output types of the previous executor.
|
||||
|
||||
Circles in the chain are not allowed, meaning the chain cannot have two executors with the same ID.
|
||||
|
||||
Args:
|
||||
executors: A list of executors to be added to the chain.
|
||||
"""
|
||||
# Wrap each candidate first to ensure stable IDs before adding edges
|
||||
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors]
|
||||
for i in range(len(wrapped) - 1):
|
||||
self.add_edge(wrapped[i], wrapped[i + 1])
|
||||
return self
|
||||
|
||||
def set_start_executor(self, executor: Executor | AgentProtocol | str) -> Self:
|
||||
"""Set the starting executor for the workflow.
|
||||
|
||||
Args:
|
||||
executor: The starting executor, which can be an Executor instance or its ID.
|
||||
"""
|
||||
if isinstance(executor, str):
|
||||
self._start_executor = executor
|
||||
else:
|
||||
wrapped = self._maybe_wrap_agent(executor) # type: ignore[arg-type]
|
||||
self._start_executor = wrapped
|
||||
# Ensure the start executor is present in the executor map so validation succeeds
|
||||
# even if no edges are added yet, or before edges wrap the same agent again.
|
||||
existing = self._executors.get(wrapped.id)
|
||||
if existing is not wrapped:
|
||||
self._add_executor(wrapped)
|
||||
return self
|
||||
|
||||
def set_max_iterations(self, max_iterations: int) -> Self:
|
||||
"""Set the maximum number of iterations for the workflow.
|
||||
|
||||
Args:
|
||||
max_iterations: The maximum number of iterations the workflow will run for convergence.
|
||||
"""
|
||||
self._max_iterations = max_iterations
|
||||
return self
|
||||
|
||||
# Removed explicit set_agent_streaming() API; agents always stream updates.
|
||||
|
||||
def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> Self:
|
||||
"""Enable checkpointing with the specified storage.
|
||||
|
||||
Args:
|
||||
checkpoint_storage: The checkpoint storage to use.
|
||||
"""
|
||||
self._checkpoint_storage = checkpoint_storage
|
||||
return self
|
||||
|
||||
def build(self) -> Workflow:
|
||||
"""Build and return the constructed workflow.
|
||||
|
||||
This method performs validation before building the workflow.
|
||||
|
||||
Returns:
|
||||
A Workflow instance with the defined edges and starting executor.
|
||||
|
||||
Raises:
|
||||
ValueError: If starting executor is not set.
|
||||
WorkflowValidationError: If workflow validation fails (includes EdgeDuplicationError,
|
||||
TypeCompatibilityError, and GraphConnectivityError subclasses).
|
||||
"""
|
||||
# Create workflow build span that includes validation and workflow creation
|
||||
with create_workflow_span(OtelAttr.WORKFLOW_BUILD_SPAN) as span:
|
||||
try:
|
||||
# Add workflow build started event
|
||||
span.add_event(OtelAttr.BUILD_STARTED)
|
||||
|
||||
if not self._start_executor:
|
||||
raise ValueError(
|
||||
"Starting executor must be set using set_start_executor before building the workflow."
|
||||
)
|
||||
|
||||
# Perform validation before creating the workflow
|
||||
validate_workflow_graph(
|
||||
self._edge_groups,
|
||||
self._executors,
|
||||
self._start_executor,
|
||||
duplicate_executor_ids=tuple(self._duplicate_executor_ids),
|
||||
)
|
||||
|
||||
# Add validation completed event
|
||||
span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED)
|
||||
|
||||
context = InProcRunnerContext(self._checkpoint_storage)
|
||||
|
||||
# Create workflow instance after validation
|
||||
workflow = Workflow(
|
||||
self._edge_groups,
|
||||
self._executors,
|
||||
self._start_executor,
|
||||
context,
|
||||
self._max_iterations,
|
||||
name=self._name,
|
||||
description=self._description,
|
||||
)
|
||||
build_attributes: dict[str, Any] = {
|
||||
OtelAttr.WORKFLOW_ID: workflow.id,
|
||||
OtelAttr.WORKFLOW_DEFINITION: workflow.to_json(),
|
||||
}
|
||||
if workflow.name:
|
||||
build_attributes[OtelAttr.WORKFLOW_NAME] = workflow.name
|
||||
if workflow.description:
|
||||
build_attributes[OtelAttr.WORKFLOW_DESCRIPTION] = workflow.description
|
||||
span.set_attributes(build_attributes)
|
||||
|
||||
# Add workflow build completed event
|
||||
span.add_event(OtelAttr.BUILD_COMPLETED)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as exc:
|
||||
attributes = {
|
||||
OtelAttr.BUILD_ERROR_MESSAGE: str(exc),
|
||||
OtelAttr.BUILD_ERROR_TYPE: type(exc).__name__,
|
||||
}
|
||||
span.add_event(OtelAttr.BUILD_ERROR, attributes) # type: ignore[reportArgumentType, arg-type]
|
||||
capture_exception(span, exc)
|
||||
raise
|
||||
@@ -435,17 +435,17 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
|
||||
"""Get the shared state."""
|
||||
return self._shared_state
|
||||
|
||||
async def set_state(self, state: dict[str, Any]) -> None:
|
||||
async def set_executor_state(self, state: dict[str, Any]) -> None:
|
||||
"""Persist this executor's state into the checkpointable context.
|
||||
|
||||
Executors call this with a JSON-serializable dict capturing the minimal
|
||||
state needed to resume. It replaces any previously stored state.
|
||||
"""
|
||||
await self._runner_context.set_state(self._executor_id, state)
|
||||
await self._runner_context.set_executor_state(self._executor_id, state)
|
||||
|
||||
async def get_state(self) -> dict[str, Any] | None:
|
||||
async def get_executor_state(self) -> dict[str, Any] | None:
|
||||
"""Retrieve previously persisted state for this executor, if any."""
|
||||
return await self._runner_context.get_state(self._executor_id)
|
||||
return await self._runner_context.get_executor_state(self._executor_id)
|
||||
|
||||
def is_streaming(self) -> bool:
|
||||
"""Check if the workflow is running in streaming mode.
|
||||
|
||||
@@ -167,9 +167,9 @@ class WorkflowExecutor(Executor):
|
||||
@handler
|
||||
async def process(self, data: str, ctx: WorkflowContext[str]) -> None:
|
||||
# Use context state instead of instance variables
|
||||
state = await ctx.get_state() or {}
|
||||
state = await ctx.get_executor_state() or {}
|
||||
state["processed"] = data
|
||||
await ctx.set_state(state)
|
||||
await ctx.set_executor_state(state)
|
||||
|
||||
|
||||
# Avoid: Stateful executor with instance variables
|
||||
@@ -501,7 +501,7 @@ class WorkflowExecutor(Executor):
|
||||
|
||||
state: dict[str, Any] | None = None
|
||||
try:
|
||||
state = await ctx.get_state()
|
||||
state = await ctx.get_executor_state()
|
||||
except Exception:
|
||||
state = None
|
||||
|
||||
@@ -665,6 +665,6 @@ class WorkflowExecutor(Executor):
|
||||
async def _persist_execution_state(self, ctx: WorkflowContext[Any]) -> None:
|
||||
snapshot = self._build_state_snapshot()
|
||||
try:
|
||||
await ctx.set_state(snapshot)
|
||||
await ctx.set_executor_state(snapshot)
|
||||
except Exception as exc: # pragma: no cover - transport specific
|
||||
logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}")
|
||||
|
||||
@@ -4,9 +4,9 @@ from dataclasses import dataclass # noqa: I001
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework._workflows._request_info_executor import RequestInfoMessage, RequestResponse
|
||||
from agent_framework._workflows._runner_context import ( # type: ignore
|
||||
_decode_checkpoint_value, # type: ignore
|
||||
_encode_checkpoint_value, # type: ignore
|
||||
from agent_framework._workflows._checkpoint_encoding import (
|
||||
decode_checkpoint_value,
|
||||
encode_checkpoint_value,
|
||||
)
|
||||
from agent_framework._workflows._typing_utils import is_instance_of
|
||||
|
||||
@@ -23,8 +23,8 @@ def test_decode_dataclass_with_nested_request() -> None:
|
||||
request_id="abc",
|
||||
)
|
||||
|
||||
encoded = _encode_checkpoint_value(original)
|
||||
decoded = cast(RequestResponse[SampleRequest, str], _decode_checkpoint_value(encoded))
|
||||
encoded = encode_checkpoint_value(original)
|
||||
decoded = cast(RequestResponse[SampleRequest, str], decode_checkpoint_value(encoded))
|
||||
|
||||
assert isinstance(decoded, RequestResponse)
|
||||
assert decoded.data == "approve"
|
||||
|
||||
@@ -5,7 +5,8 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agent_framework._workflows._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
||||
from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value
|
||||
from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary
|
||||
from agent_framework._workflows._events import RequestInfoEvent, WorkflowEvent
|
||||
from agent_framework._workflows._request_info_executor import (
|
||||
@@ -16,9 +17,8 @@ from agent_framework._workflows._request_info_executor import (
|
||||
RequestResponse,
|
||||
)
|
||||
from agent_framework._workflows._runner_context import (
|
||||
CheckpointState,
|
||||
Message,
|
||||
_encode_checkpoint_value, # type: ignore
|
||||
WorkflowState,
|
||||
)
|
||||
from agent_framework._workflows._shared_state import SharedState
|
||||
from agent_framework._workflows._workflow_context import WorkflowContext
|
||||
@@ -53,10 +53,10 @@ class _StubRunnerContext:
|
||||
async def next_event(self) -> WorkflowEvent: # pragma: no cover - unused
|
||||
raise RuntimeError("Not implemented in stub context")
|
||||
|
||||
async def get_state(self, executor_id: str) -> dict[str, Any] | None: # pragma: no cover - trivial
|
||||
async def get_executor_state(self, executor_id: str) -> dict[str, Any] | None: # pragma: no cover - trivial
|
||||
return self._state
|
||||
|
||||
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None: # pragma: no cover - unused
|
||||
async def set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None: # pragma: no cover - unused
|
||||
self._state = state
|
||||
|
||||
def has_checkpointing(self) -> bool: # pragma: no cover - unused
|
||||
@@ -71,20 +71,13 @@ class _StubRunnerContext:
|
||||
async def create_checkpoint(self, metadata: dict[str, Any] | None = None) -> str: # pragma: no cover - unused
|
||||
raise RuntimeError("Checkpointing not supported in stub context")
|
||||
|
||||
async def restore_from_checkpoint(
|
||||
self,
|
||||
checkpoint_id: str,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
) -> bool: # pragma: no cover - unused
|
||||
return False
|
||||
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pragma: no cover - unused
|
||||
return None
|
||||
|
||||
async def get_checkpoint_state(self) -> CheckpointState: # pragma: no cover - unused
|
||||
async def get_workflow_state(self) -> WorkflowState: # pragma: no cover - unused
|
||||
return {} # type: ignore[return-value]
|
||||
|
||||
async def set_checkpoint_state(self, state: CheckpointState) -> None: # pragma: no cover - unused
|
||||
async def set_workflow_state(self, state: WorkflowState) -> None: # pragma: no cover - unused
|
||||
pass
|
||||
|
||||
def set_streaming(self, streaming: bool) -> None: # pragma: no cover - unused
|
||||
@@ -178,7 +171,7 @@ def test_pending_requests_from_checkpoint_and_summary() -> None:
|
||||
request_id=request.request_id,
|
||||
)
|
||||
|
||||
encoded_response = _encode_checkpoint_value(response)
|
||||
encoded_response = encode_checkpoint_value(response)
|
||||
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
checkpoint_id="cp-1",
|
||||
|
||||
@@ -439,8 +439,8 @@ async def test_message_trace_context_serialization(span_exporter: InMemorySpanEx
|
||||
|
||||
await ctx.send_message(message)
|
||||
|
||||
# Get checkpoint state (which serializes messages)
|
||||
state = await ctx.get_checkpoint_state()
|
||||
# Get context state (which serializes messages)
|
||||
state = await ctx.get_workflow_state()
|
||||
|
||||
# Check serialized message includes trace context
|
||||
serialized_msg = state["messages"]["source"][0]
|
||||
@@ -448,7 +448,7 @@ async def test_message_trace_context_serialization(span_exporter: InMemorySpanEx
|
||||
assert serialized_msg["source_span_ids"] == ["span123"]
|
||||
|
||||
# Test deserialization
|
||||
await ctx.set_checkpoint_state(state)
|
||||
await ctx.set_workflow_state(state)
|
||||
restored_messages = await ctx.drain_messages()
|
||||
|
||||
restored_msg = list(restored_messages.values())[0][0]
|
||||
|
||||
+5
-5
@@ -140,8 +140,8 @@ class ReviewGateway(Executor):
|
||||
# persist iterations. The `RequestInfoExecutor` relies on this state to
|
||||
# rehydrate when checkpoints are restored.
|
||||
draft = response.agent_run_response.text or ""
|
||||
iteration = int((await ctx.get_state() or {}).get("iteration", 0)) + 1
|
||||
await ctx.set_state({"iteration": iteration, "last_draft": draft})
|
||||
iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1
|
||||
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
|
||||
# Emit a human approval request. Because this flows through
|
||||
# RequestInfoExecutor it will pause the workflow until an answer is
|
||||
# supplied either interactively or via pre-supplied responses.
|
||||
@@ -163,7 +163,7 @@ class ReviewGateway(Executor):
|
||||
# The RequestResponse wrapper gives us both the human data and the
|
||||
# original request message, even when resuming from checkpoints.
|
||||
reply = (feedback.data or "").strip()
|
||||
state = await ctx.get_state() or {}
|
||||
state = await ctx.get_executor_state() or {}
|
||||
draft = state.get("last_draft") or (feedback.original_request.draft if feedback.original_request else "")
|
||||
|
||||
if reply.lower() == "approve":
|
||||
@@ -175,7 +175,7 @@ class ReviewGateway(Executor):
|
||||
# Any other response loops us back to the writer with fresh guidance.
|
||||
guidance = reply or "Tighten the copy and emphasise customer benefit."
|
||||
iteration = int(state.get("iteration", 1)) + 1
|
||||
await ctx.set_state({"iteration": iteration, "last_draft": draft})
|
||||
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
|
||||
prompt = (
|
||||
"Revise the launch note. Respond with the new copy only.\n\n"
|
||||
f"Previous draft:\n{draft}\n\n"
|
||||
@@ -193,7 +193,7 @@ class FinaliseExecutor(Executor):
|
||||
@handler
|
||||
async def publish(self, text: str, ctx: WorkflowContext[Any, str]) -> None:
|
||||
# Store the output so diagnostics or a UI could fetch the final copy.
|
||||
await ctx.set_state({"published_text": text})
|
||||
await ctx.set_executor_state({"published_text": text})
|
||||
# Yield the final output so the workflow completes cleanly.
|
||||
await ctx.yield_output(text)
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ Pipeline:
|
||||
5) FinalizeFromAgent yields the final result.
|
||||
|
||||
What you learn:
|
||||
- How to persist executor state using ctx.get_state and ctx.set_state.
|
||||
- How to persist executor state using ctx.get_executor_state and ctx.set_executor_state.
|
||||
- How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility.
|
||||
- How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder.
|
||||
- How to list and inspect checkpoints programmatically.
|
||||
@@ -73,9 +73,9 @@ class UpperCaseExecutor(Executor):
|
||||
|
||||
# Persist executor-local state so it is captured in checkpoints
|
||||
# and available after resume for observability or logic.
|
||||
prev = await ctx.get_state() or {}
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_state({
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_input": text,
|
||||
"last_output": result,
|
||||
@@ -122,9 +122,9 @@ class FinalizeFromAgent(Executor):
|
||||
result = response.agent_run_response.text or ""
|
||||
|
||||
# Persist executor-local state for auditability when inspecting checkpoints.
|
||||
prev = await ctx.get_state() or {}
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_state({
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_output": result,
|
||||
"final": True,
|
||||
@@ -143,9 +143,9 @@ class ReverseTextExecutor(Executor):
|
||||
print(f"ReverseTextExecutor: '{text}' -> '{result}'")
|
||||
|
||||
# Persist executor-local state so checkpoint inspection can reveal progress.
|
||||
prev = await ctx.get_state() or {}
|
||||
prev = await ctx.get_executor_state() or {}
|
||||
count = int(prev.get("count", 0)) + 1
|
||||
await ctx.set_state({
|
||||
await ctx.set_executor_state({
|
||||
"count": count,
|
||||
"last_input": text,
|
||||
"last_output": result,
|
||||
|
||||
@@ -144,7 +144,7 @@ async def run_semantic_kernel_process_example() -> None:
|
||||
kernel=kernel,
|
||||
initial_event=KernelProcessEvent(id=CommonEvents.START_PROCESS.value, data="Initial"),
|
||||
) as process_context:
|
||||
process_state = await process_context.get_state()
|
||||
process_state = await process_context.get_executor_state()
|
||||
c_step_state: KernelProcessStepState[CStepState] | None = next(
|
||||
(s.state for s in process_state.steps if s.state.name == "CStep"),
|
||||
None,
|
||||
|
||||
@@ -136,7 +136,7 @@ async def run_semantic_kernel_nested_process() -> None:
|
||||
initial_event=ProcessEvents.START_PROCESS.value,
|
||||
data="Test",
|
||||
)
|
||||
process_info = await process_handle.get_state()
|
||||
process_info = await process_handle.get_executor_state()
|
||||
|
||||
inner_process: KernelProcess | None = next(
|
||||
(s for s in process_info.steps if s.state.name == "Inner"),
|
||||
|
||||
Reference in New Issue
Block a user