From 47012f1dcfa68cb0254e360221058556b1774851 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 9 Jun 2026 11:02:17 -0700 Subject: [PATCH] Address comments --- .../_workflows/_runner_context.py | 1 + .../core/tests/workflow/test_runner.py | 18 +++ .../_magentic.py | 15 ++ .../orchestrations/tests/test_magentic.py | 132 ++++++++++++++++++ 4 files changed, 166 insertions(+) diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 2e4901f411..23a748f0be 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -409,6 +409,7 @@ class InProcRunnerContext: self._messages.clear() # Clear any pending events (best-effort) by recreating the queue self._event_queue = asyncio.Queue() + self._pending_request_info_events.clear() self._streaming = False # Reset streaming flag async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index b41e43beee..37c17fd1ca 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -1066,6 +1066,24 @@ async def test_runner_context_reset_resets_streaming_flag(): assert ctx.is_streaming() is False +async def test_runner_context_reset_clears_pending_request_info_events(): + """reset_for_new_run clears any pending request_info events tracked for correlation.""" + ctx = InProcRunnerContext() + request_info_event = WorkflowEvent.request_info( + request_id="request-123", + source_executor_id="source", + request_data=MockMessage(data=0), + response_type=bool, + ) + await ctx.add_request_info_event(request_info_event) + + assert "request-123" in await ctx.get_pending_request_info_events() + + ctx.reset_for_new_run() + + assert await ctx.get_pending_request_info_events() == {} + + # endregion: Tests for InProcRunnerContext.reset_for_new_run() diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py index ab79dbbfa3..7288b67d2f 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py @@ -509,6 +509,14 @@ class MagenticManagerBase(ABC): """Restore runtime state from checkpoint data.""" return + def on_reset(self) -> None: + """Reset per-run runtime state for a new workflow run. + + Subclasses should clear any state accumulated during a run (e.g. cached + ledgers, agent sessions) so the manager is ready to start fresh. + """ + return + class StandardMagenticManager(MagenticManagerBase): """Standard Magentic manager that performs real LLM calls via a Agent. @@ -761,6 +769,12 @@ class StandardMagenticManager(MagenticManagerBase): except Exception: # pragma: no cover - defensive logger.warning("Failed to restore manager agent session from checkpoint state") + @override + def on_reset(self) -> None: + """Clear cached ledger and start a fresh agent session for a new run.""" + self.task_ledger = None + self._session = self._agent.create_session() + # endregion Magentic Manager @@ -1270,6 +1284,7 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator): self._task_ledger = None self._progress_ledger = None self._terminated = False + self._manager.on_reset() @override async def on_checkpoint_save(self) -> dict[str, Any]: diff --git a/python/packages/orchestrations/tests/test_magentic.py b/python/packages/orchestrations/tests/test_magentic.py index 615ba998bc..343d4eff7a 100644 --- a/python/packages/orchestrations/tests/test_magentic.py +++ b/python/packages/orchestrations/tests/test_magentic.py @@ -1243,3 +1243,135 @@ def test_standard_manager_checkpoint_restore_empty_state(): # endregion + + +# region Manager Reset Tests + + +def test_magentic_manager_base_on_reset_is_noop_by_default(): + """MagenticManagerBase.on_reset() is a no-op so subclasses can opt in.""" + mgr = FakeManager() + # Seed some state on the fake to confirm the base hook does not touch it. + mgr.task_ledger = _SimpleLedger( + facts=Message("assistant", ["facts"]), + plan=Message("assistant", ["plan"]), + ) + + mgr.on_reset() # base implementation is a no-op + + assert mgr.task_ledger is not None + + +def test_standard_manager_on_reset_clears_ledger_and_rotates_session(): + """StandardMagenticManager.on_reset() clears the cached ledger and creates a fresh session.""" + from agent_framework_orchestrations._magentic import _MagenticTaskLedger # type: ignore[reportPrivateUsage] + + agent = StubManagerAgent() + mgr = StandardMagenticManager(agent=agent) + mgr.task_ledger = _MagenticTaskLedger( + facts=Message("assistant", ["facts"]), + plan=Message("assistant", ["plan"]), + ) + original_session = mgr._session + + mgr.on_reset() + + assert mgr.task_ledger is None + assert mgr._session is not original_session + assert mgr._session.session_id != original_session.session_id + + +async def test_magentic_orchestrator_reset_invokes_manager_on_reset() -> None: + """_reset_pattern_state() clears orchestrator state and delegates to manager.on_reset().""" + from agent_framework_orchestrations._base_group_chat_orchestrator import ( + ParticipantRegistry, # type: ignore[reportPrivateUsage] + ) + + class TrackingManager(FakeManager): + reset_calls: int = 0 + + @override + def on_reset(self) -> None: + type(self).reset_calls += 1 + + manager = TrackingManager() + orchestrator = MagenticOrchestrator( + manager=manager, + participant_registry=ParticipantRegistry([]), + ) + # Seed magentic-specific state to confirm it is cleared. + orchestrator._magentic_context = MagenticContext( # type: ignore[reportPrivateUsage] + task="task", + participant_descriptions={"agentA": "desc"}, + ) + orchestrator._task_ledger = Message("assistant", ["ledger"]) # type: ignore[reportPrivateUsage] + orchestrator._progress_ledger = MagenticProgressLedger( # type: ignore[reportPrivateUsage] + is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False), + is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), + instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="do"), + ) + orchestrator._terminated = True # type: ignore[reportPrivateUsage] + + await orchestrator.reset() + + assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage] + assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage] + assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage] + assert orchestrator._terminated is False # type: ignore[reportPrivateUsage] + assert TrackingManager.reset_calls == 1 + + +async def test_magentic_orchestrator_reset_propagates_manager_on_reset_failure() -> None: + """Failures in manager.on_reset() must propagate so callers can react.""" + from agent_framework_orchestrations._base_group_chat_orchestrator import ( + ParticipantRegistry, # type: ignore[reportPrivateUsage] + ) + + class FailingManager(FakeManager): + @override + def on_reset(self) -> None: + raise RuntimeError("boom") + + manager = FailingManager() + orchestrator = MagenticOrchestrator( + manager=manager, + participant_registry=ParticipantRegistry([]), + ) + + with pytest.raises(RuntimeError, match="boom"): + await orchestrator.reset() + + +async def test_workflow_reset_resets_magentic_orchestrator_and_manager() -> None: + """End-to-end: workflow.reset_for_new_run() resets Magentic orchestrator and manager state.""" + manager = FakeManager() + manager.task_ledger = _SimpleLedger( + facts=Message("assistant", ["seeded facts"]), + plan=Message("assistant", ["seeded plan"]), + ) + workflow = MagenticBuilder( + participants=[StubAgent(manager.next_speaker_name, "first draft")], + manager=manager, + ).build() + + async for _ in workflow.run("first task", stream=True): + pass + + orchestrator = next(e for e in workflow.executors.values() if isinstance(e, MagenticOrchestrator)) + assert orchestrator._terminated is True # type: ignore[reportPrivateUsage] + assert orchestrator._task_ledger is not None # type: ignore[reportPrivateUsage] + assert manager.task_ledger is not None + + await workflow.reset_for_new_run() + + assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage] + assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage] + assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage] + assert orchestrator._terminated is False # type: ignore[reportPrivateUsage] + # FakeManager.on_reset is the base no-op, but the orchestrator still tolerates that case. + # For the standard manager we exercise full clearing in the dedicated unit test above. + + +# endregion