mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Remove lifecycle flag
This commit is contained in:
@@ -5,13 +5,11 @@ import contextlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import (
|
||||
WorkflowCheckpointException,
|
||||
WorkflowConvergenceException,
|
||||
WorkflowRunnerException,
|
||||
)
|
||||
from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint
|
||||
from ._const import EXECUTOR_STATE_KEY
|
||||
@@ -28,25 +26,6 @@ from ._state import State
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RunnerLifecycle(Enum):
|
||||
"""Lifecycle of a single :class:`Runner` invocation.
|
||||
|
||||
Three states keep concurrent-run protection in one place while still letting
|
||||
:meth:`Workflow.run` reserve the runner synchronously (before any await):
|
||||
|
||||
* ``IDLE`` - no run in progress.
|
||||
* ``RESERVED`` - :meth:`Runner.reserve` was called by an external caller but
|
||||
:meth:`Runner.run_until_convergence` has not yet entered its body. This is the
|
||||
window during which ``Workflow.run`` has handed back a ``ResponseStream`` that
|
||||
hasn't been iterated yet. ``run_until_convergence`` requires this state.
|
||||
* ``RUNNING`` - the body of :meth:`Runner.run_until_convergence` is executing.
|
||||
"""
|
||||
|
||||
IDLE = "idle"
|
||||
RESERVED = "reserved"
|
||||
RUNNING = "running"
|
||||
|
||||
|
||||
class Runner:
|
||||
"""A class to run a workflow in Pregel supersteps."""
|
||||
|
||||
@@ -83,7 +62,6 @@ class Runner:
|
||||
self._iteration = 0
|
||||
self._max_iterations = max_iterations
|
||||
self._state = state
|
||||
self._lifecycle: _RunnerLifecycle = _RunnerLifecycle.IDLE
|
||||
self._resumed_from_checkpoint = False # Track whether we resumed
|
||||
|
||||
@property
|
||||
@@ -96,31 +74,6 @@ class Runner:
|
||||
"""Get the shared state for the workflow."""
|
||||
return self._state
|
||||
|
||||
def reserve(self) -> None:
|
||||
"""Synchronously reserve the runner for an upcoming run.
|
||||
|
||||
This is the **only** way to acquire the run lock. :meth:`run_until_convergence`
|
||||
requires the runner to be in :attr:`_RunnerLifecycle.RESERVED` and will refuse
|
||||
to start otherwise. Reserving synchronously lets callers (notably
|
||||
:meth:`Workflow.run`) reject concurrent runs *before* any ``await`` or
|
||||
async-generator suspension - otherwise a second caller could slip past a
|
||||
flag-based guard while the first is still suspended above
|
||||
:meth:`run_until_convergence`. The lock is released by
|
||||
:meth:`run_until_convergence` in its ``finally`` clause once it begins, or by
|
||||
:meth:`release` when the run never starts (for example, if early validation
|
||||
raises before ``run_until_convergence`` is reached).
|
||||
|
||||
Raises:
|
||||
WorkflowRunnerException: If the runner is already reserved or running.
|
||||
"""
|
||||
if self._lifecycle is not _RunnerLifecycle.IDLE:
|
||||
raise WorkflowRunnerException("Runner is already reserved or running.")
|
||||
self._lifecycle = _RunnerLifecycle.RESERVED
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release the runner's run lock. Idempotent; safe to call when already idle."""
|
||||
self._lifecycle = _RunnerLifecycle.IDLE
|
||||
|
||||
def reset_iteration_count(self) -> None:
|
||||
"""Reset the iteration count to zero.
|
||||
|
||||
@@ -138,15 +91,10 @@ class Runner:
|
||||
This is useful when reusing the same workflow instance for a different run
|
||||
that is independent from prior runs.
|
||||
|
||||
Raises:
|
||||
WorkflowRunnerException: If the runner is reserved or running. Reset is only
|
||||
allowed when the runner is idle to avoid clobbering in-flight run state.
|
||||
Concurrent-run rejection lives at the :class:`Workflow` level via the
|
||||
active-run weak reference, so this method does not re-validate that
|
||||
no run is in progress; callers are expected to have already done so.
|
||||
"""
|
||||
if self._lifecycle is not _RunnerLifecycle.IDLE:
|
||||
raise WorkflowRunnerException(
|
||||
"Cannot reset the runner while a run is in progress. "
|
||||
"Wait for the current run to complete before calling reset_for_new_run()."
|
||||
)
|
||||
self.reset_iteration_count()
|
||||
self._ctx.reset_for_new_run()
|
||||
self._state.clear()
|
||||
@@ -156,92 +104,79 @@ class Runner:
|
||||
|
||||
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Run the workflow until no more messages are sent."""
|
||||
# Mandatory reservation: callers must reserve() the runner first. This makes
|
||||
# ``reserve`` the single entry point that takes the run lock, so all
|
||||
# concurrent-run rejection happens there. Any non-RESERVED state here is a
|
||||
# contract violation - either the caller forgot to reserve, or another run
|
||||
# is already in progress.
|
||||
if self._lifecycle is not _RunnerLifecycle.RESERVED:
|
||||
raise WorkflowRunnerException(
|
||||
"Runner must be reserved via Runner.reserve() before calling run_until_convergence()."
|
||||
)
|
||||
self._lifecycle = _RunnerLifecycle.RUNNING
|
||||
|
||||
previous_checkpoint_id: CheckpointID | None = None
|
||||
try:
|
||||
# Emit any events already produced prior to entering loop
|
||||
if await self._ctx.has_events():
|
||||
logger.info("Yielding pre-loop events")
|
||||
for event in await self._ctx.drain_events():
|
||||
yield event
|
||||
|
||||
# Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration,
|
||||
# we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the
|
||||
# states after which the start executor has run. Note that we execute the start executor outside of the
|
||||
# main iteration loop.
|
||||
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
|
||||
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
|
||||
# Emit any events already produced prior to entering loop
|
||||
if await self._ctx.has_events():
|
||||
logger.info("Yielding pre-loop events")
|
||||
for event in await self._ctx.drain_events():
|
||||
yield event
|
||||
|
||||
while self._iteration < self._max_iterations:
|
||||
logger.info(f"Starting superstep {self._iteration + 1}")
|
||||
yield WorkflowEvent.superstep_started(iteration=self._iteration + 1)
|
||||
# Create the first checkpoint. Checkpoints are usually considered to be created at the end of an iteration,
|
||||
# we can think of the first checkpoint as being created at the end of a "superstep 0" which captures the
|
||||
# states after which the start executor has run. Note that we execute the start executor outside of the
|
||||
# main iteration loop.
|
||||
if await self._ctx.has_messages() and not self._resumed_from_checkpoint:
|
||||
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
|
||||
|
||||
# Run iteration concurrently with live event streaming: we poll
|
||||
# for new events while the iteration coroutine progresses.
|
||||
iteration_task = asyncio.create_task(self._run_iteration())
|
||||
try:
|
||||
while not iteration_task.done():
|
||||
try:
|
||||
# Wait briefly for any new event; timeout allows progress checks
|
||||
event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05)
|
||||
yield event
|
||||
except asyncio.TimeoutError:
|
||||
# Periodically continue to let iteration advance
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# Propagate cancellation to the iteration task to avoid orphaned work
|
||||
iteration_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await iteration_task
|
||||
raise
|
||||
while self._iteration < self._max_iterations:
|
||||
logger.info(f"Starting superstep {self._iteration + 1}")
|
||||
yield WorkflowEvent.superstep_started(iteration=self._iteration + 1)
|
||||
|
||||
# Propagate errors from iteration, but first surface any pending events
|
||||
try:
|
||||
# Run iteration concurrently with live event streaming: we poll
|
||||
# for new events while the iteration coroutine progresses.
|
||||
iteration_task = asyncio.create_task(self._run_iteration())
|
||||
try:
|
||||
while not iteration_task.done():
|
||||
try:
|
||||
# Wait briefly for any new event; timeout allows progress checks
|
||||
event = await asyncio.wait_for(self._ctx.next_event(), timeout=0.05)
|
||||
yield event
|
||||
except asyncio.TimeoutError:
|
||||
# Periodically continue to let iteration advance
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# Propagate cancellation to the iteration task to avoid orphaned work
|
||||
iteration_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await iteration_task
|
||||
except Exception:
|
||||
# Make sure failure-related events (like ExecutorFailedEvent) are surfaced
|
||||
if await self._ctx.has_events():
|
||||
for event in await self._ctx.drain_events():
|
||||
yield event
|
||||
raise
|
||||
self._iteration += 1
|
||||
raise
|
||||
|
||||
# Drain any straggler events emitted at tail end
|
||||
# Propagate errors from iteration, but first surface any pending events
|
||||
try:
|
||||
await iteration_task
|
||||
except Exception:
|
||||
# Make sure failure-related events (like ExecutorFailedEvent) are surfaced
|
||||
if await self._ctx.has_events():
|
||||
for event in await self._ctx.drain_events():
|
||||
yield event
|
||||
raise
|
||||
self._iteration += 1
|
||||
|
||||
logger.info(f"Completed superstep {self._iteration}")
|
||||
# Drain any straggler events emitted at tail end
|
||||
if await self._ctx.has_events():
|
||||
for event in await self._ctx.drain_events():
|
||||
yield event
|
||||
|
||||
# Commit pending state changes at superstep boundary
|
||||
self._state.commit()
|
||||
logger.info(f"Completed superstep {self._iteration}")
|
||||
|
||||
# Create checkpoint after each superstep iteration
|
||||
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
|
||||
# Commit pending state changes at superstep boundary
|
||||
self._state.commit()
|
||||
|
||||
yield WorkflowEvent.superstep_completed(iteration=self._iteration)
|
||||
# Create checkpoint after each superstep iteration
|
||||
previous_checkpoint_id = await self._create_checkpoint_if_enabled(previous_checkpoint_id)
|
||||
|
||||
# Check for convergence: no more messages to process
|
||||
if not await self._ctx.has_messages():
|
||||
break
|
||||
yield WorkflowEvent.superstep_completed(iteration=self._iteration)
|
||||
|
||||
if self._iteration >= self._max_iterations and await self._ctx.has_messages():
|
||||
raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.")
|
||||
# Check for convergence: no more messages to process
|
||||
if not await self._ctx.has_messages():
|
||||
break
|
||||
|
||||
logger.info(f"Workflow completed after {self._iteration} supersteps")
|
||||
self._resumed_from_checkpoint = False # Reset resume flag for next run
|
||||
finally:
|
||||
self._lifecycle = _RunnerLifecycle.IDLE
|
||||
if self._iteration >= self._max_iterations and await self._ctx.has_messages():
|
||||
raise WorkflowConvergenceException(f"Runner did not converge after {self._max_iterations} iterations.")
|
||||
|
||||
logger.info(f"Workflow completed after {self._iteration} supersteps")
|
||||
self._resumed_from_checkpoint = False # Reset resume flag for next run
|
||||
|
||||
async def _run_iteration(self) -> None:
|
||||
"""Run a single iteration of the workflow.
|
||||
|
||||
@@ -11,6 +11,7 @@ import logging
|
||||
import types
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
@@ -362,6 +363,14 @@ class Workflow(DictConvertible):
|
||||
# for a freshly built workflow that has not yet been run.
|
||||
self._status: WorkflowRunState = WorkflowRunState.IDLE
|
||||
|
||||
# Weak reference to the in-flight run's ``ResponseStream``. Used as the single
|
||||
# concurrency lock: if the previous stream is still alive, ``run()`` rejects a
|
||||
# new run synchronously (before any await). When the stream is fully consumed
|
||||
# ``_run_core``'s finally clears this; if the caller drops the stream without
|
||||
# ever iterating, the weakref dereferences to ``None`` once Python collects it,
|
||||
# so a subsequent ``run()`` is allowed.
|
||||
self._active_run: weakref.ref[ResponseStream[WorkflowEvent, WorkflowRunResult]] | None = None
|
||||
|
||||
@property
|
||||
def status(self) -> WorkflowRunState:
|
||||
"""Return the current run-level status of this workflow instance.
|
||||
@@ -734,16 +743,20 @@ class Workflow(DictConvertible):
|
||||
# Validate parameters first so misuse fails before we touch any run state.
|
||||
self._validate_run_params(message, responses, checkpoint_id)
|
||||
|
||||
# Acquire the run lock synchronously - before constructing the ResponseStream
|
||||
# or yielding control to the event loop - so a second concurrent ``run`` call
|
||||
# is rejected immediately rather than slipping past the guard while the first
|
||||
# call is suspended inside its async generator.
|
||||
try:
|
||||
self._runner.reserve()
|
||||
except WorkflowRunnerException as exc:
|
||||
# Concurrency check: reject a second run synchronously - before constructing
|
||||
# the ResponseStream or yielding control to the event loop - so a concurrent
|
||||
# ``run`` call can't slip past the guard while the first call is suspended
|
||||
# inside its async generator. The ``ResponseStream`` returned below is the
|
||||
# lock: as long as the caller holds a reference to it, ``self._active_run()``
|
||||
# resolves to a live object and a new ``run`` is rejected. When the stream is
|
||||
# fully consumed, ``_run_core``'s finally clears the attribute. When the
|
||||
# caller drops the stream without iterating, garbage collection invalidates
|
||||
# the weakref, so a subsequent ``run`` is permitted.
|
||||
existing_stream = self._active_run() if self._active_run is not None else None
|
||||
if existing_stream is not None:
|
||||
raise WorkflowException(
|
||||
"Workflow is already running; concurrent runs are not allowed on the same instance."
|
||||
) from exc
|
||||
)
|
||||
|
||||
response_stream = ResponseStream[WorkflowEvent, WorkflowRunResult](
|
||||
self._run_core(
|
||||
@@ -757,6 +770,7 @@ class Workflow(DictConvertible):
|
||||
),
|
||||
finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events),
|
||||
)
|
||||
self._active_run = weakref.ref(response_stream)
|
||||
|
||||
if stream:
|
||||
return response_stream
|
||||
@@ -823,12 +837,14 @@ class Workflow(DictConvertible):
|
||||
continue
|
||||
yield event
|
||||
finally:
|
||||
# Release the run lock acquired synchronously by ``run()`` via ``reserve``.
|
||||
# ``run_until_convergence`` also clears it in its own ``finally`` once its
|
||||
# body runs; this clause additionally covers the case where this generator
|
||||
# raises (or is closed) before that body is reached - e.g. the in-flight
|
||||
# messages check above. ``release`` is idempotent.
|
||||
self._runner.release()
|
||||
# Clear the active-run weakref so a subsequent ``run()`` is allowed.
|
||||
# ``run()`` set this synchronously after constructing the ResponseStream;
|
||||
# we clear it here once the run has finished (success, error, early
|
||||
# close, or partial iteration). This is in-band, so by the time the
|
||||
# caller's stream is later garbage collected, ``_active_run`` is already
|
||||
# ``None`` (or has been replaced by a newer run's weakref) - no GC-time
|
||||
# finalizer is needed.
|
||||
self._active_run = None
|
||||
if checkpoint_storage is not None:
|
||||
self._runner.context.clear_runtime_checkpoint_storage()
|
||||
|
||||
@@ -1161,4 +1177,11 @@ class Workflow(DictConvertible):
|
||||
allowed when the workflow is idle to avoid clobbering in-flight run state.
|
||||
|
||||
"""
|
||||
existing_stream = self._active_run() if self._active_run is not None else None
|
||||
if existing_stream is not None:
|
||||
raise WorkflowRunnerException(
|
||||
"Cannot reset the workflow while a run is in progress. "
|
||||
"Wait for the current run to complete before calling reset_for_new_run()."
|
||||
)
|
||||
self._active_run = None
|
||||
await self._runner.reset_for_new_run()
|
||||
|
||||
@@ -17,7 +17,6 @@ from agent_framework import (
|
||||
WorkflowContext,
|
||||
WorkflowConvergenceException,
|
||||
WorkflowEvent,
|
||||
WorkflowRunnerException,
|
||||
WorkflowRunState,
|
||||
handler,
|
||||
)
|
||||
@@ -106,7 +105,6 @@ async def test_runner_run_until_convergence():
|
||||
state, # state
|
||||
ctx, # runner_context
|
||||
)
|
||||
runner.reserve()
|
||||
async for event in runner.run_until_convergence():
|
||||
assert isinstance(event, WorkflowEvent)
|
||||
if event.type == "output":
|
||||
@@ -148,7 +146,6 @@ async def test_runner_run_until_convergence_not_completed():
|
||||
WorkflowConvergenceException,
|
||||
match="Runner did not converge after 5 iterations.",
|
||||
):
|
||||
runner.reserve()
|
||||
async for event in runner.run_until_convergence():
|
||||
assert event.type != "status" or event.state != WorkflowRunState.IDLE
|
||||
|
||||
@@ -307,16 +304,22 @@ async def test_fanout_edge_runner_delivers_to_multiple_targets_concurrently() ->
|
||||
assert probe_target.call_count == 1
|
||||
|
||||
|
||||
async def test_runner_run_until_convergence_requires_reservation():
|
||||
"""run_until_convergence refuses to start without a prior reserve()."""
|
||||
async def test_runner_run_until_convergence_runs_sequentially():
|
||||
"""run_until_convergence can be invoked back-to-back on the same Runner.
|
||||
|
||||
The Runner itself does not enforce concurrency; that responsibility lives on
|
||||
:class:`Workflow`. This test simply confirms the Runner is reusable across
|
||||
sequential runs.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
with pytest.raises(WorkflowRunnerException, match="Runner must be reserved"):
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
|
||||
|
||||
def _make_runner() -> Runner:
|
||||
"""Build a minimal runner for lifecycle tests."""
|
||||
"""Build a minimal runner for runner-level tests."""
|
||||
return Runner(
|
||||
[],
|
||||
{},
|
||||
@@ -327,57 +330,11 @@ def _make_runner() -> Runner:
|
||||
)
|
||||
|
||||
|
||||
def test_runner_reserve_twice_raises():
|
||||
"""Calling reserve() while already reserved rejects the second caller.
|
||||
|
||||
This is what guards Workflow.run against a concurrent caller slipping in
|
||||
between the first call's synchronous reserve() and its first await.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
runner.reserve()
|
||||
with pytest.raises(WorkflowRunnerException, match="Runner is already reserved or running."):
|
||||
runner.reserve()
|
||||
|
||||
|
||||
def test_runner_reserve_after_release_is_accepted():
|
||||
"""Sequential runs are permitted; only concurrent ones are blocked."""
|
||||
runner = _make_runner()
|
||||
runner.reserve()
|
||||
runner.release()
|
||||
runner.reserve() # should not raise
|
||||
|
||||
|
||||
def test_runner_release_when_idle_is_noop():
|
||||
"""release() on an idle runner does not affect a subsequent reserve().
|
||||
|
||||
Workflow._run_core's finally always calls release(), even when
|
||||
run_until_convergence already cleared the lock in its own finally;
|
||||
that double-release must not lock out the next run.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
runner.release() # already idle - must not raise or wedge state
|
||||
runner.reserve() # next run still allowed
|
||||
|
||||
|
||||
async def test_runner_run_until_convergence_consumes_reservation():
|
||||
"""run_until_convergence accepts a prior reservation and runs to completion."""
|
||||
runner = _make_runner()
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
# A second run after the first completes must be accepted.
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
|
||||
|
||||
async def test_runner_accepts_new_run_after_previous_failure():
|
||||
"""A failed run must not leave the runner locked out of future runs.
|
||||
"""A failed run must not leave the Runner unable to start a new run.
|
||||
|
||||
After the first run raises, a fresh ``reserve()`` and
|
||||
``run_until_convergence()`` must succeed (or fail for a different reason -
|
||||
e.g. residual messages still don't converge - but never with the
|
||||
lock-rejection ``"Runner is already running."``).
|
||||
After the first run raises, ``run_until_convergence()`` must be callable
|
||||
again and not surface any lifecycle-related rejection.
|
||||
"""
|
||||
executor_a = MockExecutor(id="executor_a")
|
||||
executor_b = MockExecutor(id="executor_b")
|
||||
@@ -392,13 +349,12 @@ async def test_runner_accepts_new_run_after_previous_failure():
|
||||
|
||||
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
|
||||
|
||||
runner.reserve()
|
||||
with pytest.raises(WorkflowConvergenceException):
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
|
||||
# The runner should accept a fresh reservation and run again.
|
||||
runner.reserve() # must not raise
|
||||
# A second run on the same Runner must not be blocked by stale lifecycle
|
||||
# state from the failed run.
|
||||
try:
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
@@ -406,40 +362,6 @@ async def test_runner_accepts_new_run_after_previous_failure():
|
||||
assert "Runner is already running" not in str(exc), "Runner stayed locked after a failed run"
|
||||
|
||||
|
||||
async def test_runner_rejects_concurrent_run_until_convergence():
|
||||
"""While a run is in progress, a second ``reserve()`` is rejected.
|
||||
|
||||
Confirms the run lock is held for the full duration of the run, not just
|
||||
synchronously between ``reserve()`` and the first ``__anext__`` call.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def _slow_run():
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
if not started.is_set():
|
||||
started.set()
|
||||
await release.wait()
|
||||
|
||||
task = asyncio.create_task(_slow_run())
|
||||
await started.wait() # first run is now executing
|
||||
|
||||
try:
|
||||
with pytest.raises(WorkflowRunnerException, match="Runner is already reserved or running."):
|
||||
runner.reserve()
|
||||
finally:
|
||||
release.set()
|
||||
await task
|
||||
|
||||
# And after the first run finishes, a new reservation + run must be accepted.
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
|
||||
|
||||
async def test_runner_emits_runner_completion_for_agent_response_without_targets():
|
||||
ctx = InProcRunnerContext()
|
||||
runner = Runner([], {}, State(), ctx, "test_name", graph_signature_hash="test_hash")
|
||||
@@ -451,7 +373,6 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets
|
||||
)
|
||||
)
|
||||
|
||||
runner.reserve()
|
||||
events: list[WorkflowEvent] = [event async for event in runner.run_until_convergence()]
|
||||
# The runner should complete without errors when handling AgentExecutorResponse without targets
|
||||
# No specific events are expected since there are no executors to process the message
|
||||
@@ -508,7 +429,6 @@ async def test_runner_cancellation_stops_active_executor():
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
|
||||
runner.reserve()
|
||||
task = asyncio.create_task(run_workflow())
|
||||
|
||||
# Wait for executor_a to complete (0.3s) and executor_b to start but not finish
|
||||
@@ -570,7 +490,6 @@ async def test_runner_iteration_exception_drains_events():
|
||||
)
|
||||
|
||||
events: list[WorkflowEvent] = []
|
||||
runner.reserve()
|
||||
with pytest.raises(RuntimeError, match="Simulated executor failure"):
|
||||
async for event in runner.run_until_convergence():
|
||||
events.append(event)
|
||||
@@ -681,7 +600,6 @@ async def test_runner_checkpoint_creation_failure():
|
||||
|
||||
# Should complete without raising, even though checkpointing fails
|
||||
result: int | None = None
|
||||
runner.reserve()
|
||||
async for event in runner.run_until_convergence():
|
||||
if event.type == "output":
|
||||
result = event.data
|
||||
@@ -878,7 +796,6 @@ async def test_runner_with_pre_loop_events():
|
||||
await ctx.add_event(WorkflowEvent("output", executor_id="test_executor", data="pre-loop-output"))
|
||||
|
||||
events: list[WorkflowEvent] = []
|
||||
runner.reserve()
|
||||
async for event in runner.run_until_convergence():
|
||||
events.append(event)
|
||||
|
||||
@@ -926,7 +843,6 @@ async def test_runner_drains_straggler_events():
|
||||
)
|
||||
|
||||
events: list[WorkflowEvent] = []
|
||||
runner.reserve()
|
||||
async for event in runner.run_until_convergence():
|
||||
events.append(event)
|
||||
|
||||
@@ -980,7 +896,6 @@ async def test_runner_checkpoint_with_resumed_flag():
|
||||
)
|
||||
|
||||
# Run until convergence
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
|
||||
@@ -1047,7 +962,6 @@ async def test_runner_drains_events_on_iteration_exception():
|
||||
)
|
||||
|
||||
events: list[WorkflowEvent] = []
|
||||
runner.reserve()
|
||||
with pytest.raises(RuntimeError, match="Executor failed with pending events"):
|
||||
async for event in runner.run_until_convergence():
|
||||
events.append(event)
|
||||
@@ -1104,7 +1018,6 @@ async def test_runner_drains_straggler_events_at_iteration_end():
|
||||
)
|
||||
|
||||
events: list[WorkflowEvent] = []
|
||||
runner.reserve()
|
||||
async for event in runner.run_until_convergence():
|
||||
events.append(event)
|
||||
|
||||
@@ -1271,7 +1184,6 @@ async def test_runner_can_run_again_after_reset_for_new_run():
|
||||
|
||||
# First run: drives MockExecutor's loop until it yields the terminal value.
|
||||
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
|
||||
@@ -1280,7 +1192,6 @@ async def test_runner_can_run_again_after_reset_for_new_run():
|
||||
|
||||
# Second run: must succeed cleanly using the same runner instance.
|
||||
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
|
||||
runner.reserve()
|
||||
second_run_outputs: list[int] = []
|
||||
async for event in runner.run_until_convergence():
|
||||
if event.type == "output":
|
||||
@@ -1290,77 +1201,4 @@ async def test_runner_can_run_again_after_reset_for_new_run():
|
||||
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_rejected_when_reserved():
|
||||
"""reset_for_new_run refuses to run when the runner is reserved but not yet running."""
|
||||
runner = _make_runner()
|
||||
runner.reserve()
|
||||
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_rejected_while_running():
|
||||
"""reset_for_new_run refuses to run while a run is mid-execution."""
|
||||
runner = _make_runner()
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def _slow_run() -> None:
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
if not started.is_set():
|
||||
started.set()
|
||||
await release.wait()
|
||||
|
||||
task = asyncio.create_task(_slow_run())
|
||||
await started.wait() # first run is now executing inside run_until_convergence
|
||||
|
||||
try:
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
|
||||
await runner.reset_for_new_run()
|
||||
finally:
|
||||
release.set()
|
||||
await task
|
||||
|
||||
# Once the run drained, reset must succeed again.
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_does_not_mutate_when_rejected():
|
||||
"""When reset is rejected, the runner's iteration counter and state are untouched."""
|
||||
|
||||
class TrackingExecutor(MockExecutor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.reset_calls = 0
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.reset_calls += 1
|
||||
|
||||
executor = TrackingExecutor(id="executor")
|
||||
state = State()
|
||||
state.set("preserved", 42)
|
||||
state.commit()
|
||||
|
||||
runner = Runner(
|
||||
[],
|
||||
{executor.id: executor},
|
||||
state,
|
||||
InProcRunnerContext(),
|
||||
"test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
)
|
||||
runner._iteration = 7 # pyright: ignore[reportPrivateUsage]
|
||||
runner.reserve()
|
||||
|
||||
with pytest.raises(WorkflowRunnerException):
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
# Nothing was mutated by the failed reset.
|
||||
assert runner._iteration == 7 # pyright: ignore[reportPrivateUsage]
|
||||
assert state.get("preserved") == 42
|
||||
assert executor.reset_calls == 0
|
||||
|
||||
|
||||
# endregion: Tests for Runner.reset_for_new_run()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterable, Awaitable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
@@ -846,6 +847,92 @@ async def test_workflow_concurrent_execution_prevention_mixed_methods():
|
||||
assert result.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_workflow_sequential_runs_after_completion() -> None:
|
||||
"""A completed run must release the runner so the next ``run`` succeeds.
|
||||
|
||||
This is the happy-path counterpart to the concurrent-run guard tests:
|
||||
those tests verify that a *concurrent* run is rejected, but they do not
|
||||
verify that the lock is actually released afterwards. This test
|
||||
exercises that release path explicitly across the three call shapes
|
||||
(non-streaming, streaming-iterated, streaming-via-get_final_response)
|
||||
and across multiple consecutive turns to catch lock leaks.
|
||||
"""
|
||||
executor = IncrementExecutor(id="seq_executor", limit=3, increment=1)
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
# Non-streaming -> non-streaming
|
||||
r1 = await workflow.run(NumberMessage(data=0))
|
||||
assert r1.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
r2 = await workflow.run(NumberMessage(data=0))
|
||||
assert r2.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
# Non-streaming -> streaming-iterated
|
||||
stream_events: list[WorkflowEvent] = []
|
||||
async for event in workflow.run(NumberMessage(data=0), stream=True):
|
||||
stream_events.append(event)
|
||||
assert any(e.type == "status" and e.state == WorkflowRunState.IDLE for e in stream_events)
|
||||
|
||||
# Streaming -> streaming via get_final_response (no manual iteration)
|
||||
r3 = await workflow.run(NumberMessage(data=0), stream=True).get_final_response()
|
||||
assert r3.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
# Streaming -> non-streaming (back to the start)
|
||||
r4 = await workflow.run(NumberMessage(data=0))
|
||||
assert r4.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_workflow_unconsumed_stream_releases_run_lock() -> None:
|
||||
"""An unconsumed stream must not leak the run lock.
|
||||
|
||||
``Workflow.run`` reserves the runner *synchronously* so that concurrent
|
||||
callers are rejected immediately. The reservation is normally released
|
||||
by ``_run_core``'s ``finally`` once the stream is iterated. If the
|
||||
caller never iterates the stream, a GC-time finalizer must release the
|
||||
reservation instead - otherwise every subsequent ``Workflow.run`` call
|
||||
on this instance would fail with the concurrent-run error.
|
||||
"""
|
||||
executor = IncrementExecutor(id="unconsumed_stream_exec", limit=3, increment=1)
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
# Build a stream and immediately drop it without iterating.
|
||||
stream = workflow.run(NumberMessage(data=0), stream=True)
|
||||
assert stream is not None # silence unused-variable warnings; stream is GC'd below
|
||||
del stream
|
||||
gc.collect()
|
||||
# Yield to the event loop so any scheduled finalizer work can run.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# The runner should be back to IDLE; a fresh run must succeed.
|
||||
result = await workflow.run(NumberMessage(data=0))
|
||||
assert result.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_workflow_unawaited_run_coroutine_releases_run_lock() -> None:
|
||||
"""An un-awaited non-streaming ``run()`` coroutine must also not leak the lock.
|
||||
|
||||
``Workflow.run`` (non-streaming) returns a coroutine produced by
|
||||
``ResponseStream.get_final_response``. The underlying ResponseStream is
|
||||
held alive by that coroutine, so dropping the coroutine without
|
||||
awaiting it must still release the reservation via the same GC-time
|
||||
fallback used for unconsumed streams.
|
||||
"""
|
||||
executor = IncrementExecutor(id="unawaited_run_exec", limit=3, increment=1)
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
coro = workflow.run(NumberMessage(data=0))
|
||||
# Closing suppresses the "coroutine was never awaited" warning. We cast to
|
||||
# ``Any`` because the typed return is ``Awaitable[...]``; in practice it is
|
||||
# a coroutine that exposes ``close``.
|
||||
cast(Any, coro).close()
|
||||
del coro
|
||||
gc.collect()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
result = await workflow.run(NumberMessage(data=0))
|
||||
assert result.get_final_state() == WorkflowRunState.IDLE
|
||||
|
||||
|
||||
class _StreamingTestAgent(BaseAgent):
|
||||
"""Test agent that supports both streaming and non-streaming modes."""
|
||||
|
||||
@@ -1397,7 +1484,7 @@ async def test_workflow_reset_for_new_run_rejected_during_streaming_run() -> Non
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
try:
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the workflow while a run is in progress"):
|
||||
await workflow.reset_for_new_run()
|
||||
finally:
|
||||
await task
|
||||
|
||||
Reference in New Issue
Block a user