mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Filter MCP tool kwargs to declared params via allowlist (#6399)
* Filter MCP tool kwargs to declared params via allowlist Previously MCPTool combined framework runtime kwargs (from FunctionInvocationContext.kwargs) with the LLM-supplied arguments and stripped only a hardcoded denylist of known framework keys before forwarding to the MCP server. Any new framework-injected kwarg leaked to the server unless the denylist was updated. Switch to an allowlist built from each tool's declared parameters (inputSchema.properties). Only declared params are forwarded; everything else is stripped. Add an `additional_tool_argument_names` constructor argument so users can opt extra names back in, globally (Sequence[str]) and/or per remote tool name (Mapping with reserved "*" global key). The existing denylist is kept as a safety net for framework-named params a server declares in its schema; explicitly opted-in extras always win. The reserved _meta handling is unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address MCP allowlist review comments and fix reload arg loss - Fix pyright reportUnknownArgumentType in _load_tools (cast schema properties). - Register declared param names before the existing-tool skip guard so that tool-list reloads preserve the allowlist for already-loaded tools (previously unchanged tools silently dropped all declared args after a background reload). - Handle bare-string values in an additional_tool_argument_names mapping instead of iterating their characters. - Clarify the framework denylist comment: explicit extras override the denylist. - Make the extras-override-denylist test unambiguous (opt in a denylisted name). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
d222079df9
commit
cfb033e5d4
@@ -30,6 +30,7 @@ from agent_framework._mcp import (
|
||||
MCPTool,
|
||||
_build_prefixed_mcp_name,
|
||||
_get_input_model_from_mcp_prompt,
|
||||
_normalize_additional_tool_argument_names,
|
||||
_normalize_mcp_name,
|
||||
_should_propagate_cancelled_error,
|
||||
logger,
|
||||
@@ -6057,3 +6058,205 @@ async def test_max_wait_interrupts_long_poll_sleep(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region additional_tool_argument_names / allowlist filtering
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_none() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names(None)
|
||||
assert global_extras == set()
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_sequence() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names(["a", "b", "a"])
|
||||
assert global_extras == {"a", "b"}
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_single_string() -> None:
|
||||
# A bare string must be treated as a single name, not split into characters.
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names("conversation_id")
|
||||
assert global_extras == {"conversation_id"}
|
||||
assert per_tool == {}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_mapping_with_global_key() -> None:
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names({
|
||||
"*": ["g1"],
|
||||
"tool_a": ["a1", "a2"],
|
||||
"tool_b": ["b1"],
|
||||
})
|
||||
assert global_extras == {"g1"}
|
||||
assert per_tool == {"tool_a": {"a1", "a2"}, "tool_b": {"b1"}}
|
||||
|
||||
|
||||
def test_normalize_additional_tool_argument_names_mapping_with_string_values() -> None:
|
||||
# A bare string mapping value is a single name, not an iterable of characters.
|
||||
global_extras, per_tool = _normalize_additional_tool_argument_names({
|
||||
"*": "conversation_id",
|
||||
"tool_a": "custom",
|
||||
})
|
||||
assert global_extras == {"conversation_id"}
|
||||
assert per_tool == {"tool_a": {"custom"}}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_strips_undeclared_arguments() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, meta = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "value", "conversation_id": "c", "thread": object(), "unexpected": 1},
|
||||
)
|
||||
|
||||
assert filtered == {"param": "value"}
|
||||
assert meta is None
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_global_extras_allowed() -> None:
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "value", "conversation_id": "c", "options": {}},
|
||||
)
|
||||
|
||||
assert filtered == {"param": "value", "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_per_tool_and_global_extras() -> None:
|
||||
server = MCPTool(
|
||||
name="test_server",
|
||||
additional_tool_argument_names={"*": ["conversation_id"], "test_tool": ["custom"]},
|
||||
)
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}, "other_tool": {"x"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "conversation_id": "c", "custom": "y", "thread": object()},
|
||||
)
|
||||
assert filtered == {"param": "v", "conversation_id": "c", "custom": "y"}
|
||||
|
||||
# The per-tool extra does not leak to other tools; the global one still applies.
|
||||
filtered_other, _ = server._prepare_call_kwargs(
|
||||
"other_tool",
|
||||
{"x": 1, "conversation_id": "c", "custom": "y"},
|
||||
)
|
||||
assert filtered_other == {"x": 1, "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None:
|
||||
# The denylist is a safety net for framework-named params a server *declares* in its
|
||||
# schema: they are dropped so internal objects never leak. Names explicitly opted in
|
||||
# via extras always win.
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param", "thread"}}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "thread": object(), "conversation_id": "c"},
|
||||
)
|
||||
# "thread" is declared by the schema but denylisted -> dropped; conversation_id opted in -> kept.
|
||||
assert filtered == {"param": "v", "conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_extras_override_denylist() -> None:
|
||||
# Opting a denylisted framework name back in via extras takes precedence over the
|
||||
# denylist safety net. "thread" is on the framework denylist, but an explicit extra wins.
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["thread"])
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
sentinel = object()
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "thread": sentinel, "conversation_id": "c"},
|
||||
)
|
||||
# "thread" opted in via extras -> kept despite the denylist; conversation_id is denylisted,
|
||||
# not declared, and not opted in -> dropped.
|
||||
assert filtered == {"param": "v", "thread": sentinel}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": set()}
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"conversation_id": "c", "thread": object(), "stray": 1},
|
||||
)
|
||||
assert filtered == {}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None:
|
||||
server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
# No entry in _tool_param_names_by_name for this tool name.
|
||||
|
||||
filtered, _ = server._prepare_call_kwargs(
|
||||
"unknown_tool",
|
||||
{"conversation_id": "c", "other": 1},
|
||||
)
|
||||
assert filtered == {"conversation_id": "c"}
|
||||
|
||||
|
||||
def test_prepare_call_kwargs_extracts_meta() -> None:
|
||||
server = MCPTool(name="test_server")
|
||||
server._tool_param_names_by_name = {"test_tool": {"param"}}
|
||||
|
||||
filtered, meta = server._prepare_call_kwargs(
|
||||
"test_tool",
|
||||
{"param": "v", "_meta": {"trace": "abc"}},
|
||||
)
|
||||
assert filtered == {"param": "v"}
|
||||
assert meta is not None
|
||||
assert meta.get("trace") == "abc"
|
||||
|
||||
|
||||
async def test_call_tool_forwards_only_declared_arguments() -> None:
|
||||
"""End-to-end: framework runtime kwargs are stripped before reaching the server."""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")])
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server", additional_tool_argument_names=["conversation_id"])
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
session_mock = server.session
|
||||
await server.call_tool(
|
||||
"test_tool",
|
||||
param="value",
|
||||
conversation_id="c",
|
||||
thread=object(),
|
||||
response_format=object(),
|
||||
)
|
||||
|
||||
session_mock.call_tool.assert_called_once()
|
||||
_, call_kwargs = session_mock.call_tool.call_args
|
||||
assert call_kwargs["arguments"] == {"param": "value", "conversation_id": "c"}
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user