diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index bd02a22bc1..3c441863ec 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -5,13 +5,11 @@ import contextlib import logging from collections import defaultdict from collections.abc import AsyncGenerator, Sequence -from enum import Enum from typing import Any from ..exceptions import ( WorkflowCheckpointException, WorkflowConvergenceException, - WorkflowRunnerException, ) from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint from ._const import EXECUTOR_STATE_KEY @@ -28,25 +26,6 @@ from ._state import State logger = logging.getLogger(__name__) -class _RunnerLifecycle(Enum): - """Lifecycle of a single :class:`Runner` invocation. - - Three states keep concurrent-run protection in one place while still letting - :meth:`Workflow.run` reserve the runner synchronously (before any await): - - * ``IDLE`` - no run in progress. - * ``RESERVED`` - :meth:`Runner.reserve` was called by an external caller but - :meth:`Runner.run_until_convergence` has not yet entered its body. This is the - window during which ``Workflow.run`` has handed back a ``ResponseStream`` that - hasn't been iterated yet. ``run_until_convergence`` requires this state. - * ``RUNNING`` - the body of :meth:`Runner.run_until_convergence` is executing. - """ - - IDLE = "idle" - RESERVED = "reserved" - RUNNING = "running" - - class Runner: """A class to run a workflow in Pregel supersteps.""" @@ -83,7 +62,6 @@ class Runner: self._iteration = 0 self._max_iterations = max_iterations self._state = state - self._lifecycle: _RunnerLifecycle = _RunnerLifecycle.IDLE self._resumed_from_checkpoint = False # Track whether we resumed @property @@ -96,31 +74,6 @@ class Runner: """Get the shared state for the workflow.""" return self._state - def reserve(self) -> None: - """Synchronously reserve the runner for an upcoming run. - - This is the **only** way to acquire the run lock. :meth:`run_until_convergence` - requires the runner to be in :attr:`_RunnerLifecycle.RESERVED` and will refuse - to start otherwise. Reserving synchronously lets callers (notably - :meth:`Workflow.run`) reject concurrent runs *before* any ``await`` or - async-generator suspension - otherwise a second caller could slip past a - flag-based guard while the first is still suspended above - :meth:`run_until_convergence`. The lock is released by - :meth:`run_until_convergence` in its ``finally`` clause once it begins, or by - :meth:`release` when the run never starts (for example, if early validation - raises before ``run_until_convergence`` is reached). - - Raises: - WorkflowRunnerException: If the runner is already reserved or running. - """ - if self._lifecycle is not _RunnerLifecycle.IDLE: - raise WorkflowRunnerException("Runner is already reserved or running.") - self._lifecycle = _RunnerLifecycle.RESERVED - - def release(self) -> None: - """Release the runner's run lock. Idempotent; safe to call when already idle.""" - self._lifecycle = _RunnerLifecycle.IDLE - def reset_iteration_count(self) -> None: """Reset the iteration count to zero. @@ -138,15 +91,10 @@ class Runner: This is useful when reusing the same workflow instance for a different run that is independent from prior runs. - Raises: - WorkflowRunnerException: If the runner is reserved or running. Reset is only - allowed when the runner is idle to avoid clobbering in-flight run state. + Concurrent-run rejection lives at the :class:`Workflow` level via the + active-run weak reference, so this method does not re-validate that + no run is in progress; callers are expected to have already done so. """ - if self._lifecycle is not _RunnerLifecycle.IDLE: - raise WorkflowRunnerException( - "Cannot reset the runner while a run is in progress. " - "Wait for the current run to complete before calling reset_for_new_run()." - ) self.reset_iteration_count() self._ctx.reset_for_new_run() self._state.clear() @@ -156,92 +104,79 @@ class Runner: async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: """Run the workflow until no more messages are sent.""" - # Mandatory reservation: callers must reserve() the runner first. This makes - # ``reserve`` the single entry point that takes the run lock, so all - # concurrent-run rejection happens there. Any non-RESERVED state here is a - # contract violation - either the caller forgot to reserve, or another run - # is already in progress. - if self._lifecycle is not _RunnerLifecycle.RESERVED: - raise WorkflowRunnerException( - "Runner must be reserved via Runner.reserve() before calling run_until_convergence()." - ) - self._lifecycle = _RunnerLifecycle.RUNNING - previous_checkpoint_id: CheckpointID | None = None - try: - # Emit any events already produced prior to entering loop - if await self._ctx.has_events(): - logger.info("Yielding pre-loop events") - for event in await self._ctx.drain_events(): - yield event - # Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration, - # we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the - # states after which the start executor has run. Note that we execute the start executor outside of the - # main iteration loop. - if await self._ctx.has_messages() and not self._resumed_from_checkpoint: - previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) + # Emit any events already produced prior to entering loop + if await self._ctx.has_events(): + logger.info("Yielding pre-loop events") + for event in await self._ctx.drain_events(): + yield event - while self._iteration < self._max_iterations: - logger.info(f"Starting superstep {self._iteration + 1}") - yield WorkflowEvent.superstep_started(iteration=self._iteration + 1) + # Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration, + # we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the + # states after which the start executor has run. Note that we execute the start executor outside of the + # main iteration loop. + if await self._ctx.has_messages() and not self._resumed_from_checkpoint: + previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) - # Run iteration concurrently with live event streaming: we poll - # for new events while the iteration coroutine progresses. - iteration_task = asyncio.create_task(self._run_iteration()) - try: - while not iteration_task.done(): - try: - # Wait briefly for any new event; timeout allows progress checks - event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) - yield event - except asyncio.TimeoutError: - # Periodically continue to let iteration advance - continue - except asyncio.CancelledError: - # Propagate cancellation to the iteration task to avoid orphaned work - iteration_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await iteration_task - raise + while self._iteration < self._max_iterations: + logger.info(f"Starting superstep {self._iteration + 1}") + yield WorkflowEvent.superstep_started(iteration=self._iteration + 1) - # Propagate errors from iteration, but first surface any pending events - try: + # Run iteration concurrently with live event streaming: we poll + # for new events while the iteration coroutine progresses. + iteration_task = asyncio.create_task(self._run_iteration()) + try: + while not iteration_task.done(): + try: + # Wait briefly for any new event; timeout allows progress checks + event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05) + yield event + except asyncio.TimeoutError: + # Periodically continue to let iteration advance + continue + except asyncio.CancelledError: + # Propagate cancellation to the iteration task to avoid orphaned work + iteration_task.cancel() + with contextlib.suppress(asyncio.CancelledError): await iteration_task - except Exception: - # Make sure failure-related events (like ExecutorFailedEvent) are surfaced - if await self._ctx.has_events(): - for event in await self._ctx.drain_events(): - yield event - raise - self._iteration += 1 + raise - # Drain any straggler events emitted at tail end + # Propagate errors from iteration, but first surface any pending events + try: + await iteration_task + except Exception: + # Make sure failure-related events (like ExecutorFailedEvent) are surfaced if await self._ctx.has_events(): for event in await self._ctx.drain_events(): yield event + raise + self._iteration += 1 - logger.info(f"Completed superstep {self._iteration}") + # Drain any straggler events emitted at tail end + if await self._ctx.has_events(): + for event in await self._ctx.drain_events(): + yield event - # Commit pending state changes at superstep boundary - self._state.commit() + logger.info(f"Completed superstep {self._iteration}") - # Create checkpoint after each superstep iteration - previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) + # Commit pending state changes at superstep boundary + self._state.commit() - yield WorkflowEvent.superstep_completed(iteration=self._iteration) + # Create checkpoint after each superstep iteration + previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id) - # Check for convergence: no more messages to process - if not await self._ctx.has_messages(): - break + yield WorkflowEvent.superstep_completed(iteration=self._iteration) - if self._iteration >= self._max_iterations and await self._ctx.has_messages(): - raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.") + # Check for convergence: no more messages to process + if not await self._ctx.has_messages(): + break - logger.info(f"Workflow completed after {self._iteration} supersteps") - self._resumed_from_checkpoint = False # Reset resume flag for next run - finally: - self._lifecycle = _RunnerLifecycle.IDLE + if self._iteration >= self._max_iterations and await self._ctx.has_messages(): + raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.") + + logger.info(f"Workflow completed after {self._iteration} supersteps") + self._resumed_from_checkpoint = False # Reset resume flag for next run async def _run_iteration(self) -> None: """Run a single iteration of the workflow. diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 4a36e39574..6e240a1096 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -11,6 +11,7 @@ import logging import types import uuid import warnings +import weakref from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, overload @@ -362,6 +363,14 @@ class Workflow(DictConvertible): # for a freshly built workflow that has not yet been run. self._status: WorkflowRunState = WorkflowRunState.IDLE + # Weak reference to the in-flight run's ``ResponseStream``. Used as the single + # concurrency lock: if the previous stream is still alive, ``run()`` rejects a + # new run synchronously (before any await). When the stream is fully consumed + # ``_run_core``'s finally clears this; if the caller drops the stream without + # ever iterating, the weakref dereferences to ``None`` once Python collects it, + # so a subsequent ``run()`` is allowed. + self._active_run: weakref.ref[ResponseStream[WorkflowEvent, WorkflowRunResult]] | None = None + @property def status(self) -> WorkflowRunState: """Return the current run-level status of this workflow instance. @@ -734,16 +743,20 @@ class Workflow(DictConvertible): # Validate parameters first so misuse fails before we touch any run state. self._validate_run_params(message, responses, checkpoint_id) - # Acquire the run lock synchronously - before constructing the ResponseStream - # or yielding control to the event loop - so a second concurrent ``run`` call - # is rejected immediately rather than slipping past the guard while the first - # call is suspended inside its async generator. - try: - self._runner.reserve() - except WorkflowRunnerException as exc: + # Concurrency check: reject a second run synchronously - before constructing + # the ResponseStream or yielding control to the event loop - so a concurrent + # ``run`` call can't slip past the guard while the first call is suspended + # inside its async generator. The ``ResponseStream`` returned below is the + # lock: as long as the caller holds a reference to it, ``self._active_run()`` + # resolves to a live object and a new ``run`` is rejected. When the stream is + # fully consumed, ``_run_core``'s finally clears the attribute. When the + # caller drops the stream without iterating, garbage collection invalidates + # the weakref, so a subsequent ``run`` is permitted. + existing_stream = self._active_run() if self._active_run is not None else None + if existing_stream is not None: raise WorkflowException( "Workflow is already running; concurrent runs are not allowed on the same instance." - ) from exc + ) response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult]( self._run_core( @@ -757,6 +770,7 @@ class Workflow(DictConvertible): ), finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), ) + self._active_run = weakref.ref(response_stream) if stream: return response_stream @@ -823,12 +837,14 @@ class Workflow(DictConvertible): continue yield event finally: - # Release the run lock acquired synchronously by ``run()`` via ``reserve``. - # ``run_until_convergence`` also clears it in its own ``finally`` once its - # body runs; this clause additionally covers the case where this generator - # raises (or is closed) before that body is reached - e.g. the in-flight - # messages check above. ``release`` is idempotent. - self._runner.release() + # Clear the active-run weakref so a subsequent ``run()`` is allowed. + # ``run()`` set this synchronously after constructing the ResponseStream; + # we clear it here once the run has finished (success, error, early + # close, or partial iteration). This is in-band, so by the time the + # caller's stream is later garbage collected, ``_active_run`` is already + # ``None`` (or has been replaced by a newer run's weakref) - no GC-time + # finalizer is needed. + self._active_run = None if checkpoint_storage is not None: self._runner.context.clear_runtime_checkpoint_storage() @@ -1161,4 +1177,11 @@ class Workflow(DictConvertible): allowed when the workflow is idle to avoid clobbering in-flight run state. """ + existing_stream = self._active_run() if self._active_run is not None else None + if existing_stream is not None: + raise WorkflowRunnerException( + "Cannot reset the workflow while a run is in progress. " + "Wait for the current run to complete before calling reset_for_new_run()." + ) + self._active_run = None await self._runner.reset_for_new_run() diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 0019ed1e92..b41e43beee 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -17,7 +17,6 @@ from agent_framework import ( WorkflowContext, WorkflowConvergenceException, WorkflowEvent, - WorkflowRunnerException, WorkflowRunState, handler, ) @@ -106,7 +105,6 @@ async def test_runner_run_until_convergence(): state, # state ctx, # runner_context ) - runner.reserve() async for event in runner.run_until_convergence(): assert isinstance(event, WorkflowEvent) if event.type == "output": @@ -148,7 +146,6 @@ async def test_runner_run_until_convergence_not_completed(): WorkflowConvergenceException, match="Runner did not converge after 5 iterations.", ): - runner.reserve() async for event in runner.run_until_convergence(): assert event.type != "status" or event.state != WorkflowRunState.IDLE @@ -307,16 +304,22 @@ async def test_fanout_edge_runner_delivers_to_multiple_targets_concurrently() -> assert probe_target.call_count == 1 -async def test_runner_run_until_convergence_requires_reservation(): - """run_until_convergence refuses to start without a prior reserve().""" +async def test_runner_run_until_convergence_runs_sequentially(): + """run_until_convergence can be invoked back-to-back on the same Runner. + + The Runner itself does not enforce concurrency; that responsibility lives on + :class:`Workflow`. This test simply confirms the Runner is reusable across + sequential runs. + """ runner = _make_runner() - with pytest.raises(WorkflowRunnerException, match="Runner must be reserved"): - async for _ in runner.run_until_convergence(): - pass + async for _ in runner.run_until_convergence(): + pass + async for _ in runner.run_until_convergence(): + pass def _make_runner() -> Runner: - """Build a minimal runner for lifecycle tests.""" + """Build a minimal runner for runner-level tests.""" return Runner( [], {}, @@ -327,57 +330,11 @@ def _make_runner() -> Runner: ) -def test_runner_reserve_twice_raises(): - """Calling reserve() while already reserved rejects the second caller. - - This is what guards Workflow.run against a concurrent caller slipping in - between the first call's synchronous reserve() and its first await. - """ - runner = _make_runner() - runner.reserve() - with pytest.raises(WorkflowRunnerException, match="Runner is already reserved or running."): - runner.reserve() - - -def test_runner_reserve_after_release_is_accepted(): - """Sequential runs are permitted; only concurrent ones are blocked.""" - runner = _make_runner() - runner.reserve() - runner.release() - runner.reserve() # should not raise - - -def test_runner_release_when_idle_is_noop(): - """release() on an idle runner does not affect a subsequent reserve(). - - Workflow._run_core's finally always calls release(), even when - run_until_convergence already cleared the lock in its own finally; - that double-release must not lock out the next run. - """ - runner = _make_runner() - runner.release() # already idle - must not raise or wedge state - runner.reserve() # next run still allowed - - -async def test_runner_run_until_convergence_consumes_reservation(): - """run_until_convergence accepts a prior reservation and runs to completion.""" - runner = _make_runner() - runner.reserve() - async for _ in runner.run_until_convergence(): - pass - # A second run after the first completes must be accepted. - runner.reserve() - async for _ in runner.run_until_convergence(): - pass - - async def test_runner_accepts_new_run_after_previous_failure(): - """A failed run must not leave the runner locked out of future runs. + """A failed run must not leave the Runner unable to start a new run. - After the first run raises, a fresh ``reserve()`` and - ``run_until_convergence()`` must succeed (or fail for a different reason - - e.g. residual messages still don't converge - but never with the - lock-rejection ``"Runner is already running."``). + After the first run raises, ``run_until_convergence()`` must be callable + again and not surface any lifecycle-related rejection. """ executor_a = MockExecutor(id="executor_a") executor_b = MockExecutor(id="executor_b") @@ -392,13 +349,12 @@ async def test_runner_accepts_new_run_after_previous_failure(): await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - runner.reserve() with pytest.raises(WorkflowConvergenceException): async for _ in runner.run_until_convergence(): pass - # The runner should accept a fresh reservation and run again. - runner.reserve() # must not raise + # A second run on the same Runner must not be blocked by stale lifecycle + # state from the failed run. try: async for _ in runner.run_until_convergence(): pass @@ -406,40 +362,6 @@ async def test_runner_accepts_new_run_after_previous_failure(): assert "Runner is already running" not in str(exc), "Runner stayed locked after a failed run" -async def test_runner_rejects_concurrent_run_until_convergence(): - """While a run is in progress, a second ``reserve()`` is rejected. - - Confirms the run lock is held for the full duration of the run, not just - synchronously between ``reserve()`` and the first ``__anext__`` call. - """ - runner = _make_runner() - - started = asyncio.Event() - release = asyncio.Event() - - async def _slow_run(): - runner.reserve() - async for _ in runner.run_until_convergence(): - if not started.is_set(): - started.set() - await release.wait() - - task = asyncio.create_task(_slow_run()) - await started.wait() # first run is now executing - - try: - with pytest.raises(WorkflowRunnerException, match="Runner is already reserved or running."): - runner.reserve() - finally: - release.set() - await task - - # And after the first run finishes, a new reservation + run must be accepted. - runner.reserve() - async for _ in runner.run_until_convergence(): - pass - - async def test_runner_emits_runner_completion_for_agent_response_without_targets(): ctx = InProcRunnerContext() runner = Runner([], {}, State(), ctx, "test_name", graph_signature_hash="test_hash") @@ -451,7 +373,6 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets ) ) - runner.reserve() events: list[WorkflowEvent] = [event async for event in runner.run_until_convergence()] # The runner should complete without errors when handling AgentExecutorResponse without targets # No specific events are expected since there are no executors to process the message @@ -508,7 +429,6 @@ async def test_runner_cancellation_stops_active_executor(): async for _ in runner.run_until_convergence(): pass - runner.reserve() task = asyncio.create_task(run_workflow()) # Wait for executor_a to complete (0.3s) and executor_b to start but not finish @@ -570,7 +490,6 @@ async def test_runner_iteration_exception_drains_events(): ) events: list[WorkflowEvent] = [] - runner.reserve() with pytest.raises(RuntimeError, match="Simulated executor failure"): async for event in runner.run_until_convergence(): events.append(event) @@ -681,7 +600,6 @@ async def test_runner_checkpoint_creation_failure(): # Should complete without raising, even though checkpointing fails result: int | None = None - runner.reserve() async for event in runner.run_until_convergence(): if event.type == "output": result = event.data @@ -878,7 +796,6 @@ async def test_runner_with_pre_loop_events(): await ctx.add_event(WorkflowEvent("output", executor_id="test_executor", data="pre-loop-output")) events: list[WorkflowEvent] = [] - runner.reserve() async for event in runner.run_until_convergence(): events.append(event) @@ -926,7 +843,6 @@ async def test_runner_drains_straggler_events(): ) events: list[WorkflowEvent] = [] - runner.reserve() async for event in runner.run_until_convergence(): events.append(event) @@ -980,7 +896,6 @@ async def test_runner_checkpoint_with_resumed_flag(): ) # Run until convergence - runner.reserve() async for _ in runner.run_until_convergence(): pass @@ -1047,7 +962,6 @@ async def test_runner_drains_events_on_iteration_exception(): ) events: list[WorkflowEvent] = [] - runner.reserve() with pytest.raises(RuntimeError, match="Executor failed with pending events"): async for event in runner.run_until_convergence(): events.append(event) @@ -1104,7 +1018,6 @@ async def test_runner_drains_straggler_events_at_iteration_end(): ) events: list[WorkflowEvent] = [] - runner.reserve() async for event in runner.run_until_convergence(): events.append(event) @@ -1271,7 +1184,6 @@ async def test_runner_can_run_again_after_reset_for_new_run(): # First run: drives MockExecutor's loop until it yields the terminal value. await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - runner.reserve() async for _ in runner.run_until_convergence(): pass assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage] @@ -1280,7 +1192,6 @@ async def test_runner_can_run_again_after_reset_for_new_run(): # Second run: must succeed cleanly using the same runner instance. await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - runner.reserve() second_run_outputs: list[int] = [] async for event in runner.run_until_convergence(): if event.type == "output": @@ -1290,77 +1201,4 @@ async def test_runner_can_run_again_after_reset_for_new_run(): assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage] -async def test_runner_reset_for_new_run_rejected_when_reserved(): - """reset_for_new_run refuses to run when the runner is reserved but not yet running.""" - runner = _make_runner() - runner.reserve() - - with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"): - await runner.reset_for_new_run() - - -async def test_runner_reset_for_new_run_rejected_while_running(): - """reset_for_new_run refuses to run while a run is mid-execution.""" - runner = _make_runner() - - started = asyncio.Event() - release = asyncio.Event() - - async def _slow_run() -> None: - runner.reserve() - async for _ in runner.run_until_convergence(): - if not started.is_set(): - started.set() - await release.wait() - - task = asyncio.create_task(_slow_run()) - await started.wait() # first run is now executing inside run_until_convergence - - try: - with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"): - await runner.reset_for_new_run() - finally: - release.set() - await task - - # Once the run drained, reset must succeed again. - await runner.reset_for_new_run() - - -async def test_runner_reset_for_new_run_does_not_mutate_when_rejected(): - """When reset is rejected, the runner's iteration counter and state are untouched.""" - - class TrackingExecutor(MockExecutor): - def __init__(self, id: str) -> None: - super().__init__(id=id) - self.reset_calls = 0 - - async def reset(self) -> None: - self.reset_calls += 1 - - executor = TrackingExecutor(id="executor") - state = State() - state.set("preserved", 42) - state.commit() - - runner = Runner( - [], - {executor.id: executor}, - state, - InProcRunnerContext(), - "test_name", - graph_signature_hash="test_hash", - ) - runner._iteration = 7 # pyright: ignore[reportPrivateUsage] - runner.reserve() - - with pytest.raises(WorkflowRunnerException): - await runner.reset_for_new_run() - - # Nothing was mutated by the failed reset. - assert runner._iteration == 7 # pyright: ignore[reportPrivateUsage] - assert state.get("preserved") == 42 - assert executor.reset_calls == 0 - - # endregion: Tests for Runner.reset_for_new_run() diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 9267a17f68..4e607ca8f9 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import gc import tempfile from collections.abc import AsyncIterable, Awaitable, Sequence from dataclasses import dataclass, field @@ -846,6 +847,92 @@ async def test_workflow_concurrent_execution_prevention_mixed_methods(): assert result.get_final_state() == WorkflowRunState.IDLE +async def test_workflow_sequential_runs_after_completion() -> None: + """A completed run must release the runner so the next ``run`` succeeds. + + This is the happy-path counterpart to the concurrent-run guard tests: + those tests verify that a *concurrent* run is rejected, but they do not + verify that the lock is actually released afterwards. This test + exercises that release path explicitly across the three call shapes + (non-streaming, streaming-iterated, streaming-via-get_final_response) + and across multiple consecutive turns to catch lock leaks. + """ + executor = IncrementExecutor(id="seq_executor", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + # Non-streaming -> non-streaming + r1 = await workflow.run(NumberMessage(data=0)) + assert r1.get_final_state() == WorkflowRunState.IDLE + + r2 = await workflow.run(NumberMessage(data=0)) + assert r2.get_final_state() == WorkflowRunState.IDLE + + # Non-streaming -> streaming-iterated + stream_events: list[WorkflowEvent] = [] + async for event in workflow.run(NumberMessage(data=0), stream=True): + stream_events.append(event) + assert any(e.type == "status" and e.state == WorkflowRunState.IDLE for e in stream_events) + + # Streaming -> streaming via get_final_response (no manual iteration) + r3 = await workflow.run(NumberMessage(data=0), stream=True).get_final_response() + assert r3.get_final_state() == WorkflowRunState.IDLE + + # Streaming -> non-streaming (back to the start) + r4 = await workflow.run(NumberMessage(data=0)) + assert r4.get_final_state() == WorkflowRunState.IDLE + + +async def test_workflow_unconsumed_stream_releases_run_lock() -> None: + """An unconsumed stream must not leak the run lock. + + ``Workflow.run`` reserves the runner *synchronously* so that concurrent + callers are rejected immediately. The reservation is normally released + by ``_run_core``'s ``finally`` once the stream is iterated. If the + caller never iterates the stream, a GC-time finalizer must release the + reservation instead - otherwise every subsequent ``Workflow.run`` call + on this instance would fail with the concurrent-run error. + """ + executor = IncrementExecutor(id="unconsumed_stream_exec", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + # Build a stream and immediately drop it without iterating. + stream = workflow.run(NumberMessage(data=0), stream=True) + assert stream is not None # silence unused-variable warnings; stream is GC'd below + del stream + gc.collect() + # Yield to the event loop so any scheduled finalizer work can run. + await asyncio.sleep(0) + + # The runner should be back to IDLE; a fresh run must succeed. + result = await workflow.run(NumberMessage(data=0)) + assert result.get_final_state() == WorkflowRunState.IDLE + + +async def test_workflow_unawaited_run_coroutine_releases_run_lock() -> None: + """An un-awaited non-streaming ``run()`` coroutine must also not leak the lock. + + ``Workflow.run`` (non-streaming) returns a coroutine produced by + ``ResponseStream.get_final_response``. The underlying ResponseStream is + held alive by that coroutine, so dropping the coroutine without + awaiting it must still release the reservation via the same GC-time + fallback used for unconsumed streams. + """ + executor = IncrementExecutor(id="unawaited_run_exec", limit=3, increment=1) + workflow = WorkflowBuilder(start_executor=executor).build() + + coro = workflow.run(NumberMessage(data=0)) + # Closing suppresses the "coroutine was never awaited" warning. We cast to + # ``Any`` because the typed return is ``Awaitable[...]``; in practice it is + # a coroutine that exposes ``close``. + cast(Any, coro).close() + del coro + gc.collect() + await asyncio.sleep(0) + + result = await workflow.run(NumberMessage(data=0)) + assert result.get_final_state() == WorkflowRunState.IDLE + + class _StreamingTestAgent(BaseAgent): """Test agent that supports both streaming and non-streaming modes.""" @@ -1397,7 +1484,7 @@ async def test_workflow_reset_for_new_run_rejected_during_streaming_run() -> Non await asyncio.sleep(0.02) try: - with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"): + with pytest.raises(WorkflowRunnerException, match="Cannot reset the workflow while a run is in progress"): await workflow.reset_for_new_run() finally: await task