From c5e6a7797f3ef024aee8d43823cf911d428c9385 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 5 Jun 2026 16:29:19 -0700 Subject: [PATCH] Move runner state management out of Workflow --- .../agent_framework/_workflows/_runner.py | 62 +++++++- .../agent_framework/_workflows/_workflow.py | 121 +++++++------- .../core/tests/workflow/test_runner.py | 150 +++++++++++++++--- .../core/tests/workflow/test_workflow.py | 13 +- 4 files changed, 251 insertions(+), 95 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 51a3312e2b..3e1b1f3b50 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -5,6 +5,7 @@ import contextlib import logging from collections import defaultdict from collections.abc import AsyncGenerator, Sequence +from enum import Enum from typing import Any from ..exceptions import ( @@ -27,6 +28,25 @@ from ._state import State logger = logging.getLogger(__name__) +class _RunnerLifecycle(Enum): + """Lifecycle of a single :class:`Runner` invocation. + + Three states keep concurrent-run protection in one place while still letting + :meth:`Workflow.run` reserve the runner synchronously (before any await): + + * ``IDLE`` - no run in progress. + * ``RESERVED`` - :meth:`Runner.reserve` was called by an external caller but + :meth:`Runner.run_until_convergence` has not yet entered its body. This is the + window during which ``Workflow.run`` has handed back a ``ResponseStream`` that + hasn't been iterated yet. ``run_until_convergence`` requires this state. + * ``RUNNING`` - the body of :meth:`Runner.run_until_convergence` is executing. + """ + + IDLE = "idle" + RESERVED = "reserved" + RUNNING = "running" + + class Runner: """A class to run a workflow in Pregel supersteps.""" @@ -63,7 +83,7 @@ class Runner: self._iteration = 0 self._max_iterations = max_iterations self._state = state - self._running = False + self._lifecycle: _RunnerLifecycle = _RunnerLifecycle.IDLE self._resumed_from_checkpoint = False # Track whether we resumed @property @@ -71,16 +91,48 @@ class Runner: """Get the workflow context.""" return self._ctx + def reserve(self) -> None: + """Synchronously reserve the runner for an upcoming run. + + This is the **only** way to acquire the run lock. :meth:`run_until_convergence` + requires the runner to be in :attr:`_RunnerLifecycle.RESERVED` and will refuse + to start otherwise. Reserving synchronously lets callers (notably + :meth:`Workflow.run`) reject concurrent runs *before* any ``await`` or + async-generator suspension - otherwise a second caller could slip past a + flag-based guard while the first is still suspended above + :meth:`run_until_convergence`. The lock is released by + :meth:`run_until_convergence` in its ``finally`` clause once it begins, or by + :meth:`release` when the run never starts (for example, if early validation + raises before ``run_until_convergence`` is reached). + + Raises: + WorkflowRunnerException: If the runner is already reserved or running. + """ + if self._lifecycle is not _RunnerLifecycle.IDLE: + raise WorkflowRunnerException("Runner is already running.") + self._lifecycle = _RunnerLifecycle.RESERVED + + def release(self) -> None: + """Release the runner's run lock. Idempotent; safe to call when already idle.""" + self._lifecycle = _RunnerLifecycle.IDLE + def reset_iteration_count(self) -> None: """Reset the iteration count to zero.""" self._iteration = 0 async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: """Run the workflow until no more messages are sent.""" - if self._running: - raise WorkflowRunnerException("Runner is already running.") + # Mandatory reservation: callers must reserve() the runner first. This makes + # ``reserve`` the single entry point that takes the run lock, so all + # concurrent-run rejection happens there. Any non-RESERVED state here is a + # contract violation - either the caller forgot to reserve, or another run + # is already in progress. + if self._lifecycle is not _RunnerLifecycle.RESERVED: + raise WorkflowRunnerException( + "Runner must be reserved via Runner.reserve() before calling run_until_convergence()." + ) + self._lifecycle = _RunnerLifecycle.RUNNING - self._running = True previous_checkpoint_id: CheckpointID | None = None try: # Emit any events already produced prior to entering loop @@ -155,7 +207,7 @@ class Runner: logger.info(f"Workflow completed after {self._iteration} supersteps") self._resumed_from_checkpoint = False # Reset resume flag for next run finally: - self._running = False + self._lifecycle = _RunnerLifecycle.IDLE async def _run_iteration(self) -> None: """Run a single iteration of the workflow. diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index c4840bb045..780669272d 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload from .._sessions import ContextProvider from .._types import ResponseStream +from ..exceptions import WorkflowException, WorkflowRunnerException from ..observability import OtelAttr, capture_exception, create_workflow_span from ._checkpoint import CheckpointStorage from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY @@ -357,9 +358,6 @@ class Workflow(DictConvertible): max_iterations=max_iterations, ) - # Flag to prevent concurrent workflow executions - self._is_running = False - # Current run-level status of this workflow instance. Updated in lockstep with # the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE # for a freshly built workflow that has not yet been run. @@ -376,16 +374,6 @@ class Workflow(DictConvertible): """ return self._status - def _ensure_not_running(self) -> None: - """Ensure the workflow is not already running.""" - if self._is_running: - raise RuntimeError("Workflow is already running. Concurrent executions are not allowed.") - self._is_running = True - - def _reset_running_flag(self) -> None: - """Reset the running flag.""" - self._is_running = False - def to_dict(self) -> dict[str, Any]: """Serialize the workflow definition into a JSON-ready dictionary.""" data: dict[str, Any] = { @@ -745,9 +733,19 @@ class Workflow(DictConvertible): Raises: ValueError: If parameter combination is invalid. """ - # Validate parameters and set running flag eagerly (before any async work) + # Validate parameters first so misuse fails before we touch any run state. self._validate_run_params(message, responses, checkpoint_id) - self._ensure_not_running() + + # Acquire the run lock synchronously - before constructing the ResponseStream + # or yielding control to the event loop - so a second concurrent ``run`` call + # is rejected immediately rather than slipping past the guard while the first + # call is suspended inside its async generator. + try: + self._runner.reserve() + except WorkflowRunnerException as exc: + raise WorkflowException( + "Workflow is already running; concurrent runs are not allowed on the same instance." + ) from exc response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( self._run_core( @@ -760,9 +758,6 @@ class Workflow(DictConvertible): client_kwargs=client_kwargs, ), finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), - cleanup_hooks=[ - functools.partial(self._run_cleanup, checkpoint_storage), - ], ) if stream: @@ -789,51 +784,55 @@ class Workflow(DictConvertible): if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) - # Async validation: a fresh-message run is only allowed when the - # runner context has fully drained from any prior run. If it still - # has in-flight executor messages, the prior run didn't complete - - # the caller must either resume from a checkpoint or wait for the - # prior run to drain. (Pending request_info events are intentionally - # NOT blocked here: a follow-up run with message=... is the normal - # way to deliver a response to those pending requests, e.g. via - # WorkflowAgent._process_pending_requests.) - # NOTE: _validate_run_params already enforces that ``message`` is - # mutually exclusive with both ``checkpoint_id`` and ``responses``, - # so we don't need to re-check those here. - if message is not None and await self._runner.context.has_messages(): - raise RuntimeError( - "Cannot start a new run with 'message' while in-flight executor " - "messages remain from a prior run. Resume from a checkpoint " - "(checkpoint_id=...) or wait for the prior run to complete. " - "Workflows that need to recover from a mid-run failure must use " - "checkpointing; there is no in-process recovery path." - ) + try: + # Async validation: a fresh-message run is only allowed when the + # runner context has fully drained from any prior run. If it still + # has in-flight executor messages, the prior run didn't complete - + # the caller must either resume from a checkpoint or wait for the + # prior run to drain. (Pending request_info events are intentionally + # NOT blocked here: a follow-up run with message=... is the normal + # way to deliver a response to those pending requests, e.g. via + # WorkflowAgent._process_pending_requests.) + # NOTE: _validate_run_params already enforces that ``message`` is + # mutually exclusive with both ``checkpoint_id`` and ``responses``, + # so we don't need to re-check those here. + if message is not None and await self._runner.context.has_messages(): + raise RuntimeError( + "Cannot start a new run with 'message' while in-flight executor " + "messages remain from a prior run. Resume from a checkpoint " + "(checkpoint_id=...) or wait for the prior run to complete. " + "Workflows that need to recover from a mid-run failure must use " + "checkpointing; there is no in-process recovery path." + ) - initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage) + initial_executor_fn = self._resolve_execution_mode(message, responses, checkpoint_id, checkpoint_storage) - async for event in self._run_workflow_with_tracing( - initial_executor_fn=initial_executor_fn, - is_continuation=(message is None), - streaming=streaming, - function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=client_kwargs, - ): - if event.type == "request_info" and event.request_id in (responses or {}): - # Don't yield request_info events for which we have responses to send - - # these are considered "handled". This prevents the caller from seeing - # events for requests they are already responding to. - # This usually happens when responses are provided with a checkpoint - # (restore then send), because the request_info events are stored in the - # checkpoint and would be emitted on restoration by the runner regardless - # of if a response is provided or not. - continue - yield event - - async def _run_cleanup(self, checkpoint_storage: CheckpointStorage | None) -> None: - """Cleanup hook called after stream consumption.""" - if checkpoint_storage is not None: - self._runner.context.clear_runtime_checkpoint_storage() - self._reset_running_flag() + async for event in self._run_workflow_with_tracing( + initial_executor_fn=initial_executor_fn, + is_continuation=(message is None), + streaming=streaming, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + ): + if event.type == "request_info" and event.request_id in (responses or {}): + # Don't yield request_info events for which we have responses to send - + # these are considered "handled". This prevents the caller from seeing + # events for requests they are already responding to. + # This usually happens when responses are provided with a checkpoint + # (restore then send), because the request_info events are stored in the + # checkpoint and would be emitted on restoration by the runner regardless + # of if a response is provided or not. + continue + yield event + finally: + # Release the run lock acquired synchronously by ``run()`` via ``reserve``. + # ``run_until_convergence`` also clears it in its own ``finally`` once its + # body runs; this clause additionally covers the case where this generator + # raises (or is closed) before that body is reached - e.g. the in-flight + # messages check above. ``release`` is idempotent. + self._runner.release() + if checkpoint_storage is not None: + self._runner.context.clear_runtime_checkpoint_storage() @staticmethod def _finalize_events( diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 4fef26bd2d..6756bb9a9f 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -106,6 +106,7 @@ async def test_runner_run_until_convergence(): state, # state ctx, # runner_context ) + runner.reserve() async for event in runner.run_until_convergence(): assert isinstance(event, WorkflowEvent) if event.type == "output": @@ -147,6 +148,7 @@ async def test_runner_run_until_convergence_not_completed(): WorkflowConvergenceException, match="Runner did not converge after 5 iterations.", ): + runner.reserve() async for event in runner.run_until_convergence(): assert event.type != "status" or event.state != WorkflowRunState.IDLE @@ -305,40 +307,137 @@ async def test_fanout_edge_runner_delivers_to_multiple_targets_concurrently() -> assert probe_target.call_count == 1 -async def test_runner_already_running(): - """Test that running the runner while it is already running raises an error.""" +async def test_runner_run_until_convergence_requires_reservation(): + """run_until_convergence refuses to start without a prior reserve().""" + runner = _make_runner() + with pytest.raises(WorkflowRunnerException, match="Runner must be reserved"): + async for _ in runner.run_until_convergence(): + pass + + +def _make_runner() -> Runner: + """Build a minimal runner for lifecycle tests.""" + return Runner( + [], + {}, + State(), + InProcRunnerContext(), + "test_name", + graph_signature_hash="test_hash", + ) + + +def test_runner_reserve_twice_raises(): + """Calling reserve() while already reserved rejects the second caller. + + This is what guards Workflow.run against a concurrent caller slipping in + between the first call's synchronous reserve() and its first await. + """ + runner = _make_runner() + runner.reserve() + with pytest.raises(WorkflowRunnerException, match="Runner is already running."): + runner.reserve() + + +def test_runner_reserve_after_release_is_accepted(): + """Sequential runs are permitted; only concurrent ones are blocked.""" + runner = _make_runner() + runner.reserve() + runner.release() + runner.reserve() # should not raise + + +def test_runner_release_when_idle_is_noop(): + """release() on an idle runner does not affect a subsequent reserve(). + + Workflow._run_core's finally always calls release(), even when + run_until_convergence already cleared the lock in its own finally; + that double-release must not lock out the next run. + """ + runner = _make_runner() + runner.release() # already idle - must not raise or wedge state + runner.reserve() # next run still allowed + + +async def test_runner_run_until_convergence_consumes_reservation(): + """run_until_convergence accepts a prior reservation and runs to completion.""" + runner = _make_runner() + runner.reserve() + async for _ in runner.run_until_convergence(): + pass + # A second run after the first completes must be accepted. + runner.reserve() + async for _ in runner.run_until_convergence(): + pass + + +async def test_runner_accepts_new_run_after_previous_failure(): + """A failed run must not leave the runner locked out of future runs. + + After the first run raises, a fresh ``reserve()`` and + ``run_until_convergence()`` must succeed (or fail for a different reason - + e.g. residual messages still don't converge - but never with the + lock-rejection ``"Runner is already running."``). + """ executor_a = MockExecutor(id="executor_a") executor_b = MockExecutor(id="executor_b") - - # Create a loop edges = [ SingleEdgeGroup(executor_a.id, executor_b.id), SingleEdgeGroup(executor_b.id, executor_a.id), ] - - executors: dict[str, Executor] = { - executor_a.id: executor_a, - executor_b.id: executor_b, - } + executors: dict[str, Executor] = {executor_a.id: executor_a, executor_b.id: executor_b} state = State() ctx = InProcRunnerContext() + runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash", max_iterations=2) - runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") + await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - await executor_a.execute( - MockMessage(data=0), - ["START"], # source_executor_ids - state, # state - ctx, # runner_context - ) + runner.reserve() + with pytest.raises(WorkflowConvergenceException): + async for _ in runner.run_until_convergence(): + pass - with pytest.raises(WorkflowRunnerException, match="Runner is already running."): + # The runner should accept a fresh reservation and run again. + runner.reserve() # must not raise + try: + async for _ in runner.run_until_convergence(): + pass + except Exception as exc: + assert "Runner is already running" not in str(exc), "Runner stayed locked after a failed run" - async def _run(): - async for _ in runner.run_until_convergence(): - pass - await asyncio.gather(_run(), _run()) +async def test_runner_rejects_concurrent_run_until_convergence(): + """While a run is in progress, a second ``reserve()`` is rejected. + + Confirms the run lock is held for the full duration of the run, not just + synchronously between ``reserve()`` and the first ``__anext__`` call. + """ + runner = _make_runner() + + started = asyncio.Event() + release = asyncio.Event() + + async def _slow_run(): + runner.reserve() + async for _ in runner.run_until_convergence(): + if not started.is_set(): + started.set() + await release.wait() + + task = asyncio.create_task(_slow_run()) + await started.wait() # first run is now executing + + try: + with pytest.raises(WorkflowRunnerException, match="Runner is already running."): + runner.reserve() + finally: + release.set() + await task + + # And after the first run finishes, a new reservation + run must be accepted. + runner.reserve() + async for _ in runner.run_until_convergence(): + pass async def test_runner_emits_runner_completion_for_agent_response_without_targets(): @@ -352,6 +451,7 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets ) ) + runner.reserve() events: list[WorkflowEvent] = [event async for event in runner.run_until_convergence()] # The runner should complete without errors when handling AgentExecutorResponse without targets # No specific events are expected since there are no executors to process the message @@ -408,6 +508,7 @@ async def test_runner_cancellation_stops_active_executor(): async for _ in runner.run_until_convergence(): pass + runner.reserve() task = asyncio.create_task(run_workflow()) # Wait for executor_a to complete (0.3s) and executor_b to start but not finish @@ -469,6 +570,7 @@ async def test_runner_iteration_exception_drains_events(): ) events: list[WorkflowEvent] = [] + runner.reserve() with pytest.raises(RuntimeError, match="Simulated executor failure"): async for event in runner.run_until_convergence(): events.append(event) @@ -579,6 +681,7 @@ async def test_runner_checkpoint_creation_failure(): # Should complete without raising, even though checkpointing fails result: int | None = None + runner.reserve() async for event in runner.run_until_convergence(): if event.type == "output": result = event.data @@ -775,6 +878,7 @@ async def test_runner_with_pre_loop_events(): await ctx.add_event(WorkflowEvent("output", executor_id="test_executor", data="pre-loop-output")) events: list[WorkflowEvent] = [] + runner.reserve() async for event in runner.run_until_convergence(): events.append(event) @@ -822,6 +926,7 @@ async def test_runner_drains_straggler_events(): ) events: list[WorkflowEvent] = [] + runner.reserve() async for event in runner.run_until_convergence(): events.append(event) @@ -875,6 +980,7 @@ async def test_runner_checkpoint_with_resumed_flag(): ) # Run until convergence + runner.reserve() async for _ in runner.run_until_convergence(): pass @@ -941,6 +1047,7 @@ async def test_runner_drains_events_on_iteration_exception(): ) events: list[WorkflowEvent] = [] + runner.reserve() with pytest.raises(RuntimeError, match="Executor failed with pending events"): async for event in runner.run_until_convergence(): events.append(event) @@ -997,6 +1104,7 @@ async def test_runner_drains_straggler_events_at_iteration_end(): ) events: list[WorkflowEvent] = [] + runner.reserve() async for event in runner.run_until_convergence(): events.append(event) diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 27f24d26f9..1c77ae003b 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -26,6 +26,7 @@ from agent_framework import ( WorkflowContext, WorkflowConvergenceException, WorkflowEvent, + WorkflowException, WorkflowMessage, WorkflowRunState, handler, @@ -759,8 +760,7 @@ async def test_workflow_concurrent_execution_prevention(): # Try to start a second concurrent execution - this should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) @@ -795,8 +795,7 @@ async def test_workflow_concurrent_execution_prevention_streaming(): # Try to start a second concurrent execution - this should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) @@ -828,14 +827,12 @@ async def test_workflow_concurrent_execution_prevention_mixed_methods(): # Try different execution methods - all should fail with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): await workflow.run(NumberMessage(data=0)) with pytest.raises( - RuntimeError, - match="Workflow is already running. Concurrent executions are not allowed.", + WorkflowException, match="Workflow is already running; concurrent runs are not allowed on the same instance." ): async for _ in workflow.run(NumberMessage(data=0), stream=True): break