mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Add sampling guardrails to MCP tools (#6413)
* Add sampling guardrails to MCP tools Add approval, token, and request-count controls to the MCP sampling callback used when an MCPTool is configured with a chat client. - Add `sampling_approval_callback`, `sampling_max_tokens`, and `sampling_max_requests` parameters to `MCPTool` and its `MCPStdioTool`, `MCPStreamableHTTPTool`, and `MCPWebsocketTool` subclasses, positioned directly after `client`. - Gate each server-initiated `sampling/createMessage` request behind the approval callback, which denies by default when no callback is provided. - Clamp the requested `maxTokens` to `sampling_max_tokens` and enforce a per-session request count via `sampling_max_requests`. - Log incoming sampling requests at WARNING level (counts only). - Export `SamplingApprovalCallback` from the public API. - Add tests, a sample, and documentation updates. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Make sampling denial message context-aware Distinguish the deny-by-default case (no approval callback configured) from an explicit denial by a configured `sampling_approval_callback`, so the returned ErrorData message is accurate for callback-driven denials and exceptions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
cea83bd8d5
commit
9a56bc9f16
@@ -82,6 +82,7 @@ agent_framework/
|
||||
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
|
||||
- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument.
|
||||
- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins.
|
||||
- **Sampling guardrails** (`sampling_callback`) - Passing `client=` advertises `SamplingCapability` so the server can send `sampling/createMessage`. Because remote servers are untrusted (confused-deputy risk), the default `sampling_callback` is **deny-by-default** and applies, in order: a per-session rate limit (`sampling_max_requests`, default `_DEFAULT_SAMPLING_MAX_REQUESTS`), an approval gate (`sampling_approval_callback`), and a `maxTokens` cap (`sampling_max_tokens`, default `_DEFAULT_SAMPLING_MAX_TOKENS`). The approval callback (constructor arg on all subclasses; exported type alias `SamplingApprovalCallback`) receives the raw `CreateMessageRequestParams`, may be sync or async, and must return truthy to approve. When it is `None` (the default) every sampling request is denied; pass `lambda params: True` to restore legacy auto-approve as an explicit opt-in. Requests and denials are logged at WARNING (content is not logged). The per-session counter resets in `_reset_session_state`.
|
||||
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
|
||||
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
|
||||
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
|
||||
|
||||
@@ -124,7 +124,7 @@ from ._harness._todo import (
|
||||
TodoSessionStore,
|
||||
TodoStore,
|
||||
)
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
|
||||
from ._middleware import (
|
||||
AgentContext,
|
||||
AgentMiddleware,
|
||||
@@ -472,6 +472,7 @@ __all__ = [
|
||||
"RunContext",
|
||||
"Runner",
|
||||
"RunnerContext",
|
||||
"SamplingApprovalCallback",
|
||||
"SecretString",
|
||||
"SelectiveToolCallCompactionStrategy",
|
||||
"SessionContext",
|
||||
|
||||
@@ -16,6 +16,7 @@ from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ig
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
|
||||
|
||||
from opentelemetry import propagate
|
||||
@@ -99,6 +100,22 @@ _mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextV
|
||||
MCP_DEFAULT_TIMEOUT = 30
|
||||
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5
|
||||
|
||||
# Default safety limits applied to server-initiated MCP sampling requests
|
||||
# (``sampling/createMessage``). MCP servers are untrusted third parties, so the
|
||||
# default ``sampling_callback`` denies requests unless an approval callback is
|
||||
# supplied, and bounds the cost of any approved request.
|
||||
# - ``_DEFAULT_SAMPLING_MAX_TOKENS`` clamps the server-requested ``maxTokens``.
|
||||
# - ``_DEFAULT_SAMPLING_MAX_REQUESTS`` caps the number of sampling requests per
|
||||
# session connection (the counter resets on reconnect).
|
||||
_DEFAULT_SAMPLING_MAX_TOKENS = 4096
|
||||
_DEFAULT_SAMPLING_MAX_REQUESTS = 25
|
||||
|
||||
# A user-supplied gate invoked before each server-initiated sampling request is
|
||||
# forwarded to the chat client. It receives the raw ``CreateMessageRequestParams``
|
||||
# and returns (or awaits to) a truthy value to approve the request or a falsy
|
||||
# value to deny it. Both synchronous and asynchronous callables are supported.
|
||||
SamplingApprovalCallback = Callable[["types.CreateMessageRequestParams"], "bool | Coroutine[Any, Any, bool]"]
|
||||
|
||||
# region: Helpers
|
||||
|
||||
LOG_LEVEL_MAPPING: dict[str, int] = {
|
||||
@@ -345,6 +362,9 @@ class MCPTool:
|
||||
session: ClientSession | None = None,
|
||||
request_timeout: int | None = None,
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
sampling_approval_callback: SamplingApprovalCallback | None = None,
|
||||
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
|
||||
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
@@ -378,6 +398,20 @@ class MCPTool:
|
||||
session: An existing MCP client session to use.
|
||||
request_timeout: Timeout in seconds for MCP requests.
|
||||
client: A chat client for sampling callbacks.
|
||||
sampling_approval_callback: Optional gate invoked before each server-initiated
|
||||
``sampling/createMessage`` request is forwarded to ``client``. It receives the
|
||||
raw ``CreateMessageRequestParams`` and may be synchronous or asynchronous;
|
||||
returning a truthy value approves the request and a falsy value denies it. When
|
||||
``None`` (the default), every sampling request is **denied** because MCP servers
|
||||
are untrusted third parties (confused-deputy risk). To restore the legacy
|
||||
auto-approve behavior, pass ``lambda params: True`` as an explicit, conscious
|
||||
opt-in.
|
||||
sampling_max_tokens: Upper bound applied to the server-requested ``maxTokens`` for an
|
||||
approved sampling request. The effective value is ``min(requested, cap)``. Set to
|
||||
``None`` to disable the cap. Defaults to ``_DEFAULT_SAMPLING_MAX_TOKENS``.
|
||||
sampling_max_requests: Maximum number of sampling requests allowed per session
|
||||
connection; further requests are rejected. The counter resets on reconnect. Set
|
||||
to ``None`` to disable the limit. Defaults to ``_DEFAULT_SAMPLING_MAX_REQUESTS``.
|
||||
additional_properties: Additional properties for the tool.
|
||||
task_options: Options controlling how long-running MCP tasks are driven for
|
||||
tools that advertise ``execution.taskSupport == "required"``. When ``None``,
|
||||
@@ -410,6 +444,10 @@ class MCPTool:
|
||||
self.session = session
|
||||
self.request_timeout = request_timeout
|
||||
self.client = client
|
||||
self.sampling_approval_callback = sampling_approval_callback
|
||||
self.sampling_max_tokens = sampling_max_tokens
|
||||
self.sampling_max_requests = sampling_max_requests
|
||||
self._sampling_request_count = 0
|
||||
self._functions: list[FunctionTool] = []
|
||||
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
self._tool_task_support_by_name: dict[str, str] = {}
|
||||
@@ -840,6 +878,7 @@ class MCPTool:
|
||||
self._supports_prompts = True
|
||||
self._supports_logging = None
|
||||
self._ping_available = True
|
||||
self._sampling_request_count = 0
|
||||
|
||||
def _set_server_capabilities(self, capabilities: types.ServerCapabilities | None) -> None:
|
||||
self._server_capabilities = capabilities
|
||||
@@ -994,6 +1033,49 @@ class MCPTool:
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to set log level to %s", logger.level, exc_info=exc)
|
||||
|
||||
async def _sampling_request_approved(self, params: types.CreateMessageRequestParams) -> bool:
|
||||
"""Run the configured sampling approval gate.
|
||||
|
||||
Returns ``True`` only when an approval callback is configured and approves the request.
|
||||
When no callback is set, the request is denied (safe default for untrusted servers).
|
||||
"""
|
||||
callback = self.sampling_approval_callback
|
||||
if callback is None:
|
||||
logger.warning(
|
||||
"Denying MCP sampling request from '%s': no 'sampling_approval_callback' configured.",
|
||||
self.name,
|
||||
)
|
||||
return False
|
||||
try:
|
||||
outcome = callback(params)
|
||||
if isawaitable(outcome):
|
||||
outcome = await outcome
|
||||
except Exception as ex:
|
||||
logger.warning(
|
||||
"Denying MCP sampling request from '%s': approval callback raised %s.",
|
||||
self.name,
|
||||
ex,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
approved = bool(outcome)
|
||||
if not approved:
|
||||
logger.warning("MCP sampling request from '%s' was denied by the approval callback.", self.name)
|
||||
return approved
|
||||
|
||||
def _capped_sampling_max_tokens(self, requested: int) -> int:
|
||||
"""Clamp the server-requested ``maxTokens`` to ``sampling_max_tokens`` when configured."""
|
||||
cap = self.sampling_max_tokens
|
||||
if cap is not None and requested > cap:
|
||||
logger.warning(
|
||||
"Capping MCP sampling maxTokens for '%s' from %d to %d.",
|
||||
self.name,
|
||||
requested,
|
||||
cap,
|
||||
)
|
||||
return cap
|
||||
return requested
|
||||
|
||||
async def sampling_callback(
|
||||
self,
|
||||
context: RequestContext[ClientSession, Any],
|
||||
@@ -1001,20 +1083,32 @@ class MCPTool:
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
"""Callback function for sampling.
|
||||
|
||||
This function is called when the MCP server needs to get a message completed.
|
||||
It uses the configured chat client to generate responses.
|
||||
This function is called when the MCP server sends a ``sampling/createMessage``
|
||||
request. It enforces safety guardrails and, if the request is approved, uses the
|
||||
configured chat client to generate a response.
|
||||
|
||||
Safety:
|
||||
MCP servers are untrusted third parties, so forwarding server-controlled prompts
|
||||
to the chat client without review is a confused-deputy risk. This callback
|
||||
therefore applies, in order: a per-session rate limit
|
||||
(``sampling_max_requests``), an approval gate (``sampling_approval_callback``,
|
||||
which **denies by default** when not configured), and a ``maxTokens`` cap
|
||||
(``sampling_max_tokens``). To allow sampling, pass a ``sampling_approval_callback``
|
||||
that returns a truthy value (use ``lambda params: True`` to auto-approve as an
|
||||
explicit opt-in).
|
||||
|
||||
Note:
|
||||
This is a simple version of this function. It can be overridden to allow
|
||||
more complex sampling. It gets added to the session at initialization time,
|
||||
so overriding it is the best way to customize this behavior.
|
||||
This is the default implementation. It can be overridden to allow more complex
|
||||
sampling. It gets added to the session at initialization time, so overriding it is
|
||||
the best way to customize this behavior.
|
||||
|
||||
Args:
|
||||
context: The request context from the MCP server.
|
||||
params: The message creation request parameters.
|
||||
|
||||
Returns:
|
||||
Either a CreateMessageResult with the generated message or ErrorData if generation fails.
|
||||
Either a CreateMessageResult with the generated message or ErrorData if the request
|
||||
is denied, rate limited, or generation fails.
|
||||
"""
|
||||
from mcp import types
|
||||
|
||||
@@ -1023,7 +1117,38 @@ class MCPTool:
|
||||
code=types.INTERNAL_ERROR,
|
||||
message="No chat client available. Please set a chat client.",
|
||||
)
|
||||
logger.debug("Sampling callback called with params: %s", params)
|
||||
|
||||
logger.warning(
|
||||
"MCP server '%s' sent a sampling/createMessage request (%d message(s), maxTokens=%s).",
|
||||
self.name,
|
||||
len(params.messages),
|
||||
params.maxTokens,
|
||||
)
|
||||
|
||||
if self.sampling_max_requests is not None:
|
||||
if self._sampling_request_count >= self.sampling_max_requests:
|
||||
logger.warning(
|
||||
"Denying MCP sampling request from '%s': per-session limit of %d reached.",
|
||||
self.name,
|
||||
self.sampling_max_requests,
|
||||
)
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling rate limit exceeded for this MCP session.",
|
||||
)
|
||||
self._sampling_request_count += 1
|
||||
|
||||
if not await self._sampling_request_approved(params):
|
||||
if self.sampling_approval_callback is None:
|
||||
message = (
|
||||
"Sampling request denied. MCP sampling is disabled by default for untrusted "
|
||||
"servers; provide a 'sampling_approval_callback' that approves the request to "
|
||||
"enable it."
|
||||
)
|
||||
else:
|
||||
message = "Sampling request denied by the 'sampling_approval_callback'."
|
||||
return types.ErrorData(code=types.INVALID_REQUEST, message=message)
|
||||
|
||||
messages: list[Message] = []
|
||||
for msg in params.messages:
|
||||
messages.append(self._parse_message_from_mcp(msg))
|
||||
@@ -1045,7 +1170,7 @@ class MCPTool:
|
||||
|
||||
if params.temperature is not None:
|
||||
options["temperature"] = params.temperature
|
||||
options["max_tokens"] = params.maxTokens
|
||||
options["max_tokens"] = self._capped_sampling_max_tokens(params.maxTokens)
|
||||
if params.stopSequences is not None:
|
||||
options["stop"] = params.stopSequences
|
||||
|
||||
@@ -2219,6 +2344,9 @@ class MCPStdioTool(MCPTool):
|
||||
env: dict[str, str] | None = None,
|
||||
encoding: str | None = None,
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
sampling_approval_callback: SamplingApprovalCallback | None = None,
|
||||
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
|
||||
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
@@ -2266,6 +2394,16 @@ class MCPStdioTool(MCPTool):
|
||||
env: The environment variables to set for the command.
|
||||
encoding: The encoding to use for the command output.
|
||||
client: The chat client to use for sampling.
|
||||
sampling_approval_callback: Optional gate run before each server-initiated
|
||||
``sampling/createMessage`` request reaches ``client``. Receives the raw
|
||||
``CreateMessageRequestParams`` (sync or async); a truthy return approves the
|
||||
request, a falsy return denies it. When ``None`` (the default) every sampling
|
||||
request is **denied**, since MCP servers are untrusted (confused-deputy risk).
|
||||
Pass ``lambda params: True`` to auto-approve as an explicit opt-in.
|
||||
sampling_max_tokens: Cap applied to an approved request's ``maxTokens``
|
||||
(``min(requested, cap)``); ``None`` disables it.
|
||||
sampling_max_requests: Per-session cap on the number of sampling requests; further
|
||||
requests are rejected. Resets on reconnect. ``None`` disables it.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
@@ -2300,6 +2438,9 @@ class MCPStdioTool(MCPTool):
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
sampling_approval_callback=sampling_approval_callback,
|
||||
sampling_max_tokens=sampling_max_tokens,
|
||||
sampling_max_requests=sampling_max_requests,
|
||||
)
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
@@ -2375,6 +2516,9 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
terminate_on_close: bool | None = None,
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
sampling_approval_callback: SamplingApprovalCallback | None = None,
|
||||
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
|
||||
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
http_client: AsyncClient | None = None,
|
||||
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
|
||||
@@ -2423,6 +2567,16 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
additional_properties: Additional properties.
|
||||
terminate_on_close: Close the transport when the MCP client is terminated.
|
||||
client: The chat client to use for sampling.
|
||||
sampling_approval_callback: Optional gate run before each server-initiated
|
||||
``sampling/createMessage`` request reaches ``client``. Receives the raw
|
||||
``CreateMessageRequestParams`` (sync or async); a truthy return approves the
|
||||
request, a falsy return denies it. When ``None`` (the default) every sampling
|
||||
request is **denied**, since MCP servers are untrusted (confused-deputy risk).
|
||||
Pass ``lambda params: True`` to auto-approve as an explicit opt-in.
|
||||
sampling_max_tokens: Cap applied to an approved request's ``maxTokens``
|
||||
(``min(requested, cap)``); ``None`` disables it.
|
||||
sampling_max_requests: Per-session cap on the number of sampling requests; further
|
||||
requests are rejected. Resets on reconnect. ``None`` disables it.
|
||||
http_client: Optional asyncClient to use. If not provided, the
|
||||
``streamable_http_client`` API will create and manage a default client.
|
||||
To configure headers, timeouts, or other HTTP client settings, create
|
||||
@@ -2466,6 +2620,9 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
sampling_approval_callback=sampling_approval_callback,
|
||||
sampling_max_tokens=sampling_max_tokens,
|
||||
sampling_max_requests=sampling_max_requests,
|
||||
)
|
||||
self.url = url
|
||||
self.terminate_on_close = terminate_on_close
|
||||
@@ -2590,6 +2747,9 @@ class MCPWebsocketTool(MCPTool):
|
||||
approval_mode: (Literal["always_require", "never_require"] | MCPSpecificApproval | None) = None,
|
||||
allowed_tools: Collection[str] | None = None,
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
sampling_approval_callback: SamplingApprovalCallback | None = None,
|
||||
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
|
||||
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
task_options: MCPTaskOptions | None = None,
|
||||
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
|
||||
@@ -2635,6 +2795,16 @@ class MCPWebsocketTool(MCPTool):
|
||||
allowed_tools: A list of tools that are allowed to use this tool.
|
||||
additional_properties: Additional properties.
|
||||
client: The chat client to use for sampling.
|
||||
sampling_approval_callback: Optional gate run before each server-initiated
|
||||
``sampling/createMessage`` request reaches ``client``. Receives the raw
|
||||
``CreateMessageRequestParams`` (sync or async); a truthy return approves the
|
||||
request, a falsy return denies it. When ``None`` (the default) every sampling
|
||||
request is **denied**, since MCP servers are untrusted (confused-deputy risk).
|
||||
Pass ``lambda params: True`` to auto-approve as an explicit opt-in.
|
||||
sampling_max_tokens: Cap applied to an approved request's ``maxTokens``
|
||||
(``min(requested, cap)``); ``None`` disables it.
|
||||
sampling_max_requests: Per-session cap on the number of sampling requests; further
|
||||
requests are rejected. Resets on reconnect. ``None`` disables it.
|
||||
task_options: Options for tools that advertise
|
||||
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
|
||||
additional_tool_argument_names: Extra argument names to forward to the MCP server in
|
||||
@@ -2669,6 +2839,9 @@ class MCPWebsocketTool(MCPTool):
|
||||
request_timeout=request_timeout,
|
||||
task_options=task_options,
|
||||
additional_tool_argument_names=additional_tool_argument_names,
|
||||
sampling_approval_callback=sampling_approval_callback,
|
||||
sampling_max_tokens=sampling_max_tokens,
|
||||
sampling_max_requests=sampling_max_requests,
|
||||
)
|
||||
self.url = url
|
||||
self._client_kwargs = kwargs
|
||||
|
||||
@@ -1813,6 +1813,18 @@ async def test_mcp_tool_message_handler_cancel_and_replace():
|
||||
assert len(tool._pending_reload_tasks) == 0
|
||||
|
||||
|
||||
def _approve(_params: object) -> bool:
|
||||
"""Approving sampling gate used by tests that exercise forwarding behavior."""
|
||||
return True
|
||||
|
||||
|
||||
def _make_sampling_response(text: str = "response", model: str = "test-model") -> Mock:
|
||||
mock_response = Mock()
|
||||
mock_response.messages = [Message(role="assistant", contents=[Content.from_text(text)])]
|
||||
mock_response.model = model
|
||||
return mock_response
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_no_client():
|
||||
"""Test sampling callback error path when no chat client is available."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
@@ -1828,9 +1840,190 @@ async def test_mcp_tool_sampling_callback_no_client():
|
||||
assert "No chat client available" in result.message
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_denies_by_default():
|
||||
"""Sampling is denied when no approval callback is configured (safe default)."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
mock_chat_client = AsyncMock()
|
||||
tool.client = mock_chat_client
|
||||
|
||||
params = Mock()
|
||||
params.messages = []
|
||||
params.maxTokens = 128
|
||||
|
||||
result = await tool.sampling_callback(Mock(), params)
|
||||
|
||||
assert isinstance(result, types.ErrorData)
|
||||
assert result.code == types.INVALID_REQUEST
|
||||
assert "denied" in result.message
|
||||
assert "sampling_approval_callback" in result.message
|
||||
mock_chat_client.get_response.assert_not_called()
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_denied_by_callback():
|
||||
"""Sampling is denied when the approval callback returns a falsy value."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=lambda params: False)
|
||||
mock_chat_client = AsyncMock()
|
||||
tool.client = mock_chat_client
|
||||
|
||||
params = Mock()
|
||||
params.messages = []
|
||||
params.maxTokens = 128
|
||||
|
||||
result = await tool.sampling_callback(Mock(), params)
|
||||
|
||||
assert isinstance(result, types.ErrorData)
|
||||
assert result.code == types.INVALID_REQUEST
|
||||
assert "denied by the 'sampling_approval_callback'" in result.message
|
||||
mock_chat_client.get_response.assert_not_called()
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_callback_exception_denies():
|
||||
"""An approval callback that raises results in denial, not an LLM call."""
|
||||
|
||||
def boom(_params: object) -> bool:
|
||||
raise RuntimeError("approval error")
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=boom)
|
||||
mock_chat_client = AsyncMock()
|
||||
tool.client = mock_chat_client
|
||||
|
||||
params = Mock()
|
||||
params.messages = []
|
||||
params.maxTokens = 128
|
||||
|
||||
result = await tool.sampling_callback(Mock(), params)
|
||||
|
||||
assert isinstance(result, types.ErrorData)
|
||||
assert result.code == types.INVALID_REQUEST
|
||||
mock_chat_client.get_response.assert_not_called()
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_async_approval():
|
||||
"""An async approval callback that approves allows the request through."""
|
||||
|
||||
async def approve(_params: object) -> bool:
|
||||
return True
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=approve)
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_chat_client.get_response.return_value = _make_sampling_response("ok")
|
||||
tool.client = mock_chat_client
|
||||
|
||||
params = Mock()
|
||||
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
|
||||
params.temperature = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
params.toolChoice = None
|
||||
|
||||
result = await tool.sampling_callback(Mock(), params)
|
||||
|
||||
assert isinstance(result, types.CreateMessageResult)
|
||||
assert result.content.text == "ok"
|
||||
mock_chat_client.get_response.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_clamps_max_tokens():
|
||||
"""An approved request's maxTokens is clamped to sampling_max_tokens."""
|
||||
tool = MCPStdioTool(
|
||||
name="test_tool",
|
||||
command="python",
|
||||
sampling_approval_callback=_approve,
|
||||
sampling_max_tokens=512,
|
||||
)
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_chat_client.get_response.return_value = _make_sampling_response()
|
||||
tool.client = mock_chat_client
|
||||
|
||||
params = Mock()
|
||||
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
|
||||
params.temperature = None
|
||||
params.maxTokens = 1_000_000
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
params.toolChoice = None
|
||||
|
||||
result = await tool.sampling_callback(Mock(), params)
|
||||
|
||||
assert isinstance(result, types.CreateMessageResult)
|
||||
options = mock_chat_client.get_response.call_args.kwargs.get("options") or {}
|
||||
assert options["max_tokens"] == 512
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_does_not_clamp_under_cap():
|
||||
"""A request below the cap keeps its requested maxTokens."""
|
||||
tool = MCPStdioTool(
|
||||
name="test_tool",
|
||||
command="python",
|
||||
sampling_approval_callback=_approve,
|
||||
sampling_max_tokens=512,
|
||||
)
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_chat_client.get_response.return_value = _make_sampling_response()
|
||||
tool.client = mock_chat_client
|
||||
|
||||
params = Mock()
|
||||
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
|
||||
params.temperature = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
params.toolChoice = None
|
||||
|
||||
result = await tool.sampling_callback(Mock(), params)
|
||||
|
||||
assert isinstance(result, types.CreateMessageResult)
|
||||
options = mock_chat_client.get_response.call_args.kwargs.get("options") or {}
|
||||
assert options["max_tokens"] == 100
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_rate_limited():
|
||||
"""Sampling requests beyond sampling_max_requests are rejected per session."""
|
||||
tool = MCPStdioTool(
|
||||
name="test_tool",
|
||||
command="python",
|
||||
sampling_approval_callback=_approve,
|
||||
sampling_max_requests=2,
|
||||
)
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_chat_client.get_response.return_value = _make_sampling_response()
|
||||
tool.client = mock_chat_client
|
||||
|
||||
def make_params() -> Mock:
|
||||
params = Mock()
|
||||
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
|
||||
params.temperature = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
params.toolChoice = None
|
||||
return params
|
||||
|
||||
first = await tool.sampling_callback(Mock(), make_params())
|
||||
second = await tool.sampling_callback(Mock(), make_params())
|
||||
third = await tool.sampling_callback(Mock(), make_params())
|
||||
|
||||
assert isinstance(first, types.CreateMessageResult)
|
||||
assert isinstance(second, types.CreateMessageResult)
|
||||
assert isinstance(third, types.ErrorData)
|
||||
assert third.code == types.INVALID_REQUEST
|
||||
assert "rate limit" in third.message.lower()
|
||||
assert mock_chat_client.get_response.await_count == 2
|
||||
|
||||
# The counter resets on a session reset.
|
||||
tool._reset_session_state()
|
||||
fourth = await tool.sampling_callback(Mock(), make_params())
|
||||
assert isinstance(fourth, types.CreateMessageResult)
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_chat_client_exception():
|
||||
"""Test sampling callback when chat client raises exception."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
# Mock chat client that raises exception
|
||||
mock_chat_client = AsyncMock()
|
||||
@@ -1846,7 +2039,7 @@ async def test_mcp_tool_sampling_callback_chat_client_exception():
|
||||
mock_message.content.text = "Test question"
|
||||
params.messages = [mock_message]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
@@ -1863,7 +2056,7 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
|
||||
"""Test sampling callback when response has no valid content types."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
# Mock chat client with response containing only invalid content types
|
||||
mock_chat_client = AsyncMock()
|
||||
@@ -1892,7 +2085,7 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
|
||||
mock_message.content.text = "Test question"
|
||||
params.messages = [mock_message]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
@@ -1905,18 +2098,18 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
|
||||
assert "Failed to get right content types from the response." in result.message
|
||||
mock_chat_client.get_response.assert_awaited_once()
|
||||
_, kwargs = mock_chat_client.get_response.await_args
|
||||
assert kwargs["options"] == {"max_tokens": None}
|
||||
assert kwargs["options"] == {"max_tokens": 100}
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_no_response_and_successful_message_creation():
|
||||
"""Test sampling callback when the chat client returns no response and then valid content."""
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
tool.client = AsyncMock()
|
||||
|
||||
params = Mock()
|
||||
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
@@ -1955,7 +2148,7 @@ async def test_mcp_tool_sampling_callback_forwards_system_prompt():
|
||||
"""Test sampling callback passes systemPrompt as instructions in options."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
@@ -1972,7 +2165,7 @@ async def test_mcp_tool_sampling_callback_forwards_system_prompt():
|
||||
mock_message.content.text = "Test question"
|
||||
params.messages = [mock_message]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = "You are a helpful assistant"
|
||||
params.tools = None
|
||||
@@ -1990,7 +2183,7 @@ async def test_mcp_tool_sampling_callback_forwards_tools():
|
||||
"""Test sampling callback converts MCP tools to FunctionTools and passes them in options."""
|
||||
from agent_framework import FunctionTool, Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
@@ -2013,7 +2206,7 @@ async def test_mcp_tool_sampling_callback_forwards_tools():
|
||||
mock_message.content.text = "Test question"
|
||||
params.messages = [mock_message]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = [mcp_tool]
|
||||
@@ -2036,7 +2229,7 @@ async def test_mcp_tool_sampling_callback_forwards_tool_choice():
|
||||
"""Test sampling callback passes toolChoice mode in options."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
@@ -2053,7 +2246,7 @@ async def test_mcp_tool_sampling_callback_forwards_tool_choice():
|
||||
mock_message.content.text = "Test question"
|
||||
params.messages = [mock_message]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = None
|
||||
@@ -2071,7 +2264,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_system_prompt():
|
||||
"""Test sampling callback forwards empty string systemPrompt as instructions."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
@@ -2088,7 +2281,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_system_prompt():
|
||||
mock_message.content.text = "Test question"
|
||||
params.messages = [mock_message]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = ""
|
||||
params.tools = None
|
||||
@@ -2106,7 +2299,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_tools_list():
|
||||
"""Test sampling callback forwards empty tools list in options."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
@@ -2123,7 +2316,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_tools_list():
|
||||
mock_message.content.text = "Test question"
|
||||
params.messages = [mock_message]
|
||||
params.temperature = None
|
||||
params.maxTokens = None
|
||||
params.maxTokens = 100
|
||||
params.stopSequences = None
|
||||
params.systemPrompt = None
|
||||
params.tools = []
|
||||
@@ -2141,7 +2334,7 @@ async def test_mcp_tool_sampling_callback_forwards_generation_params_in_options(
|
||||
"""Test sampling callback passes temperature, max_tokens, and stop in options."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
@@ -2182,7 +2375,7 @@ async def test_mcp_tool_sampling_callback_omits_temperature_when_none():
|
||||
"""Test sampling callback does not set temperature in options when it is None."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
@@ -2219,7 +2412,7 @@ async def test_mcp_tool_sampling_callback_always_passes_max_tokens():
|
||||
"""Test sampling callback always sets max_tokens in options since maxTokens is a required int field."""
|
||||
from agent_framework import Message
|
||||
|
||||
tool = MCPStdioTool(name="test_tool", command="python")
|
||||
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
|
||||
|
||||
mock_chat_client = AsyncMock()
|
||||
mock_response = Mock()
|
||||
|
||||
@@ -14,6 +14,7 @@ The Model Context Protocol (MCP) is an open standard for connecting AI agents to
|
||||
| **API Key Authentication** | [`mcp_api_key_auth.py`](mcp_api_key_auth.py) | Demonstrates API key authentication with MCP servers using `header_provider`, runtime invocation kwargs, and a command-line API key argument |
|
||||
| **GitHub Integration with PAT** | [`mcp_github_pat.py`](mcp_github_pat.py) | Demonstrates connecting to GitHub's MCP server using Personal Access Token (PAT) authentication |
|
||||
| **Long-Running Task** | [`mcp_long_running_task.py`](mcp_long_running_task.py) | Demonstrates transparent SEP-2663 long-running task handling for MCP tools that advertise `taskSupport=required`. Self-spawns a stdio MCP child server |
|
||||
| **Sampling Approval** | [`mcp_sampling_approval.py`](mcp_sampling_approval.py) | Demonstrates gating server-initiated `sampling/createMessage` requests with a `sampling_approval_callback`, plus the `sampling_max_tokens` and `sampling_max_requests` guardrails. MCP sampling is denied by default |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent_framework import Agent, MCPStreamableHTTPTool
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from dotenv import load_dotenv
|
||||
from mcp import types
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
MCP Sampling Approval Example
|
||||
|
||||
MCP servers can send the client a ``sampling/createMessage`` request, asking the
|
||||
client to run an LLM completion on the server's behalf. Because remote MCP
|
||||
servers are untrusted third parties, forwarding these server-controlled prompts
|
||||
to your chat client without review is a confused-deputy risk: a malicious server
|
||||
could exfiltrate context, force tool calls, or burn through your token budget.
|
||||
|
||||
For that reason Agent Framework **denies MCP sampling by default**. To allow it,
|
||||
pass a ``sampling_approval_callback`` to the MCP tool. The callback receives the
|
||||
raw ``CreateMessageRequestParams`` and returns ``True`` to approve or ``False``
|
||||
to deny. It may be synchronous or asynchronous, so you can implement a
|
||||
human-in-the-loop prompt, a policy check, or an audit log.
|
||||
|
||||
Two further guardrails apply to approved requests:
|
||||
- ``sampling_max_tokens`` caps the server-requested ``maxTokens``.
|
||||
- ``sampling_max_requests`` limits how many sampling requests a single session
|
||||
may make.
|
||||
|
||||
To restore the legacy "always approve" behavior (only do this for servers you
|
||||
trust), pass ``sampling_approval_callback=lambda params: True``.
|
||||
"""
|
||||
|
||||
|
||||
async def approve_sampling(params: types.CreateMessageRequestParams) -> bool:
|
||||
"""Human-in-the-loop approval gate for server-initiated sampling.
|
||||
|
||||
Shows the server-supplied system prompt and messages, then asks the user to
|
||||
approve or deny. Returning ``False`` rejects the request.
|
||||
"""
|
||||
print("\n--- MCP server requested a sampling/createMessage ---")
|
||||
if params.systemPrompt:
|
||||
print(f"System prompt: {params.systemPrompt}")
|
||||
for message in params.messages:
|
||||
text = getattr(message.content, "text", message.content)
|
||||
print(f"{message.role}: {text}")
|
||||
answer = await asyncio.to_thread(input, "Approve this sampling request? [y/N]: ")
|
||||
return answer.strip().lower() in {"y", "yes"}
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run an agent against an MCP server with a sampling approval gate."""
|
||||
async with Agent(
|
||||
client=OpenAIChatClient(),
|
||||
name="Agent",
|
||||
instructions="You are a helpful assistant. Use your MCP tool when answering the user's question.",
|
||||
tools=MCPStreamableHTTPTool(
|
||||
name="MCP tool",
|
||||
description="MCP tool description.",
|
||||
url="<your mcp server url>",
|
||||
# Passing ``client`` enables sampling; the approval callback gates it.
|
||||
client=OpenAIChatClient(),
|
||||
sampling_approval_callback=approve_sampling,
|
||||
sampling_max_tokens=2048,
|
||||
sampling_max_requests=5,
|
||||
),
|
||||
) as agent:
|
||||
query = "Use your MCP tool to help answer this question."
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user