Address comments

This commit is contained in:
Tao Chen
2026-06-09 11:02:17 -07:00
Unverified
parent 4e623d561f
commit 47012f1dcf
4 changed files with 166 additions and 0 deletions
@@ -409,6 +409,7 @@ class InProcRunnerContext:
self._messages.clear()
# Clear any pending events (best-effort) by recreating the queue
self._event_queue = asyncio.Queue()
self._pending_request_info_events.clear()
self._streaming = False # Reset streaming flag
async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None:
@@ -1066,6 +1066,24 @@ async def test_runner_context_reset_resets_streaming_flag():
assert ctx.is_streaming() is False
async def test_runner_context_reset_clears_pending_request_info_events():
"""reset_for_new_run clears any pending request_info events tracked for correlation."""
ctx = InProcRunnerContext()
request_info_event = WorkflowEvent.request_info(
request_id="request-123",
source_executor_id="source",
request_data=MockMessage(data=0),
response_type=bool,
)
await ctx.add_request_info_event(request_info_event)
assert "request-123" in await ctx.get_pending_request_info_events()
ctx.reset_for_new_run()
assert await ctx.get_pending_request_info_events() == {}
# endregion: Tests for InProcRunnerContext.reset_for_new_run()
@@ -509,6 +509,14 @@ class MagenticManagerBase(ABC):
"""Restore runtime state from checkpoint data."""
return
def on_reset(self) -> None:
"""Reset per-run runtime state for a new workflow run.
Subclasses should clear any state accumulated during a run (e.g. cached
ledgers, agent sessions) so the manager is ready to start fresh.
"""
return
class StandardMagenticManager(MagenticManagerBase):
"""Standard Magentic manager that performs real LLM calls via a Agent.
@@ -761,6 +769,12 @@ class StandardMagenticManager(MagenticManagerBase):
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager agent session from checkpoint state")
@override
def on_reset(self) -> None:
"""Clear cached ledger and start a fresh agent session for a new run."""
self.task_ledger = None
self._session = self._agent.create_session()
# endregion Magentic Manager
@@ -1270,6 +1284,7 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator):
self._task_ledger = None
self._progress_ledger = None
self._terminated = False
self._manager.on_reset()
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
@@ -1243,3 +1243,135 @@ def test_standard_manager_checkpoint_restore_empty_state():
# endregion
# region Manager Reset Tests
def test_magentic_manager_base_on_reset_is_noop_by_default():
"""MagenticManagerBase.on_reset() is a no-op so subclasses can opt in."""
mgr = FakeManager()
# Seed some state on the fake to confirm the base hook does not touch it.
mgr.task_ledger = _SimpleLedger(
facts=Message("assistant", ["facts"]),
plan=Message("assistant", ["plan"]),
)
mgr.on_reset() # base implementation is a no-op
assert mgr.task_ledger is not None
def test_standard_manager_on_reset_clears_ledger_and_rotates_session():
"""StandardMagenticManager.on_reset() clears the cached ledger and creates a fresh session."""
from agent_framework_orchestrations._magentic import _MagenticTaskLedger # type: ignore[reportPrivateUsage]
agent = StubManagerAgent()
mgr = StandardMagenticManager(agent=agent)
mgr.task_ledger = _MagenticTaskLedger(
facts=Message("assistant", ["facts"]),
plan=Message("assistant", ["plan"]),
)
original_session = mgr._session
mgr.on_reset()
assert mgr.task_ledger is None
assert mgr._session is not original_session
assert mgr._session.session_id != original_session.session_id
async def test_magentic_orchestrator_reset_invokes_manager_on_reset() -> None:
"""_reset_pattern_state() clears orchestrator state and delegates to manager.on_reset()."""
from agent_framework_orchestrations._base_group_chat_orchestrator import (
ParticipantRegistry, # type: ignore[reportPrivateUsage]
)
class TrackingManager(FakeManager):
reset_calls: int = 0
@override
def on_reset(self) -> None:
type(self).reset_calls += 1
manager = TrackingManager()
orchestrator = MagenticOrchestrator(
manager=manager,
participant_registry=ParticipantRegistry([]),
)
# Seed magentic-specific state to confirm it is cleared.
orchestrator._magentic_context = MagenticContext( # type: ignore[reportPrivateUsage]
task="task",
participant_descriptions={"agentA": "desc"},
)
orchestrator._task_ledger = Message("assistant", ["ledger"]) # type: ignore[reportPrivateUsage]
orchestrator._progress_ledger = MagenticProgressLedger( # type: ignore[reportPrivateUsage]
is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False),
is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False),
is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True),
next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"),
instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="do"),
)
orchestrator._terminated = True # type: ignore[reportPrivateUsage]
await orchestrator.reset()
assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage]
assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._terminated is False # type: ignore[reportPrivateUsage]
assert TrackingManager.reset_calls == 1
async def test_magentic_orchestrator_reset_propagates_manager_on_reset_failure() -> None:
"""Failures in manager.on_reset() must propagate so callers can react."""
from agent_framework_orchestrations._base_group_chat_orchestrator import (
ParticipantRegistry, # type: ignore[reportPrivateUsage]
)
class FailingManager(FakeManager):
@override
def on_reset(self) -> None:
raise RuntimeError("boom")
manager = FailingManager()
orchestrator = MagenticOrchestrator(
manager=manager,
participant_registry=ParticipantRegistry([]),
)
with pytest.raises(RuntimeError, match="boom"):
await orchestrator.reset()
async def test_workflow_reset_resets_magentic_orchestrator_and_manager() -> None:
"""End-to-end: workflow.reset_for_new_run() resets Magentic orchestrator and manager state."""
manager = FakeManager()
manager.task_ledger = _SimpleLedger(
facts=Message("assistant", ["seeded facts"]),
plan=Message("assistant", ["seeded plan"]),
)
workflow = MagenticBuilder(
participants=[StubAgent(manager.next_speaker_name, "first draft")],
manager=manager,
).build()
async for _ in workflow.run("first task", stream=True):
pass
orchestrator = next(e for e in workflow.executors.values() if isinstance(e, MagenticOrchestrator))
assert orchestrator._terminated is True # type: ignore[reportPrivateUsage]
assert orchestrator._task_ledger is not None # type: ignore[reportPrivateUsage]
assert manager.task_ledger is not None
await workflow.reset_for_new_run()
assert orchestrator._magentic_context is None # type: ignore[reportPrivateUsage]
assert orchestrator._task_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._progress_ledger is None # type: ignore[reportPrivateUsage]
assert orchestrator._terminated is False # type: ignore[reportPrivateUsage]
# FakeManager.on_reset is the base no-op, but the orchestrator still tolerates that case.
# For the standard manager we exercise full clearing in the dedicated unit test above.
# endregion