mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Filter MCP tool kwargs to declared params via allowlist (#6399)
* Filter MCP tool kwargs to declared params via allowlist Previously MCPTool combined framework runtime kwargs (from FunctionInvocationContext.kwargs) with the LLM-supplied arguments and stripped only a hardcoded denylist of known framework keys before forwarding to the MCP server. Any new framework-injected kwarg leaked to the server unless the denylist was updated. Switch to an allowlist built from each tool's declared parameters (inputSchema.properties). Only declared params are forwarded; everything else is stripped. Add an `additional_tool_argument_names` constructor argument so users can opt extra names back in, globally (Sequence[str]) and/or per remote tool name (Mapping with reserved "*" global key). The existing denylist is kept as a safety net for framework-named params a server declares in its schema; explicitly opted-in extras always win. The reserved _meta handling is unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address MCP allowlist review comments and fix reload arg loss - Fix pyright reportUnknownArgumentType in _load_tools (cast schema properties). - Register declared param names before the existing-tool skip guard so that tool-list reloads preserve the allowlist for already-loaded tools (previously unchanged tools silently dropped all declared args after a background reload). - Handle bare-string values in an additional_tool_argument_names mapping instead of iterating their characters. - Clarify the framework denylist comment: explicit extras override the denylist. - Make the extras-override-denylist test unambiguous (opt in a denylisted name). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
d222079df9
commit
cfb033e5d4
@@ -80,6 +80,8 @@ agent_framework/
|
||||
|
||||
- **`MCPTool`** - Base wrapper that owns the MCP `ClientSession` and exposes the remote server's tools as `FunctionTool`s.
|
||||
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
|
||||
- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument.
|
||||
- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins.
|
||||
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
|
||||
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
|
||||
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
|
||||
|
||||
@@ -70,6 +70,31 @@ class MCPSpecificApproval(TypedDict, total=False):
|
||||
|
||||
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
|
||||
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
|
||||
# Reserved key in an ``additional_tool_argument_names`` mapping that applies its
|
||||
# values to every tool on the server rather than a single named tool.
|
||||
_MCP_GLOBAL_EXTRA_ARGS_KEY = "*"
|
||||
# Framework kwargs that flow through the function-invocation pipeline (via
|
||||
# ``FunctionInvocationContext.kwargs``) but must never be forwarded to an MCP
|
||||
# server: they are internal objects that the MCP SDK cannot serialize. They are
|
||||
# dropped as a safety net when a tool declares one of them in its schema, unless
|
||||
# the user explicitly opts the name back in via ``additional_tool_argument_names``
|
||||
# (explicit extras always win over the denylist).
|
||||
# - chat_options/tools/tool_choice/session/thread: framework runtime objects.
|
||||
# - conversation_id: internal tracking ID used by services like Azure AI.
|
||||
# - options: metadata/store used by AG-UI for Azure AI client requirements.
|
||||
# - response_format: a Pydantic model class for structured output (not serializable).
|
||||
# - _meta: reserved key extracted separately as MCP request metadata.
|
||||
_MCP_FRAMEWORK_DENYLIST: frozenset[str] = frozenset({
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"thread",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
})
|
||||
_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers")
|
||||
MCP_DEFAULT_TIMEOUT = 30
|
||||
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5
|
||||
@@ -135,6 +160,34 @@ def _build_prefixed_mcp_name(
|
||||
return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix
|
||||
|
||||
|
||||
def _normalize_additional_tool_argument_names(
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None,
|
||||
) -> tuple[set[str], dict[str, set[str]]]:
|
||||
"""Split user-supplied extra argument names into global and per-tool sets.
|
||||
|
||||
Accepts either a sequence (applied to every tool) or a mapping keyed by remote
|
||||
tool name, where the reserved key ``"*"`` is treated as global. Mapping values
|
||||
may be a sequence or a single string. Returns a
|
||||
``(global_extras, per_tool_extras)`` tuple.
|
||||
"""
|
||||
if additional_tool_argument_names is None:
|
||||
return set(), {}
|
||||
if isinstance(additional_tool_argument_names, str):
|
||||
return {additional_tool_argument_names}, {}
|
||||
if isinstance(additional_tool_argument_names, Mapping):
|
||||
global_extras: set[str] = set()
|
||||
per_tool_extras: dict[str, set[str]] = {}
|
||||
for tool_name, names in additional_tool_argument_names.items():
|
||||
# Treat a bare string value as a single name rather than iterating its characters.
|
||||
names_set = {names} if isinstance(names, str) else set(names)
|
||||
if tool_name == _MCP_GLOBAL_EXTRA_ARGS_KEY:
|
||||
global_extras.update(names_set)
|
||||
else:
|
||||
per_tool_extras[tool_name] = names_set
|
||||
return global_extras, per_tool_extras
|
||||
return set(additional_tool_argument_names), {}
|
||||
|
||||
|
||||
def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None:
|
||||
"""Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s)."""
|
||||
carrier: dict[str, str] = {}
|
||||
@@ -294,6 +347,7 @@ class MCPTool:
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the MCP Tool base.
|
||||
|
||||
@@ -328,6 +382,10 @@ class MCPTool:
|
||||
task_options: Options controlling how long-running MCP tasks are driven for
|
||||
tools that advertise ``execution.taskSupport == "required"``. When ``None``,
|
||||
the defaults from :class:`MCPTaskOptions` are used.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server
|
||||
in addition to each tool's declared parameters. A ``Sequence[str]`` applies to
|
||||
every tool; a ``Mapping[str, Sequence[str]]`` is keyed by remote tool name with
|
||||
``"*"`` as a global key. See the transport subclasses for full details.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description or ""
|
||||
@@ -355,6 +413,10 @@ class MCPTool:
|
||||
self._functions: list[FunctionTool] = []
|
||||
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
self._tool_task_support_by_name: dict[str, str] = {}
|
||||
self._tool_param_names_by_name: dict[str, set[str]] = {}
|
||||
self._global_extra_arg_names, self._tool_extra_arg_names = _normalize_additional_tool_argument_names(
|
||||
additional_tool_argument_names
|
||||
)
|
||||
self.is_connected: bool = False
|
||||
self._tools_loaded: bool = False
|
||||
self._prompts_loaded: bool = False
|
||||
@@ -1229,6 +1291,7 @@ class MCPTool:
|
||||
existing_names = {func.name for func in self._functions}
|
||||
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
tool_task_support_by_name: dict[str, str] = {}
|
||||
tool_param_names_by_name: dict[str, set[str]] = {}
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
@@ -1271,14 +1334,6 @@ class MCPTool:
|
||||
if task_support is not None:
|
||||
tool_task_support_by_name[tool.name] = task_support
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
|
||||
# Normalize inputSchema: ensure "properties" exists for object schemas.
|
||||
# Some MCP servers (e.g. zero-argument tools) omit "properties",
|
||||
# which causes OpenAI API to reject the schema with a 400 error.
|
||||
@@ -1288,6 +1343,24 @@ class MCPTool:
|
||||
if input_schema.get("type") == "object" and "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
|
||||
# Register declared param names before the existing-tool skip below so that
|
||||
# reloads (e.g. notifications/tools/list_changed) preserve the allowlist for
|
||||
# tools that are already loaded, consistent with tool_call_meta_by_name and
|
||||
# tool_task_support_by_name above.
|
||||
schema_properties = input_schema.get("properties")
|
||||
tool_param_names_by_name[tool.name] = (
|
||||
set(cast(dict[str, Any], schema_properties)) if isinstance(schema_properties, dict) else set()
|
||||
)
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
# Skip if already loaded
|
||||
if local_name in existing_names:
|
||||
continue
|
||||
|
||||
approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name)
|
||||
|
||||
async def _call_tool_with_runtime_kwargs(
|
||||
ctx: FunctionInvocationContext,
|
||||
*,
|
||||
@@ -1320,6 +1393,7 @@ class MCPTool:
|
||||
|
||||
self._tool_call_meta_by_name = tool_call_meta_by_name
|
||||
self._tool_task_support_by_name = tool_task_support_by_name
|
||||
self._tool_param_names_by_name = tool_param_names_by_name
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
# Cancel any pending reload tasks before tearing down the session.
|
||||
@@ -1530,10 +1604,14 @@ class MCPTool:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
|
||||
|
||||
def _resolved_extra_args(self, tool_name: str) -> set[str]:
|
||||
"""Return the user-configured extra argument names allowed for a tool."""
|
||||
return self._global_extra_arg_names | self._tool_extra_arg_names.get(tool_name, set())
|
||||
|
||||
def _prepare_call_kwargs(
|
||||
self, tool_name: str, kwargs: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], dict[str, Any] | None]:
|
||||
"""Filter framework-only kwargs and build the merged MCP request metadata."""
|
||||
"""Filter kwargs down to the tool's arguments and build the merged MCP request metadata."""
|
||||
raw_user_meta: object | None = kwargs.get("_meta")
|
||||
user_meta: dict[str, Any] | None = None
|
||||
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
|
||||
@@ -1546,27 +1624,28 @@ class MCPTool:
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
|
||||
user_meta[key] = value
|
||||
|
||||
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
|
||||
# These are internal objects passed through the function invocation pipeline
|
||||
# that should not be forwarded to external MCP servers.
|
||||
# conversation_id is an internal tracking ID used by services like Azure AI.
|
||||
# options contains metadata/store used by AG-UI for Azure AI client requirements.
|
||||
# response_format is a Pydantic model class used for structured output (not serializable).
|
||||
# Allowlist: forward only the tool's declared parameters (from inputSchema.properties)
|
||||
# plus any user-configured extra argument names. Everything else - notably the
|
||||
# framework runtime kwargs injected through the function-invocation pipeline - is
|
||||
# stripped so it is never forwarded to the MCP server. Tools that declare no usable
|
||||
# properties forward only the user-configured extras.
|
||||
#
|
||||
# The extra names come exclusively from additional_tool_argument_names, which is set in
|
||||
# user code at construction time; there is no per-call override, so a model-issued tool
|
||||
# call cannot change which names are allowed through.
|
||||
#
|
||||
# The framework denylist acts as a safety net for keys a server *declares* in its
|
||||
# schema that collide with internal, non-serializable framework objects (e.g. a tool
|
||||
# that declares a parameter literally named "thread"): such declared-but-denylisted
|
||||
# keys are dropped. Names the user explicitly opts in via additional_tool_argument_names
|
||||
# always win. The reserved _meta key is handled separately above and never forwarded as
|
||||
# an argument.
|
||||
declared = self._tool_param_names_by_name.get(tool_name, set())
|
||||
extras = self._resolved_extra_args(tool_name)
|
||||
filtered_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in {
|
||||
"chat_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"session",
|
||||
"thread",
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
}
|
||||
if k != "_meta" and (k in extras or (k in declared and k not in _MCP_FRAMEWORK_DENYLIST))
|
||||
}
|
||||
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
@@ -1643,9 +1722,7 @@ class MCPTool:
|
||||
return parser(fallback_result)
|
||||
|
||||
if task_id is None:
|
||||
raise ToolExecutionException(
|
||||
f"MCP server did not return a task_id or fallback result for '{tool_name}'."
|
||||
)
|
||||
raise ToolExecutionException(f"MCP server did not return a task_id or fallback result for '{tool_name}'.")
|
||||
|
||||
# Track to completion: poll until terminal, then fetch payload. Never re-issue
|
||||
# tools/call past this point; reconnect-and-retry only against the same task_id.
|
||||
@@ -1765,9 +1842,7 @@ class MCPTool:
|
||||
transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)})
|
||||
|
||||
while True:
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
request = types.ClientRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id)))
|
||||
try:
|
||||
# GetTaskResult.ttl is required-but-Optional in the SDK; coerce below.
|
||||
lenient = await self._send_with_one_reconnect(
|
||||
@@ -1775,9 +1850,7 @@ class MCPTool:
|
||||
)
|
||||
except McpError as ex:
|
||||
if ex.error.code in transient_codes:
|
||||
logger.debug(
|
||||
"Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id
|
||||
)
|
||||
logger.debug("Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id)
|
||||
await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds())
|
||||
continue
|
||||
# Hard server error mid-poll: task may still be running.
|
||||
@@ -1906,9 +1979,7 @@ class MCPTool:
|
||||
if not self._is_connection_lost(ex):
|
||||
raise
|
||||
if attempt < _MCP_RECONNECT_ATTEMPTS - 1:
|
||||
logger.info(
|
||||
"MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id
|
||||
)
|
||||
logger.info("MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id)
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
except Exception as reconn_ex:
|
||||
@@ -1967,9 +2038,7 @@ class MCPTool:
|
||||
"""
|
||||
from mcp import types
|
||||
|
||||
request = types.ClientRequest(
|
||||
types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
request = types.ClientRequest(types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id)))
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr]
|
||||
@@ -1979,8 +2048,7 @@ class MCPTool:
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Best-effort tasks/cancel for '%s' timed out after %.1fs; "
|
||||
"remote task may still be running.",
|
||||
"Best-effort tasks/cancel for '%s' timed out after %.1fs; remote task may still be running.",
|
||||
task_id,
|
||||
_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
|
||||
)
|
||||
@@ -2153,6 +2221,7 @@ class MCPStdioTool(MCPTool):
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP stdio tool.
|
||||
@@ -2199,6 +2268,20 @@ class MCPStdioTool(MCPTool):
|
||||
client: The chat client to use for sampling.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
addition to each tool's declared parameters (from its ``inputSchema.properties``).
|
||||
By default only declared parameters are sent; framework runtime kwargs injected
|
||||
through the function-invocation pipeline are stripped. Use this to opt specific
|
||||
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
|
||||
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
|
||||
``"*"`` applies to every tool. This is configured only here in user code; there is
|
||||
no per-call override, so a model-issued tool call cannot change which names pass
|
||||
through. To use a server that accepts ``additionalProperties: true``, list the
|
||||
extra names here and then either (1) manually extend that tool's ``inputSchema``
|
||||
(via the ``.functions`` list after connecting) so the model is prompted to supply
|
||||
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
|
||||
a name is supplied via both the model and ``function_invocation_kwargs``, the
|
||||
model-supplied value wins.
|
||||
kwargs: Any extra arguments to pass to the stdio client.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -2216,6 +2299,7 @@ class MCPStdioTool(MCPTool):
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
)
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
@@ -2295,6 +2379,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
http_client: AsyncClient | None = None,
|
||||
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP streamable HTTP tool.
|
||||
@@ -2349,6 +2434,20 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
agent middleware) without creating a separate ``httpx.AsyncClient``.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
addition to each tool's declared parameters (from its ``inputSchema.properties``).
|
||||
By default only declared parameters are sent; framework runtime kwargs injected
|
||||
through the function-invocation pipeline are stripped. Use this to opt specific
|
||||
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
|
||||
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
|
||||
``"*"`` applies to every tool. This is configured only here in user code; there is
|
||||
no per-call override, so a model-issued tool call cannot change which names pass
|
||||
through. To use a server that accepts ``additionalProperties: true``, list the
|
||||
extra names here and then either (1) manually extend that tool's ``inputSchema``
|
||||
(via the ``.functions`` list after connecting) so the model is prompted to supply
|
||||
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
|
||||
a name is supplied via both the model and ``function_invocation_kwargs``, the
|
||||
model-supplied value wins.
|
||||
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -2366,6 +2465,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
)
|
||||
self.url = url
|
||||
self.terminate_on_close = terminate_on_close
|
||||
@@ -2492,6 +2592,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP WebSocket tool.
|
||||
@@ -2536,6 +2637,20 @@ class MCPWebsocketTool(MCPTool):
|
||||
client: The chat client to use for sampling.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
addition to each tool's declared parameters (from its ``inputSchema.properties``).
|
||||
By default only declared parameters are sent; framework runtime kwargs injected
|
||||
through the function-invocation pipeline are stripped. Use this to opt specific
|
||||
keys back in. Accepts either a ``Sequence[str]`` applied to every tool, or a
|
||||
``Mapping[str, Sequence[str]]`` keyed by remote tool name where the reserved key
|
||||
``"*"`` applies to every tool. This is configured only here in user code; there is
|
||||
no per-call override, so a model-issued tool call cannot change which names pass
|
||||
through. To use a server that accepts ``additionalProperties: true``, list the
|
||||
extra names here and then either (1) manually extend that tool's ``inputSchema``
|
||||
(via the ``.functions`` list after connecting) so the model is prompted to supply
|
||||
them, or (2) supply the values yourself through ``function_invocation_kwargs``. If
|
||||
a name is supplied via both the model and ``function_invocation_kwargs``, the
|
||||
model-supplied value wins.
|
||||
kwargs: Any extra arguments to pass to the WebSocket client.
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -2553,6 +2668,7 @@ class MCPWebsocketTool(MCPTool):
|
||||
parse_prompt_results=parse_prompt_results,
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
)
|
||||
self.url = url
|
||||
self._client_kwargs = kwargs
|
||||
|
||||
@@ -30,6 +30,7 @@ from agent_framework._mcp import (
|
||||
MCPTool,
|
||||
_build_prefixed_mcp_name,
|
||||
_get_input_model_from_mcp_prompt,
|
||||
_normalize_additional_tool_argument_names,
|
||||
_normalize_mcp_name,
|
||||
_should_propagate_cancelled_error,
|
||||
logger,
|
||||
@@ -6057,3 +6058,205 @@ async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region additional_tool_argument_names / allowlist filtering
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_none() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names(None)
|
||||
assert global_extras == set()
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_sequence() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names(["a", "b", "a"])
|
||||
assert global_extras == {"a", "b"}
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_single_string() -> None:
|
||||
# A bare string must be treated as a single name, not split into characters.
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names("conversation_id")
|
||||
assert global_extras == {"conversation_id"}
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_mapping_with_global_key() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names({
|
||||
"*": ["g1"],
|
||||
"tool_a": ["a1", "a2"],
|
||||
"tool_b": ["b1"],
|
||||
})
|
||||
assert global_extras == {"g1"}
|
||||
assert per_tool == {"tool_a": {"a1", "a2"}, "tool_b": {"b1"}}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_mapping_with_string_values() -> None:
|
||||
# A bare string mapping value is a single name, not an iterable of characters.
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names({
|
||||
"*": "conversation_id",
|
||||
"tool_a": "custom",
|
||||
})
|
||||
assert global_extras == {"conversation_id"}
|
||||
assert per_tool == {"tool_a": {"custom"}}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_strips_undeclared_arguments() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, meta = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "value", "conversation_id": "c", "thread": object(), "unexpected": 1},
|
||||
)
|
||||
|
||||
assert filtered == {"param": "value"}
|
||||
assert meta is None
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_global_extras_allowed() -> None:
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "value", "conversation_id": "c", "options": {}},
|
||||
)
|
||||
|
||||
assert filtered == {"param": "value", "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_per_tool_and_global_extras() -> None:
|
||||
server = MCPTool(
|
||||
name="test_server",
|
||||
additional_tool_argument_names={"*": ["conversation_id"], "test_tool": ["custom"]},
|
||||
)
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}, "other_tool": {"x"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "conversation_id": "c", "custom": "y", "thread": object()},
|
||||
)
|
||||
assert filtered == {"param": "v", "conversation_id": "c", "custom": "y"}
|
||||
|
||||
# The per-tool extra does not leak to other tools; the global one still applies.
|
||||
filtered_other, _ = server._prepare_call_kwargs(
|
||||
"other_tool",
|
||||
{"x": 1, "conversation_id": "c", "custom": "y"},
|
||||
)
|
||||
assert filtered_other == {"x": 1, "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None:
|
||||
# The denylist is a safety net for framework-named params a server *declares* in its
|
||||
# schema: they are dropped so internal objects never leak. Names explicitly opted in
|
||||
# via extras always win.
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param", "thread"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "thread": object(), "conversation_id": "c"},
|
||||
)
|
||||
# "thread" is declared by the schema but denylisted -> dropped; conversation_id opted in -> kept.
|
||||
assert filtered == {"param": "v", "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_extras_override_denylist() -> None:
|
||||
# Opting a denylisted framework name back in via extras takes precedence over the
|
||||
# denylist safety net. "thread" is on the framework denylist, but an explicit extra wins.
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["thread"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
sentinel = object()
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "thread": sentinel, "conversation_id": "c"},
|
||||
)
|
||||
# "thread" opted in via extras -> kept despite the denylist; conversation_id is denylisted,
|
||||
# not declared, and not opted in -> dropped.
|
||||
assert filtered == {"param": "v", "thread": sentinel}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": set()}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"conversation_id": "c", "thread": object(), "stray": 1},
|
||||
)
|
||||
assert filtered == {}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None:
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
# No entry in _tool_param_names_by_name for this tool name.
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"unknown_tool",
|
||||
{"conversation_id": "c", "other": 1},
|
||||
)
|
||||
assert filtered == {"conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_extracts_meta() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, meta = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "_meta": {"trace": "abc"}},
|
||||
)
|
||||
assert filtered == {"param": "v"}
|
||||
assert meta is not None
|
||||
assert meta.get("trace") == "abc"
|
||||
|
||||
|
||||
async def test_call_tool_forwards_only_declared_arguments() -> None:
|
||||
"""End-to-end: framework runtime kwargs are stripped before reaching the server."""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")])
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
session_mock = server.session
|
||||
await server.call_tool(
|
||||
"test_tool",
|
||||
param="value",
|
||||
conversation_id="c",
|
||||
thread=object(),
|
||||
response_format=object(),
|
||||
)
|
||||
|
||||
session_mock.call_tool.assert_called_once()
|
||||
_, call_kwargs = session_mock.call_tool.call_args
|
||||
assert call_kwargs["arguments"] == {"param": "value", "conversation_id": "c"}
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user