mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix per-service-call history persistence with server-storing clients (#6310)
* Fix per-service-call history persistence with server-storing clients When an Agent set require_per_service_call_history_persistence=True together with a HistoryProvider, and the chat client stored history server-side by default (e.g. OpenAIChatClient, STORES_BY_DEFAULT=True), the external history provider was silently never persisted. Unify persistence on the per-service-call middleware: when the flag is set and a HistoryProvider exists, the middleware is always installed and owns persistence. service_stores_history now only selects middleware behavior: - service does not store: load providers and drive the function loop with a local sentinel conversation id, or - service stores: skip loading (the service owns history) and persist each service call while the real conversation id flows through. Also rationalize chat-options handling in _prepare_run_context: - _merge_options now skips None overrides and strips remaining None values, so an unset `store` is never forwarded and the service decides its own default. - Resolve `store` and `conversation_id` once from a single combined view (effective_options) instead of probing both default and runtime dicts; the auto-injection and per-service-call resolution now agree on conversation_id. Fixes #5798 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Correct as_agent() docstring: persistence is per service call, not once per run Address PR review: when the client stores history server-side, the per-service-call middleware still persists after each model call; only provider loading is skipped. The previous "persist once per run()" wording contradicted the implementation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address PR review: docs, missing-conversation-id warning, and tests - Clarify that require_per_service_call_history_persistence is a no-op when no HistoryProvider is present (docstrings in _agents.py and _clients.py). - Warn on every service call when the client stores history server-side but returns no conversation_id, so the (uncommon) loss of cross-turn resumability cannot fail silently. - Add tests: storing client + existing conversation_id does not raise and the id propagates; two runs on the same session keep persisting with a stable service_session_id and no provider loading; storing-without-conversation-id warns per call. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
af772997af
commit
7e0767a0a0
@@ -92,12 +92,16 @@ OptionsCoT = TypeVar(
|
||||
def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Merge two options dicts, with override values taking precedence.
|
||||
|
||||
``None`` is treated as "unset": ``None`` overrides are skipped so they don't clobber a base
|
||||
value, and the merged result is stripped of any remaining ``None`` values in a final pass so
|
||||
unset options are never forwarded (e.g. an unset ``store`` is left for the service to default).
|
||||
|
||||
Args:
|
||||
base: The base options dict.
|
||||
override: The override options dict (values take precedence).
|
||||
|
||||
Returns:
|
||||
A new merged options dict.
|
||||
A new merged options dict containing no ``None`` values.
|
||||
"""
|
||||
result = dict(base)
|
||||
|
||||
@@ -123,7 +127,7 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
|
||||
result["instructions"] = f"{result['instructions']}\n{value}"
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
return {key: value for key, value in result.items() if value is not None}
|
||||
|
||||
|
||||
def _sanitize_agent_name(agent_name: str | None) -> str | None:
|
||||
@@ -460,6 +464,9 @@ class BaseAgent(SerializationMixin):
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
# When per-service-call persistence is enabled, the per-service-call middleware owns
|
||||
# HistoryProvider persistence (in both the local and service-managed cases), so skip
|
||||
# them on the once-per-run path to avoid double persistence.
|
||||
per_service_call_history_required = self.require_per_service_call_history_persistence and any(
|
||||
isinstance(provider, HistoryProvider) for provider in self.context_providers
|
||||
)
|
||||
@@ -686,11 +693,16 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
description: A brief description of the agent's purpose.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
require_per_service_call_history_persistence: When True, history providers are invoked
|
||||
around each model call instead of once per ``run()`` when the service
|
||||
is not already storing history. If service-side storage is active for
|
||||
the run, the agent skips local history providers and relies on the
|
||||
service-managed conversation instead.
|
||||
require_per_service_call_history_persistence: When True (and a HistoryProvider is
|
||||
present), the provider always persists history via per-service-call middleware,
|
||||
regardless of whether the client stores history server-side. If the client does
|
||||
not store history, the middleware also loads providers around each model call and
|
||||
drives the function loop with a local conversation; if it does, loading is skipped
|
||||
(the service-managed conversation is the source of truth) and the middleware only
|
||||
persists. A warning is logged for providers with ``load_messages=True`` when
|
||||
loading is skipped because service-side storage is active. When no HistoryProvider
|
||||
is present, this flag has no effect (no middleware is installed and nothing is
|
||||
persisted).
|
||||
default_options: A TypedDict containing chat options. When using a typed agent like
|
||||
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
|
||||
provider-specific options including temperature, max_tokens, model,
|
||||
@@ -791,22 +803,20 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
self,
|
||||
*,
|
||||
session: AgentSession | None,
|
||||
options: Mapping[str, Any] | None,
|
||||
conversation_id: str | None,
|
||||
service_stores_history: bool,
|
||||
) -> list[HistoryProvider]:
|
||||
history_providers = self._get_history_providers()
|
||||
if not self.require_per_service_call_history_persistence or not history_providers:
|
||||
return []
|
||||
|
||||
conversation_id = (
|
||||
session.service_session_id
|
||||
if session and session.service_session_id
|
||||
else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id"))
|
||||
)
|
||||
if service_stores_history:
|
||||
return []
|
||||
|
||||
if conversation_id is not None:
|
||||
# A live service-managed session id takes precedence over the resolved conversation id.
|
||||
if session and session.service_session_id:
|
||||
conversation_id = session.service_session_id
|
||||
# Without service-side storage the middleware persists locally and drives the function
|
||||
# loop with a local sentinel, which cannot be reconciled with an existing service-managed
|
||||
# conversation. When the service stores history, an existing conversation id is expected.
|
||||
if conversation_id is not None and not service_stores_history:
|
||||
raise AgentInvalidRequestException(
|
||||
"require_per_service_call_history_persistence cannot be used "
|
||||
"with an existing service-managed conversation."
|
||||
@@ -1167,18 +1177,34 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
input_messages = normalize_messages(messages)
|
||||
|
||||
# `store` in runtime or agent options takes precedence over client-level storage
|
||||
# indicators. An explicit `store=False` forces local (in-memory) history injection,
|
||||
# even if the client is configured to use service-side storage by default.
|
||||
store_ = opts.get("store", self.default_options.get("store", getattr(self.client, "STORES_BY_DEFAULT", False)))
|
||||
# Combine agent-level defaults with runtime options up front so the decisions below read
|
||||
# `store` from a single place rather than introspecting both dicts. _merge_options applies
|
||||
# the same precedence used for the actual client call (runtime wins; unset/None falls back
|
||||
# to the agent default).
|
||||
effective_options = _merge_options(self.default_options, opts)
|
||||
|
||||
# `store` in runtime or agent options takes precedence over the client's default
|
||||
# storage behavior. An explicit `store=False` forces local (in-memory) history
|
||||
# injection even when the client stores server-side by default; an explicit
|
||||
# `store=True` forces service-side storage. A `store=None`/unset value means the
|
||||
# service falls back to its own default.
|
||||
explicit_store = effective_options.get("store")
|
||||
# Internal behavior hint: will the service own history for this run? Only when the
|
||||
# user left `store` unset do we fall back to the client's STORES_BY_DEFAULT.
|
||||
service_stores_history = (
|
||||
explicit_store if explicit_store is not None else getattr(self.client, "STORES_BY_DEFAULT", False)
|
||||
)
|
||||
# Resolve conversation_id from the same combined view so an agent-level default is honored
|
||||
# when the runtime omits it (a live session id still takes precedence below).
|
||||
effective_conversation_id = effective_options.get("conversation_id")
|
||||
# Auto-inject InMemoryHistoryProvider when session is provided, no context providers
|
||||
# registered, and no service-side storage indicators
|
||||
if (
|
||||
session is not None
|
||||
and not self.context_providers
|
||||
and not session.service_session_id
|
||||
and not opts.get("conversation_id")
|
||||
and not store_
|
||||
and not effective_conversation_id
|
||||
and not service_stores_history
|
||||
):
|
||||
self.context_providers.append(InMemoryHistoryProvider())
|
||||
|
||||
@@ -1188,10 +1214,30 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
per_service_call_history_providers = self._resolve_per_service_call_history_providers(
|
||||
session=active_session,
|
||||
options=opts,
|
||||
service_stores_history=bool(store_),
|
||||
conversation_id=effective_conversation_id,
|
||||
service_stores_history=service_stores_history,
|
||||
)
|
||||
|
||||
# When require_per_service_call_history_persistence is set together with a
|
||||
# HistoryProvider, the per-service-call middleware (installed below) always persists
|
||||
# the provider. ``service_stores_history`` only selects how the middleware behaves:
|
||||
# - service does not store: the middleware also loads providers and drives the function
|
||||
# loop with a local sentinel conversation id, or
|
||||
# - service stores: the middleware skips loading (the service owns history) and simply
|
||||
# persists each service call while the real conversation id flows through.
|
||||
# In the service-managed case loading is skipped, so warn for providers that expect to load.
|
||||
history_providers = self._get_history_providers()
|
||||
if self.require_per_service_call_history_persistence and history_providers and service_stores_history:
|
||||
for provider in history_providers:
|
||||
if provider.load_messages:
|
||||
logger.warning(
|
||||
"HistoryProvider '%s' has load_messages=True but the chat client stores history "
|
||||
"server-side; skipping local history load and relying on the service-managed "
|
||||
"conversation. Set store=False to load from the provider, or load_messages=False "
|
||||
"to silence this warning.",
|
||||
provider.source_id,
|
||||
)
|
||||
|
||||
session_context, chat_options = await self._prepare_session_and_messages(
|
||||
session=active_session,
|
||||
input_messages=input_messages,
|
||||
@@ -1265,8 +1311,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
}
|
||||
if model is not None:
|
||||
run_opts["model"] = model
|
||||
# Remove None values and merge with chat_options
|
||||
run_opts = {k: v for k, v in run_opts.items() if v is not None}
|
||||
# _merge_options strips unset (None) options, so e.g. an unset `store` is not forwarded
|
||||
# and the service decides its own default.
|
||||
co = _merge_options(chat_options, run_opts)
|
||||
|
||||
# Build session_messages from session context: context messages + input messages
|
||||
@@ -1280,6 +1326,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
agent=self,
|
||||
session=active_session,
|
||||
providers=per_service_call_history_providers,
|
||||
service_stores_history=service_stores_history,
|
||||
)
|
||||
existing_middleware = effective_client_kwargs.get("middleware")
|
||||
if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)):
|
||||
@@ -1319,7 +1366,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
"input_messages": input_messages,
|
||||
"session_messages": session_messages,
|
||||
"agent_name": agent_name,
|
||||
"suppress_response_id": bool(per_service_call_history_providers),
|
||||
"suppress_response_id": bool(per_service_call_history_providers) and not service_stores_history,
|
||||
"chat_options": co,
|
||||
"compaction_strategy": compaction_strategy or self.compaction_strategy,
|
||||
"tokenizer": tokenizer or self.tokenizer,
|
||||
@@ -1413,11 +1460,15 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=options or {},
|
||||
)
|
||||
|
||||
# When per-service-call persistence is enabled, the per-service-call middleware owns
|
||||
# HistoryProvider loading (it loads locally when the service does not store history, or
|
||||
# relies on the service when it does), so skip them on the once-per-run before_run path.
|
||||
per_service_call_history_required = self.require_per_service_call_history_persistence and bool(
|
||||
self._get_history_providers()
|
||||
)
|
||||
|
||||
# Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history)
|
||||
# Run before_run providers (forward order, skip HistoryProvider when per-service-call
|
||||
# persistence owns loading)
|
||||
for provider in self.context_providers:
|
||||
if per_service_call_history_required and isinstance(provider, HistoryProvider):
|
||||
continue
|
||||
|
||||
@@ -604,10 +604,13 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
and dict literals are accepted without specialized option typing.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
require_per_service_call_history_persistence: Whether to require per-service-call
|
||||
chat history persistence. When enabled, history providers are invoked around
|
||||
each model call instead of once per ``run()`` when the service is not already
|
||||
storing history.
|
||||
require_per_service_call_history_persistence: When enabled (and a HistoryProvider is
|
||||
present), the provider always persists history after each model call. If the
|
||||
client does not store history server-side, history providers are also loaded and
|
||||
injected around each model call; if it does, provider loading is skipped and the
|
||||
service-managed conversation is the source of truth (persistence still happens
|
||||
after each model call). When no HistoryProvider is present, this flag has no
|
||||
effect (no middleware is installed and nothing is persisted).
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level compaction override. When omitted,
|
||||
client-level compaction defaults remain in effect for each call.
|
||||
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import weakref
|
||||
@@ -36,6 +37,8 @@ if TYPE_CHECKING:
|
||||
from ._middleware import MiddlewareTypes
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
# Registry of known types for state deserialization
|
||||
_STATE_TYPE_REGISTRY: dict[str, type] = {}
|
||||
|
||||
@@ -580,6 +583,7 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
agent: SupportsAgentRun,
|
||||
session: AgentSession,
|
||||
providers: Sequence[HistoryProvider],
|
||||
service_stores_history: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
|
||||
@@ -587,10 +591,16 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
agent: The agent that owns the history providers.
|
||||
session: The active session for the current run.
|
||||
providers: The history providers participating in per-service-call persistence.
|
||||
service_stores_history: When True, the chat client stores history server-side. The
|
||||
middleware then skips loading providers and leaves the real conversation id
|
||||
untouched, persisting each service call without driving the function loop with a
|
||||
local sentinel. When False, the middleware loads providers and uses a local
|
||||
sentinel conversation id so the function loop runs without service-side storage.
|
||||
"""
|
||||
self._agent = agent
|
||||
self._session = session
|
||||
self._providers = list(providers)
|
||||
self._service_stores_history = service_stores_history
|
||||
|
||||
async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext:
|
||||
"""Create a per-call SessionContext and load history providers into it."""
|
||||
@@ -602,6 +612,9 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
)
|
||||
for source_id, source_messages in context_messages.items():
|
||||
service_call_context.extend_messages(source_id, source_messages)
|
||||
# When the service stores history, it owns loading; the providers are write-only sinks.
|
||||
if self._service_stores_history:
|
||||
return service_call_context
|
||||
for provider in self._providers:
|
||||
if not provider.load_messages:
|
||||
continue
|
||||
@@ -652,17 +665,35 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
response: ChatResponse,
|
||||
) -> ChatResponse:
|
||||
"""Persist a model response and apply the local follow-up sentinel when needed."""
|
||||
if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id):
|
||||
if (
|
||||
not self._service_stores_history
|
||||
and response.conversation_id is not None
|
||||
and not is_local_history_conversation_id(response.conversation_id)
|
||||
):
|
||||
raise ChatClientInvalidResponseException(
|
||||
"require_per_service_call_history_persistence cannot be used "
|
||||
"when the chat client returns a real conversation_id."
|
||||
)
|
||||
|
||||
# In storing mode the service is expected to echo a conversation id that the next run
|
||||
# resumes from. If it comes back empty, the provider still captures this turn but there is
|
||||
# no service id to load from next time, so cross-turn history can be lost silently. Warn
|
||||
# every time so this uncommon, easy-to-miss failure mode cannot fail quietly.
|
||||
if self._service_stores_history and response.conversation_id is None:
|
||||
logger.warning(
|
||||
"require_per_service_call_history_persistence is enabled with a chat client that "
|
||||
"stores history server-side, but the client returned no conversation_id; cross-turn "
|
||||
"history may not resume. Set store=False to load and resume from the HistoryProvider "
|
||||
"instead."
|
||||
)
|
||||
|
||||
await self._persist_service_call_response(
|
||||
service_call_context=service_call_context,
|
||||
response=response,
|
||||
)
|
||||
if _response_contains_follow_up_request(response):
|
||||
# The local sentinel only applies when the service does not store history; when it does,
|
||||
# the real conversation id already drives function-loop continuation.
|
||||
if not self._service_stores_history and _response_contains_follow_up_request(response):
|
||||
response.mark_internal_conversation_id()
|
||||
response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID
|
||||
return response
|
||||
@@ -681,8 +712,12 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
result type for streaming or non-streaming execution.
|
||||
"""
|
||||
service_call_context = await self._prepare_service_call_context(context.messages)
|
||||
context.messages = service_call_context.get_messages(include_input=True)
|
||||
self._strip_local_conversation_id(context)
|
||||
# When the service stores history, leave the outgoing messages and the real conversation
|
||||
# id untouched (pass-through); the middleware only persists. Otherwise reconstruct the
|
||||
# outgoing messages from the loaded local history and strip the local sentinel.
|
||||
if not self._service_stores_history:
|
||||
context.messages = service_call_context.get_messages(include_input=True)
|
||||
self._strip_local_conversation_id(context)
|
||||
|
||||
await call_next()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
|
||||
from typing import Any, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -42,6 +43,8 @@ from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_m
|
||||
from agent_framework._middleware import FunctionInvocationContext
|
||||
from agent_framework.exceptions import AgentInvalidRequestException, ChatClientInvalidResponseException
|
||||
|
||||
from .conftest import MockBaseChatClient
|
||||
|
||||
|
||||
class _FixedTokenizer:
|
||||
def __init__(self, token_count: int) -> None:
|
||||
@@ -609,6 +612,7 @@ async def test_streaming_per_service_call_persistence_hides_response_id_from_aft
|
||||
|
||||
async def test_per_service_call_persistence_uses_real_service_storage_when_client_stores_by_default(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
|
||||
@@ -649,15 +653,22 @@ async def test_per_service_call_persistence_uses_real_service_storage_when_clien
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
result = await agent.run("What's the weather in Seattle?", session=session)
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
result = await agent.run("What's the weather in Seattle?", session=session)
|
||||
|
||||
provider_state = session.state[provider.source_id]
|
||||
|
||||
assert result.text == "It is sunny in Seattle."
|
||||
assert result.response_id == "resp_call_2"
|
||||
assert chat_client_base.call_count == 2
|
||||
# The service owns the conversation, so the provider never loads (issue #5798).
|
||||
assert "get_call_count" not in provider_state
|
||||
assert "save_call_count" not in provider_state
|
||||
# Persistence is owned by the per-service-call middleware: it persists once per service call
|
||||
# (issue #5798: the provider must never be silently bypassed when the service stores history).
|
||||
# This run makes two service calls (function call + final answer), so it persists twice.
|
||||
assert provider_state["save_call_count"] == 2
|
||||
# load_messages=True while the service stores history surfaces a warning.
|
||||
assert any("load_messages" in record.message for record in caplog.records)
|
||||
assert session.service_session_id == "resp_service_managed"
|
||||
|
||||
|
||||
@@ -1996,6 +2007,19 @@ def test_merge_options_none_values_ignored():
|
||||
assert result["key2"] == "value2"
|
||||
|
||||
|
||||
def test_merge_options_drops_none_base_values():
|
||||
"""Test _merge_options strips None values so unset options are never forwarded."""
|
||||
base = {"store": None, "temperature": 0.5}
|
||||
override = {"top_p": 0.9}
|
||||
|
||||
result = _merge_options(base, override)
|
||||
|
||||
# An unset base value (e.g. store=None from default_options) must not survive the merge.
|
||||
assert "store" not in result
|
||||
assert result["temperature"] == 0.5
|
||||
assert result["top_p"] == 0.9
|
||||
|
||||
|
||||
def test_merge_options_runtime_model_overrides_default_model() -> None:
|
||||
"""Test _merge_options lets a runtime model override a default model."""
|
||||
result = _merge_options({"model": "default-model"}, {"model": "runtime-model"})
|
||||
@@ -2658,3 +2682,449 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo
|
||||
assert len(exc_info.value.contents) == 1
|
||||
assert exc_info.value.contents[0].type == "oauth_consent_request"
|
||||
assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent"
|
||||
|
||||
|
||||
# region Per-service-call history persistence scenario matrix
|
||||
#
|
||||
# The driving field is ``require_per_service_call_history_persistence``. Every scenario runs a
|
||||
# single agent run that makes **two service calls** -- a function call followed by a final
|
||||
# completion -- so the *timing* of persistence is observable:
|
||||
#
|
||||
# * When the flag is ``True``, the per-service-call middleware persists the provider **after each
|
||||
# service call**. So the function-call turn is already saved by the time the second (final)
|
||||
# service call starts. This holds regardless of whether the chat client stores history
|
||||
# server-side (the bug in issue #5798 was that a storing client silently bypassed persistence).
|
||||
# * When the flag is ``False``, the provider persists **once, at the end of the run** -- nothing is
|
||||
# saved between the two service calls.
|
||||
#
|
||||
# ``SpyChatClient.saves_before_call`` records ``provider.save_calls`` at the start of every service
|
||||
# call, so ``[0, 1]`` means "the function-call turn was persisted before the final call" and
|
||||
# ``[0, 0]`` means "no persistence happened mid-run". The client's ``store`` / ``STORES_BY_DEFAULT``
|
||||
# only selects *how* the middleware behaves -- never *whether* the provider persists.
|
||||
|
||||
_PSC_SERVICE_CONVERSATION_ID = "svc-conversation"
|
||||
|
||||
_psc_stream_params = pytest.mark.parametrize("stream", [False, True], ids=["sync", "stream"])
|
||||
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def _psc_lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
|
||||
def _psc_function_call_script() -> list[tuple[str, ...]]:
|
||||
"""A fresh function-call-then-final-completion script (the client mutates it)."""
|
||||
return [
|
||||
("call", "call_1", "lookup_weather", '{"location": "Seattle"}'),
|
||||
("text", "It is sunny in Seattle."),
|
||||
]
|
||||
|
||||
|
||||
class _PscSpyHistoryProvider(HistoryProvider):
|
||||
"""In-memory history provider that records load/save calls for assertions."""
|
||||
|
||||
def __init__(self, source_id: str = "spy_history", **kwargs: Any) -> None:
|
||||
super().__init__(source_id, **kwargs)
|
||||
self._messages: list[Message] = []
|
||||
self.get_calls: int = 0
|
||||
self.save_calls: int = 0
|
||||
self.saved_batches: list[list[Message]] = []
|
||||
|
||||
async def get_messages(
|
||||
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> list[Message]:
|
||||
self.get_calls += 1
|
||||
return list(self._messages)
|
||||
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.save_calls += 1
|
||||
self.saved_batches.append(list(messages))
|
||||
self._messages.extend(messages)
|
||||
|
||||
@property
|
||||
def stored_messages(self) -> list[Message]:
|
||||
return list(self._messages)
|
||||
|
||||
|
||||
class _PscSpyChatClient(MockBaseChatClient):
|
||||
"""Chat client that scripts a function-call/final-completion sequence.
|
||||
|
||||
It records, at the start of each service call, how many provider saves have already happened
|
||||
(``saves_before_call``), what messages it received, and what options it saw. When the effective
|
||||
``store`` is truthy it returns a stable ``conversation_id`` to mimic a server-managed
|
||||
conversation, so the framework propagates ``session.service_session_id``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: _PscSpyHistoryProvider,
|
||||
stores_by_default: bool = False,
|
||||
script: list[tuple[str, ...]] | None = None,
|
||||
echo_conversation_id: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.STORES_BY_DEFAULT = stores_by_default # type: ignore[attr-defined]
|
||||
self._provider = provider
|
||||
self._script = list(script) if script is not None else [("text", "ok")]
|
||||
self._echo_conversation_id = echo_conversation_id
|
||||
self.received_messages: list[list[Message]] = []
|
||||
self.received_options: list[dict[str, Any]] = []
|
||||
self.saves_before_call: list[int] = []
|
||||
|
||||
def _effective_store(self, options: dict[str, Any]) -> bool:
|
||||
store = options.get("store")
|
||||
if store is None:
|
||||
return bool(self.STORES_BY_DEFAULT)
|
||||
return bool(store)
|
||||
|
||||
def _next_contents(self) -> list[Content]:
|
||||
turn = self._script.pop(0) if self._script else ("text", "ok")
|
||||
if turn[0] == "call":
|
||||
_, call_id, name, args = turn
|
||||
return [Content.from_function_call(call_id=call_id, name=name, arguments=args)]
|
||||
return [Content.from_text(turn[1])]
|
||||
|
||||
def _inner_get_response( # type: ignore[override]
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[Message],
|
||||
stream: bool,
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
||||
self.received_messages.append(list(messages))
|
||||
self.received_options.append(dict(options))
|
||||
self.saves_before_call.append(self._provider.save_calls)
|
||||
store_and_echo = self._effective_store(options) and self._echo_conversation_id
|
||||
conv_id = _PSC_SERVICE_CONVERSATION_ID if store_and_echo else None
|
||||
contents = self._next_contents()
|
||||
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
||||
self.call_count += 1
|
||||
yield ChatResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
||||
response = ChatResponse.from_updates(updates, output_format_type=options.get("response_format"))
|
||||
if conv_id:
|
||||
response.conversation_id = conv_id
|
||||
return response
|
||||
|
||||
return ResponseStream(_stream(), finalizer=_finalize)
|
||||
|
||||
async def _get() -> ChatResponse:
|
||||
self.call_count += 1
|
||||
return ChatResponse(
|
||||
messages=Message(role="assistant", contents=contents),
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
return _get()
|
||||
|
||||
|
||||
def _psc_build_agent(
|
||||
client: _PscSpyChatClient,
|
||||
provider: _PscSpyHistoryProvider,
|
||||
*,
|
||||
require_per_service_call_history_persistence: bool,
|
||||
default_options: dict[str, Any] | None = None,
|
||||
) -> Agent:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if default_options is not None:
|
||||
kwargs["default_options"] = default_options
|
||||
return Agent(
|
||||
client=client,
|
||||
tools=[_psc_lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def _psc_run(agent: Agent, text: str, session: AgentSession, *, stream: bool) -> str:
|
||||
if stream:
|
||||
chunks: list[str] = []
|
||||
async for update in agent.run(text, session=session, stream=True):
|
||||
chunks.append(update.text or "")
|
||||
return "".join(chunks)
|
||||
result = await agent.run(text, session=session)
|
||||
return result.text
|
||||
|
||||
|
||||
# driver=True (the contract under test): persistence happens per service call
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_false_persists_after_each_service_call(stream: bool) -> None:
|
||||
"""Mode A (flag on, service does not store): function-call turn is persisted before the final call."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
# Two service calls: function call, then final completion.
|
||||
assert client.call_count == 2
|
||||
# The contract: the function-call turn was persisted *before* the second service call started.
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
# Mode A loads local history (the middleware injects it before each service call).
|
||||
assert provider.get_calls >= 1
|
||||
# No service-side storage, so no conversation id is propagated.
|
||||
assert session.service_session_id is None
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_stores_by_default_persists_after_each_service_call(
|
||||
stream: bool, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Mode B (flag on, service stores by default): still persists per service call, but skips load (issue #5798)."""
|
||||
provider = _PscSpyHistoryProvider() # load_messages=True by default
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
assert client.call_count == 2
|
||||
# The invariant the bug violated: persistence still happens per service call when the service stores.
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
# The service owns loading, so the provider is never asked to load.
|
||||
assert provider.get_calls == 0
|
||||
# A warning surfaces the bypassed load (load_messages=True).
|
||||
assert any("load_messages" in record.message for record in caplog.records)
|
||||
# The real service conversation id propagates to the session.
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_only_provider_no_load_no_warning(
|
||||
stream: bool, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Mode B with a store-only provider (load_messages=False): persists per call, no load, no warning."""
|
||||
provider = _PscSpyHistoryProvider(load_messages=False)
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
assert not any("load_messages" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_false_override_behaves_as_mode_a(stream: bool) -> None:
|
||||
"""Flag on + storing client but store=False override: falls back to Mode A (local, per call)."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(
|
||||
client, provider, require_per_service_call_history_persistence=True, default_options={"store": False}
|
||||
)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls >= 1
|
||||
# store=False forces local handling, so no real service conversation id.
|
||||
assert session.service_session_id is None
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_none_treated_as_absent(stream: bool, caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Flag on + storing client + explicit store=None: None is "unset", so the storing default applies (Mode B)."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(
|
||||
client, provider, require_per_service_call_history_persistence=True, default_options={"store": None}
|
||||
)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
assert any("load_messages" in record.message for record in caplog.records)
|
||||
# store=None must not be forwarded to the client; the service decides its own default.
|
||||
assert all("store" not in options for options in client.received_options)
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_respects_store_outputs_flag(stream: bool) -> None:
|
||||
"""Flag on: the provider's store_inputs/store_outputs flags still apply per service call."""
|
||||
provider = _PscSpyHistoryProvider(store_inputs=True, store_outputs=False)
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert provider.save_calls == 2
|
||||
# Outputs disabled, so no assistant/tool-call messages were stored, only user/tool inputs.
|
||||
assert provider.stored_messages
|
||||
assert all(message.role != "assistant" for message in provider.stored_messages)
|
||||
|
||||
|
||||
# driver=False (control): persistence happens once, at the end of the run
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_off_store_false_persists_once_at_end(stream: bool) -> None:
|
||||
"""Flag off + non-storing client: nothing is persisted mid-run; one save at the end."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False)
|
||||
session = agent.create_session()
|
||||
|
||||
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
assert client.call_count == 2
|
||||
# The control contract: no save happened between the function call and the final completion.
|
||||
assert client.saves_before_call == [0, 0]
|
||||
assert provider.save_calls == 1
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_off_stores_by_default_persists_once_at_end(stream: bool) -> None:
|
||||
"""Flag off + storing client: once-per-run persistence, and the service conversation id propagates."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 0]
|
||||
assert provider.save_calls == 1
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_storing_with_existing_conversation_id_does_not_raise(stream: bool) -> None:
|
||||
"""Allow side of the guard: flag on + storing client + an existing conversation_id resumes (no raise).
|
||||
|
||||
The non-storing path raises on an existing service-managed conversation id, but with a storing
|
||||
client the run must proceed and the service conversation id must propagate to the session.
|
||||
"""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
if stream:
|
||||
chunks: list[str] = []
|
||||
async for update in agent.run(
|
||||
"What's the weather in Seattle?",
|
||||
session=session,
|
||||
stream=True,
|
||||
options={"conversation_id": "existing_conversation"},
|
||||
):
|
||||
chunks.append(update.text or "")
|
||||
text = "".join(chunks)
|
||||
else:
|
||||
result = await agent.run(
|
||||
"What's the weather in Seattle?",
|
||||
session=session,
|
||||
options={"conversation_id": "existing_conversation"},
|
||||
)
|
||||
text = result.text
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
# Persistence still happens per service call, and the real service id propagates to the session.
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_storing_two_runs_same_session(stream: bool) -> None:
|
||||
"""Storing mode across two runs on one session: persistence keeps happening, id is stable, no load.
|
||||
|
||||
The second run exercises the precedence branch where the session already carries a
|
||||
service_session_id, which must continue to skip provider loading and keep persisting.
|
||||
"""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
first_run_service_id = session.service_session_id
|
||||
assert first_run_service_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
# Reset the scripted client for a second run on the same session.
|
||||
client._script = _psc_function_call_script()
|
||||
client.call_count = 0
|
||||
client.saves_before_call = []
|
||||
|
||||
await _psc_run(agent, "And in Portland?", session, stream=stream)
|
||||
|
||||
# Persistence keeps happening on the second run (two more saves), still per service call.
|
||||
assert client.saves_before_call == [2, 3]
|
||||
assert provider.save_calls == 4
|
||||
# Loading stays skipped and the service conversation id stays stable across runs.
|
||||
assert provider.get_calls == 0
|
||||
assert session.service_session_id == first_run_service_id
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_storing_without_conversation_id_warns_every_call(
|
||||
stream: bool, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Storing mode but the client returns no conversation_id: warn on every service call.
|
||||
|
||||
Without an echoed conversation id the next run has nothing to resume from, so cross-turn
|
||||
history can be lost silently. The warning fires per service call (no dedup) so the uncommon
|
||||
failure mode cannot pass unnoticed.
|
||||
"""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(
|
||||
provider=provider,
|
||||
stores_by_default=True,
|
||||
script=_psc_function_call_script(),
|
||||
echo_conversation_id=False,
|
||||
)
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
# Persistence still happens, but no service id is captured to resume from.
|
||||
assert provider.save_calls == 2
|
||||
assert session.service_session_id is None
|
||||
# Two service calls -> the warning is emitted twice (one per call, not deduped).
|
||||
missing_id_warnings = [r for r in caplog.records if "returned no conversation_id" in r.message]
|
||||
assert len(missing_id_warnings) == 2
|
||||
|
||||
Reference in New Issue
Block a user