Remove reset

This commit is contained in:
Tao Chen
2026-06-11 11:14:36 -07:00
Unverified
parent 6534a739d0
commit 9da83347c8
19 changed files with 15 additions and 1311 deletions
@@ -166,10 +166,6 @@ class AgentExecutor(Executor):
raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.")
super().__init__(exec_id)
self._agent = agent
# Track whether the caller supplied a session so reset() can preserve their
# session reference (which may be wired to external/service-side storage)
# and only replace sessions we created ourselves.
self._session_supplied_by_caller = session is not None
self._session = session or self._agent.create_session()
self._pending_agent_requests: dict[str, Content] = {}
@@ -369,37 +365,10 @@ class AgentExecutor(Executor):
pending_responses_payload = state.get("pending_responses_to_agent")
self._pending_responses_to_agent = pending_responses_payload or []
@override
async def reset(self) -> None:
"""Reset the executor to its initial state for a new workflow run.
Clears the message cache, full conversation snapshot, and any pending
user-input request/response bookkeeping.
Session handling:
* If the session was created by this executor (no ``session`` argument
was passed to ``__init__``), it is replaced with a fresh one via
``agent.create_session()`` so prior conversation history does not
leak into the next run.
* If the session was supplied by the caller, it is left untouched.
The caller owns the session lifecycle (it may be backed by
service-side or external storage) and is responsible for clearing
or rotating it if a clean slate is desired.
"""
logger.debug("AgentExecutor %s: Resetting state", self.id)
def reset(self) -> None:
"""Reset the internal cache of the executor."""
logger.debug("AgentExecutor %s: Resetting cache", self.id)
self._cache.clear()
self._full_conversation.clear()
self._pending_agent_requests.clear()
self._pending_responses_to_agent.clear()
if not self._session_supplied_by_caller:
self._session = self._agent.create_session()
else:
logger.warning(
"AgentExecutor %s: Session was supplied by the caller and will not be reset. "
"Prior conversation history retained in the session may leak into the next run. "
"Reset or rotate the session externally if a clean slate is required.",
self.id,
)
async def _run_agent_and_emit(
self,
@@ -516,14 +516,6 @@ class Executor(RequestInfoMixin, DictConvertible):
"""
...
async def reset(self) -> None:
"""Reset the executor to its initial state.
Override this method in subclasses to implement custom logic that should
run when the workflow is reset.
"""
...
# endregion: Executor
@@ -88,23 +88,6 @@ class Runner:
"""
self._iteration = 0
async def reset_for_new_run(self) -> None:
"""Reset the runner for a new run.
This is useful when reusing the same workflow instance for a different run
that is independent from prior runs.
Concurrent-run rejection lives at the :class:`Workflow` level via the
active-run weak reference, so this method does not re-validate that
no run is in progress; callers are expected to have already done so.
"""
self.reset_iteration_count()
self._ctx.reset_for_new_run()
self._state.clear()
self._resumed_from_checkpoint = False
for executor in self._executors.values():
await executor.reset()
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
"""Run the workflow until no more messages are sent."""
# Emit any events already produced prior to entering loop
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
from .._sessions import ContextProvider
from .._types import ResponseStream
from ..exceptions import WorkflowCheckpointException, WorkflowException, WorkflowRunnerException
from ..exceptions import WorkflowCheckpointException, WorkflowException
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._checkpoint import CheckpointID, CheckpointStorage
from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
@@ -568,13 +568,12 @@ class Workflow(DictConvertible):
yield in_progress # noqa: RUF070
# Per-run reset for fresh-message runs only. We deliberately
# do NOT clear shared workflow state (`_state.clear()`) or the
# runner context's in-flight messages (`reset_for_new_run()`)
# here - state and pending work persist across `run()` calls
# so that a `WorkflowAgent` can deliver multi-turn input on
# the same instance and have prior turns' context survive.
# Iteration counting and per-run kwargs ARE per-run though,
# so they're reset here.
# do NOT clear shared workflow state or the runner context's
# in-flight messages here - state and pending work persist
# across `run()` calls so that a `WorkflowAgent` can deliver
# multi-turn input on the same instance and have prior turns'
# context survive. Iteration counting and per-run kwargs ARE
# per-run though, so they're reset here.
if not is_continuation:
self._runner.reset_iteration_count()
@@ -1210,24 +1209,3 @@ class Workflow(DictConvertible):
"""
existing_stream = self._active_run() if self._active_run is not None else None
return existing_stream is not None
async def reset_for_new_run(self) -> None:
"""Reset the workflow for a new run that is independent from prior runs.
Note:
This will reset EVERYTHING - executor states, workflow state, and runner
context (including pending requests/messages).
Raises:
WorkflowRunnerException: If a run is currently in progress. Reset is only
allowed when the workflow is idle to avoid clobbering in-flight run state.
"""
existing_stream = self._active_run() if self._active_run is not None else None
if existing_stream is not None:
raise WorkflowRunnerException(
"Cannot reset the workflow while a run is in progress. "
"Wait for the current run to complete before calling reset_for_new_run()."
)
self._active_run = None
await self._runner.reset_for_new_run()
@@ -534,19 +534,6 @@ class WorkflowExecutor(Executor):
for event in request_info_events
])
@override
async def reset(self) -> None:
"""Reset the WorkflowExecutor to its initial state for a new workflow run.
Clears in-flight execution contexts and the request-to-execution mapping,
then recursively resets the wrapped sub-workflow so its executors, runner
context, and shared state are also returned to a clean state.
"""
logger.debug("WorkflowExecutor %s: Resetting state", self.id)
self._execution_contexts.clear()
self._request_to_execution.clear()
await self.workflow.reset_for_new_run()
async def _process_workflow_result(
self,
result: WorkflowRunResult,
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from collections.abc import AsyncIterable, Awaitable
from typing import Any, Literal, overload
@@ -20,7 +19,7 @@ from agent_framework import (
WorkflowEvent,
WorkflowRunState,
)
from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse
from agent_framework._workflows._agent_executor import AgentExecutorResponse
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
from agent_framework._workflows._const import GLOBAL_KWARGS_KEY
@@ -307,108 +306,6 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
assert restored_session.session_id == session.session_id
# region: Tests for AgentExecutor.reset()
async def test_agent_executor_reset_clears_per_run_state() -> None:
"""reset() clears cache, conversation snapshot, and pending request/response buffers."""
agent = _CountingAgent(id="reset_agent", name="ResetAgent")
executor = AgentExecutor(agent, id="reset_exec")
# Populate every per-run buffer.
executor._cache = [Message(role="user", contents=["cached"])] # type: ignore[reportPrivateUsage]
executor._full_conversation = [ # type: ignore[reportPrivateUsage]
Message(role="user", contents=["prior turn"]),
Message(role="assistant", contents=["prior response"]),
]
pending_request = Content.from_text(text="approve?")
executor._pending_agent_requests = {"req-1": pending_request} # type: ignore[reportPrivateUsage]
executor._pending_responses_to_agent = [Content.from_text(text="approved")] # type: ignore[reportPrivateUsage]
await executor.reset()
assert executor._cache == [] # type: ignore[reportPrivateUsage]
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
assert executor._pending_agent_requests == {} # type: ignore[reportPrivateUsage]
assert executor._pending_responses_to_agent == [] # type: ignore[reportPrivateUsage]
async def test_agent_executor_reset_creates_fresh_session_when_auto_created() -> None:
"""reset() replaces the agent session when the executor created it itself."""
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
# No session passed in — executor creates one via agent.create_session().
executor = AgentExecutor(agent, id="reset_session_exec")
auto_created = executor._session # type: ignore[reportPrivateUsage]
auto_created.state["history"] = {"messages": [Message(role="user", contents=["old"])]}
await executor.reset()
new_session = executor._session # type: ignore[reportPrivateUsage]
assert new_session is not auto_created
assert new_session.session_id != auto_created.session_id
assert "history" not in new_session.state
async def test_agent_executor_reset_preserves_caller_supplied_session(caplog: pytest.LogCaptureFixture) -> None:
"""reset() leaves a session passed in via __init__ untouched and warns the caller."""
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
caller_session = AgentSession()
history_payload = {"messages": [Message(role="user", contents=["old"])]}
caller_session.state["history"] = history_payload
executor = AgentExecutor(agent, id="reset_session_exec", session=caller_session)
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
with caplog.at_level(logging.WARNING, logger="agent_framework._workflows._agent_executor"):
await executor.reset()
# Same instance, state untouched — the caller is responsible for managing the session.
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
assert caller_session.state["history"] is history_payload
assert any("Session was supplied by the caller" in record.message for record in caplog.records)
async def test_agent_executor_reset_allows_subsequent_run() -> None:
"""After reset(), the executor can be reused for a fresh workflow run without leaking state."""
agent = _CountingAgent(id="reset_reuse_agent", name="ResetReuseAgent")
executor = AgentExecutor(agent, id="reset_reuse_exec")
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
first_outputs: list[WorkflowEvent] = []
async for event in workflow.run(
AgentExecutorRequest(messages=[Message(role="user", contents=["hello"])]),
stream=True,
):
if event.type == "output":
first_outputs.append(event)
assert first_outputs, "first run should have produced at least one output event"
# After a normal run the cache is drained but the conversation snapshot remains.
assert executor._cache == [] # type: ignore[reportPrivateUsage]
assert executor._full_conversation != [] # type: ignore[reportPrivateUsage]
first_session_id = executor._session.session_id # type: ignore[reportPrivateUsage]
await workflow.reset_for_new_run()
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
# Session was auto-created, so reset() rotates it to a fresh one.
assert executor._session.session_id != first_session_id # type: ignore[reportPrivateUsage]
second_outputs: list[WorkflowEvent] = []
async for event in workflow.run(
AgentExecutorRequest(messages=[Message(role="user", contents=["second"])]),
stream=True,
):
if event.type == "output":
second_outputs.append(event)
assert second_outputs, "second run after reset should have produced at least one output event"
assert agent.call_count == 2
# endregion: Tests for AgentExecutor.reset()
async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None:
"""_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs."""
agent = _CountingAgent(id="test_agent", name="TestAgent")
@@ -1004,53 +1004,3 @@ def test_handler_typevar_error_takes_priority_over_context_error():
@handler
async def process(self, message: _T, ctx) -> None: # type: ignore[no-untyped-def]
pass
# region: Tests for Executor.reset()
async def test_executor_default_reset_is_noop():
"""The base Executor.reset() is a no-op and must complete without raising.
Subclasses that don't carry reset-relevant state should be able to rely on the
default implementation.
"""
class StatelessExecutor(Executor):
@handler
async def handle(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message)
executor_instance = StatelessExecutor(id="stateless")
# Must complete without raising and return None.
assert await executor_instance.reset() is None
async def test_executor_subclass_reset_is_invoked():
"""A subclass that overrides reset() can clear its own internal state."""
class CounterExecutor(Executor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.counter = 0
self.reset_calls = 0
@handler
async def handle(self, message: int, ctx: WorkflowContext[int]) -> None:
self.counter += message
async def reset(self) -> None:
self.counter = 0
self.reset_calls += 1
executor_instance = CounterExecutor(id="counter")
executor_instance.counter = 42
await executor_instance.reset()
assert executor_instance.counter == 0
assert executor_instance.reset_calls == 1
# endregion: Tests for Executor.reset()
@@ -1111,204 +1111,3 @@ async def test_runner_drains_straggler_events_at_iteration_end():
output_events = [e for e in events if e.type == "output"]
# We should have output events from both executors
assert len(output_events) >= 2
# region: Tests for InProcRunnerContext.reset_for_new_run()
async def test_runner_context_reset_clears_in_flight_messages():
"""reset_for_new_run drops queued executor-to-executor messages."""
ctx = InProcRunnerContext()
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="src"))
assert await ctx.has_messages() is True
ctx.reset_for_new_run()
assert await ctx.has_messages() is False
assert await ctx.drain_messages() == {}
async def test_runner_context_reset_drains_pending_events():
"""reset_for_new_run discards any events buffered for streaming."""
ctx = InProcRunnerContext()
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
assert await ctx.has_events() is True
ctx.reset_for_new_run()
assert await ctx.has_events() is False
assert await ctx.drain_events() == []
async def test_runner_context_reset_resets_streaming_flag():
"""reset_for_new_run resets streaming back to its non-streaming default."""
ctx = InProcRunnerContext()
ctx.set_streaming(True)
assert ctx.is_streaming() is True
ctx.reset_for_new_run()
assert ctx.is_streaming() is False
async def test_runner_context_reset_clears_pending_request_info_events():
"""reset_for_new_run clears any pending request_info events tracked for correlation."""
ctx = InProcRunnerContext()
request_info_event = WorkflowEvent.request_info(
request_id="request-123",
source_executor_id="source",
request_data=MockMessage(data=0),
response_type=bool,
)
await ctx.add_request_info_event(request_info_event)
assert "request-123" in await ctx.get_pending_request_info_events()
ctx.reset_for_new_run()
assert await ctx.get_pending_request_info_events() == {}
# endregion: Tests for InProcRunnerContext.reset_for_new_run()
# region: Tests for Runner.reset_for_new_run()
async def test_runner_reset_for_new_run_resets_iteration_count():
"""reset_for_new_run resets the iteration counter back to zero."""
runner = _make_runner()
runner._iteration = 7 # pyright: ignore[reportPrivateUsage]
await runner.reset_for_new_run()
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
async def test_runner_reset_for_new_run_clears_shared_state():
"""reset_for_new_run wipes both committed and pending entries from shared state."""
state = State()
state.set("committed_key", "committed_value")
state.commit()
state.set("pending_key", "pending_value") # uncommitted
runner = Runner(
[],
{},
state,
InProcRunnerContext(),
"test_name",
graph_signature_hash="test_hash",
)
await runner.reset_for_new_run()
assert state.get("committed_key") is None
assert state.get("pending_key") is None
assert state.has("committed_key") is False
assert state.has("pending_key") is False
async def test_runner_reset_for_new_run_clears_resumed_from_checkpoint_flag():
"""reset_for_new_run clears the flag set by restore_from_checkpoint."""
runner = _make_runner()
resumed_checkpoint = WorkflowCheckpoint(
checkpoint_id="resumed-cp",
workflow_name="test_name",
graph_signature_hash="test_hash",
iteration_count=5,
)
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
await runner.reset_for_new_run()
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
# And the iteration count restored from the checkpoint must be wiped, too.
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
async def test_runner_reset_for_new_run_invokes_executor_reset_for_each_executor():
"""reset_for_new_run calls reset() on every registered executor exactly once."""
class TrackingExecutor(MockExecutor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.reset_calls = 0
async def reset(self) -> None:
self.reset_calls += 1
executor_a = TrackingExecutor(id="executor_a")
executor_b = TrackingExecutor(id="executor_b")
runner = Runner(
[],
{executor_a.id: executor_a, executor_b.id: executor_b},
State(),
InProcRunnerContext(),
"test_name",
graph_signature_hash="test_hash",
)
await runner.reset_for_new_run()
assert executor_a.reset_calls == 1
assert executor_b.reset_calls == 1
async def test_runner_reset_for_new_run_resets_runner_context():
"""reset_for_new_run forwards the reset to the underlying runner context."""
ctx = InProcRunnerContext()
await ctx.send_message(WorkflowMessage(data=MockMessage(data=0), source_id="src"))
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
ctx.set_streaming(True)
runner = Runner([], {}, State(), ctx, "test_name", graph_signature_hash="test_hash")
await runner.reset_for_new_run()
assert await ctx.has_messages() is False
assert await ctx.has_events() is False
assert ctx.is_streaming() is False
async def test_runner_can_run_again_after_reset_for_new_run():
"""After reset_for_new_run the runner can be reserved and converge a fresh workload."""
executor_a = MockExecutor(id="executor_a")
executor_b = MockExecutor(id="executor_b")
edges = [
SingleEdgeGroup(executor_a.id, executor_b.id),
SingleEdgeGroup(executor_b.id, executor_a.id),
]
executors: dict[str, Executor] = {
executor_a.id: executor_a,
executor_b.id: executor_b,
}
state = State()
ctx = InProcRunnerContext()
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
# First run: drives MockExecutor's loop until it yields the terminal value.
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
async for _ in runner.run_until_convergence():
pass
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
await runner.reset_for_new_run()
# Second run: must succeed cleanly using the same runner instance.
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
second_run_outputs: list[int] = []
async for event in runner.run_until_convergence():
if event.type == "output":
second_run_outputs.append(event.data)
assert second_run_outputs == [10]
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
# endregion: Tests for Runner.reset_for_new_run()
@@ -689,89 +689,3 @@ async def test_sub_workflow_intermediate_outputs_propagate_to_parent() -> None:
# The parent's own terminal output is unaffected.
assert any(e.executor_id == "parent_sink" and e.data == "final: hello" for e in output_events)
# region: Tests for WorkflowExecutor.reset()
async def test_workflow_executor_reset_clears_execution_state() -> None:
"""reset() clears the WorkflowExecutor's per-run execution contexts and request mappings."""
validation_workflow = create_email_validation_workflow()
parent = Coordinator()
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
main_workflow = (
WorkflowBuilder(start_executor=parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.build()
)
# First run pauses with a pending request from the sub-workflow.
result = await main_workflow.run("test@example.com")
assert len(result.get_request_info_events()) == 1
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
assert len(workflow_executor._request_to_execution) == 1 # type: ignore[reportPrivateUsage]
await main_workflow.reset_for_new_run()
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
async def test_workflow_executor_reset_resets_wrapped_workflow() -> None:
"""reset() recursively resets the wrapped workflow (runner iteration counter cleared)."""
validation_workflow = create_email_validation_workflow()
parent = Coordinator()
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
main_workflow = (
WorkflowBuilder(start_executor=parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.build()
)
await main_workflow.run("test@example.com")
# The sub-workflow's runner advanced past iteration 0 during execution.
assert validation_workflow._runner._iteration > 0 # type: ignore[reportPrivateUsage]
await main_workflow.reset_for_new_run()
# The wrapped workflow's runner was reset along with the parent.
assert validation_workflow._runner._iteration == 0 # type: ignore[reportPrivateUsage]
async def test_workflow_executor_reset_allows_subsequent_run() -> None:
"""After reset(), the parent + WorkflowExecutor can be reused for a fresh run with no leakage."""
validation_workflow = create_email_validation_workflow()
parent = Coordinator()
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
main_workflow = (
WorkflowBuilder(start_executor=parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.build()
)
first_result = await main_workflow.run("first@example.com")
assert len(first_result.get_request_info_events()) == 1
await main_workflow.reset_for_new_run()
# State on the WorkflowExecutor and parent's pending-request bookkeeping is clean.
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
second_result = await main_workflow.run("second@example.com")
second_requests = second_result.get_request_info_events()
assert len(second_requests) == 1
assert isinstance(second_requests[0].data, DomainCheckRequest)
# Confirm the new run produced a request from the second email, not the cached first one.
assert second_requests[0].data.email == "second@example.com"
# And the WorkflowExecutor is now tracking exactly one fresh execution.
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
# endregion: Tests for WorkflowExecutor.reset()
@@ -29,7 +29,6 @@ from agent_framework import (
WorkflowEvent,
WorkflowException,
WorkflowMessage,
WorkflowRunnerException,
WorkflowRunState,
handler,
response_handler,
@@ -1354,144 +1353,3 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None
# endregion
# region: Tests for Workflow.reset_for_new_run()
async def test_workflow_reset_for_new_run_allows_subsequent_run() -> None:
"""After reset_for_new_run() the same workflow instance can be run again from scratch."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = IncrementExecutor(id="executor_b")
workflow = (
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
.add_edge(executor_a, executor_b)
.add_edge(executor_b, executor_a)
.build()
)
first = await workflow.run(NumberMessage(data=0))
assert first.get_outputs() == [10]
await workflow.reset_for_new_run()
second = await workflow.run(NumberMessage(data=0))
assert second.get_outputs() == [10]
async def test_workflow_reset_for_new_run_clears_workflow_state() -> None:
"""reset_for_new_run() clears values that executors persisted in shared workflow state."""
class StateWritingExecutor(Executor):
@handler
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
previous = ctx.get_state("seen") or 0
ctx.set_state("seen", previous + 1)
await ctx.yield_output(previous + 1)
state_writer = StateWritingExecutor(id="state_writer")
workflow = WorkflowBuilder(start_executor=state_writer, output_from=[state_writer]).build()
first = await workflow.run(NumberMessage(data=1))
assert first.get_outputs() == [1]
# State was persisted by the executor.
assert workflow._runner.state.get("seen") == 1 # pyright: ignore[reportPrivateUsage]
await workflow.reset_for_new_run()
# The runner's shared state has been wiped.
assert workflow._runner.state.get("seen") is None # pyright: ignore[reportPrivateUsage]
second = await workflow.run(NumberMessage(data=1))
# Counter started fresh from 0 again; output is 1, not 2.
assert second.get_outputs() == [1]
async def test_workflow_reset_for_new_run_invokes_executor_reset_hook() -> None:
"""reset_for_new_run() calls Executor.reset() on every executor in the workflow."""
class ResettableExecutor(Executor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.reset_calls = 0
self.handled = 0
@handler
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
self.handled += 1
await ctx.yield_output(self.handled)
async def reset(self) -> None:
self.reset_calls += 1
self.handled = 0
executor = ResettableExecutor(id="resettable")
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
await workflow.run(NumberMessage(data=1))
assert executor.handled == 1
assert executor.reset_calls == 0
await workflow.reset_for_new_run()
assert executor.reset_calls == 1
# The executor's own counter was wiped by its overridden reset().
assert executor.handled == 0
async def test_workflow_reset_for_new_run_resets_runner_iteration_counter() -> None:
"""reset_for_new_run() drops the iteration counter accumulated during a prior run."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = IncrementExecutor(id="executor_b")
workflow = (
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
.add_edge(executor_a, executor_b)
.add_edge(executor_b, executor_a)
.build()
)
await workflow.run(NumberMessage(data=0))
assert workflow._runner._iteration > 0 # pyright: ignore[reportPrivateUsage]
await workflow.reset_for_new_run()
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
async def test_workflow_reset_for_new_run_rejected_during_streaming_run() -> None:
"""reset_for_new_run() raises WorkflowRunnerException while a streaming run is in progress."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = IncrementExecutor(id="executor_b")
workflow = (
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
.add_edge(executor_a, executor_b)
.add_edge(executor_b, executor_a)
.build()
)
async def consume_stream_slowly() -> list[WorkflowEvent]:
events: list[WorkflowEvent] = []
async for event in workflow.run(NumberMessage(data=0), stream=True):
events.append(event)
await asyncio.sleep(0.01)
return events
task = asyncio.create_task(consume_stream_slowly())
# Let the streaming run start.
await asyncio.sleep(0.02)
try:
with pytest.raises(WorkflowRunnerException, match="Cannot reset the workflow while a run is in progress"):
await workflow.reset_for_new_run()
finally:
await task
# After the run completes, reset succeeds again.
await workflow.reset_for_new_run()
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
# endregion
@@ -3032,6 +3032,7 @@ class TestCheckpointContextPathValidation:
agent.workflow = MagicMock()
agent.workflow.name = "wf"
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
agent.run = AsyncMock(
side_effect=[
AgentResponse(messages=[]),
@@ -3155,6 +3156,8 @@ class TestCheckpointContextPathValidation:
agent.workflow = MagicMock()
agent.workflow.name = "wf"
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
agent.run = AsyncMock(return_value=AgentResponse(messages=[]))
# Constructor inspects WorkflowAgent.workflow internals; bypass setup
# by feeding a configured mock through a normal init.
@@ -3971,80 +3974,5 @@ class TestWorkflowAgentHosting:
assert len(approval_responses) == 1
assert approval_responses[0].approved is False # type: ignore[attr-defined]
async def test_workflow_is_reset_when_no_prior_conversation(self) -> None:
"""Without ``previous_response_id`` or ``conversation_id`` the host must
reset the workflow so any in-memory state from a previous request is
cleared before the new turn runs."""
workflow_agent = _build_text_workflow_agent("fresh")
server = _make_server(workflow_agent)
with patch.object(
workflow_agent.workflow,
"reset_for_new_run",
new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run),
) as reset_spy:
resp = await _post(server, input_text="hi", stream=False)
assert resp.status_code == 200
assert reset_spy.await_count == 1
async def test_workflow_is_reset_across_independent_requests(self) -> None:
"""Two consecutive requests without a chaining context id must each
reset the workflow."""
workflow_agent = _build_text_workflow_agent("again")
server = _make_server(workflow_agent)
with patch.object(
workflow_agent.workflow,
"reset_for_new_run",
new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run),
) as reset_spy:
first = await _post(server, input_text="hi", stream=False)
second = await _post(server, input_text="hi again", stream=False)
assert first.status_code == 200
assert second.status_code == 200
assert reset_spy.await_count == 2
async def test_workflow_is_not_reset_when_resuming_from_checkpoint(self) -> None:
"""When ``previous_response_id`` resolves to an existing workflow
checkpoint the host restores the checkpoint instead of resetting
the workflow."""
workflow_agent, _ = _build_approval_workflow_agent(
approval_request_id="apr_no_reset",
final_text="resumed",
)
server = _make_server(workflow_agent)
first = await _post(server, stream=False)
assert first.status_code == 200
first_body = first.json()
first_response_id = first_body["id"]
approval_request_id = next(it["id"] for it in first_body["output"] if it["type"] == "mcp_approval_request")
with patch.object(
workflow_agent.workflow,
"reset_for_new_run",
new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run),
) as reset_spy:
second = await _post_json(
server,
{
"model": "test-model",
"input": [
{
"type": "mcp_approval_response",
"approval_request_id": approval_request_id,
"approve": True,
}
],
"stream": False,
"previous_response_id": first_response_id,
},
)
assert second.status_code == 200
assert reset_spy.await_count == 0
# endregion
@@ -198,12 +198,7 @@ class TestBeforeRun:
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
mock_oss_mem0_client.search.return_value = []
provider = Mem0ContextProvider(
source_id="mem0",
mem0_client=mock_oss_mem0_client,
user_id="u1",
agent_id="a1"
)
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1")
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
@@ -598,25 +598,3 @@ class BaseGroupChatOrchestrator(Executor, ABC):
metadata: Pattern-specific state dict
"""
pass
@override
async def reset(self) -> None:
"""Reset the orchestrator to its initial state for a new workflow run.
Clears the shared conversation history and round counter, then delegates
to ``_reset_pattern_state()`` so subclasses can clean up any
pattern-specific per-run state (caches, sessions, ledgers, etc.).
"""
logger.debug("%s %s: Resetting state", self.__class__.__name__, self.id)
self._full_conversation.clear()
self._round_index = 0
self._reset_pattern_state()
def _reset_pattern_state(self) -> None:
"""Reset pattern-specific state.
Override this method in subclasses to clear pattern-specific per-run state
when ``reset()`` is invoked. Called after the base class clears the shared
conversation and round counter.
"""
pass
@@ -327,7 +327,6 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator):
)
self._agent = agent
self._retry_attempts = retry_attempts
self._session_supplied_by_caller = session is not None
self._session = session or agent.create_session()
# Cache for messages since last agent invocation
# This is different from the full conversation history maintained by the base orchestrator
@@ -338,25 +337,6 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator):
self._cache.extend(messages)
return super()._append_messages(messages)
@override
def _reset_pattern_state(self) -> None:
"""Reset pattern-specific state for a new workflow run.
Clears the per-run message cache and rotates the orchestrator agent's
session unless the caller supplied a session explicitly (in which case
the caller is responsible for the session's lifecycle).
"""
self._cache.clear()
if self._session_supplied_by_caller:
logger.warning(
"%s %s: Session was supplied by the caller and will not be reset. "
"If you want a fresh session for the next run, reset or replace it before invoking the workflow.",
self.__class__.__name__,
self.id,
)
else:
self._session = self._agent.create_session()
@override
async def _handle_messages(
self,
@@ -509,14 +509,6 @@ class MagenticManagerBase(ABC):
"""Restore runtime state from checkpoint data."""
return
def on_reset(self) -> None:
"""Reset per-run runtime state for a new workflow run.
Subclasses should clear any state accumulated during a run (e.g. cached
ledgers, agent sessions) so the manager is ready to start fresh.
"""
return
class StandardMagenticManager(MagenticManagerBase):
"""Standard Magentic manager that performs real LLM calls via a Agent.
@@ -769,12 +761,6 @@ class StandardMagenticManager(MagenticManagerBase):
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager agent session from checkpoint state")
@override
def on_reset(self) -> None:
"""Clear cached ledger and start a fresh agent session for a new run."""
self.task_ledger = None
self._session = self._agent.create_session()
# endregion Magentic Manager
@@ -1277,15 +1263,6 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator):
# a target will broadcast to all.
await ctx.send_message(MagenticResetSignal())
@override
def _reset_pattern_state(self) -> None:
"""Reset Magentic-specific per-run state for a new workflow run."""
self._magentic_context = None
self._task_ledger = None
self._progress_ledger = None
self._terminated = False
self._manager.on_reset()
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current orchestrator state for checkpointing."""
@@ -1097,139 +1097,3 @@ def test_group_chat_orchestrator_factory_invalid_return_type():
# endregion
# region Reset
async def test_base_orchestrator_reset_clears_conversation_and_round_index() -> None:
"""reset() clears the conversation history and the round counter."""
from agent_framework.orchestrations import GroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
selector = make_sequence_selector()
orchestrator = GroupChatOrchestrator(
id="orch",
participant_registry=ParticipantRegistry([]),
selection_func=selector,
max_rounds=2,
)
orchestrator._full_conversation = [Message(role="user", contents=["hi"], author_name="user")]
orchestrator._round_index = 4
await orchestrator.reset()
assert orchestrator._full_conversation == []
assert orchestrator._round_index == 0
async def test_base_orchestrator_reset_invokes_pattern_state_hook() -> None:
"""reset() calls _reset_pattern_state() so subclasses can clean up their own state."""
from agent_framework.orchestrations import GroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
selector = make_sequence_selector()
class TrackingOrchestrator(GroupChatOrchestrator):
reset_calls: int = 0
def _reset_pattern_state(self) -> None:
type(self).reset_calls += 1
orchestrator = TrackingOrchestrator(
id="orch",
participant_registry=ParticipantRegistry([]),
selection_func=selector,
max_rounds=2,
)
await orchestrator.reset()
await orchestrator.reset()
assert TrackingOrchestrator.reset_calls == 2
async def test_agent_based_orchestrator_reset_clears_cache_and_rotates_session() -> None:
"""When the session was not supplied by the caller, reset() rotates the session and clears the cache."""
from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
agent = cast(Agent, StubManagerAgent())
orchestrator = AgentBasedGroupChatOrchestrator(
agent=agent,
participant_registry=ParticipantRegistry([]),
max_rounds=2,
)
original_session = orchestrator._session
orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")]
orchestrator._full_conversation = [Message(role="user", contents=["x"], author_name="user")]
orchestrator._round_index = 3
await orchestrator.reset()
assert orchestrator._cache == []
assert orchestrator._full_conversation == []
assert orchestrator._round_index == 0
assert orchestrator._session is not original_session
async def test_agent_based_orchestrator_reset_warns_when_session_supplied(caplog: pytest.LogCaptureFixture) -> None:
"""When the caller supplied a session, reset() preserves it and logs a warning."""
import logging
from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
agent = cast(Agent, StubManagerAgent())
supplied_session = agent.create_session()
orchestrator = AgentBasedGroupChatOrchestrator(
agent=agent,
participant_registry=ParticipantRegistry([]),
session=supplied_session,
max_rounds=2,
)
orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")]
with caplog.at_level(logging.WARNING, logger="agent_framework_orchestrations._group_chat"):
await orchestrator.reset()
assert orchestrator._cache == []
# The caller-owned session must be preserved.
assert orchestrator._session is supplied_session
warnings = [
r for r in caplog.records if r.levelno == logging.WARNING and "Session was supplied by the caller" in r.message
]
assert warnings, f"expected a warning about caller-supplied session, got: {[r.message for r in caplog.records]}"
async def test_workflow_reset_resets_group_chat_orchestrator() -> None:
"""End-to-end: workflow.reset_for_new_run() resets the orchestrator's conversation state."""
selector = make_sequence_selector()
alpha = StubAgent("alpha", "ack from alpha")
beta = StubAgent("beta", "ack from beta")
workflow = GroupChatBuilder(
participants=[alpha, beta],
max_rounds=2,
selection_func=selector,
orchestrator_name="manager",
).build()
async for _ in workflow.run("first task", stream=True):
pass
orchestrator = cast(BaseGroupChatOrchestrator, workflow.executors[GroupChatBuilder.DEFAULT_ORCHESTRATOR_ID])
assert orchestrator._full_conversation, "orchestrator should have accumulated conversation after first run"
assert orchestrator._round_index > 0
await workflow.reset_for_new_run()
assert orchestrator._full_conversation == []
assert orchestrator._round_index == 0
# endregion
@@ -1243,135 +1243,3 @@ def test_standard_manager_checkpoint_restore_empty_state():
# endregion
# region Manager Reset Tests
def test_magentic_manager_base_on_reset_is_noop_by_default():
"""MagenticManagerBase.on_reset() is a no-op so subclasses can opt in."""
mgr = FakeManager()
# Seed some state on the fake to confirm the base hook does not touch it.
mgr.task_ledger = _SimpleLedger(
facts=Message("assistant", ["facts"]),
plan=Message("assistant", ["plan"]),
)
mgr.on_reset() # base implementation is a no-op
assert mgr.task_ledger is not None
def test_standard_manager_on_reset_clears_ledger_and_rotates_session():
"""StandardMagenticManager.on_reset() clears the cached ledger and creates a fresh session."""
from agent_framework_orchestrations._magentic import _MagenticTaskLedger # type: ignore[reportPrivateUsage]
agent = StubManagerAgent()
mgr = StandardMagenticManager(agent=agent)
mgr.task_ledger = _MagenticTaskLedger(
facts=Message("assistant", ["facts"]),
plan=Message("assistant", ["plan"]),
)
original_session = mgr._session
mgr.on_reset()
assert mgr.task_ledger is None
assert mgr._session is not original_session
assert mgr._session.session_id != original_session.session_id
async def test_magentic_orchestrator_reset_invokes_manager_on_reset() -> None:
"""_reset_pattern_state() clears orchestrator state and delegates to manager.on_reset()."""
from agent_framework_orchestrations._base_group_chat_orchestrator import (
ParticipantRegistry, # type: ignore[reportPrivateUsage]
)
class TrackingManager(FakeManager):
reset_calls: int = 0
@override
def on_reset(self) -> None:
type(self).reset_calls += 1
manager = TrackingManager()
orchestrator = MagenticOrchestrator(
manager=manager,
participant_registry=ParticipantRegistry([]),
)
# Seed magentic-specific state to confirm it is cleared.
orchestrator._magentic_context = MagenticContext( # type: ignore[reportPrivateUsage]
task="task",
participant_descriptions={"agentA": "desc"},
)
orchestrator._task_ledger = Message("assistant", ["ledger"]) # type: ignore[reportPrivateUsage]
orchestrator._progress_ledger = MagenticProgressLedger( # type: ignore[reportPrivateUsage]
is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False),
is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False),
is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True),
next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"),
instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="do"),
)
orchestrator._terminated = True # type: ignore[reportPrivateUsage]
await orchestrator.reset()
assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage]
assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._terminated is False # type: ignore[reportPrivateUsage]
assert TrackingManager.reset_calls == 1
async def test_magentic_orchestrator_reset_propagates_manager_on_reset_failure() -> None:
"""Failures in manager.on_reset() must propagate so callers can react."""
from agent_framework_orchestrations._base_group_chat_orchestrator import (
ParticipantRegistry, # type: ignore[reportPrivateUsage]
)
class FailingManager(FakeManager):
@override
def on_reset(self) -> None:
raise RuntimeError("boom")
manager = FailingManager()
orchestrator = MagenticOrchestrator(
manager=manager,
participant_registry=ParticipantRegistry([]),
)
with pytest.raises(RuntimeError, match="boom"):
await orchestrator.reset()
async def test_workflow_reset_resets_magentic_orchestrator_and_manager() -> None:
"""End-to-end: workflow.reset_for_new_run() resets Magentic orchestrator and manager state."""
manager = FakeManager()
manager.task_ledger = _SimpleLedger(
facts=Message("assistant", ["seeded facts"]),
plan=Message("assistant", ["seeded plan"]),
)
workflow = MagenticBuilder(
participants=[StubAgent(manager.next_speaker_name, "first draft")],
manager=manager,
).build()
async for _ in workflow.run("first task", stream=True):
pass
orchestrator = next(e for e in workflow.executors.values() if isinstance(e, MagenticOrchestrator))
assert orchestrator._terminated is True # type: ignore[reportPrivateUsage]
assert orchestrator._task_ledger is not None # type: ignore[reportPrivateUsage]
assert manager.task_ledger is not None
await workflow.reset_for_new_run()
assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage]
assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._terminated is False # type: ignore[reportPrivateUsage]
# FakeManager.on_reset is the base no-op, but the orchestrator still tolerates that case.
# For the standard manager we exercise full clearing in the dedicated unit test above.
# endregion
-1
View File
@@ -169,7 +169,6 @@ callers can still inspect progress or supporting work from the response messages
| Sample | File | Concepts |
| -------------------------------- | ------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------- |
| State with Agents | [state-management/state_with_agents.py](./state-management/state_with_agents.py) | Store in state once and later reuse across agents |
| Reset Workflow Between Runs | [state-management/workflow_reset.py](./state-management/workflow_reset.py) | Reuse one workflow instance across independent runs via `reset_for_new_run()` |
| Workflow Kwargs - Global Context | [state-management/workflow_kwargs_global.py](./state-management/workflow_kwargs_global.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools in all agents |
| Workflow Kwargs - Per Agent | [state-management/workflow_kwargs_per_agent.py](./state-management/workflow_kwargs_per_agent.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools in individual agents |
@@ -1,212 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from dataclasses import dataclass
from agent_framework import Executor, Workflow, WorkflowBuilder, WorkflowContext, handler
from typing_extensions import Never, override
"""
Sample: Reusing a Workflow across independent jobs with reset_for_new_run().
Build a small moderation pipeline that silently accumulates stats as messages
flow through it and emits a summary only when the caller asks for one. Drive
the same workflow instance across multiple independent jobs and show that
``Workflow.reset_for_new_run()`` clears all per-executor state without
rebuilding the graph.
Two custom executors share the work, each with its own per-job state and its
own ``reset()`` override:
- ``FlaggedKeywordCounter`` is the start executor. It accepts message strings
to inspect (silently updating local stats, sending nothing downstream) and
``ReportRequest`` markers that cause it to forward a ``StatsSnapshot``.
- ``StatsReporter`` formats the snapshot, increments its own emitted-reports
counter, and yields the summary as the workflow output.
A run with a string produces no output, just a state update. A run with a
``ReportRequest`` produces exactly one summary. Job boundaries are entirely
controlled by the caller via ``reset_for_new_run()``, which calls ``reset()``
on every executor in the graph.
Purpose:
Show how to:
- Hold per-job aggregate state on a custom Executor subclass.
- Override ``Executor.reset()`` on every executor that owns per-run state, so
it is cleared automatically when the workflow is reset.
- Call ``Workflow.reset_for_new_run()`` between independent jobs so a single
workflow instance can serve a stream of unrelated batches without leaking
state.
Prerequisites:
- No external services or credentials required; this sample runs entirely in-process.
- Familiarity with WorkflowBuilder and Executor subclasses.
"""
@dataclass
class ReportRequest:
"""Marker input that asks the workflow to emit a summary of stats so far."""
@dataclass
class StatsSnapshot:
"""Snapshot the counter forwards to the reporter when a report is requested."""
messages_seen: int
flagged_messages: int
flagged_keywords: list[str]
class FlaggedKeywordCounter(Executor):
"""Executor that silently accumulates per-job stats; emits on demand.
Holds three instance attributes that build up across the runs that make up a
single job:
- ``_messages_seen``: how many messages have been inspected.
- ``_flagged_messages``: how many of those messages contained any flagged keyword.
- ``_flagged_keywords``: the set of distinct keywords actually observed.
Two handlers dispatch by input type:
- ``inspect`` accepts a string, updates the counters, and sends nothing.
- ``emit_report`` accepts a ``ReportRequest`` and forwards a current
``StatsSnapshot`` to the downstream reporter.
Without overriding ``reset()`` this state would leak into the next job when the
workflow is reused via ``Workflow.reset_for_new_run()``. The override below
clears these attributes so each fresh job starts empty.
"""
FLAGGED_KEYWORDS = frozenset({"spam", "scam", "phishing"})
def __init__(self, id: str) -> None:
super().__init__(id=id)
self._messages_seen: int = 0
self._flagged_messages: int = 0
self._flagged_keywords: set[str] = set()
@handler
async def inspect(self, message: str, ctx: WorkflowContext[StatsSnapshot]) -> None:
"""Inspect ``message`` and update local stats. Sends nothing downstream."""
self._messages_seen += 1
hits = {kw for kw in self.FLAGGED_KEYWORDS if kw in message.lower()}
if hits:
self._flagged_messages += 1
self._flagged_keywords.update(hits)
@handler
async def emit_report(self, _: ReportRequest, ctx: WorkflowContext[StatsSnapshot]) -> None:
"""Forward the current stats snapshot to the reporter on request."""
await ctx.send_message(
StatsSnapshot(
messages_seen=self._messages_seen,
flagged_messages=self._flagged_messages,
flagged_keywords=sorted(self._flagged_keywords),
)
)
@override
async def reset(self) -> None:
"""Clear per-job aggregate state when the workflow is reset.
``Workflow.reset_for_new_run()`` calls ``reset()`` on every executor in the
graph; overriding it here is what makes a reused workflow safe to drive with
a brand-new job.
"""
self._messages_seen = 0
self._flagged_messages = 0
self._flagged_keywords.clear()
class StatsReporter(Executor):
"""Terminal executor that formats a snapshot and yields it as workflow output.
Holds a single instance attribute, ``_reports_emitted``, that tracks how many
summaries this reporter has produced on this workflow instance, and clears
it on reset so a reset workflow behaves identically to a freshly built one.
"""
def __init__(self, id: str) -> None:
super().__init__(id=id)
self._reports_emitted: int = 0
@handler
async def report(self, snapshot: StatsSnapshot, ctx: WorkflowContext[Never, str]) -> None:
self._reports_emitted += 1
summary = (
f"messages={snapshot.messages_seen}, "
f"flagged={snapshot.flagged_messages}, "
f"keywords={snapshot.flagged_keywords or 'none'}, "
f"reports_emitted={self._reports_emitted}"
)
await ctx.yield_output(summary)
@override
async def reset(self) -> None:
"""Clear the emitted-reports counter when the workflow is reset."""
self._reports_emitted = 0
async def _process(workflow: Workflow, messages: list[str]) -> None:
"""Send each message through the workflow; no output is produced."""
for message in messages:
await workflow.run(message)
async def _request_report(workflow: Workflow) -> str:
"""Ask the workflow for a summary of the stats accumulated so far."""
events = await workflow.run(ReportRequest())
outputs = events.get_outputs()
return outputs[0] if outputs else ""
async def main() -> None:
"""Build the moderation workflow once, then run it across three independent jobs."""
# 1. Build the moderation pipeline once. The same workflow instance will be
# reused for every job; that's the whole point of this sample.
counter = FlaggedKeywordCounter(id="counter")
reporter = StatsReporter(id="reporter")
workflow = WorkflowBuilder(start_executor=counter, output_from=[reporter]).add_edge(counter, reporter).build()
# 2. First job -- inspect three messages, then request a report. Note this
# batch happens to be three messages, but any size works.
await _process(workflow, ["hello there", "free phishing kit", "lunch plans?"])
print(f"Batch A summary: {await _request_report(workflow)}")
# 3. Second job WITHOUT reset. State from batch A leaks in: the counter's
# tallies and the reporter's emitted-reports counter both keep
# accumulating even though batch B is conceptually a separate job.
await _process(workflow, ["weekly status update", "team offsite agenda", "quarterly review"])
print(f"Batch B summary (no reset): {await _request_report(workflow)}")
# 4. Now reset between jobs and process the same batch B again. The summary
# reflects only batch B and every per-run counter starts fresh, because
# reset_for_new_run() calls reset() on every executor in the graph:
# - FlaggedKeywordCounter clears its message / flag / keyword tallies.
# - StatsReporter clears its emitted-reports counter.
await workflow.reset_for_new_run()
await _process(workflow, ["weekly status update", "team offsite agenda", "quarterly review"])
print(f"Batch B summary (after reset): {await _request_report(workflow)}")
# 5. Reset again before a final unrelated job. A reset workflow is
# indistinguishable from a freshly built one for state purposes, but
# cheaper because the graph and executor objects are reused.
await workflow.reset_for_new_run()
await _process(workflow, ["spam offer #1", "scam alert", "phishing attempt"])
print(f"Batch C summary (after reset): {await _request_report(workflow)}")
"""
Sample Output:
Batch A summary: messages=3, flagged=1, keywords=['phishing'], reports_emitted=1
Batch B summary (no reset): messages=6, flagged=1, keywords=['phishing'], reports_emitted=2
Batch B summary (after reset): messages=3, flagged=0, keywords=none, reports_emitted=1
Batch C summary (after reset): messages=3, flagged=3, keywords=['phishing', 'scam', 'spam'], reports_emitted=1
"""
if __name__ == "__main__":
asyncio.run(main())