Python: Filter MCP tool kwargs to declared params via allowlist (#6399)

* Filter MCP tool kwargs to declared params via allowlist

Previously MCPTool combined framework runtime kwargs (from
FunctionInvocationContext.kwargs) with the LLM-supplied arguments and
stripped only a hardcoded denylist of known framework keys before
forwarding to the MCP server. Any new framework-injected kwarg leaked to
the server unless the denylist was updated.

Switch to an allowlist built from each tool's declared parameters
(inputSchema.properties). Only declared params are forwarded; everything
else is stripped. Add an `additional_tool_argument_names` constructor
argument so users can opt extra names back in, globally (Sequence[str])
and/or per remote tool name (Mapping with reserved "*" global key). The
existing denylist is kept as a safety net for framework-named params a
server declares in its schema; explicitly opted-in extras always win. The
reserved _meta handling is unchanged.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address MCP allowlist review comments and fix reload arg loss

- Fix pyright reportUnknownArgumentType in _load_tools (cast schema properties).
- Register declared param names before the existing-tool skip guard so that
  tool-list reloads preserve the allowlist for already-loaded tools (previously
  unchanged tools silently dropped all declared args after a background reload).
- Handle bare-string values in an additional_tool_argument_names mapping instead
  of iterating their characters.
- Clarify the framework denylist comment: explicit extras override the denylist.
- Make the extras-override-denylist test unambiguous (opt in a denylisted name).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-06-09 09:37:11 +02:00
committed by GitHub
Unverified
parent d222079df9
commit cfb033e5d4
3 changed files with 365 additions and 44 deletions
+203
View File
@@ -30,6 +30,7 @@ from agent_framework._mcp import (
MCPTool,
_build_prefixed_mcp_name,
_get_input_model_from_mcp_prompt,
_normalize_additional_tool_argument_names,
_normalize_mcp_name,
_should_propagate_cancelled_error,
logger,
@@ -6057,3 +6058,205 @@ async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPat
# endregion
# region additional_tool_argument_names / allowlist filtering
def test_normalize_additional_tool_argument_names_none() -> None:
global_extras, per_tool = _normalize_additional_tool_argument_names(None)
assert global_extras == set()
assert per_tool == {}
def test_normalize_additional_tool_argument_names_sequence() -> None:
global_extras, per_tool = _normalize_additional_tool_argument_names(["a", "b", "a"])
assert global_extras == {"a", "b"}
assert per_tool == {}
def test_normalize_additional_tool_argument_names_single_string() -> None:
# A bare string must be treated as a single name, not split into characters.
global_extras, per_tool = _normalize_additional_tool_argument_names("conversation_id")
assert global_extras == {"conversation_id"}
assert per_tool == {}
def test_normalize_additional_tool_argument_names_mapping_with_global_key() -> None:
global_extras, per_tool = _normalize_additional_tool_argument_names({
"*": ["g1"],
"tool_a": ["a1", "a2"],
"tool_b": ["b1"],
})
assert global_extras == {"g1"}
assert per_tool == {"tool_a": {"a1", "a2"}, "tool_b": {"b1"}}
def test_normalize_additional_tool_argument_names_mapping_with_string_values() -> None:
# A bare string mapping value is a single name, not an iterable of characters.
global_extras, per_tool = _normalize_additional_tool_argument_names({
"*": "conversation_id",
"tool_a": "custom",
})
assert global_extras == {"conversation_id"}
assert per_tool == {"tool_a": {"custom"}}
def test_prepare_call_kwargs_strips_undeclared_arguments() -> None:
server = MCPTool(name="test_server")
server._tool_param_names_by_name = {"test_tool": {"param"}}
filtered, meta = server._prepare_call_kwargs(
"test_tool",
{"param": "value", "conversation_id": "c", "thread": object(), "unexpected": 1},
)
assert filtered == {"param": "value"}
assert meta is None
def test_prepare_call_kwargs_global_extras_allowed() -> None:
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
server._tool_param_names_by_name = {"test_tool": {"param"}}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "value", "conversation_id": "c", "options": {}},
)
assert filtered == {"param": "value", "conversation_id": "c"}
def test_prepare_call_kwargs_per_tool_and_global_extras() -> None:
server = MCPTool(
name="test_server",
additional_tool_argument_names={"*": ["conversation_id"], "test_tool": ["custom"]},
)
server._tool_param_names_by_name = {"test_tool": {"param"}, "other_tool": {"x"}}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "conversation_id": "c", "custom": "y", "thread": object()},
)
assert filtered == {"param": "v", "conversation_id": "c", "custom": "y"}
# The per-tool extra does not leak to other tools; the global one still applies.
filtered_other, _ = server._prepare_call_kwargs(
"other_tool",
{"x": 1, "conversation_id": "c", "custom": "y"},
)
assert filtered_other == {"x": 1, "conversation_id": "c"}
def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None:
# The denylist is a safety net for framework-named params a server *declares* in its
# schema: they are dropped so internal objects never leak. Names explicitly opted in
# via extras always win.
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
server._tool_param_names_by_name = {"test_tool": {"param", "thread"}}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "thread": object(), "conversation_id": "c"},
)
# "thread" is declared by the schema but denylisted -> dropped; conversation_id opted in -> kept.
assert filtered == {"param": "v", "conversation_id": "c"}
def test_prepare_call_kwargs_extras_override_denylist() -> None:
# Opting a denylisted framework name back in via extras takes precedence over the
# denylist safety net. "thread" is on the framework denylist, but an explicit extra wins.
server = MCPTool(name="test_server", additional_tool_argument_names=["thread"])
server._tool_param_names_by_name = {"test_tool": {"param"}}
sentinel = object()
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "thread": sentinel, "conversation_id": "c"},
)
# "thread" opted in via extras -> kept despite the denylist; conversation_id is denylisted,
# not declared, and not opted in -> dropped.
assert filtered == {"param": "v", "thread": sentinel}
def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None:
server = MCPTool(name="test_server")
server._tool_param_names_by_name = {"test_tool": set()}
filtered, _ = server._prepare_call_kwargs(
"test_tool",
{"conversation_id": "c", "thread": object(), "stray": 1},
)
assert filtered == {}
def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None:
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
# No entry in _tool_param_names_by_name for this tool name.
filtered, _ = server._prepare_call_kwargs(
"unknown_tool",
{"conversation_id": "c", "other": 1},
)
assert filtered == {"conversation_id": "c"}
def test_prepare_call_kwargs_extracts_meta() -> None:
server = MCPTool(name="test_server")
server._tool_param_names_by_name = {"test_tool": {"param"}}
filtered, meta = server._prepare_call_kwargs(
"test_tool",
{"param": "v", "_meta": {"trace": "abc"}},
)
assert filtered == {"param": "v"}
assert meta is not None
assert meta.get("trace") == "abc"
async def test_call_tool_forwards_only_declared_arguments() -> None:
"""End-to-end: framework runtime kwargs are stripped before reaching the server."""
class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")])
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server", additional_tool_argument_names=["conversation_id"])
async with server:
await server.load_tools()
session_mock = server.session
await server.call_tool(
"test_tool",
param="value",
conversation_id="c",
thread=object(),
response_format=object(),
)
session_mock.call_tool.assert_called_once()
_, call_kwargs = session_mock.call_tool.call_args
assert call_kwargs["arguments"] == {"param": "value", "conversation_id": "c"}
# endregion