Move runner state management out of Workflow

This commit is contained in:
Tao Chen
2026-06-05 16:29:19 -07:00
Unverified
parent dcc218dbac
commit c5e6a7797f
4 changed files with 251 additions and 95 deletions
@@ -5,6 +5,7 @@ import contextlib
import logging
from collections import defaultdict
from collections.abc import AsyncGenerator, Sequence
from enum import Enum
from typing import Any
from ..exceptions import (
@@ -27,6 +28,25 @@ from ._state import State
logger = logging.getLogger(__name__)
class _RunnerLifecycle(Enum):
"""Lifecycle of a single :class:`Runner` invocation.
Three states keep concurrent-run protection in one place while still letting
:meth:`Workflow.run` reserve the runner synchronously (before any await):
* ``IDLE`` - no run in progress.
* ``RESERVED`` - :meth:`Runner.reserve` was called by an external caller but
:meth:`Runner.run_until_convergence` has not yet entered its body. This is the
window during which ``Workflow.run`` has handed back a ``ResponseStream`` that
hasn't been iterated yet. ``run_until_convergence`` requires this state.
* ``RUNNING`` - the body of :meth:`Runner.run_until_convergence` is executing.
"""
IDLE = "idle"
RESERVED = "reserved"
RUNNING = "running"
class Runner:
"""A class to run a workflow in Pregel supersteps."""
@@ -63,7 +83,7 @@ class Runner:
self._iteration = 0
self._max_iterations = max_iterations
self._state = state
self._running = False
self._lifecycle: _RunnerLifecycle = _RunnerLifecycle.IDLE
self._resumed_from_checkpoint = False # Track whether we resumed
@property
@@ -71,16 +91,48 @@ class Runner:
"""Get the workflow context."""
return self._ctx
def reserve(self) -> None:
"""Synchronously reserve the runner for an upcoming run.
This is the **only** way to acquire the run lock. :meth:`run_until_convergence`
requires the runner to be in :attr:`_RunnerLifecycle.RESERVED` and will refuse
to start otherwise. Reserving synchronously lets callers (notably
:meth:`Workflow.run`) reject concurrent runs *before* any ``await`` or
async-generator suspension - otherwise a second caller could slip past a
flag-based guard while the first is still suspended above
:meth:`run_until_convergence`. The lock is released by
:meth:`run_until_convergence` in its ``finally`` clause once it begins, or by
:meth:`release` when the run never starts (for example, if early validation
raises before ``run_until_convergence`` is reached).
Raises:
WorkflowRunnerException: If the runner is already reserved or running.
"""
if self._lifecycle is not _RunnerLifecycle.IDLE:
raise WorkflowRunnerException("Runner is already running.")
self._lifecycle = _RunnerLifecycle.RESERVED
def release(self) -> None:
"""Release the runner's run lock. Idempotent; safe to call when already idle."""
self._lifecycle = _RunnerLifecycle.IDLE
def reset_iteration_count(self) -> None:
"""Reset the iteration count to zero."""
self._iteration = 0
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
"""Run the workflow until no more messages are sent."""
if self._running:
raise WorkflowRunnerException("Runner is already running.")
# Mandatory reservation: callers must reserve() the runner first. This makes
# ``reserve`` the single entry point that takes the run lock, so all
# concurrent-run rejection happens there. Any non-RESERVED state here is a
# contract violation - either the caller forgot to reserve, or another run
# is already in progress.
if self._lifecycle is not _RunnerLifecycle.RESERVED:
raise WorkflowRunnerException(
"Runner must be reserved via Runner.reserve() before calling run_until_convergence()."
)
self._lifecycle = _RunnerLifecycle.RUNNING
self._running = True
previous_checkpoint_id: CheckpointID | None = None
try:
# Emit any events already produced prior to entering loop
@@ -155,7 +207,7 @@ class Runner:
logger.info(f"Workflow completed after {self._iteration} supersteps")
self._resumed_from_checkpoint = False # Reset resume flag for next run
finally:
self._running = False
self._lifecycle = _RunnerLifecycle.IDLE
async def _run_iteration(self) -> None:
"""Run a single iteration of the workflow.
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
from .._sessions import ContextProvider
from .._types import ResponseStream
from ..exceptions import WorkflowException, WorkflowRunnerException
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._checkpoint import CheckpointStorage
from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
@@ -357,9 +358,6 @@ class Workflow(DictConvertible):
max_iterations=max_iterations,
)
# Flag to prevent concurrent workflow executions
self._is_running = False
# Current run-level status of this workflow instance. Updated in lockstep with
# the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE
# for a freshly built workflow that has not yet been run.
@@ -376,16 +374,6 @@ class Workflow(DictConvertible):
"""
return self._status
def _ensure_not_running(self) -> None:
"""Ensure the workflow is not already running."""
if self._is_running:
raise RuntimeError("Workflow is already running. Concurrent executions are not allowed.")
self._is_running = True
def _reset_running_flag(self) -> None:
"""Reset the running flag."""
self._is_running = False
def to_dict(self) -> dict[str, Any]:
"""Serialize the workflow definition into a JSON-ready dictionary."""
data: dict[str, Any] = {
@@ -745,9 +733,19 @@ class Workflow(DictConvertible):
Raises:
ValueError: If parameter combination is invalid.
"""
# Validate parameters and set running flag eagerly (before any async work)
# Validate parameters first so misuse fails before we touch any run state.
self._validate_run_params(message, responses, checkpoint_id)
self._ensure_not_running()
# Acquire the run lock synchronously - before constructing the ResponseStream
# or yielding control to the event loop - so a second concurrent ``run`` call
# is rejected immediately rather than slipping past the guard while the first
# call is suspended inside its async generator.
try:
self._runner.reserve()
except WorkflowRunnerException as exc:
raise WorkflowException(
"Workflow is already running; concurrent runs are not allowed on the same instance."
) from exc
response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult](
self._run_core(
@@ -760,9 +758,6 @@ class Workflow(DictConvertible):
client_kwargs=client_kwargs,
),
finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events),
cleanup_hooks=[
functools.partial(self._run_cleanup, checkpoint_storage),
],
)
if stream:
@@ -789,51 +784,55 @@ class Workflow(DictConvertible):
if checkpoint_storage is not None:
self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage)
# Async validation: a fresh-message run is only allowed when the
# runner context has fully drained from any prior run. If it still
# has in-flight executor messages, the prior run didn't complete -
# the caller must either resume from a checkpoint or wait for the
# prior run to drain. (Pending request_info events are intentionally
# NOT blocked here: a follow-up run with message=... is the normal
# way to deliver a response to those pending requests, e.g. via
# WorkflowAgent._process_pending_requests.)
# NOTE: _validate_run_params already enforces that ``message`` is
# mutually exclusive with both ``checkpoint_id`` and ``responses``,
# so we don't need to re-check those here.
if message is not None and await self._runner.context.has_messages():
raise RuntimeError(
"Cannot start a new run with 'message' while in-flight executor "
"messages remain from a prior run. Resume from a checkpoint "
"(checkpoint_id=...) or wait for the prior run to complete. "
"Workflows that need to recover from a mid-run failure must use "
"checkpointing; there is no in-process recovery path."
)
try:
# Async validation: a fresh-message run is only allowed when the
# runner context has fully drained from any prior run. If it still
# has in-flight executor messages, the prior run didn't complete -
# the caller must either resume from a checkpoint or wait for the
# prior run to drain. (Pending request_info events are intentionally
# NOT blocked here: a follow-up run with message=... is the normal
# way to deliver a response to those pending requests, e.g. via
# WorkflowAgent._process_pending_requests.)
# NOTE: _validate_run_params already enforces that ``message`` is
# mutually exclusive with both ``checkpoint_id`` and ``responses``,
# so we don't need to re-check those here.
if message is not None and await self._runner.context.has_messages():
raise RuntimeError(
"Cannot start a new run with 'message' while in-flight executor "
"messages remain from a prior run. Resume from a checkpoint "
"(checkpoint_id=...) or wait for the prior run to complete. "
"Workflows that need to recover from a mid-run failure must use "
"checkpointing; there is no in-process recovery path."
)
initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage)
initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage)
async for event in self._run_workflow_with_tracing(
initial_executor_fn=initial_executor_fn,
is_continuation=(message is None),
streaming=streaming,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
if event.type == "request_info" and event.request_id in (responses or {}):
# Don't yield request_info events for which we have responses to send -
# these are considered "handled". This prevents the caller from seeing
# events for requests they are already responding to.
# This usually happens when responses are provided with a checkpoint
# (restore then send), because the request_info events are stored in the
# checkpoint and would be emitted on restoration by the runner regardless
# of if a response is provided or not.
continue
yield event
async def _run_cleanup(self, checkpoint_storage: CheckpointStorage | None) -> None:
"""Cleanup hook called after stream consumption."""
if checkpoint_storage is not None:
self._runner.context.clear_runtime_checkpoint_storage()
self._reset_running_flag()
async for event in self._run_workflow_with_tracing(
initial_executor_fn=initial_executor_fn,
is_continuation=(message is None),
streaming=streaming,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
):
if event.type == "request_info" and event.request_id in (responses or {}):
# Don't yield request_info events for which we have responses to send -
# these are considered "handled". This prevents the caller from seeing
# events for requests they are already responding to.
# This usually happens when responses are provided with a checkpoint
# (restore then send), because the request_info events are stored in the
# checkpoint and would be emitted on restoration by the runner regardless
# of if a response is provided or not.
continue
yield event
finally:
# Release the run lock acquired synchronously by ``run()`` via ``reserve``.
# ``run_until_convergence`` also clears it in its own ``finally`` once its
# body runs; this clause additionally covers the case where this generator
# raises (or is closed) before that body is reached - e.g. the in-flight
# messages check above. ``release`` is idempotent.
self._runner.release()
if checkpoint_storage is not None:
self._runner.context.clear_runtime_checkpoint_storage()
@staticmethod
def _finalize_events(
@@ -106,6 +106,7 @@ async def test_runner_run_until_convergence():
state, # state
ctx, # runner_context
)
runner.reserve()
async for event in runner.run_until_convergence():
assert isinstance(event, WorkflowEvent)
if event.type == "output":
@@ -147,6 +148,7 @@ async def test_runner_run_until_convergence_not_completed():
WorkflowConvergenceException,
match="Runner did not converge after 5 iterations.",
):
runner.reserve()
async for event in runner.run_until_convergence():
assert event.type != "status" or event.state != WorkflowRunState.IDLE
@@ -305,40 +307,137 @@ async def test_fanout_edge_runner_delivers_to_multiple_targets_concurrently() ->
assert probe_target.call_count == 1
async def test_runner_already_running():
"""Test that running the runner while it is already running raises an error."""
async def test_runner_run_until_convergence_requires_reservation():
"""run_until_convergence refuses to start without a prior reserve()."""
runner = _make_runner()
with pytest.raises(WorkflowRunnerException, match="Runner must be reserved"):
async for _ in runner.run_until_convergence():
pass
def _make_runner() -> Runner:
"""Build a minimal runner for lifecycle tests."""
return Runner(
[],
{},
State(),
InProcRunnerContext(),
"test_name",
graph_signature_hash="test_hash",
)
def test_runner_reserve_twice_raises():
"""Calling reserve() while already reserved rejects the second caller.
This is what guards Workflow.run against a concurrent caller slipping in
between the first call's synchronous reserve() and its first await.
"""
runner = _make_runner()
runner.reserve()
with pytest.raises(WorkflowRunnerException, match="Runner is already running."):
runner.reserve()
def test_runner_reserve_after_release_is_accepted():
"""Sequential runs are permitted; only concurrent ones are blocked."""
runner = _make_runner()
runner.reserve()
runner.release()
runner.reserve() # should not raise
def test_runner_release_when_idle_is_noop():
"""release() on an idle runner does not affect a subsequent reserve().
Workflow._run_core's finally always calls release(), even when
run_until_convergence already cleared the lock in its own finally;
that double-release must not lock out the next run.
"""
runner = _make_runner()
runner.release() # already idle - must not raise or wedge state
runner.reserve() # next run still allowed
async def test_runner_run_until_convergence_consumes_reservation():
"""run_until_convergence accepts a prior reservation and runs to completion."""
runner = _make_runner()
runner.reserve()
async for _ in runner.run_until_convergence():
pass
# A second run after the first completes must be accepted.
runner.reserve()
async for _ in runner.run_until_convergence():
pass
async def test_runner_accepts_new_run_after_previous_failure():
"""A failed run must not leave the runner locked out of future runs.
After the first run raises, a fresh ``reserve()`` and
``run_until_convergence()`` must succeed (or fail for a different reason -
e.g. residual messages still don't converge - but never with the
lock-rejection ``"Runner is already running."``).
"""
executor_a = MockExecutor(id="executor_a")
executor_b = MockExecutor(id="executor_b")
# Create a loop
edges = [
SingleEdgeGroup(executor_a.id, executor_b.id),
SingleEdgeGroup(executor_b.id, executor_a.id),
]
executors: dict[str, Executor] = {
executor_a.id: executor_a,
executor_b.id: executor_b,
}
executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b}
state = State()
ctx = InProcRunnerContext()
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash", max_iterations=2)
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
await executor_a.execute(
MockMessage(data=0),
["START"], # source_executor_ids
state, # state
ctx, # runner_context
)
runner.reserve()
with pytest.raises(WorkflowConvergenceException):
async for _ in runner.run_until_convergence():
pass
with pytest.raises(WorkflowRunnerException, match="Runner is already running."):
# The runner should accept a fresh reservation and run again.
runner.reserve() # must not raise
try:
async for _ in runner.run_until_convergence():
pass
except Exception as exc:
assert "Runner is already running" not in str(exc), "Runner stayed locked after a failed run"
async def _run():
async for _ in runner.run_until_convergence():
pass
await asyncio.gather(_run(), _run())
async def test_runner_rejects_concurrent_run_until_convergence():
"""While a run is in progress, a second ``reserve()`` is rejected.
Confirms the run lock is held for the full duration of the run, not just
synchronously between ``reserve()`` and the first ``__anext__`` call.
"""
runner = _make_runner()
started = asyncio.Event()
release = asyncio.Event()
async def _slow_run():
runner.reserve()
async for _ in runner.run_until_convergence():
if not started.is_set():
started.set()
await release.wait()
task = asyncio.create_task(_slow_run())
await started.wait() # first run is now executing
try:
with pytest.raises(WorkflowRunnerException, match="Runner is already running."):
runner.reserve()
finally:
release.set()
await task
# And after the first run finishes, a new reservation + run must be accepted.
runner.reserve()
async for _ in runner.run_until_convergence():
pass
async def test_runner_emits_runner_completion_for_agent_response_without_targets():
@@ -352,6 +451,7 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets
)
)
runner.reserve()
events: list[WorkflowEvent] = [event async for event in runner.run_until_convergence()]
# The runner should complete without errors when handling AgentExecutorResponse without targets
# No specific events are expected since there are no executors to process the message
@@ -408,6 +508,7 @@ async def test_runner_cancellation_stops_active_executor():
async for _ in runner.run_until_convergence():
pass
runner.reserve()
task = asyncio.create_task(run_workflow())
# Wait for executor_a to complete (0.3s) and executor_b to start but not finish
@@ -469,6 +570,7 @@ async def test_runner_iteration_exception_drains_events():
)
events: list[WorkflowEvent] = []
runner.reserve()
with pytest.raises(RuntimeError, match="Simulated executor failure"):
async for event in runner.run_until_convergence():
events.append(event)
@@ -579,6 +681,7 @@ async def test_runner_checkpoint_creation_failure():
# Should complete without raising, even though checkpointing fails
result: int | None = None
runner.reserve()
async for event in runner.run_until_convergence():
if event.type == "output":
result = event.data
@@ -775,6 +878,7 @@ async def test_runner_with_pre_loop_events():
await ctx.add_event(WorkflowEvent("output", executor_id="test_executor", data="pre-loop-output"))
events: list[WorkflowEvent] = []
runner.reserve()
async for event in runner.run_until_convergence():
events.append(event)
@@ -822,6 +926,7 @@ async def test_runner_drains_straggler_events():
)
events: list[WorkflowEvent] = []
runner.reserve()
async for event in runner.run_until_convergence():
events.append(event)
@@ -875,6 +980,7 @@ async def test_runner_checkpoint_with_resumed_flag():
)
# Run until convergence
runner.reserve()
async for _ in runner.run_until_convergence():
pass
@@ -941,6 +1047,7 @@ async def test_runner_drains_events_on_iteration_exception():
)
events: list[WorkflowEvent] = []
runner.reserve()
with pytest.raises(RuntimeError, match="Executor failed with pending events"):
async for event in runner.run_until_convergence():
events.append(event)
@@ -997,6 +1104,7 @@ async def test_runner_drains_straggler_events_at_iteration_end():
)
events: list[WorkflowEvent] = []
runner.reserve()
async for event in runner.run_until_convergence():
events.append(event)
@@ -26,6 +26,7 @@ from agent_framework import (
WorkflowContext,
WorkflowConvergenceException,
WorkflowEvent,
WorkflowException,
WorkflowMessage,
WorkflowRunState,
handler,
@@ -759,8 +760,7 @@ async def test_workflow_concurrent_execution_prevention():
# Try to start a second concurrent execution - this should fail
with pytest.raises(
RuntimeError,
match="Workflow is already running. Concurrent executions are not allowed.",
WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance."
):
await workflow.run(NumberMessage(data=0))
@@ -795,8 +795,7 @@ async def test_workflow_concurrent_execution_prevention_streaming():
# Try to start a second concurrent execution - this should fail
with pytest.raises(
RuntimeError,
match="Workflow is already running. Concurrent executions are not allowed.",
WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance."
):
await workflow.run(NumberMessage(data=0))
@@ -828,14 +827,12 @@ async def test_workflow_concurrent_execution_prevention_mixed_methods():
# Try different execution methods - all should fail
with pytest.raises(
RuntimeError,
match="Workflow is already running. Concurrent executions are not allowed.",
WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance."
):
await workflow.run(NumberMessage(data=0))
with pytest.raises(
RuntimeError,
match="Workflow is already running. Concurrent executions are not allowed.",
WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance."
):
async for _ in workflow.run(NumberMessage(data=0), stream=True):
break