Python: added approval_mode and allowed_tools to local MCP (#1203)

* added approval_mode and allowed_tools to local MCP

* updated docs
This commit is contained in:
Eduard van Valkenburg
2025-10-07 09:44:08 +02:00
committed by GitHub
Unverified
parent 7ebe00ec3d
commit b49395fa3d
2 changed files with 196 additions and 18 deletions
+73 -18
View File
@@ -5,10 +5,11 @@ import logging
import re
import sys
from abc import abstractmethod
from collections.abc import Collection
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from mcp import types
from mcp.client.session import ClientSession
@@ -20,7 +21,7 @@ from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from pydantic import BaseModel, create_model
from ._tools import AIFunction
from ._tools import AIFunction, HostedMCPSpecificApproval
from ._types import ChatMessage, Contents, DataContent, Role, TextContent, UriContent
from .exceptions import ToolException, ToolExecutionException
@@ -264,27 +265,25 @@ class MCPTool:
self,
name: str,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
allowed_tools: Collection[str] | None = None,
load_tools: bool = True,
load_prompts: bool = True,
session: ClientSession | None = None,
request_timeout: int | None = None,
chat_client: "ChatClientProtocol | None" = None,
additional_properties: dict[str, Any] | None = None,
) -> None:
"""Initialize the MCP Tool base.
Args:
name: The name of the MCP tool.
description: The description of the tool.
additional_properties: Additional properties for the tool.
load_tools: Whether to automatically load tools from the MCP server.
load_prompts: Whether to automatically load prompts from the MCP server.
session: Pre-existing session to use for the MCP connection.
request_timeout: The default timeout in seconds for all requests.
chat_client: The chat client to use for sampling callbacks.
Note:
Do not use this method, use one of the subclasses: MCPStreamableHTTPTool, MCPWebsocketTool
or MCPStdioTool.
"""
self.name = name
self.description = description or ""
self.approval_mode = approval_mode
self.allowed_tools = allowed_tools
self.additional_properties = additional_properties
self.load_tools_flag = load_tools
self.load_prompts_flag = load_prompts
@@ -292,12 +291,19 @@ class MCPTool:
self.session = session
self.request_timeout = request_timeout
self.chat_client = chat_client
self.functions: list[AIFunction[Any, Any]] = []
self._functions: list[AIFunction[Any, Any]] = []
self.is_connected: bool = False
def __str__(self) -> str:
return f"MCPTool(name={self.name}, description={self.description})"
@property
def functions(self) -> list[AIFunction[Any, Any]]:
"""Get the list of functions that are allowed."""
if not self.allowed_tools:
return self._functions
return [func for func in self._functions if func.name in self.allowed_tools]
async def connect(self) -> None:
"""Connect to the MCP server.
@@ -458,6 +464,18 @@ class MCPTool:
case _:
logger.debug("Unhandled notification: %s", message.root.method)
def _determine_approval_mode(
self,
local_name: str,
) -> Literal["always_require", "never_require"] | None:
if isinstance(self.approval_mode, dict):
if (always_require := self.approval_mode.get("always_require_approval")) and local_name in always_require:
return "always_require"
if (never_require := self.approval_mode.get("never_require_approval")) and local_name in never_require:
return "never_require"
return None
return self.approval_mode # type: ignore[reportReturnType]
async def load_prompts(self) -> None:
"""Load prompts from the MCP server.
@@ -480,13 +498,15 @@ class MCPTool:
for prompt in prompt_list.prompts if prompt_list else []:
local_name = _normalize_mcp_name(prompt.name)
input_model = _get_input_model_from_mcp_prompt(prompt)
approval_mode = self._determine_approval_mode(local_name)
func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
approval_mode=approval_mode,
input_model=input_model,
)
self.functions.append(func)
self._functions.append(func)
async def load_tools(self) -> None:
"""Load tools from the MCP server.
@@ -510,14 +530,16 @@ class MCPTool:
for tool in tool_list.tools if tool_list else []:
local_name = _normalize_mcp_name(tool.name)
input_model = _get_input_model_from_mcp_tool(tool)
approval_mode = self._determine_approval_mode(local_name)
# Create AIFunctions out of each tool
func: AIFunction[BaseModel, list[Contents]] = AIFunction(
func=partial(self.call_tool, tool.name),
name=local_name,
description=tool.description or "",
approval_mode=approval_mode,
input_model=input_model,
)
self.functions.append(func)
self._functions.append(func)
async def close(self) -> None:
"""Disconnect from the MCP server.
@@ -670,11 +692,13 @@ class MCPStdioTool(MCPTool):
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
allowed_tools: Collection[str] | None = None,
args: list[str] | None = None,
env: dict[str, str] | None = None,
encoding: str | None = None,
chat_client: "ChatClientProtocol | None" = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP stdio tool.
@@ -694,6 +718,13 @@ class MCPStdioTool(MCPTool):
request_timeout: The default timeout in seconds for all requests.
session: The session to use for the MCP connection.
description: The description of the tool.
approval_mode: The approval mode for the tool. This can be:
- "always_require": The tool always requires approval before use.
- "never_require": The tool never requires approval before use.
- A dict with keys `always_require_approval` or `never_require_approval`,
followed by a sequence of strings with the names of the relevant tools.
A tool should not be listed in both, if so, it will require approval.
allowed_tools: A list of tools that are allowed to use this tool.
additional_properties: Additional properties.
args: The arguments to pass to the command.
env: The environment variables to set for the command.
@@ -704,6 +735,8 @@ class MCPStdioTool(MCPTool):
super().__init__(
name=name,
description=description,
approval_mode=approval_mode,
allowed_tools=allowed_tools,
additional_properties=additional_properties,
session=session,
chat_client=chat_client,
@@ -769,12 +802,14 @@ class MCPStreamableHTTPTool(MCPTool):
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
allowed_tools: Collection[str] | None = None,
headers: dict[str, Any] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
terminate_on_close: bool | None = None,
chat_client: "ChatClientProtocol | None" = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.
@@ -795,6 +830,13 @@ class MCPStreamableHTTPTool(MCPTool):
request_timeout: The default timeout in seconds for all requests.
session: The session to use for the MCP connection.
description: The description of the tool.
approval_mode: The approval mode for the tool. This can be:
- "always_require": The tool always requires approval before use.
- "never_require": The tool never requires approval before use.
- A dict with keys `always_require_approval` or `never_require_approval`,
followed by a sequence of strings with the names of the relevant tools.
A tool should not be listed in both, if so, it will require approval.
allowed_tools: A list of tools that are allowed to use this tool.
additional_properties: Additional properties.
headers: The headers to send with the request.
timeout: The timeout for the request.
@@ -806,6 +848,8 @@ class MCPStreamableHTTPTool(MCPTool):
super().__init__(
name=name,
description=description,
approval_mode=approval_mode,
allowed_tools=allowed_tools,
additional_properties=additional_properties,
session=session,
chat_client=chat_client,
@@ -873,8 +917,10 @@ class MCPWebsocketTool(MCPTool):
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
allowed_tools: Collection[str] | None = None,
chat_client: "ChatClientProtocol | None" = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP WebSocket tool.
@@ -895,6 +941,13 @@ class MCPWebsocketTool(MCPTool):
request_timeout: The default timeout in seconds for all requests.
session: The session to use for the MCP connection.
description: The description of the tool.
approval_mode: The approval mode for the tool. This can be:
- "always_require": The tool always requires approval before use.
- "never_require": The tool never requires approval before use.
- A dict with keys `always_require_approval` or `never_require_approval`,
followed by a sequence of strings with the names of the relevant tools.
A tool should not be listed in both, if so, it will require approval.
allowed_tools: A list of tools that are allowed to use this tool.
additional_properties: Additional properties.
chat_client: The chat client to use for sampling.
kwargs: Any extra arguments to pass to the WebSocket client.
@@ -902,6 +955,8 @@ class MCPWebsocketTool(MCPTool):
super().__init__(
name=name,
description=description,
approval_mode=approval_mode,
allowed_tools=allowed_tools,
additional_properties=additional_properties,
session=session,
chat_client=chat_client,
+123
View File
@@ -604,6 +604,129 @@ async def test_local_mcp_server_prompt_execution():
assert result[0].contents[0].text == "Test message"
@pytest.mark.parametrize(
"approval_mode,expected_approvals",
[
("always_require", {"tool_one": "always_require", "tool_two": "always_require"}),
("never_require", {"tool_one": "never_require", "tool_two": "never_require"}),
(
{"always_require_approval": ["tool_one"], "never_require_approval": ["tool_two"]},
{"tool_one": "always_require", "tool_two": "never_require"},
),
],
)
async def test_mcp_tool_approval_mode(approval_mode, expected_approvals):
"""Test MCPTool approval_mode parameter with various configurations.
The approval_mode parameter controls whether tools require approval before execution.
It can be set globally ("always_require" or "never_require") or per-tool using a dict.
"""
class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="tool_one",
description="First tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
},
),
types.Tool(
name="tool_two",
description="Second tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
},
),
]
)
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server", approval_mode=approval_mode)
async with server:
await server.load_tools()
assert len(server.functions) == 2
# Verify each tool has the expected approval mode
for func in server.functions:
assert func.approval_mode == expected_approvals[func.name]
@pytest.mark.parametrize(
"allowed_tools,expected_count,expected_names",
[
(None, 3, ["tool_one", "tool_two", "tool_three"]), # None means all tools are allowed
(["tool_one"], 1, ["tool_one"]), # Only tool_one is allowed
(["tool_one", "tool_three"], 2, ["tool_one", "tool_three"]), # Two tools allowed
(["nonexistent_tool"], 0, []), # No matching tools
],
)
async def test_mcp_tool_allowed_tools(allowed_tools, expected_count, expected_names):
"""Test MCPTool allowed_tools parameter with various configurations.
The allowed_tools parameter filters which tools are exposed via the functions property.
When None, all loaded tools are available. When set to a list, only tools whose names
are in that list are exposed.
"""
class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="tool_one",
description="First tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
},
),
types.Tool(
name="tool_two",
description="Second tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
},
),
types.Tool(
name="tool_three",
description="Third tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
},
),
]
)
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server", allowed_tools=allowed_tools)
async with server:
await server.load_tools()
# _functions should contain all tools
assert len(server._functions) == 3
# functions property should filter based on allowed_tools
assert len(server.functions) == expected_count
actual_names = [func.name for func in server.functions]
assert sorted(actual_names) == sorted(expected_names)
# Server implementation tests
def test_local_mcp_stdio_tool_init():
"""Test MCPStdioTool initialization."""