mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into local-branch-python-add-reset-to-workflow
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Any, cast
|
||||
from ag_ui.core import BaseEvent
|
||||
from agent_framework import SupportsAgentRun
|
||||
|
||||
from ._agent_run import run_agent_stream
|
||||
from ._agent_run import PendingApprovalEntry, run_agent_stream
|
||||
|
||||
|
||||
class AgentConfig:
|
||||
@@ -107,7 +107,7 @@ class AgentFrameworkAgent:
|
||||
# Populated when approval requests are emitted; consumed when responses arrive.
|
||||
# Prevents bypass, function name spoofing, and replay attacks.
|
||||
# Bounded to prevent unbounded growth from abandoned approval requests.
|
||||
self._pending_approvals: OrderedDict[str, str] = OrderedDict()
|
||||
self._pending_approvals: OrderedDict[str, PendingApprovalEntry] = OrderedDict()
|
||||
self._pending_approvals_max_size: int = 10_000
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from ag_ui.core import (
|
||||
BaseEvent,
|
||||
@@ -56,6 +56,7 @@ from ._run_common import (
|
||||
_stringify_tool_result, # type: ignore
|
||||
)
|
||||
from ._utils import (
|
||||
canonical_function_arguments,
|
||||
convert_agui_tools_to_agent_framework,
|
||||
generate_event_id,
|
||||
get_conversation_id_from_update,
|
||||
@@ -407,7 +408,33 @@ def _make_approval_tool_result_events(resolved_approval_results: list[Content])
|
||||
return events
|
||||
|
||||
|
||||
def _evict_oldest_approvals(registry: dict[str, str], max_size: int = 10_000) -> None:
|
||||
class _PendingApproval(TypedDict):
|
||||
"""Pending approval details for a requested function call."""
|
||||
|
||||
name: str
|
||||
arguments: str | None
|
||||
|
||||
|
||||
PendingApprovalEntry = _PendingApproval | str
|
||||
|
||||
|
||||
def _make_pending_approval_entry(name: str, arguments: str | None) -> _PendingApproval:
|
||||
return {"name": name, "arguments": arguments}
|
||||
|
||||
|
||||
def _pending_approval_name(entry: PendingApprovalEntry) -> str | None:
|
||||
if isinstance(entry, str):
|
||||
return entry
|
||||
return entry["name"]
|
||||
|
||||
|
||||
def _pending_approval_arguments(entry: PendingApprovalEntry) -> str | None:
|
||||
if isinstance(entry, str):
|
||||
return None
|
||||
return entry["arguments"]
|
||||
|
||||
|
||||
def _evict_oldest_approvals(registry: dict[str, PendingApprovalEntry], max_size: int = 10_000) -> None:
|
||||
"""Evict the oldest entries from the pending-approvals registry (LRU).
|
||||
|
||||
Only effective when *registry* is an ``OrderedDict``; plain dicts are
|
||||
@@ -427,7 +454,7 @@ async def _resolve_approval_responses(
|
||||
tools: list[Any],
|
||||
agent: SupportsAgentRun,
|
||||
run_kwargs: dict[str, Any],
|
||||
pending_approvals: dict[str, str] | None = None,
|
||||
pending_approvals: dict[str, PendingApprovalEntry] | None = None,
|
||||
thread_id: str = "",
|
||||
) -> list[Content]:
|
||||
"""Execute approved function calls and replace approval content with results.
|
||||
@@ -480,7 +507,8 @@ async def _resolve_approval_responses(
|
||||
invalid_ids.add(resp_id)
|
||||
continue
|
||||
|
||||
pending_name = pending_approvals[registry_key]
|
||||
pending_entry = pending_approvals[registry_key]
|
||||
pending_name = _pending_approval_name(pending_entry)
|
||||
if resp_name != pending_name:
|
||||
logger.warning(
|
||||
"Rejected approval response id=%s: function name mismatch (response=%s, pending=%s)",
|
||||
@@ -491,6 +519,16 @@ async def _resolve_approval_responses(
|
||||
invalid_ids.add(resp_id)
|
||||
continue
|
||||
|
||||
pending_arguments = _pending_approval_arguments(pending_entry)
|
||||
response_arguments = canonical_function_arguments(resp.function_call)
|
||||
if pending_arguments is not None and response_arguments != pending_arguments:
|
||||
logger.warning(
|
||||
"Rejected approval response id=%s: function arguments mismatch",
|
||||
resp_id,
|
||||
)
|
||||
invalid_ids.add(resp_id)
|
||||
continue
|
||||
|
||||
# Valid — consume entry to prevent replay
|
||||
del pending_approvals[registry_key]
|
||||
if resp.approved:
|
||||
@@ -714,7 +752,7 @@ async def run_agent_stream(
|
||||
input_data: dict[str, Any],
|
||||
agent: SupportsAgentRun,
|
||||
config: AgentConfig,
|
||||
pending_approvals: dict[str, str] | None = None,
|
||||
pending_approvals: dict[str, PendingApprovalEntry] | None = None,
|
||||
) -> AsyncGenerator[BaseEvent]:
|
||||
"""Run agent and yield AG-UI events.
|
||||
|
||||
@@ -917,7 +955,10 @@ async def run_agent_stream(
|
||||
# Register pending approval requests so we can validate responses later
|
||||
if content_type == "function_approval_request" and pending_approvals is not None:
|
||||
if content.id and content.function_call and content.function_call.name:
|
||||
pending_approvals[f"{thread_id}:{content.id}"] = content.function_call.name
|
||||
pending_approvals[f"{thread_id}:{content.id}"] = _make_pending_approval_entry(
|
||||
content.function_call.name,
|
||||
canonical_function_arguments(content.function_call),
|
||||
)
|
||||
# Evict oldest entries if the registry exceeds a safe bound (LRU)
|
||||
_evict_oldest_approvals(pending_approvals, max_size=10_000)
|
||||
else:
|
||||
|
||||
@@ -56,6 +56,22 @@ def safe_json_parse(value: Any) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def canonical_function_arguments(function_call: Any) -> str | None:
|
||||
"""Return a stable representation of function-call arguments."""
|
||||
if function_call is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed_arguments = function_call.parse_arguments()
|
||||
except Exception:
|
||||
parsed_arguments = getattr(function_call, "arguments", None)
|
||||
|
||||
if parsed_arguments is None:
|
||||
parsed_arguments = {}
|
||||
|
||||
return json.dumps(make_json_safe(parsed_arguments), sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def get_role_value(message: Any) -> str:
|
||||
"""Extract role string from a message object.
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from ._run_common import (
|
||||
_extract_resume_payload,
|
||||
_normalize_resume_interrupts,
|
||||
)
|
||||
from ._utils import generate_event_id, make_json_safe
|
||||
from ._utils import canonical_function_arguments, generate_event_id, make_json_safe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -324,6 +324,29 @@ def _coerce_response_for_request(request_event: Any, value: Any) -> Any | None:
|
||||
return candidate
|
||||
|
||||
|
||||
def _approval_response_matches_request(request_id: str, request_event: Any, response: Any) -> bool:
|
||||
"""Check whether an approval response matches the pending approval request."""
|
||||
request_data = getattr(request_event, "data", None)
|
||||
if not isinstance(request_data, Content) or request_data.type != "function_approval_request":
|
||||
return True
|
||||
|
||||
if not isinstance(response, Content) or response.type != "function_approval_response":
|
||||
return False
|
||||
|
||||
if str(getattr(response, "id", "")) != request_id:
|
||||
return False
|
||||
|
||||
request_call = getattr(request_data, "function_call", None)
|
||||
response_call = getattr(response, "function_call", None)
|
||||
if request_call is None or response_call is None:
|
||||
return False
|
||||
|
||||
if getattr(response_call, "name", None) != getattr(request_call, "name", None):
|
||||
return False
|
||||
|
||||
return canonical_function_arguments(response_call) == canonical_function_arguments(request_call)
|
||||
|
||||
|
||||
def _single_pending_response_from_value(pending_events: dict[str, Any], value: Any) -> dict[str, Any]:
|
||||
"""Map a scalar resume payload to the single pending request (if unambiguous)."""
|
||||
if value is None or len(pending_events) != 1:
|
||||
@@ -343,6 +366,13 @@ def _single_pending_response_from_value(pending_events: dict[str, Any], value: A
|
||||
)
|
||||
return {}
|
||||
|
||||
if not _approval_response_matches_request(str(request_id), request_event, coerced_value):
|
||||
logger.info(
|
||||
"Ignoring pending request response for request_id=%s: approval response does not match pending request",
|
||||
request_id,
|
||||
)
|
||||
return {}
|
||||
|
||||
return {str(request_id): coerced_value}
|
||||
|
||||
|
||||
@@ -372,6 +402,12 @@ def _coerce_responses_for_pending_requests(
|
||||
_response_type_name(request_event),
|
||||
)
|
||||
continue
|
||||
if not _approval_response_matches_request(request_key, request_event, coerced_value):
|
||||
logger.info(
|
||||
"Ignoring resume response for request_id=%s: approval response does not match pending request",
|
||||
request_key,
|
||||
)
|
||||
continue
|
||||
normalized[request_key] = coerced_value
|
||||
return normalized
|
||||
|
||||
|
||||
@@ -1407,6 +1407,92 @@ async def test_fabricated_rejection_without_pending_approval_is_blocked(streamin
|
||||
assert False, "Fabricated rejection response leaked as function_result into LLM messages"
|
||||
|
||||
|
||||
async def test_approval_argument_mismatch_is_blocked(streaming_chat_client_stub):
|
||||
"""An approval response must not execute changed arguments for the pending call."""
|
||||
from agent_framework import tool
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
executed_args: list[dict[str, Any]] = []
|
||||
|
||||
@tool(
|
||||
name="update_record",
|
||||
description="Update a record",
|
||||
approval_mode="always_require",
|
||||
)
|
||||
def update_record(record_id: str, value: str) -> str:
|
||||
executed_args.append({"record_id": record_id, "value": value})
|
||||
return f"updated {record_id} to {value}"
|
||||
|
||||
async def stream_fn_approval(
|
||||
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
name="update_record",
|
||||
call_id="call_update_001",
|
||||
arguments={"record_id": "alpha", "value": "approved"},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
wrapper = AgentFrameworkAgent(
|
||||
agent=Agent(
|
||||
client=streaming_chat_client_stub(stream_fn_approval),
|
||||
name="test_agent",
|
||||
instructions="Test",
|
||||
tools=[update_record],
|
||||
)
|
||||
)
|
||||
thread_id = "thread-argument-mismatch-test"
|
||||
|
||||
events1: list[Any] = []
|
||||
async for event in wrapper.run({"thread_id": thread_id, "messages": [{"role": "user", "content": "update"}]}):
|
||||
events1.append(event)
|
||||
|
||||
assert any("call_update_001" in k for k in wrapper._pending_approvals)
|
||||
|
||||
async def stream_fn_post(
|
||||
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[Content.from_text(text="Done")])
|
||||
|
||||
wrapper.agent = Agent(
|
||||
client=streaming_chat_client_stub(stream_fn_post),
|
||||
name="test_agent",
|
||||
instructions="Test",
|
||||
tools=[update_record],
|
||||
)
|
||||
|
||||
turn2_input: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "approve",
|
||||
"function_approvals": [
|
||||
{
|
||||
"id": "call_update_001",
|
||||
"call_id": "call_update_001",
|
||||
"name": "update_record",
|
||||
"approved": True,
|
||||
"arguments": {"record_id": "beta", "value": "changed"},
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
events2: list[Any] = []
|
||||
async for event in wrapper.run(turn2_input):
|
||||
events2.append(event)
|
||||
|
||||
assert executed_args == []
|
||||
assert any("call_update_001" in k for k in wrapper._pending_approvals), (
|
||||
"Pending approval should be preserved after argument mismatch for legitimate retry"
|
||||
)
|
||||
|
||||
|
||||
async def test_state_update_end_to_end_via_real_tool_invocation(streaming_chat_client_stub):
|
||||
"""End-to-end coverage for issue #3167: a real ``@tool`` returning ``state_update`` must
|
||||
emit a deterministic STATE_SNAPSHOT through the full pipeline.
|
||||
|
||||
@@ -1352,6 +1352,70 @@ async def test_workflow_run_approval_via_messages_approved() -> None:
|
||||
assert not resumed_finished.get("interrupt")
|
||||
|
||||
|
||||
async def test_workflow_run_approval_argument_mismatch_keeps_interrupt_pending() -> None:
|
||||
"""Workflow approval responses must not resume with changed function arguments."""
|
||||
|
||||
handled_responses: list[dict[str, Any]] = []
|
||||
|
||||
class ApprovalExecutor(Executor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(id="approval_executor")
|
||||
|
||||
@handler
|
||||
async def start(self, message: Any, ctx: WorkflowContext) -> None:
|
||||
del message
|
||||
function_call = Content.from_function_call(
|
||||
call_id="refund-call",
|
||||
name="submit_refund",
|
||||
arguments={"order_id": "12345", "amount": "$89.99"},
|
||||
)
|
||||
approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call)
|
||||
await ctx.request_info(approval_request, Content, request_id="approval-1")
|
||||
|
||||
@response_handler
|
||||
async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None:
|
||||
del original_request
|
||||
if response.function_call is not None:
|
||||
handled_responses.append(response.function_call.parse_arguments() or {})
|
||||
await ctx.yield_output("handled")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build()
|
||||
first_events = [
|
||||
event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)
|
||||
]
|
||||
first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump()
|
||||
interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt"))
|
||||
assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1
|
||||
|
||||
resumed_events = [
|
||||
event
|
||||
async for event in run_workflow_stream(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
"function_approvals": [
|
||||
{
|
||||
"approved": True,
|
||||
"id": "approval-1",
|
||||
"call_id": "refund-call",
|
||||
"name": "submit_refund",
|
||||
"arguments": {"order_id": "99999", "amount": "$1000.00"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
workflow,
|
||||
)
|
||||
]
|
||||
|
||||
assert handled_responses == []
|
||||
resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump()
|
||||
assert resumed_finished.get("interrupt")
|
||||
|
||||
|
||||
async def test_workflow_run_approval_via_messages_denied() -> None:
|
||||
"""Denied approval response sent via messages (function_approvals) should satisfy the pending request."""
|
||||
|
||||
|
||||
@@ -221,9 +221,31 @@ class ClaudeAgentOptions(TypedDict, total=False):
|
||||
thinking: ThinkingConfig
|
||||
"""Extended thinking configuration (adaptive, enabled, or disabled)."""
|
||||
|
||||
effort: Literal["low", "medium", "high", "max"]
|
||||
effort: Literal["low", "medium", "high", "xhigh", "max"]
|
||||
"""Effort level for thinking depth."""
|
||||
|
||||
skills: list[str] | Literal["all"]
|
||||
"""Skills to enable for the main session. Use ``"all"`` for every discovered skill,
|
||||
a list of named skills, or ``[]`` to suppress all skills."""
|
||||
|
||||
session_id: str
|
||||
"""Use a specific session ID (must be a valid UUID) instead of auto-generated."""
|
||||
|
||||
task_budget: dict[str, int]
|
||||
"""API-side task budget in tokens for pacing tool use."""
|
||||
|
||||
include_hook_events: bool
|
||||
"""When True, hook lifecycle events are emitted in the message stream."""
|
||||
|
||||
strict_mcp_config: bool
|
||||
"""When True, only use MCP servers passed via ``mcp_servers``, ignoring all others."""
|
||||
|
||||
continue_conversation: bool
|
||||
"""Continue the most recent conversation instead of starting a new one."""
|
||||
|
||||
fork_session: bool
|
||||
"""When True, resumed sessions fork to a new session ID."""
|
||||
|
||||
on_function_approval: FunctionApprovalCallback
|
||||
"""Approval callback for ``FunctionTool`` instances declared with
|
||||
``approval_mode="always_require"``. The callback is awaited (sync or async)
|
||||
|
||||
@@ -24,7 +24,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.6.0,<2",
|
||||
"claude-agent-sdk>=0.1.36,<0.1.49",
|
||||
"claude-agent-sdk>=0.1.36,<0.3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -80,6 +80,8 @@ agent_framework/
|
||||
|
||||
- **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s.
|
||||
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
|
||||
- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument.
|
||||
- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins.
|
||||
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
|
||||
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
|
||||
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
|
||||
|
||||
@@ -92,12 +92,16 @@ OptionsCoT = TypeVar(
|
||||
def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Merge two options dicts, with override values taking precedence.
|
||||
|
||||
``None`` is treated as "unset": ``None`` overrides are skipped so they don't clobber a base
|
||||
value, and the merged result is stripped of any remaining ``None`` values in a final pass so
|
||||
unset options are never forwarded (e.g. an unset ``store`` is left for the service to default).
|
||||
|
||||
Args:
|
||||
base: The base options dict.
|
||||
override: The override options dict (values take precedence).
|
||||
|
||||
Returns:
|
||||
A new merged options dict.
|
||||
A new merged options dict containing no ``None`` values.
|
||||
"""
|
||||
result = dict(base)
|
||||
|
||||
@@ -123,7 +127,7 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
|
||||
result["instructions"] = f"{result['instructions']}\n{value}"
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
return {key: value for key, value in result.items() if value is not None}
|
||||
|
||||
|
||||
def _sanitize_agent_name(agent_name: str | None) -> str | None:
|
||||
@@ -460,6 +464,9 @@ class BaseAgent(SerializationMixin):
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
# When per-service-call persistence is enabled, the per-service-call middleware owns
|
||||
# HistoryProvider persistence (in both the local and service-managed cases), so skip
|
||||
# them on the once-per-run path to avoid double persistence.
|
||||
per_service_call_history_required = self.require_per_service_call_history_persistence and any(
|
||||
isinstance(provider, HistoryProvider) for provider in self.context_providers
|
||||
)
|
||||
@@ -686,11 +693,16 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
description: A brief description of the agent's purpose.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
require_per_service_call_history_persistence: When True, history providers are invoked
|
||||
around each model call instead of once per ``run()`` when the service
|
||||
is not already storing history. If service-side storage is active for
|
||||
the run, the agent skips local history providers and relies on the
|
||||
service-managed conversation instead.
|
||||
require_per_service_call_history_persistence: When True (and a HistoryProvider is
|
||||
present), the provider always persists history via per-service-call middleware,
|
||||
regardless of whether the client stores history server-side. If the client does
|
||||
not store history, the middleware also loads providers around each model call and
|
||||
drives the function loop with a local conversation; if it does, loading is skipped
|
||||
(the service-managed conversation is the source of truth) and the middleware only
|
||||
persists. A warning is logged for providers with ``load_messages=True`` when
|
||||
loading is skipped because service-side storage is active. When no HistoryProvider
|
||||
is present, this flag has no effect (no middleware is installed and nothing is
|
||||
persisted).
|
||||
default_options: A TypedDict containing chat options. When using a typed agent like
|
||||
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
|
||||
provider-specific options including temperature, max_tokens, model,
|
||||
@@ -791,22 +803,20 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
self,
|
||||
*,
|
||||
session: AgentSession | None,
|
||||
options: Mapping[str, Any] | None,
|
||||
conversation_id: str | None,
|
||||
service_stores_history: bool,
|
||||
) -> list[HistoryProvider]:
|
||||
history_providers = self._get_history_providers()
|
||||
if not self.require_per_service_call_history_persistence or not history_providers:
|
||||
return []
|
||||
|
||||
conversation_id = (
|
||||
session.service_session_id
|
||||
if session and session.service_session_id
|
||||
else cast(str | None, (options or {}).get("conversation_id") or self.default_options.get("conversation_id"))
|
||||
)
|
||||
if service_stores_history:
|
||||
return []
|
||||
|
||||
if conversation_id is not None:
|
||||
# A live service-managed session id takes precedence over the resolved conversation id.
|
||||
if session and session.service_session_id:
|
||||
conversation_id = session.service_session_id
|
||||
# Without service-side storage the middleware persists locally and drives the function
|
||||
# loop with a local sentinel, which cannot be reconciled with an existing service-managed
|
||||
# conversation. When the service stores history, an existing conversation id is expected.
|
||||
if conversation_id is not None and not service_stores_history:
|
||||
raise AgentInvalidRequestException(
|
||||
"require_per_service_call_history_persistence cannot be used "
|
||||
"with an existing service-managed conversation."
|
||||
@@ -1167,18 +1177,34 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
input_messages = normalize_messages(messages)
|
||||
|
||||
# `store` in runtime or agent options takes precedence over client-level storage
|
||||
# indicators. An explicit `store=False` forces local (in-memory) history injection,
|
||||
# even if the client is configured to use service-side storage by default.
|
||||
store_ = opts.get("store", self.default_options.get("store", getattr(self.client, "STORES_BY_DEFAULT", False)))
|
||||
# Combine agent-level defaults with runtime options up front so the decisions below read
|
||||
# `store` from a single place rather than introspecting both dicts. _merge_options applies
|
||||
# the same precedence used for the actual client call (runtime wins; unset/None falls back
|
||||
# to the agent default).
|
||||
effective_options = _merge_options(self.default_options, opts)
|
||||
|
||||
# `store` in runtime or agent options takes precedence over the client's default
|
||||
# storage behavior. An explicit `store=False` forces local (in-memory) history
|
||||
# injection even when the client stores server-side by default; an explicit
|
||||
# `store=True` forces service-side storage. A `store=None`/unset value means the
|
||||
# service falls back to its own default.
|
||||
explicit_store = effective_options.get("store")
|
||||
# Internal behavior hint: will the service own history for this run? Only when the
|
||||
# user left `store` unset do we fall back to the client's STORES_BY_DEFAULT.
|
||||
service_stores_history = (
|
||||
explicit_store if explicit_store is not None else getattr(self.client, "STORES_BY_DEFAULT", False)
|
||||
)
|
||||
# Resolve conversation_id from the same combined view so an agent-level default is honored
|
||||
# when the runtime omits it (a live session id still takes precedence below).
|
||||
effective_conversation_id = effective_options.get("conversation_id")
|
||||
# Auto-inject InMemoryHistoryProvider when session is provided, no context providers
|
||||
# registered, and no service-side storage indicators
|
||||
if (
|
||||
session is not None
|
||||
and not self.context_providers
|
||||
and not session.service_session_id
|
||||
and not opts.get("conversation_id")
|
||||
and not store_
|
||||
and not effective_conversation_id
|
||||
and not service_stores_history
|
||||
):
|
||||
self.context_providers.append(InMemoryHistoryProvider())
|
||||
|
||||
@@ -1188,10 +1214,30 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
per_service_call_history_providers = self._resolve_per_service_call_history_providers(
|
||||
session=active_session,
|
||||
options=opts,
|
||||
service_stores_history=bool(store_),
|
||||
conversation_id=effective_conversation_id,
|
||||
service_stores_history=service_stores_history,
|
||||
)
|
||||
|
||||
# When require_per_service_call_history_persistence is set together with a
|
||||
# HistoryProvider, the per-service-call middleware (installed below) always persists
|
||||
# the provider. ``service_stores_history`` only selects how the middleware behaves:
|
||||
# - service does not store: the middleware also loads providers and drives the function
|
||||
# loop with a local sentinel conversation id, or
|
||||
# - service stores: the middleware skips loading (the service owns history) and simply
|
||||
# persists each service call while the real conversation id flows through.
|
||||
# In the service-managed case loading is skipped, so warn for providers that expect to load.
|
||||
history_providers = self._get_history_providers()
|
||||
if self.require_per_service_call_history_persistence and history_providers and service_stores_history:
|
||||
for provider in history_providers:
|
||||
if provider.load_messages:
|
||||
logger.warning(
|
||||
"HistoryProvider '%s' has load_messages=True but the chat client stores history "
|
||||
"server-side; skipping local history load and relying on the service-managed "
|
||||
"conversation. Set store=False to load from the provider, or load_messages=False "
|
||||
"to silence this warning.",
|
||||
provider.source_id,
|
||||
)
|
||||
|
||||
session_context, chat_options = await self._prepare_session_and_messages(
|
||||
session=active_session,
|
||||
input_messages=input_messages,
|
||||
@@ -1265,8 +1311,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
}
|
||||
if model is not None:
|
||||
run_opts["model"] = model
|
||||
# Remove None values and merge with chat_options
|
||||
run_opts = {k: v for k, v in run_opts.items() if v is not None}
|
||||
# _merge_options strips unset (None) options, so e.g. an unset `store` is not forwarded
|
||||
# and the service decides its own default.
|
||||
co = _merge_options(chat_options, run_opts)
|
||||
|
||||
# Build session_messages from session context: context messages + input messages
|
||||
@@ -1280,6 +1326,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
agent=self,
|
||||
session=active_session,
|
||||
providers=per_service_call_history_providers,
|
||||
service_stores_history=service_stores_history,
|
||||
)
|
||||
existing_middleware = effective_client_kwargs.get("middleware")
|
||||
if isinstance(existing_middleware, Sequence) and not isinstance(existing_middleware, (str, bytes)):
|
||||
@@ -1319,7 +1366,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
"input_messages": input_messages,
|
||||
"session_messages": session_messages,
|
||||
"agent_name": agent_name,
|
||||
"suppress_response_id": bool(per_service_call_history_providers),
|
||||
"suppress_response_id": bool(per_service_call_history_providers) and not service_stores_history,
|
||||
"chat_options": co,
|
||||
"compaction_strategy": compaction_strategy or self.compaction_strategy,
|
||||
"tokenizer": tokenizer or self.tokenizer,
|
||||
@@ -1413,11 +1460,15 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=options or {},
|
||||
)
|
||||
|
||||
# When per-service-call persistence is enabled, the per-service-call middleware owns
|
||||
# HistoryProvider loading (it loads locally when the service does not store history, or
|
||||
# relies on the service when it does), so skip them on the once-per-run before_run path.
|
||||
per_service_call_history_required = self.require_per_service_call_history_persistence and bool(
|
||||
self._get_history_providers()
|
||||
)
|
||||
|
||||
# Run before_run providers (forward order, skip HistoryProvider when per-service-call persistence owns history)
|
||||
# Run before_run providers (forward order, skip HistoryProvider when per-service-call
|
||||
# persistence owns loading)
|
||||
for provider in self.context_providers:
|
||||
if per_service_call_history_required and isinstance(provider, HistoryProvider):
|
||||
continue
|
||||
|
||||
@@ -604,10 +604,13 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
and dict literals are accepted without specialized option typing.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
require_per_service_call_history_persistence: Whether to require per-service-call
|
||||
chat history persistence. When enabled, history providers are invoked around
|
||||
each model call instead of once per ``run()`` when the service is not already
|
||||
storing history.
|
||||
require_per_service_call_history_persistence: When enabled (and a HistoryProvider is
|
||||
present), the provider always persists history after each model call. If the
|
||||
client does not store history server-side, history providers are also loaded and
|
||||
injected around each model call; if it does, provider loading is skipped and the
|
||||
service-managed conversation is the source of truth (persistence still happens
|
||||
after each model call). When no HistoryProvider is present, this flag has no
|
||||
effect (no middleware is installed and nothing is persisted).
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level compaction override. When omitted,
|
||||
client-level compaction defaults remain in effect for each call.
|
||||
|
||||
@@ -70,6 +70,31 @@ class MCPSpecificApproval(TypedDict, total=False):
|
||||
|
||||
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
|
||||
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
|
||||
# Reserved key in an ``additional_tool_argument_names`` mapping that applies its
|
||||
# values to every tool on the server rather than a single named tool.
|
||||
_MCP_GLOBAL_EXTRA_ARGS_KEY = "*"
|
||||
# Framework kwargs that flow through the function-invocation pipeline (via
|
||||
# ``FunctionInvocationContext.kwargs``) but must never be forwarded to an MCP
|
||||
# server: they are internal objects that the MCP SDK cannot serialize. They are
|
||||
# dropped as a safety net when a tool declares one of them in its schema, unless
|
||||
# the user explicitly opts the name back in via ``additional_tool_argument_names``
|
||||
# (explicit extras always win over the denylist).
|
||||
# - chat_options/tools/tool_choice/session/thread: framework runtime objects.
|
||||
# - conversation_id: internal tracking ID used by services like Azure AI.
|
||||
# - options: metadata/store used by AG-UI for Azure AI client requirements.
|
||||
# - response_format: a Pydantic model class for structured output (not serializable).
|
||||
# - _meta: reserved key extracted separately as MCP request metadata.
|
||||
_MCP_FRAMEWORK_DENYLIST: frozenset[str] = frozenset({
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"thread",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
})
|
||||
_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers")
|
||||
MCP_DEFAULT_TIMEOUT = 30
|
||||
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5
|
||||
@@ -135,6 +160,34 @@ def _build_prefixed_mcp_name(
|
||||
return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix
|
||||
|
||||
|
||||
def _normalize_additional_tool_argument_names(
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None,
|
||||
) -> tuple[set[str], dict[str, set[str]]]:
|
||||
"""Split user-supplied extra argument names into global and per-tool sets.
|
||||
|
||||
Accepts either a sequence (applied to every tool) or a mapping keyed by remote
|
||||
tool name, where the reserved key ``"*"`` is treated as global. Mapping values
|
||||
may be a sequence or a single string. Returns a
|
||||
``(global_extras, per_tool_extras)`` tuple.
|
||||
"""
|
||||
if additional_tool_argument_names is None:
|
||||
return set(), {}
|
||||
if isinstance(additional_tool_argument_names, str):
|
||||
return {additional_tool_argument_names}, {}
|
||||
if isinstance(additional_tool_argument_names, Mapping):
|
||||
global_extras: set[str] = set()
|
||||
per_tool_extras: dict[str, set[str]] = {}
|
||||
for tool_name, names in additional_tool_argument_names.items():
|
||||
# Treat a bare string value as a single name rather than iterating its characters.
|
||||
names_set = {names} if isinstance(names, str) else set(names)
|
||||
if tool_name == _MCP_GLOBAL_EXTRA_ARGS_KEY:
|
||||
global_extras.update(names_set)
|
||||
else:
|
||||
per_tool_extras[tool_name] = names_set
|
||||
return global_extras, per_tool_extras
|
||||
return set(additional_tool_argument_names), {}
|
||||
|
||||
|
||||
def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None:
|
||||
"""Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s)."""
|
||||
carrier: dict[str, str] = {}
|
||||
@@ -294,6 +347,7 @@ class MCPTool:
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the MCP Tool base.
|
||||
|
||||
@@ -328,6 +382,10 @@ class MCPTool:
|
||||
task_options: Options controlling how long-running MCP tasks are driven for
|
||||
tools that advertise ``execution.taskSupport == "required"``. When ``None``,
|
||||
the defaults from :class:`MCPTaskOptions` are used.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server
|
||||
in addition to each tool's declared parameters. A ``Sequence[str]`` applies to
|
||||
every tool; a ``Mapping[str, Sequence[str]]`` is keyed by remote tool name with
|
||||
``"*"`` as a global key. See the transport subclasses for full details.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description or ""
|
||||
@@ -355,6 +413,10 @@ class MCPTool:
|
||||
self._functions: list[FunctionTool] = []
|
||||
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
self._tool_task_support_by_name: dict[str, str] = {}
|
||||
self._tool_param_names_by_name: dict[str, set[str]] = {}
|
||||
self._global_extra_arg_names, self._tool_extra_arg_names = _normalize_additional_tool_argument_names(
|
||||
additional_tool_argument_names
|
||||
)
|
||||
self.is_connected: bool = False
|
||||
self._tools_loaded: bool = False
|
||||
self._prompts_loaded: bool = False
|
||||
@@ -1229,6 +1291,7 @@ class MCPTool:
|
||||
existing_names = {func.name for func in self._functions}
|
||||
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
tool_task_support_by_name: dict[str, str] = {}
|
||||
tool_param_names_by_name: dict[str, set[str]] = {}
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
@@ -1271,14 +1334,6 @@ class MCPTool:
|
||||
if task_support is not None:
|
||||
tool_task_support_by_name[tool.name] = task_support
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
|
||||
# Normalize inputSchema: ensure "properties" exists for object schemas.
|
||||
# Some MCP servers (e.g. zero-argument tools) omit "properties",
|
||||
# which causes OpenAI API to reject the schema with a 400 error.
|
||||
@@ -1288,6 +1343,24 @@ class MCPTool:
|
||||
if input_schema.get("type") == "object" and "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
|
||||
# Register declared param names before the existing-tool skip below so that
|
||||
# reloads (e.g. notifications/tools/list_changed) preserve the allowlist for
|
||||
# tools that are already loaded, consistent with tool_call_meta_by_name and
|
||||
# tool_task_support_by_name above.
|
||||
schema_properties = input_schema.get("properties")
|
||||
tool_param_names_by_name[tool.name] = (
|
||||
set(cast(dict[str, Any], schema_properties)) if isinstance(schema_properties, dict) else set()
|
||||
)
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
|
||||
|
||||
async def _call_tool_with_runtime_kwargs(
|
||||
ctx: FunctionInvocationContext,
|
||||
*,
|
||||
@@ -1320,6 +1393,7 @@ class MCPTool:
|
||||
|
||||
self._tool_call_meta_by_name = tool_call_meta_by_name
|
||||
self._tool_task_support_by_name = tool_task_support_by_name
|
||||
self._tool_param_names_by_name = tool_param_names_by_name
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
# Cancel any pending reload tasks before tearing down the session.
|
||||
@@ -1530,10 +1604,14 @@ class MCPTool:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
|
||||
|
||||
def _resolved_extra_args(self, tool_name: str) -> set[str]:
|
||||
"""Return the user-configured extra argument names allowed for a tool."""
|
||||
return self._global_extra_arg_names | self._tool_extra_arg_names.get(tool_name, set())
|
||||
|
||||
def _prepare_call_kwargs(
|
||||
self, tool_name: str, kwargs: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], dict[str, Any] | None]:
|
||||
"""Filter framework-only kwargs and build the merged MCP request metadata."""
|
||||
"""Filter kwargs down to the tool's arguments and build the merged MCP request metadata."""
|
||||
raw_user_meta: object | None = kwargs.get("_meta")
|
||||
user_meta: dict[str, Any] | None = None
|
||||
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
|
||||
@@ -1546,27 +1624,28 @@ class MCPTool:
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
|
||||
user_meta[key] = value
|
||||
|
||||
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
|
||||
# These are internal objects passed through the function invocation pipeline
|
||||
# that should not be forwarded to external MCP servers.
|
||||
# conversation_id is an internal tracking ID used by services like Azure AI.
|
||||
# options contains metadata/store used by AG-UI for Azure AI client requirements.
|
||||
# response_format is a Pydantic model class used for structured output (not serializable).
|
||||
# Allowlist: forward only the tool's declared parameters (from inputSchema.properties)
|
||||
# plus any user-configured extra argument names. Everything else - notably the
|
||||
# framework runtime kwargs injected through the function-invocation pipeline - is
|
||||
# stripped so it is never forwarded to the MCP server. Tools that declare no usable
|
||||
# properties forward only the user-configured extras.
|
||||
#
|
||||
# The extra names come exclusively from additional_tool_argument_names, which is set in
|
||||
# user code at construction time; there is no per-call override, so a model-issued tool
|
||||
# call cannot change which names are allowed through.
|
||||
#
|
||||
# The framework denylist acts as a safety net for keys a server *declares* in its
|
||||
# schema that collide with internal, non-serializable framework objects (e.g. a tool
|
||||
# that declares a parameter literally named "thread"): such declared-but-denylisted
|
||||
# keys are dropped. Names the user explicitly opts in via additional_tool_argument_names
|
||||
# always win. The reserved _meta key is handled separately above and never forwarded as
|
||||
# an argument.
|
||||
declared = self._tool_param_names_by_name.get(tool_name, set())
|
||||
extras = self._resolved_extra_args(tool_name)
|
||||
filtered_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in {
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"thread",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
}
|
||||
if k != "_meta" and (k in extras or (k in declared and k not in _MCP_FRAMEWORK_DENYLIST))
|
||||
}
|
||||
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
@@ -2142,6 +2221,7 @@ class MCPStdioTool(MCPTool):
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP stdio tool.
|
||||
@@ -2188,6 +2268,20 @@ class MCPStdioTool(MCPTool):
|
||||
client: The chat client to use for sampling.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
addition to each tool's declared parameters (from its ``inputSchema.properties``).
|
||||
By default only declared parameters are sent; framework runtime kwargs injected
|
||||
through the function-invocation pipeline are stripped. Use this to opt specific
|
||||
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
|
||||
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
|
||||
``"*"`` applies to every tool. This is configured only here in user code; there is
|
||||
no per-call override, so a model-issued tool call cannot change which names pass
|
||||
through. To use a server that accepts ``additionalProperties: true``, list the
|
||||
extra names here and then either (1) manually extend that tool's ``inputSchema``
|
||||
(via the ``.functions`` list after connecting) so the model is prompted to supply
|
||||
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
|
||||
a name is supplied via both the model and ``function_invocation_kwargs``, the
|
||||
model-supplied value wins.
|
||||
kwargs: Any extra arguments to pass to the stdio client.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -2205,6 +2299,7 @@ class MCPStdioTool(MCPTool):
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
)
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
@@ -2284,6 +2379,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
http_client: AsyncClient | None = None,
|
||||
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP streamable HTTP tool.
|
||||
@@ -2338,6 +2434,20 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
agent middleware) without creating a separate ``httpx.AsyncClient``.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
addition to each tool's declared parameters (from its ``inputSchema.properties``).
|
||||
By default only declared parameters are sent; framework runtime kwargs injected
|
||||
through the function-invocation pipeline are stripped. Use this to opt specific
|
||||
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
|
||||
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
|
||||
``"*"`` applies to every tool. This is configured only here in user code; there is
|
||||
no per-call override, so a model-issued tool call cannot change which names pass
|
||||
through. To use a server that accepts ``additionalProperties: true``, list the
|
||||
extra names here and then either (1) manually extend that tool's ``inputSchema``
|
||||
(via the ``.functions`` list after connecting) so the model is prompted to supply
|
||||
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
|
||||
a name is supplied via both the model and ``function_invocation_kwargs``, the
|
||||
model-supplied value wins.
|
||||
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -2355,6 +2465,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
)
|
||||
self.url = url
|
||||
self.terminate_on_close = terminate_on_close
|
||||
@@ -2481,6 +2592,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP WebSocket tool.
|
||||
@@ -2525,6 +2637,20 @@ class MCPWebsocketTool(MCPTool):
|
||||
client: The chat client to use for sampling.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
addition to each tool's declared parameters (from its ``inputSchema.properties``).
|
||||
By default only declared parameters are sent; framework runtime kwargs injected
|
||||
through the function-invocation pipeline are stripped. Use this to opt specific
|
||||
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
|
||||
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
|
||||
``"*"`` applies to every tool. This is configured only here in user code; there is
|
||||
no per-call override, so a model-issued tool call cannot change which names pass
|
||||
through. To use a server that accepts ``additionalProperties: true``, list the
|
||||
extra names here and then either (1) manually extend that tool's ``inputSchema``
|
||||
(via the ``.functions`` list after connecting) so the model is prompted to supply
|
||||
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
|
||||
a name is supplied via both the model and ``function_invocation_kwargs``, the
|
||||
model-supplied value wins.
|
||||
kwargs: Any extra arguments to pass to the WebSocket client.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -2542,6 +2668,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
)
|
||||
self.url = url
|
||||
self._client_kwargs = kwargs
|
||||
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import weakref
|
||||
@@ -36,6 +37,8 @@ if TYPE_CHECKING:
|
||||
from ._middleware import MiddlewareTypes
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
# Registry of known types for state deserialization
|
||||
_STATE_TYPE_REGISTRY: dict[str, type] = {}
|
||||
|
||||
@@ -580,6 +583,7 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
agent: SupportsAgentRun,
|
||||
session: AgentSession,
|
||||
providers: Sequence[HistoryProvider],
|
||||
service_stores_history: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
|
||||
@@ -587,10 +591,16 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
agent: The agent that owns the history providers.
|
||||
session: The active session for the current run.
|
||||
providers: The history providers participating in per-service-call persistence.
|
||||
service_stores_history: When True, the chat client stores history server-side. The
|
||||
middleware then skips loading providers and leaves the real conversation id
|
||||
untouched, persisting each service call without driving the function loop with a
|
||||
local sentinel. When False, the middleware loads providers and uses a local
|
||||
sentinel conversation id so the function loop runs without service-side storage.
|
||||
"""
|
||||
self._agent = agent
|
||||
self._session = session
|
||||
self._providers = list(providers)
|
||||
self._service_stores_history = service_stores_history
|
||||
|
||||
async def _prepare_service_call_context(self, messages: Sequence[Message]) -> SessionContext:
|
||||
"""Create a per-call SessionContext and load history providers into it."""
|
||||
@@ -602,6 +612,9 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
)
|
||||
for source_id, source_messages in context_messages.items():
|
||||
service_call_context.extend_messages(source_id, source_messages)
|
||||
# When the service stores history, it owns loading; the providers are write-only sinks.
|
||||
if self._service_stores_history:
|
||||
return service_call_context
|
||||
for provider in self._providers:
|
||||
if not provider.load_messages:
|
||||
continue
|
||||
@@ -652,17 +665,35 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
response: ChatResponse,
|
||||
) -> ChatResponse:
|
||||
"""Persist a model response and apply the local follow-up sentinel when needed."""
|
||||
if response.conversation_id is not None and not is_local_history_conversation_id(response.conversation_id):
|
||||
if (
|
||||
not self._service_stores_history
|
||||
and response.conversation_id is not None
|
||||
and not is_local_history_conversation_id(response.conversation_id)
|
||||
):
|
||||
raise ChatClientInvalidResponseException(
|
||||
"require_per_service_call_history_persistence cannot be used "
|
||||
"when the chat client returns a real conversation_id."
|
||||
)
|
||||
|
||||
# In storing mode the service is expected to echo a conversation id that the next run
|
||||
# resumes from. If it comes back empty, the provider still captures this turn but there is
|
||||
# no service id to load from next time, so cross-turn history can be lost silently. Warn
|
||||
# every time so this uncommon, easy-to-miss failure mode cannot fail quietly.
|
||||
if self._service_stores_history and response.conversation_id is None:
|
||||
logger.warning(
|
||||
"require_per_service_call_history_persistence is enabled with a chat client that "
|
||||
"stores history server-side, but the client returned no conversation_id; cross-turn "
|
||||
"history may not resume. Set store=False to load and resume from the HistoryProvider "
|
||||
"instead."
|
||||
)
|
||||
|
||||
await self._persist_service_call_response(
|
||||
service_call_context=service_call_context,
|
||||
response=response,
|
||||
)
|
||||
if _response_contains_follow_up_request(response):
|
||||
# The local sentinel only applies when the service does not store history; when it does,
|
||||
# the real conversation id already drives function-loop continuation.
|
||||
if not self._service_stores_history and _response_contains_follow_up_request(response):
|
||||
response.mark_internal_conversation_id()
|
||||
response.conversation_id = LOCAL_HISTORY_CONVERSATION_ID
|
||||
return response
|
||||
@@ -681,8 +712,12 @@ class PerServiceCallHistoryPersistingMiddleware(ChatMiddleware):
|
||||
result type for streaming or non-streaming execution.
|
||||
"""
|
||||
service_call_context = await self._prepare_service_call_context(context.messages)
|
||||
context.messages = service_call_context.get_messages(include_input=True)
|
||||
self._strip_local_conversation_id(context)
|
||||
# When the service stores history, leave the outgoing messages and the real conversation
|
||||
# id untouched (pass-through); the middleware only persists. Otherwise reconstruct the
|
||||
# outgoing messages from the loaded local history and strip the local sentinel.
|
||||
if not self._service_stores_history:
|
||||
context.messages = service_call_context.get_messages(include_input=True)
|
||||
self._strip_local_conversation_id(context)
|
||||
|
||||
await call_next()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
|
||||
from typing import Any, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -42,6 +43,8 @@ from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_m
|
||||
from agent_framework._middleware import FunctionInvocationContext
|
||||
from agent_framework.exceptions import AgentInvalidRequestException, ChatClientInvalidResponseException
|
||||
|
||||
from .conftest import MockBaseChatClient
|
||||
|
||||
|
||||
class _FixedTokenizer:
|
||||
def __init__(self, token_count: int) -> None:
|
||||
@@ -609,6 +612,7 @@ async def test_streaming_per_service_call_persistence_hides_response_id_from_aft
|
||||
|
||||
async def test_per_service_call_persistence_uses_real_service_storage_when_client_stores_by_default(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
provider = _RecordingHistoryProvider()
|
||||
|
||||
@@ -649,15 +653,22 @@ async def test_per_service_call_persistence_uses_real_service_storage_when_clien
|
||||
require_per_service_call_history_persistence=True,
|
||||
)
|
||||
|
||||
result = await agent.run("What's the weather in Seattle?", session=session)
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
result = await agent.run("What's the weather in Seattle?", session=session)
|
||||
|
||||
provider_state = session.state[provider.source_id]
|
||||
|
||||
assert result.text == "It is sunny in Seattle."
|
||||
assert result.response_id == "resp_call_2"
|
||||
assert chat_client_base.call_count == 2
|
||||
# The service owns the conversation, so the provider never loads (issue #5798).
|
||||
assert "get_call_count" not in provider_state
|
||||
assert "save_call_count" not in provider_state
|
||||
# Persistence is owned by the per-service-call middleware: it persists once per service call
|
||||
# (issue #5798: the provider must never be silently bypassed when the service stores history).
|
||||
# This run makes two service calls (function call + final answer), so it persists twice.
|
||||
assert provider_state["save_call_count"] == 2
|
||||
# load_messages=True while the service stores history surfaces a warning.
|
||||
assert any("load_messages" in record.message for record in caplog.records)
|
||||
assert session.service_session_id == "resp_service_managed"
|
||||
|
||||
|
||||
@@ -1996,6 +2007,19 @@ def test_merge_options_none_values_ignored():
|
||||
assert result["key2"] == "value2"
|
||||
|
||||
|
||||
def test_merge_options_drops_none_base_values():
|
||||
"""Test _merge_options strips None values so unset options are never forwarded."""
|
||||
base = {"store": None, "temperature": 0.5}
|
||||
override = {"top_p": 0.9}
|
||||
|
||||
result = _merge_options(base, override)
|
||||
|
||||
# An unset base value (e.g. store=None from default_options) must not survive the merge.
|
||||
assert "store" not in result
|
||||
assert result["temperature"] == 0.5
|
||||
assert result["top_p"] == 0.9
|
||||
|
||||
|
||||
def test_merge_options_runtime_model_overrides_default_model() -> None:
|
||||
"""Test _merge_options lets a runtime model override a default model."""
|
||||
result = _merge_options({"model": "default-model"}, {"model": "runtime-model"})
|
||||
@@ -2658,3 +2682,449 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo
|
||||
assert len(exc_info.value.contents) == 1
|
||||
assert exc_info.value.contents[0].type == "oauth_consent_request"
|
||||
assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent"
|
||||
|
||||
|
||||
# region Per-service-call history persistence scenario matrix
|
||||
#
|
||||
# The driving field is ``require_per_service_call_history_persistence``. Every scenario runs a
|
||||
# single agent run that makes **two service calls** -- a function call followed by a final
|
||||
# completion -- so the *timing* of persistence is observable:
|
||||
#
|
||||
# * When the flag is ``True``, the per-service-call middleware persists the provider **after each
|
||||
# service call**. So the function-call turn is already saved by the time the second (final)
|
||||
# service call starts. This holds regardless of whether the chat client stores history
|
||||
# server-side (the bug in issue #5798 was that a storing client silently bypassed persistence).
|
||||
# * When the flag is ``False``, the provider persists **once, at the end of the run** -- nothing is
|
||||
# saved between the two service calls.
|
||||
#
|
||||
# ``SpyChatClient.saves_before_call`` records ``provider.save_calls`` at the start of every service
|
||||
# call, so ``[0, 1]`` means "the function-call turn was persisted before the final call" and
|
||||
# ``[0, 0]`` means "no persistence happened mid-run". The client's ``store`` / ``STORES_BY_DEFAULT``
|
||||
# only selects *how* the middleware behaves -- never *whether* the provider persists.
|
||||
|
||||
_PSC_SERVICE_CONVERSATION_ID = "svc-conversation"
|
||||
|
||||
_psc_stream_params = pytest.mark.parametrize("stream", [False, True], ids=["sync", "stream"])
|
||||
|
||||
|
||||
@tool(name="lookup_weather", approval_mode="never_require")
|
||||
def _psc_lookup_weather(location: str) -> str:
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
|
||||
def _psc_function_call_script() -> list[tuple[str, ...]]:
|
||||
"""A fresh function-call-then-final-completion script (the client mutates it)."""
|
||||
return [
|
||||
("call", "call_1", "lookup_weather", '{"location": "Seattle"}'),
|
||||
("text", "It is sunny in Seattle."),
|
||||
]
|
||||
|
||||
|
||||
class _PscSpyHistoryProvider(HistoryProvider):
|
||||
"""In-memory history provider that records load/save calls for assertions."""
|
||||
|
||||
def __init__(self, source_id: str = "spy_history", **kwargs: Any) -> None:
|
||||
super().__init__(source_id, **kwargs)
|
||||
self._messages: list[Message] = []
|
||||
self.get_calls: int = 0
|
||||
self.save_calls: int = 0
|
||||
self.saved_batches: list[list[Message]] = []
|
||||
|
||||
async def get_messages(
|
||||
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> list[Message]:
|
||||
self.get_calls += 1
|
||||
return list(self._messages)
|
||||
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.save_calls += 1
|
||||
self.saved_batches.append(list(messages))
|
||||
self._messages.extend(messages)
|
||||
|
||||
@property
|
||||
def stored_messages(self) -> list[Message]:
|
||||
return list(self._messages)
|
||||
|
||||
|
||||
class _PscSpyChatClient(MockBaseChatClient):
|
||||
"""Chat client that scripts a function-call/final-completion sequence.
|
||||
|
||||
It records, at the start of each service call, how many provider saves have already happened
|
||||
(``saves_before_call``), what messages it received, and what options it saw. When the effective
|
||||
``store`` is truthy it returns a stable ``conversation_id`` to mimic a server-managed
|
||||
conversation, so the framework propagates ``session.service_session_id``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: _PscSpyHistoryProvider,
|
||||
stores_by_default: bool = False,
|
||||
script: list[tuple[str, ...]] | None = None,
|
||||
echo_conversation_id: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.STORES_BY_DEFAULT = stores_by_default # type: ignore[attr-defined]
|
||||
self._provider = provider
|
||||
self._script = list(script) if script is not None else [("text", "ok")]
|
||||
self._echo_conversation_id = echo_conversation_id
|
||||
self.received_messages: list[list[Message]] = []
|
||||
self.received_options: list[dict[str, Any]] = []
|
||||
self.saves_before_call: list[int] = []
|
||||
|
||||
def _effective_store(self, options: dict[str, Any]) -> bool:
|
||||
store = options.get("store")
|
||||
if store is None:
|
||||
return bool(self.STORES_BY_DEFAULT)
|
||||
return bool(store)
|
||||
|
||||
def _next_contents(self) -> list[Content]:
|
||||
turn = self._script.pop(0) if self._script else ("text", "ok")
|
||||
if turn[0] == "call":
|
||||
_, call_id, name, args = turn
|
||||
return [Content.from_function_call(call_id=call_id, name=name, arguments=args)]
|
||||
return [Content.from_text(turn[1])]
|
||||
|
||||
def _inner_get_response( # type: ignore[override]
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[Message],
|
||||
stream: bool,
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
||||
self.received_messages.append(list(messages))
|
||||
self.received_options.append(dict(options))
|
||||
self.saves_before_call.append(self._provider.save_calls)
|
||||
store_and_echo = self._effective_store(options) and self._echo_conversation_id
|
||||
conv_id = _PSC_SERVICE_CONVERSATION_ID if store_and_echo else None
|
||||
contents = self._next_contents()
|
||||
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
||||
self.call_count += 1
|
||||
yield ChatResponseUpdate(
|
||||
contents=contents,
|
||||
role="assistant",
|
||||
finish_reason="stop",
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
||||
response = ChatResponse.from_updates(updates, output_format_type=options.get("response_format"))
|
||||
if conv_id:
|
||||
response.conversation_id = conv_id
|
||||
return response
|
||||
|
||||
return ResponseStream(_stream(), finalizer=_finalize)
|
||||
|
||||
async def _get() -> ChatResponse:
|
||||
self.call_count += 1
|
||||
return ChatResponse(
|
||||
messages=Message(role="assistant", contents=contents),
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
return _get()
|
||||
|
||||
|
||||
def _psc_build_agent(
|
||||
client: _PscSpyChatClient,
|
||||
provider: _PscSpyHistoryProvider,
|
||||
*,
|
||||
require_per_service_call_history_persistence: bool,
|
||||
default_options: dict[str, Any] | None = None,
|
||||
) -> Agent:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if default_options is not None:
|
||||
kwargs["default_options"] = default_options
|
||||
return Agent(
|
||||
client=client,
|
||||
tools=[_psc_lookup_weather],
|
||||
context_providers=[provider],
|
||||
require_per_service_call_history_persistence=require_per_service_call_history_persistence,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def _psc_run(agent: Agent, text: str, session: AgentSession, *, stream: bool) -> str:
|
||||
if stream:
|
||||
chunks: list[str] = []
|
||||
async for update in agent.run(text, session=session, stream=True):
|
||||
chunks.append(update.text or "")
|
||||
return "".join(chunks)
|
||||
result = await agent.run(text, session=session)
|
||||
return result.text
|
||||
|
||||
|
||||
# driver=True (the contract under test): persistence happens per service call
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_false_persists_after_each_service_call(stream: bool) -> None:
|
||||
"""Mode A (flag on, service does not store): function-call turn is persisted before the final call."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
# Two service calls: function call, then final completion.
|
||||
assert client.call_count == 2
|
||||
# The contract: the function-call turn was persisted *before* the second service call started.
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
# Mode A loads local history (the middleware injects it before each service call).
|
||||
assert provider.get_calls >= 1
|
||||
# No service-side storage, so no conversation id is propagated.
|
||||
assert session.service_session_id is None
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_stores_by_default_persists_after_each_service_call(
|
||||
stream: bool, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Mode B (flag on, service stores by default): still persists per service call, but skips load (issue #5798)."""
|
||||
provider = _PscSpyHistoryProvider() # load_messages=True by default
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
assert client.call_count == 2
|
||||
# The invariant the bug violated: persistence still happens per service call when the service stores.
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
# The service owns loading, so the provider is never asked to load.
|
||||
assert provider.get_calls == 0
|
||||
# A warning surfaces the bypassed load (load_messages=True).
|
||||
assert any("load_messages" in record.message for record in caplog.records)
|
||||
# The real service conversation id propagates to the session.
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_only_provider_no_load_no_warning(
|
||||
stream: bool, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Mode B with a store-only provider (load_messages=False): persists per call, no load, no warning."""
|
||||
provider = _PscSpyHistoryProvider(load_messages=False)
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
assert not any("load_messages" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_false_override_behaves_as_mode_a(stream: bool) -> None:
|
||||
"""Flag on + storing client but store=False override: falls back to Mode A (local, per call)."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(
|
||||
client, provider, require_per_service_call_history_persistence=True, default_options={"store": False}
|
||||
)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls >= 1
|
||||
# store=False forces local handling, so no real service conversation id.
|
||||
assert session.service_session_id is None
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_store_none_treated_as_absent(stream: bool, caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Flag on + storing client + explicit store=None: None is "unset", so the storing default applies (Mode B)."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(
|
||||
client, provider, require_per_service_call_history_persistence=True, default_options={"store": None}
|
||||
)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 1]
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
assert any("load_messages" in record.message for record in caplog.records)
|
||||
# store=None must not be forwarded to the client; the service decides its own default.
|
||||
assert all("store" not in options for options in client.received_options)
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_respects_store_outputs_flag(stream: bool) -> None:
|
||||
"""Flag on: the provider's store_inputs/store_outputs flags still apply per service call."""
|
||||
provider = _PscSpyHistoryProvider(store_inputs=True, store_outputs=False)
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert provider.save_calls == 2
|
||||
# Outputs disabled, so no assistant/tool-call messages were stored, only user/tool inputs.
|
||||
assert provider.stored_messages
|
||||
assert all(message.role != "assistant" for message in provider.stored_messages)
|
||||
|
||||
|
||||
# driver=False (control): persistence happens once, at the end of the run
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_off_store_false_persists_once_at_end(stream: bool) -> None:
|
||||
"""Flag off + non-storing client: nothing is persisted mid-run; one save at the end."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=False, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False)
|
||||
session = agent.create_session()
|
||||
|
||||
text = await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
assert client.call_count == 2
|
||||
# The control contract: no save happened between the function call and the final completion.
|
||||
assert client.saves_before_call == [0, 0]
|
||||
assert provider.save_calls == 1
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_off_stores_by_default_persists_once_at_end(stream: bool) -> None:
|
||||
"""Flag off + storing client: once-per-run persistence, and the service conversation id propagates."""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=False)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert client.saves_before_call == [0, 0]
|
||||
assert provider.save_calls == 1
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_storing_with_existing_conversation_id_does_not_raise(stream: bool) -> None:
|
||||
"""Allow side of the guard: flag on + storing client + an existing conversation_id resumes (no raise).
|
||||
|
||||
The non-storing path raises on an existing service-managed conversation id, but with a storing
|
||||
client the run must proceed and the service conversation id must propagate to the session.
|
||||
"""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
if stream:
|
||||
chunks: list[str] = []
|
||||
async for update in agent.run(
|
||||
"What's the weather in Seattle?",
|
||||
session=session,
|
||||
stream=True,
|
||||
options={"conversation_id": "existing_conversation"},
|
||||
):
|
||||
chunks.append(update.text or "")
|
||||
text = "".join(chunks)
|
||||
else:
|
||||
result = await agent.run(
|
||||
"What's the weather in Seattle?",
|
||||
session=session,
|
||||
options={"conversation_id": "existing_conversation"},
|
||||
)
|
||||
text = result.text
|
||||
|
||||
assert text == "It is sunny in Seattle."
|
||||
# Persistence still happens per service call, and the real service id propagates to the session.
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
assert session.service_session_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_storing_two_runs_same_session(stream: bool) -> None:
|
||||
"""Storing mode across two runs on one session: persistence keeps happening, id is stable, no load.
|
||||
|
||||
The second run exercises the precedence branch where the session already carries a
|
||||
service_session_id, which must continue to skip provider loading and keep persisting.
|
||||
"""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(provider=provider, stores_by_default=True, script=_psc_function_call_script())
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
assert provider.save_calls == 2
|
||||
assert provider.get_calls == 0
|
||||
first_run_service_id = session.service_session_id
|
||||
assert first_run_service_id == _PSC_SERVICE_CONVERSATION_ID
|
||||
|
||||
# Reset the scripted client for a second run on the same session.
|
||||
client._script = _psc_function_call_script()
|
||||
client.call_count = 0
|
||||
client.saves_before_call = []
|
||||
|
||||
await _psc_run(agent, "And in Portland?", session, stream=stream)
|
||||
|
||||
# Persistence keeps happening on the second run (two more saves), still per service call.
|
||||
assert client.saves_before_call == [2, 3]
|
||||
assert provider.save_calls == 4
|
||||
# Loading stays skipped and the service conversation id stays stable across runs.
|
||||
assert provider.get_calls == 0
|
||||
assert session.service_session_id == first_run_service_id
|
||||
|
||||
|
||||
@_psc_stream_params
|
||||
async def test_psc_flag_on_storing_without_conversation_id_warns_every_call(
|
||||
stream: bool, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Storing mode but the client returns no conversation_id: warn on every service call.
|
||||
|
||||
Without an echoed conversation id the next run has nothing to resume from, so cross-turn
|
||||
history can be lost silently. The warning fires per service call (no dedup) so the uncommon
|
||||
failure mode cannot pass unnoticed.
|
||||
"""
|
||||
provider = _PscSpyHistoryProvider()
|
||||
client = _PscSpyChatClient(
|
||||
provider=provider,
|
||||
stores_by_default=True,
|
||||
script=_psc_function_call_script(),
|
||||
echo_conversation_id=False,
|
||||
)
|
||||
agent = _psc_build_agent(client, provider, require_per_service_call_history_persistence=True)
|
||||
session = agent.create_session()
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
await _psc_run(agent, "What's the weather in Seattle?", session, stream=stream)
|
||||
|
||||
# Persistence still happens, but no service id is captured to resume from.
|
||||
assert provider.save_calls == 2
|
||||
assert session.service_session_id is None
|
||||
# Two service calls -> the warning is emitted twice (one per call, not deduped).
|
||||
missing_id_warnings = [r for r in caplog.records if "returned no conversation_id" in r.message]
|
||||
assert len(missing_id_warnings) == 2
|
||||
|
||||
@@ -30,6 +30,7 @@ from agent_framework._mcp import (
|
||||
MCPTool,
|
||||
_build_prefixed_mcp_name,
|
||||
_get_input_model_from_mcp_prompt,
|
||||
_normalize_additional_tool_argument_names,
|
||||
_normalize_mcp_name,
|
||||
_should_propagate_cancelled_error,
|
||||
logger,
|
||||
@@ -6057,3 +6058,205 @@ async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region additional_tool_argument_names / allowlist filtering
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_none() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names(None)
|
||||
assert global_extras == set()
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_sequence() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names(["a", "b", "a"])
|
||||
assert global_extras == {"a", "b"}
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_single_string() -> None:
|
||||
# A bare string must be treated as a single name, not split into characters.
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names("conversation_id")
|
||||
assert global_extras == {"conversation_id"}
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_mapping_with_global_key() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names({
|
||||
"*": ["g1"],
|
||||
"tool_a": ["a1", "a2"],
|
||||
"tool_b": ["b1"],
|
||||
})
|
||||
assert global_extras == {"g1"}
|
||||
assert per_tool == {"tool_a": {"a1", "a2"}, "tool_b": {"b1"}}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_mapping_with_string_values() -> None:
|
||||
# A bare string mapping value is a single name, not an iterable of characters.
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names({
|
||||
"*": "conversation_id",
|
||||
"tool_a": "custom",
|
||||
})
|
||||
assert global_extras == {"conversation_id"}
|
||||
assert per_tool == {"tool_a": {"custom"}}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_strips_undeclared_arguments() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, meta = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "value", "conversation_id": "c", "thread": object(), "unexpected": 1},
|
||||
)
|
||||
|
||||
assert filtered == {"param": "value"}
|
||||
assert meta is None
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_global_extras_allowed() -> None:
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "value", "conversation_id": "c", "options": {}},
|
||||
)
|
||||
|
||||
assert filtered == {"param": "value", "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_per_tool_and_global_extras() -> None:
|
||||
server = MCPTool(
|
||||
name="test_server",
|
||||
additional_tool_argument_names={"*": ["conversation_id"], "test_tool": ["custom"]},
|
||||
)
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}, "other_tool": {"x"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "conversation_id": "c", "custom": "y", "thread": object()},
|
||||
)
|
||||
assert filtered == {"param": "v", "conversation_id": "c", "custom": "y"}
|
||||
|
||||
# The per-tool extra does not leak to other tools; the global one still applies.
|
||||
filtered_other, _ = server._prepare_call_kwargs(
|
||||
"other_tool",
|
||||
{"x": 1, "conversation_id": "c", "custom": "y"},
|
||||
)
|
||||
assert filtered_other == {"x": 1, "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None:
|
||||
# The denylist is a safety net for framework-named params a server *declares* in its
|
||||
# schema: they are dropped so internal objects never leak. Names explicitly opted in
|
||||
# via extras always win.
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param", "thread"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "thread": object(), "conversation_id": "c"},
|
||||
)
|
||||
# "thread" is declared by the schema but denylisted -> dropped; conversation_id opted in -> kept.
|
||||
assert filtered == {"param": "v", "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_extras_override_denylist() -> None:
|
||||
# Opting a denylisted framework name back in via extras takes precedence over the
|
||||
# denylist safety net. "thread" is on the framework denylist, but an explicit extra wins.
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["thread"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
sentinel = object()
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "thread": sentinel, "conversation_id": "c"},
|
||||
)
|
||||
# "thread" opted in via extras -> kept despite the denylist; conversation_id is denylisted,
|
||||
# not declared, and not opted in -> dropped.
|
||||
assert filtered == {"param": "v", "thread": sentinel}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": set()}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"conversation_id": "c", "thread": object(), "stray": 1},
|
||||
)
|
||||
assert filtered == {}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None:
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
# No entry in _tool_param_names_by_name for this tool name.
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"unknown_tool",
|
||||
{"conversation_id": "c", "other": 1},
|
||||
)
|
||||
assert filtered == {"conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_extracts_meta() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, meta = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "_meta": {"trace": "abc"}},
|
||||
)
|
||||
assert filtered == {"param": "v"}
|
||||
assert meta is not None
|
||||
assert meta.get("trace") == "abc"
|
||||
|
||||
|
||||
async def test_call_tool_forwards_only_declared_arguments() -> None:
|
||||
"""End-to-end: framework runtime kwargs are stripped before reaching the server."""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")])
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
session_mock = server.session
|
||||
await server.call_tool(
|
||||
"test_tool",
|
||||
param="value",
|
||||
conversation_id="c",
|
||||
thread=object(),
|
||||
response_format=object(),
|
||||
)
|
||||
|
||||
session_mock.call_tool.assert_called_once()
|
||||
_, call_kwargs = session_mock.call_tool.call_args
|
||||
assert call_kwargs["arguments"] == {"param": "value", "conversation_id": "c"}
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -8,29 +8,34 @@ This module provides ``Mem0ContextProvider``, built on the new
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypedDict
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
|
||||
from mem0 import AsyncMemory, AsyncMemoryClient
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import NotRequired, Self, TypedDict # pragma: no cover
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework._agents import SupportsAgentRun
|
||||
|
||||
|
||||
class _MemorySearchResponse_v1_1(TypedDict):
|
||||
results: list[dict[str, Any]]
|
||||
relations: NotRequired[list[dict[str, Any]]]
|
||||
logger = logging.getLogger(__name__)
|
||||
MemoryRecord: TypeAlias = dict[str, object]
|
||||
|
||||
|
||||
_MemorySearchResponse_v2 = list[dict[str, Any]]
|
||||
class SearchResults(TypedDict):
|
||||
results: list[MemoryRecord]
|
||||
|
||||
|
||||
SearchResponse: TypeAlias = list[MemoryRecord] | SearchResults
|
||||
|
||||
|
||||
class Mem0ContextProvider(ContextProvider):
|
||||
@@ -106,28 +111,85 @@ class Mem0ContextProvider(ContextProvider):
|
||||
if not input_text.strip():
|
||||
return
|
||||
|
||||
filters = self._build_filters()
|
||||
# Query entity partitions independently to bypass strict logical AND limitations
|
||||
# Mem0 OSS and Platform SDKs expose inconsistent search typings.
|
||||
search_tasks: list[Awaitable[Any]] = []
|
||||
|
||||
# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
|
||||
# AsyncMemoryClient (Platform) expects them in a filters dict
|
||||
search_kwargs: dict[str, Any] = {"query": input_text}
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
search_kwargs.update(filters)
|
||||
else:
|
||||
search_kwargs["filters"] = filters
|
||||
# 1. Query User partition independently
|
||||
if self.user_id:
|
||||
user_kwargs = self._build_search_kwargs(input_text, "user_id", self.user_id)
|
||||
search_tasks.append(self.mem0_client.search(**user_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
|
||||
search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
|
||||
**search_kwargs,
|
||||
)
|
||||
# 2. Query Agent partition independently
|
||||
if self.agent_id:
|
||||
agent_kwargs = self._build_search_kwargs(input_text, "agent_id", self.agent_id)
|
||||
search_tasks.append(self.mem0_client.search(**agent_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
|
||||
if isinstance(search_response, list):
|
||||
memories = search_response
|
||||
elif isinstance(search_response, dict) and "results" in search_response:
|
||||
memories = search_response["results"]
|
||||
else:
|
||||
memories = [search_response]
|
||||
# Fall back to an app-scoped search when only application_id is configured
|
||||
if not search_tasks and self.application_id:
|
||||
app_kwargs: dict[str, Any] = {"query": input_text}
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
app_kwargs["app_id"] = self.application_id
|
||||
else:
|
||||
app_kwargs["filters"] = {"app_id": self.application_id}
|
||||
search_tasks.append(self.mem0_client.search(**app_kwargs)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
|
||||
if not search_tasks:
|
||||
return
|
||||
|
||||
line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
|
||||
results: list[SearchResponse | BaseException] = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
# Merge and deduplicate results
|
||||
memories: list[MemoryRecord] = []
|
||||
seen_memory_ids: set[str] = set()
|
||||
failed_tasks_count: int = 0
|
||||
|
||||
for search_response in results:
|
||||
if isinstance(search_response, asyncio.CancelledError):
|
||||
raise search_response
|
||||
|
||||
if isinstance(search_response, BaseException):
|
||||
failed_tasks_count += 1
|
||||
logger.error(
|
||||
"Mem0 partition search task failed: %s",
|
||||
search_response,
|
||||
exc_info=(type(search_response), search_response, search_response.__traceback__),
|
||||
)
|
||||
continue
|
||||
|
||||
current_memories: list[MemoryRecord] = []
|
||||
if isinstance(search_response, list):
|
||||
current_memories = [mem for mem in search_response if isinstance(mem, dict)]
|
||||
elif isinstance(search_response, dict):
|
||||
results_field = search_response.get("results")
|
||||
if isinstance(results_field, list):
|
||||
current_memories = [
|
||||
item
|
||||
for item in results_field
|
||||
if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType]
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
"Unexpected Mem0 search response format: %s",
|
||||
type(results_field).__name__,
|
||||
)
|
||||
|
||||
for mem in current_memories:
|
||||
mem_id = mem.get("id")
|
||||
if mem_id is not None and not isinstance(mem_id, str):
|
||||
mem_id = str(mem_id)
|
||||
|
||||
if mem_id is not None and mem_id in seen_memory_ids:
|
||||
continue
|
||||
|
||||
if mem_id is not None:
|
||||
seen_memory_ids.add(mem_id)
|
||||
|
||||
memories.append(mem)
|
||||
|
||||
if failed_tasks_count == len(search_tasks):
|
||||
logger.error("All Mem0 retrieval tasks failed. Context provider is unable to verify memory state.")
|
||||
|
||||
line_separated_memories = "\n".join(str(memory.get("memory", "")) for memory in memories)
|
||||
if line_separated_memories:
|
||||
context.extend_messages(
|
||||
self.source_id,
|
||||
@@ -159,12 +221,21 @@ class Mem0ContextProvider(ContextProvider):
|
||||
]
|
||||
|
||||
if messages:
|
||||
await self.mem0_client.add( # type: ignore[misc]
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
metadata={"application_id": self.application_id},
|
||||
)
|
||||
add_kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"user_id": self.user_id,
|
||||
"agent_id": self.agent_id,
|
||||
}
|
||||
|
||||
# Inject the application scope using the matching signature format for each SDK variant
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
if self.application_id:
|
||||
add_kwargs["app_id"] = self.application_id
|
||||
else:
|
||||
if self.application_id:
|
||||
add_kwargs["filters"] = {"app_id": self.application_id}
|
||||
|
||||
await self.mem0_client.add(**add_kwargs) # type: ignore[misc, call-arg]
|
||||
|
||||
# -- Internal methods ------------------------------------------------------
|
||||
|
||||
@@ -173,15 +244,21 @@ class Mem0ContextProvider(ContextProvider):
|
||||
if not self.agent_id and not self.user_id and not self.application_id:
|
||||
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")
|
||||
|
||||
def _build_filters(self) -> dict[str, Any]:
|
||||
"""Build search filters from initialization parameters."""
|
||||
filters: dict[str, Any] = {}
|
||||
if self.user_id:
|
||||
filters["user_id"] = self.user_id
|
||||
if self.agent_id:
|
||||
filters["agent_id"] = self.agent_id
|
||||
if self.application_id:
|
||||
filters["app_id"] = self.application_id
|
||||
def _build_search_kwargs(self, input_text: str, entity_key: str, entity_value: str) -> dict[str, Any]:
|
||||
"""Build search keyword arguments formatted for OSS vs Platform clients."""
|
||||
filters: dict[str, Any] = {"query": input_text}
|
||||
|
||||
if isinstance(self.mem0_client, AsyncMemory):
|
||||
# AsyncMemory (OSS) expects direct kwargs
|
||||
filters[entity_key] = entity_value
|
||||
if self.application_id:
|
||||
filters["app_id"] = self.application_id
|
||||
else:
|
||||
# AsyncMemoryClient (Platform) expects a filters dict
|
||||
filters["filters"] = {entity_key: entity_value}
|
||||
if self.application_id:
|
||||
filters["filters"]["app_id"] = self.application_id
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, Message
|
||||
@@ -193,39 +193,59 @@ class TestBeforeRun:
|
||||
assert call_kwargs["user_id"] == "u1"
|
||||
assert "filters" not in call_kwargs
|
||||
|
||||
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
|
||||
"""OSS client with all scoping parameters passes them as direct kwargs."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_oss_client_all_scoping_params_except_app_id(self, mock_oss_mem0_client: AsyncMock) -> None:
|
||||
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
|
||||
mock_oss_mem0_client.search.return_value = []
|
||||
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
|
||||
source_id="mem0",
|
||||
mem0_client=mock_oss_mem0_client,
|
||||
user_id="u1",
|
||||
agent_id="a1"
|
||||
)
|
||||
session = AgentSession(session_id="test-session")
|
||||
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
|
||||
|
||||
mock_context = MagicMock(spec=SessionContext)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.text = "hello"
|
||||
mock_context.input_messages = [mock_msg]
|
||||
mock_context.response = None
|
||||
|
||||
await provider.before_run(
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
||||
)
|
||||
|
||||
call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "u1"
|
||||
assert call_kwargs["agent_id"] == "a1"
|
||||
assert "filters" not in call_kwargs
|
||||
# Re-aligned assertion: We expect 2 separate concurrent calls instead of 1 combined call
|
||||
assert mock_oss_mem0_client.search.call_count == 2
|
||||
mock_oss_mem0_client.search.assert_any_call(query="hello", user_id="u1")
|
||||
mock_oss_mem0_client.search.assert_any_call(query="hello", agent_id="a1")
|
||||
|
||||
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_client_passes_filters_dict_except_app_id(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Platform client passes scoping parameters concurrently inside the nested filters dictionary."""
|
||||
mock_mem0_client.search.return_value = []
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
session = AgentSession(session_id="test-session")
|
||||
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")
|
||||
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0",
|
||||
mem0_client=mock_mem0_client,
|
||||
user_id="u1",
|
||||
agent_id="a1",
|
||||
)
|
||||
|
||||
mock_context = MagicMock(spec=SessionContext)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.text = "hello"
|
||||
mock_context.input_messages = [mock_msg]
|
||||
mock_context.response = None
|
||||
|
||||
await provider.before_run(
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
||||
)
|
||||
|
||||
call_kwargs = mock_mem0_client.search.call_args.kwargs
|
||||
assert call_kwargs["query"] == "Hello"
|
||||
assert "filters" in call_kwargs
|
||||
assert call_kwargs["filters"]["user_id"] == "u1"
|
||||
# Re-aligned assertion: Platform client isolates filters per call to bypass AND limitations
|
||||
assert mock_mem0_client.search.call_count == 2
|
||||
mock_mem0_client.search.assert_any_call(query="hello", filters={"user_id": "u1"})
|
||||
mock_mem0_client.search.assert_any_call(query="hello", filters={"agent_id": "a1"})
|
||||
|
||||
|
||||
# -- after_run tests -----------------------------------------------------------
|
||||
@@ -318,8 +338,8 @@ class TestAfterRun:
|
||||
with pytest.raises(ValueError, match="At least one of the filters"):
|
||||
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
||||
|
||||
async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""application_id is passed in metadata."""
|
||||
async def test_stores_with_application_id_filters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""application_id is passed in filters."""
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0", mem0_client=mock_mem0_client, user_id="u1", application_id="app1"
|
||||
)
|
||||
@@ -331,7 +351,7 @@ class TestAfterRun:
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"}
|
||||
assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"}
|
||||
|
||||
|
||||
# -- _validate_filters tests --------------------------------------------------
|
||||
@@ -358,15 +378,20 @@ class TestValidateFilters:
|
||||
provider._validate_filters()
|
||||
|
||||
|
||||
# -- _build_filters tests -----------------------------------------------------
|
||||
# -- _build_search_kwargs tests -----------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildFilters:
|
||||
"""Test _build_filters method."""
|
||||
class TestBuildSearchKwargs:
|
||||
"""Test _build_search_kwargs method."""
|
||||
|
||||
def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
assert provider._build_filters() == {"user_id": "u1"}
|
||||
|
||||
# Pass the 3 required arguments
|
||||
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
||||
|
||||
# AsyncMock triggers the Platform client nested 'filters' structure
|
||||
assert result == {"query": "test query", "filters": {"user_id": "u1"}}
|
||||
|
||||
def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
|
||||
provider = Mem0ContextProvider(
|
||||
@@ -376,28 +401,66 @@ class TestBuildFilters:
|
||||
agent_id="a1",
|
||||
application_id="app1",
|
||||
)
|
||||
assert provider._build_filters() == {
|
||||
"user_id": "u1",
|
||||
"agent_id": "a1",
|
||||
"app_id": "app1",
|
||||
|
||||
# Test that app_id correctly merges with the isolated target entity
|
||||
result = provider._build_search_kwargs("test query", "agent_id", "a1")
|
||||
|
||||
assert result == {
|
||||
"query": "test query",
|
||||
"filters": {
|
||||
"agent_id": "a1",
|
||||
"app_id": "app1",
|
||||
},
|
||||
}
|
||||
|
||||
def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
filters = provider._build_filters()
|
||||
assert "agent_id" not in filters
|
||||
assert "run_id" not in filters
|
||||
assert "app_id" not in filters
|
||||
|
||||
# application_id is None by default, it should not appear in the dictionary
|
||||
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
||||
|
||||
assert "app_id" not in result.get("filters", {})
|
||||
|
||||
def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""run_id is excluded from search filters so memories work across sessions."""
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
|
||||
filters = provider._build_filters()
|
||||
assert "run_id" not in filters
|
||||
|
||||
result = provider._build_search_kwargs("test query", "user_id", "u1")
|
||||
|
||||
assert "run_id" not in result.get("filters", {})
|
||||
assert "run_id" not in result
|
||||
|
||||
def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
|
||||
# Validates base query payload generation
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
||||
assert provider._build_filters() == {}
|
||||
|
||||
result = provider._build_search_kwargs("test query", "custom_key", "custom_val")
|
||||
|
||||
assert result == {"query": "test query", "filters": {"custom_key": "custom_val"}}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_run_application_only_fallback(self, mock_mem0_client: AsyncMock) -> None:
|
||||
|
||||
provider = Mem0ContextProvider(
|
||||
source_id="mem0", mem0_client=mock_mem0_client, application_id="app_fallback_test"
|
||||
)
|
||||
|
||||
# Mock a valid message list and session container setup
|
||||
mock_context = MagicMock(spec=SessionContext)
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.text = "Retrieve systemic fallback memory traces"
|
||||
mock_context.input_messages = [mock_msg]
|
||||
mock_context.response = None
|
||||
|
||||
mock_mem0_client.search = AsyncMock(return_value=[{"id": "m1", "memory": "System configuration template"}])
|
||||
|
||||
await provider.before_run(
|
||||
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
|
||||
)
|
||||
|
||||
# Verify that an application-scoped search task executed successfully
|
||||
assert mock_mem0_client.search.call_count == 1
|
||||
mock_context.extend_messages.assert_called_once()
|
||||
|
||||
|
||||
# -- Context manager tests -----------------------------------------------------
|
||||
|
||||
@@ -1997,7 +1997,11 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
metadata: dict[str, Any] = response.metadata or {}
|
||||
contents: list[Content] = []
|
||||
local_shell_tool_name = self._get_local_shell_tool_name(options.get("tools"))
|
||||
for item in response.output: # type: ignore[reportUnknownMemberType]
|
||||
try:
|
||||
response_outputs = response.output # type: ignore[reportUnknownMemberType]
|
||||
except AttributeError:
|
||||
response_outputs = []
|
||||
for item in response_outputs: # type: ignore[reportUnknownVariableType]
|
||||
match item.type:
|
||||
# types:
|
||||
# ParsedResponseOutputMessage[Unknown] |
|
||||
|
||||
@@ -788,13 +788,13 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
def _get_metadata_from_chat_response(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Get metadata from a chat response."""
|
||||
return {
|
||||
"system_fingerprint": response.system_fingerprint,
|
||||
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||
}
|
||||
|
||||
def _get_metadata_from_streaming_chat_response(self, response: ChatCompletionChunk) -> dict[str, Any]:
|
||||
"""Get metadata from a streaming chat response."""
|
||||
return {
|
||||
"system_fingerprint": response.system_fingerprint,
|
||||
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||
}
|
||||
|
||||
def _get_metadata_from_chat_choice(self, choice: Choice | ChunkChoice) -> dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user