diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index fb4cb62a54..d96929ce79 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -166,10 +166,6 @@ class AgentExecutor(Executor): raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.") super().__init__(exec_id) self._agent = agent - # Track whether the caller supplied a session so reset() can preserve their - # session reference (which may be wired to external/service-side storage) - # and only replace sessions we created ourselves. - self._session_supplied_by_caller = session is not None self._session = session or self._agent.create_session() self._pending_agent_requests: dict[str, Content] = {} @@ -369,37 +365,10 @@ class AgentExecutor(Executor): pending_responses_payload = state.get("pending_responses_to_agent") self._pending_responses_to_agent = pending_responses_payload or [] - @override - async def reset(self) -> None: - """Reset the executor to its initial state for a new workflow run. - - Clears the message cache, full conversation snapshot, and any pending - user-input request/response bookkeeping. - - Session handling: - * If the session was created by this executor (no ``session`` argument - was passed to ``__init__``), it is replaced with a fresh one via - ``agent.create_session()`` so prior conversation history does not - leak into the next run. - * If the session was supplied by the caller, it is left untouched. - The caller owns the session lifecycle (it may be backed by - service-side or external storage) and is responsible for clearing - or rotating it if a clean slate is desired. - """ - logger.debug("AgentExecutor %s: Resetting state", self.id) + def reset(self) -> None: + """Reset the internal cache of the executor.""" + logger.debug("AgentExecutor %s: Resetting cache", self.id) self._cache.clear() - self._full_conversation.clear() - self._pending_agent_requests.clear() - self._pending_responses_to_agent.clear() - if not self._session_supplied_by_caller: - self._session = self._agent.create_session() - else: - logger.warning( - "AgentExecutor %s: Session was supplied by the caller and will not be reset. " - "Prior conversation history retained in the session may leak into the next run. " - "Reset or rotate the session externally if a clean slate is required.", - self.id, - ) async def _run_agent_and_emit( self, diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index d562a925b2..f57102b2bc 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -516,14 +516,6 @@ class Executor(RequestInfoMixin, DictConvertible): """ ... - async def reset(self) -> None: - """Reset the executor to its initial state. - - Override this method in subclasses to implement custom logic that should - run when the workflow is reset. - """ - ... - # endregion: Executor diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 28f1bec7d5..d36c53186e 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -88,23 +88,6 @@ class Runner: """ self._iteration = 0 - async def reset_for_new_run(self) -> None: - """Reset the runner for a new run. - - This is useful when reusing the same workflow instance for a different run - that is independent from prior runs. - - Concurrent-run rejection lives at the :class:`Workflow` level via the - active-run weak reference, so this method does not re-validate that - no run is in progress; callers are expected to have already done so. - """ - self.reset_iteration_count() - self._ctx.reset_for_new_run() - self._state.clear() - self._resumed_from_checkpoint = False - for executor in self._executors.values(): - await executor.reset() - async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: """Run the workflow until no more messages are sent.""" # Emit any events already produced prior to entering loop diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 5c9a78dd75..0e0d9fe6cd 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload from .._sessions import ContextProvider from .._types import ResponseStream -from ..exceptions import WorkflowCheckpointException, WorkflowException, WorkflowRunnerException +from ..exceptions import WorkflowCheckpointException, WorkflowException from ..observability import OtelAttr, capture_exception, create_workflow_span from ._checkpoint import CheckpointID, CheckpointStorage from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY @@ -568,13 +568,12 @@ class Workflow(DictConvertible): yield in_progress # noqa: RUF070 # Per-run reset for fresh-message runs only. We deliberately - # do NOT clear shared workflow state (`_state.clear()`) or the - # runner context's in-flight messages (`reset_for_new_run()`) - # here - state and pending work persist across `run()` calls - # so that a `WorkflowAgent` can deliver multi-turn input on - # the same instance and have prior turns' context survive. - # Iteration counting and per-run kwargs ARE per-run though, - # so they're reset here. + # do NOT clear shared workflow state or the runner context's + # in-flight messages here - state and pending work persist + # across `run()` calls so that a `WorkflowAgent` can deliver + # multi-turn input on the same instance and have prior turns' + # context survive. Iteration counting and per-run kwargs ARE + # per-run though, so they're reset here. if not is_continuation: self._runner.reset_iteration_count() @@ -1210,24 +1209,3 @@ class Workflow(DictConvertible): """ existing_stream = self._active_run() if self._active_run is not None else None return existing_stream is not None - - async def reset_for_new_run(self) -> None: - """Reset the workflow for a new run that is independent from prior runs. - - Note: - This will reset EVERYTHING - executor states, workflow state, and runner - context (including pending requests/messages). - - Raises: - WorkflowRunnerException: If a run is currently in progress. Reset is only - allowed when the workflow is idle to avoid clobbering in-flight run state. - - """ - existing_stream = self._active_run() if self._active_run is not None else None - if existing_stream is not None: - raise WorkflowRunnerException( - "Cannot reset the workflow while a run is in progress. " - "Wait for the current run to complete before calling reset_for_new_run()." - ) - self._active_run = None - await self._runner.reset_for_new_run() diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 7868651b47..e131533429 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -534,19 +534,6 @@ class WorkflowExecutor(Executor): for event in request_info_events ]) - @override - async def reset(self) -> None: - """Reset the WorkflowExecutor to its initial state for a new workflow run. - - Clears in-flight execution contexts and the request-to-execution mapping, - then recursively resets the wrapped sub-workflow so its executors, runner - context, and shared state are also returned to a clean state. - """ - logger.debug("WorkflowExecutor %s: Resetting state", self.id) - self._execution_contexts.clear() - self._request_to_execution.clear() - await self.workflow.reset_for_new_run() - async def _process_workflow_result( self, result: WorkflowRunResult, diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index e324d52d0d..c9004f234b 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import logging from collections.abc import AsyncIterable, Awaitable from typing import Any, Literal, overload @@ -20,7 +19,7 @@ from agent_framework import ( WorkflowEvent, WorkflowRunState, ) -from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse +from agent_framework._workflows._agent_executor import AgentExecutorResponse from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework._workflows._const import GLOBAL_KWARGS_KEY @@ -307,108 +306,6 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: assert restored_session.session_id == session.session_id -# region: Tests for AgentExecutor.reset() - - -async def test_agent_executor_reset_clears_per_run_state() -> None: - """reset() clears cache, conversation snapshot, and pending request/response buffers.""" - agent = _CountingAgent(id="reset_agent", name="ResetAgent") - executor = AgentExecutor(agent, id="reset_exec") - - # Populate every per-run buffer. - executor._cache = [Message(role="user", contents=["cached"])] # type: ignore[reportPrivateUsage] - executor._full_conversation = [ # type: ignore[reportPrivateUsage] - Message(role="user", contents=["prior turn"]), - Message(role="assistant", contents=["prior response"]), - ] - pending_request = Content.from_text(text="approve?") - executor._pending_agent_requests = {"req-1": pending_request} # type: ignore[reportPrivateUsage] - executor._pending_responses_to_agent = [Content.from_text(text="approved")] # type: ignore[reportPrivateUsage] - - await executor.reset() - - assert executor._cache == [] # type: ignore[reportPrivateUsage] - assert executor._full_conversation == [] # type: ignore[reportPrivateUsage] - assert executor._pending_agent_requests == {} # type: ignore[reportPrivateUsage] - assert executor._pending_responses_to_agent == [] # type: ignore[reportPrivateUsage] - - -async def test_agent_executor_reset_creates_fresh_session_when_auto_created() -> None: - """reset() replaces the agent session when the executor created it itself.""" - agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent") - # No session passed in — executor creates one via agent.create_session(). - executor = AgentExecutor(agent, id="reset_session_exec") - auto_created = executor._session # type: ignore[reportPrivateUsage] - auto_created.state["history"] = {"messages": [Message(role="user", contents=["old"])]} - - await executor.reset() - - new_session = executor._session # type: ignore[reportPrivateUsage] - assert new_session is not auto_created - assert new_session.session_id != auto_created.session_id - assert "history" not in new_session.state - - -async def test_agent_executor_reset_preserves_caller_supplied_session(caplog: pytest.LogCaptureFixture) -> None: - """reset() leaves a session passed in via __init__ untouched and warns the caller.""" - agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent") - caller_session = AgentSession() - history_payload = {"messages": [Message(role="user", contents=["old"])]} - caller_session.state["history"] = history_payload - executor = AgentExecutor(agent, id="reset_session_exec", session=caller_session) - - assert executor._session is caller_session # type: ignore[reportPrivateUsage] - - with caplog.at_level(logging.WARNING, logger="agent_framework._workflows._agent_executor"): - await executor.reset() - - # Same instance, state untouched — the caller is responsible for managing the session. - assert executor._session is caller_session # type: ignore[reportPrivateUsage] - assert caller_session.state["history"] is history_payload - assert any("Session was supplied by the caller" in record.message for record in caplog.records) - - -async def test_agent_executor_reset_allows_subsequent_run() -> None: - """After reset(), the executor can be reused for a fresh workflow run without leaking state.""" - agent = _CountingAgent(id="reset_reuse_agent", name="ResetReuseAgent") - executor = AgentExecutor(agent, id="reset_reuse_exec") - workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build() - - first_outputs: list[WorkflowEvent] = [] - async for event in workflow.run( - AgentExecutorRequest(messages=[Message(role="user", contents=["hello"])]), - stream=True, - ): - if event.type == "output": - first_outputs.append(event) - - assert first_outputs, "first run should have produced at least one output event" - # After a normal run the cache is drained but the conversation snapshot remains. - assert executor._cache == [] # type: ignore[reportPrivateUsage] - assert executor._full_conversation != [] # type: ignore[reportPrivateUsage] - first_session_id = executor._session.session_id # type: ignore[reportPrivateUsage] - - await workflow.reset_for_new_run() - - assert executor._full_conversation == [] # type: ignore[reportPrivateUsage] - # Session was auto-created, so reset() rotates it to a fresh one. - assert executor._session.session_id != first_session_id # type: ignore[reportPrivateUsage] - - second_outputs: list[WorkflowEvent] = [] - async for event in workflow.run( - AgentExecutorRequest(messages=[Message(role="user", contents=["second"])]), - stream=True, - ): - if event.type == "output": - second_outputs.append(event) - - assert second_outputs, "second run after reset should have produced at least one output event" - assert agent.call_count == 2 - - -# endregion: Tests for AgentExecutor.reset() - - async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None: """_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs.""" agent = _CountingAgent(id="test_agent", name="TestAgent") diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 6908900515..df742b76f6 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -1004,53 +1004,3 @@ def test_handler_typevar_error_takes_priority_over_context_error(): @handler async def process(self, message: _T, ctx) -> None: # type: ignore[no-untyped-def] pass - - -# region: Tests for Executor.reset() - - -async def test_executor_default_reset_is_noop(): - """The base Executor.reset() is a no-op and must complete without raising. - - Subclasses that don't carry reset-relevant state should be able to rely on the - default implementation. - """ - - class StatelessExecutor(Executor): - @handler - async def handle(self, message: str, ctx: WorkflowContext[str]) -> None: - await ctx.send_message(message) - - executor_instance = StatelessExecutor(id="stateless") - - # Must complete without raising and return None. - assert await executor_instance.reset() is None - - -async def test_executor_subclass_reset_is_invoked(): - """A subclass that overrides reset() can clear its own internal state.""" - - class CounterExecutor(Executor): - def __init__(self, id: str) -> None: - super().__init__(id=id) - self.counter = 0 - self.reset_calls = 0 - - @handler - async def handle(self, message: int, ctx: WorkflowContext[int]) -> None: - self.counter += message - - async def reset(self) -> None: - self.counter = 0 - self.reset_calls += 1 - - executor_instance = CounterExecutor(id="counter") - executor_instance.counter = 42 - - await executor_instance.reset() - - assert executor_instance.counter == 0 - assert executor_instance.reset_calls == 1 - - -# endregion: Tests for Executor.reset() diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index f6b36ebe32..5c458be7c4 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -1111,204 +1111,3 @@ async def test_runner_drains_straggler_events_at_iteration_end(): output_events = [e for e in events if e.type == "output"] # We should have output events from both executors assert len(output_events) >= 2 - - -# region: Tests for InProcRunnerContext.reset_for_new_run() - - -async def test_runner_context_reset_clears_in_flight_messages(): - """reset_for_new_run drops queued executor-to-executor messages.""" - ctx = InProcRunnerContext() - await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="src")) - - assert await ctx.has_messages() is True - - ctx.reset_for_new_run() - - assert await ctx.has_messages() is False - assert await ctx.drain_messages() == {} - - -async def test_runner_context_reset_drains_pending_events(): - """reset_for_new_run discards any events buffered for streaming.""" - ctx = InProcRunnerContext() - await ctx.add_event(WorkflowEvent.superstep_started(iteration=1)) - assert await ctx.has_events() is True - - ctx.reset_for_new_run() - - assert await ctx.has_events() is False - assert await ctx.drain_events() == [] - - -async def test_runner_context_reset_resets_streaming_flag(): - """reset_for_new_run resets streaming back to its non-streaming default.""" - ctx = InProcRunnerContext() - ctx.set_streaming(True) - assert ctx.is_streaming() is True - - ctx.reset_for_new_run() - - assert ctx.is_streaming() is False - - -async def test_runner_context_reset_clears_pending_request_info_events(): - """reset_for_new_run clears any pending request_info events tracked for correlation.""" - ctx = InProcRunnerContext() - request_info_event = WorkflowEvent.request_info( - request_id="request-123", - source_executor_id="source", - request_data=MockMessage(data=0), - response_type=bool, - ) - await ctx.add_request_info_event(request_info_event) - - assert "request-123" in await ctx.get_pending_request_info_events() - - ctx.reset_for_new_run() - - assert await ctx.get_pending_request_info_events() == {} - - -# endregion: Tests for InProcRunnerContext.reset_for_new_run() - - -# region: Tests for Runner.reset_for_new_run() - - -async def test_runner_reset_for_new_run_resets_iteration_count(): - """reset_for_new_run resets the iteration counter back to zero.""" - runner = _make_runner() - runner._iteration = 7 # pyright: ignore[reportPrivateUsage] - - await runner.reset_for_new_run() - - assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage] - - -async def test_runner_reset_for_new_run_clears_shared_state(): - """reset_for_new_run wipes both committed and pending entries from shared state.""" - state = State() - state.set("committed_key", "committed_value") - state.commit() - state.set("pending_key", "pending_value") # uncommitted - - runner = Runner( - [], - {}, - state, - InProcRunnerContext(), - "test_name", - graph_signature_hash="test_hash", - ) - - await runner.reset_for_new_run() - - assert state.get("committed_key") is None - assert state.get("pending_key") is None - assert state.has("committed_key") is False - assert state.has("pending_key") is False - - -async def test_runner_reset_for_new_run_clears_resumed_from_checkpoint_flag(): - """reset_for_new_run clears the flag set by restore_from_checkpoint.""" - runner = _make_runner() - resumed_checkpoint = WorkflowCheckpoint( - checkpoint_id="resumed-cp", - workflow_name="test_name", - graph_signature_hash="test_hash", - iteration_count=5, - ) - runner._mark_resumed(resumed_checkpoint) # pyright: ignore[reportPrivateUsage] - assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage] - - await runner.reset_for_new_run() - - assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage] - # And the iteration count restored from the checkpoint must be wiped, too. - assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage] - - -async def test_runner_reset_for_new_run_invokes_executor_reset_for_each_executor(): - """reset_for_new_run calls reset() on every registered executor exactly once.""" - - class TrackingExecutor(MockExecutor): - def __init__(self, id: str) -> None: - super().__init__(id=id) - self.reset_calls = 0 - - async def reset(self) -> None: - self.reset_calls += 1 - - executor_a = TrackingExecutor(id="executor_a") - executor_b = TrackingExecutor(id="executor_b") - - runner = Runner( - [], - {executor_a.id: executor_a, executor_b.id: executor_b}, - State(), - InProcRunnerContext(), - "test_name", - graph_signature_hash="test_hash", - ) - - await runner.reset_for_new_run() - - assert executor_a.reset_calls == 1 - assert executor_b.reset_calls == 1 - - -async def test_runner_reset_for_new_run_resets_runner_context(): - """reset_for_new_run forwards the reset to the underlying runner context.""" - ctx = InProcRunnerContext() - await ctx.send_message(WorkflowMessage(data=MockMessage(data=0), source_id="src")) - await ctx.add_event(WorkflowEvent.superstep_started(iteration=1)) - ctx.set_streaming(True) - - runner = Runner([], {}, State(), ctx, "test_name", graph_signature_hash="test_hash") - - await runner.reset_for_new_run() - - assert await ctx.has_messages() is False - assert await ctx.has_events() is False - assert ctx.is_streaming() is False - - -async def test_runner_can_run_again_after_reset_for_new_run(): - """After reset_for_new_run the runner can be reserved and converge a fresh workload.""" - executor_a = MockExecutor(id="executor_a") - executor_b = MockExecutor(id="executor_b") - - edges = [ - SingleEdgeGroup(executor_a.id, executor_b.id), - SingleEdgeGroup(executor_b.id, executor_a.id), - ] - executors: dict[str, Executor] = { - executor_a.id: executor_a, - executor_b.id: executor_b, - } - state = State() - ctx = InProcRunnerContext() - - runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") - - # First run: drives MockExecutor's loop until it yields the terminal value. - await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - async for _ in runner.run_until_convergence(): - pass - assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage] - - await runner.reset_for_new_run() - - # Second run: must succeed cleanly using the same runner instance. - await executor_a.execute(MockMessage(data=0), ["START"], state, ctx) - second_run_outputs: list[int] = [] - async for event in runner.run_until_convergence(): - if event.type == "output": - second_run_outputs.append(event.data) - - assert second_run_outputs == [10] - assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage] - - -# endregion: Tests for Runner.reset_for_new_run() diff --git a/python/packages/core/tests/workflow/test_sub_workflow.py b/python/packages/core/tests/workflow/test_sub_workflow.py index 17b7e05961..7bf38a06f3 100644 --- a/python/packages/core/tests/workflow/test_sub_workflow.py +++ b/python/packages/core/tests/workflow/test_sub_workflow.py @@ -689,89 +689,3 @@ async def test_sub_workflow_intermediate_outputs_propagate_to_parent() -> None: # The parent's own terminal output is unaffected. assert any(e.executor_id == "parent_sink" and e.data == "final: hello" for e in output_events) - - -# region: Tests for WorkflowExecutor.reset() - - -async def test_workflow_executor_reset_clears_execution_state() -> None: - """reset() clears the WorkflowExecutor's per-run execution contexts and request mappings.""" - validation_workflow = create_email_validation_workflow() - parent = Coordinator() - workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow") - - main_workflow = ( - WorkflowBuilder(start_executor=parent) - .add_edge(parent, workflow_executor) - .add_edge(workflow_executor, parent) - .build() - ) - - # First run pauses with a pending request from the sub-workflow. - result = await main_workflow.run("test@example.com") - assert len(result.get_request_info_events()) == 1 - assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage] - assert len(workflow_executor._request_to_execution) == 1 # type: ignore[reportPrivateUsage] - - await main_workflow.reset_for_new_run() - - assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage] - assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage] - - -async def test_workflow_executor_reset_resets_wrapped_workflow() -> None: - """reset() recursively resets the wrapped workflow (runner iteration counter cleared).""" - validation_workflow = create_email_validation_workflow() - parent = Coordinator() - workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow") - - main_workflow = ( - WorkflowBuilder(start_executor=parent) - .add_edge(parent, workflow_executor) - .add_edge(workflow_executor, parent) - .build() - ) - - await main_workflow.run("test@example.com") - # The sub-workflow's runner advanced past iteration 0 during execution. - assert validation_workflow._runner._iteration > 0 # type: ignore[reportPrivateUsage] - - await main_workflow.reset_for_new_run() - - # The wrapped workflow's runner was reset along with the parent. - assert validation_workflow._runner._iteration == 0 # type: ignore[reportPrivateUsage] - - -async def test_workflow_executor_reset_allows_subsequent_run() -> None: - """After reset(), the parent + WorkflowExecutor can be reused for a fresh run with no leakage.""" - validation_workflow = create_email_validation_workflow() - parent = Coordinator() - workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow") - - main_workflow = ( - WorkflowBuilder(start_executor=parent) - .add_edge(parent, workflow_executor) - .add_edge(workflow_executor, parent) - .build() - ) - - first_result = await main_workflow.run("first@example.com") - assert len(first_result.get_request_info_events()) == 1 - - await main_workflow.reset_for_new_run() - - # State on the WorkflowExecutor and parent's pending-request bookkeeping is clean. - assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage] - assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage] - - second_result = await main_workflow.run("second@example.com") - second_requests = second_result.get_request_info_events() - assert len(second_requests) == 1 - assert isinstance(second_requests[0].data, DomainCheckRequest) - # Confirm the new run produced a request from the second email, not the cached first one. - assert second_requests[0].data.email == "second@example.com" - # And the WorkflowExecutor is now tracking exactly one fresh execution. - assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage] - - -# endregion: Tests for WorkflowExecutor.reset() diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 4e607ca8f9..fb46e6d426 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -29,7 +29,6 @@ from agent_framework import ( WorkflowEvent, WorkflowException, WorkflowMessage, - WorkflowRunnerException, WorkflowRunState, handler, response_handler, @@ -1354,144 +1353,3 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None # endregion - - -# region: Tests for Workflow.reset_for_new_run() - - -async def test_workflow_reset_for_new_run_allows_subsequent_run() -> None: - """After reset_for_new_run() the same workflow instance can be run again from scratch.""" - executor_a = IncrementExecutor(id="executor_a") - executor_b = IncrementExecutor(id="executor_b") - - workflow = ( - WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b]) - .add_edge(executor_a, executor_b) - .add_edge(executor_b, executor_a) - .build() - ) - - first = await workflow.run(NumberMessage(data=0)) - assert first.get_outputs() == [10] - - await workflow.reset_for_new_run() - - second = await workflow.run(NumberMessage(data=0)) - assert second.get_outputs() == [10] - - -async def test_workflow_reset_for_new_run_clears_workflow_state() -> None: - """reset_for_new_run() clears values that executors persisted in shared workflow state.""" - - class StateWritingExecutor(Executor): - @handler - async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None: - previous = ctx.get_state("seen") or 0 - ctx.set_state("seen", previous + 1) - await ctx.yield_output(previous + 1) - - state_writer = StateWritingExecutor(id="state_writer") - workflow = WorkflowBuilder(start_executor=state_writer, output_from=[state_writer]).build() - - first = await workflow.run(NumberMessage(data=1)) - assert first.get_outputs() == [1] - # State was persisted by the executor. - assert workflow._runner.state.get("seen") == 1 # pyright: ignore[reportPrivateUsage] - - await workflow.reset_for_new_run() - - # The runner's shared state has been wiped. - assert workflow._runner.state.get("seen") is None # pyright: ignore[reportPrivateUsage] - - second = await workflow.run(NumberMessage(data=1)) - # Counter started fresh from 0 again; output is 1, not 2. - assert second.get_outputs() == [1] - - -async def test_workflow_reset_for_new_run_invokes_executor_reset_hook() -> None: - """reset_for_new_run() calls Executor.reset() on every executor in the workflow.""" - - class ResettableExecutor(Executor): - def __init__(self, id: str) -> None: - super().__init__(id=id) - self.reset_calls = 0 - self.handled = 0 - - @handler - async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None: - self.handled += 1 - await ctx.yield_output(self.handled) - - async def reset(self) -> None: - self.reset_calls += 1 - self.handled = 0 - - executor = ResettableExecutor(id="resettable") - workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build() - - await workflow.run(NumberMessage(data=1)) - assert executor.handled == 1 - assert executor.reset_calls == 0 - - await workflow.reset_for_new_run() - - assert executor.reset_calls == 1 - # The executor's own counter was wiped by its overridden reset(). - assert executor.handled == 0 - - -async def test_workflow_reset_for_new_run_resets_runner_iteration_counter() -> None: - """reset_for_new_run() drops the iteration counter accumulated during a prior run.""" - executor_a = IncrementExecutor(id="executor_a") - executor_b = IncrementExecutor(id="executor_b") - - workflow = ( - WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b]) - .add_edge(executor_a, executor_b) - .add_edge(executor_b, executor_a) - .build() - ) - - await workflow.run(NumberMessage(data=0)) - assert workflow._runner._iteration > 0 # pyright: ignore[reportPrivateUsage] - - await workflow.reset_for_new_run() - - assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage] - - -async def test_workflow_reset_for_new_run_rejected_during_streaming_run() -> None: - """reset_for_new_run() raises WorkflowRunnerException while a streaming run is in progress.""" - executor_a = IncrementExecutor(id="executor_a") - executor_b = IncrementExecutor(id="executor_b") - - workflow = ( - WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b]) - .add_edge(executor_a, executor_b) - .add_edge(executor_b, executor_a) - .build() - ) - - async def consume_stream_slowly() -> list[WorkflowEvent]: - events: list[WorkflowEvent] = [] - async for event in workflow.run(NumberMessage(data=0), stream=True): - events.append(event) - await asyncio.sleep(0.01) - return events - - task = asyncio.create_task(consume_stream_slowly()) - # Let the streaming run start. - await asyncio.sleep(0.02) - - try: - with pytest.raises(WorkflowRunnerException, match="Cannot reset the workflow while a run is in progress"): - await workflow.reset_for_new_run() - finally: - await task - - # After the run completes, reset succeeds again. - await workflow.reset_for_new_run() - assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage] - - -# endregion diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index df3aec0d65..f3061885e4 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -3032,6 +3032,7 @@ class TestCheckpointContextPathValidation: agent.workflow = MagicMock() agent.workflow.name = "wf" agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") agent.run = AsyncMock( side_effect=[ AgentResponse(messages=[]), @@ -3155,6 +3156,8 @@ class TestCheckpointContextPathValidation: agent.workflow = MagicMock() agent.workflow.name = "wf" agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False) + agent.workflow.create_checkpoint = AsyncMock(return_value="cp_initial") + agent.run = AsyncMock(return_value=AgentResponse(messages=[])) # Constructor inspects WorkflowAgent.workflow internals; bypass setup # by feeding a configured mock through a normal init. @@ -3971,80 +3974,5 @@ class TestWorkflowAgentHosting: assert len(approval_responses) == 1 assert approval_responses[0].approved is False # type: ignore[attr-defined] - async def test_workflow_is_reset_when_no_prior_conversation(self) -> None: - """Without ``previous_response_id`` or ``conversation_id`` the host must - reset the workflow so any in-memory state from a previous request is - cleared before the new turn runs.""" - workflow_agent = _build_text_workflow_agent("fresh") - server = _make_server(workflow_agent) - - with patch.object( - workflow_agent.workflow, - "reset_for_new_run", - new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run), - ) as reset_spy: - resp = await _post(server, input_text="hi", stream=False) - - assert resp.status_code == 200 - assert reset_spy.await_count == 1 - - async def test_workflow_is_reset_across_independent_requests(self) -> None: - """Two consecutive requests without a chaining context id must each - reset the workflow.""" - workflow_agent = _build_text_workflow_agent("again") - server = _make_server(workflow_agent) - - with patch.object( - workflow_agent.workflow, - "reset_for_new_run", - new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run), - ) as reset_spy: - first = await _post(server, input_text="hi", stream=False) - second = await _post(server, input_text="hi again", stream=False) - - assert first.status_code == 200 - assert second.status_code == 200 - assert reset_spy.await_count == 2 - - async def test_workflow_is_not_reset_when_resuming_from_checkpoint(self) -> None: - """When ``previous_response_id`` resolves to an existing workflow - checkpoint the host restores the checkpoint instead of resetting - the workflow.""" - workflow_agent, _ = _build_approval_workflow_agent( - approval_request_id="apr_no_reset", - final_text="resumed", - ) - server = _make_server(workflow_agent) - - first = await _post(server, stream=False) - assert first.status_code == 200 - first_body = first.json() - first_response_id = first_body["id"] - approval_request_id = next(it["id"] for it in first_body["output"] if it["type"] == "mcp_approval_request") - - with patch.object( - workflow_agent.workflow, - "reset_for_new_run", - new=AsyncMock(wraps=workflow_agent.workflow.reset_for_new_run), - ) as reset_spy: - second = await _post_json( - server, - { - "model": "test-model", - "input": [ - { - "type": "mcp_approval_response", - "approval_request_id": approval_request_id, - "approve": True, - } - ], - "stream": False, - "previous_response_id": first_response_id, - }, - ) - - assert second.status_code == 200 - assert reset_spy.await_count == 0 - # endregion diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index a047af1638..d5c9e2fe67 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -198,12 +198,7 @@ class TestBeforeRun: """OSS client with all scoping parameters passes them as isolated concurrent kwargs.""" mock_oss_mem0_client.search.return_value = [] - provider = Mem0ContextProvider( - source_id="mem0", - mem0_client=mock_oss_mem0_client, - user_id="u1", - agent_id="a1" - ) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1") mock_context = MagicMock(spec=SessionContext) mock_msg = MagicMock() diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py b/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py index 8cdf71a0d0..a4108b23f0 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py @@ -598,25 +598,3 @@ class BaseGroupChatOrchestrator(Executor, ABC): metadata: Pattern-specific state dict """ pass - - @override - async def reset(self) -> None: - """Reset the orchestrator to its initial state for a new workflow run. - - Clears the shared conversation history and round counter, then delegates - to ``_reset_pattern_state()`` so subclasses can clean up any - pattern-specific per-run state (caches, sessions, ledgers, etc.). - """ - logger.debug("%s %s: Resetting state", self.__class__.__name__, self.id) - self._full_conversation.clear() - self._round_index = 0 - self._reset_pattern_state() - - def _reset_pattern_state(self) -> None: - """Reset pattern-specific state. - - Override this method in subclasses to clear pattern-specific per-run state - when ``reset()`` is invoked. Called after the base class clears the shared - conversation and round counter. - """ - pass diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py index f6bbac5400..728f3e388c 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py @@ -327,7 +327,6 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator): ) self._agent = agent self._retry_attempts = retry_attempts - self._session_supplied_by_caller = session is not None self._session = session or agent.create_session() # Cache for messages since last agent invocation # This is different from the full conversation history maintained by the base orchestrator @@ -338,25 +337,6 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator): self._cache.extend(messages) return super()._append_messages(messages) - @override - def _reset_pattern_state(self) -> None: - """Reset pattern-specific state for a new workflow run. - - Clears the per-run message cache and rotates the orchestrator agent's - session unless the caller supplied a session explicitly (in which case - the caller is responsible for the session's lifecycle). - """ - self._cache.clear() - if self._session_supplied_by_caller: - logger.warning( - "%s %s: Session was supplied by the caller and will not be reset. " - "If you want a fresh session for the next run, reset or replace it before invoking the workflow.", - self.__class__.__name__, - self.id, - ) - else: - self._session = self._agent.create_session() - @override async def _handle_messages( self, diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py index 7288b67d2f..f8cbf88fd7 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py @@ -509,14 +509,6 @@ class MagenticManagerBase(ABC): """Restore runtime state from checkpoint data.""" return - def on_reset(self) -> None: - """Reset per-run runtime state for a new workflow run. - - Subclasses should clear any state accumulated during a run (e.g. cached - ledgers, agent sessions) so the manager is ready to start fresh. - """ - return - class StandardMagenticManager(MagenticManagerBase): """Standard Magentic manager that performs real LLM calls via a Agent. @@ -769,12 +761,6 @@ class StandardMagenticManager(MagenticManagerBase): except Exception: # pragma: no cover - defensive logger.warning("Failed to restore manager agent session from checkpoint state") - @override - def on_reset(self) -> None: - """Clear cached ledger and start a fresh agent session for a new run.""" - self.task_ledger = None - self._session = self._agent.create_session() - # endregion Magentic Manager @@ -1277,15 +1263,6 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator): # a target will broadcast to all. await ctx.send_message(MagenticResetSignal()) - @override - def _reset_pattern_state(self) -> None: - """Reset Magentic-specific per-run state for a new workflow run.""" - self._magentic_context = None - self._task_ledger = None - self._progress_ledger = None - self._terminated = False - self._manager.on_reset() - @override async def on_checkpoint_save(self) -> dict[str, Any]: """Capture current orchestrator state for checkpointing.""" diff --git a/python/packages/orchestrations/tests/test_group_chat.py b/python/packages/orchestrations/tests/test_group_chat.py index 56a6f4a0b4..50f58e781a 100644 --- a/python/packages/orchestrations/tests/test_group_chat.py +++ b/python/packages/orchestrations/tests/test_group_chat.py @@ -1097,139 +1097,3 @@ def test_group_chat_orchestrator_factory_invalid_return_type(): # endregion - - -# region Reset - - -async def test_base_orchestrator_reset_clears_conversation_and_round_index() -> None: - """reset() clears the conversation history and the round counter.""" - from agent_framework.orchestrations import GroupChatOrchestrator - - from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry - - selector = make_sequence_selector() - orchestrator = GroupChatOrchestrator( - id="orch", - participant_registry=ParticipantRegistry([]), - selection_func=selector, - max_rounds=2, - ) - orchestrator._full_conversation = [Message(role="user", contents=["hi"], author_name="user")] - orchestrator._round_index = 4 - - await orchestrator.reset() - - assert orchestrator._full_conversation == [] - assert orchestrator._round_index == 0 - - -async def test_base_orchestrator_reset_invokes_pattern_state_hook() -> None: - """reset() calls _reset_pattern_state() so subclasses can clean up their own state.""" - from agent_framework.orchestrations import GroupChatOrchestrator - - from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry - - selector = make_sequence_selector() - - class TrackingOrchestrator(GroupChatOrchestrator): - reset_calls: int = 0 - - def _reset_pattern_state(self) -> None: - type(self).reset_calls += 1 - - orchestrator = TrackingOrchestrator( - id="orch", - participant_registry=ParticipantRegistry([]), - selection_func=selector, - max_rounds=2, - ) - - await orchestrator.reset() - await orchestrator.reset() - - assert TrackingOrchestrator.reset_calls == 2 - - -async def test_agent_based_orchestrator_reset_clears_cache_and_rotates_session() -> None: - """When the session was not supplied by the caller, reset() rotates the session and clears the cache.""" - from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator - - from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry - - agent = cast(Agent, StubManagerAgent()) - orchestrator = AgentBasedGroupChatOrchestrator( - agent=agent, - participant_registry=ParticipantRegistry([]), - max_rounds=2, - ) - original_session = orchestrator._session - orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")] - orchestrator._full_conversation = [Message(role="user", contents=["x"], author_name="user")] - orchestrator._round_index = 3 - - await orchestrator.reset() - - assert orchestrator._cache == [] - assert orchestrator._full_conversation == [] - assert orchestrator._round_index == 0 - assert orchestrator._session is not original_session - - -async def test_agent_based_orchestrator_reset_warns_when_session_supplied(caplog: pytest.LogCaptureFixture) -> None: - """When the caller supplied a session, reset() preserves it and logs a warning.""" - import logging - - from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator - - from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry - - agent = cast(Agent, StubManagerAgent()) - supplied_session = agent.create_session() - orchestrator = AgentBasedGroupChatOrchestrator( - agent=agent, - participant_registry=ParticipantRegistry([]), - session=supplied_session, - max_rounds=2, - ) - orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")] - - with caplog.at_level(logging.WARNING, logger="agent_framework_orchestrations._group_chat"): - await orchestrator.reset() - - assert orchestrator._cache == [] - # The caller-owned session must be preserved. - assert orchestrator._session is supplied_session - warnings = [ - r for r in caplog.records if r.levelno == logging.WARNING and "Session was supplied by the caller" in r.message - ] - assert warnings, f"expected a warning about caller-supplied session, got: {[r.message for r in caplog.records]}" - - -async def test_workflow_reset_resets_group_chat_orchestrator() -> None: - """End-to-end: workflow.reset_for_new_run() resets the orchestrator's conversation state.""" - selector = make_sequence_selector() - alpha = StubAgent("alpha", "ack from alpha") - beta = StubAgent("beta", "ack from beta") - - workflow = GroupChatBuilder( - participants=[alpha, beta], - max_rounds=2, - selection_func=selector, - orchestrator_name="manager", - ).build() - - async for _ in workflow.run("first task", stream=True): - pass - - orchestrator = cast(BaseGroupChatOrchestrator, workflow.executors[GroupChatBuilder.DEFAULT_ORCHESTRATOR_ID]) - assert orchestrator._full_conversation, "orchestrator should have accumulated conversation after first run" - assert orchestrator._round_index > 0 - - await workflow.reset_for_new_run() - - assert orchestrator._full_conversation == [] - assert orchestrator._round_index == 0 - - -# endregion diff --git a/python/packages/orchestrations/tests/test_magentic.py b/python/packages/orchestrations/tests/test_magentic.py index 343d4eff7a..615ba998bc 100644 --- a/python/packages/orchestrations/tests/test_magentic.py +++ b/python/packages/orchestrations/tests/test_magentic.py @@ -1243,135 +1243,3 @@ def test_standard_manager_checkpoint_restore_empty_state(): # endregion - - -# region Manager Reset Tests - - -def test_magentic_manager_base_on_reset_is_noop_by_default(): - """MagenticManagerBase.on_reset() is a no-op so subclasses can opt in.""" - mgr = FakeManager() - # Seed some state on the fake to confirm the base hook does not touch it. - mgr.task_ledger = _SimpleLedger( - facts=Message("assistant", ["facts"]), - plan=Message("assistant", ["plan"]), - ) - - mgr.on_reset() # base implementation is a no-op - - assert mgr.task_ledger is not None - - -def test_standard_manager_on_reset_clears_ledger_and_rotates_session(): - """StandardMagenticManager.on_reset() clears the cached ledger and creates a fresh session.""" - from agent_framework_orchestrations._magentic import _MagenticTaskLedger # type: ignore[reportPrivateUsage] - - agent = StubManagerAgent() - mgr = StandardMagenticManager(agent=agent) - mgr.task_ledger = _MagenticTaskLedger( - facts=Message("assistant", ["facts"]), - plan=Message("assistant", ["plan"]), - ) - original_session = mgr._session - - mgr.on_reset() - - assert mgr.task_ledger is None - assert mgr._session is not original_session - assert mgr._session.session_id != original_session.session_id - - -async def test_magentic_orchestrator_reset_invokes_manager_on_reset() -> None: - """_reset_pattern_state() clears orchestrator state and delegates to manager.on_reset().""" - from agent_framework_orchestrations._base_group_chat_orchestrator import ( - ParticipantRegistry, # type: ignore[reportPrivateUsage] - ) - - class TrackingManager(FakeManager): - reset_calls: int = 0 - - @override - def on_reset(self) -> None: - type(self).reset_calls += 1 - - manager = TrackingManager() - orchestrator = MagenticOrchestrator( - manager=manager, - participant_registry=ParticipantRegistry([]), - ) - # Seed magentic-specific state to confirm it is cleared. - orchestrator._magentic_context = MagenticContext( # type: ignore[reportPrivateUsage] - task="task", - participant_descriptions={"agentA": "desc"}, - ) - orchestrator._task_ledger = Message("assistant", ["ledger"]) # type: ignore[reportPrivateUsage] - orchestrator._progress_ledger = MagenticProgressLedger( # type: ignore[reportPrivateUsage] - is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False), - is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False), - is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True), - next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), - instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="do"), - ) - orchestrator._terminated = True # type: ignore[reportPrivateUsage] - - await orchestrator.reset() - - assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage] - assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage] - assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage] - assert orchestrator._terminated is False # type: ignore[reportPrivateUsage] - assert TrackingManager.reset_calls == 1 - - -async def test_magentic_orchestrator_reset_propagates_manager_on_reset_failure() -> None: - """Failures in manager.on_reset() must propagate so callers can react.""" - from agent_framework_orchestrations._base_group_chat_orchestrator import ( - ParticipantRegistry, # type: ignore[reportPrivateUsage] - ) - - class FailingManager(FakeManager): - @override - def on_reset(self) -> None: - raise RuntimeError("boom") - - manager = FailingManager() - orchestrator = MagenticOrchestrator( - manager=manager, - participant_registry=ParticipantRegistry([]), - ) - - with pytest.raises(RuntimeError, match="boom"): - await orchestrator.reset() - - -async def test_workflow_reset_resets_magentic_orchestrator_and_manager() -> None: - """End-to-end: workflow.reset_for_new_run() resets Magentic orchestrator and manager state.""" - manager = FakeManager() - manager.task_ledger = _SimpleLedger( - facts=Message("assistant", ["seeded facts"]), - plan=Message("assistant", ["seeded plan"]), - ) - workflow = MagenticBuilder( - participants=[StubAgent(manager.next_speaker_name, "first draft")], - manager=manager, - ).build() - - async for _ in workflow.run("first task", stream=True): - pass - - orchestrator = next(e for e in workflow.executors.values() if isinstance(e, MagenticOrchestrator)) - assert orchestrator._terminated is True # type: ignore[reportPrivateUsage] - assert orchestrator._task_ledger is not None # type: ignore[reportPrivateUsage] - assert manager.task_ledger is not None - - await workflow.reset_for_new_run() - - assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage] - assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage] - assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage] - assert orchestrator._terminated is False # type: ignore[reportPrivateUsage] - # FakeManager.on_reset is the base no-op, but the orchestrator still tolerates that case. - # For the standard manager we exercise full clearing in the dedicated unit test above. - - -# endregion diff --git a/python/samples/03-workflows/README.md b/python/samples/03-workflows/README.md index adcb57dadb..c79203d742 100644 --- a/python/samples/03-workflows/README.md +++ b/python/samples/03-workflows/README.md @@ -169,7 +169,6 @@ callers can still inspect progress or supporting work from the response messages | Sample | File | Concepts | | -------------------------------- | ------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------- | | State with Agents | [state-management/state_with_agents.py](./state-management/state_with_agents.py) | Store in state once and later reuse across agents | -| Reset Workflow Between Runs | [state-management/workflow_reset.py](./state-management/workflow_reset.py) | Reuse one workflow instance across independent runs via `reset_for_new_run()` | | Workflow Kwargs - Global Context | [state-management/workflow_kwargs_global.py](./state-management/workflow_kwargs_global.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools in all agents | | Workflow Kwargs - Per Agent | [state-management/workflow_kwargs_per_agent.py](./state-management/workflow_kwargs_per_agent.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools in individual agents | diff --git a/python/samples/03-workflows/state-management/workflow_reset.py b/python/samples/03-workflows/state-management/workflow_reset.py deleted file mode 100644 index d751e6b066..0000000000 --- a/python/samples/03-workflows/state-management/workflow_reset.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -from dataclasses import dataclass - -from agent_framework import Executor, Workflow, WorkflowBuilder, WorkflowContext, handler -from typing_extensions import Never, override - -""" -Sample: Reusing a Workflow across independent jobs with reset_for_new_run(). - -Build a small moderation pipeline that silently accumulates stats as messages -flow through it and emits a summary only when the caller asks for one. Drive -the same workflow instance across multiple independent jobs and show that -``Workflow.reset_for_new_run()`` clears all per-executor state without -rebuilding the graph. - -Two custom executors share the work, each with its own per-job state and its -own ``reset()`` override: - -- ``FlaggedKeywordCounter`` is the start executor. It accepts message strings - to inspect (silently updating local stats, sending nothing downstream) and - ``ReportRequest`` markers that cause it to forward a ``StatsSnapshot``. -- ``StatsReporter`` formats the snapshot, increments its own emitted-reports - counter, and yields the summary as the workflow output. - -A run with a string produces no output, just a state update. A run with a -``ReportRequest`` produces exactly one summary. Job boundaries are entirely -controlled by the caller via ``reset_for_new_run()``, which calls ``reset()`` -on every executor in the graph. - -Purpose: -Show how to: -- Hold per-job aggregate state on a custom Executor subclass. -- Override ``Executor.reset()`` on every executor that owns per-run state, so - it is cleared automatically when the workflow is reset. -- Call ``Workflow.reset_for_new_run()`` between independent jobs so a single - workflow instance can serve a stream of unrelated batches without leaking - state. - -Prerequisites: -- No external services or credentials required; this sample runs entirely in-process. -- Familiarity with WorkflowBuilder and Executor subclasses. -""" - - -@dataclass -class ReportRequest: - """Marker input that asks the workflow to emit a summary of stats so far.""" - - -@dataclass -class StatsSnapshot: - """Snapshot the counter forwards to the reporter when a report is requested.""" - - messages_seen: int - flagged_messages: int - flagged_keywords: list[str] - - -class FlaggedKeywordCounter(Executor): - """Executor that silently accumulates per-job stats; emits on demand. - - Holds three instance attributes that build up across the runs that make up a - single job: - - - ``_messages_seen``: how many messages have been inspected. - - ``_flagged_messages``: how many of those messages contained any flagged keyword. - - ``_flagged_keywords``: the set of distinct keywords actually observed. - - Two handlers dispatch by input type: - - - ``inspect`` accepts a string, updates the counters, and sends nothing. - - ``emit_report`` accepts a ``ReportRequest`` and forwards a current - ``StatsSnapshot`` to the downstream reporter. - - Without overriding ``reset()`` this state would leak into the next job when the - workflow is reused via ``Workflow.reset_for_new_run()``. The override below - clears these attributes so each fresh job starts empty. - """ - - FLAGGED_KEYWORDS = frozenset({"spam", "scam", "phishing"}) - - def __init__(self, id: str) -> None: - super().__init__(id=id) - self._messages_seen: int = 0 - self._flagged_messages: int = 0 - self._flagged_keywords: set[str] = set() - - @handler - async def inspect(self, message: str, ctx: WorkflowContext[StatsSnapshot]) -> None: - """Inspect ``message`` and update local stats. Sends nothing downstream.""" - self._messages_seen += 1 - hits = {kw for kw in self.FLAGGED_KEYWORDS if kw in message.lower()} - if hits: - self._flagged_messages += 1 - self._flagged_keywords.update(hits) - - @handler - async def emit_report(self, _: ReportRequest, ctx: WorkflowContext[StatsSnapshot]) -> None: - """Forward the current stats snapshot to the reporter on request.""" - await ctx.send_message( - StatsSnapshot( - messages_seen=self._messages_seen, - flagged_messages=self._flagged_messages, - flagged_keywords=sorted(self._flagged_keywords), - ) - ) - - @override - async def reset(self) -> None: - """Clear per-job aggregate state when the workflow is reset. - - ``Workflow.reset_for_new_run()`` calls ``reset()`` on every executor in the - graph; overriding it here is what makes a reused workflow safe to drive with - a brand-new job. - """ - self._messages_seen = 0 - self._flagged_messages = 0 - self._flagged_keywords.clear() - - -class StatsReporter(Executor): - """Terminal executor that formats a snapshot and yields it as workflow output. - - Holds a single instance attribute, ``_reports_emitted``, that tracks how many - summaries this reporter has produced on this workflow instance, and clears - it on reset so a reset workflow behaves identically to a freshly built one. - """ - - def __init__(self, id: str) -> None: - super().__init__(id=id) - self._reports_emitted: int = 0 - - @handler - async def report(self, snapshot: StatsSnapshot, ctx: WorkflowContext[Never, str]) -> None: - self._reports_emitted += 1 - summary = ( - f"messages={snapshot.messages_seen}, " - f"flagged={snapshot.flagged_messages}, " - f"keywords={snapshot.flagged_keywords or 'none'}, " - f"reports_emitted={self._reports_emitted}" - ) - await ctx.yield_output(summary) - - @override - async def reset(self) -> None: - """Clear the emitted-reports counter when the workflow is reset.""" - self._reports_emitted = 0 - - -async def _process(workflow: Workflow, messages: list[str]) -> None: - """Send each message through the workflow; no output is produced.""" - for message in messages: - await workflow.run(message) - - -async def _request_report(workflow: Workflow) -> str: - """Ask the workflow for a summary of the stats accumulated so far.""" - events = await workflow.run(ReportRequest()) - outputs = events.get_outputs() - return outputs[0] if outputs else "" - - -async def main() -> None: - """Build the moderation workflow once, then run it across three independent jobs.""" - - # 1. Build the moderation pipeline once. The same workflow instance will be - # reused for every job; that's the whole point of this sample. - counter = FlaggedKeywordCounter(id="counter") - reporter = StatsReporter(id="reporter") - workflow = WorkflowBuilder(start_executor=counter, output_from=[reporter]).add_edge(counter, reporter).build() - - # 2. First job -- inspect three messages, then request a report. Note this - # batch happens to be three messages, but any size works. - await _process(workflow, ["hello there", "free phishing kit", "lunch plans?"]) - print(f"Batch A summary: {await _request_report(workflow)}") - - # 3. Second job WITHOUT reset. State from batch A leaks in: the counter's - # tallies and the reporter's emitted-reports counter both keep - # accumulating even though batch B is conceptually a separate job. - await _process(workflow, ["weekly status update", "team offsite agenda", "quarterly review"]) - print(f"Batch B summary (no reset): {await _request_report(workflow)}") - - # 4. Now reset between jobs and process the same batch B again. The summary - # reflects only batch B and every per-run counter starts fresh, because - # reset_for_new_run() calls reset() on every executor in the graph: - # - FlaggedKeywordCounter clears its message / flag / keyword tallies. - # - StatsReporter clears its emitted-reports counter. - await workflow.reset_for_new_run() - await _process(workflow, ["weekly status update", "team offsite agenda", "quarterly review"]) - print(f"Batch B summary (after reset): {await _request_report(workflow)}") - - # 5. Reset again before a final unrelated job. A reset workflow is - # indistinguishable from a freshly built one for state purposes, but - # cheaper because the graph and executor objects are reused. - await workflow.reset_for_new_run() - await _process(workflow, ["spam offer #1", "scam alert", "phishing attempt"]) - print(f"Batch C summary (after reset): {await _request_report(workflow)}") - - """ - Sample Output: - - Batch A summary: messages=3, flagged=1, keywords=['phishing'], reports_emitted=1 - Batch B summary (no reset): messages=6, flagged=1, keywords=['phishing'], reports_emitted=2 - Batch B summary (after reset): messages=3, flagged=0, keywords=none, reports_emitted=1 - Batch C summary (after reset): messages=3, flagged=3, keywords=['phishing', 'scam', 'spam'], reports_emitted=1 - """ - - -if __name__ == "__main__": - asyncio.run(main())