mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: added approval_mode and allowed_tools to local MCP (#1203)
* added approval_mode and allowed_tools to local MCP * updated docs
This commit is contained in:
committed by
GitHub
Unverified
parent
7ebe00ec3d
commit
b49395fa3d
@@ -5,10 +5,11 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Collection
|
||||
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from mcp import types
|
||||
from mcp.client.session import ClientSession
|
||||
@@ -20,7 +21,7 @@ from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from ._tools import AIFunction
|
||||
from ._tools import AIFunction, HostedMCPSpecificApproval
|
||||
from ._types import ChatMessage, Contents, DataContent, Role, TextContent, UriContent
|
||||
from .exceptions import ToolException, ToolExecutionException
|
||||
|
||||
@@ -264,27 +265,25 @@ class MCPTool:
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
load_tools: bool = True,
|
||||
load_prompts: bool = True,
|
||||
session: ClientSession | None = None,
|
||||
request_timeout: int | None = None,
|
||||
chat_client: "ChatClientProtocol | None" = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the MCP Tool base.
|
||||
|
||||
Args:
|
||||
name: The name of the MCP tool.
|
||||
description: The description of the tool.
|
||||
additional_properties: Additional properties for the tool.
|
||||
load_tools: Whether to automatically load tools from the MCP server.
|
||||
load_prompts: Whether to automatically load prompts from the MCP server.
|
||||
session: Pre-existing session to use for the MCP connection.
|
||||
request_timeout: The default timeout in seconds for all requests.
|
||||
chat_client: The chat client to use for sampling callbacks.
|
||||
Note:
|
||||
Do not use this method, use one of the subclasses: MCPStreamableHTTPTool, MCPWebsocketTool
|
||||
or MCPStdioTool.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description or ""
|
||||
self.approval_mode = approval_mode
|
||||
self.allowed_tools = allowed_tools
|
||||
self.additional_properties = additional_properties
|
||||
self.load_tools_flag = load_tools
|
||||
self.load_prompts_flag = load_prompts
|
||||
@@ -292,12 +291,19 @@ class MCPTool:
|
||||
self.session = session
|
||||
self.request_timeout = request_timeout
|
||||
self.chat_client = chat_client
|
||||
self.functions: list[AIFunction[Any, Any]] = []
|
||||
self._functions: list[AIFunction[Any, Any]] = []
|
||||
self.is_connected: bool = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MCPTool(name={self.name}, description={self.description})"
|
||||
|
||||
@property
|
||||
def functions(self) -> list[AIFunction[Any, Any]]:
|
||||
"""Get the list of functions that are allowed."""
|
||||
if not self.allowed_tools:
|
||||
return self._functions
|
||||
return [func for func in self._functions if func.name in self.allowed_tools]
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to the MCP server.
|
||||
|
||||
@@ -458,6 +464,18 @@ class MCPTool:
|
||||
case _:
|
||||
logger.debug("Unhandled notification: %s", message.root.method)
|
||||
|
||||
def _determine_approval_mode(
|
||||
self,
|
||||
local_name: str,
|
||||
) -> Literal["always_require", "never_require"] | None:
|
||||
if isinstance(self.approval_mode, dict):
|
||||
if (always_require := self.approval_mode.get("always_require_approval")) and local_name in always_require:
|
||||
return "always_require"
|
||||
if (never_require := self.approval_mode.get("never_require_approval")) and local_name in never_require:
|
||||
return "never_require"
|
||||
return None
|
||||
return self.approval_mode # type: ignore[reportReturnType]
|
||||
|
||||
async def load_prompts(self) -> None:
|
||||
"""Load prompts from the MCP server.
|
||||
|
||||
@@ -480,13 +498,15 @@ class MCPTool:
|
||||
for prompt in prompt_list.prompts if prompt_list else []:
|
||||
local_name = _normalize_mcp_name(prompt.name)
|
||||
input_model = _get_input_model_from_mcp_prompt(prompt)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction(
|
||||
func=partial(self.get_prompt, prompt.name),
|
||||
name=local_name,
|
||||
description=prompt.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
)
|
||||
self.functions.append(func)
|
||||
self._functions.append(func)
|
||||
|
||||
async def load_tools(self) -> None:
|
||||
"""Load tools from the MCP server.
|
||||
@@ -510,14 +530,16 @@ class MCPTool:
|
||||
for tool in tool_list.tools if tool_list else []:
|
||||
local_name = _normalize_mcp_name(tool.name)
|
||||
input_model = _get_input_model_from_mcp_tool(tool)
|
||||
approval_mode = self._determine_approval_mode(local_name)
|
||||
# Create AIFunctions out of each tool
|
||||
func: AIFunction[BaseModel, list[Contents]] = AIFunction(
|
||||
func=partial(self.call_tool, tool.name),
|
||||
name=local_name,
|
||||
description=tool.description or "",
|
||||
approval_mode=approval_mode,
|
||||
input_model=input_model,
|
||||
)
|
||||
self.functions.append(func)
|
||||
self._functions.append(func)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Disconnect from the MCP server.
|
||||
@@ -670,11 +692,13 @@ class MCPStdioTool(MCPTool):
|
||||
request_timeout: int | None = None,
|
||||
session: ClientSession | None = None,
|
||||
description: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
args: list[str] | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
encoding: str | None = None,
|
||||
chat_client: "ChatClientProtocol | None" = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP stdio tool.
|
||||
@@ -694,6 +718,13 @@ class MCPStdioTool(MCPTool):
|
||||
request_timeout: The default timeout in seconds for all requests.
|
||||
session: The session to use for the MCP connection.
|
||||
description: The description of the tool.
|
||||
approval_mode: The approval mode for the tool. This can be:
|
||||
- "always_require": The tool always requires approval before use.
|
||||
- "never_require": The tool never requires approval before use.
|
||||
- A dict with keys `always_require_approval` or `never_require_approval`,
|
||||
followed by a sequence of strings with the names of the relevant tools.
|
||||
A tool should not be listed in both, if so, it will require approval.
|
||||
allowed_tools: A list of tools that are allowed to use this tool.
|
||||
additional_properties: Additional properties.
|
||||
args: The arguments to pass to the command.
|
||||
env: The environment variables to set for the command.
|
||||
@@ -704,6 +735,8 @@ class MCPStdioTool(MCPTool):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
approval_mode=approval_mode,
|
||||
allowed_tools=allowed_tools,
|
||||
additional_properties=additional_properties,
|
||||
session=session,
|
||||
chat_client=chat_client,
|
||||
@@ -769,12 +802,14 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
request_timeout: int | None = None,
|
||||
session: ClientSession | None = None,
|
||||
description: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
terminate_on_close: bool | None = None,
|
||||
chat_client: "ChatClientProtocol | None" = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP streamable HTTP tool.
|
||||
@@ -795,6 +830,13 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
request_timeout: The default timeout in seconds for all requests.
|
||||
session: The session to use for the MCP connection.
|
||||
description: The description of the tool.
|
||||
approval_mode: The approval mode for the tool. This can be:
|
||||
- "always_require": The tool always requires approval before use.
|
||||
- "never_require": The tool never requires approval before use.
|
||||
- A dict with keys `always_require_approval` or `never_require_approval`,
|
||||
followed by a sequence of strings with the names of the relevant tools.
|
||||
A tool should not be listed in both, if so, it will require approval.
|
||||
allowed_tools: A list of tools that are allowed to use this tool.
|
||||
additional_properties: Additional properties.
|
||||
headers: The headers to send with the request.
|
||||
timeout: The timeout for the request.
|
||||
@@ -806,6 +848,8 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
approval_mode=approval_mode,
|
||||
allowed_tools=allowed_tools,
|
||||
additional_properties=additional_properties,
|
||||
session=session,
|
||||
chat_client=chat_client,
|
||||
@@ -873,8 +917,10 @@ class MCPWebsocketTool(MCPTool):
|
||||
request_timeout: int | None = None,
|
||||
session: ClientSession | None = None,
|
||||
description: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
approval_mode: Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None = None,
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
chat_client: "ChatClientProtocol | None" = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP WebSocket tool.
|
||||
@@ -895,6 +941,13 @@ class MCPWebsocketTool(MCPTool):
|
||||
request_timeout: The default timeout in seconds for all requests.
|
||||
session: The session to use for the MCP connection.
|
||||
description: The description of the tool.
|
||||
approval_mode: The approval mode for the tool. This can be:
|
||||
- "always_require": The tool always requires approval before use.
|
||||
- "never_require": The tool never requires approval before use.
|
||||
- A dict with keys `always_require_approval` or `never_require_approval`,
|
||||
followed by a sequence of strings with the names of the relevant tools.
|
||||
A tool should not be listed in both, if so, it will require approval.
|
||||
allowed_tools: A list of tools that are allowed to use this tool.
|
||||
additional_properties: Additional properties.
|
||||
chat_client: The chat client to use for sampling.
|
||||
kwargs: Any extra arguments to pass to the WebSocket client.
|
||||
@@ -902,6 +955,8 @@ class MCPWebsocketTool(MCPTool):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
approval_mode=approval_mode,
|
||||
allowed_tools=allowed_tools,
|
||||
additional_properties=additional_properties,
|
||||
session=session,
|
||||
chat_client=chat_client,
|
||||
|
||||
@@ -604,6 +604,129 @@ async def test_local_mcp_server_prompt_execution():
|
||||
assert result[0].contents[0].text == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"approval_mode,expected_approvals",
|
||||
[
|
||||
("always_require", {"tool_one": "always_require", "tool_two": "always_require"}),
|
||||
("never_require", {"tool_one": "never_require", "tool_two": "never_require"}),
|
||||
(
|
||||
{"always_require_approval": ["tool_one"], "never_require_approval": ["tool_two"]},
|
||||
{"tool_one": "always_require", "tool_two": "never_require"},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_mcp_tool_approval_mode(approval_mode, expected_approvals):
|
||||
"""Test MCPTool approval_mode parameter with various configurations.
|
||||
|
||||
The approval_mode parameter controls whether tools require approval before execution.
|
||||
It can be set globally ("always_require" or "never_require") or per-tool using a dict.
|
||||
"""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="tool_one",
|
||||
description="First tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="tool_two",
|
||||
description="Second tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server", approval_mode=approval_mode)
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
assert len(server.functions) == 2
|
||||
|
||||
# Verify each tool has the expected approval mode
|
||||
for func in server.functions:
|
||||
assert func.approval_mode == expected_approvals[func.name]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allowed_tools,expected_count,expected_names",
|
||||
[
|
||||
(None, 3, ["tool_one", "tool_two", "tool_three"]), # None means all tools are allowed
|
||||
(["tool_one"], 1, ["tool_one"]), # Only tool_one is allowed
|
||||
(["tool_one", "tool_three"], 2, ["tool_one", "tool_three"]), # Two tools allowed
|
||||
(["nonexistent_tool"], 0, []), # No matching tools
|
||||
],
|
||||
)
|
||||
async def test_mcp_tool_allowed_tools(allowed_tools, expected_count, expected_names):
|
||||
"""Test MCPTool allowed_tools parameter with various configurations.
|
||||
|
||||
The allowed_tools parameter filters which tools are exposed via the functions property.
|
||||
When None, all loaded tools are available. When set to a list, only tools whose names
|
||||
are in that list are exposed.
|
||||
"""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="tool_one",
|
||||
description="First tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="tool_two",
|
||||
description="Second tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
},
|
||||
),
|
||||
types.Tool(
|
||||
name="tool_three",
|
||||
description="Third tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param": {"type": "string"}},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server", allowed_tools=allowed_tools)
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
# _functions should contain all tools
|
||||
assert len(server._functions) == 3
|
||||
|
||||
# functions property should filter based on allowed_tools
|
||||
assert len(server.functions) == expected_count
|
||||
actual_names = [func.name for func in server.functions]
|
||||
assert sorted(actual_names) == sorted(expected_names)
|
||||
|
||||
|
||||
# Server implementation tests
|
||||
def test_local_mcp_stdio_tool_init():
|
||||
"""Test MCPStdioTool initialization."""
|
||||
|
||||
Reference in New Issue
Block a user