Python: [BREAKING] Add sampling guardrails to MCP tools (#6413)

* Add sampling guardrails to MCP tools

Add approval, token, and request-count controls to the MCP sampling
callback used when an MCPTool is configured with a chat client.

- Add `sampling_approval_callback`, `sampling_max_tokens`, and
  `sampling_max_requests` parameters to `MCPTool` and its
  `MCPStdioTool`, `MCPStreamableHTTPTool`, and `MCPWebsocketTool`
  subclasses, positioned directly after `client`.
- Gate each server-initiated `sampling/createMessage` request behind the
  approval callback, which denies by default when no callback is provided.
- Clamp the requested `maxTokens` to `sampling_max_tokens` and enforce a
  per-session request count via `sampling_max_requests`.
- Log incoming sampling requests at WARNING level (counts only).
- Export `SamplingApprovalCallback` from the public API.
- Add tests, a sample, and documentation updates.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Make sampling denial message context-aware

Distinguish the deny-by-default case (no approval callback configured)
from an explicit denial by a configured `sampling_approval_callback`, so
the returned ErrorData message is accurate for callback-driven denials
and exceptions.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-06-10 12:17:36 +02:00
committed by GitHub
Unverified
parent cea83bd8d5
commit 9a56bc9f16
6 changed files with 476 additions and 29 deletions
+1
View File
@@ -82,6 +82,7 @@ agent_framework/
- **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses.
- **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument.
- **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins.
- **Sampling guardrails** (`sampling_callback`) - Passing `client=` advertises `SamplingCapability` so the server can send `sampling/createMessage`. Because remote servers are untrusted (confused-deputy risk), the default `sampling_callback` is **deny-by-default** and applies, in order: a per-session rate limit (`sampling_max_requests`, default `_DEFAULT_SAMPLING_MAX_REQUESTS`), an approval gate (`sampling_approval_callback`), and a `maxTokens` cap (`sampling_max_tokens`, default `_DEFAULT_SAMPLING_MAX_TOKENS`). The approval callback (constructor arg on all subclasses; exported type alias `SamplingApprovalCallback`) receives the raw `CreateMessageRequestParams`, may be sync or async, and must return truthy to approve. When it is `None` (the default) every sampling request is denied; pass `lambda params: True` to restore legacy auto-approve as an explicit opt-in. Requests and denials are logged at WARNING (content is not logged). The per-session counter resets in `_reset_session_state`.
- **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields:
- `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies.
- `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel.
@@ -124,7 +124,7 @@ from ._harness._todo import (
TodoSessionStore,
TodoStore,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback
from ._middleware import (
AgentContext,
AgentMiddleware,
@@ -472,6 +472,7 @@ __all__ = [
"RunContext",
"Runner",
"RunnerContext",
"SamplingApprovalCallback",
"SecretString",
"SelectiveToolCallCompactionStrategy",
"SessionContext",
+181 -8
View File
@@ -16,6 +16,7 @@ from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ig
from dataclasses import dataclass
from datetime import timedelta
from functools import partial
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
from opentelemetry import propagate
@@ -99,6 +100,22 @@ _mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextV
MCP_DEFAULT_TIMEOUT = 30
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5
# Default safety limits applied to server-initiated MCP sampling requests
# (``sampling/createMessage``). MCP servers are untrusted third parties, so the
# default ``sampling_callback`` denies requests unless an approval callback is
# supplied, and bounds the cost of any approved request.
# - ``_DEFAULT_SAMPLING_MAX_TOKENS`` clamps the server-requested ``maxTokens``.
# - ``_DEFAULT_SAMPLING_MAX_REQUESTS`` caps the number of sampling requests per
# session connection (the counter resets on reconnect).
_DEFAULT_SAMPLING_MAX_TOKENS = 4096
_DEFAULT_SAMPLING_MAX_REQUESTS = 25
# A user-supplied gate invoked before each server-initiated sampling request is
# forwarded to the chat client. It receives the raw ``CreateMessageRequestParams``
# and returns (or awaits to) a truthy value to approve the request or a falsy
# value to deny it. Both synchronous and asynchronous callables are supported.
SamplingApprovalCallback = Callable[["types.CreateMessageRequestParams"], "bool | Coroutine[Any, Any, bool]"]
# region: Helpers
LOG_LEVEL_MAPPING: dict[str, int] = {
@@ -345,6 +362,9 @@ class MCPTool:
session: ClientSession | None = None,
request_timeout: int | None = None,
client: SupportsChatGetResponse | None = None,
sampling_approval_callback: SamplingApprovalCallback | None = None,
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
@@ -378,6 +398,20 @@ class MCPTool:
session: An existing MCP client session to use.
request_timeout: Timeout in seconds for MCP requests.
client: A chat client for sampling callbacks.
sampling_approval_callback: Optional gate invoked before each server-initiated
``sampling/createMessage`` request is forwarded to ``client``. It receives the
raw ``CreateMessageRequestParams`` and may be synchronous or asynchronous;
returning a truthy value approves the request and a falsy value denies it. When
``None`` (the default), every sampling request is **denied** because MCP servers
are untrusted third parties (confused-deputy risk). To restore the legacy
auto-approve behavior, pass ``lambda params: True`` as an explicit, conscious
opt-in.
sampling_max_tokens: Upper bound applied to the server-requested ``maxTokens`` for an
approved sampling request. The effective value is ``min(requested, cap)``. Set to
``None`` to disable the cap. Defaults to ``_DEFAULT_SAMPLING_MAX_TOKENS``.
sampling_max_requests: Maximum number of sampling requests allowed per session
connection; further requests are rejected. The counter resets on reconnect. Set
to ``None`` to disable the limit. Defaults to ``_DEFAULT_SAMPLING_MAX_REQUESTS``.
additional_properties: Additional properties for the tool.
task_options: Options controlling how long-running MCP tasks are driven for
tools that advertise ``execution.taskSupport == "required"``. When ``None``,
@@ -410,6 +444,10 @@ class MCPTool:
self.session = session
self.request_timeout = request_timeout
self.client = client
self.sampling_approval_callback = sampling_approval_callback
self.sampling_max_tokens = sampling_max_tokens
self.sampling_max_requests = sampling_max_requests
self._sampling_request_count = 0
self._functions: list[FunctionTool] = []
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
self._tool_task_support_by_name: dict[str, str] = {}
@@ -840,6 +878,7 @@ class MCPTool:
self._supports_prompts = True
self._supports_logging = None
self._ping_available = True
self._sampling_request_count = 0
def _set_server_capabilities(self, capabilities: types.ServerCapabilities | None) -> None:
self._server_capabilities = capabilities
@@ -994,6 +1033,49 @@ class MCPTool:
except Exception as exc:
logger.warning("Failed to set log level to %s", logger.level, exc_info=exc)
async def _sampling_request_approved(self, params: types.CreateMessageRequestParams) -> bool:
"""Run the configured sampling approval gate.
Returns ``True`` only when an approval callback is configured and approves the request.
When no callback is set, the request is denied (safe default for untrusted servers).
"""
callback = self.sampling_approval_callback
if callback is None:
logger.warning(
"Denying MCP sampling request from '%s': no 'sampling_approval_callback' configured.",
self.name,
)
return False
try:
outcome = callback(params)
if isawaitable(outcome):
outcome = await outcome
except Exception as ex:
logger.warning(
"Denying MCP sampling request from '%s': approval callback raised %s.",
self.name,
ex,
exc_info=True,
)
return False
approved = bool(outcome)
if not approved:
logger.warning("MCP sampling request from '%s' was denied by the approval callback.", self.name)
return approved
def _capped_sampling_max_tokens(self, requested: int) -> int:
"""Clamp the server-requested ``maxTokens`` to ``sampling_max_tokens`` when configured."""
cap = self.sampling_max_tokens
if cap is not None and requested > cap:
logger.warning(
"Capping MCP sampling maxTokens for '%s' from %d to %d.",
self.name,
requested,
cap,
)
return cap
return requested
async def sampling_callback(
self,
context: RequestContext[ClientSession, Any],
@@ -1001,20 +1083,32 @@ class MCPTool:
) -> types.CreateMessageResult | types.ErrorData:
"""Callback function for sampling.
This function is called when the MCP server needs to get a message completed.
It uses the configured chat client to generate responses.
This function is called when the MCP server sends a ``sampling/createMessage``
request. It enforces safety guardrails and, if the request is approved, uses the
configured chat client to generate a response.
Safety:
MCP servers are untrusted third parties, so forwarding server-controlled prompts
to the chat client without review is a confused-deputy risk. This callback
therefore applies, in order: a per-session rate limit
(``sampling_max_requests``), an approval gate (``sampling_approval_callback``,
which **denies by default** when not configured), and a ``maxTokens`` cap
(``sampling_max_tokens``). To allow sampling, pass a ``sampling_approval_callback``
that returns a truthy value (use ``lambda params: True`` to auto-approve as an
explicit opt-in).
Note:
This is a simple version of this function. It can be overridden to allow
more complex sampling. It gets added to the session at initialization time,
so overriding it is the best way to customize this behavior.
This is the default implementation. It can be overridden to allow more complex
sampling. It gets added to the session at initialization time, so overriding it is
the best way to customize this behavior.
Args:
context: The request context from the MCP server.
params: The message creation request parameters.
Returns:
Either a CreateMessageResult with the generated message or ErrorData if generation fails.
Either a CreateMessageResult with the generated message or ErrorData if the request
is denied, rate limited, or generation fails.
"""
from mcp import types
@@ -1023,7 +1117,38 @@ class MCPTool:
code=types.INTERNAL_ERROR,
message="No chat client available. Please set a chat client.",
)
logger.debug("Sampling callback called with params: %s", params)
logger.warning(
"MCP server '%s' sent a sampling/createMessage request (%d message(s), maxTokens=%s).",
self.name,
len(params.messages),
params.maxTokens,
)
if self.sampling_max_requests is not None:
if self._sampling_request_count >= self.sampling_max_requests:
logger.warning(
"Denying MCP sampling request from '%s': per-session limit of %d reached.",
self.name,
self.sampling_max_requests,
)
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling rate limit exceeded for this MCP session.",
)
self._sampling_request_count += 1
if not await self._sampling_request_approved(params):
if self.sampling_approval_callback is None:
message = (
"Sampling request denied. MCP sampling is disabled by default for untrusted "
"servers; provide a 'sampling_approval_callback' that approves the request to "
"enable it."
)
else:
message = "Sampling request denied by the 'sampling_approval_callback'."
return types.ErrorData(code=types.INVALID_REQUEST, message=message)
messages: list[Message] = []
for msg in params.messages:
messages.append(self._parse_message_from_mcp(msg))
@@ -1045,7 +1170,7 @@ class MCPTool:
if params.temperature is not None:
options["temperature"] = params.temperature
options["max_tokens"] = params.maxTokens
options["max_tokens"] = self._capped_sampling_max_tokens(params.maxTokens)
if params.stopSequences is not None:
options["stop"] = params.stopSequences
@@ -2219,6 +2344,9 @@ class MCPStdioTool(MCPTool):
env: dict[str, str] | None = None,
encoding: str | None = None,
client: SupportsChatGetResponse | None = None,
sampling_approval_callback: SamplingApprovalCallback | None = None,
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
@@ -2266,6 +2394,16 @@ class MCPStdioTool(MCPTool):
env: The environment variables to set for the command.
encoding: The encoding to use for the command output.
client: The chat client to use for sampling.
sampling_approval_callback: Optional gate run before each server-initiated
``sampling/createMessage`` request reaches ``client``. Receives the raw
``CreateMessageRequestParams`` (sync or async); a truthy return approves the
request, a falsy return denies it. When ``None`` (the default) every sampling
request is **denied**, since MCP servers are untrusted (confused-deputy risk).
Pass ``lambda params: True`` to auto-approve as an explicit opt-in.
sampling_max_tokens: Cap applied to an approved request's ``maxTokens``
(``min(requested, cap)``); ``None`` disables it.
sampling_max_requests: Per-session cap on the number of sampling requests; further
requests are rejected. Resets on reconnect. ``None`` disables it.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
additional_tool_argument_names: Extra argument names to forward to the MCP server in
@@ -2300,6 +2438,9 @@ class MCPStdioTool(MCPTool):
request_timeout=request_timeout,
task_options=task_options,
additional_tool_argument_names=additional_tool_argument_names,
sampling_approval_callback=sampling_approval_callback,
sampling_max_tokens=sampling_max_tokens,
sampling_max_requests=sampling_max_requests,
)
self.command = command
self.args = args or []
@@ -2375,6 +2516,9 @@ class MCPStreamableHTTPTool(MCPTool):
allowed_tools: Collection[str] | None = None,
terminate_on_close: bool | None = None,
client: SupportsChatGetResponse | None = None,
sampling_approval_callback: SamplingApprovalCallback | None = None,
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
additional_properties: dict[str, Any] | None = None,
http_client: AsyncClient | None = None,
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
@@ -2423,6 +2567,16 @@ class MCPStreamableHTTPTool(MCPTool):
additional_properties: Additional properties.
terminate_on_close: Close the transport when the MCP client is terminated.
client: The chat client to use for sampling.
sampling_approval_callback: Optional gate run before each server-initiated
``sampling/createMessage`` request reaches ``client``. Receives the raw
``CreateMessageRequestParams`` (sync or async); a truthy return approves the
request, a falsy return denies it. When ``None`` (the default) every sampling
request is **denied**, since MCP servers are untrusted (confused-deputy risk).
Pass ``lambda params: True`` to auto-approve as an explicit opt-in.
sampling_max_tokens: Cap applied to an approved request's ``maxTokens``
(``min(requested, cap)``); ``None`` disables it.
sampling_max_requests: Per-session cap on the number of sampling requests; further
requests are rejected. Resets on reconnect. ``None`` disables it.
http_client: Optional asyncClient to use. If not provided, the
``streamable_http_client`` API will create and manage a default client.
To configure headers, timeouts, or other HTTP client settings, create
@@ -2466,6 +2620,9 @@ class MCPStreamableHTTPTool(MCPTool):
request_timeout=request_timeout,
task_options=task_options,
additional_tool_argument_names=additional_tool_argument_names,
sampling_approval_callback=sampling_approval_callback,
sampling_max_tokens=sampling_max_tokens,
sampling_max_requests=sampling_max_requests,
)
self.url = url
self.terminate_on_close = terminate_on_close
@@ -2590,6 +2747,9 @@ class MCPWebsocketTool(MCPTool):
approval_mode: (Literal["always_require", "never_require"] | MCPSpecificApproval | None) = None,
allowed_tools: Collection[str] | None = None,
client: SupportsChatGetResponse | None = None,
sampling_approval_callback: SamplingApprovalCallback | None = None,
sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS,
sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS,
additional_properties: dict[str, Any] | None = None,
task_options: MCPTaskOptions | None = None,
additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None,
@@ -2635,6 +2795,16 @@ class MCPWebsocketTool(MCPTool):
allowed_tools: A list of tools that are allowed to use this tool.
additional_properties: Additional properties.
client: The chat client to use for sampling.
sampling_approval_callback: Optional gate run before each server-initiated
``sampling/createMessage`` request reaches ``client``. Receives the raw
``CreateMessageRequestParams`` (sync or async); a truthy return approves the
request, a falsy return denies it. When ``None`` (the default) every sampling
request is **denied**, since MCP servers are untrusted (confused-deputy risk).
Pass ``lambda params: True`` to auto-approve as an explicit opt-in.
sampling_max_tokens: Cap applied to an approved request's ``maxTokens``
(``min(requested, cap)``); ``None`` disables it.
sampling_max_requests: Per-session cap on the number of sampling requests; further
requests are rejected. Resets on reconnect. ``None`` disables it.
task_options: Options for tools that advertise
``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`.
additional_tool_argument_names: Extra argument names to forward to the MCP server in
@@ -2669,6 +2839,9 @@ class MCPWebsocketTool(MCPTool):
request_timeout=request_timeout,
task_options=task_options,
additional_tool_argument_names=additional_tool_argument_names,
sampling_approval_callback=sampling_approval_callback,
sampling_max_tokens=sampling_max_tokens,
sampling_max_requests=sampling_max_requests,
)
self.url = url
self._client_kwargs = kwargs
+213 -20
View File
@@ -1813,6 +1813,18 @@ async def test_mcp_tool_message_handler_cancel_and_replace():
assert len(tool._pending_reload_tasks) == 0
def _approve(_params: object) -> bool:
"""Approving sampling gate used by tests that exercise forwarding behavior."""
return True
def _make_sampling_response(text: str = "response", model: str = "test-model") -> Mock:
mock_response = Mock()
mock_response.messages = [Message(role="assistant", contents=[Content.from_text(text)])]
mock_response.model = model
return mock_response
async def test_mcp_tool_sampling_callback_no_client():
"""Test sampling callback error path when no chat client is available."""
tool = MCPStdioTool(name="test_tool", command="python")
@@ -1828,9 +1840,190 @@ async def test_mcp_tool_sampling_callback_no_client():
assert "No chat client available" in result.message
async def test_mcp_tool_sampling_callback_denies_by_default():
"""Sampling is denied when no approval callback is configured (safe default)."""
tool = MCPStdioTool(name="test_tool", command="python")
mock_chat_client = AsyncMock()
tool.client = mock_chat_client
params = Mock()
params.messages = []
params.maxTokens = 128
result = await tool.sampling_callback(Mock(), params)
assert isinstance(result, types.ErrorData)
assert result.code == types.INVALID_REQUEST
assert "denied" in result.message
assert "sampling_approval_callback" in result.message
mock_chat_client.get_response.assert_not_called()
async def test_mcp_tool_sampling_callback_denied_by_callback():
"""Sampling is denied when the approval callback returns a falsy value."""
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=lambda params: False)
mock_chat_client = AsyncMock()
tool.client = mock_chat_client
params = Mock()
params.messages = []
params.maxTokens = 128
result = await tool.sampling_callback(Mock(), params)
assert isinstance(result, types.ErrorData)
assert result.code == types.INVALID_REQUEST
assert "denied by the 'sampling_approval_callback'" in result.message
mock_chat_client.get_response.assert_not_called()
async def test_mcp_tool_sampling_callback_callback_exception_denies():
"""An approval callback that raises results in denial, not an LLM call."""
def boom(_params: object) -> bool:
raise RuntimeError("approval error")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=boom)
mock_chat_client = AsyncMock()
tool.client = mock_chat_client
params = Mock()
params.messages = []
params.maxTokens = 128
result = await tool.sampling_callback(Mock(), params)
assert isinstance(result, types.ErrorData)
assert result.code == types.INVALID_REQUEST
mock_chat_client.get_response.assert_not_called()
async def test_mcp_tool_sampling_callback_async_approval():
"""An async approval callback that approves allows the request through."""
async def approve(_params: object) -> bool:
return True
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=approve)
mock_chat_client = AsyncMock()
mock_chat_client.get_response.return_value = _make_sampling_response("ok")
tool.client = mock_chat_client
params = Mock()
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
params.temperature = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = None
params.toolChoice = None
result = await tool.sampling_callback(Mock(), params)
assert isinstance(result, types.CreateMessageResult)
assert result.content.text == "ok"
mock_chat_client.get_response.assert_awaited_once()
async def test_mcp_tool_sampling_callback_clamps_max_tokens():
"""An approved request's maxTokens is clamped to sampling_max_tokens."""
tool = MCPStdioTool(
name="test_tool",
command="python",
sampling_approval_callback=_approve,
sampling_max_tokens=512,
)
mock_chat_client = AsyncMock()
mock_chat_client.get_response.return_value = _make_sampling_response()
tool.client = mock_chat_client
params = Mock()
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
params.temperature = None
params.maxTokens = 1_000_000
params.stopSequences = None
params.systemPrompt = None
params.tools = None
params.toolChoice = None
result = await tool.sampling_callback(Mock(), params)
assert isinstance(result, types.CreateMessageResult)
options = mock_chat_client.get_response.call_args.kwargs.get("options") or {}
assert options["max_tokens"] == 512
async def test_mcp_tool_sampling_callback_does_not_clamp_under_cap():
"""A request below the cap keeps its requested maxTokens."""
tool = MCPStdioTool(
name="test_tool",
command="python",
sampling_approval_callback=_approve,
sampling_max_tokens=512,
)
mock_chat_client = AsyncMock()
mock_chat_client.get_response.return_value = _make_sampling_response()
tool.client = mock_chat_client
params = Mock()
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
params.temperature = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = None
params.toolChoice = None
result = await tool.sampling_callback(Mock(), params)
assert isinstance(result, types.CreateMessageResult)
options = mock_chat_client.get_response.call_args.kwargs.get("options") or {}
assert options["max_tokens"] == 100
async def test_mcp_tool_sampling_callback_rate_limited():
"""Sampling requests beyond sampling_max_requests are rejected per session."""
tool = MCPStdioTool(
name="test_tool",
command="python",
sampling_approval_callback=_approve,
sampling_max_requests=2,
)
mock_chat_client = AsyncMock()
mock_chat_client.get_response.return_value = _make_sampling_response()
tool.client = mock_chat_client
def make_params() -> Mock:
params = Mock()
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
params.temperature = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = None
params.toolChoice = None
return params
first = await tool.sampling_callback(Mock(), make_params())
second = await tool.sampling_callback(Mock(), make_params())
third = await tool.sampling_callback(Mock(), make_params())
assert isinstance(first, types.CreateMessageResult)
assert isinstance(second, types.CreateMessageResult)
assert isinstance(third, types.ErrorData)
assert third.code == types.INVALID_REQUEST
assert "rate limit" in third.message.lower()
assert mock_chat_client.get_response.await_count == 2
# The counter resets on a session reset.
tool._reset_session_state()
fourth = await tool.sampling_callback(Mock(), make_params())
assert isinstance(fourth, types.CreateMessageResult)
async def test_mcp_tool_sampling_callback_chat_client_exception():
"""Test sampling callback when chat client raises exception."""
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
# Mock chat client that raises exception
mock_chat_client = AsyncMock()
@@ -1846,7 +2039,7 @@ async def test_mcp_tool_sampling_callback_chat_client_exception():
mock_message.content.text = "Test question"
params.messages = [mock_message]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = None
@@ -1863,7 +2056,7 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
"""Test sampling callback when response has no valid content types."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
# Mock chat client with response containing only invalid content types
mock_chat_client = AsyncMock()
@@ -1892,7 +2085,7 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
mock_message.content.text = "Test question"
params.messages = [mock_message]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = None
@@ -1905,18 +2098,18 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
assert "Failed to get right content types from the response." in result.message
mock_chat_client.get_response.assert_awaited_once()
_, kwargs = mock_chat_client.get_response.await_args
assert kwargs["options"] == {"max_tokens": None}
assert kwargs["options"] == {"max_tokens": 100}
async def test_mcp_tool_sampling_callback_no_response_and_successful_message_creation():
"""Test sampling callback when the chat client returns no response and then valid content."""
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
tool.client = AsyncMock()
params = Mock()
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = None
@@ -1955,7 +2148,7 @@ async def test_mcp_tool_sampling_callback_forwards_system_prompt():
"""Test sampling callback passes systemPrompt as instructions in options."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
@@ -1972,7 +2165,7 @@ async def test_mcp_tool_sampling_callback_forwards_system_prompt():
mock_message.content.text = "Test question"
params.messages = [mock_message]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = "You are a helpful assistant"
params.tools = None
@@ -1990,7 +2183,7 @@ async def test_mcp_tool_sampling_callback_forwards_tools():
"""Test sampling callback converts MCP tools to FunctionTools and passes them in options."""
from agent_framework import FunctionTool, Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
@@ -2013,7 +2206,7 @@ async def test_mcp_tool_sampling_callback_forwards_tools():
mock_message.content.text = "Test question"
params.messages = [mock_message]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = [mcp_tool]
@@ -2036,7 +2229,7 @@ async def test_mcp_tool_sampling_callback_forwards_tool_choice():
"""Test sampling callback passes toolChoice mode in options."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
@@ -2053,7 +2246,7 @@ async def test_mcp_tool_sampling_callback_forwards_tool_choice():
mock_message.content.text = "Test question"
params.messages = [mock_message]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = None
@@ -2071,7 +2264,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_system_prompt():
"""Test sampling callback forwards empty string systemPrompt as instructions."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
@@ -2088,7 +2281,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_system_prompt():
mock_message.content.text = "Test question"
params.messages = [mock_message]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = ""
params.tools = None
@@ -2106,7 +2299,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_tools_list():
"""Test sampling callback forwards empty tools list in options."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
@@ -2123,7 +2316,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_tools_list():
mock_message.content.text = "Test question"
params.messages = [mock_message]
params.temperature = None
params.maxTokens = None
params.maxTokens = 100
params.stopSequences = None
params.systemPrompt = None
params.tools = []
@@ -2141,7 +2334,7 @@ async def test_mcp_tool_sampling_callback_forwards_generation_params_in_options(
"""Test sampling callback passes temperature, max_tokens, and stop in options."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
@@ -2182,7 +2375,7 @@ async def test_mcp_tool_sampling_callback_omits_temperature_when_none():
"""Test sampling callback does not set temperature in options when it is None."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
@@ -2219,7 +2412,7 @@ async def test_mcp_tool_sampling_callback_always_passes_max_tokens():
"""Test sampling callback always sets max_tokens in options since maxTokens is a required int field."""
from agent_framework import Message
tool = MCPStdioTool(name="test_tool", command="python")
tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve)
mock_chat_client = AsyncMock()
mock_response = Mock()
+1
View File
@@ -14,6 +14,7 @@ The Model Context Protocol (MCP) is an open standard for connecting AI agents to
| **API Key Authentication** | [`mcp_api_key_auth.py`](mcp_api_key_auth.py) | Demonstrates API key authentication with MCP servers using `header_provider`, runtime invocation kwargs, and a command-line API key argument |
| **GitHub Integration with PAT** | [`mcp_github_pat.py`](mcp_github_pat.py) | Demonstrates connecting to GitHub's MCP server using Personal Access Token (PAT) authentication |
| **Long-Running Task** | [`mcp_long_running_task.py`](mcp_long_running_task.py) | Demonstrates transparent SEP-2663 long-running task handling for MCP tools that advertise `taskSupport=required`. Self-spawns a stdio MCP child server |
| **Sampling Approval** | [`mcp_sampling_approval.py`](mcp_sampling_approval.py) | Demonstrates gating server-initiated `sampling/createMessage` requests with a `sampling_approval_callback`, plus the `sampling_max_tokens` and `sampling_max_requests` guardrails. MCP sampling is denied by default |
## Prerequisites
@@ -0,0 +1,78 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from agent_framework import Agent, MCPStreamableHTTPTool
from agent_framework.openai import OpenAIChatClient
from dotenv import load_dotenv
from mcp import types
# Load environment variables from .env file
load_dotenv()
"""
MCP Sampling Approval Example
MCP servers can send the client a ``sampling/createMessage`` request, asking the
client to run an LLM completion on the server's behalf. Because remote MCP
servers are untrusted third parties, forwarding these server-controlled prompts
to your chat client without review is a confused-deputy risk: a malicious server
could exfiltrate context, force tool calls, or burn through your token budget.
For that reason Agent Framework **denies MCP sampling by default**. To allow it,
pass a ``sampling_approval_callback`` to the MCP tool. The callback receives the
raw ``CreateMessageRequestParams`` and returns ``True`` to approve or ``False``
to deny. It may be synchronous or asynchronous, so you can implement a
human-in-the-loop prompt, a policy check, or an audit log.
Two further guardrails apply to approved requests:
- ``sampling_max_tokens`` caps the server-requested ``maxTokens``.
- ``sampling_max_requests`` limits how many sampling requests a single session
may make.
To restore the legacy "always approve" behavior (only do this for servers you
trust), pass ``sampling_approval_callback=lambda params: True``.
"""
async def approve_sampling(params: types.CreateMessageRequestParams) -> bool:
"""Human-in-the-loop approval gate for server-initiated sampling.
Shows the server-supplied system prompt and messages, then asks the user to
approve or deny. Returning ``False`` rejects the request.
"""
print("\n--- MCP server requested a sampling/createMessage ---")
if params.systemPrompt:
print(f"System prompt: {params.systemPrompt}")
for message in params.messages:
text = getattr(message.content, "text", message.content)
print(f"{message.role}: {text}")
answer = await asyncio.to_thread(input, "Approve this sampling request? [y/N]: ")
return answer.strip().lower() in {"y", "yes"}
async def main() -> None:
"""Run an agent against an MCP server with a sampling approval gate."""
async with Agent(
client=OpenAIChatClient(),
name="Agent",
instructions="You are a helpful assistant. Use your MCP tool when answering the user's question.",
tools=MCPStreamableHTTPTool(
name="MCP tool",
description="MCP tool description.",
url="<your mcp server url>",
# Passing ``client`` enables sampling; the approval callback gates it.
client=OpenAIChatClient(),
sampling_approval_callback=approve_sampling,
sampling_max_tokens=2048,
sampling_max_requests=5,
),
) as agent:
query = "Use your MCP tool to help answer this question."
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
if __name__ == "__main__":
asyncio.run(main())