mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Remove reset
This commit is contained in:
@@ -166,10 +166,6 @@ class AgentExecutor(Executor):
|
||||
raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.")
|
||||
super().__init__(exec_id)
|
||||
self._agent = agent
|
||||
# Track whether the caller supplied a session so reset() can preserve their
|
||||
# session reference (which may be wired to external/service-side storage)
|
||||
# and only replace sessions we created ourselves.
|
||||
self._session_supplied_by_caller = session is not None
|
||||
self._session = session or self._agent.create_session()
|
||||
|
||||
self._pending_agent_requests: dict[str, Content] = {}
|
||||
@@ -369,37 +365,10 @@ class AgentExecutor(Executor):
|
||||
pending_responses_payload = state.get("pending_responses_to_agent")
|
||||
self._pending_responses_to_agent = pending_responses_payload or []
|
||||
|
||||
@override
|
||||
async def reset(self) -> None:
|
||||
"""Reset the executor to its initial state for a new workflow run.
|
||||
|
||||
Clears the message cache, full conversation snapshot, and any pending
|
||||
user-input request/response bookkeeping.
|
||||
|
||||
Session handling:
|
||||
* If the session was created by this executor (no ``session`` argument
|
||||
was passed to ``__init__``), it is replaced with a fresh one via
|
||||
``agent.create_session()`` so prior conversation history does not
|
||||
leak into the next run.
|
||||
* If the session was supplied by the caller, it is left untouched.
|
||||
The caller owns the session lifecycle (it may be backed by
|
||||
service-side or external storage) and is responsible for clearing
|
||||
or rotating it if a clean slate is desired.
|
||||
"""
|
||||
logger.debug("AgentExecutor %s: Resetting state", self.id)
|
||||
def reset(self) -> None:
|
||||
"""Reset the internal cache of the executor."""
|
||||
logger.debug("AgentExecutor %s: Resetting cache", self.id)
|
||||
self._cache.clear()
|
||||
self._full_conversation.clear()
|
||||
self._pending_agent_requests.clear()
|
||||
self._pending_responses_to_agent.clear()
|
||||
if not self._session_supplied_by_caller:
|
||||
self._session = self._agent.create_session()
|
||||
else:
|
||||
logger.warning(
|
||||
"AgentExecutor %s: Session was supplied by the caller and will not be reset. "
|
||||
"Prior conversation history retained in the session may leak into the next run. "
|
||||
"Reset or rotate the session externally if a clean slate is required.",
|
||||
self.id,
|
||||
)
|
||||
|
||||
async def _run_agent_and_emit(
|
||||
self,
|
||||
|
||||
@@ -516,14 +516,6 @@ class Executor(RequestInfoMixin, DictConvertible):
|
||||
"""
|
||||
...
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the executor to its initial state.
|
||||
|
||||
Override this method in subclasses to implement custom logic that should
|
||||
run when the workflow is reset.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# endregion: Executor
|
||||
|
||||
|
||||
@@ -88,23 +88,6 @@ class Runner:
|
||||
"""
|
||||
self._iteration = 0
|
||||
|
||||
async def reset_for_new_run(self) -> None:
|
||||
"""Reset the runner for a new run.
|
||||
|
||||
This is useful when reusing the same workflow instance for a different run
|
||||
that is independent from prior runs.
|
||||
|
||||
Concurrent-run rejection lives at the :class:`Workflow` level via the
|
||||
active-run weak reference, so this method does not re-validate that
|
||||
no run is in progress; callers are expected to have already done so.
|
||||
"""
|
||||
self.reset_iteration_count()
|
||||
self._ctx.reset_for_new_run()
|
||||
self._state.clear()
|
||||
self._resumed_from_checkpoint = False
|
||||
for executor in self._executors.values():
|
||||
await executor.reset()
|
||||
|
||||
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Run the workflow until no more messages are sent."""
|
||||
# Emit any events already produced prior to entering loop
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from .._sessions import ContextProvider
|
||||
from .._types import ResponseStream
|
||||
from ..exceptions import WorkflowCheckpointException, WorkflowException, WorkflowRunnerException
|
||||
from ..exceptions import WorkflowCheckpointException, WorkflowException
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._checkpoint import CheckpointID, CheckpointStorage
|
||||
from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
|
||||
@@ -568,13 +568,12 @@ class Workflow(DictConvertible):
|
||||
yield in_progress # noqa: RUF070
|
||||
|
||||
# Per-run reset for fresh-message runs only. We deliberately
|
||||
# do NOT clear shared workflow state (`_state.clear()`) or the
|
||||
# runner context's in-flight messages (`reset_for_new_run()`)
|
||||
# here - state and pending work persist across `run()` calls
|
||||
# so that a `WorkflowAgent` can deliver multi-turn input on
|
||||
# the same instance and have prior turns' context survive.
|
||||
# Iteration counting and per-run kwargs ARE per-run though,
|
||||
# so they're reset here.
|
||||
# do NOT clear shared workflow state or the runner context's
|
||||
# in-flight messages here - state and pending work persist
|
||||
# across `run()` calls so that a `WorkflowAgent` can deliver
|
||||
# multi-turn input on the same instance and have prior turns'
|
||||
# context survive. Iteration counting and per-run kwargs ARE
|
||||
# per-run though, so they're reset here.
|
||||
if not is_continuation:
|
||||
self._runner.reset_iteration_count()
|
||||
|
||||
@@ -1210,24 +1209,3 @@ class Workflow(DictConvertible):
|
||||
"""
|
||||
existing_stream = self._active_run() if self._active_run is not None else None
|
||||
return existing_stream is not None
|
||||
|
||||
async def reset_for_new_run(self) -> None:
|
||||
"""Reset the workflow for a new run that is independent from prior runs.
|
||||
|
||||
Note:
|
||||
This will reset EVERYTHING - executor states, workflow state, and runner
|
||||
context (including pending requests/messages).
|
||||
|
||||
Raises:
|
||||
WorkflowRunnerException: If a run is currently in progress. Reset is only
|
||||
allowed when the workflow is idle to avoid clobbering in-flight run state.
|
||||
|
||||
"""
|
||||
existing_stream = self._active_run() if self._active_run is not None else None
|
||||
if existing_stream is not None:
|
||||
raise WorkflowRunnerException(
|
||||
"Cannot reset the workflow while a run is in progress. "
|
||||
"Wait for the current run to complete before calling reset_for_new_run()."
|
||||
)
|
||||
self._active_run = None
|
||||
await self._runner.reset_for_new_run()
|
||||
|
||||
@@ -534,19 +534,6 @@ class WorkflowExecutor(Executor):
|
||||
for event in request_info_events
|
||||
])
|
||||
|
||||
@override
|
||||
async def reset(self) -> None:
|
||||
"""Reset the WorkflowExecutor to its initial state for a new workflow run.
|
||||
|
||||
Clears in-flight execution contexts and the request-to-execution mapping,
|
||||
then recursively resets the wrapped sub-workflow so its executors, runner
|
||||
context, and shared state are also returned to a clean state.
|
||||
"""
|
||||
logger.debug("WorkflowExecutor %s: Resetting state", self.id)
|
||||
self._execution_contexts.clear()
|
||||
self._request_to_execution.clear()
|
||||
await self.workflow.reset_for_new_run()
|
||||
|
||||
async def _process_workflow_result(
|
||||
self,
|
||||
result: WorkflowRunResult,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
@@ -20,7 +19,7 @@ from agent_framework import (
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
)
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorResponse
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
||||
from agent_framework._workflows._const import GLOBAL_KWARGS_KEY
|
||||
|
||||
@@ -307,108 +306,6 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
|
||||
assert restored_session.session_id == session.session_id
|
||||
|
||||
|
||||
# region: Tests for AgentExecutor.reset()
|
||||
|
||||
|
||||
async def test_agent_executor_reset_clears_per_run_state() -> None:
|
||||
"""reset() clears cache, conversation snapshot, and pending request/response buffers."""
|
||||
agent = _CountingAgent(id="reset_agent", name="ResetAgent")
|
||||
executor = AgentExecutor(agent, id="reset_exec")
|
||||
|
||||
# Populate every per-run buffer.
|
||||
executor._cache = [Message(role="user", contents=["cached"])] # type: ignore[reportPrivateUsage]
|
||||
executor._full_conversation = [ # type: ignore[reportPrivateUsage]
|
||||
Message(role="user", contents=["prior turn"]),
|
||||
Message(role="assistant", contents=["prior response"]),
|
||||
]
|
||||
pending_request = Content.from_text(text="approve?")
|
||||
executor._pending_agent_requests = {"req-1": pending_request} # type: ignore[reportPrivateUsage]
|
||||
executor._pending_responses_to_agent = [Content.from_text(text="approved")] # type: ignore[reportPrivateUsage]
|
||||
|
||||
await executor.reset()
|
||||
|
||||
assert executor._cache == [] # type: ignore[reportPrivateUsage]
|
||||
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
|
||||
assert executor._pending_agent_requests == {} # type: ignore[reportPrivateUsage]
|
||||
assert executor._pending_responses_to_agent == [] # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_agent_executor_reset_creates_fresh_session_when_auto_created() -> None:
|
||||
"""reset() replaces the agent session when the executor created it itself."""
|
||||
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
|
||||
# No session passed in — executor creates one via agent.create_session().
|
||||
executor = AgentExecutor(agent, id="reset_session_exec")
|
||||
auto_created = executor._session # type: ignore[reportPrivateUsage]
|
||||
auto_created.state["history"] = {"messages": [Message(role="user", contents=["old"])]}
|
||||
|
||||
await executor.reset()
|
||||
|
||||
new_session = executor._session # type: ignore[reportPrivateUsage]
|
||||
assert new_session is not auto_created
|
||||
assert new_session.session_id != auto_created.session_id
|
||||
assert "history" not in new_session.state
|
||||
|
||||
|
||||
async def test_agent_executor_reset_preserves_caller_supplied_session(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""reset() leaves a session passed in via __init__ untouched and warns the caller."""
|
||||
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
|
||||
caller_session = AgentSession()
|
||||
history_payload = {"messages": [Message(role="user", contents=["old"])]}
|
||||
caller_session.state["history"] = history_payload
|
||||
executor = AgentExecutor(agent, id="reset_session_exec", session=caller_session)
|
||||
|
||||
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework._workflows._agent_executor"):
|
||||
await executor.reset()
|
||||
|
||||
# Same instance, state untouched — the caller is responsible for managing the session.
|
||||
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
|
||||
assert caller_session.state["history"] is history_payload
|
||||
assert any("Session was supplied by the caller" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
async def test_agent_executor_reset_allows_subsequent_run() -> None:
|
||||
"""After reset(), the executor can be reused for a fresh workflow run without leaking state."""
|
||||
agent = _CountingAgent(id="reset_reuse_agent", name="ResetReuseAgent")
|
||||
executor = AgentExecutor(agent, id="reset_reuse_exec")
|
||||
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
first_outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run(
|
||||
AgentExecutorRequest(messages=[Message(role="user", contents=["hello"])]),
|
||||
stream=True,
|
||||
):
|
||||
if event.type == "output":
|
||||
first_outputs.append(event)
|
||||
|
||||
assert first_outputs, "first run should have produced at least one output event"
|
||||
# After a normal run the cache is drained but the conversation snapshot remains.
|
||||
assert executor._cache == [] # type: ignore[reportPrivateUsage]
|
||||
assert executor._full_conversation != [] # type: ignore[reportPrivateUsage]
|
||||
first_session_id = executor._session.session_id # type: ignore[reportPrivateUsage]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
|
||||
# Session was auto-created, so reset() rotates it to a fresh one.
|
||||
assert executor._session.session_id != first_session_id # type: ignore[reportPrivateUsage]
|
||||
|
||||
second_outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run(
|
||||
AgentExecutorRequest(messages=[Message(role="user", contents=["second"])]),
|
||||
stream=True,
|
||||
):
|
||||
if event.type == "output":
|
||||
second_outputs.append(event)
|
||||
|
||||
assert second_outputs, "second run after reset should have produced at least one output event"
|
||||
assert agent.call_count == 2
|
||||
|
||||
|
||||
# endregion: Tests for AgentExecutor.reset()
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None:
|
||||
"""_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs."""
|
||||
agent = _CountingAgent(id="test_agent", name="TestAgent")
|
||||
|
||||
@@ -1004,53 +1004,3 @@ def test_handler_typevar_error_takes_priority_over_context_error():
|
||||
@handler
|
||||
async def process(self, message: _T, ctx) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
|
||||
# region: Tests for Executor.reset()
|
||||
|
||||
|
||||
async def test_executor_default_reset_is_noop():
|
||||
"""The base Executor.reset() is a no-op and must complete without raising.
|
||||
|
||||
Subclasses that don't carry reset-relevant state should be able to rely on the
|
||||
default implementation.
|
||||
"""
|
||||
|
||||
class StatelessExecutor(Executor):
|
||||
@handler
|
||||
async def handle(self, message: str, ctx: WorkflowContext[str]) -> None:
|
||||
await ctx.send_message(message)
|
||||
|
||||
executor_instance = StatelessExecutor(id="stateless")
|
||||
|
||||
# Must complete without raising and return None.
|
||||
assert await executor_instance.reset() is None
|
||||
|
||||
|
||||
async def test_executor_subclass_reset_is_invoked():
|
||||
"""A subclass that overrides reset() can clear its own internal state."""
|
||||
|
||||
class CounterExecutor(Executor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.counter = 0
|
||||
self.reset_calls = 0
|
||||
|
||||
@handler
|
||||
async def handle(self, message: int, ctx: WorkflowContext[int]) -> None:
|
||||
self.counter += message
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.counter = 0
|
||||
self.reset_calls += 1
|
||||
|
||||
executor_instance = CounterExecutor(id="counter")
|
||||
executor_instance.counter = 42
|
||||
|
||||
await executor_instance.reset()
|
||||
|
||||
assert executor_instance.counter == 0
|
||||
assert executor_instance.reset_calls == 1
|
||||
|
||||
|
||||
# endregion: Tests for Executor.reset()
|
||||
|
||||
@@ -1111,204 +1111,3 @@ async def test_runner_drains_straggler_events_at_iteration_end():
|
||||
output_events = [e for e in events if e.type == "output"]
|
||||
# We should have output events from both executors
|
||||
assert len(output_events) >= 2
|
||||
|
||||
|
||||
# region: Tests for InProcRunnerContext.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_context_reset_clears_in_flight_messages():
|
||||
"""reset_for_new_run drops queued executor-to-executor messages."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="src"))
|
||||
|
||||
assert await ctx.has_messages() is True
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.has_messages() is False
|
||||
assert await ctx.drain_messages() == {}
|
||||
|
||||
|
||||
async def test_runner_context_reset_drains_pending_events():
|
||||
"""reset_for_new_run discards any events buffered for streaming."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
|
||||
assert await ctx.has_events() is True
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.has_events() is False
|
||||
assert await ctx.drain_events() == []
|
||||
|
||||
|
||||
async def test_runner_context_reset_resets_streaming_flag():
|
||||
"""reset_for_new_run resets streaming back to its non-streaming default."""
|
||||
ctx = InProcRunnerContext()
|
||||
ctx.set_streaming(True)
|
||||
assert ctx.is_streaming() is True
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert ctx.is_streaming() is False
|
||||
|
||||
|
||||
async def test_runner_context_reset_clears_pending_request_info_events():
|
||||
"""reset_for_new_run clears any pending request_info events tracked for correlation."""
|
||||
ctx = InProcRunnerContext()
|
||||
request_info_event = WorkflowEvent.request_info(
|
||||
request_id="request-123",
|
||||
source_executor_id="source",
|
||||
request_data=MockMessage(data=0),
|
||||
response_type=bool,
|
||||
)
|
||||
await ctx.add_request_info_event(request_info_event)
|
||||
|
||||
assert "request-123" in await ctx.get_pending_request_info_events()
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.get_pending_request_info_events() == {}
|
||||
|
||||
|
||||
# endregion: Tests for InProcRunnerContext.reset_for_new_run()
|
||||
|
||||
|
||||
# region: Tests for Runner.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_resets_iteration_count():
|
||||
"""reset_for_new_run resets the iteration counter back to zero."""
|
||||
runner = _make_runner()
|
||||
runner._iteration = 7 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_clears_shared_state():
|
||||
"""reset_for_new_run wipes both committed and pending entries from shared state."""
|
||||
state = State()
|
||||
state.set("committed_key", "committed_value")
|
||||
state.commit()
|
||||
state.set("pending_key", "pending_value") # uncommitted
|
||||
|
||||
runner = Runner(
|
||||
[],
|
||||
{},
|
||||
state,
|
||||
InProcRunnerContext(),
|
||||
"test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
)
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert state.get("committed_key") is None
|
||||
assert state.get("pending_key") is None
|
||||
assert state.has("committed_key") is False
|
||||
assert state.has("pending_key") is False
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_clears_resumed_from_checkpoint_flag():
|
||||
"""reset_for_new_run clears the flag set by restore_from_checkpoint."""
|
||||
runner = _make_runner()
|
||||
resumed_checkpoint = WorkflowCheckpoint(
|
||||
checkpoint_id="resumed-cp",
|
||||
workflow_name="test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
iteration_count=5,
|
||||
)
|
||||
runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage]
|
||||
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
|
||||
# And the iteration count restored from the checkpoint must be wiped, too.
|
||||
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_invokes_executor_reset_for_each_executor():
|
||||
"""reset_for_new_run calls reset() on every registered executor exactly once."""
|
||||
|
||||
class TrackingExecutor(MockExecutor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.reset_calls = 0
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.reset_calls += 1
|
||||
|
||||
executor_a = TrackingExecutor(id="executor_a")
|
||||
executor_b = TrackingExecutor(id="executor_b")
|
||||
|
||||
runner = Runner(
|
||||
[],
|
||||
{executor_a.id: executor_a, executor_b.id: executor_b},
|
||||
State(),
|
||||
InProcRunnerContext(),
|
||||
"test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
)
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert executor_a.reset_calls == 1
|
||||
assert executor_b.reset_calls == 1
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_resets_runner_context():
|
||||
"""reset_for_new_run forwards the reset to the underlying runner context."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=0), source_id="src"))
|
||||
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
|
||||
ctx.set_streaming(True)
|
||||
|
||||
runner = Runner([], {}, State(), ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert await ctx.has_messages() is False
|
||||
assert await ctx.has_events() is False
|
||||
assert ctx.is_streaming() is False
|
||||
|
||||
|
||||
async def test_runner_can_run_again_after_reset_for_new_run():
|
||||
"""After reset_for_new_run the runner can be reserved and converge a fresh workload."""
|
||||
executor_a = MockExecutor(id="executor_a")
|
||||
executor_b = MockExecutor(id="executor_b")
|
||||
|
||||
edges = [
|
||||
SingleEdgeGroup(executor_a.id, executor_b.id),
|
||||
SingleEdgeGroup(executor_b.id, executor_a.id),
|
||||
]
|
||||
executors: dict[str, Executor] = {
|
||||
executor_a.id: executor_a,
|
||||
executor_b.id: executor_b,
|
||||
}
|
||||
state = State()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
# First run: drives MockExecutor's loop until it yields the terminal value.
|
||||
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
# Second run: must succeed cleanly using the same runner instance.
|
||||
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
|
||||
second_run_outputs: list[int] = []
|
||||
async for event in runner.run_until_convergence():
|
||||
if event.type == "output":
|
||||
second_run_outputs.append(event.data)
|
||||
|
||||
assert second_run_outputs == [10]
|
||||
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# endregion: Tests for Runner.reset_for_new_run()
|
||||
|
||||
@@ -689,89 +689,3 @@ async def test_sub_workflow_intermediate_outputs_propagate_to_parent() -> None:
|
||||
|
||||
# The parent's own terminal output is unaffected.
|
||||
assert any(e.executor_id == "parent_sink" and e.data == "final: hello" for e in output_events)
|
||||
|
||||
|
||||
# region: Tests for WorkflowExecutor.reset()
|
||||
|
||||
|
||||
async def test_workflow_executor_reset_clears_execution_state() -> None:
|
||||
"""reset() clears the WorkflowExecutor's per-run execution contexts and request mappings."""
|
||||
validation_workflow = create_email_validation_workflow()
|
||||
parent = Coordinator()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor=parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.build()
|
||||
)
|
||||
|
||||
# First run pauses with a pending request from the sub-workflow.
|
||||
result = await main_workflow.run("test@example.com")
|
||||
assert len(result.get_request_info_events()) == 1
|
||||
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
|
||||
assert len(workflow_executor._request_to_execution) == 1 # type: ignore[reportPrivateUsage]
|
||||
|
||||
await main_workflow.reset_for_new_run()
|
||||
|
||||
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
|
||||
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_workflow_executor_reset_resets_wrapped_workflow() -> None:
|
||||
"""reset() recursively resets the wrapped workflow (runner iteration counter cleared)."""
|
||||
validation_workflow = create_email_validation_workflow()
|
||||
parent = Coordinator()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor=parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.build()
|
||||
)
|
||||
|
||||
await main_workflow.run("test@example.com")
|
||||
# The sub-workflow's runner advanced past iteration 0 during execution.
|
||||
assert validation_workflow._runner._iteration > 0 # type: ignore[reportPrivateUsage]
|
||||
|
||||
await main_workflow.reset_for_new_run()
|
||||
|
||||
# The wrapped workflow's runner was reset along with the parent.
|
||||
assert validation_workflow._runner._iteration == 0 # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_workflow_executor_reset_allows_subsequent_run() -> None:
|
||||
"""After reset(), the parent + WorkflowExecutor can be reused for a fresh run with no leakage."""
|
||||
validation_workflow = create_email_validation_workflow()
|
||||
parent = Coordinator()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor=parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.build()
|
||||
)
|
||||
|
||||
first_result = await main_workflow.run("first@example.com")
|
||||
assert len(first_result.get_request_info_events()) == 1
|
||||
|
||||
await main_workflow.reset_for_new_run()
|
||||
|
||||
# State on the WorkflowExecutor and parent's pending-request bookkeeping is clean.
|
||||
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
|
||||
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
|
||||
|
||||
second_result = await main_workflow.run("second@example.com")
|
||||
second_requests = second_result.get_request_info_events()
|
||||
assert len(second_requests) == 1
|
||||
assert isinstance(second_requests[0].data, DomainCheckRequest)
|
||||
# Confirm the new run produced a request from the second email, not the cached first one.
|
||||
assert second_requests[0].data.email == "second@example.com"
|
||||
# And the WorkflowExecutor is now tracking exactly one fresh execution.
|
||||
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# endregion: Tests for WorkflowExecutor.reset()
|
||||
|
||||
@@ -29,7 +29,6 @@ from agent_framework import (
|
||||
WorkflowEvent,
|
||||
WorkflowException,
|
||||
WorkflowMessage,
|
||||
WorkflowRunnerException,
|
||||
WorkflowRunState,
|
||||
handler,
|
||||
response_handler,
|
||||
@@ -1354,144 +1353,3 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region: Tests for Workflow.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_allows_subsequent_run() -> None:
|
||||
"""After reset_for_new_run() the same workflow instance can be run again from scratch."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = IncrementExecutor(id="executor_b")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
|
||||
.add_edge(executor_a, executor_b)
|
||||
.add_edge(executor_b, executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
first = await workflow.run(NumberMessage(data=0))
|
||||
assert first.get_outputs() == [10]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
second = await workflow.run(NumberMessage(data=0))
|
||||
assert second.get_outputs() == [10]
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_clears_workflow_state() -> None:
|
||||
"""reset_for_new_run() clears values that executors persisted in shared workflow state."""
|
||||
|
||||
class StateWritingExecutor(Executor):
|
||||
@handler
|
||||
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
|
||||
previous = ctx.get_state("seen") or 0
|
||||
ctx.set_state("seen", previous + 1)
|
||||
await ctx.yield_output(previous + 1)
|
||||
|
||||
state_writer = StateWritingExecutor(id="state_writer")
|
||||
workflow = WorkflowBuilder(start_executor=state_writer, output_from=[state_writer]).build()
|
||||
|
||||
first = await workflow.run(NumberMessage(data=1))
|
||||
assert first.get_outputs() == [1]
|
||||
# State was persisted by the executor.
|
||||
assert workflow._runner.state.get("seen") == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
# The runner's shared state has been wiped.
|
||||
assert workflow._runner.state.get("seen") is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
second = await workflow.run(NumberMessage(data=1))
|
||||
# Counter started fresh from 0 again; output is 1, not 2.
|
||||
assert second.get_outputs() == [1]
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_invokes_executor_reset_hook() -> None:
|
||||
"""reset_for_new_run() calls Executor.reset() on every executor in the workflow."""
|
||||
|
||||
class ResettableExecutor(Executor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.reset_calls = 0
|
||||
self.handled = 0
|
||||
|
||||
@handler
|
||||
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
|
||||
self.handled += 1
|
||||
await ctx.yield_output(self.handled)
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.reset_calls += 1
|
||||
self.handled = 0
|
||||
|
||||
executor = ResettableExecutor(id="resettable")
|
||||
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
await workflow.run(NumberMessage(data=1))
|
||||
assert executor.handled == 1
|
||||
assert executor.reset_calls == 0
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert executor.reset_calls == 1
|
||||
# The executor's own counter was wiped by its overridden reset().
|
||||
assert executor.handled == 0
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_resets_runner_iteration_counter() -> None:
|
||||
"""reset_for_new_run() drops the iteration counter accumulated during a prior run."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = IncrementExecutor(id="executor_b")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
|
||||
.add_edge(executor_a, executor_b)
|
||||
.add_edge(executor_b, executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
await workflow.run(NumberMessage(data=0))
|
||||
assert workflow._runner._iteration > 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_rejected_during_streaming_run() -> None:
|
||||
"""reset_for_new_run() raises WorkflowRunnerException while a streaming run is in progress."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = IncrementExecutor(id="executor_b")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
|
||||
.add_edge(executor_a, executor_b)
|
||||
.add_edge(executor_b, executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
async def consume_stream_slowly() -> list[WorkflowEvent]:
|
||||
events: list[WorkflowEvent] = []
|
||||
async for event in workflow.run(NumberMessage(data=0), stream=True):
|
||||
events.append(event)
|
||||
await asyncio.sleep(0.01)
|
||||
return events
|
||||
|
||||
task = asyncio.create_task(consume_stream_slowly())
|
||||
# Let the streaming run start.
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
try:
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the workflow while a run is in progress"):
|
||||
await workflow.reset_for_new_run()
|
||||
finally:
|
||||
await task
|
||||
|
||||
# After the run completes, reset succeeds again.
|
||||
await workflow.reset_for_new_run()
|
||||
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -3032,6 +3032,7 @@ class TestCheckpointContextPathValidation:
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
||||
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
|
||||
agent.run = AsyncMock(
|
||||
side_effect=[
|
||||
AgentResponse(messages=[]),
|
||||
@@ -3155,6 +3156,8 @@ class TestCheckpointContextPathValidation:
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
||||
agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial")
|
||||
agent.run = AsyncMock(return_value=AgentResponse(messages=[]))
|
||||
|
||||
# Constructor inspects WorkflowAgent.workflow internals; bypass setup
|
||||
# by feeding a configured mock through a normal init.
|
||||
@@ -3971,80 +3974,5 @@ class TestWorkflowAgentHosting:
|
||||
assert len(approval_responses) == 1
|
||||
assert approval_responses[0].approved is False # type: ignore[attr-defined]
|
||||
|
||||
async def test_workflow_is_reset_when_no_prior_conversation(self) -> None:
|
||||
"""Without ``previous_response_id`` or ``conversation_id`` the host must
|
||||
reset the workflow so any in-memory state from a previous request is
|
||||
cleared before the new turn runs."""
|
||||
workflow_agent = _build_text_workflow_agent("fresh")
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
with patch.object(
|
||||
workflow_agent.workflow,
|
||||
"reset_for_new_run",
|
||||
new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run),
|
||||
) as reset_spy:
|
||||
resp = await _post(server, input_text="hi", stream=False)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert reset_spy.await_count == 1
|
||||
|
||||
async def test_workflow_is_reset_across_independent_requests(self) -> None:
|
||||
"""Two consecutive requests without a chaining context id must each
|
||||
reset the workflow."""
|
||||
workflow_agent = _build_text_workflow_agent("again")
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
with patch.object(
|
||||
workflow_agent.workflow,
|
||||
"reset_for_new_run",
|
||||
new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run),
|
||||
) as reset_spy:
|
||||
first = await _post(server, input_text="hi", stream=False)
|
||||
second = await _post(server, input_text="hi again", stream=False)
|
||||
|
||||
assert first.status_code == 200
|
||||
assert second.status_code == 200
|
||||
assert reset_spy.await_count == 2
|
||||
|
||||
async def test_workflow_is_not_reset_when_resuming_from_checkpoint(self) -> None:
|
||||
"""When ``previous_response_id`` resolves to an existing workflow
|
||||
checkpoint the host restores the checkpoint instead of resetting
|
||||
the workflow."""
|
||||
workflow_agent, _ = _build_approval_workflow_agent(
|
||||
approval_request_id="apr_no_reset",
|
||||
final_text="resumed",
|
||||
)
|
||||
server = _make_server(workflow_agent)
|
||||
|
||||
first = await _post(server, stream=False)
|
||||
assert first.status_code == 200
|
||||
first_body = first.json()
|
||||
first_response_id = first_body["id"]
|
||||
approval_request_id = next(it["id"] for it in first_body["output"] if it["type"] == "mcp_approval_request")
|
||||
|
||||
with patch.object(
|
||||
workflow_agent.workflow,
|
||||
"reset_for_new_run",
|
||||
new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run),
|
||||
) as reset_spy:
|
||||
second = await _post_json(
|
||||
server,
|
||||
{
|
||||
"model": "test-model",
|
||||
"input": [
|
||||
{
|
||||
"type": "mcp_approval_response",
|
||||
"approval_request_id": approval_request_id,
|
||||
"approve": True,
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"previous_response_id": first_response_id,
|
||||
},
|
||||
)
|
||||
|
||||
assert second.status_code == 200
|
||||
assert reset_spy.await_count == 0
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -198,12 +198,7 @@ class TestBeforeRun:
|
||||
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
|
||||
mock_oss_mem0_client.search.return_value = []
|
||||
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0",
|
||||
mem0_client=mock_oss_mem0_client,
|
||||
user_id="u1",
|
||||
agent_id="a1"
|
||||
)
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1")
|
||||
|
||||
mock_context = MagicMock(spec=SessionContext)
|
||||
mock_msg = MagicMock()
|
||||
|
||||
-22
@@ -598,25 +598,3 @@ class BaseGroupChatOrchestrator(Executor, ABC):
|
||||
metadata: Pattern-specific state dict
|
||||
"""
|
||||
pass
|
||||
|
||||
@override
|
||||
async def reset(self) -> None:
|
||||
"""Reset the orchestrator to its initial state for a new workflow run.
|
||||
|
||||
Clears the shared conversation history and round counter, then delegates
|
||||
to ``_reset_pattern_state()`` so subclasses can clean up any
|
||||
pattern-specific per-run state (caches, sessions, ledgers, etc.).
|
||||
"""
|
||||
logger.debug("%s %s: Resetting state", self.__class__.__name__, self.id)
|
||||
self._full_conversation.clear()
|
||||
self._round_index = 0
|
||||
self._reset_pattern_state()
|
||||
|
||||
def _reset_pattern_state(self) -> None:
|
||||
"""Reset pattern-specific state.
|
||||
|
||||
Override this method in subclasses to clear pattern-specific per-run state
|
||||
when ``reset()`` is invoked. Called after the base class clears the shared
|
||||
conversation and round counter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -327,7 +327,6 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator):
|
||||
)
|
||||
self._agent = agent
|
||||
self._retry_attempts = retry_attempts
|
||||
self._session_supplied_by_caller = session is not None
|
||||
self._session = session or agent.create_session()
|
||||
# Cache for messages since last agent invocation
|
||||
# This is different from the full conversation history maintained by the base orchestrator
|
||||
@@ -338,25 +337,6 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator):
|
||||
self._cache.extend(messages)
|
||||
return super()._append_messages(messages)
|
||||
|
||||
@override
|
||||
def _reset_pattern_state(self) -> None:
|
||||
"""Reset pattern-specific state for a new workflow run.
|
||||
|
||||
Clears the per-run message cache and rotates the orchestrator agent's
|
||||
session unless the caller supplied a session explicitly (in which case
|
||||
the caller is responsible for the session's lifecycle).
|
||||
"""
|
||||
self._cache.clear()
|
||||
if self._session_supplied_by_caller:
|
||||
logger.warning(
|
||||
"%s %s: Session was supplied by the caller and will not be reset. "
|
||||
"If you want a fresh session for the next run, reset or replace it before invoking the workflow.",
|
||||
self.__class__.__name__,
|
||||
self.id,
|
||||
)
|
||||
else:
|
||||
self._session = self._agent.create_session()
|
||||
|
||||
@override
|
||||
async def _handle_messages(
|
||||
self,
|
||||
|
||||
@@ -509,14 +509,6 @@ class MagenticManagerBase(ABC):
|
||||
"""Restore runtime state from checkpoint data."""
|
||||
return
|
||||
|
||||
def on_reset(self) -> None:
|
||||
"""Reset per-run runtime state for a new workflow run.
|
||||
|
||||
Subclasses should clear any state accumulated during a run (e.g. cached
|
||||
ledgers, agent sessions) so the manager is ready to start fresh.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class StandardMagenticManager(MagenticManagerBase):
|
||||
"""Standard Magentic manager that performs real LLM calls via a Agent.
|
||||
@@ -769,12 +761,6 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Failed to restore manager agent session from checkpoint state")
|
||||
|
||||
@override
|
||||
def on_reset(self) -> None:
|
||||
"""Clear cached ledger and start a fresh agent session for a new run."""
|
||||
self.task_ledger = None
|
||||
self._session = self._agent.create_session()
|
||||
|
||||
|
||||
# endregion Magentic Manager
|
||||
|
||||
@@ -1277,15 +1263,6 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator):
|
||||
# a target will broadcast to all.
|
||||
await ctx.send_message(MagenticResetSignal())
|
||||
|
||||
@override
|
||||
def _reset_pattern_state(self) -> None:
|
||||
"""Reset Magentic-specific per-run state for a new workflow run."""
|
||||
self._magentic_context = None
|
||||
self._task_ledger = None
|
||||
self._progress_ledger = None
|
||||
self._terminated = False
|
||||
self._manager.on_reset()
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
"""Capture current orchestrator state for checkpointing."""
|
||||
|
||||
@@ -1097,139 +1097,3 @@ def test_group_chat_orchestrator_factory_invalid_return_type():
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Reset
|
||||
|
||||
|
||||
async def test_base_orchestrator_reset_clears_conversation_and_round_index() -> None:
|
||||
"""reset() clears the conversation history and the round counter."""
|
||||
from agent_framework.orchestrations import GroupChatOrchestrator
|
||||
|
||||
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
|
||||
|
||||
selector = make_sequence_selector()
|
||||
orchestrator = GroupChatOrchestrator(
|
||||
id="orch",
|
||||
participant_registry=ParticipantRegistry([]),
|
||||
selection_func=selector,
|
||||
max_rounds=2,
|
||||
)
|
||||
orchestrator._full_conversation = [Message(role="user", contents=["hi"], author_name="user")]
|
||||
orchestrator._round_index = 4
|
||||
|
||||
await orchestrator.reset()
|
||||
|
||||
assert orchestrator._full_conversation == []
|
||||
assert orchestrator._round_index == 0
|
||||
|
||||
|
||||
async def test_base_orchestrator_reset_invokes_pattern_state_hook() -> None:
|
||||
"""reset() calls _reset_pattern_state() so subclasses can clean up their own state."""
|
||||
from agent_framework.orchestrations import GroupChatOrchestrator
|
||||
|
||||
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
|
||||
|
||||
selector = make_sequence_selector()
|
||||
|
||||
class TrackingOrchestrator(GroupChatOrchestrator):
|
||||
reset_calls: int = 0
|
||||
|
||||
def _reset_pattern_state(self) -> None:
|
||||
type(self).reset_calls += 1
|
||||
|
||||
orchestrator = TrackingOrchestrator(
|
||||
id="orch",
|
||||
participant_registry=ParticipantRegistry([]),
|
||||
selection_func=selector,
|
||||
max_rounds=2,
|
||||
)
|
||||
|
||||
await orchestrator.reset()
|
||||
await orchestrator.reset()
|
||||
|
||||
assert TrackingOrchestrator.reset_calls == 2
|
||||
|
||||
|
||||
async def test_agent_based_orchestrator_reset_clears_cache_and_rotates_session() -> None:
|
||||
"""When the session was not supplied by the caller, reset() rotates the session and clears the cache."""
|
||||
from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator
|
||||
|
||||
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
|
||||
|
||||
agent = cast(Agent, StubManagerAgent())
|
||||
orchestrator = AgentBasedGroupChatOrchestrator(
|
||||
agent=agent,
|
||||
participant_registry=ParticipantRegistry([]),
|
||||
max_rounds=2,
|
||||
)
|
||||
original_session = orchestrator._session
|
||||
orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")]
|
||||
orchestrator._full_conversation = [Message(role="user", contents=["x"], author_name="user")]
|
||||
orchestrator._round_index = 3
|
||||
|
||||
await orchestrator.reset()
|
||||
|
||||
assert orchestrator._cache == []
|
||||
assert orchestrator._full_conversation == []
|
||||
assert orchestrator._round_index == 0
|
||||
assert orchestrator._session is not original_session
|
||||
|
||||
|
||||
async def test_agent_based_orchestrator_reset_warns_when_session_supplied(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""When the caller supplied a session, reset() preserves it and logs a warning."""
|
||||
import logging
|
||||
|
||||
from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator
|
||||
|
||||
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
|
||||
|
||||
agent = cast(Agent, StubManagerAgent())
|
||||
supplied_session = agent.create_session()
|
||||
orchestrator = AgentBasedGroupChatOrchestrator(
|
||||
agent=agent,
|
||||
participant_registry=ParticipantRegistry([]),
|
||||
session=supplied_session,
|
||||
max_rounds=2,
|
||||
)
|
||||
orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")]
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework_orchestrations._group_chat"):
|
||||
await orchestrator.reset()
|
||||
|
||||
assert orchestrator._cache == []
|
||||
# The caller-owned session must be preserved.
|
||||
assert orchestrator._session is supplied_session
|
||||
warnings = [
|
||||
r for r in caplog.records if r.levelno == logging.WARNING and "Session was supplied by the caller" in r.message
|
||||
]
|
||||
assert warnings, f"expected a warning about caller-supplied session, got: {[r.message for r in caplog.records]}"
|
||||
|
||||
|
||||
async def test_workflow_reset_resets_group_chat_orchestrator() -> None:
|
||||
"""End-to-end: workflow.reset_for_new_run() resets the orchestrator's conversation state."""
|
||||
selector = make_sequence_selector()
|
||||
alpha = StubAgent("alpha", "ack from alpha")
|
||||
beta = StubAgent("beta", "ack from beta")
|
||||
|
||||
workflow = GroupChatBuilder(
|
||||
participants=[alpha, beta],
|
||||
max_rounds=2,
|
||||
selection_func=selector,
|
||||
orchestrator_name="manager",
|
||||
).build()
|
||||
|
||||
async for _ in workflow.run("first task", stream=True):
|
||||
pass
|
||||
|
||||
orchestrator = cast(BaseGroupChatOrchestrator, workflow.executors[GroupChatBuilder.DEFAULT_ORCHESTRATOR_ID])
|
||||
assert orchestrator._full_conversation, "orchestrator should have accumulated conversation after first run"
|
||||
assert orchestrator._round_index > 0
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert orchestrator._full_conversation == []
|
||||
assert orchestrator._round_index == 0
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -1243,135 +1243,3 @@ def test_standard_manager_checkpoint_restore_empty_state():
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Manager Reset Tests
|
||||
|
||||
|
||||
def test_magentic_manager_base_on_reset_is_noop_by_default():
|
||||
"""MagenticManagerBase.on_reset() is a no-op so subclasses can opt in."""
|
||||
mgr = FakeManager()
|
||||
# Seed some state on the fake to confirm the base hook does not touch it.
|
||||
mgr.task_ledger = _SimpleLedger(
|
||||
facts=Message("assistant", ["facts"]),
|
||||
plan=Message("assistant", ["plan"]),
|
||||
)
|
||||
|
||||
mgr.on_reset() # base implementation is a no-op
|
||||
|
||||
assert mgr.task_ledger is not None
|
||||
|
||||
|
||||
def test_standard_manager_on_reset_clears_ledger_and_rotates_session():
|
||||
"""StandardMagenticManager.on_reset() clears the cached ledger and creates a fresh session."""
|
||||
from agent_framework_orchestrations._magentic import _MagenticTaskLedger # type: ignore[reportPrivateUsage]
|
||||
|
||||
agent = StubManagerAgent()
|
||||
mgr = StandardMagenticManager(agent=agent)
|
||||
mgr.task_ledger = _MagenticTaskLedger(
|
||||
facts=Message("assistant", ["facts"]),
|
||||
plan=Message("assistant", ["plan"]),
|
||||
)
|
||||
original_session = mgr._session
|
||||
|
||||
mgr.on_reset()
|
||||
|
||||
assert mgr.task_ledger is None
|
||||
assert mgr._session is not original_session
|
||||
assert mgr._session.session_id != original_session.session_id
|
||||
|
||||
|
||||
async def test_magentic_orchestrator_reset_invokes_manager_on_reset() -> None:
|
||||
"""_reset_pattern_state() clears orchestrator state and delegates to manager.on_reset()."""
|
||||
from agent_framework_orchestrations._base_group_chat_orchestrator import (
|
||||
ParticipantRegistry, # type: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
class TrackingManager(FakeManager):
|
||||
reset_calls: int = 0
|
||||
|
||||
@override
|
||||
def on_reset(self) -> None:
|
||||
type(self).reset_calls += 1
|
||||
|
||||
manager = TrackingManager()
|
||||
orchestrator = MagenticOrchestrator(
|
||||
manager=manager,
|
||||
participant_registry=ParticipantRegistry([]),
|
||||
)
|
||||
# Seed magentic-specific state to confirm it is cleared.
|
||||
orchestrator._magentic_context = MagenticContext( # type: ignore[reportPrivateUsage]
|
||||
task="task",
|
||||
participant_descriptions={"agentA": "desc"},
|
||||
)
|
||||
orchestrator._task_ledger = Message("assistant", ["ledger"]) # type: ignore[reportPrivateUsage]
|
||||
orchestrator._progress_ledger = MagenticProgressLedger( # type: ignore[reportPrivateUsage]
|
||||
is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False),
|
||||
is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False),
|
||||
is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True),
|
||||
next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"),
|
||||
instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="do"),
|
||||
)
|
||||
orchestrator._terminated = True # type: ignore[reportPrivateUsage]
|
||||
|
||||
await orchestrator.reset()
|
||||
|
||||
assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage]
|
||||
assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage]
|
||||
assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage]
|
||||
assert orchestrator._terminated is False # type: ignore[reportPrivateUsage]
|
||||
assert TrackingManager.reset_calls == 1
|
||||
|
||||
|
||||
async def test_magentic_orchestrator_reset_propagates_manager_on_reset_failure() -> None:
|
||||
"""Failures in manager.on_reset() must propagate so callers can react."""
|
||||
from agent_framework_orchestrations._base_group_chat_orchestrator import (
|
||||
ParticipantRegistry, # type: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
class FailingManager(FakeManager):
|
||||
@override
|
||||
def on_reset(self) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
manager = FailingManager()
|
||||
orchestrator = MagenticOrchestrator(
|
||||
manager=manager,
|
||||
participant_registry=ParticipantRegistry([]),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await orchestrator.reset()
|
||||
|
||||
|
||||
async def test_workflow_reset_resets_magentic_orchestrator_and_manager() -> None:
|
||||
"""End-to-end: workflow.reset_for_new_run() resets Magentic orchestrator and manager state."""
|
||||
manager = FakeManager()
|
||||
manager.task_ledger = _SimpleLedger(
|
||||
facts=Message("assistant", ["seeded facts"]),
|
||||
plan=Message("assistant", ["seeded plan"]),
|
||||
)
|
||||
workflow = MagenticBuilder(
|
||||
participants=[StubAgent(manager.next_speaker_name, "first draft")],
|
||||
manager=manager,
|
||||
).build()
|
||||
|
||||
async for _ in workflow.run("first task", stream=True):
|
||||
pass
|
||||
|
||||
orchestrator = next(e for e in workflow.executors.values() if isinstance(e, MagenticOrchestrator))
|
||||
assert orchestrator._terminated is True # type: ignore[reportPrivateUsage]
|
||||
assert orchestrator._task_ledger is not None # type: ignore[reportPrivateUsage]
|
||||
assert manager.task_ledger is not None
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage]
|
||||
assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage]
|
||||
assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage]
|
||||
assert orchestrator._terminated is False # type: ignore[reportPrivateUsage]
|
||||
# FakeManager.on_reset is the base no-op, but the orchestrator still tolerates that case.
|
||||
# For the standard manager we exercise full clearing in the dedicated unit test above.
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -169,7 +169,6 @@ callers can still inspect progress or supporting work from the response messages
|
||||
| Sample | File | Concepts |
|
||||
| -------------------------------- | ------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------- |
|
||||
| State with Agents | [state-management/state_with_agents.py](./state-management/state_with_agents.py) | Store in state once and later reuse across agents |
|
||||
| Reset Workflow Between Runs | [state-management/workflow_reset.py](./state-management/workflow_reset.py) | Reuse one workflow instance across independent runs via `reset_for_new_run()` |
|
||||
| Workflow Kwargs - Global Context | [state-management/workflow_kwargs_global.py](./state-management/workflow_kwargs_global.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools in all agents |
|
||||
| Workflow Kwargs - Per Agent | [state-management/workflow_kwargs_per_agent.py](./state-management/workflow_kwargs_per_agent.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools in individual agents |
|
||||
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework import Executor, Workflow, WorkflowBuilder, WorkflowContext, handler
|
||||
from typing_extensions import Never, override
|
||||
|
||||
"""
|
||||
Sample: Reusing a Workflow across independent jobs with reset_for_new_run().
|
||||
|
||||
Build a small moderation pipeline that silently accumulates stats as messages
|
||||
flow through it and emits a summary only when the caller asks for one. Drive
|
||||
the same workflow instance across multiple independent jobs and show that
|
||||
``Workflow.reset_for_new_run()`` clears all per-executor state without
|
||||
rebuilding the graph.
|
||||
|
||||
Two custom executors share the work, each with its own per-job state and its
|
||||
own ``reset()`` override:
|
||||
|
||||
- ``FlaggedKeywordCounter`` is the start executor. It accepts message strings
|
||||
to inspect (silently updating local stats, sending nothing downstream) and
|
||||
``ReportRequest`` markers that cause it to forward a ``StatsSnapshot``.
|
||||
- ``StatsReporter`` formats the snapshot, increments its own emitted-reports
|
||||
counter, and yields the summary as the workflow output.
|
||||
|
||||
A run with a string produces no output, just a state update. A run with a
|
||||
``ReportRequest`` produces exactly one summary. Job boundaries are entirely
|
||||
controlled by the caller via ``reset_for_new_run()``, which calls ``reset()``
|
||||
on every executor in the graph.
|
||||
|
||||
Purpose:
|
||||
Show how to:
|
||||
- Hold per-job aggregate state on a custom Executor subclass.
|
||||
- Override ``Executor.reset()`` on every executor that owns per-run state, so
|
||||
it is cleared automatically when the workflow is reset.
|
||||
- Call ``Workflow.reset_for_new_run()`` between independent jobs so a single
|
||||
workflow instance can serve a stream of unrelated batches without leaking
|
||||
state.
|
||||
|
||||
Prerequisites:
|
||||
- No external services or credentials required; this sample runs entirely in-process.
|
||||
- Familiarity with WorkflowBuilder and Executor subclasses.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReportRequest:
|
||||
"""Marker input that asks the workflow to emit a summary of stats so far."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatsSnapshot:
|
||||
"""Snapshot the counter forwards to the reporter when a report is requested."""
|
||||
|
||||
messages_seen: int
|
||||
flagged_messages: int
|
||||
flagged_keywords: list[str]
|
||||
|
||||
|
||||
class FlaggedKeywordCounter(Executor):
|
||||
"""Executor that silently accumulates per-job stats; emits on demand.
|
||||
|
||||
Holds three instance attributes that build up across the runs that make up a
|
||||
single job:
|
||||
|
||||
- ``_messages_seen``: how many messages have been inspected.
|
||||
- ``_flagged_messages``: how many of those messages contained any flagged keyword.
|
||||
- ``_flagged_keywords``: the set of distinct keywords actually observed.
|
||||
|
||||
Two handlers dispatch by input type:
|
||||
|
||||
- ``inspect`` accepts a string, updates the counters, and sends nothing.
|
||||
- ``emit_report`` accepts a ``ReportRequest`` and forwards a current
|
||||
``StatsSnapshot`` to the downstream reporter.
|
||||
|
||||
Without overriding ``reset()`` this state would leak into the next job when the
|
||||
workflow is reused via ``Workflow.reset_for_new_run()``. The override below
|
||||
clears these attributes so each fresh job starts empty.
|
||||
"""
|
||||
|
||||
FLAGGED_KEYWORDS = frozenset({"spam", "scam", "phishing"})
|
||||
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._messages_seen: int = 0
|
||||
self._flagged_messages: int = 0
|
||||
self._flagged_keywords: set[str] = set()
|
||||
|
||||
@handler
|
||||
async def inspect(self, message: str, ctx: WorkflowContext[StatsSnapshot]) -> None:
|
||||
"""Inspect ``message`` and update local stats. Sends nothing downstream."""
|
||||
self._messages_seen += 1
|
||||
hits = {kw for kw in self.FLAGGED_KEYWORDS if kw in message.lower()}
|
||||
if hits:
|
||||
self._flagged_messages += 1
|
||||
self._flagged_keywords.update(hits)
|
||||
|
||||
@handler
|
||||
async def emit_report(self, _: ReportRequest, ctx: WorkflowContext[StatsSnapshot]) -> None:
|
||||
"""Forward the current stats snapshot to the reporter on request."""
|
||||
await ctx.send_message(
|
||||
StatsSnapshot(
|
||||
messages_seen=self._messages_seen,
|
||||
flagged_messages=self._flagged_messages,
|
||||
flagged_keywords=sorted(self._flagged_keywords),
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
async def reset(self) -> None:
|
||||
"""Clear per-job aggregate state when the workflow is reset.
|
||||
|
||||
``Workflow.reset_for_new_run()`` calls ``reset()`` on every executor in the
|
||||
graph; overriding it here is what makes a reused workflow safe to drive with
|
||||
a brand-new job.
|
||||
"""
|
||||
self._messages_seen = 0
|
||||
self._flagged_messages = 0
|
||||
self._flagged_keywords.clear()
|
||||
|
||||
|
||||
class StatsReporter(Executor):
|
||||
"""Terminal executor that formats a snapshot and yields it as workflow output.
|
||||
|
||||
Holds a single instance attribute, ``_reports_emitted``, that tracks how many
|
||||
summaries this reporter has produced on this workflow instance, and clears
|
||||
it on reset so a reset workflow behaves identically to a freshly built one.
|
||||
"""
|
||||
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self._reports_emitted: int = 0
|
||||
|
||||
@handler
|
||||
async def report(self, snapshot: StatsSnapshot, ctx: WorkflowContext[Never, str]) -> None:
|
||||
self._reports_emitted += 1
|
||||
summary = (
|
||||
f"messages={snapshot.messages_seen}, "
|
||||
f"flagged={snapshot.flagged_messages}, "
|
||||
f"keywords={snapshot.flagged_keywords or 'none'}, "
|
||||
f"reports_emitted={self._reports_emitted}"
|
||||
)
|
||||
await ctx.yield_output(summary)
|
||||
|
||||
@override
|
||||
async def reset(self) -> None:
|
||||
"""Clear the emitted-reports counter when the workflow is reset."""
|
||||
self._reports_emitted = 0
|
||||
|
||||
|
||||
async def _process(workflow: Workflow, messages: list[str]) -> None:
|
||||
"""Send each message through the workflow; no output is produced."""
|
||||
for message in messages:
|
||||
await workflow.run(message)
|
||||
|
||||
|
||||
async def _request_report(workflow: Workflow) -> str:
|
||||
"""Ask the workflow for a summary of the stats accumulated so far."""
|
||||
events = await workflow.run(ReportRequest())
|
||||
outputs = events.get_outputs()
|
||||
return outputs[0] if outputs else ""
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Build the moderation workflow once, then run it across three independent jobs."""
|
||||
|
||||
# 1. Build the moderation pipeline once. The same workflow instance will be
|
||||
# reused for every job; that's the whole point of this sample.
|
||||
counter = FlaggedKeywordCounter(id="counter")
|
||||
reporter = StatsReporter(id="reporter")
|
||||
workflow = WorkflowBuilder(start_executor=counter, output_from=[reporter]).add_edge(counter, reporter).build()
|
||||
|
||||
# 2. First job -- inspect three messages, then request a report. Note this
|
||||
# batch happens to be three messages, but any size works.
|
||||
await _process(workflow, ["hello there", "free phishing kit", "lunch plans?"])
|
||||
print(f"Batch A summary: {await _request_report(workflow)}")
|
||||
|
||||
# 3. Second job WITHOUT reset. State from batch A leaks in: the counter's
|
||||
# tallies and the reporter's emitted-reports counter both keep
|
||||
# accumulating even though batch B is conceptually a separate job.
|
||||
await _process(workflow, ["weekly status update", "team offsite agenda", "quarterly review"])
|
||||
print(f"Batch B summary (no reset): {await _request_report(workflow)}")
|
||||
|
||||
# 4. Now reset between jobs and process the same batch B again. The summary
|
||||
# reflects only batch B and every per-run counter starts fresh, because
|
||||
# reset_for_new_run() calls reset() on every executor in the graph:
|
||||
# - FlaggedKeywordCounter clears its message / flag / keyword tallies.
|
||||
# - StatsReporter clears its emitted-reports counter.
|
||||
await workflow.reset_for_new_run()
|
||||
await _process(workflow, ["weekly status update", "team offsite agenda", "quarterly review"])
|
||||
print(f"Batch B summary (after reset): {await _request_report(workflow)}")
|
||||
|
||||
# 5. Reset again before a final unrelated job. A reset workflow is
|
||||
# indistinguishable from a freshly built one for state purposes, but
|
||||
# cheaper because the graph and executor objects are reused.
|
||||
await workflow.reset_for_new_run()
|
||||
await _process(workflow, ["spam offer #1", "scam alert", "phishing attempt"])
|
||||
print(f"Batch C summary (after reset): {await _request_report(workflow)}")
|
||||
|
||||
"""
|
||||
Sample Output:
|
||||
|
||||
Batch A summary: messages=3, flagged=1, keywords=['phishing'], reports_emitted=1
|
||||
Batch B summary (no reset): messages=6, flagged=1, keywords=['phishing'], reports_emitted=2
|
||||
Batch B summary (after reset): messages=3, flagged=0, keywords=none, reports_emitted=1
|
||||
Batch C summary (after reset): messages=3, flagged=3, keywords=['phishing', 'scam', 'spam'], reports_emitted=1
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user