mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
10d10364a9
* cleanup of threads and serialization * fix for sliding window * fix redis test * updated from comments * updated context provider and threads * updated lock * add asyncio default * fix redis tests * fix tests * fix tests * renamed to invoking * fixed tests * fix for instructions
710 lines
28 KiB
Python
710 lines
28 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import sys
|
|
from abc import abstractmethod
|
|
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
|
from datetime import timedelta
|
|
from functools import partial
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from mcp import types
|
|
from mcp.client.session import ClientSession
|
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
from mcp.client.websocket import websocket_client
|
|
from mcp.shared.context import RequestContext
|
|
from mcp.shared.exceptions import McpError
|
|
from mcp.shared.session import RequestResponder
|
|
from pydantic import BaseModel, create_model
|
|
|
|
from ._tools import AIFunction
|
|
from ._types import ChatMessage, Contents, DataContent, Role, TextContent, UriContent
|
|
from .exceptions import ToolException, ToolExecutionException
|
|
|
|
if sys.version_info >= (3, 11):
|
|
from typing import Self # pragma: no cover
|
|
else:
|
|
from typing_extensions import Self # pragma: no cover
|
|
|
|
if TYPE_CHECKING:
|
|
from ._clients import ChatClientProtocol
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# region: Helpers
|
|
|
|
LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
|
|
"debug": logging.DEBUG,
|
|
"info": logging.INFO,
|
|
"notice": logging.INFO,
|
|
"warning": logging.WARNING,
|
|
"error": logging.ERROR,
|
|
"critical": logging.CRITICAL,
|
|
"alert": logging.CRITICAL,
|
|
"emergency": logging.CRITICAL,
|
|
}
|
|
|
|
__all__ = [
|
|
"MCPStdioTool",
|
|
"MCPStreamableHTTPTool",
|
|
"MCPWebsocketTool",
|
|
]
|
|
|
|
|
|
def _mcp_prompt_message_to_chat_message(
|
|
mcp_type: types.PromptMessage | types.SamplingMessage,
|
|
) -> ChatMessage:
|
|
"""Convert a MCP container type to a Agent Framework type."""
|
|
return ChatMessage(
|
|
role=Role(value=mcp_type.role),
|
|
contents=[_mcp_type_to_ai_content(mcp_type.content)], # type: ignore[call-arg]
|
|
raw_representation=mcp_type,
|
|
)
|
|
|
|
|
|
def _mcp_call_tool_result_to_ai_contents(
|
|
mcp_type: types.CallToolResult,
|
|
) -> list[Contents]:
|
|
"""Convert a MCP container type to a Agent Framework type."""
|
|
return [_mcp_type_to_ai_content(item) for item in mcp_type.content]
|
|
|
|
|
|
def _mcp_type_to_ai_content(
|
|
mcp_type: types.ImageContent | types.TextContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink,
|
|
) -> Contents:
|
|
"""Convert a MCP type to a Agent Framework type."""
|
|
match mcp_type:
|
|
case types.TextContent():
|
|
return TextContent(text=mcp_type.text, raw_representation=mcp_type)
|
|
case types.ImageContent() | types.AudioContent():
|
|
return DataContent(uri=mcp_type.data, media_type=mcp_type.mimeType, raw_representation=mcp_type)
|
|
case types.ResourceLink():
|
|
return UriContent(
|
|
uri=str(mcp_type.uri), media_type=mcp_type.mimeType or "application/json", raw_representation=mcp_type
|
|
)
|
|
case _:
|
|
match mcp_type.resource:
|
|
case types.TextResourceContents():
|
|
return TextContent(
|
|
text=mcp_type.resource.text,
|
|
raw_representation=mcp_type,
|
|
additional_properties=mcp_type.annotations.model_dump() if mcp_type.annotations else None,
|
|
)
|
|
case types.BlobResourceContents():
|
|
return DataContent(
|
|
uri=mcp_type.resource.blob,
|
|
media_type=mcp_type.resource.mimeType,
|
|
raw_representation=mcp_type,
|
|
additional_properties=mcp_type.annotations.model_dump() if mcp_type.annotations else None,
|
|
)
|
|
|
|
|
|
def _ai_content_to_mcp_types(
|
|
content: Contents,
|
|
) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None:
|
|
"""Convert a BaseContent type to a MCP type."""
|
|
match content:
|
|
case TextContent():
|
|
return types.TextContent(type="text", text=content.text)
|
|
case DataContent():
|
|
if content.media_type and content.media_type.startswith("image/"):
|
|
return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type)
|
|
if content.media_type and content.media_type.startswith("audio/"):
|
|
return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type)
|
|
if content.media_type and content.media_type.startswith("application/"):
|
|
return types.EmbeddedResource(
|
|
type="resource",
|
|
resource=types.BlobResourceContents(
|
|
blob=content.uri,
|
|
mimeType=content.media_type,
|
|
# uri's are not limited in MCP but they have to be set.
|
|
# the uri of data content, contains the data uri, which
|
|
# is not the uri meant here, UriContent would match this.
|
|
uri=content.additional_properties.get("uri", "af://binary")
|
|
if content.additional_properties
|
|
else "af://binary", # type: ignore[reportArgumentType]
|
|
),
|
|
)
|
|
return None
|
|
case UriContent():
|
|
return types.ResourceLink(
|
|
type="resource_link",
|
|
uri=content.uri, # type: ignore[reportArgumentType]
|
|
mimeType=content.media_type,
|
|
name=content.additional_properties.get("name", "Unknown")
|
|
if content.additional_properties
|
|
else "Unknown",
|
|
)
|
|
case _:
|
|
return None
|
|
|
|
|
|
def _chat_message_to_mcp_types(
|
|
content: ChatMessage,
|
|
) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink]:
|
|
"""Convert a ChatMessage to a list of MCP types."""
|
|
messages: list[
|
|
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink
|
|
] = []
|
|
for item in content.contents:
|
|
mcp_content = _ai_content_to_mcp_types(item)
|
|
if mcp_content:
|
|
messages.append(mcp_content)
|
|
return messages
|
|
|
|
|
|
def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]:
|
|
"""Creates a Pydantic model from a prompt's parameters."""
|
|
# Check if 'arguments' is missing or empty
|
|
if not prompt.arguments:
|
|
return create_model(f"{prompt.name}_input")
|
|
|
|
field_definitions: dict[str, Any] = {}
|
|
for prompt_argument in prompt.arguments:
|
|
# For prompts, all arguments are typically required and string type
|
|
# unless specified otherwise in the prompt argument
|
|
python_type = str # Default type for prompt arguments
|
|
|
|
# Create field definition for create_model
|
|
if prompt_argument.required:
|
|
field_definitions[prompt_argument.name] = (python_type, ...)
|
|
else:
|
|
field_definitions[prompt_argument.name] = (python_type, None)
|
|
|
|
return create_model(f"{prompt.name}_input", **field_definitions)
|
|
|
|
|
|
def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
|
|
"""Creates a Pydantic model from a tools parameters."""
|
|
properties = tool.inputSchema.get("properties", None)
|
|
required = tool.inputSchema.get("required", [])
|
|
# Check if 'properties' is missing or not a dictionary
|
|
if not properties:
|
|
return create_model(f"{tool.name}_input")
|
|
|
|
field_definitions: dict[str, Any] = {}
|
|
for prop_name, prop_details in properties.items():
|
|
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
|
|
|
|
# Map JSON Schema types to Python types
|
|
json_type = prop_details.get("type", "string")
|
|
python_type: type = str # default
|
|
if json_type == "integer":
|
|
python_type = int
|
|
elif json_type == "number":
|
|
python_type = float
|
|
elif json_type == "boolean":
|
|
python_type = bool
|
|
elif json_type == "array":
|
|
python_type = list
|
|
elif json_type == "object":
|
|
python_type = dict
|
|
|
|
# Create field definition for create_model
|
|
if prop_name in required:
|
|
field_definitions[prop_name] = (python_type, ...)
|
|
else:
|
|
default_value = prop_details.get("default", None)
|
|
field_definitions[prop_name] = (python_type, default_value)
|
|
|
|
return create_model(f"{tool.name}_input", **field_definitions)
|
|
|
|
|
|
def _normalize_mcp_name(name: str) -> str:
|
|
"""Normalize MCP tool/prompt names to allowed identifier pattern (A-Za-z0-9_.-)."""
|
|
return re.sub(r"[^A-Za-z0-9_.-]", "-", name)
|
|
|
|
|
|
# region: MCP Plugin
|
|
|
|
|
|
class MCPTool:
|
|
"""Main MCP class, to initialize use one of the subclasses."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
description: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
load_tools: bool = True,
|
|
load_prompts: bool = True,
|
|
session: ClientSession | None = None,
|
|
request_timeout: int | None = None,
|
|
chat_client: "ChatClientProtocol | None" = None,
|
|
) -> None:
|
|
"""Initialize the MCP Plugin Base."""
|
|
self.name = name
|
|
self.description = description or ""
|
|
self.additional_properties = additional_properties
|
|
self.load_tools_flag = load_tools
|
|
self.load_prompts_flag = load_prompts
|
|
self._exit_stack = AsyncExitStack()
|
|
self.session = session
|
|
self.request_timeout = request_timeout
|
|
self.chat_client = chat_client
|
|
self.functions: list[AIFunction[Any, Any]] = []
|
|
self.is_connected: bool = False
|
|
|
|
def __str__(self) -> str:
|
|
return f"MCPTool(name={self.name}, description={self.description})"
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect to the MCP server."""
|
|
if not self.session:
|
|
try:
|
|
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
|
|
except Exception as ex:
|
|
await self._exit_stack.aclose()
|
|
raise ToolException(
|
|
"Failed to connect to the MCP server. Please check your configuration.", inner_exception=ex
|
|
) from ex
|
|
try:
|
|
session = await self._exit_stack.enter_async_context(
|
|
ClientSession(
|
|
read_stream=transport[0],
|
|
write_stream=transport[1],
|
|
read_timeout_seconds=timedelta(seconds=self.request_timeout) if self.request_timeout else None,
|
|
message_handler=self.message_handler,
|
|
logging_callback=self.logging_callback,
|
|
sampling_callback=self.sampling_callback,
|
|
)
|
|
)
|
|
except Exception as ex:
|
|
await self._exit_stack.aclose()
|
|
raise ToolException(
|
|
message="Failed to create a session. Please check your configuration.", inner_exception=ex
|
|
) from ex
|
|
await session.initialize()
|
|
self.session = session
|
|
elif self.session._request_id == 0: # type: ignore[reportPrivateUsage]
|
|
# If the session is not initialized, we need to reinitialize it
|
|
await self.session.initialize()
|
|
logger.debug("Connected to MCP server: %s", self.session)
|
|
self.is_connected = True
|
|
if self.load_tools_flag:
|
|
await self.load_tools()
|
|
if self.load_prompts_flag:
|
|
await self.load_prompts()
|
|
|
|
if logger.level != logging.NOTSET:
|
|
try:
|
|
await self.session.set_logging_level(
|
|
next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("Failed to set log level to %s", logger.level, exc_info=exc)
|
|
|
|
async def sampling_callback(
|
|
self, context: RequestContext[ClientSession, Any], params: types.CreateMessageRequestParams
|
|
) -> types.CreateMessageResult | types.ErrorData:
|
|
"""Callback function for sampling.
|
|
|
|
This function is called when the MCP server needs to get a message completed.
|
|
|
|
This is a simple version of this function, it can be overridden to allow more complex sampling.
|
|
It get's added to the session at initialization time, so overriding it is the best way to do this.
|
|
"""
|
|
if not self.chat_client:
|
|
return types.ErrorData(
|
|
code=types.INTERNAL_ERROR,
|
|
message="No chat client available. Please set a chat client.",
|
|
)
|
|
logger.debug("Sampling callback called with params: %s", params)
|
|
messages: list[ChatMessage] = []
|
|
for msg in params.messages:
|
|
messages.append(_mcp_prompt_message_to_chat_message(msg))
|
|
try:
|
|
response = await self.chat_client.get_response(
|
|
messages,
|
|
temperature=params.temperature,
|
|
max_tokens=params.maxTokens,
|
|
stop=params.stopSequences,
|
|
)
|
|
except Exception as ex:
|
|
return types.ErrorData(
|
|
code=types.INTERNAL_ERROR,
|
|
message=f"Failed to get chat message content: {ex}",
|
|
)
|
|
if not response or not response.messages:
|
|
return types.ErrorData(
|
|
code=types.INTERNAL_ERROR,
|
|
message="Failed to get chat message content.",
|
|
)
|
|
mcp_contents = _chat_message_to_mcp_types(response.messages[0])
|
|
# grab the first content that is of type TextContent or ImageContent
|
|
mcp_content = next(
|
|
(content for content in mcp_contents if isinstance(content, (types.TextContent, types.ImageContent))),
|
|
None,
|
|
)
|
|
if not mcp_content:
|
|
return types.ErrorData(
|
|
code=types.INTERNAL_ERROR,
|
|
message="Failed to get right content types from the response.",
|
|
)
|
|
return types.CreateMessageResult(
|
|
role="assistant",
|
|
content=mcp_content,
|
|
model=response.ai_model_id or "unknown",
|
|
)
|
|
|
|
async def logging_callback(self, params: types.LoggingMessageNotificationParams) -> None:
|
|
"""Callback function for logging.
|
|
|
|
This function is called when the MCP Server sends a log message.
|
|
By default it will log the message to the logger with the level set in the params.
|
|
|
|
Please subclass the MCP*Plugin and override this function if you want to adapt the behavior.
|
|
"""
|
|
logger.log(LOG_LEVEL_MAPPING[params.level], params.data)
|
|
|
|
async def message_handler(
|
|
self,
|
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
|
) -> None:
|
|
"""Handle messages from the MCP server.
|
|
|
|
By default this function will handle exceptions on the server, by logging those.
|
|
|
|
And it will trigger a reload of the tools and prompts when the list changed notification is received.
|
|
|
|
If you want to extend this behavior you can subclass the MCPPlugin and override this function,
|
|
if you want to keep the default behavior, make sure to call `super().message_handler(message)`.
|
|
"""
|
|
if isinstance(message, Exception):
|
|
logger.error("Error from MCP server: %s", message, exc_info=message)
|
|
return
|
|
if isinstance(message, types.ServerNotification):
|
|
match message.root.method:
|
|
case "notifications/tools/list_changed":
|
|
await self.load_tools()
|
|
case "notifications/prompts/list_changed":
|
|
await self.load_prompts()
|
|
case _:
|
|
logger.debug("Unhandled notification: %s", message.root.method)
|
|
|
|
async def load_prompts(self) -> None:
|
|
"""Load prompts from the MCP server."""
|
|
if not self.session:
|
|
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
|
try:
|
|
prompt_list = await self.session.list_prompts()
|
|
except Exception as exc:
|
|
logger.info(
|
|
"Prompt could not be loaded, you can exclude trying to load, by setting: load_prompts=False",
|
|
exc_info=exc,
|
|
)
|
|
prompt_list = None
|
|
for prompt in prompt_list.prompts if prompt_list else []:
|
|
local_name = _normalize_mcp_name(prompt.name)
|
|
input_model = _get_input_model_from_mcp_prompt(prompt)
|
|
func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction(
|
|
func=partial(self.get_prompt, prompt.name),
|
|
name=local_name,
|
|
description=prompt.description or "",
|
|
input_model=input_model,
|
|
)
|
|
self.functions.append(func)
|
|
|
|
async def load_tools(self) -> None:
|
|
"""Load tools from the MCP server."""
|
|
if not self.session:
|
|
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
|
try:
|
|
tool_list = await self.session.list_tools()
|
|
except Exception as exc:
|
|
logger.info(
|
|
"Tools could not be loaded, you can exclude trying to load, by setting: load_tools=False",
|
|
exc_info=exc,
|
|
)
|
|
tool_list = None
|
|
for tool in tool_list.tools if tool_list else []:
|
|
local_name = _normalize_mcp_name(tool.name)
|
|
input_model = _get_input_model_from_mcp_tool(tool)
|
|
# Create AIFunctions out of each tool
|
|
func: AIFunction[BaseModel, list[Contents]] = AIFunction(
|
|
func=partial(self.call_tool, tool.name),
|
|
name=local_name,
|
|
description=tool.description or "",
|
|
input_model=input_model,
|
|
)
|
|
self.functions.append(func)
|
|
|
|
async def close(self) -> None:
|
|
"""Disconnect from the MCP server."""
|
|
await self._exit_stack.aclose()
|
|
self.session = None
|
|
self.is_connected = False
|
|
|
|
@abstractmethod
|
|
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
|
"""Get an MCP client."""
|
|
pass
|
|
|
|
async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]:
|
|
"""Call a tool with the given arguments."""
|
|
if not self.session:
|
|
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
|
if not self.load_tools_flag:
|
|
raise ToolExecutionException(
|
|
"Tools are not loaded for this server, please set load_tools=True in the constructor."
|
|
)
|
|
try:
|
|
return _mcp_call_tool_result_to_ai_contents(await self.session.call_tool(tool_name, arguments=kwargs))
|
|
except McpError as mcp_exc:
|
|
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
|
|
except Exception as ex:
|
|
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
|
|
|
async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage]:
|
|
"""Call a prompt with the given arguments."""
|
|
if not self.session:
|
|
raise ToolExecutionException("MCP server not connected, please call connect() before using this method.")
|
|
if not self.load_prompts_flag:
|
|
raise ToolExecutionException(
|
|
"Prompts are not loaded for this server, please set load_prompts=True in the constructor."
|
|
)
|
|
try:
|
|
prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs)
|
|
return [_mcp_prompt_message_to_chat_message(message) for message in prompt_result.messages]
|
|
except McpError as mcp_exc:
|
|
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
|
|
except Exception as ex:
|
|
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Enter the context manager."""
|
|
try:
|
|
await self.connect()
|
|
return self
|
|
except ToolException:
|
|
raise
|
|
except Exception as ex:
|
|
await self._exit_stack.aclose()
|
|
raise ToolExecutionException("Failed to enter context manager.", inner_exception=ex) from ex
|
|
|
|
async def __aexit__(
|
|
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any
|
|
) -> None:
|
|
"""Exit the context manager."""
|
|
await self.close()
|
|
|
|
|
|
# region: MCP Plugin Implementations
|
|
|
|
|
|
class MCPStdioTool(MCPTool):
|
|
"""MCP stdio server configuration."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
command: str,
|
|
*,
|
|
load_tools: bool = True,
|
|
load_prompts: bool = True,
|
|
request_timeout: int | None = None,
|
|
session: ClientSession | None = None,
|
|
description: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
args: list[str] | None = None,
|
|
env: dict[str, str] | None = None,
|
|
encoding: str | None = None,
|
|
chat_client: "ChatClientProtocol | None" = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the MCP stdio plugin.
|
|
|
|
The arguments are used to create a StdioServerParameters object.
|
|
Which is then used to create a stdio client.
|
|
see mcp.client.stdio.stdio_client and mcp.client.stdio.stdio_server_parameters
|
|
for more details.
|
|
|
|
Args:
|
|
name: The name of the plugin.
|
|
command: The command to run the MCP server.
|
|
load_tools: Whether to load tools from the MCP server.
|
|
load_prompts: Whether to load prompts from the MCP server.
|
|
request_timeout: The default timeout used for all requests.
|
|
session: The session to use for the MCP connection.
|
|
description: The description of the plugin.
|
|
additional_properties: Additional properties.
|
|
args: The arguments to pass to the command.
|
|
env: The environment variables to set for the command.
|
|
encoding: The encoding to use for the command output.
|
|
chat_client: The chat client to use for sampling.
|
|
kwargs: Any extra arguments to pass to the stdio client.
|
|
|
|
"""
|
|
super().__init__(
|
|
name=name,
|
|
description=description,
|
|
additional_properties=additional_properties,
|
|
session=session,
|
|
chat_client=chat_client,
|
|
load_tools=load_tools,
|
|
load_prompts=load_prompts,
|
|
request_timeout=request_timeout,
|
|
)
|
|
self.command = command
|
|
self.args = args or []
|
|
self.env = env
|
|
self.encoding = encoding
|
|
self._client_kwargs = kwargs
|
|
|
|
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
|
"""Get an MCP stdio client."""
|
|
args: dict[str, Any] = {
|
|
"command": self.command,
|
|
"args": self.args,
|
|
"env": self.env,
|
|
}
|
|
if self.encoding:
|
|
args["encoding"] = self.encoding
|
|
if self._client_kwargs:
|
|
args.update(self._client_kwargs)
|
|
return stdio_client(server=StdioServerParameters(**args))
|
|
|
|
|
|
class MCPStreamableHTTPTool(MCPTool):
|
|
"""MCP streamable http server configuration."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
url: str,
|
|
*,
|
|
load_tools: bool = True,
|
|
load_prompts: bool = True,
|
|
request_timeout: int | None = None,
|
|
session: ClientSession | None = None,
|
|
description: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
headers: dict[str, Any] | None = None,
|
|
timeout: float | None = None,
|
|
sse_read_timeout: float | None = None,
|
|
terminate_on_close: bool | None = None,
|
|
chat_client: "ChatClientProtocol | None" = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the MCP streamable http plugin.
|
|
|
|
The arguments are used to create a streamable http client.
|
|
see mcp.client.streamable_http.streamablehttp_client for more details.
|
|
|
|
Any extra arguments passed to the constructor will be passed to the
|
|
streamable http client constructor.
|
|
|
|
Args:
|
|
name: The name of the plugin.
|
|
url: The URL of the MCP server.
|
|
load_tools: Whether to load tools from the MCP server.
|
|
load_prompts: Whether to load prompts from the MCP server.
|
|
request_timeout: The default timeout used for all requests.
|
|
session: The session to use for the MCP connection.
|
|
description: The description of the plugin.
|
|
additional_properties: Additional properties.
|
|
headers: The headers to send with the request.
|
|
timeout: The timeout for the request.
|
|
sse_read_timeout: The timeout for reading from the SSE stream.
|
|
terminate_on_close: Close the transport when the MCP client is terminated.
|
|
chat_client: The chat client to use for sampling.
|
|
kwargs: Any extra arguments to pass to the sse client.
|
|
"""
|
|
super().__init__(
|
|
name=name,
|
|
description=description,
|
|
additional_properties=additional_properties,
|
|
session=session,
|
|
chat_client=chat_client,
|
|
load_tools=load_tools,
|
|
load_prompts=load_prompts,
|
|
request_timeout=request_timeout,
|
|
)
|
|
self.url = url
|
|
self.headers = headers or {}
|
|
self.timeout = timeout
|
|
self.sse_read_timeout = sse_read_timeout
|
|
self.terminate_on_close = terminate_on_close
|
|
self._client_kwargs = kwargs
|
|
|
|
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
|
"""Get an MCP streamable http client."""
|
|
args: dict[str, Any] = {
|
|
"url": self.url,
|
|
}
|
|
if self.headers:
|
|
args["headers"] = self.headers
|
|
if self.timeout is not None:
|
|
args["timeout"] = self.timeout
|
|
if self.sse_read_timeout is not None:
|
|
args["sse_read_timeout"] = self.sse_read_timeout
|
|
if self.terminate_on_close is not None:
|
|
args["terminate_on_close"] = self.terminate_on_close
|
|
if self._client_kwargs:
|
|
args.update(self._client_kwargs)
|
|
return streamablehttp_client(**args)
|
|
|
|
|
|
class MCPWebsocketTool(MCPTool):
|
|
"""MCP websocket server configuration."""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
url: str,
|
|
*,
|
|
load_tools: bool = True,
|
|
load_prompts: bool = True,
|
|
request_timeout: int | None = None,
|
|
session: ClientSession | None = None,
|
|
description: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
chat_client: "ChatClientProtocol | None" = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the MCP websocket plugin.
|
|
|
|
The arguments are used to create a websocket client.
|
|
see mcp.client.websocket.websocket_client for more details.
|
|
|
|
Any extra arguments passed to the constructor will be passed to the
|
|
websocket client constructor.
|
|
|
|
Args:
|
|
name: The name of the plugin.
|
|
url: The URL of the MCP server.
|
|
load_tools: Whether to load tools from the MCP server.
|
|
load_prompts: Whether to load prompts from the MCP server.
|
|
request_timeout: The default timeout used for all requests.
|
|
session: The session to use for the MCP connection.
|
|
description: The description of the plugin.
|
|
additional_properties: Additional properties.
|
|
chat_client: The chat client to use for sampling.
|
|
kwargs: Any extra arguments to pass to the websocket client.
|
|
|
|
"""
|
|
super().__init__(
|
|
name=name,
|
|
description=description,
|
|
additional_properties=additional_properties,
|
|
session=session,
|
|
chat_client=chat_client,
|
|
load_tools=load_tools,
|
|
load_prompts=load_prompts,
|
|
request_timeout=request_timeout,
|
|
)
|
|
self.url = url
|
|
self._client_kwargs = kwargs
|
|
|
|
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
|
"""Get an MCP websocket client."""
|
|
args: dict[str, Any] = {
|
|
"url": self.url,
|
|
}
|
|
if self._client_kwargs:
|
|
args.update(self._client_kwargs)
|
|
return websocket_client(**args)
|