Merge branch 'main' into local-branch-python-add-reset-to-workflow

This commit is contained in:
Tao Chen
2026-06-09 09:38:45 -07:00
Unverified
109 changed files with 11534 additions and 4467 deletions
@@ -9,7 +9,7 @@ from typing import Any, cast
from ag_ui.core import BaseEvent
from agent_framework import SupportsAgentRun
from ._agent_run import run_agent_stream
from ._agent_run import PendingApprovalEntry, run_agent_stream
class AgentConfig:
@@ -107,7 +107,7 @@ class AgentFrameworkAgent:
# Populated when approval requests are emitted; consumed when responses arrive.
# Prevents bypass, function name spoofing, and replay attacks.
# Bounded to prevent unbounded growth from abandoned approval requests.
self._pending_approvals: OrderedDict[str, str] = OrderedDict()
self._pending_approvals: OrderedDict[str, PendingApprovalEntry] = OrderedDict()
self._pending_approvals_max_size: int = 10_000
async def run(
@@ -8,7 +8,7 @@ import json
import logging
import uuid
from collections.abc import AsyncIterable, Awaitable
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypedDict, cast
from ag_ui.core import (
BaseEvent,
@@ -56,6 +56,7 @@ from ._run_common import (
_stringify_tool_result, # type: ignore
)
from ._utils import (
canonical_function_arguments,
convert_agui_tools_to_agent_framework,
generate_event_id,
get_conversation_id_from_update,
@@ -407,7 +408,33 @@ def _make_approval_tool_result_events(resolved_approval_results: list[Content])
return events
def _evict_oldest_approvals(registry: dict[str, str], max_size: int = 10_000) -> None:
class _PendingApproval(TypedDict):
"""Pending approval details for a requested function call."""
name: str
arguments: str | None
PendingApprovalEntry = _PendingApproval | str
def _make_pending_approval_entry(name: str, arguments: str | None) -> _PendingApproval:
return {"name": name, "arguments": arguments}
def _pending_approval_name(entry: PendingApprovalEntry) -> str | None:
if isinstance(entry, str):
return entry
return entry["name"]
def _pending_approval_arguments(entry: PendingApprovalEntry) -> str | None:
if isinstance(entry, str):
return None
return entry["arguments"]
def _evict_oldest_approvals(registry: dict[str, PendingApprovalEntry], max_size: int = 10_000) -> None:
"""Evict the oldest entries from the pending-approvals registry (LRU).
Only effective when *registry* is an ``OrderedDict``; plain dicts are
@@ -427,7 +454,7 @@ async def _resolve_approval_responses(
tools: list[Any],
agent: SupportsAgentRun,
run_kwargs: dict[str, Any],
pending_approvals: dict[str, str] | None = None,
pending_approvals: dict[str, PendingApprovalEntry] | None = None,
thread_id: str = "",
) -> list[Content]:
"""Execute approved function calls and replace approval content with results.
@@ -480,7 +507,8 @@ async def _resolve_approval_responses(
invalid_ids.add(resp_id)
continue
pending_name = pending_approvals[registry_key]
pending_entry = pending_approvals[registry_key]
pending_name = _pending_approval_name(pending_entry)
if resp_name != pending_name:
logger.warning(
"Rejected approval response id=%s: function name mismatch (response=%s, pending=%s)",
@@ -491,6 +519,16 @@ async def _resolve_approval_responses(
invalid_ids.add(resp_id)
continue
pending_arguments = _pending_approval_arguments(pending_entry)
response_arguments = canonical_function_arguments(resp.function_call)
if pending_arguments is not None and response_arguments != pending_arguments:
logger.warning(
"Rejected approval response id=%s: function arguments mismatch",
resp_id,
)
invalid_ids.add(resp_id)
continue
# Valid — consume entry to prevent replay
del pending_approvals[registry_key]
if resp.approved:
@@ -714,7 +752,7 @@ async def run_agent_stream(
input_data: dict[str, Any],
agent: SupportsAgentRun,
config: AgentConfig,
pending_approvals: dict[str, str] | None = None,
pending_approvals: dict[str, PendingApprovalEntry] | None = None,
) -> AsyncGenerator[BaseEvent]:
"""Run agent and yield AG-UI events.
@@ -917,7 +955,10 @@ async def run_agent_stream(
# Register pending approval requests so we can validate responses later
if content_type == "function_approval_request" and pending_approvals is not None:
if content.id and content.function_call and content.function_call.name:
pending_approvals[f"{thread_id}:{content.id}"] = content.function_call.name
pending_approvals[f"{thread_id}:{content.id}"] = _make_pending_approval_entry(
content.function_call.name,
canonical_function_arguments(content.function_call),
)
# Evict oldest entries if the registry exceeds a safe bound (LRU)
_evict_oldest_approvals(pending_approvals, max_size=10_000)
else:
@@ -56,6 +56,22 @@ def safe_json_parse(value: Any) -> dict[str, Any] | None:
return None
def canonical_function_arguments(function_call: Any) -> str | None:
"""Return a stable representation of function-call arguments."""
if function_call is None:
return None
try:
parsed_arguments = function_call.parse_arguments()
except Exception:
parsed_arguments = getattr(function_call, "arguments", None)
if parsed_arguments is None:
parsed_arguments = {}
return json.dumps(make_json_safe(parsed_arguments), sort_keys=True, separators=(",", ":"))
def get_role_value(message: Any) -> str:
"""Extract role string from a message object.
@@ -35,7 +35,7 @@ from ._run_common import (
_extract_resume_payload,
_normalize_resume_interrupts,
)
from ._utils import generate_event_id, make_json_safe
from ._utils import canonical_function_arguments, generate_event_id, make_json_safe
logger = logging.getLogger(__name__)
@@ -324,6 +324,29 @@ def _coerce_response_for_request(request_event: Any, value: Any) -> Any | None:
return candidate
def _approval_response_matches_request(request_id: str, request_event: Any, response: Any) -> bool:
"""Check whether an approval response matches the pending approval request."""
request_data = getattr(request_event, "data", None)
if not isinstance(request_data, Content) or request_data.type != "function_approval_request":
return True
if not isinstance(response, Content) or response.type != "function_approval_response":
return False
if str(getattr(response, "id", "")) != request_id:
return False
request_call = getattr(request_data, "function_call", None)
response_call = getattr(response, "function_call", None)
if request_call is None or response_call is None:
return False
if getattr(response_call, "name", None) != getattr(request_call, "name", None):
return False
return canonical_function_arguments(response_call) == canonical_function_arguments(request_call)
def _single_pending_response_from_value(pending_events: dict[str, Any], value: Any) -> dict[str, Any]:
"""Map a scalar resume payload to the single pending request (if unambiguous)."""
if value is None or len(pending_events) != 1:
@@ -343,6 +366,13 @@ def _single_pending_response_from_value(pending_events: dict[str, Any], value: A
)
return {}
if not _approval_response_matches_request(str(request_id), request_event, coerced_value):
logger.info(
"Ignoring pending request response for request_id=%s: approval response does not match pending request",
request_id,
)
return {}
return {str(request_id): coerced_value}
@@ -372,6 +402,12 @@ def _coerce_responses_for_pending_requests(
_response_type_name(request_event),
)
continue
if not _approval_response_matches_request(request_key, request_event, coerced_value):
logger.info(
"Ignoring resume response for request_id=%s: approval response does not match pending request",
request_key,
)
continue
normalized[request_key] = coerced_value
return normalized
@@ -1407,6 +1407,92 @@ async def test_fabricated_rejection_without_pending_approval_is_blocked(streamin
assert False, "Fabricated rejection response leaked as function_result into LLM messages"
async def test_approval_argument_mismatch_is_blocked(streaming_chat_client_stub):
"""An approval response must not execute changed arguments for the pending call."""
from agent_framework import tool
from agent_framework.ag_ui import AgentFrameworkAgent
executed_args: list[dict[str, Any]] = []
@tool(
name="update_record",
description="Update a record",
approval_mode="always_require",
)
def update_record(record_id: str, value: str) -> str:
executed_args.append({"record_id": record_id, "value": value})
return f"updated {record_id} to {value}"
async def stream_fn_approval(
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(
contents=[
Content.from_function_call(
name="update_record",
call_id="call_update_001",
arguments={"record_id": "alpha", "value": "approved"},
)
]
)
wrapper = AgentFrameworkAgent(
agent=Agent(
client=streaming_chat_client_stub(stream_fn_approval),
name="test_agent",
instructions="Test",
tools=[update_record],
)
)
thread_id = "thread-argument-mismatch-test"
events1: list[Any] = []
async for event in wrapper.run({"thread_id": thread_id, "messages": [{"role": "user", "content": "update"}]}):
events1.append(event)
assert any("call_update_001" in k for k in wrapper._pending_approvals)
async def stream_fn_post(
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[Content.from_text(text="Done")])
wrapper.agent = Agent(
client=streaming_chat_client_stub(stream_fn_post),
name="test_agent",
instructions="Test",
tools=[update_record],
)
turn2_input: dict[str, Any] = {
"thread_id": thread_id,
"messages": [
{
"role": "user",
"content": "approve",
"function_approvals": [
{
"id": "call_update_001",
"call_id": "call_update_001",
"name": "update_record",
"approved": True,
"arguments": {"record_id": "beta", "value": "changed"},
}
],
},
],
}
events2: list[Any] = []
async for event in wrapper.run(turn2_input):
events2.append(event)
assert executed_args == []
assert any("call_update_001" in k for k in wrapper._pending_approvals), (
"Pending approval should be preserved after argument mismatch for legitimate retry"
)
async def test_state_update_end_to_end_via_real_tool_invocation(streaming_chat_client_stub):
"""End-to-end coverage for issue #3167: a real ``@tool`` returning ``state_update`` must
emit a deterministic STATE_SNAPSHOT through the full pipeline.
@@ -1352,6 +1352,70 @@ async def test_workflow_run_approval_via_messages_approved() -> None:
assert not resumed_finished.get("interrupt")
async def test_workflow_run_approval_argument_mismatch_keeps_interrupt_pending() -> None:
"""Workflow approval responses must not resume with changed function arguments."""
handled_responses: list[dict[str, Any]] = []
class ApprovalExecutor(Executor):
def __init__(self) -> None:
super().__init__(id="approval_executor")
@handler
async def start(self, message: Any, ctx: WorkflowContext) -> None:
del message
function_call = Content.from_function_call(
call_id="refund-call",
name="submit_refund",
arguments={"order_id": "12345", "amount": "$89.99"},
)
approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call)
await ctx.request_info(approval_request, Content, request_id="approval-1")
@response_handler
async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None:
del original_request
if response.function_call is not None:
handled_responses.append(response.function_call.parse_arguments() or {})
await ctx.yield_output("handled")
workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build()
first_events = [
event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)
]
first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump()
interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt"))
assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1
resumed_events = [
event
async for event in run_workflow_stream(
{
"messages": [
{
"role": "user",
"content": "",
"function_approvals": [
{
"approved": True,
"id": "approval-1",
"call_id": "refund-call",
"name": "submit_refund",
"arguments": {"order_id": "99999", "amount": "$1000.00"},
}
],
}
],
},
workflow,
)
]
assert handled_responses == []
resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump()
assert resumed_finished.get("interrupt")
async def test_workflow_run_approval_via_messages_denied() -> None:
"""Denied approval response sent via messages (function_approvals) should satisfy the pending request."""
@@ -221,9 +221,31 @@ class ClaudeAgentOptions(TypedDict, total=False):
thinking: ThinkingConfig
"""Extended thinking configuration (adaptive, enabled, or disabled)."""
effort: Literal["low", "medium", "high", "max"]
effort: Literal["low", "medium", "high", "xhigh", "max"]
"""Effort level for thinking depth."""
skills: list[str] | Literal["all"]
"""Skills to enable for the main session. Use ``"all"`` for every discovered skill,
a list of named skills, or ``[]`` to suppress all skills."""
session_id: str
"""Use a specific session ID (must be a valid UUID) instead of auto-generated."""
task_budget: dict[str, int]
"""API-side task budget in tokens for pacing tool use."""
include_hook_events: bool
"""When True, hook lifecycle events are emitted in the message stream."""
strict_mcp_config: bool
"""When True, only use MCP servers passed via ``mcp_servers``, ignoring all others."""
continue_conversation: bool
"""Continue the most recent conversation instead of starting a new one."""
fork_session: bool
"""When True, resumed sessions fork to a new session ID."""
on_function_approval: FunctionApprovalCallback
"""Approval callback for ``FunctionTool`` instances declared with
``approval_mode="always_require"``. The callback is awaited (sync or async)
+1 -1
View File
@@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"agent-framework-core>=1.6.0,<2",
"claude-agent-sdk>=0.1.36,<0.1.49",
"claude-agent-sdk>=0.1.36,<0.3",
]
[tool.uv]
+2
View File
@@ -80,6 +80,8 @@ agent_framework/
- **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s.
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument.
- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins.
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
+80 -29
View File
@@ -92,12 +92,16 @@ OptionsCoT = TypeVar(
def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Merge two options dicts, with override values taking precedence.
``None`` is treated as "unset": ``None`` overrides are skipped so they don't clobber a base
value, and the merged result is stripped of any remaining ``None`` values in a final pass so
unset options are never forwarded (e.g. an unset ``store`` is left for the service to default).
Args:
base: The base options dict.
override: The override options dict (values take precedence).
Returns:
A new merged options dict.
A new merged options dict containing no ``None`` values.
"""
result = dict(base)
@@ -123,7 +127,7 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
result["instructions"] = f"{result['instructions']}\n{value}"
else:
result[key] = value
return result
return {key: value for key, value in result.items() if value is not None}
def _sanitize_agent_name(agent_name: str | None) -> str | None:
@@ -460,6 +464,9 @@ class BaseAgent(SerializationMixin):
if provider_session is None and self.context_providers:
provider_session = AgentSession()
# When per-service-call persistence is enabled, the per-service-call middleware owns
# HistoryProvider persistence (in both the local and service-managed cases), so skip
# them on the once-per-run path to avoid double persistence.
per_service_call_history_required = self.require_per_service_call_history_persistence and any(
isinstance(provider, HistoryProvider) for provider in self.context_providers
)
@@ -686,11 +693,16 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
description: A brief description of the agent's purpose.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
require_per_service_call_history_persistence: When True, history providers are invoked
around each model call instead of once per ``run()`` when the service
is not already storing history. If service-side storage is active for
the run, the agent skips local history providers and relies on the
service-managed conversation instead.
require_per_service_call_history_persistence: When True (and a HistoryProvider is
present), the provider always persists history via per-service-call middleware,
regardless of whether the client stores history server-side. If the client does
not store history, the middleware also loads providers around each model call and
drives the function loop with a local conversation; if it does, loading is skipped
(the service-managed conversation is the source of truth) and the middleware only
persists. A warning is logged for providers with ``load_messages=True`` when
loading is skipped because service-side storage is active. When no HistoryProvider
is present, this flag has no effect (no middleware is installed and nothing is
persisted).
default_options: A TypedDict containing chat options. When using a typed agent like
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
provider-specific options including temperature, max_tokens, model,
@@ -791,22 +803,20 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
self,
*,
session: AgentSession | None,
options: Mapping[str, Any] | None,
conversation_id: str | None,
service_stores_history: bool,
) -> list[HistoryProvider]:
history_providers = self._get_history_providers()
if not self.require_per_service_call_history_persistence or not history_providers:
return []
conversation_id = (
session.service_session_id
if session and session.service_session_id
else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id"))
)
if service_stores_history:
return []
if conversation_id is not None:
# A live service-managed session id takes precedence over the resolved conversation id.
if session and session.service_session_id:
conversation_id = session.service_session_id
# Without service-side storage the middleware persists locally and drives the function
# loop with a local sentinel, which cannot be reconciled with an existing service-managed
# conversation. When the service stores history, an existing conversation id is expected.
if conversation_id is not None and not service_stores_history:
raise AgentInvalidRequestException(
"require_per_service_call_history_persistence cannot be used "
"with an existing service-managed conversation."
@@ -1167,18 +1177,34 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
input_messages = normalize_messages(messages)
# `store` in runtime or agent options takes precedence over client-level storage
# indicators. An explicit `store=False` forces local (in-memory) history injection,
# even if the client is configured to use service-side storage by default.
store_ = opts.get("store", self.default_options.get("store", getattr(self.client, "STORES_BY_DEFAULT", False)))
# Combine agent-level defaults with runtime options up front so the decisions below read
# `store` from a single place rather than introspecting both dicts. _merge_options applies
# the same precedence used for the actual client call (runtime wins; unset/None falls back
# to the agent default).
effective_options = _merge_options(self.default_options, opts)
# `store` in runtime or agent options takes precedence over the client's default
# storage behavior. An explicit `store=False` forces local (in-memory) history
# injection even when the client stores server-side by default; an explicit
# `store=True` forces service-side storage. A `store=None`/unset value means the
# service falls back to its own default.
explicit_store = effective_options.get("store")
# Internal behavior hint: will the service own history for this run? Only when the
# user left `store` unset do we fall back to the client's STORES_BY_DEFAULT.
service_stores_history = (
explicit_store if explicit_store is not None else getattr(self.client, "STORES_BY_DEFAULT", False)
)
# Resolve conversation_id from the same combined view so an agent-level default is honored
# when the runtime omits it (a live session id still takes precedence below).
effective_conversation_id = effective_options.get("conversation_id")
# Auto-inject InMemoryHistoryProvider when session is provided, no context providers
# registered, and no service-side storage indicators
if (
session is not None
and not self.context_providers
and not session.service_session_id
and not opts.get("conversation_id")
and not store_
and not effective_conversation_id
and not service_stores_history
):
self.context_providers.append(InMemoryHistoryProvider())
@@ -1188,10 +1214,30 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
per_service_call_history_providers = self._resolve_per_service_call_history_providers(
session=active_session,
options=opts,
service_stores_history=bool(store_),
conversation_id=effective_conversation_id,
service_stores_history=service_stores_history,
)
# When require_per_service_call_history_persistence is set together with a
# HistoryProvider, the per-service-call middleware (installed below) always persists
# the provider. ``service_stores_history`` only selects how the middleware behaves:
# - service does not store: the middleware also loads providers and drives the function
# loop with a local sentinel conversation id, or
# - service stores: the middleware skips loading (the service owns history) and simply
# persists each service call while the real conversation id flows through.
# In the service-managed case loading is skipped, so warn for providers that expect to load.
history_providers = self._get_history_providers()
if self.require_per_service_call_history_persistence and history_providers and service_stores_history:
for provider in history_providers:
if provider.load_messages:
logger.warning(
"HistoryProvider '%s' has load_messages=True but the chat client stores history "
"server-side; skipping local history load and relying on the service-managed "
"conversation. Set store=False to load from the provider, or load_messages=False "
"to silence this warning.",
provider.source_id,
)
session_context, chat_options = await self._prepare_session_and_messages(
session=active_session,
input_messages=input_messages,
@@ -1265,8 +1311,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
}
if model is not None:
run_opts["model"] = model
# Remove None values and merge with chat_options
run_opts = {k: v for k, v in run_opts.items() if v is not None}
# _merge_options strips unset (None) options, so e.g. an unset `store` is not forwarded
# and the service decides its own default.
co = _merge_options(chat_options, run_opts)
# Build session_messages from session context: context messages + input messages
@@ -1280,6 +1326,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
agent=self,
session=active_session,
providers=per_service_call_history_providers,
service_stores_history=service_stores_history,
)
existing_middleware = effective_client_kwargs.get("middleware")
if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)):
@@ -1319,7 +1366,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
"input_messages": input_messages,
"session_messages": session_messages,
"agent_name": agent_name,
"suppress_response_id": bool(per_service_call_history_providers),
"suppress_response_id": bool(per_service_call_history_providers) and not service_stores_history,
"chat_options": co,
"compaction_strategy": compaction_strategy or self.compaction_strategy,
"tokenizer": tokenizer or self.tokenizer,
@@ -1413,11 +1460,15 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=options or {},
)
# When per-service-call persistence is enabled, the per-service-call middleware owns
# HistoryProvider loading (it loads locally when the service does not store history, or
# relies on the service when it does), so skip them on the once-per-run before_run path.
per_service_call_history_required = self.require_per_service_call_history_persistence and bool(
self._get_history_providers()
)
# Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history)
# Run before_run providers (forward order, skip HistoryProvider when per-service-call
# persistence owns loading)
for provider in self.context_providers:
if per_service_call_history_required and isinstance(provider, HistoryProvider):
continue
@@ -604,10 +604,13 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
and dict literals are accepted without specialized option typing.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
require_per_service_call_history_persistence: Whether to require per-service-call
chat history persistence. When enabled, history providers are invoked around
each model call instead of once per ``run()`` when the service is not already
storing history.
require_per_service_call_history_persistence: When enabled (and a HistoryProvider is
present), the provider always persists history after each model call. If the
client does not store history server-side, history providers are also loaded and
injected around each model call; if it does, provider loading is skipped and the
service-managed conversation is the source of truth (persistence still happens
after each model call). When no HistoryProvider is present, this flag has no
effect (no middleware is installed and nothing is persisted).
function_invocation_configuration: Optional function invocation configuration override.
compaction_strategy: Optional agent-level compaction override. When omitted,
client-level compaction defaults remain in effect for each call.
+154 -27
View File
@@ -70,6 +70,31 @@ class MCPSpecificApproval(TypedDict, total=False):
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
# Reserved key in an ``additional_tool_argument_names`` mapping that applies its
# values to every tool on the server rather than a single named tool.
_MCP_GLOBAL_EXTRA_ARGS_KEY = "*"
# Framework kwargs that flow through the function-invocation pipeline (via
# ``FunctionInvocationContext.kwargs``) but must never be forwarded to an MCP
# server: they are internal objects that the MCP SDK cannot serialize. They are
# dropped as a safety net when a tool declares one of them in its schema, unless
# the user explicitly opts the name back in via ``additional_tool_argument_names``
# (explicit extras always win over the denylist).
# - chat_options/tools/tool_choice/session/thread: framework runtime objects.
# - conversation_id: internal tracking ID used by services like Azure AI.
# - options: metadata/store used by AG-UI for Azure AI client requirements.
# - response_format: a Pydantic model class for structured output (not serializable).
# - _meta: reserved key extracted separately as MCP request metadata.
_MCP_FRAMEWORK_DENYLIST: frozenset[str] = frozenset({
"chat_options",
"tools",
"tool_choice",
"session",
"thread",
"conversation_id",
"options",
"response_format",
"_meta",
})
_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers")
MCP_DEFAULT_TIMEOUT = 30
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5
@@ -135,6 +160,34 @@ def _build_prefixed_mcp_name(
return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix
def _normalize_additional_tool_argument_names(
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None,
) -> tuple[set[str], dict[str, set[str]]]:
"""Split user-supplied extra argument names into global and per-tool sets.
Accepts either a sequence (applied to every tool) or a mapping keyed by remote
tool name, where the reserved key ``"*"`` is treated as global. Mapping values
may be a sequence or a single string. Returns a
``(global_extras, per_tool_extras)`` tuple.
"""
if additional_tool_argument_names is None:
return set(), {}
if isinstance(additional_tool_argument_names, str):
return {additional_tool_argument_names}, {}
if isinstance(additional_tool_argument_names, Mapping):
global_extras: set[str] = set()
per_tool_extras: dict[str, set[str]] = {}
for tool_name, names in additional_tool_argument_names.items():
# Treat a bare string value as a single name rather than iterating its characters.
names_set = {names} if isinstance(names, str) else set(names)
if tool_name == _MCP_GLOBAL_EXTRA_ARGS_KEY:
global_extras.update(names_set)
else:
per_tool_extras[tool_name] = names_set
return global_extras, per_tool_extras
return set(additional_tool_argument_names), {}
def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None:
"""Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s)."""
carrier: dict[str, str] = {}
@@ -294,6 +347,7 @@ class MCPTool:
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
) -> None:
"""Initialize the MCP Tool base.
@@ -328,6 +382,10 @@ class MCPTool:
task_options: Options controlling how long-running MCP tasks are driven for
tools that advertise ``execution.taskSupport == "required"``. When ``None``,
the defaults from :class:`MCPTaskOptions` are used.
additional_tool_argument_names: Extra argument names to forward to the MCP server
in addition to each tool's declared parameters. A ``Sequence[str]`` applies to
every tool; a ``Mapping[str, Sequence[str]]`` is keyed by remote tool name with
``"*"`` as a global key. See the transport subclasses for full details.
"""
self.name = name
self.description = description or ""
@@ -355,6 +413,10 @@ class MCPTool:
self._functions: list[FunctionTool] = []
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
self._tool_task_support_by_name: dict[str, str] = {}
self._tool_param_names_by_name: dict[str, set[str]] = {}
self._global_extra_arg_names, self._tool_extra_arg_names = _normalize_additional_tool_argument_names(
additional_tool_argument_names
)
self.is_connected: bool = False
self._tools_loaded: bool = False
self._prompts_loaded: bool = False
@@ -1229,6 +1291,7 @@ class MCPTool:
existing_names = {func.name for func in self._functions}
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
tool_task_support_by_name: dict[str, str] = {}
tool_param_names_by_name: dict[str, set[str]] = {}
params: types.PaginatedRequestParams | None = None
while True:
@@ -1271,14 +1334,6 @@ class MCPTool:
if task_support is not None:
tool_task_support_by_name[tool.name] = task_support
normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
# Skip if already loaded
if local_name in existing_names:
continue
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
# Normalize inputSchema: ensure "properties" exists for object schemas.
# Some MCP servers (e.g. zero-argument tools) omit "properties",
# which causes OpenAI API to reject the schema with a 400 error.
@@ -1288,6 +1343,24 @@ class MCPTool:
if input_schema.get("type") == "object" and "properties" not in input_schema:
input_schema["properties"] = {}
# Register declared param names before the existing-tool skip below so that
# reloads (e.g. notifications/tools/list_changed) preserve the allowlist for
# tools that are already loaded, consistent with tool_call_meta_by_name and
# tool_task_support_by_name above.
schema_properties = input_schema.get("properties")
tool_param_names_by_name[tool.name] = (
set(cast(dict[str, Any], schema_properties)) if isinstance(schema_properties, dict) else set()
)
normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
# Skip if already loaded
if local_name in existing_names:
continue
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
async def _call_tool_with_runtime_kwargs(
ctx: FunctionInvocationContext,
*,
@@ -1320,6 +1393,7 @@ class MCPTool:
self._tool_call_meta_by_name = tool_call_meta_by_name
self._tool_task_support_by_name = tool_task_support_by_name
self._tool_param_names_by_name = tool_param_names_by_name
async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
@@ -1530,10 +1604,14 @@ class MCPTool:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
def _resolved_extra_args(self, tool_name: str) -> set[str]:
"""Return the user-configured extra argument names allowed for a tool."""
return self._global_extra_arg_names | self._tool_extra_arg_names.get(tool_name, set())
def _prepare_call_kwargs(
self, tool_name: str, kwargs: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any] | None]:
"""Filter framework-only kwargs and build the merged MCP request metadata."""
"""Filter kwargs down to the tool's arguments and build the merged MCP request metadata."""
raw_user_meta: object | None = kwargs.get("_meta")
user_meta: dict[str, Any] | None = None
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
@@ -1546,27 +1624,28 @@ class MCPTool:
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
user_meta[key] = value
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
# conversation_id is an internal tracking ID used by services like Azure AI.
# options contains metadata/store used by AG-UI for Azure AI client requirements.
# response_format is a Pydantic model class used for structured output (not serializable).
# Allowlist: forward only the tool's declared parameters (from inputSchema.properties)
# plus any user-configured extra argument names. Everything else - notably the
# framework runtime kwargs injected through the function-invocation pipeline - is
# stripped so it is never forwarded to the MCP server. Tools that declare no usable
# properties forward only the user-configured extras.
#
# The extra names come exclusively from additional_tool_argument_names, which is set in
# user code at construction time; there is no per-call override, so a model-issued tool
# call cannot change which names are allowed through.
#
# The framework denylist acts as a safety net for keys a server *declares* in its
# schema that collide with internal, non-serializable framework objects (e.g. a tool
# that declares a parameter literally named "thread"): such declared-but-denylisted
# keys are dropped. Names the user explicitly opts in via additional_tool_argument_names
# always win. The reserved _meta key is handled separately above and never forwarded as
# an argument.
declared = self._tool_param_names_by_name.get(tool_name, set())
extras = self._resolved_extra_args(tool_name)
filtered_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in {
"chat_options",
"tools",
"tool_choice",
"session",
"thread",
"conversation_id",
"options",
"response_format",
"_meta",
}
if k != "_meta" and (k in extras or (k in declared and k not in _MCP_FRAMEWORK_DENYLIST))
}
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
@@ -2142,6 +2221,7 @@ class MCPStdioTool(MCPTool):
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP stdio tool.
@@ -2188,6 +2268,20 @@ class MCPStdioTool(MCPTool):
client: The chat client to use for sampling.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
additional_tool_argument_names: Extra argument names to forward to the MCP server in
addition to each tool's declared parameters (from its ``inputSchema.properties``).
By default only declared parameters are sent; framework runtime kwargs injected
through the function-invocation pipeline are stripped. Use this to opt specific
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
``"*"`` applies to every tool. This is configured only here in user code; there is
no per-call override, so a model-issued tool call cannot change which names pass
through. To use a server that accepts ``additionalProperties: true``, list the
extra names here and then either (1) manually extend that tool's ``inputSchema``
(via the ``.functions`` list after connecting) so the model is prompted to supply
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
a name is supplied via both the model and ``function_invocation_kwargs``, the
model-supplied value wins.
kwargs: Any extra arguments to pass to the stdio client.
"""
super().__init__(
@@ -2205,6 +2299,7 @@ class MCPStdioTool(MCPTool):
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
task_options=task_options,
additional_tool_argument_names=additional_tool_argument_names,
)
self.command = command
self.args = args or []
@@ -2284,6 +2379,7 @@ class MCPStreamableHTTPTool(MCPTool):
http_client: AsyncClient | None = None,
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
task_options: MCPTaskOptions | None = None,
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.
@@ -2338,6 +2434,20 @@ class MCPStreamableHTTPTool(MCPTool):
agent middleware) without creating a separate ``httpx.AsyncClient``.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
additional_tool_argument_names: Extra argument names to forward to the MCP server in
addition to each tool's declared parameters (from its ``inputSchema.properties``).
By default only declared parameters are sent; framework runtime kwargs injected
through the function-invocation pipeline are stripped. Use this to opt specific
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
``"*"`` applies to every tool. This is configured only here in user code; there is
no per-call override, so a model-issued tool call cannot change which names pass
through. To use a server that accepts ``additionalProperties: true``, list the
extra names here and then either (1) manually extend that tool's ``inputSchema``
(via the ``.functions`` list after connecting) so the model is prompted to supply
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
a name is supplied via both the model and ``function_invocation_kwargs``, the
model-supplied value wins.
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
"""
super().__init__(
@@ -2355,6 +2465,7 @@ class MCPStreamableHTTPTool(MCPTool):
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
task_options=task_options,
additional_tool_argument_names=additional_tool_argument_names,
)
self.url = url
self.terminate_on_close = terminate_on_close
@@ -2481,6 +2592,7 @@ class MCPWebsocketTool(MCPTool):
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP WebSocket tool.
@@ -2525,6 +2637,20 @@ class MCPWebsocketTool(MCPTool):
client: The chat client to use for sampling.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
additional_tool_argument_names: Extra argument names to forward to the MCP server in
addition to each tool's declared parameters (from its ``inputSchema.properties``).
By default only declared parameters are sent; framework runtime kwargs injected
through the function-invocation pipeline are stripped. Use this to opt specific
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
``"*"`` applies to every tool. This is configured only here in user code; there is
no per-call override, so a model-issued tool call cannot change which names pass
through. To use a server that accepts ``additionalProperties: true``, list the
extra names here and then either (1) manually extend that tool's ``inputSchema``
(via the ``.functions`` list after connecting) so the model is prompted to supply
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
a name is supplied via both the model and ``function_invocation_kwargs``, the
model-supplied value wins.
kwargs: Any extra arguments to pass to the WebSocket client.
"""
super().__init__(
@@ -2542,6 +2668,7 @@ class MCPWebsocketTool(MCPTool):
parse_prompt_results=parse_prompt_results,
request_timeout=request_timeout,
task_options=task_options,
additional_tool_argument_names=additional_tool_argument_names,
)
self.url = url
self._client_kwargs = kwargs
@@ -16,6 +16,7 @@ from __future__ import annotations
import asyncio
import copy
import json
import logging
import threading
import uuid
import weakref
@@ -36,6 +37,8 @@ if TYPE_CHECKING:
from ._middleware import MiddlewareTypes
logger = logging.getLogger("agent_framework")
# Registry of known types for state deserialization
_STATE_TYPE_REGISTRY: dict[str, type] = {}
@@ -580,6 +583,7 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
agent: SupportsAgentRun,
session: AgentSession,
providers: Sequence[HistoryProvider],
service_stores_history: bool = False,
) -> None:
"""Initialize the middleware.
@@ -587,10 +591,16 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
agent: The agent that owns the history providers.
session: The active session for the current run.
providers: The history providers participating in per-service-call persistence.
service_stores_history: When True, the chat client stores history server-side. The
middleware then skips loading providers and leaves the real conversation id
untouched, persisting each service call without driving the function loop with a
local sentinel. When False, the middleware loads providers and uses a local
sentinel conversation id so the function loop runs without service-side storage.
"""
self._agent = agent
self._session = session
self._providers = list(providers)
self._service_stores_history = service_stores_history
async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext:
"""Create a per-call SessionContext and load history providers into it."""
@@ -602,6 +612,9 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
)
for source_id, source_messages in context_messages.items():
service_call_context.extend_messages(source_id, source_messages)
# When the service stores history, it owns loading; the providers are write-only sinks.
if self._service_stores_history:
return service_call_context
for provider in self._providers:
if not provider.load_messages:
continue
@@ -652,17 +665,35 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
response: ChatResponse,
) -> ChatResponse:
"""Persist a model response and apply the local follow-up sentinel when needed."""
if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id):
if (
not self._service_stores_history
and response.conversation_id is not None
and not is_local_history_conversation_id(response.conversation_id)
):
raise ChatClientInvalidResponseException(
"require_per_service_call_history_persistence cannot be used "
"when the chat client returns a real conversation_id."
)
# In storing mode the service is expected to echo a conversation id that the next run
# resumes from. If it comes back empty, the provider still captures this turn but there is
# no service id to load from next time, so cross-turn history can be lost silently. Warn
# every time so this uncommon, easy-to-miss failure mode cannot fail quietly.
if self._service_stores_history and response.conversation_id is None:
logger.warning(
"require_per_service_call_history_persistence is enabled with a chat client that "
"stores history server-side, but the client returned no conversation_id; cross-turn "
"history may not resume. Set store=False to load and resume from the HistoryProvider "
"instead."
)
await self._persist_service_call_response(
service_call_context=service_call_context,
response=response,
)
if _response_contains_follow_up_request(response):
# The local sentinel only applies when the service does not store history; when it does,
# the real conversation id already drives function-loop continuation.
if not self._service_stores_history and _response_contains_follow_up_request(response):
response.mark_internal_conversation_id()
response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID
return response
@@ -681,8 +712,12 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
result type for streaming or non-streaming execution.
"""
service_call_context = await self._prepare_service_call_context(context.messages)
context.messages = service_call_context.get_messages(include_input=True)
self._strip_local_conversation_id(context)
# When the service stores history, leave the outgoing messages and the real conversation
# id untouched (pass-through); the middleware only persists. Otherwise reconstruct the
# outgoing messages from the loaded local history and strip the local sentinel.
if not self._service_stores_history:
context.messages = service_call_context.get_messages(include_input=True)
self._strip_local_conversation_id(context)
await call_next()
+472 -2
View File
@@ -3,6 +3,7 @@
import contextlib
import inspect
import json
import logging
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
from typing import Any, cast
from unittest.mock import AsyncMock, MagicMock, patch
@@ -42,6 +43,8 @@ from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_m
from agent_framework._middleware import FunctionInvocationContext
from agent_framework.exceptions import AgentInvalidRequestException, ChatClientInvalidResponseException
from .conftest import MockBaseChatClient
class _FixedTokenizer:
def __init__(self, token_count: int) -> None:
@@ -609,6 +612,7 @@ async def test_streaming_per_service_call_persistence_hides_response_id_from_aft
async def test_per_service_call_persistence_uses_real_service_storage_when_client_stores_by_default(
chat_client_base: SupportsChatGetResponse,
caplog: pytest.LogCaptureFixture,
) -> None:
provider = _RecordingHistoryProvider()
@@ -649,15 +653,22 @@ async def test_per_service_call_persistence_uses_real_service_storage_when_clien
require_per_service_call_history_persistence=True,
)
result = await agent.run("What's the weather in Seattle?", session=session)
with caplog.at_level(logging.WARNING, logger="agent_framework"):
result = await agent.run("What's the weather in Seattle?", session=session)
provider_state = session.state[provider.source_id]
assert result.text == "It is sunny in Seattle."
assert result.response_id == "resp_call_2"
assert chat_client_base.call_count == 2
# The service owns the conversation, so the provider never loads (issue #5798).
assert "get_call_count" not in provider_state
assert "save_call_count" not in provider_state
# Persistence is owned by the per-service-call middleware: it persists once per service call
# (issue #5798: the provider must never be silently bypassed when the service stores history).
# This run makes two service calls (function call + final answer), so it persists twice.
assert provider_state["save_call_count"] == 2
# load_messages=True while the service stores history surfaces a warning.
assert any("load_messages" in record.message for record in caplog.records)
assert session.service_session_id == "resp_service_managed"
@@ -1996,6 +2007,19 @@ def test_merge_options_none_values_ignored():
assert result["key2"] == "value2"
def test_merge_options_drops_none_base_values():
"""Test _merge_options strips None values so unset options are never forwarded."""
base = {"store": None, "temperature": 0.5}
override = {"top_p": 0.9}
result = _merge_options(base, override)
# An unset base value (e.g. store=None from default_options) must not survive the merge.
assert "store" not in result
assert result["temperature"] == 0.5
assert result["top_p"] == 0.9
def test_merge_options_runtime_model_overrides_default_model() -> None:
"""Test _merge_options lets a runtime model override a default model."""
result = _merge_options({"model": "default-model"}, {"model": "runtime-model"})
@@ -2658,3 +2682,449 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo
assert len(exc_info.value.contents) == 1
assert exc_info.value.contents[0].type == "oauth_consent_request"
assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent"
# region Per-service-call history persistence scenario matrix
#
# The driving field is ``require_per_service_call_history_persistence``. Every scenario runs a
# single agent run that makes **two service calls** -- a function call followed by a final
# completion -- so the *timing* of persistence is observable:
#
# * When the flag is ``True``, the per-service-call middleware persists the provider **after each
# service call**. So the function-call turn is already saved by the time the second (final)
# service call starts. This holds regardless of whether the chat client stores history
# server-side (the bug in issue #5798 was that a storing client silently bypassed persistence).
# * When the flag is ``False``, the provider persists **once, at the end of the run** -- nothing is
# saved between the two service calls.
#
# ``SpyChatClient.saves_before_call`` records ``provider.save_calls`` at the start of every service
# call, so ``[0, 1]`` means "the function-call turn was persisted before the final call" and
# ``[0, 0]`` means "no persistence happened mid-run". The client's ``store`` / ``STORES_BY_DEFAULT``
# only selects *how* the middleware behaves -- never *whether* the provider persists.
_PSC_SERVICE_CONVERSATION_ID = "svc-conversation"
_psc_stream_params = pytest.mark.parametrize("stream", [False, True], ids=["sync", "stream"])
@tool(name="lookup_weather", approval_mode="never_require")
def _psc_lookup_weather(location: str) -> str:
return f"Weather in {location}: sunny"
def _psc_function_call_script() -> list[tuple[str, ...]]:
"""A fresh function-call-then-final-completion script (the client mutates it)."""
return [
("call", "call_1", "lookup_weather", '{"location": "Seattle"}'),
("text", "It is sunny in Seattle."),
]
class _PscSpyHistoryProvider(HistoryProvider):
"""In-memory history provider that records load/save calls for assertions."""
def __init__(self, source_id: str = "spy_history", **kwargs: Any) -> None:
super().__init__(source_id, **kwargs)
self._messages: list[Message] = []
self.get_calls: int = 0
self.save_calls: int = 0
self.saved_batches: list[list[Message]] = []
async def get_messages(
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
) -> list[Message]:
self.get_calls += 1
return list(self._messages)
async def save_messages(
self,
session_id: str | None,
messages: Sequence[Message],
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.save_calls += 1
self.saved_batches.append(list(messages))
self._messages.extend(messages)
@property
def stored_messages(self) -> list[Message]:
return list(self._messages)
class _PscSpyChatClient(MockBaseChatClient):
"""Chat client that scripts a function-call/final-completion sequence.
It records, at the start of each service call, how many provider saves have already happened
(``saves_before_call``), what messages it received, and what options it saw. When the effective
``store`` is truthy it returns a stable ``conversation_id`` to mimic a server-managed
conversation, so the framework propagates ``session.service_session_id``.
"""
def __init__(
self,
*,
provider: _PscSpyHistoryProvider,
stores_by_default: bool = False,
script: list[tuple[str, ...]] | None = None,
echo_conversation_id: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.STORES_BY_DEFAULT = stores_by_default # type: ignore[attr-defined]
self._provider = provider
self._script = list(script) if script is not None else [("text", "ok")]
self._echo_conversation_id = echo_conversation_id
self.received_messages: list[list[Message]] = []
self.received_options: list[dict[str, Any]] = []
self.saves_before_call: list[int] = []
def _effective_store(self, options: dict[str, Any]) -> bool:
store = options.get("store")
if store is None:
return bool(self.STORES_BY_DEFAULT)
return bool(store)
def _next_contents(self) -> list[Content]:
turn = self._script.pop(0) if self._script else ("text", "ok")
if turn[0] == "call":
_, call_id, name, args = turn
return [Content.from_function_call(call_id=call_id, name=name, arguments=args)]
return [Content.from_text(turn[1])]
def _inner_get_response( # type: ignore[override]
self,
*,
messages: MutableSequence[Message],
stream: bool,
options: dict[str, Any],
**kwargs: Any,
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
self.received_messages.append(list(messages))
self.received_options.append(dict(options))
self.saves_before_call.append(self._provider.save_calls)
store_and_echo = self._effective_store(options) and self._echo_conversation_id
conv_id = _PSC_SERVICE_CONVERSATION_ID if store_and_echo else None
contents = self._next_contents()
if stream:
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
self.call_count += 1
yield ChatResponseUpdate(
contents=contents,
role="assistant",
finish_reason="stop",
conversation_id=conv_id,
)
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
response = ChatResponse.from_updates(updates, output_format_type=options.get("response_format"))
if conv_id:
response.conversation_id = conv_id
return response
return ResponseStream(_stream(), finalizer=_finalize)
async def _get() -> ChatResponse:
self.call_count += 1
return ChatResponse(
messages=Message(role="assistant", contents=contents),
conversation_id=conv_id,
)
return _get()
def _psc_build_agent(
client: _PscSpyChatClient,
provider: _PscSpyHistoryProvider,
*,
require_per_service_call_history_persistence: bool,
default_options: dict[str, Any] | None = None,
) -> Agent:
kwargs: dict[str, Any] = {}
if default_options is not None:
kwargs["default_options"] = default_options
return Agent(
client=client,
tools=[_psc_lookup_weather],
context_providers=[provider],
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
**kwargs,
)
async def _psc_run(agent: Agent, text: str, session: AgentSession, *, stream: bool) -> str:
if stream:
chunks: list[str] = []
async for update in agent.run(text, session=session, stream=True):
chunks.append(update.text or "")
return "".join(chunks)
result = await agent.run(text, session=session)
return result.text
# driver=True (the contract under test): persistence happens per service call
@_psc_stream_params
async def test_psc_flag_on_store_false_persists_after_each_service_call(stream: bool) -> None:
"""Mode A (flag on, service does not store): function-call turn is persisted before the final call."""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
session = agent.create_session()
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert text == "It is sunny in Seattle."
# Two service calls: function call, then final completion.
assert client.call_count == 2
# The contract: the function-call turn was persisted *before* the second service call started.
assert client.saves_before_call == [0, 1]
assert provider.save_calls == 2
# Mode A loads local history (the middleware injects it before each service call).
assert provider.get_calls >= 1
# No service-side storage, so no conversation id is propagated.
assert session.service_session_id is None
@_psc_stream_params
async def test_psc_flag_on_stores_by_default_persists_after_each_service_call(
stream: bool, caplog: pytest.LogCaptureFixture
) -> None:
"""Mode B (flag on, service stores by default): still persists per service call, but skips load (issue #5798)."""
provider = _PscSpyHistoryProvider() # load_messages=True by default
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
session = agent.create_session()
with caplog.at_level(logging.WARNING, logger="agent_framework"):
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert text == "It is sunny in Seattle."
assert client.call_count == 2
# The invariant the bug violated: persistence still happens per service call when the service stores.
assert client.saves_before_call == [0, 1]
assert provider.save_calls == 2
# The service owns loading, so the provider is never asked to load.
assert provider.get_calls == 0
# A warning surfaces the bypassed load (load_messages=True).
assert any("load_messages" in record.message for record in caplog.records)
# The real service conversation id propagates to the session.
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
@_psc_stream_params
async def test_psc_flag_on_store_only_provider_no_load_no_warning(
stream: bool, caplog: pytest.LogCaptureFixture
) -> None:
"""Mode B with a store-only provider (load_messages=False): persists per call, no load, no warning."""
provider = _PscSpyHistoryProvider(load_messages=False)
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
session = agent.create_session()
with caplog.at_level(logging.WARNING, logger="agent_framework"):
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert client.saves_before_call == [0, 1]
assert provider.save_calls == 2
assert provider.get_calls == 0
assert not any("load_messages" in record.message for record in caplog.records)
@_psc_stream_params
async def test_psc_flag_on_store_false_override_behaves_as_mode_a(stream: bool) -> None:
"""Flag on + storing client but store=False override: falls back to Mode A (local, per call)."""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(
client, provider, require_per_service_call_history_persistence=True, default_options={"store": False}
)
session = agent.create_session()
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert client.saves_before_call == [0, 1]
assert provider.save_calls == 2
assert provider.get_calls >= 1
# store=False forces local handling, so no real service conversation id.
assert session.service_session_id is None
@_psc_stream_params
async def test_psc_flag_on_store_none_treated_as_absent(stream: bool, caplog: pytest.LogCaptureFixture) -> None:
"""Flag on + storing client + explicit store=None: None is "unset", so the storing default applies (Mode B)."""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(
client, provider, require_per_service_call_history_persistence=True, default_options={"store": None}
)
session = agent.create_session()
with caplog.at_level(logging.WARNING, logger="agent_framework"):
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert client.saves_before_call == [0, 1]
assert provider.save_calls == 2
assert provider.get_calls == 0
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
assert any("load_messages" in record.message for record in caplog.records)
# store=None must not be forwarded to the client; the service decides its own default.
assert all("store" not in options for options in client.received_options)
@_psc_stream_params
async def test_psc_flag_on_respects_store_outputs_flag(stream: bool) -> None:
"""Flag on: the provider's store_inputs/store_outputs flags still apply per service call."""
provider = _PscSpyHistoryProvider(store_inputs=True, store_outputs=False)
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
session = agent.create_session()
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert provider.save_calls == 2
# Outputs disabled, so no assistant/tool-call messages were stored, only user/tool inputs.
assert provider.stored_messages
assert all(message.role != "assistant" for message in provider.stored_messages)
# driver=False (control): persistence happens once, at the end of the run
@_psc_stream_params
async def test_psc_flag_off_store_false_persists_once_at_end(stream: bool) -> None:
"""Flag off + non-storing client: nothing is persisted mid-run; one save at the end."""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False)
session = agent.create_session()
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert text == "It is sunny in Seattle."
assert client.call_count == 2
# The control contract: no save happened between the function call and the final completion.
assert client.saves_before_call == [0, 0]
assert provider.save_calls == 1
@_psc_stream_params
async def test_psc_flag_off_stores_by_default_persists_once_at_end(stream: bool) -> None:
"""Flag off + storing client: once-per-run persistence, and the service conversation id propagates."""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False)
session = agent.create_session()
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert client.saves_before_call == [0, 0]
assert provider.save_calls == 1
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
@_psc_stream_params
async def test_psc_flag_on_storing_with_existing_conversation_id_does_not_raise(stream: bool) -> None:
"""Allow side of the guard: flag on + storing client + an existing conversation_id resumes (no raise).
The non-storing path raises on an existing service-managed conversation id, but with a storing
client the run must proceed and the service conversation id must propagate to the session.
"""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
session = agent.create_session()
if stream:
chunks: list[str] = []
async for update in agent.run(
"What's the weather in Seattle?",
session=session,
stream=True,
options={"conversation_id": "existing_conversation"},
):
chunks.append(update.text or "")
text = "".join(chunks)
else:
result = await agent.run(
"What's the weather in Seattle?",
session=session,
options={"conversation_id": "existing_conversation"},
)
text = result.text
assert text == "It is sunny in Seattle."
# Persistence still happens per service call, and the real service id propagates to the session.
assert provider.save_calls == 2
assert provider.get_calls == 0
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
@_psc_stream_params
async def test_psc_flag_on_storing_two_runs_same_session(stream: bool) -> None:
"""Storing mode across two runs on one session: persistence keeps happening, id is stable, no load.
The second run exercises the precedence branch where the session already carries a
service_session_id, which must continue to skip provider loading and keep persisting.
"""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
session = agent.create_session()
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
assert provider.save_calls == 2
assert provider.get_calls == 0
first_run_service_id = session.service_session_id
assert first_run_service_id == _PSC_SERVICE_CONVERSATION_ID
# Reset the scripted client for a second run on the same session.
client._script = _psc_function_call_script()
client.call_count = 0
client.saves_before_call = []
await _psc_run(agent, "And in Portland?", session, stream=stream)
# Persistence keeps happening on the second run (two more saves), still per service call.
assert client.saves_before_call == [2, 3]
assert provider.save_calls == 4
# Loading stays skipped and the service conversation id stays stable across runs.
assert provider.get_calls == 0
assert session.service_session_id == first_run_service_id
@_psc_stream_params
async def test_psc_flag_on_storing_without_conversation_id_warns_every_call(
stream: bool, caplog: pytest.LogCaptureFixture
) -> None:
"""Storing mode but the client returns no conversation_id: warn on every service call.
Without an echoed conversation id the next run has nothing to resume from, so cross-turn
history can be lost silently. The warning fires per service call (no dedup) so the uncommon
failure mode cannot pass unnoticed.
"""
provider = _PscSpyHistoryProvider()
client = _PscSpyChatClient(
provider=provider,
stores_by_default=True,
script=_psc_function_call_script(),
echo_conversation_id=False,
)
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
session = agent.create_session()
with caplog.at_level(logging.WARNING, logger="agent_framework"):
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
# Persistence still happens, but no service id is captured to resume from.
assert provider.save_calls == 2
assert session.service_session_id is None
# Two service calls -> the warning is emitted twice (one per call, not deduped).
missing_id_warnings = [r for r in caplog.records if "returned no conversation_id" in r.message]
assert len(missing_id_warnings) == 2
+203
View File
@@ -30,6 +30,7 @@ from agent_framework._mcp import (
MCPTool,
_build_prefixed_mcp_name,
_get_input_model_from_mcp_prompt,
_normalize_additional_tool_argument_names,
_normalize_mcp_name,
_should_propagate_cancelled_error,
logger,
@@ -6057,3 +6058,205 @@ async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPat
# endregion
# region additional_tool_argument_names / allowlist filtering
def test_normalize_additional_tool_argument_names_none() -> None:
global_extras, per_tool = _normalize_additional_tool_argument_names(None)
assert global_extras == set()
assert per_tool == {}
def test_normalize_additional_tool_argument_names_sequence() -> None:
global_extras, per_tool = _normalize_additional_tool_argument_names(["a", "b", "a"])
assert global_extras == {"a", "b"}
assert per_tool == {}
def test_normalize_additional_tool_argument_names_single_string() -> None:
# A bare string must be treated as a single name, not split into characters.
global_extras, per_tool = _normalize_additional_tool_argument_names("conversation_id")
assert global_extras == {"conversation_id"}
assert per_tool == {}
def test_normalize_additional_tool_argument_names_mapping_with_global_key() -> None:
global_extras, per_tool = _normalize_additional_tool_argument_names({
"*": ["g1"],
"tool_a": ["a1", "a2"],
"tool_b": ["b1"],
})
assert global_extras == {"g1"}
assert per_tool == {"tool_a": {"a1", "a2"}, "tool_b": {"b1"}}
def test_normalize_additional_tool_argument_names_mapping_with_string_values() -> None:
# A bare string mapping value is a single name, not an iterable of characters.
global_extras, per_tool = _normalize_additional_tool_argument_names({
"*": "conversation_id",
"tool_a": "custom",
})
assert global_extras == {"conversation_id"}
assert per_tool == {"tool_a": {"custom"}}
def test_prepare_call_kwargs_strips_undeclared_arguments() -> None:
server = MCPTool(name="test_server")
server._tool_param_names_by_name = {"test_tool": {"param"}}
filtered, meta = server._prepare_call_kwargs(
"test_tool",
{"param": "value", "conversation_id": "c", "thread": object(), "unexpected": 1},
)
assert filtered == {"param": "value"}
assert meta is None
def test_prepare_call_kwargs_global_extras_allowed() -> None:
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
server._tool_param_names_by_name = {"test_tool": {"param"}}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "value", "conversation_id": "c", "options": {}},
)
assert filtered == {"param": "value", "conversation_id": "c"}
def test_prepare_call_kwargs_per_tool_and_global_extras() -> None:
server = MCPTool(
name="test_server",
additional_tool_argument_names={"*": ["conversation_id"], "test_tool": ["custom"]},
)
server._tool_param_names_by_name = {"test_tool": {"param"}, "other_tool": {"x"}}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "conversation_id": "c", "custom": "y", "thread": object()},
)
assert filtered == {"param": "v", "conversation_id": "c", "custom": "y"}
# The per-tool extra does not leak to other tools; the global one still applies.
filtered_other, _ = server._prepare_call_kwargs(
"other_tool",
{"x": 1, "conversation_id": "c", "custom": "y"},
)
assert filtered_other == {"x": 1, "conversation_id": "c"}
def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None:
# The denylist is a safety net for framework-named params a server *declares* in its
# schema: they are dropped so internal objects never leak. Names explicitly opted in
# via extras always win.
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
server._tool_param_names_by_name = {"test_tool": {"param", "thread"}}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "thread": object(), "conversation_id": "c"},
)
# "thread" is declared by the schema but denylisted -> dropped; conversation_id opted in -> kept.
assert filtered == {"param": "v", "conversation_id": "c"}
def test_prepare_call_kwargs_extras_override_denylist() -> None:
# Opting a denylisted framework name back in via extras takes precedence over the
# denylist safety net. "thread" is on the framework denylist, but an explicit extra wins.
server = MCPTool(name="test_server", additional_tool_argument_names=["thread"])
server._tool_param_names_by_name = {"test_tool": {"param"}}
sentinel = object()
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "thread": sentinel, "conversation_id": "c"},
)
# "thread" opted in via extras -> kept despite the denylist; conversation_id is denylisted,
# not declared, and not opted in -> dropped.
assert filtered == {"param": "v", "thread": sentinel}
def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None:
server = MCPTool(name="test_server")
server._tool_param_names_by_name = {"test_tool": set()}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"conversation_id": "c", "thread": object(), "stray": 1},
)
assert filtered == {}
def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None:
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
# No entry in _tool_param_names_by_name for this tool name.
filtered, _ = server._prepare_call_kwargs(
"unknown_tool",
{"conversation_id": "c", "other": 1},
)
assert filtered == {"conversation_id": "c"}
def test_prepare_call_kwargs_extracts_meta() -> None:
server = MCPTool(name="test_server")
server._tool_param_names_by_name = {"test_tool": {"param"}}
filtered, meta = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "_meta": {"trace": "abc"}},
)
assert filtered == {"param": "v"}
assert meta is not None
assert meta.get("trace") == "abc"
async def test_call_tool_forwards_only_declared_arguments() -> None:
"""End-to-end: framework runtime kwargs are stripped before reaching the server."""
class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")])
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server", additional_tool_argument_names=["conversation_id"])
async with server:
await server.load_tools()
session_mock = server.session
await server.call_tool(
"test_tool",
param="value",
conversation_id="c",
thread=object(),
response_format=object(),
)
session_mock.call_tool.assert_called_once()
_, call_kwargs = session_mock.call_tool.call_args
assert call_kwargs["arguments"] == {"param": "value", "conversation_id": "c"}
# endregion
@@ -8,29 +8,34 @@ This module provides ``Mem0ContextProvider``, built on the new
from __future__ import annotations
import asyncio
import logging
import sys
from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypedDict
from agent_framework import Message
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
from mem0 import AsyncMemory, AsyncMemoryClient
if sys.version_info >= (3, 11):
from typing import NotRequired, Self, TypedDict # pragma: no cover
from typing import Self # pragma: no cover
else:
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
from typing_extensions import Self # pragma: no cover
if TYPE_CHECKING:
from agent_framework._agents import SupportsAgentRun
class _MemorySearchResponse_v1_1(TypedDict):
results: list[dict[str, Any]]
relations: NotRequired[list[dict[str, Any]]]
logger = logging.getLogger(__name__)
MemoryRecord: TypeAlias = dict[str, object]
_MemorySearchResponse_v2 = list[dict[str, Any]]
class SearchResults(TypedDict):
results: list[MemoryRecord]
SearchResponse: TypeAlias = list[MemoryRecord] | SearchResults
class Mem0ContextProvider(ContextProvider):
@@ -106,28 +111,85 @@ class Mem0ContextProvider(ContextProvider):
if not input_text.strip():
return
filters = self._build_filters()
# Query entity partitions independently to bypass strict logical AND limitations
# Mem0 OSS and Platform SDKs expose inconsistent search typings.
search_tasks: list[Awaitable[Any]] = []
# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
# AsyncMemoryClient (Platform) expects them in a filters dict
search_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
search_kwargs.update(filters)
else:
search_kwargs["filters"] = filters
# 1. Query User partition independently
if self.user_id:
user_kwargs = self._build_search_kwargs(input_text, "user_id", self.user_id)
search_tasks.append(self.mem0_client.search(**user_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
**search_kwargs,
)
# 2. Query Agent partition independently
if self.agent_id:
agent_kwargs = self._build_search_kwargs(input_text, "agent_id", self.agent_id)
search_tasks.append(self.mem0_client.search(**agent_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
if isinstance(search_response, list):
memories = search_response
elif isinstance(search_response, dict) and "results" in search_response:
memories = search_response["results"]
else:
memories = [search_response]
# Fall back to an app-scoped search when only application_id is configured
if not search_tasks and self.application_id:
app_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
app_kwargs["app_id"] = self.application_id
else:
app_kwargs["filters"] = {"app_id": self.application_id}
search_tasks.append(self.mem0_client.search(**app_kwargs)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
if not search_tasks:
return
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
results: list[SearchResponse | BaseException] = await asyncio.gather(*search_tasks, return_exceptions=True)
# Merge and deduplicate results
memories: list[MemoryRecord] = []
seen_memory_ids: set[str] = set()
failed_tasks_count: int = 0
for search_response in results:
if isinstance(search_response, asyncio.CancelledError):
raise search_response
if isinstance(search_response, BaseException):
failed_tasks_count += 1
logger.error(
"Mem0 partition search task failed: %s",
search_response,
exc_info=(type(search_response), search_response, search_response.__traceback__),
)
continue
current_memories: list[MemoryRecord] = []
if isinstance(search_response, list):
current_memories = [mem for mem in search_response if isinstance(mem, dict)]
elif isinstance(search_response, dict):
results_field = search_response.get("results")
if isinstance(results_field, list):
current_memories = [
item
for item in results_field
if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType]
]
else:
logger.warning(
"Unexpected Mem0 search response format: %s",
type(results_field).__name__,
)
for mem in current_memories:
mem_id = mem.get("id")
if mem_id is not None and not isinstance(mem_id, str):
mem_id = str(mem_id)
if mem_id is not None and mem_id in seen_memory_ids:
continue
if mem_id is not None:
seen_memory_ids.add(mem_id)
memories.append(mem)
if failed_tasks_count == len(search_tasks):
logger.error("All Mem0 retrieval tasks failed. Context provider is unable to verify memory state.")
line_separated_memories = "\n".join(str(memory.get("memory", "")) for memory in memories)
if line_separated_memories:
context.extend_messages(
self.source_id,
@@ -159,12 +221,21 @@ class Mem0ContextProvider(ContextProvider):
]
if messages:
await self.mem0_client.add( # type: ignore[misc]
messages=messages,
user_id=self.user_id,
agent_id=self.agent_id,
metadata={"application_id": self.application_id},
)
add_kwargs: dict[str, Any] = {
"messages": messages,
"user_id": self.user_id,
"agent_id": self.agent_id,
}
# Inject the application scope using the matching signature format for each SDK variant
if isinstance(self.mem0_client, AsyncMemory):
if self.application_id:
add_kwargs["app_id"] = self.application_id
else:
if self.application_id:
add_kwargs["filters"] = {"app_id": self.application_id}
await self.mem0_client.add(**add_kwargs) # type: ignore[misc, call-arg]
# -- Internal methods ------------------------------------------------------
@@ -173,15 +244,21 @@ class Mem0ContextProvider(ContextProvider):
if not self.agent_id and not self.user_id and not self.application_id:
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")
def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters."""
filters: dict[str, Any] = {}
if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.application_id:
filters["app_id"] = self.application_id
def _build_search_kwargs(self, input_text: str, entity_key: str, entity_value: str) -> dict[str, Any]:
"""Build search keyword arguments formatted for OSS vs Platform clients."""
filters: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
# AsyncMemory (OSS) expects direct kwargs
filters[entity_key] = entity_value
if self.application_id:
filters["app_id"] = self.application_id
else:
# AsyncMemoryClient (Platform) expects a filters dict
filters["filters"] = {entity_key: entity_value}
if self.application_id:
filters["filters"]["app_id"] = self.application_id
return filters
@@ -3,7 +3,7 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import AgentResponse, Message
@@ -193,39 +193,59 @@ class TestBeforeRun:
assert call_kwargs["user_id"] == "u1"
assert "filters" not in call_kwargs
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
"""OSS client with all scoping parameters passes them as direct kwargs."""
@pytest.mark.asyncio
async def test_oss_client_all_scoping_params_except_app_id(self, mock_oss_mem0_client: AsyncMock) -> None:
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
mock_oss_mem0_client.search.return_value = []
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
source_id="mem0",
mem0_client=mock_oss_mem0_client,
user_id="u1",
agent_id="a1"
)
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "hello"
mock_context.input_messages = [mock_msg]
mock_context.response = None
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
assert call_kwargs["user_id"] == "u1"
assert call_kwargs["agent_id"] == "a1"
assert "filters" not in call_kwargs
# Re-aligned assertion: We expect 2 separate concurrent calls instead of 1 combined call
assert mock_oss_mem0_client.search.call_count == 2
mock_oss_mem0_client.search.assert_any_call(query="hello", user_id="u1")
mock_oss_mem0_client.search.assert_any_call(query="hello", agent_id="a1")
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
@pytest.mark.asyncio
async def test_platform_client_passes_filters_dict_except_app_id(self, mock_mem0_client: AsyncMock) -> None:
"""Platform client passes scoping parameters concurrently inside the nested filters dictionary."""
mock_mem0_client.search.return_value = []
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
provider = Mem0ContextProvider(
source_id="mem0",
mem0_client=mock_mem0_client,
user_id="u1",
agent_id="a1",
)
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "hello"
mock_context.input_messages = [mock_msg]
mock_context.response = None
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)
call_kwargs = mock_mem0_client.search.call_args.kwargs
assert call_kwargs["query"] == "Hello"
assert "filters" in call_kwargs
assert call_kwargs["filters"]["user_id"] == "u1"
# Re-aligned assertion: Platform client isolates filters per call to bypass AND limitations
assert mock_mem0_client.search.call_count == 2
mock_mem0_client.search.assert_any_call(query="hello", filters={"user_id": "u1"})
mock_mem0_client.search.assert_any_call(query="hello", filters={"agent_id": "a1"})
# -- after_run tests -----------------------------------------------------------
@@ -318,8 +338,8 @@ class TestAfterRun:
with pytest.raises(ValueError, match="At least one of the filters"):
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None:
"""application_id is passed in metadata."""
async def test_stores_with_application_id_filters(self, mock_mem0_client: AsyncMock) -> None:
"""application_id is passed in filters."""
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1"
)
@@ -331,7 +351,7 @@ class TestAfterRun:
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"}
assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"}
# -- _validate_filters tests --------------------------------------------------
@@ -358,15 +378,20 @@ class TestValidateFilters:
provider._validate_filters()
# -- _build_filters tests -----------------------------------------------------
# -- _build_search_kwargs tests -----------------------------------------------------
class TestBuildFilters:
"""Test _build_filters method."""
class TestBuildSearchKwargs:
"""Test _build_search_kwargs method."""
def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
assert provider._build_filters() == {"user_id": "u1"}
# Pass the 3 required arguments
result = provider._build_search_kwargs("test query", "user_id", "u1")
# AsyncMock triggers the Platform client nested 'filters' structure
assert result == {"query": "test query", "filters": {"user_id": "u1"}}
def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(
@@ -376,28 +401,66 @@ class TestBuildFilters:
agent_id="a1",
application_id="app1",
)
assert provider._build_filters() == {
"user_id": "u1",
"agent_id": "a1",
"app_id": "app1",
# Test that app_id correctly merges with the isolated target entity
result = provider._build_search_kwargs("test query", "agent_id", "a1")
assert result == {
"query": "test query",
"filters": {
"agent_id": "a1",
"app_id": "app1",
},
}
def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters()
assert "agent_id" not in filters
assert "run_id" not in filters
assert "app_id" not in filters
# application_id is None by default, it should not appear in the dictionary
result = provider._build_search_kwargs("test query", "user_id", "u1")
assert "app_id" not in result.get("filters", {})
def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None:
"""run_id is excluded from search filters so memories work across sessions."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters()
assert "run_id" not in filters
result = provider._build_search_kwargs("test query", "user_id", "u1")
assert "run_id" not in result.get("filters", {})
assert "run_id" not in result
def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
# Validates base query payload generation
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
assert provider._build_filters() == {}
result = provider._build_search_kwargs("test query", "custom_key", "custom_val")
assert result == {"query": "test query", "filters": {"custom_key": "custom_val"}}
@pytest.mark.asyncio
async def test_before_run_application_only_fallback(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_mem0_client, application_id="app_fallback_test"
)
# Mock a valid message list and session container setup
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "Retrieve systemic fallback memory traces"
mock_context.input_messages = [mock_msg]
mock_context.response = None
mock_mem0_client.search = AsyncMock(return_value=[{"id": "m1", "memory": "System configuration template"}])
await provider.before_run(
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)
# Verify that an application-scoped search task executed successfully
assert mock_mem0_client.search.call_count == 1
mock_context.extend_messages.assert_called_once()
# -- Context manager tests -----------------------------------------------------
@@ -1997,7 +1997,11 @@ class RawOpenAIChatClient( # type: ignore[misc]
metadata: dict[str, Any] = response.metadata or {}
contents: list[Content] = []
local_shell_tool_name = self._get_local_shell_tool_name(options.get("tools"))
for item in response.output: # type: ignore[reportUnknownMemberType]
try:
response_outputs = response.output # type: ignore[reportUnknownMemberType]
except AttributeError:
response_outputs = []
for item in response_outputs: # type: ignore[reportUnknownVariableType]
match item.type:
# types:
# ParsedResponseOutputMessage[Unknown] |
@@ -788,13 +788,13 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
def _get_metadata_from_chat_response(self, response: ChatCompletion) -> dict[str, Any]:
"""Get metadata from a chat response."""
return {
"system_fingerprint": response.system_fingerprint,
"system_fingerprint": getattr(response, "system_fingerprint", None),
}
def _get_metadata_from_streaming_chat_response(self, response: ChatCompletionChunk) -> dict[str, Any]:
"""Get metadata from a streaming chat response."""
return {
"system_fingerprint": response.system_fingerprint,
"system_fingerprint": getattr(response, "system_fingerprint", None),
}
def _get_metadata_from_chat_choice(self, choice: Choice | ChunkChoice) -> dict[str, Any]: