Files
agent-framework/python/packages/main/agent_framework/_mcp.py
T
Eduard van Valkenburg 6aa746d891 Python: Introducing UserInputRequest and Response types and HostedMcpTool (#405)
* initial work on User Approval (and hosted mcp to validate)

* small update to the comments in the sample

* enable local MCP tools in chatClient get methods

* working streaming and improved setup

* fix for pyright

* updated create_approval -> create_response method

* added tests

* updated HostedMcpTool and addressed feedback

* update type name

* naming updates

* small docstring update

* mypy fix

* fixes and updates

* fixes for responses

* fix int tests

* removed broken tests

* updated test running

* removed specific content check on websearch

* increased timeout

* split slow foundry test

* don't parallel run samples

* add dist load to unit tests

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-09-10 13:37:34 +00:00

707 lines
28 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
import json
import logging
import re
import sys
from abc import abstractmethod
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any
from mcp import types
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.websocket import websocket_client
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from pydantic import BaseModel, create_model
from ._tools import AIFunction
from ._types import ChatMessage, Contents, DataContent, Role, TextContent, UriContent
from .exceptions import ToolException, ToolExecutionException
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
if TYPE_CHECKING:
from ._clients import ChatClientProtocol
logger = logging.getLogger(__name__)
# region: Helpers
LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
"debug": logging.DEBUG,
"info": logging.INFO,
"notice": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
"alert": logging.CRITICAL,
"emergency": logging.CRITICAL,
}
__all__ = [
"MCPStdioTool",
"MCPStreamableHTTPTool",
"MCPWebsocketTool",
]
def _mcp_prompt_message_to_chat_message(
mcp_type: types.PromptMessage | types.SamplingMessage,
) -> ChatMessage:
"""Convert a MCP container type to a Agent Framework type."""
return ChatMessage(
role=Role(value=mcp_type.role),
contents=[_mcp_type_to_ai_content(mcp_type.content)], # type: ignore[call-arg]
raw_representation=mcp_type,
)
def _mcp_call_tool_result_to_ai_contents(
mcp_type: types.CallToolResult,
) -> list[Contents]:
"""Convert a MCP container type to a Agent Framework type."""
return [_mcp_type_to_ai_content(item) for item in mcp_type.content]
def _mcp_type_to_ai_content(
mcp_type: types.ImageContent | types.TextContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink,
) -> Contents:
"""Convert a MCP type to a Agent Framework type."""
match mcp_type:
case types.TextContent():
return TextContent(text=mcp_type.text, raw_representation=mcp_type)
case types.ImageContent() | types.AudioContent():
return DataContent(uri=mcp_type.data, media_type=mcp_type.mimeType, raw_representation=mcp_type)
case types.ResourceLink():
return UriContent(
uri=str(mcp_type.uri), media_type=mcp_type.mimeType or "application/json", raw_representation=mcp_type
)
case _:
match mcp_type.resource:
case types.TextResourceContents():
return TextContent(
text=mcp_type.resource.text,
raw_representation=mcp_type,
additional_properties=mcp_type.annotations.model_dump() if mcp_type.annotations else None,
)
case types.BlobResourceContents():
return DataContent(
uri=mcp_type.resource.blob,
media_type=mcp_type.resource.mimeType,
raw_representation=mcp_type,
additional_properties=mcp_type.annotations.model_dump() if mcp_type.annotations else None,
)
def _ai_content_to_mcp_types(
content: Contents,
) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None:
"""Convert a BaseContent type to a MCP type."""
match content:
case TextContent():
return types.TextContent(type="text", text=content.text)
case DataContent():
if content.media_type and content.media_type.startswith("image/"):
return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type)
if content.media_type and content.media_type.startswith("audio/"):
return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type)
if content.media_type and content.media_type.startswith("application/"):
return types.EmbeddedResource(
type="resource",
resource=types.BlobResourceContents(
blob=content.uri,
mimeType=content.media_type,
# uri's are not limited in MCP but they have to be set.
# the uri of data content, contains the data uri, which
# is not the uri meant here, UriContent would match this.
uri=content.additional_properties.get("uri", "af://binary")
if content.additional_properties
else "af://binary", # type: ignore[reportArgumentType]
),
)
return None
case UriContent():
return types.ResourceLink(
type="resource_link",
uri=content.uri, # type: ignore[reportArgumentType]
mimeType=content.media_type,
name=content.additional_properties.get("name", "Unknown")
if content.additional_properties
else "Unknown",
)
case _:
return None
def _chat_message_to_mcp_types(
content: ChatMessage,
) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink]:
"""Convert a ChatMessage to a list of MCP types."""
messages: list[
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink
] = []
for item in content.contents:
mcp_content = _ai_content_to_mcp_types(item)
if mcp_content:
messages.append(mcp_content)
return messages
def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]:
"""Creates a Pydantic model from a prompt's parameters."""
# Check if 'arguments' is missing or empty
if not prompt.arguments:
return create_model(f"{prompt.name}_input")
field_definitions: dict[str, Any] = {}
for prompt_argument in prompt.arguments:
# For prompts, all arguments are typically required and string type
# unless specified otherwise in the prompt argument
python_type = str # Default type for prompt arguments
# Create field definition for create_model
if prompt_argument.required:
field_definitions[prompt_argument.name] = (python_type, ...)
else:
field_definitions[prompt_argument.name] = (python_type, None)
return create_model(f"{prompt.name}_input", **field_definitions)
def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
"""Creates a Pydantic model from a tools parameters."""
properties = tool.inputSchema.get("properties", None)
required = tool.inputSchema.get("required", [])
# Check if 'properties' is missing or not a dictionary
if not properties:
return create_model(f"{tool.name}_input")
field_definitions: dict[str, Any] = {}
for prop_name, prop_details in properties.items():
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
# Map JSON Schema types to Python types
json_type = prop_details.get("type", "string")
python_type: type = str # default
if json_type == "integer":
python_type = int
elif json_type == "number":
python_type = float
elif json_type == "boolean":
python_type = bool
elif json_type == "array":
python_type = list
elif json_type == "object":
python_type = dict
# Create field definition for create_model
if prop_name in required:
field_definitions[prop_name] = (python_type, ...)
else:
default_value = prop_details.get("default", None)
field_definitions[prop_name] = (python_type, default_value)
return create_model(f"{tool.name}_input", **field_definitions)
def _normalize_mcp_name(name: str) -> str:
"""Normalize MCP tool/prompt names to allowed identifier pattern (A-Za-z0-9_.-)."""
return re.sub(r"[^A-Za-z0-9_.-]", "-", name)
# region: MCP Plugin
class MCPTool:
"""Main MCP class, to initialize use one of the subclasses."""
def __init__(
self,
name: str,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
load_tools: bool = True,
load_prompts: bool = True,
session: ClientSession | None = None,
request_timeout: int | None = None,
chat_client: "ChatClientProtocol | None" = None,
) -> None:
"""Initialize the MCP Plugin Base."""
self.name = name
self.description = description or ""
self.additional_properties = additional_properties
self.load_tools_flag = load_tools
self.load_prompts_flag = load_prompts
self._exit_stack = AsyncExitStack()
self.session = session
self.request_timeout = request_timeout
self.chat_client = chat_client
self.functions: list[AIFunction[Any, Any]] = []
def __str__(self) -> str:
return f"MCPTool(name={self.name}, description={self.description})"
async def connect(self) -> None:
"""Connect to the MCP server."""
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
except Exception as ex:
await self._exit_stack.aclose()
raise ToolException(
"Failed to connect to the MCP server. Please check your configuration.", inner_exception=ex
) from ex
try:
session = await self._exit_stack.enter_async_context(
ClientSession(
read_stream=transport[0],
write_stream=transport[1],
read_timeout_seconds=timedelta(seconds=self.request_timeout) if self.request_timeout else None,
message_handler=self.message_handler,
logging_callback=self.logging_callback,
sampling_callback=self.sampling_callback,
)
)
except Exception as ex:
await self._exit_stack.aclose()
raise ToolException(
message="Failed to create a session. Please check your configuration.", inner_exception=ex
) from ex
await session.initialize()
self.session = session
elif self.session._request_id == 0: # type: ignore[reportPrivateUsage]
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
logger.debug("Connected to MCP server: %s", self.session)
if self.load_tools_flag:
await self.load_tools()
if self.load_prompts_flag:
await self.load_prompts()
if logger.level != logging.NOTSET:
try:
await self.session.set_logging_level(
next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
)
except Exception as exc:
logger.warning("Failed to set log level to %s", logger.level, exc_info=exc)
async def sampling_callback(
self, context: RequestContext[ClientSession, Any], params: types.CreateMessageRequestParams
) -> types.CreateMessageResult | types.ErrorData:
"""Callback function for sampling.
This function is called when the MCP server needs to get a message completed.
This is a simple version of this function, it can be overridden to allow more complex sampling.
It get's added to the session at initialization time, so overriding it is the best way to do this.
"""
if not self.chat_client:
return types.ErrorData(
code=types.INTERNAL_ERROR,
message="No chat client available. Please set a chat client.",
)
logger.debug("Sampling callback called with params: %s", params)
messages: list[ChatMessage] = []
for msg in params.messages:
messages.append(_mcp_prompt_message_to_chat_message(msg))
try:
response = await self.chat_client.get_response(
messages,
temperature=params.temperature,
max_tokens=params.maxTokens,
stop=params.stopSequences,
)
except Exception as ex:
return types.ErrorData(
code=types.INTERNAL_ERROR,
message=f"Failed to get chat message content: {ex}",
)
if not response or not response.messages:
return types.ErrorData(
code=types.INTERNAL_ERROR,
message="Failed to get chat message content.",
)
mcp_contents = _chat_message_to_mcp_types(response.messages[0])
# grab the first content that is of type TextContent or ImageContent
mcp_content = next(
(content for content in mcp_contents if isinstance(content, (types.TextContent, types.ImageContent))),
None,
)
if not mcp_content:
return types.ErrorData(
code=types.INTERNAL_ERROR,
message="Failed to get right content types from the response.",
)
return types.CreateMessageResult(
role="assistant",
content=mcp_content,
model=response.ai_model_id or "unknown",
)
async def logging_callback(self, params: types.LoggingMessageNotificationParams) -> None:
"""Callback function for logging.
This function is called when the MCP Server sends a log message.
By default it will log the message to the logger with the level set in the params.
Please subclass the MCP*Plugin and override this function if you want to adapt the behavior.
"""
logger.log(LOG_LEVEL_MAPPING[params.level], params.data)
async def message_handler(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
"""Handle messages from the MCP server.
By default this function will handle exceptions on the server, by logging those.
And it will trigger a reload of the tools and prompts when the list changed notification is received.
If you want to extend this behavior you can subclass the MCPPlugin and override this function,
if you want to keep the default behavior, make sure to call `super().message_handler(message)`.
"""
if isinstance(message, Exception):
logger.error("Error from MCP server: %s", message, exc_info=message)
return
if isinstance(message, types.ServerNotification):
match message.root.method:
case "notifications/tools/list_changed":
await self.load_tools()
case "notifications/prompts/list_changed":
await self.load_prompts()
case _:
logger.debug("Unhandled notification: %s", message.root.method)
async def load_prompts(self) -> None:
"""Load prompts from the MCP server."""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
try:
prompt_list = await self.session.list_prompts()
except Exception as exc:
logger.info(
"Prompt could not be loaded, you can exclude trying to load, by setting: load_prompts=False",
exc_info=exc,
)
prompt_list = None
for prompt in prompt_list.prompts if prompt_list else []:
local_name = _normalize_mcp_name(prompt.name)
input_model = _get_input_model_from_mcp_prompt(prompt)
func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction(
func=partial(self.get_prompt, prompt.name),
name=local_name,
description=prompt.description or "",
input_model=input_model,
)
self.functions.append(func)
async def load_tools(self) -> None:
"""Load tools from the MCP server."""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
try:
tool_list = await self.session.list_tools()
except Exception as exc:
logger.info(
"Tools could not be loaded, you can exclude trying to load, by setting: load_tools=False",
exc_info=exc,
)
tool_list = None
for tool in tool_list.tools if tool_list else []:
local_name = _normalize_mcp_name(tool.name)
input_model = _get_input_model_from_mcp_tool(tool)
# Create AIFunctions out of each tool
func: AIFunction[BaseModel, list[Contents]] = AIFunction(
func=partial(self.call_tool, tool.name),
name=local_name,
description=tool.description or "",
input_model=input_model,
)
self.functions.append(func)
async def close(self) -> None:
"""Disconnect from the MCP server."""
await self._exit_stack.aclose()
self.session = None
@abstractmethod
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP client."""
pass
async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]:
"""Call a tool with the given arguments."""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
if not self.load_tools_flag:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
)
try:
return _mcp_call_tool_result_to_ai_contents(await self.session.call_tool(tool_name, arguments=kwargs))
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage]:
"""Call a prompt with the given arguments."""
if not self.session:
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
if not self.load_prompts_flag:
raise ToolExecutionException(
"Prompts are not loaded for this server, please set load_prompts=True in the constructor."
)
try:
prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs)
return [_mcp_prompt_message_to_chat_message(message) for message in prompt_result.messages]
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
async def __aenter__(self) -> Self:
"""Enter the context manager."""
try:
await self.connect()
return self
except ToolException:
raise
except Exception as ex:
await self._exit_stack.aclose()
raise ToolExecutionException("Failed to enter context manager.", inner_exception=ex) from ex
async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any
) -> None:
"""Exit the context manager."""
await self.close()
# region: MCP Plugin Implementations
class MCPStdioTool(MCPTool):
"""MCP stdio server configuration."""
def __init__(
self,
name: str,
command: str,
*,
load_tools: bool = True,
load_prompts: bool = True,
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
args: list[str] | None = None,
env: dict[str, str] | None = None,
encoding: str | None = None,
chat_client: "ChatClientProtocol | None" = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP stdio plugin.
The arguments are used to create a StdioServerParameters object.
Which is then used to create a stdio client.
see mcp.client.stdio.stdio_client and mcp.client.stdio.stdio_server_parameters
for more details.
Args:
name: The name of the plugin.
command: The command to run the MCP server.
load_tools: Whether to load tools from the MCP server.
load_prompts: Whether to load prompts from the MCP server.
request_timeout: The default timeout used for all requests.
session: The session to use for the MCP connection.
description: The description of the plugin.
additional_properties: Additional properties.
args: The arguments to pass to the command.
env: The environment variables to set for the command.
encoding: The encoding to use for the command output.
chat_client: The chat client to use for sampling.
kwargs: Any extra arguments to pass to the stdio client.
"""
super().__init__(
name=name,
description=description,
additional_properties=additional_properties,
session=session,
chat_client=chat_client,
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
)
self.command = command
self.args = args or []
self.env = env
self.encoding = encoding
self._client_kwargs = kwargs
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP stdio client."""
args: dict[str, Any] = {
"command": self.command,
"args": self.args,
"env": self.env,
}
if self.encoding:
args["encoding"] = self.encoding
if self._client_kwargs:
args.update(self._client_kwargs)
return stdio_client(server=StdioServerParameters(**args))
class MCPStreamableHTTPTool(MCPTool):
"""MCP streamable http server configuration."""
def __init__(
self,
name: str,
url: str,
*,
load_tools: bool = True,
load_prompts: bool = True,
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
terminate_on_close: bool | None = None,
chat_client: "ChatClientProtocol | None" = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable http plugin.
The arguments are used to create a streamable http client.
see mcp.client.streamable_http.streamablehttp_client for more details.
Any extra arguments passed to the constructor will be passed to the
streamable http client constructor.
Args:
name: The name of the plugin.
url: The URL of the MCP server.
load_tools: Whether to load tools from the MCP server.
load_prompts: Whether to load prompts from the MCP server.
request_timeout: The default timeout used for all requests.
session: The session to use for the MCP connection.
description: The description of the plugin.
additional_properties: Additional properties.
headers: The headers to send with the request.
timeout: The timeout for the request.
sse_read_timeout: The timeout for reading from the SSE stream.
terminate_on_close: Close the transport when the MCP client is terminated.
chat_client: The chat client to use for sampling.
kwargs: Any extra arguments to pass to the sse client.
"""
super().__init__(
name=name,
description=description,
additional_properties=additional_properties,
session=session,
chat_client=chat_client,
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
)
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.terminate_on_close = terminate_on_close
self._client_kwargs = kwargs
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP streamable http client."""
args: dict[str, Any] = {
"url": self.url,
}
if self.headers:
args["headers"] = self.headers
if self.timeout is not None:
args["timeout"] = self.timeout
if self.sse_read_timeout is not None:
args["sse_read_timeout"] = self.sse_read_timeout
if self.terminate_on_close is not None:
args["terminate_on_close"] = self.terminate_on_close
if self._client_kwargs:
args.update(self._client_kwargs)
return streamablehttp_client(**args)
class MCPWebsocketTool(MCPTool):
"""MCP websocket server configuration."""
def __init__(
self,
name: str,
url: str,
*,
load_tools: bool = True,
load_prompts: bool = True,
request_timeout: int | None = None,
session: ClientSession | None = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
chat_client: "ChatClientProtocol | None" = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP websocket plugin.
The arguments are used to create a websocket client.
see mcp.client.websocket.websocket_client for more details.
Any extra arguments passed to the constructor will be passed to the
websocket client constructor.
Args:
name: The name of the plugin.
url: The URL of the MCP server.
load_tools: Whether to load tools from the MCP server.
load_prompts: Whether to load prompts from the MCP server.
request_timeout: The default timeout used for all requests.
session: The session to use for the MCP connection.
description: The description of the plugin.
additional_properties: Additional properties.
chat_client: The chat client to use for sampling.
kwargs: Any extra arguments to pass to the websocket client.
"""
super().__init__(
name=name,
description=description,
additional_properties=additional_properties,
session=session,
chat_client=chat_client,
load_tools=load_tools,
load_prompts=load_prompts,
request_timeout=request_timeout,
)
self.url = url
self._client_kwargs = kwargs
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP websocket client."""
args: dict[str, Any] = {
"url": self.url,
}
if self._client_kwargs:
args.update(self._client_kwargs)
return websocket_client(**args)