mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix StandardMagenticManager to propagate session to manager agent (#4409)
* Fix #4371: Propagate session to manager agent in StandardMagenticManager StandardMagenticManager._complete() was calling self._agent.run(messages) without passing a session. This caused context providers (e.g. RedisHistoryProvider) configured on the manager agent to silently fail, as each call created a new ephemeral session with a different session_id. Changes: - Create an AgentSession in StandardMagenticManager.__init__() - Pass session=self._session in _complete() calls to agent.run() - Persist/restore the session in checkpoint save/restore methods - Add regression tests for session propagation and checkpoint round-trip Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add type: ignore[reportPrivateUsage] to private attribute assertions in tests Address PR review feedback: add # type: ignore[reportPrivateUsage] comments to _session attribute accesses in the new regression tests, matching the existing convention used elsewhere in test_magentic.py (e.g., lines 401-406). The @pytest.mark.asyncio decorator is not needed because pyproject.toml sets asyncio_mode = "auto". Fixes #4371 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review: use getattr for private _session access in tests (#4371) Replace direct mgr._session access with getattr(mgr, "_session") to avoid reportPrivateUsage type-checking warnings without needing type: ignore comments. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Address PR review: fix session restore guard and improve test robustness (#4371) - Use 'is not None' instead of truthiness check for session_payload restore - Use getattr() for private _session attribute access in tests - Add backward-compatibility test for on_checkpoint_restore with empty state Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Make non-async tests plain def to avoid pytest-asyncio dependency (#4409) Tests that never await anything don't need to be async. Using plain def ensures they always run regardless of pytest-asyncio configuration. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
d5da6e05d8
commit
c5ed8209df
@@ -14,6 +14,7 @@ from typing import Any, ClassVar, TypeVar, cast
|
||||
|
||||
from agent_framework import (
|
||||
AgentResponse,
|
||||
AgentSession,
|
||||
Message,
|
||||
SupportsAgentRun,
|
||||
)
|
||||
@@ -559,6 +560,7 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
)
|
||||
|
||||
self._agent: SupportsAgentRun = agent
|
||||
self._session: AgentSession = self._agent.create_session()
|
||||
self.task_ledger: _MagenticTaskLedger | None = task_ledger
|
||||
|
||||
# Prompts may be overridden if needed
|
||||
@@ -587,7 +589,7 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
The agent's run method is called which applies the agent's configured options
|
||||
(temperature, seed, instructions, etc.).
|
||||
"""
|
||||
response: AgentResponse = await self._agent.run(messages)
|
||||
response: AgentResponse = await self._agent.run(messages, session=self._session)
|
||||
if not response.messages:
|
||||
raise RuntimeError("Agent returned no messages in response.")
|
||||
if len(response.messages) > 1:
|
||||
@@ -730,6 +732,7 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
state: dict[str, Any] = {}
|
||||
if self.task_ledger is not None:
|
||||
state["task_ledger"] = self.task_ledger.to_dict()
|
||||
state["agent_session"] = self._session.to_dict()
|
||||
return state
|
||||
|
||||
@override
|
||||
@@ -740,6 +743,12 @@ class StandardMagenticManager(MagenticManagerBase):
|
||||
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Failed to restore manager task ledger from checkpoint state")
|
||||
session_payload = state.get("agent_session")
|
||||
if session_payload is not None:
|
||||
try:
|
||||
self._session = AgentSession.from_dict(session_payload)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Failed to restore manager agent session from checkpoint state")
|
||||
|
||||
|
||||
# endregion Magentic Manager
|
||||
|
||||
@@ -1074,4 +1074,71 @@ def test_magentic_agent_factory_with_standard_manager_options():
|
||||
assert manager.final_answer_prompt == custom_final_prompt
|
||||
|
||||
|
||||
async def test_standard_manager_propagates_session_to_agent():
|
||||
"""Verify StandardMagenticManager passes a consistent session to the underlying agent.
|
||||
|
||||
Regression test for #4371: context providers (e.g. RedisHistoryProvider) configured on
|
||||
the manager agent silently failed because no session was propagated.
|
||||
"""
|
||||
captured_sessions: list[AgentSession | None] = []
|
||||
|
||||
class SessionCapturingAgent(BaseAgent):
|
||||
"""Agent that records the session passed to each run() call."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]:
|
||||
captured_sessions.append(session)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", ["ok"])])
|
||||
|
||||
return _run()
|
||||
|
||||
agent = SessionCapturingAgent()
|
||||
mgr = StandardMagenticManager(agent=agent)
|
||||
ctx = MagenticContext(task="task", participant_descriptions={"a": "desc"})
|
||||
|
||||
await mgr.plan(ctx.clone())
|
||||
|
||||
# plan() calls _complete twice (facts + plan), both should receive the same session
|
||||
assert len(captured_sessions) == 2
|
||||
assert all(s is not None for s in captured_sessions), "session must be passed to agent.run()"
|
||||
assert captured_sessions[0] is captured_sessions[1], "same session instance must be reused across calls"
|
||||
assert captured_sessions[0] is mgr._session
|
||||
|
||||
|
||||
def test_standard_manager_checkpoint_preserves_session():
|
||||
"""Verify that checkpoint save/restore preserves the manager's session identity."""
|
||||
agent = StubManagerAgent()
|
||||
mgr = StandardMagenticManager(agent=agent)
|
||||
original_session_id = mgr._session.session_id
|
||||
|
||||
state = mgr.on_checkpoint_save()
|
||||
assert "agent_session" in state
|
||||
|
||||
# Restore into a fresh manager and verify session_id is preserved
|
||||
mgr2 = StandardMagenticManager(agent=agent)
|
||||
assert mgr2._session.session_id != original_session_id
|
||||
mgr2.on_checkpoint_restore(state)
|
||||
assert mgr2._session.session_id == original_session_id
|
||||
|
||||
|
||||
def test_standard_manager_checkpoint_restore_empty_state():
|
||||
"""Verify that restoring from a state without agent_session leaves the session intact."""
|
||||
agent = StubManagerAgent()
|
||||
mgr = StandardMagenticManager(agent=agent)
|
||||
original_session = mgr._session
|
||||
original_session_id = original_session.session_id
|
||||
|
||||
mgr.on_checkpoint_restore({})
|
||||
assert mgr._session is original_session
|
||||
assert mgr._session.session_id == original_session_id
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user