Code clean up: Checkpoint and WorkflowBuilder (#1557)

Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
Tao Chen
2025-10-20 09:34:06 -07:00
committed by GitHub
Unverified
parent 9c3f52566f
commit 083d0de3f3
21 changed files with 812 additions and 809 deletions
@@ -98,7 +98,8 @@ from ._validation import (
validate_workflow_graph,
)
from ._viz import WorkflowViz
from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult
from ._workflow import Workflow, WorkflowRunResult
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
from ._workflow_executor import WorkflowExecutor
@@ -96,7 +96,8 @@ from ._validation import (
validate_workflow_graph,
)
from ._viz import WorkflowViz
from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult
from ._workflow import Workflow, WorkflowRunResult
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
from ._workflow_executor import WorkflowExecutor
@@ -0,0 +1,251 @@
# Copyright (c) Microsoft. All rights reserved.
import contextlib
import importlib
import logging
import sys
from dataclasses import fields, is_dataclass
from typing import Any, cast
# Checkpoint serialization helpers
MODEL_MARKER = "__af_model__"
DATACLASS_MARKER = "__af_dataclass__"
# Guards to prevent runaway recursion while encoding arbitrary user data
_MAX_ENCODE_DEPTH = 100
_CYCLE_SENTINEL = "<cycle>"
logger = logging.getLogger(__name__)
def encode_checkpoint_value(value: Any) -> Any:
"""Recursively encode values into JSON-serializable structures.
- Objects exposing to_dict/to_json -> { MODEL_MARKER: "module:Class", value: encoded }
- dataclass instances -> { DATACLASS_MARKER: "module:Class", value: {field: encoded} }
- dict -> encode keys as str and values recursively
- list/tuple/set -> list of encoded items
- other -> returned as-is if already JSON-serializable
Includes cycle and depth protection to avoid infinite recursion.
"""
def _enc(v: Any, stack: set[int], depth: int) -> Any:
# Depth guard
if depth > _MAX_ENCODE_DEPTH:
logger.debug(f"Max encode depth reached at depth={depth} for type={type(v)}")
return "<max_depth>"
# Structured model handling (objects exposing to_dict/to_json)
if _supports_model_protocol(v):
cls = cast(type[Any], type(v)) # type: ignore
try:
if hasattr(v, "to_dict") and callable(getattr(v, "to_dict", None)):
raw = v.to_dict() # type: ignore[attr-defined]
strategy = "to_dict"
elif hasattr(v, "to_json") and callable(getattr(v, "to_json", None)):
serialized = v.to_json() # type: ignore[attr-defined]
if isinstance(serialized, (bytes, bytearray)):
try:
serialized = serialized.decode()
except Exception:
serialized = serialized.decode(errors="replace")
raw = serialized
strategy = "to_json"
else:
raise AttributeError("Structured model lacks serialization hooks")
return {
MODEL_MARKER: f"{cls.__module__}:{cls.__name__}",
"strategy": strategy,
"value": _enc(raw, stack, depth + 1),
}
except Exception as exc: # best-effort fallback
logger.debug(f"Structured model serialization failed for {cls}: {exc}")
return str(v)
# Dataclasses (instances only)
if is_dataclass(v) and not isinstance(v, type):
oid = id(v)
if oid in stack:
logger.debug("Cycle detected while encoding dataclass instance")
return _CYCLE_SENTINEL
stack.add(oid)
try:
# type(v) already narrows sufficiently; cast was redundant
dc_cls: type[Any] = type(v)
field_values: dict[str, Any] = {}
for f in fields(v): # type: ignore[arg-type]
field_values[f.name] = _enc(getattr(v, f.name), stack, depth + 1)
return {
DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}",
"value": field_values,
}
finally:
stack.remove(oid)
# Collections
if isinstance(v, dict):
v_dict = cast("dict[object, object]", v)
oid = id(v_dict)
if oid in stack:
logger.debug("Cycle detected while encoding dict")
return _CYCLE_SENTINEL
stack.add(oid)
try:
json_dict: dict[str, Any] = {}
for k_any, val_any in v_dict.items(): # type: ignore[assignment]
k_str: str = str(k_any)
json_dict[k_str] = _enc(val_any, stack, depth + 1)
return json_dict
finally:
stack.remove(oid)
if isinstance(v, (list, tuple, set)):
iterable_v = cast("list[object] | tuple[object, ...] | set[object]", v)
oid = id(iterable_v)
if oid in stack:
logger.debug("Cycle detected while encoding iterable")
return _CYCLE_SENTINEL
stack.add(oid)
try:
seq: list[object] = list(iterable_v)
encoded_list: list[Any] = []
for item in seq:
encoded_list.append(_enc(item, stack, depth + 1))
return encoded_list
finally:
stack.remove(oid)
# Primitives (or unknown objects): ensure JSON-serializable
if isinstance(v, (str, int, float, bool)) or v is None:
return v
# Fallback: stringify unknown objects to avoid JSON serialization errors
try:
return str(v)
except Exception:
return f"<{type(v).__name__}>"
return _enc(value, set(), 0)
def decode_checkpoint_value(value: Any) -> Any:
"""Recursively decode values previously encoded by encode_checkpoint_value."""
if isinstance(value, dict):
value_dict = cast(dict[str, Any], value) # encoded form always uses string keys
# Structured model marker handling
if MODEL_MARKER in value_dict and "value" in value_dict:
type_key: str | None = value_dict.get(MODEL_MARKER) # type: ignore[assignment]
strategy: str | None = value_dict.get("strategy") # type: ignore[assignment]
raw_encoded: Any = value_dict.get("value")
decoded_payload = decode_checkpoint_value(raw_encoded)
if isinstance(type_key, str):
try:
cls = _import_qualified_name(type_key)
except Exception as exc:
logger.debug(f"Failed to import structured model {type_key}: {exc}")
cls = None
if cls is not None:
if strategy == "to_dict" and hasattr(cls, "from_dict"):
with contextlib.suppress(Exception):
return cls.from_dict(decoded_payload)
if strategy == "to_json" and hasattr(cls, "from_json"):
if isinstance(decoded_payload, (str, bytes, bytearray)):
with contextlib.suppress(Exception):
return cls.from_json(decoded_payload)
if isinstance(decoded_payload, dict) and hasattr(cls, "from_dict"):
with contextlib.suppress(Exception):
return cls.from_dict(decoded_payload)
return decoded_payload
# Dataclass marker handling
if DATACLASS_MARKER in value_dict and "value" in value_dict:
type_key_dc: str | None = value_dict.get(DATACLASS_MARKER) # type: ignore[assignment]
raw_dc: Any = value_dict.get("value")
decoded_raw = decode_checkpoint_value(raw_dc)
if isinstance(type_key_dc, str):
try:
module_name, class_name = type_key_dc.split(":", 1)
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
cls_dc: Any = getattr(module, class_name)
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
if constructed is not None:
return constructed
except Exception as exc:
logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value")
return decoded_raw
# Regular dict: decode recursively
decoded: dict[str, Any] = {}
for k_any, v_any in value_dict.items():
decoded[k_any] = decode_checkpoint_value(v_any)
return decoded
if isinstance(value, list):
# After isinstance check, treat value as list[Any] for decoding
value_list: list[Any] = value # type: ignore[assignment]
return [decode_checkpoint_value(v_any) for v_any in value_list]
return value
def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None:
if not isinstance(cls, type):
logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}")
return None
if isinstance(payload, dict):
try:
return cls(**payload) # type: ignore[arg-type]
except TypeError as exc:
logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}")
except Exception as exc:
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}")
try:
instance = object.__new__(cls)
except Exception as exc:
logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}")
return None
for key, val in payload.items(): # type: ignore[attr-defined]
try:
setattr(instance, key, val) # type: ignore[arg-type]
except Exception as exc:
logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}")
return instance
try:
return cls(payload) # type: ignore[call-arg]
except TypeError as exc:
logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}")
except Exception as exc:
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}")
return None
def _supports_model_protocol(obj: object) -> bool:
"""Detect objects that expose dictionary serialization hooks."""
try:
obj_type: type[Any] = type(obj)
except Exception:
return False
has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type]
has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None))
has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type]
has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None))
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
def _import_qualified_name(qualname: str) -> type[Any] | None:
if ":" not in qualname:
return None
module_name, class_name = qualname.split(":", 1)
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
attr: Any = module
for part in class_name.split("."):
attr = getattr(attr, part)
return attr if isinstance(attr, type) else None
@@ -7,8 +7,8 @@ from textwrap import shorten
from typing import Any
from ._checkpoint import WorkflowCheckpoint
from ._checkpoint_encoding import decode_checkpoint_value
from ._request_info_executor import PendingRequestDetails, RequestInfoMessage, RequestResponse
from ._runner_context import _decode_checkpoint_value # type: ignore
logger = logging.getLogger(__name__)
@@ -90,7 +90,7 @@ def _pending_requests_from_checkpoint(
for message in message_list:
if not isinstance(message, Mapping):
continue
payload = _decode_checkpoint_value(message.get("data"))
payload = decode_checkpoint_value(message.get("data"))
_merge_message_payload(pending, payload, message)
return list(pending.values())
@@ -13,7 +13,8 @@ from agent_framework import AgentProtocol, ChatMessage, Role
from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse
from ._checkpoint import CheckpointStorage
from ._executor import Executor, handler
from ._workflow import Workflow, WorkflowBuilder
from ._workflow import Workflow
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
logger = logging.getLogger(__name__)
@@ -30,7 +30,8 @@ from ._events import WorkflowEvent
from ._executor import Executor, handler
from ._model_utils import DictConvertible, encode_value
from ._request_info_executor import RequestInfoMessage, RequestResponse
from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult
from ._workflow import Workflow, WorkflowRunResult
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 11):
@@ -1071,7 +1072,7 @@ class MagenticOrchestratorExecutor(Executor):
) -> None:
if self._state_restored and self._context is not None:
return
state = await context.get_state()
state = await context.get_executor_state()
if not state:
self._state_restored = True
return
@@ -1552,7 +1553,7 @@ class MagenticAgentExecutor(Executor):
async def _ensure_state_restored(self, context: WorkflowContext[Any, Any]) -> None:
if self._state_restored and self._chat_history:
return
state = await context.get_state()
state = await context.get_executor_state()
if not state:
self._state_restored = True
return
@@ -249,7 +249,7 @@ class RequestInfoExecutor(Executor):
async def _retrieve_existing_pending_requests(self, ctx: WorkflowContext) -> dict[str, PendingRequestSnapshot]:
"""Retrieve existing pending requests from executor state."""
executor_state = await ctx.get_state()
executor_state = await ctx.get_executor_state()
if executor_state is None:
return {}
@@ -271,9 +271,9 @@ class RequestInfoExecutor(Executor):
self, pending: dict[str, PendingRequestSnapshot], ctx: WorkflowContext
) -> None:
"""Persist the current pending requests to the executor's state."""
executor_state = await ctx.get_state() or {}
executor_state = await ctx.get_executor_state() or {}
executor_state[self._PENDING_SHARED_STATE_KEY] = pending
await ctx.set_state(executor_state)
await ctx.set_executor_state(executor_state)
def _build_pending_request_snapshot(
self, request: RequestInfoMessage, source_executor_id: str
@@ -7,17 +7,15 @@ from collections.abc import AsyncGenerator, Sequence
from typing import TYPE_CHECKING, Any
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpoint_value
from ._edge import EdgeGroup
from ._edge_runner import EdgeRunner, create_edge_runner
from ._events import WorkflowEvent
from ._executor import Executor
from ._runner_context import (
_DATACLASS_MARKER, # type: ignore
_MODEL_MARKER, # type: ignore
CheckpointState,
Message,
RunnerContext,
_decode_checkpoint_value, # type: ignore
WorkflowState,
)
from ._shared_state import SharedState
@@ -168,10 +166,10 @@ class Runner:
data = message.data
if not isinstance(data, dict):
return
if _MODEL_MARKER not in data and _DATACLASS_MARKER not in data:
if MODEL_MARKER not in data and DATACLASS_MARKER not in data:
return
try:
decoded = _decode_checkpoint_value(data)
decoded = decode_checkpoint_value(data)
except Exception as exc: # pragma: no cover - defensive
logger.debug("Failed to decode checkpoint payload during delivery: %s", exc)
return
@@ -238,7 +236,7 @@ class Runner:
logger.debug(f"Executor {exec_id} snapshot_state failed: {ex}")
if state_dict is not None:
try:
await self._ctx.set_state(exec_id, state_dict)
await self._ctx.set_executor_state(exec_id, state_dict)
except Exception as ex: # pragma: no cover
logger.debug(f"Failed to persist state for executor {exec_id}: {ex}")
@@ -247,7 +245,7 @@ class Runner:
return
try:
current_state = await self._ctx.get_checkpoint_state()
current_state = await self._ctx.get_workflow_state()
shared_state_data = {}
async with self._shared_state.hold():
@@ -258,7 +256,7 @@ class Runner:
current_state["iteration_count"] = self._iteration
current_state["max_iterations"] = self._max_iterations
await self._ctx.set_checkpoint_state(current_state)
await self._ctx.set_workflow_state(current_state)
except Exception as e:
logger.warning(f"Failed to update context with shared state: {e}")
@@ -278,6 +276,7 @@ class Runner:
True if restoration was successful, False otherwise
"""
try:
# Load the checkpoint
checkpoint: WorkflowCheckpoint | None
if self._ctx.has_checkpointing():
checkpoint = await self._ctx.load_checkpoint(checkpoint_id)
@@ -291,6 +290,7 @@ class Runner:
logger.error(f"Checkpoint {checkpoint_id} not found")
return False
# Validate the loaded checkpoint against the workflow
graph_hash = getattr(self, "graph_signature_hash", None)
checkpoint_hash = (checkpoint.metadata or {}).get("graph_signature")
if graph_hash and checkpoint_hash and graph_hash != checkpoint_hash:
@@ -306,8 +306,9 @@ class Runner:
await self._restore_executor_states(checkpoint.executor_states)
state = self._checkpoint_to_state(checkpoint)
await self._ctx.set_checkpoint_state(state)
state = _convert_checkpoint_to_workflow_state(checkpoint)
await self._ctx.set_workflow_state(state)
if checkpoint.workflow_id:
self._ctx.set_workflow_id(checkpoint.workflow_id)
self._workflow_id = checkpoint.workflow_id
@@ -348,7 +349,7 @@ class Runner:
async def _restore_shared_state_from_context(self) -> None:
try:
restored_state = await self._ctx.get_checkpoint_state()
restored_state = await self._ctx.get_workflow_state()
shared_state_data = restored_state.get("shared_state", {})
if shared_state_data and hasattr(self._shared_state, "_state"):
@@ -362,16 +363,6 @@ class Runner:
except Exception as e:
logger.warning(f"Failed to restore shared state from context: {e}")
@staticmethod
def _checkpoint_to_state(checkpoint: WorkflowCheckpoint) -> CheckpointState:
return {
"messages": checkpoint.messages,
"shared_state": checkpoint.shared_state,
"executor_states": checkpoint.executor_states,
"iteration_count": checkpoint.iteration_count,
"max_iterations": checkpoint.max_iterations,
}
def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[EdgeRunner]]:
"""Parse the edge runners of the workflow into a mapping where each source executor ID maps to its edge runners.
@@ -421,3 +412,14 @@ class Runner:
if executor.id == msg.target_id and isinstance(executor, RequestInfoExecutor):
return True
return False
def _convert_checkpoint_to_workflow_state(checkpoint: WorkflowCheckpoint) -> WorkflowState:
"""Helper function to convert a WorkflowCheckpoint to a WorkflowState."""
return {
"messages": checkpoint.messages,
"shared_state": checkpoint.shared_state,
"executor_states": checkpoint.executor_states,
"iteration_count": checkpoint.iteration_count,
"max_iterations": checkpoint.max_iterations,
}
@@ -1,16 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import contextlib
import importlib
import logging
import sys
import uuid
from copy import copy
from dataclasses import dataclass, fields, is_dataclass
from dataclasses import dataclass
from typing import Any, Protocol, TypedDict, TypeVar, cast, runtime_checkable
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from ._const import DEFAULT_MAX_ITERATIONS
from ._events import WorkflowEvent
from ._shared_state import SharedState
@@ -45,7 +43,12 @@ class Message:
return self.source_span_ids[0] if self.source_span_ids else None
class CheckpointState(TypedDict):
class WorkflowState(TypedDict):
"""TypedDict representing the serializable state of a workflow execution.
This includes all state data needed for checkpointing and restoration.
"""
messages: dict[str, list[dict[str, Any]]]
shared_state: dict[str, Any]
executor_states: dict[str, dict[str, Any]]
@@ -53,248 +56,6 @@ class CheckpointState(TypedDict):
max_iterations: int
# Checkpoint serialization helpers
_MODEL_MARKER = "__af_model__"
_DATACLASS_MARKER = "__af_dataclass__"
_AF_MARKER = "__af__"
# Guards to prevent runaway recursion while encoding arbitrary user data
_MAX_ENCODE_DEPTH = 100
_CYCLE_SENTINEL = "<cycle>"
def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None:
if not isinstance(cls, type):
logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}")
return None
if isinstance(payload, dict):
try:
return cls(**payload) # type: ignore[arg-type]
except TypeError as exc:
logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}")
except Exception as exc:
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}")
try:
instance = object.__new__(cls)
except Exception as exc:
logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}")
return None
for key, val in payload.items(): # type: ignore[attr-defined]
try:
setattr(instance, key, val) # type: ignore[arg-type]
except Exception as exc:
logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}")
return instance
try:
return cls(payload) # type: ignore[call-arg]
except TypeError as exc:
logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}")
except Exception as exc:
logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}")
return None
def _supports_model_protocol(obj: object) -> bool:
"""Detect objects that expose dictionary serialization hooks."""
try:
obj_type: type[Any] = type(obj)
except Exception:
return False
has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type]
has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None))
has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type]
has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None))
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
def _import_qualified_name(qualname: str) -> type[Any] | None:
if ":" not in qualname:
return None
module_name, class_name = qualname.split(":", 1)
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
attr: Any = module
for part in class_name.split("."):
attr = getattr(attr, part)
return attr if isinstance(attr, type) else None
def _encode_checkpoint_value(value: Any) -> Any:
"""Recursively encode values into JSON-serializable structures.
- Objects exposing to_dict/to_json -> { _MODEL_MARKER: "module:Class", value: encoded }
- dataclass instances -> { _DATACLASS_MARKER: "module:Class", value: {field: encoded} }
- dict -> encode keys as str and values recursively
- list/tuple/set -> list of encoded items
- other -> returned as-is if already JSON-serializable
Includes cycle and depth protection to avoid infinite recursion.
"""
def _enc(v: Any, stack: set[int], depth: int) -> Any:
# Depth guard
if depth > _MAX_ENCODE_DEPTH:
logger.debug(f"Max encode depth reached at depth={depth} for type={type(v)}")
return "<max_depth>"
# Structured model handling (objects exposing to_dict/to_json)
if _supports_model_protocol(v):
cls = cast(type[Any], type(v)) # type: ignore
try:
if hasattr(v, "to_dict") and callable(getattr(v, "to_dict", None)):
raw = v.to_dict() # type: ignore[attr-defined]
strategy = "to_dict"
elif hasattr(v, "to_json") and callable(getattr(v, "to_json", None)):
serialized = v.to_json() # type: ignore[attr-defined]
if isinstance(serialized, (bytes, bytearray)):
try:
serialized = serialized.decode()
except Exception:
serialized = serialized.decode(errors="replace")
raw = serialized
strategy = "to_json"
else:
raise AttributeError("Structured model lacks serialization hooks")
return {
_MODEL_MARKER: f"{cls.__module__}:{cls.__name__}",
"strategy": strategy,
"value": _enc(raw, stack, depth + 1),
}
except Exception as exc: # best-effort fallback
logger.debug(f"Structured model serialization failed for {cls}: {exc}")
return str(v)
# Dataclasses (instances only)
if is_dataclass(v) and not isinstance(v, type):
oid = id(v)
if oid in stack:
logger.debug("Cycle detected while encoding dataclass instance")
return _CYCLE_SENTINEL
stack.add(oid)
try:
# type(v) already narrows sufficiently; cast was redundant
dc_cls: type[Any] = type(v)
field_values: dict[str, Any] = {}
for f in fields(v): # type: ignore[arg-type]
field_values[f.name] = _enc(getattr(v, f.name), stack, depth + 1)
return {
_DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}",
"value": field_values,
}
finally:
stack.remove(oid)
# Collections
if isinstance(v, dict):
v_dict = cast("dict[object, object]", v)
oid = id(v_dict)
if oid in stack:
logger.debug("Cycle detected while encoding dict")
return _CYCLE_SENTINEL
stack.add(oid)
try:
json_dict: dict[str, Any] = {}
for k_any, val_any in v_dict.items(): # type: ignore[assignment]
k_str: str = str(k_any)
json_dict[k_str] = _enc(val_any, stack, depth + 1)
return json_dict
finally:
stack.remove(oid)
if isinstance(v, (list, tuple, set)):
iterable_v = cast("list[object] | tuple[object, ...] | set[object]", v)
oid = id(iterable_v)
if oid in stack:
logger.debug("Cycle detected while encoding iterable")
return _CYCLE_SENTINEL
stack.add(oid)
try:
seq: list[object] = list(iterable_v)
encoded_list: list[Any] = []
for item in seq:
encoded_list.append(_enc(item, stack, depth + 1))
return encoded_list
finally:
stack.remove(oid)
# Primitives (or unknown objects): ensure JSON-serializable
if isinstance(v, (str, int, float, bool)) or v is None:
return v
# Fallback: stringify unknown objects to avoid JSON serialization errors
try:
return str(v)
except Exception:
return f"<{type(v).__name__}>"
return _enc(value, set(), 0)
def _decode_checkpoint_value(value: Any) -> Any:
"""Recursively decode values previously encoded by _encode_checkpoint_value."""
if isinstance(value, dict):
value_dict = cast(dict[str, Any], value) # encoded form always uses string keys
# Structured model marker handling
if _MODEL_MARKER in value_dict and "value" in value_dict:
type_key: str | None = value_dict.get(_MODEL_MARKER) # type: ignore[assignment]
strategy: str | None = value_dict.get("strategy") # type: ignore[assignment]
raw_encoded: Any = value_dict.get("value")
decoded_payload = _decode_checkpoint_value(raw_encoded)
if isinstance(type_key, str):
try:
cls = _import_qualified_name(type_key)
except Exception as exc:
logger.debug(f"Failed to import structured model {type_key}: {exc}")
cls = None
if cls is not None:
if strategy == "to_dict" and hasattr(cls, "from_dict"):
with contextlib.suppress(Exception):
return cls.from_dict(decoded_payload)
if strategy == "to_json" and hasattr(cls, "from_json"):
if isinstance(decoded_payload, (str, bytes, bytearray)):
with contextlib.suppress(Exception):
return cls.from_json(decoded_payload)
if isinstance(decoded_payload, dict) and hasattr(cls, "from_dict"):
with contextlib.suppress(Exception):
return cls.from_dict(decoded_payload)
return decoded_payload
# Dataclass marker handling
if _DATACLASS_MARKER in value_dict and "value" in value_dict:
type_key_dc: str | None = value_dict.get(_DATACLASS_MARKER) # type: ignore[assignment]
raw_dc: Any = value_dict.get("value")
decoded_raw = _decode_checkpoint_value(raw_dc)
if isinstance(type_key_dc, str):
try:
module_name, class_name = type_key_dc.split(":", 1)
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
cls_dc: Any = getattr(module, class_name)
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
if constructed is not None:
return constructed
except Exception as exc:
logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value")
return decoded_raw
# Regular dict: decode recursively
decoded: dict[str, Any] = {}
for k_any, v_any in value_dict.items():
decoded[k_any] = _decode_checkpoint_value(v_any)
return decoded
if isinstance(value, list):
# After isinstance check, treat value as list[Any] for decoding
value_list: list[Any] = value # type: ignore[assignment]
return [_decode_checkpoint_value(v_any) for v_any in value_list]
return value
@runtime_checkable
class RunnerContext(Protocol):
"""Protocol for the execution context used by the runner.
@@ -355,7 +116,7 @@ class RunnerContext(Protocol):
"""Wait for and return the next event emitted by the workflow run."""
...
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None:
async def set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
"""Set the state for a specific executor.
Args:
@@ -364,7 +125,7 @@ class RunnerContext(Protocol):
"""
...
async def get_state(self, executor_id: str) -> dict[str, Any] | None:
async def get_executor_state(self, executor_id: str) -> dict[str, Any] | None:
"""Get the state for a specific executor.
Args:
@@ -417,30 +178,19 @@ class RunnerContext(Protocol):
"""
...
async def restore_from_checkpoint(self, checkpoint_id: str) -> bool:
"""Restore the context from a checkpoint.
Args:
checkpoint_id: The ID of the checkpoint to restore from.
Returns:
True if the restoration was successful, False otherwise.
"""
...
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
"""Load a checkpoint without mutating the current context state."""
...
async def get_checkpoint_state(self) -> CheckpointState:
"""Get the current state of the context suitable for checkpointing."""
async def get_workflow_state(self) -> WorkflowState:
"""Get the current state of the workflow suitable for checkpointing."""
...
async def set_checkpoint_state(self, state: CheckpointState) -> None:
"""Set the state of the context from a checkpoint.
async def set_workflow_state(self, state: WorkflowState) -> None:
"""Set the state of the workflow from a checkpoint.
Args:
state: The state data to set for the context.
state: The state data to set for the workflow.
"""
...
@@ -509,10 +259,10 @@ class InProcRunnerContext:
"""
return await self._event_queue.get()
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None:
async def set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None:
self._executor_states[executor_id] = state
async def get_state(self, executor_id: str) -> dict[str, Any] | None:
async def get_executor_state(self, executor_id: str) -> dict[str, Any] | None:
return self._executor_states.get(executor_id)
def has_checkpointing(self) -> bool:
@@ -554,7 +304,7 @@ class InProcRunnerContext:
wf_id = self._workflow_id or str(uuid.uuid4())
self._workflow_id = wf_id
state = await self.get_checkpoint_state()
state = await self.get_workflow_state()
checkpoint = WorkflowCheckpoint(
workflow_id=wf_id,
@@ -569,38 +319,17 @@ class InProcRunnerContext:
logger.info(f"Created checkpoint {checkpoint_id} for workflow {wf_id}'")
return checkpoint_id
async def restore_from_checkpoint(self, checkpoint_id: str) -> bool:
if not self._checkpoint_storage:
raise ValueError("Checkpoint storage not configured")
checkpoint = await self._checkpoint_storage.load_checkpoint(checkpoint_id)
if not checkpoint:
logger.error(f"Checkpoint {checkpoint_id} not found")
return False
state: CheckpointState = {
"messages": checkpoint.messages,
"shared_state": checkpoint.shared_state,
"executor_states": checkpoint.executor_states,
"iteration_count": checkpoint.iteration_count,
"max_iterations": checkpoint.max_iterations,
}
await self.set_checkpoint_state(state)
self._workflow_id = checkpoint.workflow_id
logger.info(f"Restored state from checkpoint {checkpoint_id}'")
return True
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
if not self._checkpoint_storage:
raise ValueError("Checkpoint storage not configured")
return await self._checkpoint_storage.load_checkpoint(checkpoint_id)
async def get_checkpoint_state(self) -> CheckpointState:
async def get_workflow_state(self) -> WorkflowState:
serializable_messages: dict[str, list[dict[str, Any]]] = {}
for source_id, message_list in self._messages.items():
serializable_messages[source_id] = [
{
"data": _encode_checkpoint_value(msg.data),
"data": encode_checkpoint_value(msg.data),
"source_id": msg.source_id,
"target_id": msg.target_id,
"trace_contexts": msg.trace_contexts,
@@ -608,21 +337,22 @@ class InProcRunnerContext:
}
for msg in message_list
]
return {
"messages": serializable_messages,
"shared_state": _encode_checkpoint_value(self._shared_state),
"executor_states": _encode_checkpoint_value(self._executor_states),
"shared_state": encode_checkpoint_value(self._shared_state),
"executor_states": encode_checkpoint_value(self._executor_states),
"iteration_count": self._iteration_count,
"max_iterations": self._max_iterations,
}
async def set_checkpoint_state(self, state: CheckpointState) -> None:
async def set_workflow_state(self, state: WorkflowState) -> None:
self._messages.clear()
messages_data = state.get("messages", {})
for source_id, message_list in messages_data.items():
self._messages[source_id] = [
Message(
data=_decode_checkpoint_value(msg.get("data")),
data=decode_checkpoint_value(msg.get("data")),
source_id=msg.get("source_id", ""),
target_id=msg.get("target_id"),
trace_contexts=msg.get("trace_contexts"),
@@ -631,14 +361,14 @@ class InProcRunnerContext:
for msg in message_list
]
# Restore shared_state
decoded_shared_raw = _decode_checkpoint_value(state.get("shared_state", {}))
decoded_shared_raw = decode_checkpoint_value(state.get("shared_state", {}))
if isinstance(decoded_shared_raw, dict):
self._shared_state = cast(dict[str, Any], decoded_shared_raw)
else: # fallback to empty dict if corrupted
self._shared_state = {}
# Restore executor_states ensuring value types are dicts
decoded_exec_raw = _decode_checkpoint_value(state.get("executor_states", {}))
decoded_exec_raw = decode_checkpoint_value(state.get("executor_states", {}))
if isinstance(decoded_exec_raw, dict):
typed_exec: dict[str, dict[str, Any]] = {}
for k_raw, v_raw in decoded_exec_raw.items(): # type: ignore[assignment]
@@ -51,7 +51,8 @@ from ._executor import (
Executor,
handler,
)
from ._workflow import Workflow, WorkflowBuilder
from ._workflow import Workflow
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
logger = logging.getLogger(__name__)
@@ -6,25 +6,16 @@ import json
import logging
import sys
import uuid
from collections.abc import AsyncIterable, Awaitable, Callable, Sequence
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any
from .._agents import AgentProtocol
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._agent import WorkflowAgent
from ._agent_executor import AgentExecutor
from ._checkpoint import CheckpointStorage
from ._const import DEFAULT_MAX_ITERATIONS
from ._edge import (
Case,
Default,
EdgeGroup,
FanInEdgeGroup,
FanOutEdgeGroup,
SingleEdgeGroup,
SwitchCaseEdgeGroup,
SwitchCaseEdgeGroupCase,
SwitchCaseEdgeGroupDefault,
)
from ._events import (
RequestInfoEvent,
@@ -41,15 +32,14 @@ from ._executor import Executor
from ._model_utils import DictConvertible
from ._request_info_executor import RequestInfoExecutor
from ._runner import Runner
from ._runner_context import InProcRunnerContext, RunnerContext
from ._runner_context import RunnerContext
from ._shared_state import SharedState
from ._validation import validate_workflow_graph
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
pass # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
pass # pragma: no cover
logger = logging.getLogger(__name__)
@@ -858,422 +848,3 @@ class Workflow(DictConvertible):
from ._agent import WorkflowAgent
return WorkflowAgent(workflow=self, name=name)
# region WorkflowBuilder
class WorkflowBuilder:
"""A builder class for constructing workflows.
This class provides methods to add edges and set the starting executor for the workflow.
"""
def __init__(
self,
max_iterations: int = DEFAULT_MAX_ITERATIONS,
name: str | None = None,
description: str | None = None,
):
"""Initialize the WorkflowBuilder with an empty list of edges and no starting executor.
Args:
max_iterations: Maximum number of iterations for workflow convergence.
name: Optional human-readable name for the workflow.
description: Optional description of what the workflow does.
"""
self._edge_groups: list[EdgeGroup] = []
self._executors: dict[str, Executor] = {}
self._duplicate_executor_ids: set[str] = set()
self._start_executor: Executor | str | None = None
self._checkpoint_storage: CheckpointStorage | None = None
self._max_iterations: int = max_iterations
self._name: str | None = name
self._description: str | None = description
# Maps underlying AgentProtocol object id -> wrapped Executor so we reuse the same wrapper
# across set_start_executor / add_edge calls. Without this, unnamed agents (which receive
# random UUID based executor ids) end up wrapped multiple times, giving different ids for
# the start node vs edge nodes and triggering a GraphConnectivityError during validation.
self._agent_wrappers: dict[int, Executor] = {}
# Agents auto-wrapped by builder now always stream incremental updates.
def _add_executor(self, executor: Executor) -> str:
"""Add an executor to the map and return its ID."""
existing = self._executors.get(executor.id)
if existing is not None and existing is not executor:
self._duplicate_executor_ids.add(executor.id)
else:
self._executors[executor.id] = executor
return executor.id
def _maybe_wrap_agent(
self,
candidate: Executor | AgentProtocol,
agent_thread: Any | None = None,
output_response: bool = False,
executor_id: str | None = None,
) -> Executor:
"""If the provided object implements AgentProtocol, wrap it in an AgentExecutor.
This allows fluent builder APIs to directly accept agents instead of
requiring callers to manually instantiate AgentExecutor.
Args:
candidate: The executor or agent to wrap.
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
executor_id: A unique identifier for the executor. If None, the agent's name will be used if available.
"""
try: # Local import to avoid hard dependency at import time
from agent_framework import AgentProtocol # type: ignore
except Exception: # pragma: no cover - defensive
AgentProtocol = object # type: ignore
if isinstance(candidate, Executor): # Already an executor
return candidate
if isinstance(candidate, AgentProtocol): # type: ignore[arg-type]
# Reuse existing wrapper for the same agent instance if present
agent_instance_id = id(candidate)
existing = self._agent_wrappers.get(agent_instance_id)
if existing is not None:
return existing
# Use agent name if available and unique among current executors
name = getattr(candidate, "name", None)
proposed_id: str | None = executor_id
if proposed_id is None and name:
proposed_id = str(name)
if proposed_id in self._executors:
raise ValueError(
f"Duplicate executor ID '{proposed_id}' from agent name. "
"Agent names must be unique within a workflow."
)
wrapper = AgentExecutor(
candidate,
agent_thread=agent_thread,
output_response=output_response,
id=proposed_id,
)
self._agent_wrappers[agent_instance_id] = wrapper
return wrapper
raise TypeError(
f"WorkflowBuilder expected an Executor or AgentProtocol instance; got {type(candidate).__name__}."
)
def add_agent(
self,
agent: AgentProtocol,
agent_thread: Any | None = None,
output_response: bool = False,
id: str | None = None,
) -> Self:
"""Add an agent to the workflow by wrapping it in an AgentExecutor.
This method creates an AgentExecutor that wraps the agent with the given parameters
and ensures that subsequent uses of the same agent instance in other builder methods
(like add_edge, set_start_executor, etc.) will reuse the same wrapped executor.
Note: Agents adapt their behavior based on how the workflow is executed:
- run_stream(): Agents emit incremental AgentRunUpdateEvent events as tokens are produced
- run(): Agents emit a single AgentRunEvent containing the complete response
Args:
agent: The agent to add to the workflow.
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
id: A unique identifier for the executor. If None, the agent's name will be used if available.
Returns:
The WorkflowBuilder instance (for method chaining).
Raises:
ValueError: If the provided id or agent name conflicts with an existing executor.
"""
executor = self._maybe_wrap_agent(
agent, agent_thread=agent_thread, output_response=output_response, executor_id=id
)
self._add_executor(executor)
return self
def add_edge(
self,
source: Executor | AgentProtocol,
target: Executor | AgentProtocol,
condition: Callable[[Any], bool] | None = None,
) -> Self:
"""Add a directed edge between two executors.
The output types of the source and the input types of the target must be compatible.
Args:
source: The source executor of the edge.
target: The target executor of the edge.
condition: An optional condition function that determines whether the edge
should be traversed based on the message type.
"""
# TODO(@taochen): Support executor factories for lazy initialization
source_exec = self._maybe_wrap_agent(source)
target_exec = self._maybe_wrap_agent(target)
source_id = self._add_executor(source_exec)
target_id = self._add_executor(target_exec)
self._edge_groups.append(SingleEdgeGroup(source_id, target_id, condition)) # type: ignore[call-arg]
return self
def add_fan_out_edges(
self,
source: Executor | AgentProtocol,
targets: Sequence[Executor | AgentProtocol],
) -> Self:
"""Add multiple edges to the workflow where messages from the source will be sent to all target.
The output types of the source and the input types of the targets must be compatible.
Args:
source: The source executor of the edges.
targets: A list of target executors for the edges.
"""
source_exec = self._maybe_wrap_agent(source)
target_execs = [self._maybe_wrap_agent(t) for t in targets]
source_id = self._add_executor(source_exec)
target_ids = [self._add_executor(t) for t in target_execs]
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids)) # type: ignore[call-arg]
return self
def add_switch_case_edge_group(
self,
source: Executor | AgentProtocol,
cases: Sequence[Case | Default],
) -> Self:
"""Add an edge group that represents a switch-case statement.
The output types of the source and the input types of the targets must be compatible.
Messages from the source executor will be sent to one of the target executors based on
the provided conditions.
Think of this as a switch statement where each target executor corresponds to a case.
Each condition function will be evaluated in order, and the first one that returns True
will determine which target executor receives the message.
The last case (the default case) will receive messages that fall through all conditions
(i.e., no condition matched).
Args:
source: The source executor of the edges.
cases: A list of case objects that determine the target executor for each message.
"""
source_exec = self._maybe_wrap_agent(source)
source_id = self._add_executor(source_exec)
# Convert case data types to internal types that only uses target_id.
internal_cases: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = []
for case in cases:
# Allow case targets to be agents
case.target = self._maybe_wrap_agent(case.target) # type: ignore[attr-defined]
self._add_executor(case.target)
if isinstance(case, Default):
internal_cases.append(SwitchCaseEdgeGroupDefault(target_id=case.target.id))
else:
internal_cases.append(SwitchCaseEdgeGroupCase(condition=case.condition, target_id=case.target.id))
self._edge_groups.append(SwitchCaseEdgeGroup(source_id, internal_cases)) # type: ignore[call-arg]
return self
def add_multi_selection_edge_group(
self,
source: Executor | AgentProtocol,
targets: Sequence[Executor | AgentProtocol],
selection_func: Callable[[Any, list[str]], list[str]],
) -> Self:
"""Add an edge group that represents a multi-selection execution model.
The output types of the source and the input types of the targets must be compatible.
Messages from the source executor will be sent to multiple target executors based on
the provided selection function.
The selection function should take a message and the name of the target executors,
and return a list of indices indicating which target executors should receive the message.
Args:
source: The source executor of the edges.
targets: A list of target executors for the edges.
selection_func: A function that selects target executors for messages.
"""
source_exec = self._maybe_wrap_agent(source)
target_execs = [self._maybe_wrap_agent(t) for t in targets]
source_id = self._add_executor(source_exec)
target_ids = [self._add_executor(t) for t in target_execs]
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids, selection_func)) # type: ignore[call-arg]
return self
def add_fan_in_edges(
self,
sources: Sequence[Executor | AgentProtocol],
target: Executor | AgentProtocol,
) -> Self:
"""Add multiple edges from sources to a single target executor.
The edges will be grouped together for synchronized processing, meaning
the target executor will only be executed once all source executors have completed.
The target executor will receive a list of messages aggregated from all source executors.
Thus the input types of the target executor must be compatible with a list of the output
types of the source executors. For example:
class Target(Executor):
@handler
def handle_messages(self, messages: list[Message]) -> None:
# Process the aggregated messages from all sources
class Source(Executor):
@handler(output_type=[Message])
def handle_message(self, message: Message) -> None:
# Send a message to the target executor
self.send_message(message)
workflow = (
WorkflowBuilder()
.add_fan_in_edges(
[Source(id="source1"), Source(id="source2")],
Target(id="target")
)
.build()
)
Args:
sources: A list of source executors for the edges.
target: The target executor for the edges.
"""
source_execs = [self._maybe_wrap_agent(s) for s in sources]
target_exec = self._maybe_wrap_agent(target)
source_ids = [self._add_executor(s) for s in source_execs]
target_id = self._add_executor(target_exec)
self._edge_groups.append(FanInEdgeGroup(source_ids, target_id)) # type: ignore[call-arg]
return self
def add_chain(self, executors: Sequence[Executor | AgentProtocol]) -> Self:
"""Add a chain of executors to the workflow.
The output of each executor in the chain will be sent to the next executor in the chain.
The input types of each executor must be compatible with the output types of the previous executor.
Circles in the chain are not allowed, meaning the chain cannot have two executors with the same ID.
Args:
executors: A list of executors to be added to the chain.
"""
# Wrap each candidate first to ensure stable IDs before adding edges
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors]
for i in range(len(wrapped) - 1):
self.add_edge(wrapped[i], wrapped[i + 1])
return self
def set_start_executor(self, executor: Executor | AgentProtocol | str) -> Self:
"""Set the starting executor for the workflow.
Args:
executor: The starting executor, which can be an Executor instance or its ID.
"""
if isinstance(executor, str):
self._start_executor = executor
else:
wrapped = self._maybe_wrap_agent(executor) # type: ignore[arg-type]
self._start_executor = wrapped
# Ensure the start executor is present in the executor map so validation succeeds
# even if no edges are added yet, or before edges wrap the same agent again.
existing = self._executors.get(wrapped.id)
if existing is not wrapped:
self._add_executor(wrapped)
return self
def set_max_iterations(self, max_iterations: int) -> Self:
"""Set the maximum number of iterations for the workflow.
Args:
max_iterations: The maximum number of iterations the workflow will run for convergence.
"""
self._max_iterations = max_iterations
return self
# Removed explicit set_agent_streaming() API; agents always stream updates.
def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> Self:
"""Enable checkpointing with the specified storage.
Args:
checkpoint_storage: The checkpoint storage to use.
"""
self._checkpoint_storage = checkpoint_storage
return self
def build(self) -> Workflow:
"""Build and return the constructed workflow.
This method performs validation before building the workflow.
Returns:
A Workflow instance with the defined edges and starting executor.
Raises:
ValueError: If starting executor is not set.
WorkflowValidationError: If workflow validation fails (includes EdgeDuplicationError,
TypeCompatibilityError, and GraphConnectivityError subclasses).
"""
# Create workflow build span that includes validation and workflow creation
with create_workflow_span(OtelAttr.WORKFLOW_BUILD_SPAN) as span:
try:
# Add workflow build started event
span.add_event(OtelAttr.BUILD_STARTED)
if not self._start_executor:
raise ValueError(
"Starting executor must be set using set_start_executor before building the workflow."
)
# Perform validation before creating the workflow
validate_workflow_graph(
self._edge_groups,
self._executors,
self._start_executor,
duplicate_executor_ids=tuple(self._duplicate_executor_ids),
)
# Add validation completed event
span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED)
context = InProcRunnerContext(self._checkpoint_storage)
# Create workflow instance after validation
workflow = Workflow(
self._edge_groups,
self._executors,
self._start_executor,
context,
self._max_iterations,
name=self._name,
description=self._description,
)
build_attributes: dict[str, Any] = {
OtelAttr.WORKFLOW_ID: workflow.id,
OtelAttr.WORKFLOW_DEFINITION: workflow.to_json(),
}
if workflow.name:
build_attributes[OtelAttr.WORKFLOW_NAME] = workflow.name
if workflow.description:
build_attributes[OtelAttr.WORKFLOW_DESCRIPTION] = workflow.description
span.set_attributes(build_attributes)
# Add workflow build completed event
span.add_event(OtelAttr.BUILD_COMPLETED)
return workflow
except Exception as exc:
attributes = {
OtelAttr.BUILD_ERROR_MESSAGE: str(exc),
OtelAttr.BUILD_ERROR_TYPE: type(exc).__name__,
}
span.add_event(OtelAttr.BUILD_ERROR, attributes) # type: ignore[reportArgumentType, arg-type]
capture_exception(span, exc)
raise
@@ -0,0 +1,451 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
import sys
from collections.abc import Callable, Sequence
from typing import Any
from .._agents import AgentProtocol
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._agent_executor import AgentExecutor
from ._checkpoint import CheckpointStorage
from ._const import DEFAULT_MAX_ITERATIONS
from ._edge import (
Case,
Default,
EdgeGroup,
FanInEdgeGroup,
FanOutEdgeGroup,
SingleEdgeGroup,
SwitchCaseEdgeGroup,
SwitchCaseEdgeGroupCase,
SwitchCaseEdgeGroupDefault,
)
from ._executor import Executor
from ._runner_context import InProcRunnerContext
from ._validation import validate_workflow_graph
from ._workflow import Workflow
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
logger = logging.getLogger(__name__)
class WorkflowBuilder:
"""A builder class for constructing workflows.
This class provides methods to add edges and set the starting executor for the workflow.
"""
def __init__(
self,
max_iterations: int = DEFAULT_MAX_ITERATIONS,
name: str | None = None,
description: str | None = None,
):
"""Initialize the WorkflowBuilder with an empty list of edges and no starting executor.
Args:
max_iterations: Maximum number of iterations for workflow convergence.
name: Optional human-readable name for the workflow.
description: Optional description of what the workflow does.
"""
self._edge_groups: list[EdgeGroup] = []
self._executors: dict[str, Executor] = {}
self._duplicate_executor_ids: set[str] = set()
self._start_executor: Executor | str | None = None
self._checkpoint_storage: CheckpointStorage | None = None
self._max_iterations: int = max_iterations
self._name: str | None = name
self._description: str | None = description
# Maps underlying AgentProtocol object id -> wrapped Executor so we reuse the same wrapper
# across set_start_executor / add_edge calls. Without this, unnamed agents (which receive
# random UUID based executor ids) end up wrapped multiple times, giving different ids for
# the start node vs edge nodes and triggering a GraphConnectivityError during validation.
self._agent_wrappers: dict[int, Executor] = {}
# Agents auto-wrapped by builder now always stream incremental updates.
def _add_executor(self, executor: Executor) -> str:
"""Add an executor to the map and return its ID."""
existing = self._executors.get(executor.id)
if existing is not None and existing is not executor:
self._duplicate_executor_ids.add(executor.id)
else:
self._executors[executor.id] = executor
return executor.id
def _maybe_wrap_agent(
self,
candidate: Executor | AgentProtocol,
agent_thread: Any | None = None,
output_response: bool = False,
executor_id: str | None = None,
) -> Executor:
"""If the provided object implements AgentProtocol, wrap it in an AgentExecutor.
This allows fluent builder APIs to directly accept agents instead of
requiring callers to manually instantiate AgentExecutor.
Args:
candidate: The executor or agent to wrap.
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
executor_id: A unique identifier for the executor. If None, the agent's name will be used if available.
"""
try: # Local import to avoid hard dependency at import time
from agent_framework import AgentProtocol # type: ignore
except Exception: # pragma: no cover - defensive
AgentProtocol = object # type: ignore
if isinstance(candidate, Executor): # Already an executor
return candidate
if isinstance(candidate, AgentProtocol): # type: ignore[arg-type]
# Reuse existing wrapper for the same agent instance if present
agent_instance_id = id(candidate)
existing = self._agent_wrappers.get(agent_instance_id)
if existing is not None:
return existing
# Use agent name if available and unique among current executors
name = getattr(candidate, "name", None)
proposed_id: str | None = executor_id
if proposed_id is None and name:
proposed_id = str(name)
if proposed_id in self._executors:
raise ValueError(
f"Duplicate executor ID '{proposed_id}' from agent name. "
"Agent names must be unique within a workflow."
)
wrapper = AgentExecutor(
candidate,
agent_thread=agent_thread,
output_response=output_response,
id=proposed_id,
)
self._agent_wrappers[agent_instance_id] = wrapper
return wrapper
raise TypeError(
f"WorkflowBuilder expected an Executor or AgentProtocol instance; got {type(candidate).__name__}."
)
def add_agent(
self,
agent: AgentProtocol,
agent_thread: Any | None = None,
output_response: bool = False,
id: str | None = None,
) -> Self:
"""Add an agent to the workflow by wrapping it in an AgentExecutor.
This method creates an AgentExecutor that wraps the agent with the given parameters
and ensures that subsequent uses of the same agent instance in other builder methods
(like add_edge, set_start_executor, etc.) will reuse the same wrapped executor.
Note: Agents adapt their behavior based on how the workflow is executed:
- run_stream(): Agents emit incremental AgentRunUpdateEvent events as tokens are produced
- run(): Agents emit a single AgentRunEvent containing the complete response
Args:
agent: The agent to add to the workflow.
agent_thread: The thread to use for running the agent. If None, a new thread will be created.
output_response: Whether to yield an AgentRunResponse as a workflow output when the agent completes.
id: A unique identifier for the executor. If None, the agent's name will be used if available.
Returns:
The WorkflowBuilder instance (for method chaining).
Raises:
ValueError: If the provided id or agent name conflicts with an existing executor.
"""
executor = self._maybe_wrap_agent(
agent, agent_thread=agent_thread, output_response=output_response, executor_id=id
)
self._add_executor(executor)
return self
def add_edge(
self,
source: Executor | AgentProtocol,
target: Executor | AgentProtocol,
condition: Callable[[Any], bool] | None = None,
) -> Self:
"""Add a directed edge between two executors.
The output types of the source and the input types of the target must be compatible.
Args:
source: The source executor of the edge.
target: The target executor of the edge.
condition: An optional condition function that determines whether the edge
should be traversed based on the message type.
"""
# TODO(@taochen): Support executor factories for lazy initialization
source_exec = self._maybe_wrap_agent(source)
target_exec = self._maybe_wrap_agent(target)
source_id = self._add_executor(source_exec)
target_id = self._add_executor(target_exec)
self._edge_groups.append(SingleEdgeGroup(source_id, target_id, condition)) # type: ignore[call-arg]
return self
def add_fan_out_edges(
self,
source: Executor | AgentProtocol,
targets: Sequence[Executor | AgentProtocol],
) -> Self:
"""Add multiple edges to the workflow where messages from the source will be sent to all target.
The output types of the source and the input types of the targets must be compatible.
Args:
source: The source executor of the edges.
targets: A list of target executors for the edges.
"""
source_exec = self._maybe_wrap_agent(source)
target_execs = [self._maybe_wrap_agent(t) for t in targets]
source_id = self._add_executor(source_exec)
target_ids = [self._add_executor(t) for t in target_execs]
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids)) # type: ignore[call-arg]
return self
def add_switch_case_edge_group(
self,
source: Executor | AgentProtocol,
cases: Sequence[Case | Default],
) -> Self:
"""Add an edge group that represents a switch-case statement.
The output types of the source and the input types of the targets must be compatible.
Messages from the source executor will be sent to one of the target executors based on
the provided conditions.
Think of this as a switch statement where each target executor corresponds to a case.
Each condition function will be evaluated in order, and the first one that returns True
will determine which target executor receives the message.
The last case (the default case) will receive messages that fall through all conditions
(i.e., no condition matched).
Args:
source: The source executor of the edges.
cases: A list of case objects that determine the target executor for each message.
"""
source_exec = self._maybe_wrap_agent(source)
source_id = self._add_executor(source_exec)
# Convert case data types to internal types that only uses target_id.
internal_cases: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = []
for case in cases:
# Allow case targets to be agents
case.target = self._maybe_wrap_agent(case.target) # type: ignore[attr-defined]
self._add_executor(case.target)
if isinstance(case, Default):
internal_cases.append(SwitchCaseEdgeGroupDefault(target_id=case.target.id))
else:
internal_cases.append(SwitchCaseEdgeGroupCase(condition=case.condition, target_id=case.target.id))
self._edge_groups.append(SwitchCaseEdgeGroup(source_id, internal_cases)) # type: ignore[call-arg]
return self
def add_multi_selection_edge_group(
self,
source: Executor | AgentProtocol,
targets: Sequence[Executor | AgentProtocol],
selection_func: Callable[[Any, list[str]], list[str]],
) -> Self:
"""Add an edge group that represents a multi-selection execution model.
The output types of the source and the input types of the targets must be compatible.
Messages from the source executor will be sent to multiple target executors based on
the provided selection function.
The selection function should take a message and the name of the target executors,
and return a list of indices indicating which target executors should receive the message.
Args:
source: The source executor of the edges.
targets: A list of target executors for the edges.
selection_func: A function that selects target executors for messages.
"""
source_exec = self._maybe_wrap_agent(source)
target_execs = [self._maybe_wrap_agent(t) for t in targets]
source_id = self._add_executor(source_exec)
target_ids = [self._add_executor(t) for t in target_execs]
self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids, selection_func)) # type: ignore[call-arg]
return self
def add_fan_in_edges(
self,
sources: Sequence[Executor | AgentProtocol],
target: Executor | AgentProtocol,
) -> Self:
"""Add multiple edges from sources to a single target executor.
The edges will be grouped together for synchronized processing, meaning
the target executor will only be executed once all source executors have completed.
The target executor will receive a list of messages aggregated from all source executors.
Thus the input types of the target executor must be compatible with a list of the output
types of the source executors. For example:
class Target(Executor):
@handler
def handle_messages(self, messages: list[Message]) -> None:
# Process the aggregated messages from all sources
class Source(Executor):
@handler(output_type=[Message])
def handle_message(self, message: Message) -> None:
# Send a message to the target executor
self.send_message(message)
workflow = (
WorkflowBuilder()
.add_fan_in_edges(
[Source(id="source1"), Source(id="source2")],
Target(id="target")
)
.build()
)
Args:
sources: A list of source executors for the edges.
target: The target executor for the edges.
"""
source_execs = [self._maybe_wrap_agent(s) for s in sources]
target_exec = self._maybe_wrap_agent(target)
source_ids = [self._add_executor(s) for s in source_execs]
target_id = self._add_executor(target_exec)
self._edge_groups.append(FanInEdgeGroup(source_ids, target_id)) # type: ignore[call-arg]
return self
def add_chain(self, executors: Sequence[Executor | AgentProtocol]) -> Self:
"""Add a chain of executors to the workflow.
The output of each executor in the chain will be sent to the next executor in the chain.
The input types of each executor must be compatible with the output types of the previous executor.
Circles in the chain are not allowed, meaning the chain cannot have two executors with the same ID.
Args:
executors: A list of executors to be added to the chain.
"""
# Wrap each candidate first to ensure stable IDs before adding edges
wrapped: list[Executor] = [self._maybe_wrap_agent(e) for e in executors]
for i in range(len(wrapped) - 1):
self.add_edge(wrapped[i], wrapped[i + 1])
return self
def set_start_executor(self, executor: Executor | AgentProtocol | str) -> Self:
"""Set the starting executor for the workflow.
Args:
executor: The starting executor, which can be an Executor instance or its ID.
"""
if isinstance(executor, str):
self._start_executor = executor
else:
wrapped = self._maybe_wrap_agent(executor) # type: ignore[arg-type]
self._start_executor = wrapped
# Ensure the start executor is present in the executor map so validation succeeds
# even if no edges are added yet, or before edges wrap the same agent again.
existing = self._executors.get(wrapped.id)
if existing is not wrapped:
self._add_executor(wrapped)
return self
def set_max_iterations(self, max_iterations: int) -> Self:
"""Set the maximum number of iterations for the workflow.
Args:
max_iterations: The maximum number of iterations the workflow will run for convergence.
"""
self._max_iterations = max_iterations
return self
# Removed explicit set_agent_streaming() API; agents always stream updates.
def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> Self:
"""Enable checkpointing with the specified storage.
Args:
checkpoint_storage: The checkpoint storage to use.
"""
self._checkpoint_storage = checkpoint_storage
return self
def build(self) -> Workflow:
"""Build and return the constructed workflow.
This method performs validation before building the workflow.
Returns:
A Workflow instance with the defined edges and starting executor.
Raises:
ValueError: If starting executor is not set.
WorkflowValidationError: If workflow validation fails (includes EdgeDuplicationError,
TypeCompatibilityError, and GraphConnectivityError subclasses).
"""
# Create workflow build span that includes validation and workflow creation
with create_workflow_span(OtelAttr.WORKFLOW_BUILD_SPAN) as span:
try:
# Add workflow build started event
span.add_event(OtelAttr.BUILD_STARTED)
if not self._start_executor:
raise ValueError(
"Starting executor must be set using set_start_executor before building the workflow."
)
# Perform validation before creating the workflow
validate_workflow_graph(
self._edge_groups,
self._executors,
self._start_executor,
duplicate_executor_ids=tuple(self._duplicate_executor_ids),
)
# Add validation completed event
span.add_event(OtelAttr.BUILD_VALIDATION_COMPLETED)
context = InProcRunnerContext(self._checkpoint_storage)
# Create workflow instance after validation
workflow = Workflow(
self._edge_groups,
self._executors,
self._start_executor,
context,
self._max_iterations,
name=self._name,
description=self._description,
)
build_attributes: dict[str, Any] = {
OtelAttr.WORKFLOW_ID: workflow.id,
OtelAttr.WORKFLOW_DEFINITION: workflow.to_json(),
}
if workflow.name:
build_attributes[OtelAttr.WORKFLOW_NAME] = workflow.name
if workflow.description:
build_attributes[OtelAttr.WORKFLOW_DESCRIPTION] = workflow.description
span.set_attributes(build_attributes)
# Add workflow build completed event
span.add_event(OtelAttr.BUILD_COMPLETED)
return workflow
except Exception as exc:
attributes = {
OtelAttr.BUILD_ERROR_MESSAGE: str(exc),
OtelAttr.BUILD_ERROR_TYPE: type(exc).__name__,
}
span.add_event(OtelAttr.BUILD_ERROR, attributes) # type: ignore[reportArgumentType, arg-type]
capture_exception(span, exc)
raise
@@ -435,17 +435,17 @@ class WorkflowContext(Generic[T_Out, T_W_Out]):
"""Get the shared state."""
return self._shared_state
async def set_state(self, state: dict[str, Any]) -> None:
async def set_executor_state(self, state: dict[str, Any]) -> None:
"""Persist this executor's state into the checkpointable context.
Executors call this with a JSON-serializable dict capturing the minimal
state needed to resume. It replaces any previously stored state.
"""
await self._runner_context.set_state(self._executor_id, state)
await self._runner_context.set_executor_state(self._executor_id, state)
async def get_state(self) -> dict[str, Any] | None:
async def get_executor_state(self) -> dict[str, Any] | None:
"""Retrieve previously persisted state for this executor, if any."""
return await self._runner_context.get_state(self._executor_id)
return await self._runner_context.get_executor_state(self._executor_id)
def is_streaming(self) -> bool:
"""Check if the workflow is running in streaming mode.
@@ -167,9 +167,9 @@ class WorkflowExecutor(Executor):
@handler
async def process(self, data: str, ctx: WorkflowContext[str]) -> None:
# Use context state instead of instance variables
state = await ctx.get_state() or {}
state = await ctx.get_executor_state() or {}
state["processed"] = data
await ctx.set_state(state)
await ctx.set_executor_state(state)
# Avoid: Stateful executor with instance variables
@@ -501,7 +501,7 @@ class WorkflowExecutor(Executor):
state: dict[str, Any] | None = None
try:
state = await ctx.get_state()
state = await ctx.get_executor_state()
except Exception:
state = None
@@ -665,6 +665,6 @@ class WorkflowExecutor(Executor):
async def _persist_execution_state(self, ctx: WorkflowContext[Any]) -> None:
snapshot = self._build_state_snapshot()
try:
await ctx.set_state(snapshot)
await ctx.set_executor_state(snapshot)
except Exception as exc: # pragma: no cover - transport specific
logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}")
@@ -4,9 +4,9 @@ from dataclasses import dataclass # noqa: I001
from typing import Any, cast
from agent_framework._workflows._request_info_executor import RequestInfoMessage, RequestResponse
from agent_framework._workflows._runner_context import ( # type: ignore
_decode_checkpoint_value, # type: ignore
_encode_checkpoint_value, # type: ignore
from agent_framework._workflows._checkpoint_encoding import (
decode_checkpoint_value,
encode_checkpoint_value,
)
from agent_framework._workflows._typing_utils import is_instance_of
@@ -23,8 +23,8 @@ def test_decode_dataclass_with_nested_request() -> None:
request_id="abc",
)
encoded = _encode_checkpoint_value(original)
decoded = cast(RequestResponse[SampleRequest, str], _decode_checkpoint_value(encoded))
encoded = encode_checkpoint_value(original)
decoded = cast(RequestResponse[SampleRequest, str], decode_checkpoint_value(encoded))
assert isinstance(decoded, RequestResponse)
assert decoded.data == "approve"
@@ -5,7 +5,8 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from agent_framework._workflows._checkpoint import CheckpointStorage, WorkflowCheckpoint
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value
from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary
from agent_framework._workflows._events import RequestInfoEvent, WorkflowEvent
from agent_framework._workflows._request_info_executor import (
@@ -16,9 +17,8 @@ from agent_framework._workflows._request_info_executor import (
RequestResponse,
)
from agent_framework._workflows._runner_context import (
CheckpointState,
Message,
_encode_checkpoint_value, # type: ignore
WorkflowState,
)
from agent_framework._workflows._shared_state import SharedState
from agent_framework._workflows._workflow_context import WorkflowContext
@@ -53,10 +53,10 @@ class _StubRunnerContext:
async def next_event(self) -> WorkflowEvent: # pragma: no cover - unused
raise RuntimeError("Not implemented in stub context")
async def get_state(self, executor_id: str) -> dict[str, Any] | None: # pragma: no cover - trivial
async def get_executor_state(self, executor_id: str) -> dict[str, Any] | None: # pragma: no cover - trivial
return self._state
async def set_state(self, executor_id: str, state: dict[str, Any]) -> None: # pragma: no cover - unused
async def set_executor_state(self, executor_id: str, state: dict[str, Any]) -> None: # pragma: no cover - unused
self._state = state
def has_checkpointing(self) -> bool: # pragma: no cover - unused
@@ -71,20 +71,13 @@ class _StubRunnerContext:
async def create_checkpoint(self, metadata: dict[str, Any] | None = None) -> str: # pragma: no cover - unused
raise RuntimeError("Checkpointing not supported in stub context")
async def restore_from_checkpoint(
self,
checkpoint_id: str,
checkpoint_storage: CheckpointStorage | None = None,
) -> bool: # pragma: no cover - unused
return False
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pragma: no cover - unused
return None
async def get_checkpoint_state(self) -> CheckpointState: # pragma: no cover - unused
async def get_workflow_state(self) -> WorkflowState: # pragma: no cover - unused
return {} # type: ignore[return-value]
async def set_checkpoint_state(self, state: CheckpointState) -> None: # pragma: no cover - unused
async def set_workflow_state(self, state: WorkflowState) -> None: # pragma: no cover - unused
pass
def set_streaming(self, streaming: bool) -> None: # pragma: no cover - unused
@@ -178,7 +171,7 @@ def test_pending_requests_from_checkpoint_and_summary() -> None:
request_id=request.request_id,
)
encoded_response = _encode_checkpoint_value(response)
encoded_response = encode_checkpoint_value(response)
checkpoint = WorkflowCheckpoint(
checkpoint_id="cp-1",
@@ -439,8 +439,8 @@ async def test_message_trace_context_serialization(span_exporter: InMemorySpanEx
await ctx.send_message(message)
# Get checkpoint state (which serializes messages)
state = await ctx.get_checkpoint_state()
# Get context state (which serializes messages)
state = await ctx.get_workflow_state()
# Check serialized message includes trace context
serialized_msg = state["messages"]["source"][0]
@@ -448,7 +448,7 @@ async def test_message_trace_context_serialization(span_exporter: InMemorySpanEx
assert serialized_msg["source_span_ids"] == ["span123"]
# Test deserialization
await ctx.set_checkpoint_state(state)
await ctx.set_workflow_state(state)
restored_messages = await ctx.drain_messages()
restored_msg = list(restored_messages.values())[0][0]
@@ -140,8 +140,8 @@ class ReviewGateway(Executor):
# persist iterations. The `RequestInfoExecutor` relies on this state to
# rehydrate when checkpoints are restored.
draft = response.agent_run_response.text or ""
iteration = int((await ctx.get_state() or {}).get("iteration", 0)) + 1
await ctx.set_state({"iteration": iteration, "last_draft": draft})
iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
# Emit a human approval request. Because this flows through
# RequestInfoExecutor it will pause the workflow until an answer is
# supplied either interactively or via pre-supplied responses.
@@ -163,7 +163,7 @@ class ReviewGateway(Executor):
# The RequestResponse wrapper gives us both the human data and the
# original request message, even when resuming from checkpoints.
reply = (feedback.data or "").strip()
state = await ctx.get_state() or {}
state = await ctx.get_executor_state() or {}
draft = state.get("last_draft") or (feedback.original_request.draft if feedback.original_request else "")
if reply.lower() == "approve":
@@ -175,7 +175,7 @@ class ReviewGateway(Executor):
# Any other response loops us back to the writer with fresh guidance.
guidance = reply or "Tighten the copy and emphasise customer benefit."
iteration = int(state.get("iteration", 1)) + 1
await ctx.set_state({"iteration": iteration, "last_draft": draft})
await ctx.set_executor_state({"iteration": iteration, "last_draft": draft})
prompt = (
"Revise the launch note. Respond with the new copy only.\n\n"
f"Previous draft:\n{draft}\n\n"
@@ -193,7 +193,7 @@ class FinaliseExecutor(Executor):
@handler
async def publish(self, text: str, ctx: WorkflowContext[Any, str]) -> None:
# Store the output so diagnostics or a UI could fetch the final copy.
await ctx.set_state({"published_text": text})
await ctx.set_executor_state({"published_text": text})
# Yield the final output so the workflow completes cleanly.
await ctx.yield_output(text)
@@ -42,7 +42,7 @@ Pipeline:
5) FinalizeFromAgent yields the final result.
What you learn:
- How to persist executor state using ctx.get_state and ctx.set_state.
- How to persist executor state using ctx.get_executor_state and ctx.set_executor_state.
- How to persist shared workflow state using ctx.set_shared_state for cross-executor visibility.
- How to configure FileCheckpointStorage and call with_checkpointing on WorkflowBuilder.
- How to list and inspect checkpoints programmatically.
@@ -73,9 +73,9 @@ class UpperCaseExecutor(Executor):
# Persist executor-local state so it is captured in checkpoints
# and available after resume for observability or logic.
prev = await ctx.get_state() or {}
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_state({
await ctx.set_executor_state({
"count": count,
"last_input": text,
"last_output": result,
@@ -122,9 +122,9 @@ class FinalizeFromAgent(Executor):
result = response.agent_run_response.text or ""
# Persist executor-local state for auditability when inspecting checkpoints.
prev = await ctx.get_state() or {}
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_state({
await ctx.set_executor_state({
"count": count,
"last_output": result,
"final": True,
@@ -143,9 +143,9 @@ class ReverseTextExecutor(Executor):
print(f"ReverseTextExecutor: '{text}' -> '{result}'")
# Persist executor-local state so checkpoint inspection can reveal progress.
prev = await ctx.get_state() or {}
prev = await ctx.get_executor_state() or {}
count = int(prev.get("count", 0)) + 1
await ctx.set_state({
await ctx.set_executor_state({
"count": count,
"last_input": text,
"last_output": result,
@@ -144,7 +144,7 @@ async def run_semantic_kernel_process_example() -> None:
kernel=kernel,
initial_event=KernelProcessEvent(id=CommonEvents.START_PROCESS.value, data="Initial"),
) as process_context:
process_state = await process_context.get_state()
process_state = await process_context.get_executor_state()
c_step_state: KernelProcessStepState[CStepState] | None = next(
(s.state for s in process_state.steps if s.state.name == "CStep"),
None,
@@ -136,7 +136,7 @@ async def run_semantic_kernel_nested_process() -> None:
initial_event=ProcessEvents.START_PROCESS.value,
data="Test",
)
process_info = await process_handle.get_state()
process_info = await process_handle.get_executor_state()
inner_process: KernelProcess | None = next(
(s for s in process_info.steps if s.state.name == "Inner"),