Python: [BREAKING] Reduce core dependencies and simplify optional integrations (#4904)

* improved dependencies and some fixes

* fix for mypy

* improve mcp
This commit is contained in:
Eduard van Valkenburg
2026-03-25 19:03:43 +01:00
committed by GitHub
Unverified
parent 49d69b3bf5
commit c012aac5f2
14 changed files with 985 additions and 440 deletions
+2
View File
@@ -25,8 +25,10 @@ classifiers = [
dependencies = [
"agent-framework-core>=1.0.0rc5",
"agent-framework-openai>=1.0.0rc5",
"azure-ai-projects>=2.0.0,<3.0",
"azure-ai-agents>=1.2.0b5,<1.2.0b6",
"azure-ai-inference>=1.0.0b9,<1.0.0b10",
"azure-identity>=1,<2",
"aiohttp>=3.7.0,<4",
]
@@ -24,9 +24,6 @@ from typing import (
)
from uuid import uuid4
from mcp import types
from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError
from pydantic import BaseModel
from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage]
@@ -71,6 +68,9 @@ else:
from typing_extensions import Self, TypedDict # pragma: no cover
if TYPE_CHECKING:
from mcp import types
from mcp.server.lowlevel import Server
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._types import ChatOptions
@@ -1369,6 +1369,15 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
Returns:
The MCP server instance.
"""
try:
from mcp import types
from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"`mcp` is required to use `Agent.as_mcp_server()`. Please install `mcp`."
) from exc
server_args: dict[str, Any] = {
"name": server_name,
"version": version,
@@ -1469,8 +1478,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
class Agent(
AgentTelemetryLayer,
AgentMiddlewareLayer,
AgentTelemetryLayer,
RawAgent[OptionsCoT],
Generic[OptionsCoT],
):
+338 -300
View File
@@ -13,25 +13,12 @@ from collections.abc import Callable, Collection, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, TypedDict
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
import httpx
from anyio import ClosedResourceError
from mcp import types
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamable_http_client
from mcp.client.websocket import websocket_client
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from opentelemetry import propagate
from ._tools import FunctionTool
from ._types import (
Content,
Message,
)
from ._types import Content, Message
from .exceptions import ToolException, ToolExecutionException
if sys.version_info >= (3, 11):
@@ -40,9 +27,18 @@ else:
from typing_extensions import Self # pragma: no cover
if TYPE_CHECKING:
from httpx import AsyncClient
from mcp import types
from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder
from ._clients import SupportsChatGetResponse
logger = logging.getLogger(__name__)
class MCPSpecificApproval(TypedDict, total=False):
"""Represents the specific approval mode for an MCP tool.
@@ -57,13 +53,12 @@ class MCPSpecificApproval(TypedDict, total=False):
never_require_approval: Collection[str] | None
logger = logging.getLogger(__name__)
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
# region: Helpers
LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
LOG_LEVEL_MAPPING: dict[str, int] = {
"debug": logging.DEBUG,
"info": logging.INFO,
"notice": logging.INFO,
@@ -75,269 +70,6 @@ LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
}
def _parse_prompt_result_from_mcp(
mcp_type: types.GetPromptResult,
) -> str:
"""Parse an MCP GetPromptResult directly into a string representation.
Converts each message in the prompt result to its string form and combines them.
Args:
mcp_type: The MCP GetPromptResult object to convert.
Returns:
A string representation of the prompt result.
"""
parts: list[str] = []
for message in mcp_type.messages:
content = message.content
if isinstance(content, types.TextContent):
parts.append(content.text)
elif isinstance(content, (types.ImageContent, types.AudioContent)):
parts.append(
json.dumps(
{
"type": "image" if isinstance(content, types.ImageContent) else "audio",
"data": content.data,
"mimeType": content.mimeType,
},
default=str,
)
)
elif isinstance(content, types.EmbeddedResource):
match content.resource:
case types.TextResourceContents():
parts.append(content.resource.text)
case types.BlobResourceContents():
parts.append(
json.dumps(
{
"type": "blob",
"data": content.resource.blob,
"mimeType": content.resource.mimeType,
},
default=str,
)
)
else:
parts.append(str(content))
if not parts:
return ""
if len(parts) == 1:
return parts[0]
return json.dumps(parts, default=str)
def _parse_message_from_mcp(
mcp_type: types.PromptMessage | types.SamplingMessage,
) -> Message:
"""Parse an MCP container type into an Agent Framework type."""
return Message(
role=mcp_type.role,
contents=_parse_content_from_mcp(mcp_type.content),
raw_representation=mcp_type,
)
def _parse_tool_result_from_mcp(
mcp_type: types.CallToolResult,
) -> list[Content]:
"""Parse an MCP CallToolResult into a list of Content items.
Converts each content item in the MCP result to its appropriate
Content form. Text items become ``Content(type="text")`` and media
items (images, audio) are preserved as rich Content.
Args:
mcp_type: The MCP CallToolResult object to convert.
Returns:
A list of Content items representing the tool result.
"""
result: list[Content] = []
for item in mcp_type.content:
match item:
case types.TextContent():
result.append(Content.from_text(item.text))
case types.ImageContent() | types.AudioContent():
decoded = base64.b64decode(item.data)
result.append(
Content.from_data(
data=decoded,
media_type=item.mimeType,
)
)
case types.ResourceLink():
result.append(
Content.from_uri(
uri=str(item.uri),
media_type=item.mimeType,
)
)
case types.EmbeddedResource():
match item.resource:
case types.TextResourceContents():
result.append(Content.from_text(item.resource.text))
case types.BlobResourceContents():
blob = item.resource.blob
mime = item.resource.mimeType or "application/octet-stream"
if not blob.startswith("data:"):
blob = f"data:{mime};base64,{blob}"
result.append(
Content.from_uri(
uri=blob,
media_type=mime,
)
)
case _:
result.append(Content.from_text(str(item)))
if not result:
result.append(Content.from_text("null"))
return result
def _parse_content_from_mcp(
mcp_type: types.ImageContent
| types.TextContent
| types.AudioContent
| types.EmbeddedResource
| types.ResourceLink
| types.ToolUseContent
| types.ToolResultContent
| Sequence[
types.ImageContent
| types.TextContent
| types.AudioContent
| types.EmbeddedResource
| types.ResourceLink
| types.ToolUseContent
| types.ToolResultContent
],
) -> list[Content]:
"""Parse an MCP type into an Agent Framework type."""
mcp_types = mcp_type if isinstance(mcp_type, Sequence) else [mcp_type]
return_types: list[Content] = []
for mcp_type in mcp_types:
match mcp_type:
case types.TextContent():
return_types.append(Content.from_text(text=mcp_type.text, raw_representation=mcp_type))
case types.ImageContent() | types.AudioContent():
# MCP protocol uses base64-encoded strings, convert to bytes
data_bytes = base64.b64decode(mcp_type.data) if isinstance(mcp_type.data, str) else mcp_type.data
return_types.append(
Content.from_data(
data=data_bytes,
media_type=mcp_type.mimeType,
raw_representation=mcp_type,
)
)
case types.ResourceLink():
return_types.append(
Content.from_uri(
uri=str(mcp_type.uri),
media_type=mcp_type.mimeType or "application/json",
raw_representation=mcp_type,
)
)
case types.ToolUseContent():
return_types.append(
Content.from_function_call(
call_id=mcp_type.id,
name=mcp_type.name,
arguments=mcp_type.input,
raw_representation=mcp_type,
)
)
case types.ToolResultContent():
return_types.append(
Content.from_function_result(
call_id=mcp_type.toolUseId,
result=_parse_content_from_mcp(mcp_type.content)
if mcp_type.content
else mcp_type.structuredContent,
exception=str(Exception()) if mcp_type.isError else None, # type: ignore[arg-type]
raw_representation=mcp_type,
)
)
case types.EmbeddedResource():
match mcp_type.resource:
case types.TextResourceContents():
return_types.append(
Content.from_text(
text=mcp_type.resource.text,
raw_representation=mcp_type,
additional_properties=(
mcp_type.annotations.model_dump() if mcp_type.annotations else None
),
)
)
case types.BlobResourceContents():
return_types.append(
Content.from_uri(
uri=mcp_type.resource.blob,
media_type=mcp_type.resource.mimeType,
raw_representation=mcp_type,
additional_properties=(
mcp_type.annotations.model_dump() if mcp_type.annotations else None
),
)
)
return return_types
def _prepare_content_for_mcp(
content: Content,
) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None:
"""Prepare an Agent Framework content type for MCP."""
if content.type == "text":
return types.TextContent(type="text", text=content.text) # type: ignore[attr-defined]
if content.type == "data":
if content.media_type and content.media_type.startswith("image/"): # type: ignore[attr-defined]
return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined]
if content.media_type and content.media_type.startswith("audio/"): # type: ignore[attr-defined]
return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined]
if content.media_type and content.media_type.startswith("application/"): # type: ignore[attr-defined]
return types.EmbeddedResource(
type="resource",
resource=types.BlobResourceContents(
blob=content.uri, # type: ignore[attr-defined]
mimeType=content.media_type, # type: ignore[attr-defined]
# uri's are not limited in MCP but they have to be set.
# the uri of data content, contains the data uri, which
# is not the uri meant here, UriContent would match this.
uri=(
content.additional_properties.get("uri", "af://binary")
if content.additional_properties
else "af://binary"
), # type: ignore[reportArgumentType]
),
)
return None
if content.type == "uri":
return types.ResourceLink(
type="resource_link",
uri=content.uri, # type: ignore[reportArgumentType,attr-defined]
mimeType=content.media_type, # type: ignore[attr-defined]
name=(content.additional_properties.get("name", "Unknown") if content.additional_properties else "Unknown"),
)
return None
def _prepare_message_for_mcp(
content: Message,
) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink]:
"""Prepare a Message for MCP format."""
messages: list[
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink
] = []
for item in content.contents:
mcp_content = _prepare_content_for_mcp(item)
if mcp_content:
messages.append(mcp_content)
return messages
def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any]:
"""Get the input model from an MCP prompt.
@@ -462,8 +194,8 @@ class MCPTool:
``Callable[[types.GetPromptResult], str]`` that overrides the default prompt
result parsing. When ``None`` (the default), the built-in parser converts
MCP prompt results to a string. If you need per-function result parsing,
access the ``.functions`` list after connecting and set ``result_parser`` on
individual ``FunctionTool`` instances.
access the ``.functions`` list after connecting and set ``result_parser`` on
individual ``FunctionTool`` instances.
session: An existing MCP client session to use.
request_timeout: Timeout in seconds for MCP requests.
client: A chat client for sampling callbacks.
@@ -495,6 +227,264 @@ class MCPTool:
def __str__(self) -> str:
return f"MCPTool(name={self.name}, description={self.description})"
def _parse_prompt_result_from_mcp(
self,
mcp_type: types.GetPromptResult,
) -> str:
"""Parse an MCP GetPromptResult directly into a string representation."""
from mcp import types
parts: list[str] = []
for message in mcp_type.messages:
content = message.content
if isinstance(content, types.TextContent):
parts.append(content.text)
elif isinstance(content, (types.ImageContent, types.AudioContent)):
parts.append(
json.dumps(
{
"type": "image" if isinstance(content, types.ImageContent) else "audio",
"data": content.data,
"mimeType": content.mimeType,
},
default=str,
)
)
elif isinstance(content, types.EmbeddedResource):
match content.resource:
case types.TextResourceContents():
parts.append(content.resource.text)
case types.BlobResourceContents():
parts.append(
json.dumps(
{
"type": "blob",
"data": content.resource.blob,
"mimeType": content.resource.mimeType,
},
default=str,
)
)
else:
parts.append(str(content))
if not parts:
return ""
if len(parts) == 1:
return parts[0]
return json.dumps(parts, default=str)
def _parse_message_from_mcp(
self,
mcp_type: types.PromptMessage | types.SamplingMessage,
) -> Message:
"""Parse an MCP container type into an Agent Framework type."""
return Message(
role=mcp_type.role,
contents=self._parse_content_from_mcp(mcp_type.content),
raw_representation=mcp_type,
)
def _parse_tool_result_from_mcp(
self,
mcp_type: types.CallToolResult,
) -> list[Content]:
"""Parse an MCP CallToolResult into a list of Content items."""
from mcp import types
result: list[Content] = []
for item in mcp_type.content:
match item:
case types.TextContent():
result.append(Content.from_text(item.text))
case types.ImageContent() | types.AudioContent():
decoded = base64.b64decode(item.data)
result.append(
Content.from_data(
data=decoded,
media_type=item.mimeType,
)
)
case types.ResourceLink():
result.append(
Content.from_uri(
uri=str(item.uri),
media_type=item.mimeType,
)
)
case types.EmbeddedResource():
match item.resource:
case types.TextResourceContents():
result.append(Content.from_text(item.resource.text))
case types.BlobResourceContents():
blob = item.resource.blob
mime = item.resource.mimeType or "application/octet-stream"
if not blob.startswith("data:"):
blob = f"data:{mime};base64,{blob}"
result.append(
Content.from_uri(
uri=blob,
media_type=mime,
)
)
case _:
result.append(Content.from_text(str(item)))
if not result:
result.append(Content.from_text("null"))
return result
def _parse_content_from_mcp(
self,
mcp_type: types.ImageContent
| types.TextContent
| types.AudioContent
| types.EmbeddedResource
| types.ResourceLink
| types.ToolUseContent
| types.ToolResultContent
| Sequence[
types.ImageContent
| types.TextContent
| types.AudioContent
| types.EmbeddedResource
| types.ResourceLink
| types.ToolUseContent
| types.ToolResultContent
],
) -> list[Content]:
"""Parse an MCP type into an Agent Framework type."""
from mcp import types
mcp_content_types: Sequence[Any] = (
cast(Sequence[Any], mcp_type) if isinstance(mcp_type, Sequence) else [mcp_type]
) # type: ignore[redundant-cast]
return_types: list[Content] = []
for mcp_type in mcp_content_types:
match mcp_type:
case types.TextContent():
return_types.append(Content.from_text(text=mcp_type.text, raw_representation=mcp_type))
case types.ImageContent() | types.AudioContent():
data_bytes = base64.b64decode(mcp_type.data) if isinstance(mcp_type.data, str) else mcp_type.data
return_types.append(
Content.from_data(
data=data_bytes,
media_type=mcp_type.mimeType,
raw_representation=mcp_type,
)
)
case types.ResourceLink():
return_types.append(
Content.from_uri(
uri=str(mcp_type.uri),
media_type=mcp_type.mimeType or "application/json",
raw_representation=mcp_type,
)
)
case types.ToolUseContent():
return_types.append(
Content.from_function_call(
call_id=mcp_type.id,
name=mcp_type.name,
arguments=mcp_type.input,
raw_representation=mcp_type,
)
)
case types.ToolResultContent():
return_types.append(
Content.from_function_result(
call_id=mcp_type.toolUseId,
result=self._parse_content_from_mcp(mcp_type.content)
if mcp_type.content
else mcp_type.structuredContent,
exception=str(Exception()) if mcp_type.isError else None, # type: ignore[arg-type]
raw_representation=mcp_type,
)
)
case types.EmbeddedResource():
match mcp_type.resource:
case types.TextResourceContents():
return_types.append(
Content.from_text(
text=mcp_type.resource.text,
raw_representation=mcp_type,
additional_properties=(
mcp_type.annotations.model_dump() if mcp_type.annotations else None
),
)
)
case types.BlobResourceContents():
return_types.append(
Content.from_uri(
uri=mcp_type.resource.blob,
media_type=mcp_type.resource.mimeType,
raw_representation=mcp_type,
additional_properties=(
mcp_type.annotations.model_dump() if mcp_type.annotations else None
),
)
)
case _:
pass
return return_types
def _prepare_content_for_mcp(
self,
content: Content,
) -> (
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None
):
"""Prepare an Agent Framework content type for MCP."""
from mcp import types
if content.type == "text":
return types.TextContent(type="text", text=content.text) # type: ignore[attr-defined]
if content.type == "data":
if content.media_type and content.media_type.startswith("image/"): # type: ignore[attr-defined]
return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined]
if content.media_type and content.media_type.startswith("audio/"): # type: ignore[attr-defined]
return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined]
if content.media_type and content.media_type.startswith("application/"): # type: ignore[attr-defined]
return types.EmbeddedResource(
type="resource",
resource=types.BlobResourceContents(
blob=content.uri, # type: ignore[attr-defined]
mimeType=content.media_type, # type: ignore[attr-defined]
uri=(
content.additional_properties.get("uri", "af://binary")
if content.additional_properties
else "af://binary"
), # type: ignore[arg-type]
),
)
return None
if content.type == "uri":
resource_name = (
content.additional_properties.get("name", "Unknown") if content.additional_properties else "Unknown"
)
return types.ResourceLink(
type="resource_link",
uri=content.uri, # type: ignore[arg-type,attr-defined]
mimeType=content.media_type, # type: ignore[attr-defined]
name=resource_name,
)
return None
def _prepare_message_for_mcp(
self,
content: Message,
) -> list[
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink
]:
"""Prepare a Message for MCP format."""
messages: list[
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink
] = []
for item in content.contents:
mcp_content = self._prepare_content_for_mcp(item)
if mcp_content:
messages.append(mcp_content)
return messages
@property
def functions(self) -> list[FunctionTool]:
"""Get the list of functions that are allowed."""
@@ -649,8 +639,16 @@ class MCPTool:
error_msg = f"Failed to connect to MCP server: {ex}"
raise ToolException(error_msg, inner_exception=ex) from ex
try:
try:
from mcp.client.session import ClientSession as runtime_client_session
except ModuleNotFoundError as ex:
await self._safe_close_exit_stack()
raise ToolException(
"MCP support requires `mcp`. Please install `mcp`.",
inner_exception=ex,
) from ex
session = await self._exit_stack.enter_async_context(
ClientSession(
runtime_client_session(
read_stream=transport[0],
write_stream=transport[1],
read_timeout_seconds=(
@@ -681,7 +679,7 @@ class MCPTool:
error_msg = f"MCP server failed to initialize: {ex}"
raise ToolException(error_msg, inner_exception=ex) from ex
self.session = session
elif self.session._request_id == 0: # type: ignore[reportPrivateUsage]
elif self.session._request_id == 0: # type: ignore[attr-defined]
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
logger.debug("Connected to MCP server: %s", self.session)
@@ -695,9 +693,10 @@ class MCPTool:
if logger.level != logging.NOTSET:
try:
await self.session.set_logging_level(
next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
level_name = cast(
Any, next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
)
await self.session.set_logging_level(level_name)
except Exception as exc:
logger.warning("Failed to set log level to %s", logger.level, exc_info=exc)
@@ -723,6 +722,8 @@ class MCPTool:
Returns:
Either a CreateMessageResult with the generated message or ErrorData if generation fails.
"""
from mcp import types
if not self.client:
return types.ErrorData(
code=types.INTERNAL_ERROR,
@@ -731,7 +732,7 @@ class MCPTool:
logger.debug("Sampling callback called with params: %s", params)
messages: list[Message] = []
for msg in params.messages:
messages.append(_parse_message_from_mcp(msg))
messages.append(self._parse_message_from_mcp(msg))
try:
response = await self.client.get_response(
messages,
@@ -749,7 +750,7 @@ class MCPTool:
code=types.INTERNAL_ERROR,
message="Failed to get chat message content.",
)
mcp_contents = _prepare_message_for_mcp(response.messages[0])
mcp_contents = self._prepare_message_for_mcp(response.messages[0])
# grab the first content that is of type TextContent or ImageContent
mcp_content = next(
(content for content in mcp_contents if isinstance(content, (types.TextContent, types.ImageContent))),
@@ -798,6 +799,8 @@ class MCPTool:
Args:
message: The message from the MCP server (request responder, notification, or exception).
"""
from mcp import types
if isinstance(message, Exception):
logger.error("Error from MCP server: %s", message, exc_info=message)
return
@@ -824,7 +827,7 @@ class MCPTool:
):
return "never_require"
return None
return self.approval_mode # type: ignore[reportReturnType]
return self.approval_mode # type: ignore[return-value]
async def load_prompts(self) -> None:
"""Load prompts from the MCP server.
@@ -835,6 +838,8 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
from mcp import types
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
@@ -883,6 +888,8 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
from mcp import types
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
@@ -996,6 +1003,9 @@ class MCPTool:
ToolExecutionException: If the MCP server is not connected, tools are not loaded,
or the tool call fails.
"""
from anyio import ClosedResourceError
from mcp.shared.exceptions import McpError
if not self.load_tools_flag:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
@@ -1025,8 +1035,7 @@ class MCPTool:
# Inject OpenTelemetry trace context into MCP _meta for distributed tracing.
otel_meta = _inject_otel_into_mcp_meta()
parser = self.parse_tool_results or _parse_tool_result_from_mcp
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Try the operation, reconnecting once if the connection is closed
for attempt in range(2):
try:
@@ -1062,7 +1071,8 @@ class MCPTool:
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
error_message = mcp_exc.error.message
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
@@ -1083,13 +1093,15 @@ class MCPTool:
ToolExecutionException: If the MCP server is not connected, prompts are not loaded,
or the prompt call fails.
"""
from anyio import ClosedResourceError
from mcp.shared.exceptions import McpError
if not self.load_prompts_flag:
raise ToolExecutionException(
"Prompts are not loaded for this server, please set load_prompts=True in the constructor."
)
parser = self.parse_prompt_results or _parse_prompt_result_from_mcp
parser = self.parse_prompt_results or self._parse_prompt_result_from_mcp
# Try the operation, reconnecting once if the connection is closed
for attempt in range(2):
try:
@@ -1115,7 +1127,8 @@ class MCPTool:
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc
error_message = mcp_exc.error.message
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
except Exception as ex:
raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to get prompt '{prompt_name}' after retries.")
@@ -1289,6 +1302,11 @@ class MCPStdioTool(MCPTool):
args["encoding"] = self.encoding
if self._client_kwargs:
args.update(self._client_kwargs)
try:
from mcp.client.stdio import StdioServerParameters, stdio_client
except ModuleNotFoundError as ex:
raise ModuleNotFoundError("`mcp` is required to use `MCPStdioTool`. Please install `mcp`.") from ex
return stdio_client(server=StdioServerParameters(**args))
@@ -1333,7 +1351,7 @@ class MCPStreamableHTTPTool(MCPTool):
terminate_on_close: bool | None = None,
client: SupportsChatGetResponse | None = None,
additional_properties: dict[str, Any] | None = None,
http_client: httpx.AsyncClient | None = None,
http_client: AsyncClient | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.
@@ -1341,7 +1359,7 @@ class MCPStreamableHTTPTool(MCPTool):
Note:
The arguments are used to create a streamable HTTP client using the
new ``mcp.client.streamable_http.streamable_http_client`` API.
If an httpx.AsyncClient is provided via ``http_client``, it will be used directly.
If an asyncClient is provided via ``http_client``, it will be used directly.
Otherwise, the ``streamable_http_client`` API will create and manage a default client.
Args:
@@ -1377,10 +1395,10 @@ class MCPStreamableHTTPTool(MCPTool):
additional_properties: Additional properties.
terminate_on_close: Close the transport when the MCP client is terminated.
client: The chat client to use for sampling.
http_client: Optional httpx.AsyncClient to use. If not provided, the
http_client: Optional asyncClient to use. If not provided, the
``streamable_http_client`` API will create and manage a default client.
To configure headers, timeouts, or other HTTP client settings, create
and pass your own ``httpx.AsyncClient`` instance.
and pass your own ``asyncClient`` instance.
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
"""
super().__init__(
@@ -1400,7 +1418,7 @@ class MCPStreamableHTTPTool(MCPTool):
)
self.url = url
self.terminate_on_close = terminate_on_close
self._httpx_client: httpx.AsyncClient | None = http_client
self._httpx_client: AsyncClient | None = http_client
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP streamable HTTP client.
@@ -1408,6 +1426,11 @@ class MCPStreamableHTTPTool(MCPTool):
Returns:
An async context manager for the streamable HTTP client transport.
"""
try:
from mcp.client.streamable_http import streamable_http_client
except ModuleNotFoundError as ex:
raise ModuleNotFoundError("`mcp` is required to use `MCPStreamableHTTPTool`. Please install `mcp`.") from ex
# Pass the http_client (which may be None) to streamable_http_client
return streamable_http_client(
url=self.url,
@@ -1522,6 +1545,21 @@ class MCPWebsocketTool(MCPTool):
Returns:
An async context manager for the WebSocket client transport.
"""
try:
from mcp.client.websocket import websocket_client
except ModuleNotFoundError as ex:
missing_name = ex.name or "mcp/websocket dependencies"
if missing_name == "mcp" or missing_name.startswith("mcp."):
reason = "The `mcp` package is not installed."
elif missing_name == "websockets" or missing_name.startswith("websockets."):
reason = "WebSocket transport support is not installed."
else:
reason = f"The optional dependency `{missing_name}` is not installed."
raise ModuleNotFoundError(
f"`MCPWebsocketTool` requires websocket transport support. {reason} "
"Please install `mcp[ws]` and update your dependencies."
) from ex
args: dict[str, Any] = {
"url": self.url,
}
@@ -27,9 +27,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, TypedD
from dotenv import load_dotenv
from opentelemetry import metrics, trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv_ai import Meters
from . import __version__ as version_info
from ._settings import load_settings
@@ -43,6 +40,7 @@ if TYPE_CHECKING: # pragma: no cover
from opentelemetry.sdk._logs.export import LogRecordExporter
from opentelemetry.sdk.metrics.export import MetricExporter
from opentelemetry.sdk.metrics.view import View
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SpanExporter
from opentelemetry.trace import Tracer
from opentelemetry.util._decorator import _AgnosticContextManager # type: ignore[reportPrivateUsage]
@@ -206,6 +204,8 @@ class OtelAttr(str, Enum):
TOOL_RESULT = "gen_ai.tool.call.result"
# Agent attributes
AGENT_ID = "gen_ai.agent.id"
SERVICE_NAME = "service.name"
SERVICE_VERSION = "service.version"
# Client attributes
# replaced TOKEN with T, because both ruff and bandit,
# complain about TOKEN being a potential secret
@@ -214,6 +214,8 @@ class OtelAttr(str, Enum):
T_TYPE_INPUT = "input"
T_TYPE_OUTPUT = "output"
DURATION_UNIT = "s"
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
LLM_TOKEN_USAGE = "gen_ai.client.token.usage" # nosec B105 # noqa: S105 - OpenTelemetry metric name, not a secret.
# Agent attributes
AGENT_NAME = "gen_ai.agent.name"
@@ -224,8 +226,6 @@ class OtelAttr(str, Enum):
INPUT_MESSAGES = "gen_ai.input.messages"
OUTPUT_MESSAGES = "gen_ai.output.messages"
SYSTEM_INSTRUCTIONS = "gen_ai.system_instructions"
# Attributes previously from opentelemetry-semantic-conventions-ai SpanAttributes,
# removed in v0.4.14. Defined here for forward compatibility.
SYSTEM = "gen_ai.system"
REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
REQUEST_TEMPERATURE = "gen_ai.request.temperature"
@@ -282,7 +282,7 @@ class OtelAttr(str, Enum):
CHAT_COMPLETION_OPERATION = "chat"
EMBEDDING_OPERATION = "embeddings"
TOOL_EXECUTION_OPERATION = "execute_tool"
# Describes GenAI agent creation and is usually applicable when working with remote agent services.
# Describes GenAI agent creation and is usually applicable when working with remote agent services.
AGENT_CREATE_OPERATION = "create_agent"
AGENT_INVOKE_OPERATION = "invoke_agent"
@@ -576,25 +576,27 @@ def create_resource(
# Load from custom .env file
resource = create_resource(env_file_path="config/.env")
"""
# Load environment variables from a .env file only when explicitly provided
try:
from opentelemetry.sdk.resources import Resource
except ModuleNotFoundError as ex:
raise ModuleNotFoundError(
"`opentelemetry-sdk` is required to use `create_resource()`. "
"Please install `opentelemetry-sdk` and update your dependencies."
) from ex
if env_file_path is not None:
load_dotenv(dotenv_path=env_file_path, encoding=env_file_encoding)
# Start with provided attributes
resource_attributes: dict[str, Any] = dict(attributes)
# Set service name
if service_name is None:
service_name = os.getenv("OTEL_SERVICE_NAME", "agent_framework")
resource_attributes[service_attributes.SERVICE_NAME] = service_name
resource_attributes[OtelAttr.SERVICE_NAME] = service_name
# Set service version
if service_version is None:
service_version = os.getenv("OTEL_SERVICE_VERSION", version_info)
resource_attributes[service_attributes.SERVICE_VERSION] = service_version
resource_attributes[OtelAttr.SERVICE_VERSION] = service_version
# Parse OTEL_RESOURCE_ATTRIBUTES environment variable
# Format: key1=value1,key2=value2
if resource_attrs_env := os.getenv("OTEL_RESOURCE_ATTRIBUTES"):
resource_attributes.update(_parse_headers(resource_attrs_env))
return Resource.create(resource_attributes)
@@ -602,10 +604,15 @@ def create_resource(
def create_metric_views() -> list[View]:
"""Create the default OpenTelemetry metric views for Agent Framework."""
from opentelemetry.sdk.metrics.view import DropAggregation, View
try:
from opentelemetry.sdk.metrics.view import DropAggregation, View
except ModuleNotFoundError as ex:
raise ModuleNotFoundError(
"`opentelemetry-sdk` is required to use `create_metric_views()`. "
"Please install `opentelemetry-sdk` and update your dependencies."
) from ex
return [
# Dropping all enable_instrumentation names except for those starting with "agent_framework"
View(instrument_name="agent_framework*"),
View(instrument_name="gen_ai*"),
View(instrument_name="*", aggregation=DropAggregation()),
@@ -659,7 +666,7 @@ class ObservabilitySettings:
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the settings and create the resource."""
"""Initialize the settings."""
env_file_path = kwargs.pop("env_file_path", None)
env_file_encoding = kwargs.pop("env_file_encoding", None)
data = load_settings(
@@ -674,10 +681,6 @@ class ObservabilitySettings:
self.vs_code_extension_port: int | None = data.get("vs_code_extension_port")
self.env_file_path = env_file_path
self.env_file_encoding = env_file_encoding
self._resource = create_resource(
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self._executed_setup = False
@property
@@ -762,17 +765,27 @@ class ObservabilitySettings:
exporters: A list of exporters for logs, metrics and/or spans.
views: Optional list of OpenTelemetry views for metrics. Default is empty list.
"""
from opentelemetry._logs import set_logger_provider
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogRecordExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import MetricExporter, PeriodicExportingMetricReader
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
try:
from opentelemetry._logs import set_logger_provider
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, LogRecordExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import MetricExporter, PeriodicExportingMetricReader
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
except ModuleNotFoundError as ex:
raise ModuleNotFoundError(
"`opentelemetry-sdk` is required to use `configure_otel_providers()`. "
"Please install `opentelemetry-sdk` and update your dependencies."
) from ex
span_exporters: list[SpanExporter] = []
log_exporters: list[LogRecordExporter] = []
metric_exporters: list[MetricExporter] = []
resource = create_resource(
env_file_path=self.env_file_path,
env_file_encoding=self.env_file_encoding,
)
for exp in exporters:
if isinstance(exp, SpanExporter):
span_exporters.append(exp)
@@ -783,14 +796,14 @@ class ObservabilitySettings:
# Tracing
if span_exporters:
tracer_provider = TracerProvider(resource=self._resource)
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider)
for exporter in span_exporters:
tracer_provider.add_span_processor(BatchSpanProcessor(exporter))
# Logging
if log_exporters:
logger_provider = LoggerProvider(resource=self._resource)
logger_provider = LoggerProvider(resource=resource)
for log_exporter in log_exporters:
logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
# Attach a handler with the provider to the root logger
@@ -805,7 +818,7 @@ class ObservabilitySettings:
PeriodicExportingMetricReader(exporter, export_interval_millis=5000)
for exporter in metric_exporters
],
resource=self._resource,
resource=resource,
views=views or [],
)
metrics.set_meter_provider(meter_provider)
@@ -1106,7 +1119,6 @@ def configure_otel_providers(
OBSERVABILITY_SETTINGS.vs_code_extension_port = updated_settings.vs_code_extension_port
OBSERVABILITY_SETTINGS.env_file_path = updated_settings.env_file_path
OBSERVABILITY_SETTINGS.env_file_encoding = updated_settings.env_file_encoding
OBSERVABILITY_SETTINGS._resource = updated_settings._resource # type: ignore[reportPrivateUsage]
OBSERVABILITY_SETTINGS._executed_setup = False # type: ignore[reportPrivateUsage]
else:
# Re-read settings from current environment in case env vars were set
@@ -1123,7 +1135,6 @@ def configure_otel_providers(
OBSERVABILITY_SETTINGS.vs_code_extension_port = (
vs_code_extension_port if vs_code_extension_port is not None else _read_int_env("VS_CODE_EXTENSION_PORT")
)
OBSERVABILITY_SETTINGS._resource = create_resource() # type: ignore[reportPrivateUsage]
OBSERVABILITY_SETTINGS._executed_setup = False # type: ignore[reportPrivateUsage]
OBSERVABILITY_SETTINGS._configure( # type: ignore[reportPrivateUsage]
@@ -1137,7 +1148,7 @@ def configure_otel_providers(
def _get_duration_histogram() -> metrics.Histogram:
return get_meter().create_histogram(
name=Meters.LLM_OPERATION_DURATION,
name=OtelAttr.LLM_OPERATION_DURATION,
unit=OtelAttr.DURATION_UNIT,
description="Captures the duration of operations of function-invoking chat clients",
explicit_bucket_boundaries_advisory=OPERATION_DURATION_BUCKET_BOUNDARIES,
@@ -1146,7 +1157,7 @@ def _get_duration_histogram() -> metrics.Histogram:
def _get_token_usage_histogram() -> metrics.Histogram:
return get_meter().create_histogram(
name=Meters.LLM_TOKEN_USAGE,
name=OtelAttr.LLM_TOKEN_USAGE,
unit=OtelAttr.T_UNIT,
description="Captures the token usage of chat clients",
explicit_bucket_boundaries_advisory=TOKEN_USAGE_BUCKET_BOUNDARIES,
+2 -10
View File
@@ -23,28 +23,20 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
# utilities
"typing-extensions>=4.15.0,<5",
"pydantic>=2,<3",
"python-dotenv>=1,<2",
# telemetry
"opentelemetry-api>=1.39.0,<2",
"opentelemetry-sdk>=1.39.0,<2",
"opentelemetry-semantic-conventions-ai>=0.4.13,<0.4.14",
# connectors and functions
"openai>=1.99.0,<3",
"azure-identity>=1,<2",
"azure-ai-projects>=2.0.0,<3.0",
"mcp[ws]>=1.24.0,<2",
"packaging>=24.1,<25",
]
[project.optional-dependencies]
all = [
"mcp>=1.24.0,<2",
"agent-framework-a2a",
"agent-framework-ag-ui",
"agent-framework-azure-ai-search",
"agent-framework-anthropic",
"agent-framework-openai",
"agent-framework-claude",
"agent-framework-azure-ai",
"agent-framework-azurefunctions",
+1 -1
View File
@@ -67,7 +67,7 @@ def span_exporter(monkeypatch, enable_instrumentation: bool, enable_sensitive_da
if enable_instrumentation or enable_sensitive_data:
from opentelemetry.sdk.trace import TracerProvider
tracer_provider = TracerProvider(resource=observability_settings._resource)
tracer_provider = TracerProvider(resource=observability.create_resource())
trace.set_tracer_provider(tracer_provider)
monkeypatch.setattr(observability, "OBSERVABILITY_SETTINGS", observability_settings, raising=False) # type: ignore
+336 -37
View File
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
# type: ignore[reportPrivateUsage]
import json
import logging
import os
from contextlib import _AsyncGeneratorContextManager # type: ignore
@@ -23,13 +24,9 @@ from agent_framework import (
)
from agent_framework._mcp import (
MCPTool,
_build_prefixed_mcp_name,
_get_input_model_from_mcp_prompt,
_normalize_mcp_name,
_parse_content_from_mcp,
_parse_message_from_mcp,
_parse_tool_result_from_mcp,
_prepare_content_for_mcp,
_prepare_message_for_mcp,
logger,
)
from agent_framework._middleware import FunctionMiddlewarePipeline
@@ -50,6 +47,9 @@ def _mcp_result_to_text(result: str | list[Content]) -> str:
return text or str(result)
_HELPER_MCP_TOOL = MCPTool(name="helper")
# Helper function tests
def test_normalize_mcp_name():
"""Test MCP name normalization."""
@@ -61,6 +61,10 @@ def test_normalize_mcp_name():
assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes"
def test_build_prefixed_mcp_name_ignores_empty_normalized_prefix() -> None:
assert _build_prefixed_mcp_name("search", "---") == "search"
def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None:
assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio"
assert (
@@ -139,7 +143,7 @@ async def test_load_prompts_with_tool_name_prefix() -> None:
def test_mcp_prompt_message_to_ai_content():
"""Test conversion from MCP prompt message to AI content."""
mcp_message = types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello, world!"))
ai_content = _parse_message_from_mcp(mcp_message)
ai_content = _HELPER_MCP_TOOL._parse_message_from_mcp(mcp_message)
assert isinstance(ai_content, Message)
assert ai_content.role == "user"
@@ -149,6 +153,55 @@ def test_mcp_prompt_message_to_ai_content():
assert ai_content.raw_representation == mcp_message
def test_mcp_tool_str_and_parse_prompt_result_rich_content() -> None:
tool = MCPTool(name="helper", description="Helper MCP tool")
prompt_result = types.GetPromptResult(
messages=[
types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello")),
types.PromptMessage(
role="assistant",
content=types.ImageContent(type="image", data="eHl6", mimeType="image/png"),
),
types.PromptMessage(
role="assistant",
content=types.AudioContent(type="audio", data="YXVkaW8=", mimeType="audio/wav"),
),
types.PromptMessage(
role="assistant",
content=types.EmbeddedResource(
type="resource",
resource=types.TextResourceContents(
uri=AnyUrl("file://prompt.txt"),
mimeType="text/plain",
text="Embedded prompt",
),
),
),
types.PromptMessage(
role="assistant",
content=types.EmbeddedResource(
type="resource",
resource=types.BlobResourceContents(
uri=AnyUrl("file://prompt.bin"),
mimeType="application/pdf",
blob="ZGF0YQ==",
),
),
),
]
)
result = tool._parse_prompt_result_from_mcp(prompt_result)
parsed = json.loads(result)
assert str(tool) == "MCPTool(name=helper, description=Helper MCP tool)"
assert parsed[0] == "Hello"
assert json.loads(parsed[1]) == {"type": "image", "data": "eHl6", "mimeType": "image/png"}
assert json.loads(parsed[2]) == {"type": "audio", "data": "YXVkaW8=", "mimeType": "audio/wav"}
assert parsed[3] == "Embedded prompt"
assert json.loads(parsed[4]) == {"type": "blob", "data": "ZGF0YQ==", "mimeType": "application/pdf"}
def test_parse_tool_result_from_mcp():
"""Test conversion from MCP tool result with images preserves original order."""
mcp_result = types.CallToolResult(
@@ -159,7 +212,7 @@ def test_parse_tool_result_from_mcp():
types.ImageContent(type="image", data="YWJj", mimeType="image/webp"),
]
)
result = _parse_tool_result_from_mcp(mcp_result)
result = _HELPER_MCP_TOOL._parse_tool_result_from_mcp(mcp_result)
# Results with images return a list of Content objects in original order
assert isinstance(result, list)
@@ -180,7 +233,7 @@ def test_parse_tool_result_from_mcp():
def test_parse_tool_result_from_mcp_single_text():
"""Test conversion from MCP tool result with a single text item."""
mcp_result = types.CallToolResult(content=[types.TextContent(type="text", text="Simple result")])
result = _parse_tool_result_from_mcp(mcp_result)
result = _HELPER_MCP_TOOL._parse_tool_result_from_mcp(mcp_result)
# Single text item returns list with one text Content
assert isinstance(result, list)
@@ -196,7 +249,7 @@ def test_parse_tool_result_from_mcp_meta_not_in_string():
_meta={"isError": True, "errorCode": "TOOL_ERROR"},
)
result = _parse_tool_result_from_mcp(mcp_result)
result = _HELPER_MCP_TOOL._parse_tool_result_from_mcp(mcp_result)
assert isinstance(result, list)
assert len(result) == 1
assert result[0].text == "Error occurred"
@@ -205,7 +258,7 @@ def test_parse_tool_result_from_mcp_meta_not_in_string():
def test_parse_tool_result_from_mcp_empty_content():
"""Test that empty MCP content normalizes to JSON null text content."""
mcp_result = types.CallToolResult(content=[])
result = _parse_tool_result_from_mcp(mcp_result)
result = _HELPER_MCP_TOOL._parse_tool_result_from_mcp(mcp_result)
assert isinstance(result, list)
assert len(result) == 1
assert result[0].type == "text"
@@ -222,7 +275,7 @@ def test_parse_tool_result_from_mcp_audio_content():
types.AudioContent(type="audio", data="YXVkaW8=", mimeType="audio/wav"),
]
)
result = _parse_tool_result_from_mcp(mcp_result)
result = _HELPER_MCP_TOOL._parse_tool_result_from_mcp(mcp_result)
assert isinstance(result, list)
assert len(result) == 1
@@ -245,7 +298,7 @@ def test_parse_tool_result_from_mcp_blob_plain_base64():
),
]
)
result = _parse_tool_result_from_mcp(mcp_result)
result = _HELPER_MCP_TOOL._parse_tool_result_from_mcp(mcp_result)
assert isinstance(result, list)
assert len(result) == 1
@@ -254,10 +307,39 @@ def test_parse_tool_result_from_mcp_blob_plain_base64():
assert "dGVzdCBkYXRh" in result[0].uri
def test_parse_tool_result_from_mcp_resource_link_text_resource_and_unknown():
"""Test additional MCP tool result variants."""
mcp_result = types.CallToolResult(
content=[
types.ResourceLink(
type="resource_link",
uri=AnyUrl("https://example.com/resource"),
name="resource",
mimeType="application/json",
),
types.EmbeddedResource(
type="resource",
resource=types.TextResourceContents(
uri=AnyUrl("file://prompt.txt"),
mimeType="text/plain",
text="Embedded result",
),
),
]
)
result = _HELPER_MCP_TOOL._parse_tool_result_from_mcp(mcp_result)
assert result[0].type == "uri"
assert result[0].uri == "https://example.com/resource"
assert result[1].type == "text"
assert result[1].text == "Embedded result"
def test_mcp_content_types_to_ai_content_text():
"""Test conversion of MCP text content to AI content."""
mcp_content = types.TextContent(type="text", text="Sample text")
ai_content = _parse_content_from_mcp(mcp_content)[0]
ai_content = _HELPER_MCP_TOOL._parse_content_from_mcp(mcp_content)[0]
assert ai_content.type == "text"
assert ai_content.text == "Sample text"
@@ -268,7 +350,7 @@ def test_mcp_content_types_to_ai_content_image():
"""Test conversion of MCP image content to AI content."""
# MCP can send data as base64 string or as bytes
mcp_content = types.ImageContent(type="image", data="YWJj", mimeType="image/jpeg") # base64 for b"abc"
ai_content = _parse_content_from_mcp(mcp_content)[0]
ai_content = _HELPER_MCP_TOOL._parse_content_from_mcp(mcp_content)[0]
assert ai_content.type == "data"
assert ai_content.uri == "data:image/jpeg;base64,YWJj"
@@ -280,7 +362,7 @@ def test_mcp_content_types_to_ai_content_audio():
"""Test conversion of MCP audio content to AI content."""
# Use properly padded base64
mcp_content = types.AudioContent(type="audio", data="ZGVm", mimeType="audio/wav") # base64 for b"def"
ai_content = _parse_content_from_mcp(mcp_content)[0]
ai_content = _HELPER_MCP_TOOL._parse_content_from_mcp(mcp_content)[0]
assert ai_content.type == "data"
assert ai_content.uri == "data:audio/wav;base64,ZGVm"
@@ -296,7 +378,7 @@ def test_mcp_content_types_to_ai_content_resource_link():
name="test_resource",
mimeType="application/json",
)
ai_content = _parse_content_from_mcp(mcp_content)[0]
ai_content = _HELPER_MCP_TOOL._parse_content_from_mcp(mcp_content)[0]
assert ai_content.type == "uri"
assert ai_content.uri == "https://example.com/resource"
@@ -312,7 +394,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_text():
text="Embedded text content",
)
mcp_content = types.EmbeddedResource(type="resource", resource=text_resource)
ai_content = _parse_content_from_mcp(mcp_content)[0]
ai_content = _HELPER_MCP_TOOL._parse_content_from_mcp(mcp_content)[0]
assert ai_content.type == "text"
assert ai_content.text == "Embedded text content"
@@ -328,7 +410,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_blob():
blob="data:application/octet-stream;base64,dGVzdCBkYXRh",
)
mcp_content = types.EmbeddedResource(type="resource", resource=blob_resource)
ai_content = _parse_content_from_mcp(mcp_content)[0]
ai_content = _HELPER_MCP_TOOL._parse_content_from_mcp(mcp_content)[0]
assert ai_content.type == "data"
assert ai_content.uri == "data:application/octet-stream;base64,dGVzdCBkYXRh"
@@ -336,10 +418,33 @@ def test_mcp_content_types_to_ai_content_embedded_resource_blob():
assert ai_content.raw_representation == mcp_content
def test_mcp_content_types_to_ai_content_tool_use_and_tool_result():
"""Test conversion of MCP tool use/result content to AI function call/result content."""
tool_use_content = types.ToolUseContent(type="tool_use", id="call-1", name="calculator", input={"x": 1})
tool_result_content = types.ToolResultContent(
type="tool_result",
toolUseId="call-1",
content=[types.TextContent(type="text", text="done")],
isError=True,
)
function_call = _HELPER_MCP_TOOL._parse_content_from_mcp(tool_use_content)[0]
function_result = _HELPER_MCP_TOOL._parse_content_from_mcp(tool_result_content)[0]
assert function_call.type == "function_call"
assert function_call.call_id == "call-1"
assert function_call.name == "calculator"
assert function_call.arguments == {"x": 1}
assert function_result.type == "function_result"
assert function_result.call_id == "call-1"
assert function_result.result == "done"
assert function_result.exception == ""
def test_ai_content_to_mcp_content_types_text():
"""Test conversion of AI text content to MCP content."""
ai_content = Content.from_text(text="Sample text")
mcp_content = _prepare_content_for_mcp(ai_content)
mcp_content = _HELPER_MCP_TOOL._prepare_content_for_mcp(ai_content)
assert isinstance(mcp_content, types.TextContent)
assert mcp_content.type == "text"
@@ -349,7 +454,7 @@ def test_ai_content_to_mcp_content_types_text():
def test_ai_content_to_mcp_content_types_data_image():
"""Test conversion of AI data content to MCP content."""
ai_content = Content.from_uri(uri="data:image/png;base64,xyz", media_type="image/png")
mcp_content = _prepare_content_for_mcp(ai_content)
mcp_content = _HELPER_MCP_TOOL._prepare_content_for_mcp(ai_content)
assert isinstance(mcp_content, types.ImageContent)
assert mcp_content.type == "image"
@@ -360,7 +465,7 @@ def test_ai_content_to_mcp_content_types_data_image():
def test_ai_content_to_mcp_content_types_data_audio():
"""Test conversion of AI data content to MCP content."""
ai_content = Content.from_uri(uri="data:audio/mpeg;base64,xyz", media_type="audio/mpeg")
mcp_content = _prepare_content_for_mcp(ai_content)
mcp_content = _HELPER_MCP_TOOL._prepare_content_for_mcp(ai_content)
assert isinstance(mcp_content, types.AudioContent)
assert mcp_content.type == "audio"
@@ -374,7 +479,7 @@ def test_ai_content_to_mcp_content_types_data_binary():
uri="data:application/octet-stream;base64,xyz",
media_type="application/octet-stream",
)
mcp_content = _prepare_content_for_mcp(ai_content)
mcp_content = _HELPER_MCP_TOOL._prepare_content_for_mcp(ai_content)
assert isinstance(mcp_content, types.EmbeddedResource)
assert mcp_content.type == "resource"
@@ -385,7 +490,7 @@ def test_ai_content_to_mcp_content_types_data_binary():
def test_ai_content_to_mcp_content_types_uri():
"""Test conversion of AI URI content to MCP content."""
ai_content = Content.from_uri(uri="https://example.com/resource", media_type="application/json")
mcp_content = _prepare_content_for_mcp(ai_content)
mcp_content = _HELPER_MCP_TOOL._prepare_content_for_mcp(ai_content)
assert isinstance(mcp_content, types.ResourceLink)
assert mcp_content.type == "resource_link"
@@ -401,12 +506,24 @@ def test_prepare_message_for_mcp():
Content.from_uri(uri="data:image/png;base64,xyz", media_type="image/png"),
],
)
mcp_contents = _prepare_message_for_mcp(message)
mcp_contents = _HELPER_MCP_TOOL._prepare_message_for_mcp(message)
assert len(mcp_contents) == 2
assert isinstance(mcp_contents[0], types.TextContent)
assert isinstance(mcp_contents[1], types.ImageContent)
def test_prepare_message_for_mcp_skips_unsupported_content() -> None:
unsupported = Content(type="annotations", text="ignored")
assert _HELPER_MCP_TOOL._prepare_content_for_mcp(unsupported) is None
mcp_contents = _HELPER_MCP_TOOL._prepare_message_for_mcp(
Message(role="user", contents=[Content.from_text("kept"), unsupported])
)
assert len(mcp_contents) == 1
assert isinstance(mcp_contents[0], types.TextContent)
@pytest.mark.parametrize(
"test_id,input_schema",
[
@@ -1287,6 +1404,18 @@ async def test_mcp_tool_approval_mode(approval_mode, expected_approvals):
assert func.approval_mode == expected_approvals[func.name]
def test_mcp_tool_approval_mode_returns_none_for_unmatched_names() -> None:
tool = MCPTool(
name="test_tool",
approval_mode={
"always_require_approval": ["tool_one"],
"never_require_approval": ["tool_two"],
},
)
assert tool._determine_approval_mode("tool_three") is None
@pytest.mark.parametrize(
"allowed_tools,expected_count,expected_names",
[
@@ -1618,6 +1747,46 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
assert "Failed to get right content types from the response." in result.message
async def test_mcp_tool_sampling_callback_no_response_and_successful_message_creation():
"""Test sampling callback when the chat client returns no response and then valid content."""
tool = MCPStdioTool(name="test_tool", command="python")
tool.client = AsyncMock()
params = Mock()
params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))]
params.temperature = None
params.maxTokens = None
params.stopSequences = None
tool.client.get_response.return_value = None
no_response = await tool.sampling_callback(Mock(), params)
assert isinstance(no_response, types.ErrorData)
assert no_response.message == "Failed to get chat message content."
tool.client.get_response.return_value = Mock(
messages=[Message(role="assistant", contents=[Content.from_text("Hello")])],
model_id="test-model",
)
success = await tool.sampling_callback(Mock(), params)
assert isinstance(success, types.CreateMessageResult)
assert success.role == "assistant"
assert success.model == "test-model"
assert isinstance(success.content, types.TextContent)
assert success.content.text == "Hello"
async def test_mcp_tool_logging_callback_logs_at_requested_level() -> None:
tool = MCPStdioTool(name="test_tool", command="python")
with patch.object(logger, "log") as mock_log:
await tool.logging_callback(types.LoggingMessageNotificationParams(level="warning", data="be careful"))
mock_log.assert_called_once_with(logging.WARNING, "be careful")
# Test error handling in connect() method
@@ -1633,7 +1802,7 @@ async def test_connect_session_creation_failure():
tool.get_mcp_client = Mock(return_value=mock_context_manager)
# Mock ClientSession to raise an exception
with patch("agent_framework._mcp.ClientSession") as mock_session_class:
with patch("mcp.client.session.ClientSession") as mock_session_class:
mock_session_class.side_effect = RuntimeError("Session creation failed")
with pytest.raises(ToolException) as exc_info:
@@ -1658,7 +1827,7 @@ async def test_connect_initialization_failure_http_no_command():
mock_session = Mock()
mock_session.initialize = AsyncMock(side_effect=ConnectionError("Server not ready"))
with patch("agent_framework._mcp.ClientSession") as mock_session_class:
with patch("mcp.client.session.ClientSession") as mock_session_class:
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None)
@@ -1687,6 +1856,18 @@ async def test_connect_cleanup_on_transport_failure():
tool._exit_stack.aclose.assert_called_once()
async def test_connect_cleanup_on_transport_failure_http_uses_generic_message():
"""Test HTTP transport failures use the generic connection message when no command exists."""
tool = MCPStreamableHTTPTool(name="test", url="https://example.com/mcp")
tool._exit_stack.aclose = AsyncMock()
tool.get_mcp_client = Mock(side_effect=RuntimeError("Transport failed"))
with pytest.raises(ToolException, match="Failed to connect to MCP server: Transport failed"):
await tool.connect()
tool._exit_stack.aclose.assert_called_once()
async def test_connect_cleanup_on_initialization_failure():
"""Test that _exit_stack.aclose() is called when initialization fails."""
tool = MCPStdioTool(name="test", command="test-command")
@@ -1705,7 +1886,7 @@ async def test_connect_cleanup_on_initialization_failure():
mock_session = Mock()
mock_session.initialize = AsyncMock(side_effect=RuntimeError("Init failed"))
with patch("agent_framework._mcp.ClientSession") as mock_session_class:
with patch("mcp.client.session.ClientSession") as mock_session_class:
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None)
@@ -1722,18 +1903,20 @@ def test_mcp_stdio_tool_get_mcp_client_with_env_and_kwargs():
tool = MCPStdioTool(
name="test",
command="test-command",
encoding="utf-16",
env=env_vars,
custom_param="value1",
another_param=42,
)
with patch("agent_framework._mcp.stdio_client"), patch("agent_framework._mcp.StdioServerParameters") as mock_params:
with patch("mcp.client.stdio.stdio_client"), patch("mcp.client.stdio.StdioServerParameters") as mock_params:
tool.get_mcp_client()
# Verify all parameters including custom kwargs were passed
mock_params.assert_called_once_with(
command="test-command",
args=[],
encoding="utf-16",
env=env_vars,
custom_param="value1",
another_param=42,
@@ -1748,7 +1931,7 @@ def test_mcp_streamable_http_tool_get_mcp_client_all_params():
terminate_on_close=True,
)
with patch("agent_framework._mcp.streamable_http_client") as mock_http_client:
with patch("mcp.client.streamable_http.streamable_http_client") as mock_http_client:
tool.get_mcp_client()
# Verify streamable_http_client was called with None for http_client
@@ -1770,7 +1953,7 @@ def test_mcp_websocket_tool_get_mcp_client_with_kwargs():
compression="deflate",
)
with patch("agent_framework._mcp.websocket_client") as mock_ws_client:
with patch("mcp.client.websocket.websocket_client") as mock_ws_client:
tool.get_mcp_client()
# Verify all kwargs were passed
@@ -1928,8 +2111,8 @@ async def test_mcp_streamable_http_tool_httpx_client_cleanup():
# Mock the streamable_http_client to avoid actual connections
with (
patch("agent_framework._mcp.streamable_http_client") as mock_client,
patch("agent_framework._mcp.ClientSession") as mock_session_class,
patch("mcp.client.streamable_http.streamable_http_client") as mock_client,
patch("mcp.client.session.ClientSession") as mock_session_class,
):
# Setup mock context manager for streamable_http_client
mock_transport = (Mock(), Mock())
@@ -2624,6 +2807,80 @@ async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error():
assert "failed to reconnect" in str(exc_info.value).lower()
async def test_mcp_tool_call_tool_requires_loaded_tools() -> None:
tool = MCPTool(name="test_tool", load_tools=False)
with pytest.raises(ToolExecutionException, match="Tools are not loaded"):
await tool.call_tool("remote_tool")
async def test_mcp_tool_get_prompt_requires_loaded_prompts() -> None:
tool = MCPTool(name="test_tool", load_prompts=False)
with pytest.raises(ToolExecutionException, match="Prompts are not loaded"):
await tool.get_prompt("remote_prompt")
async def test_mcp_tool_call_tool_raises_after_reconnection_still_fails() -> None:
from anyio.streams.memory import ClosedResourceError
tool = MCPTool(name="test_tool", load_tools=True)
tool.session = Mock(call_tool=AsyncMock(side_effect=[ClosedResourceError(), ClosedResourceError()]))
with (
patch.object(tool, "connect", AsyncMock()) as mock_connect,
patch.object(logger, "error") as mock_error,
pytest.raises(ToolExecutionException, match="connection lost"),
):
await tool.call_tool("remote_tool")
mock_connect.assert_awaited_once_with(reset=True)
mock_error.assert_called_once()
async def test_mcp_tool_get_prompt_raises_after_reconnection_still_fails() -> None:
from anyio.streams.memory import ClosedResourceError
tool = MCPTool(name="test_tool", load_prompts=True)
tool.session = Mock(get_prompt=AsyncMock(side_effect=[ClosedResourceError(), ClosedResourceError()]))
with (
patch.object(tool, "connect", AsyncMock()) as mock_connect,
patch.object(logger, "error") as mock_error,
pytest.raises(ToolExecutionException, match="connection lost"),
):
await tool.get_prompt("remote_prompt")
mock_connect.assert_awaited_once_with(reset=True)
mock_error.assert_called_once()
async def test_mcp_tool_wraps_unexpected_call_tool_and_get_prompt_errors() -> None:
tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True)
tool.session = Mock()
tool.session.call_tool = AsyncMock(side_effect=RuntimeError("tool boom"))
tool.session.get_prompt = AsyncMock(side_effect=RuntimeError("prompt boom"))
with pytest.raises(ToolExecutionException, match="Failed to call tool 'remote_tool'"):
await tool.call_tool("remote_tool")
with pytest.raises(ToolExecutionException, match="Failed to call prompt 'remote_prompt'"):
await tool.get_prompt("remote_prompt")
async def test_mcp_tool_aenter_wraps_unexpected_errors_and_closes() -> None:
tool = MCPStdioTool(name="test_tool", command="python")
with (
patch.object(tool, "connect", AsyncMock(side_effect=RuntimeError("boom"))),
patch.object(tool, "close", AsyncMock()) as mock_close,
pytest.raises(ToolExecutionException, match="Failed to enter context manager"),
):
await tool.__aenter__()
mock_close.assert_awaited_once()
async def test_mcp_tool_close_cleans_up_in_original_task(caplog):
"""Closing an MCP tool from another task should still unwind contexts in the owner task."""
import asyncio
@@ -2663,7 +2920,7 @@ async def test_mcp_tool_close_cleans_up_in_original_task(caplog):
with (
patch.object(tool, "get_mcp_client", return_value=transport_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
patch("mcp.client.session.ClientSession", return_value=mock_session_context),
):
await asyncio.create_task(tool.connect())
@@ -2721,7 +2978,7 @@ async def test_mcp_tool_connect_reset_cleans_up_in_original_task(caplog):
with (
patch.object(tool, "get_mcp_client", side_effect=transport_contexts),
patch("agent_framework._mcp.ClientSession", side_effect=session_contexts),
patch("mcp.client.session.ClientSession", side_effect=session_contexts),
):
await tool.connect()
@@ -2905,7 +3162,7 @@ async def test_connect_sets_logging_level_when_logger_level_is_set():
with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
patch("mcp.client.session.ClientSession", return_value=mock_session_context),
patch.object(logger, "level", logging.DEBUG), # Set logger level to DEBUG
):
await tool.connect()
@@ -2942,7 +3199,7 @@ async def test_connect_does_not_set_logging_level_when_logger_level_is_notset():
with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
patch("mcp.client.session.ClientSession", return_value=mock_session_context),
patch.object(logger, "level", logging.NOTSET), # Set logger level to NOTSET
):
await tool.connect()
@@ -2980,7 +3237,7 @@ async def test_connect_handles_set_logging_level_exception():
with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
patch("mcp.client.session.ClientSession", return_value=mock_session_context),
patch.object(logger, "level", logging.INFO), # Set logger level to INFO
patch.object(logger, "warning") as mock_warning,
):
@@ -2996,6 +3253,48 @@ async def test_connect_handles_set_logging_level_exception():
assert "Failed to set log level" in call_args[0][0]
async def test_connect_reinitializes_existing_session_and_loads_tools_and_prompts() -> None:
tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True)
tool.is_connected = True
tool.session = Mock()
tool.session._request_id = 0
tool.session.initialize = AsyncMock()
with (
patch.object(tool, "load_tools", AsyncMock()) as mock_load_tools,
patch.object(tool, "load_prompts", AsyncMock()) as mock_load_prompts,
patch.object(logger, "level", logging.NOTSET),
):
await tool._connect_on_owner()
tool.session.initialize.assert_awaited_once()
mock_load_tools.assert_awaited_once()
mock_load_prompts.assert_awaited_once()
assert tool._tools_loaded is True
assert tool._prompts_loaded is True
async def test_ensure_connected_reconnects_on_failed_ping() -> None:
tool = MCPTool(name="test_tool")
tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed")))
with patch.object(tool, "connect", AsyncMock()) as mock_connect:
await tool._ensure_connected()
mock_connect.assert_awaited_once_with(reset=True)
async def test_ensure_connected_wraps_reconnect_failure() -> None:
tool = MCPTool(name="test_tool")
tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed")))
with (
patch.object(tool, "connect", AsyncMock(side_effect=RuntimeError("still closed"))),
pytest.raises(ToolExecutionException, match="Failed to establish MCP connection"),
):
await tool._ensure_connected()
async def test_mcp_tool_filters_framework_kwargs():
"""Test that call_tool filters out framework-specific kwargs before calling MCP session.
@@ -3,7 +3,7 @@
import logging
from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence
from typing import Any
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
@@ -1080,7 +1080,8 @@ def test_configure_otel_providers_reads_env_sensitive_data(monkeypatch):
# Simulate load_dotenv() setting env var after import
monkeypatch.setenv("ENABLE_SENSITIVE_DATA", "true")
observability.configure_otel_providers()
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers()
assert observability.OBSERVABILITY_SETTINGS.enable_instrumentation is True
assert observability.OBSERVABILITY_SETTINGS.enable_sensitive_data is True
@@ -1135,7 +1136,8 @@ def test_configure_otel_providers_explicit_param_overrides_env(monkeypatch):
importlib.reload(observability)
# Explicit False should override the env var True
observability.configure_otel_providers(enable_sensitive_data=False)
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers(enable_sensitive_data=False)
assert observability.OBSERVABILITY_SETTINGS.enable_sensitive_data is False
@@ -1196,7 +1198,8 @@ def test_enable_instrumentation_does_not_clobber_console_exporters(monkeypatch):
importlib.reload(observability)
# Set console exporters via configure_otel_providers
observability.configure_otel_providers(enable_console_exporters=True)
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers(enable_console_exporters=True)
assert observability.OBSERVABILITY_SETTINGS.enable_console_exporters is True
# Calling enable_instrumentation should not clobber the value
@@ -1224,7 +1227,8 @@ def test_enable_instrumentation_with_sensitive_data_does_not_touch_console_expor
importlib.reload(observability)
# Set console exporters via configure_otel_providers
observability.configure_otel_providers(enable_console_exporters=True)
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers(enable_console_exporters=True)
assert observability.OBSERVABILITY_SETTINGS.enable_console_exporters is True
# Calling enable_instrumentation with explicit sensitive_data should not clobber console exporters
@@ -1275,7 +1279,8 @@ def test_configure_otel_providers_reads_env_console_exporters(monkeypatch):
# Simulate load_dotenv() setting env var after import
monkeypatch.setenv("ENABLE_CONSOLE_EXPORTERS", "true")
observability.configure_otel_providers()
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers()
assert observability.OBSERVABILITY_SETTINGS.enable_console_exporters is True
@@ -1298,7 +1303,8 @@ def test_configure_otel_providers_explicit_console_exporters_overrides_env(monke
importlib.reload(observability)
# Explicit False should override the env var True
observability.configure_otel_providers(enable_console_exporters=False)
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers(enable_console_exporters=False)
assert observability.OBSERVABILITY_SETTINGS.enable_console_exporters is False
@@ -2005,6 +2011,14 @@ async def test_agent_streaming_observability(span_exporter: InMemorySpanExporter
assert len(spans) == 1
def test_agent_middleware_wraps_agent_telemetry() -> None:
"""Agent middleware must run outside telemetry so middleware time is excluded from agent latency."""
from agent_framework import Agent
from agent_framework._middleware import AgentMiddlewareLayer
assert Agent.__mro__.index(AgentMiddlewareLayer) < Agent.__mro__.index(AgentTelemetryLayer)
# region Test AgentTelemetryLayer error cases
@@ -3049,11 +3063,12 @@ def test_configure_otel_providers_with_env_file_path(monkeypatch, tmp_path):
env_file = tmp_path / ".env"
env_file.write_text("ENABLE_INSTRUMENTATION=true\n")
observability.configure_otel_providers(
env_file_path=str(env_file),
enable_sensitive_data=True,
vs_code_extension_port=None,
)
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers(
env_file_path=str(env_file),
enable_sensitive_data=True,
vs_code_extension_port=None,
)
assert observability.OBSERVABILITY_SETTINGS.enable_instrumentation is True
assert observability.OBSERVABILITY_SETTINGS.enable_sensitive_data is True
@@ -3078,11 +3093,12 @@ def test_configure_otel_providers_with_env_file_and_vs_code_port(monkeypatch, tm
env_file = tmp_path / ".env"
env_file.write_text("ENABLE_INSTRUMENTATION=true\n")
observability.configure_otel_providers(
env_file_path=str(env_file),
env_file_encoding="utf-8",
vs_code_extension_port=4317,
)
with patch.object(observability.OBSERVABILITY_SETTINGS, "_configure"):
observability.configure_otel_providers(
env_file_path=str(env_file),
env_file_encoding="utf-8",
vs_code_extension_port=4317,
)
assert observability.OBSERVABILITY_SETTINGS.enable_instrumentation is True
assert observability.OBSERVABILITY_SETTINGS.vs_code_extension_port == 4317
@@ -0,0 +1,181 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import sys
import pytest
import agent_framework
import agent_framework.observability as observability
from agent_framework import Agent
def _hide_otel_sdk(monkeypatch: pytest.MonkeyPatch) -> None:
import builtins
real_import = builtins.__import__
for module_name in list(sys.modules):
if module_name == "opentelemetry.sdk" or module_name.startswith("opentelemetry.sdk."):
sys.modules.pop(module_name, None)
def _import_without_otel_sdk(
name: str,
globals_: dict[str, object] | None = None,
locals_: dict[str, object] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> object:
if name == "opentelemetry.sdk" or name.startswith("opentelemetry.sdk."):
raise ModuleNotFoundError(f"No module named '{name}'", name=name)
return real_import(name, globals_, locals_, fromlist, level)
monkeypatch.setattr(builtins, "__import__", _import_without_otel_sdk)
def test_create_resource_requires_otel_sdk(monkeypatch: pytest.MonkeyPatch) -> None:
_hide_otel_sdk(monkeypatch)
with pytest.raises(ModuleNotFoundError, match="opentelemetry-sdk"):
observability.create_resource()
def test_observability_settings_initializes_without_cached_resource(monkeypatch: pytest.MonkeyPatch) -> None:
_hide_otel_sdk(monkeypatch)
settings = observability.ObservabilitySettings()
assert not hasattr(settings, "_resource")
def test_configure_otel_providers_requires_otel_sdk(monkeypatch: pytest.MonkeyPatch) -> None:
_hide_otel_sdk(monkeypatch)
for key in [
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT",
"OTEL_EXPORTER_OTLP_METRICS_ENDPOINT",
"OTEL_EXPORTER_OTLP_LOGS_ENDPOINT",
"VS_CODE_EXTENSION_PORT",
]:
monkeypatch.delenv(key, raising=False)
with pytest.raises(ModuleNotFoundError, match="opentelemetry-sdk"):
observability.configure_otel_providers()
def test_agent_framework_mcp_exports_remain_importable_without_mcp(monkeypatch: pytest.MonkeyPatch) -> None:
import builtins
import agent_framework._mcp as mcp_module
real_import = builtins.__import__
def _import_without_mcp(
name: str,
globals_: dict[str, object] | None = None,
locals_: dict[str, object] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> object:
if name == "mcp" or name.startswith("mcp."):
raise ModuleNotFoundError("No module named 'mcp'")
return real_import(name, globals_, locals_, fromlist, level)
monkeypatch.setattr(builtins, "__import__", _import_without_mcp)
assert agent_framework.MCPStdioTool is mcp_module.MCPStdioTool
with pytest.raises(ModuleNotFoundError, match=r"Please install `mcp`\.$"):
agent_framework.MCPStdioTool(name="test", command="python").get_mcp_client()
def test_mcp_streamable_http_tool_requires_mcp(monkeypatch: pytest.MonkeyPatch) -> None:
import builtins
real_import = builtins.__import__
def _import_without_mcp(
name: str,
globals_: dict[str, object] | None = None,
locals_: dict[str, object] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> object:
if name == "mcp" or name.startswith("mcp."):
raise ModuleNotFoundError("No module named 'mcp'")
return real_import(name, globals_, locals_, fromlist, level)
monkeypatch.setattr(builtins, "__import__", _import_without_mcp)
with pytest.raises(ModuleNotFoundError, match=r"Please install `mcp`\.$"):
agent_framework.MCPStreamableHTTPTool(name="test", url="https://example.com").get_mcp_client()
def test_agent_as_mcp_server_requires_mcp(client, monkeypatch: pytest.MonkeyPatch) -> None:
import builtins
real_import = builtins.__import__
def _import_without_mcp(
name: str,
globals_: dict[str, object] | None = None,
locals_: dict[str, object] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> object:
if name == "mcp" or name.startswith("mcp."):
raise ModuleNotFoundError("No module named 'mcp'")
return real_import(name, globals_, locals_, fromlist, level)
monkeypatch.setattr(builtins, "__import__", _import_without_mcp)
agent = Agent(client=client)
with pytest.raises(ModuleNotFoundError, match=r"Please install `mcp`\.$"):
agent.as_mcp_server()
def test_mcp_websocket_tool_requires_ws_support(monkeypatch: pytest.MonkeyPatch) -> None:
import builtins
real_import = builtins.__import__
sys.modules.pop("mcp.client.websocket", None)
def _import_without_websocket_support(
name: str,
globals_: dict[str, object] | None = None,
locals_: dict[str, object] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> object:
if name == "mcp.client.websocket":
raise ModuleNotFoundError("No module named 'websockets'", name="websockets")
return real_import(name, globals_, locals_, fromlist, level)
monkeypatch.setattr(builtins, "__import__", _import_without_websocket_support)
with pytest.raises(ModuleNotFoundError, match=r"mcp\[ws\]"):
agent_framework.MCPWebsocketTool(name="test", url="wss://example.com").get_mcp_client()
def test_mcp_websocket_tool_requires_mcp(monkeypatch: pytest.MonkeyPatch) -> None:
import builtins
real_import = builtins.__import__
sys.modules.pop("mcp.client.websocket", None)
def _import_without_mcp(
name: str,
globals_: dict[str, object] | None = None,
locals_: dict[str, object] | None = None,
fromlist: tuple[str, ...] = (),
level: int = 0,
) -> object:
if name == "mcp.client.websocket":
raise ModuleNotFoundError("No module named 'mcp.client.websocket'", name="mcp.client.websocket")
return real_import(name, globals_, locals_, fromlist, level)
monkeypatch.setattr(builtins, "__import__", _import_without_mcp)
with pytest.raises(ModuleNotFoundError, match=r"agent-framework-core\[mcp\]|mcp\[ws\]"):
agent_framework.MCPWebsocketTool(name="test", url="wss://example.com").get_mcp_client()
+7 -6
View File
@@ -22,12 +22,13 @@ classifiers = [
"Programming Language :: Python :: 3.14",
"Typing :: Typed",
]
dependencies = [
"agent-framework-core>=1.0.0rc5",
"openai>=1.99.0,<3",
"fastapi>=0.115.0,<0.133.1",
"uvicorn[standard]>=0.30.0,<0.42.0"
]
dependencies = [
"agent-framework-core>=1.0.0rc5",
"openai>=1.99.0,<3",
"opentelemetry-sdk>=1.39.0,<2",
"fastapi>=0.115.0,<0.133.1",
"uvicorn[standard]>=0.30.0,<0.42.0"
]
[project.optional-dependencies]
dev = [
@@ -195,8 +195,8 @@ class RawFoundryAgent( # type: ignore[misc]
class FoundryAgent( # type: ignore[misc]
AgentTelemetryLayer,
AgentMiddlewareLayer,
AgentTelemetryLayer,
RawFoundryAgent[FoundryAgentOptionsT],
):
"""Microsoft Foundry Agent with full middleware and telemetry support.
+7 -6
View File
@@ -27,12 +27,13 @@ dependencies = [
[project.optional-dependencies]
# GAIA benchmark module dependencies
gaia = [
"pydantic>=2.0.0",
"opentelemetry-api>=1.39.0",
"tqdm>=4.60.0",
"huggingface-hub>=0.20.0",
"orjson>=3.10.7,<4",
gaia = [
"pydantic>=2.0.0",
"opentelemetry-api>=1.39.0",
"opentelemetry-sdk>=1.39.0,<2",
"tqdm>=4.60.0",
"huggingface-hub>=0.20.0",
"orjson>=3.10.7,<4",
"pyarrow>=18.0.0", # For reading parquet files
]
+2
View File
@@ -39,6 +39,8 @@ dev = [
"pytest-retry==1.7.0",
"mypy==1.19.1",
"pyright==1.1.408",
"mcp[ws]>=1.24.0,<2",
"opentelemetry-sdk>=1.39.0,<2",
#tasks
"poethepoet==0.42.1",
"rich==13.7.1",
+16 -23
View File
@@ -103,7 +103,9 @@ dependencies = [
[package.dev-dependencies]
dev = [
{ name = "flit", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "mcp", extra = ["ws"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "mypy", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "opentelemetry-sdk", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "poethepoet", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "prek", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "pyright", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
@@ -125,7 +127,9 @@ requires-dist = [{ name = "agent-framework-core", extras = ["all"], editable = "
[package.metadata.requires-dev]
dev = [
{ name = "flit", specifier = "==3.12.0" },
{ name = "mcp", extras = ["ws"], specifier = ">=1.24.0,<2" },
{ name = "mypy", specifier = "==1.19.1" },
{ name = "opentelemetry-sdk", specifier = ">=1.39.0,<2" },
{ name = "poethepoet", specifier = "==0.42.1" },
{ name = "prek", specifier = "==0.3.4" },
{ name = "pyright", specifier = "==1.1.408" },
@@ -209,6 +213,8 @@ dependencies = [
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-ai-agents", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-ai-inference", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
[package.metadata]
@@ -218,6 +224,8 @@ requires-dist = [
{ name = "aiohttp", specifier = ">=3.7.0,<4" },
{ name = "azure-ai-agents", specifier = ">=1.2.0b5,<1.2.0b6" },
{ name = "azure-ai-inference", specifier = ">=1.0.0b9,<1.0.0b10" },
{ name = "azure-ai-projects", specifier = ">=2.0.0,<3.0" },
{ name = "azure-identity", specifier = ">=1,<2" },
]
[[package]]
@@ -339,14 +347,7 @@ name = "agent-framework-core"
version = "1.0.0rc5"
source = { editable = "packages/core" }
dependencies = [
{ name = "azure-ai-projects", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-identity", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "mcp", extra = ["ws"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "opentelemetry-sdk", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "opentelemetry-semantic-conventions-ai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "python-dotenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
@@ -373,9 +374,11 @@ all = [
{ name = "agent-framework-lab", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "agent-framework-mem0", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "agent-framework-ollama", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "agent-framework-openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "agent-framework-orchestrations", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "agent-framework-purview", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "agent-framework-redis", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "mcp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
[package.metadata]
@@ -399,17 +402,12 @@ requires-dist = [
{ name = "agent-framework-lab", marker = "extra == 'all'", editable = "packages/lab" },
{ name = "agent-framework-mem0", marker = "extra == 'all'", editable = "packages/mem0" },
{ name = "agent-framework-ollama", marker = "extra == 'all'", editable = "packages/ollama" },
{ name = "agent-framework-openai", marker = "extra == 'all'", editable = "packages/openai" },
{ name = "agent-framework-orchestrations", marker = "extra == 'all'", editable = "packages/orchestrations" },
{ name = "agent-framework-purview", marker = "extra == 'all'", editable = "packages/purview" },
{ name = "agent-framework-redis", marker = "extra == 'all'", editable = "packages/redis" },
{ name = "azure-ai-projects", specifier = ">=2.0.0,<3.0" },
{ name = "azure-identity", specifier = ">=1,<2" },
{ name = "mcp", extras = ["ws"], specifier = ">=1.24.0,<2" },
{ name = "openai", specifier = ">=1.99.0,<3" },
{ name = "mcp", marker = "extra == 'all'", specifier = ">=1.24.0,<2" },
{ name = "opentelemetry-api", specifier = ">=1.39.0,<2" },
{ name = "opentelemetry-sdk", specifier = ">=1.39.0,<2" },
{ name = "opentelemetry-semantic-conventions-ai", specifier = ">=0.4.13,<0.4.14" },
{ name = "packaging", specifier = ">=24.1,<25" },
{ name = "pydantic", specifier = ">=2,<3" },
{ name = "python-dotenv", specifier = ">=1,<2" },
{ name = "typing-extensions", specifier = ">=4.15.0,<5" },
@@ -449,6 +447,7 @@ dependencies = [
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "opentelemetry-sdk", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "uvicorn", extra = ["standard"], marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
@@ -469,6 +468,7 @@ requires-dist = [
{ name = "agent-framework-orchestrations", marker = "extra == 'dev'", editable = "packages/orchestrations" },
{ name = "fastapi", specifier = ">=0.115.0,<0.133.1" },
{ name = "openai", specifier = ">=1.99.0,<3" },
{ name = "opentelemetry-sdk", specifier = ">=1.39.0,<2" },
{ name = "pytest", marker = "extra == 'all'", specifier = "==9.0.2" },
{ name = "pytest", marker = "extra == 'dev'", specifier = "==9.0.2" },
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.30.0,<0.42.0" },
@@ -565,6 +565,7 @@ dependencies = [
gaia = [
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "opentelemetry-sdk", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "orjson", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "pyarrow", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
@@ -607,6 +608,7 @@ requires-dist = [
{ name = "loguru", marker = "extra == 'tau2'", specifier = ">=0.7.3" },
{ name = "numpy", marker = "extra == 'tau2'" },
{ name = "opentelemetry-api", marker = "extra == 'gaia'", specifier = ">=1.39.0" },
{ name = "opentelemetry-sdk", marker = "extra == 'gaia'", specifier = ">=1.39.0,<2" },
{ name = "orjson", marker = "extra == 'gaia'", specifier = ">=3.10.7,<4" },
{ name = "pyarrow", marker = "extra == 'gaia'", specifier = ">=18.0.0" },
{ name = "pydantic", marker = "extra == 'gaia'", specifier = ">=2.0.0" },
@@ -4220,15 +4222,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2", size = 231621, upload-time = "2026-03-04T14:17:19.33Z" },
]
[[package]]
name = "opentelemetry-semantic-conventions-ai"
version = "0.4.13"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ba/e6/40b59eda51ac47009fb47afcdf37c6938594a0bd7f3b9fadcbc6058248e3/opentelemetry_semantic_conventions_ai-0.4.13.tar.gz", hash = "sha256:94efa9fb4ffac18c45f54a3a338ffeb7eedb7e1bb4d147786e77202e159f0036", size = 5368, upload-time = "2025-08-22T10:14:17.387Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/35/b5/cf25da2218910f0d6cdf7f876a06bed118c4969eacaf60a887cbaef44f44/opentelemetry_semantic_conventions_ai-0.4.13-py3-none-any.whl", hash = "sha256:883a30a6bb5deaec0d646912b5f9f6dcbb9f6f72557b73d0f2560bf25d13e2d5", size = 6080, upload-time = "2025-08-22T10:14:16.477Z" },
]
[[package]]
name = "ordered-set"
version = "4.1.0"