mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add header_provider to Streamable HTTP MCP servers (#4849)
* Python: Add header_provider to MCPStreamableHTTPTool (#4808) Add a header_provider callback parameter to MCPStreamableHTTPTool that enables injecting dynamic per-request HTTP headers from runtime kwargs (originating from FunctionInvocationContext.kwargs set in agent middleware). The implementation uses contextvars and httpx event hooks to ensure headers are task-local and safe for concurrent tool calls: - header_provider receives the runtime kwargs dict and returns headers - call_tool sets a ContextVar before delegating to MCPTool.call_tool - An httpx request event hook reads from the ContextVar and injects headers Example usage: mcp_tool = MCPStreamableHTTPTool( name="web-api", url="https://api.example.com/mcp", header_provider=lambda kwargs: { "X-Auth-Token": kwargs.get("auth_token", ""), }, ) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4808: Python: [Bug]: Unable to pass AgentContext to MCPStreamableHTTPTool * Add test for header_provider via FunctionTool.invoke with FunctionInvocationContext Addresses PR review comment: exercises the full pipeline from FunctionInvocationContext.kwargs through FunctionTool.invoke to MCPStreamableHTTPTool.call_tool and header_provider, rather than testing call_tool in isolation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4808: review comment fixes * Fix streamable MCP transport defaults Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix Azure AI test client mocks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix MCP runtime kwarg regressions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Stabilize MCP tool runtime kwargs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Use context kwargs in MCP wrappers Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * updated mcp samples * fix link --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
7c2dae8855
commit
9c57680f00
@@ -77,7 +77,7 @@ def mock_project_client() -> MagicMock:
|
||||
mock_client.telemetry.get_application_insights_connection_string = AsyncMock()
|
||||
|
||||
# Mock get_openai_client method
|
||||
mock_client.get_openai_client = AsyncMock()
|
||||
mock_client.get_openai_client = MagicMock()
|
||||
|
||||
# Mock close method
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
@@ -34,7 +34,7 @@ def mock_project_client() -> MagicMock:
|
||||
mock_client.telemetry.get_application_insights_connection_string = AsyncMock()
|
||||
|
||||
# Mock get_openai_client method
|
||||
mock_client.get_openai_client = AsyncMock()
|
||||
mock_client.get_openai_client = MagicMock()
|
||||
|
||||
# Mock close method
|
||||
mock_client.close = AsyncMock()
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -38,6 +39,7 @@ if TYPE_CHECKING:
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._middleware import FunctionInvocationContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,6 +61,9 @@ class MCPSpecificApproval(TypedDict, total=False):
|
||||
|
||||
_MCP_REMOTE_NAME_KEY = "_mcp_remote_name"
|
||||
_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name"
|
||||
_mcp_call_headers: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("_mcp_call_headers")
|
||||
MCP_DEFAULT_TIMEOUT = 30
|
||||
MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5
|
||||
|
||||
# region: Helpers
|
||||
|
||||
@@ -137,6 +142,22 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
|
||||
return meta
|
||||
|
||||
|
||||
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
"""Lazily import the MCP streamable HTTP transport."""
|
||||
try:
|
||||
from mcp.client.streamable_http import streamable_http_client as _streamable_http_client
|
||||
except ModuleNotFoundError as ex:
|
||||
missing_name = ex.name or str(ex)
|
||||
if missing_name == "mcp" or missing_name.startswith("mcp.") or "mcp" in missing_name:
|
||||
raise ModuleNotFoundError("`MCPStreamableHTTPTool` requires `mcp`. Please install `mcp`.") from ex
|
||||
raise ModuleNotFoundError(
|
||||
f"`MCPStreamableHTTPTool` requires streamable HTTP transport support. "
|
||||
f"The optional dependency `{missing_name}` is not installed. Please update your dependencies."
|
||||
) from ex
|
||||
|
||||
return _streamable_http_client(*args, **kwargs) # type: ignore[return-value]
|
||||
|
||||
|
||||
# region: MCP Plugin
|
||||
|
||||
|
||||
@@ -951,9 +972,20 @@ class MCPTool:
|
||||
input_schema = dict(tool.inputSchema or {})
|
||||
if input_schema.get("type") == "object" and "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
|
||||
async def _call_tool_with_runtime_kwargs(
|
||||
ctx: FunctionInvocationContext,
|
||||
*,
|
||||
_remote_tool_name: str = tool.name,
|
||||
**kwargs: Any,
|
||||
) -> str | list[Content]:
|
||||
call_kwargs = dict(ctx.kwargs)
|
||||
call_kwargs.update(kwargs)
|
||||
return await self.call_tool(_remote_tool_name, **call_kwargs)
|
||||
|
||||
# Create FunctionTools out of each tool
|
||||
func: FunctionTool = FunctionTool(
|
||||
func=partial(self.call_tool, tool.name),
|
||||
func=_call_tool_with_runtime_kwargs,
|
||||
name=local_name,
|
||||
description=tool.description or "",
|
||||
approval_mode=approval_mode,
|
||||
@@ -1386,6 +1418,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
client: SupportsChatGetResponse | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
http_client: AsyncClient | None = None,
|
||||
header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the MCP streamable HTTP tool.
|
||||
@@ -1433,6 +1466,11 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
``streamable_http_client`` API will create and manage a default client.
|
||||
To configure headers, timeouts, or other HTTP client settings, create
|
||||
and pass your own ``asyncClient`` instance.
|
||||
header_provider: Optional callable that receives the runtime keyword arguments
|
||||
(from ``FunctionInvocationContext.kwargs``) and returns a ``dict[str, str]``
|
||||
of HTTP headers to inject into every outbound request to the MCP server.
|
||||
Use this to forward per-request context (e.g. authentication tokens set in
|
||||
agent middleware) without creating a separate ``httpx.AsyncClient``.
|
||||
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -1453,6 +1491,7 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
self.url = url
|
||||
self.terminate_on_close = terminate_on_close
|
||||
self._httpx_client: AsyncClient | None = http_client
|
||||
self._header_provider = header_provider
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
"""Get an MCP streamable HTTP client.
|
||||
@@ -1460,18 +1499,59 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
Returns:
|
||||
An async context manager for the streamable HTTP client transport.
|
||||
"""
|
||||
try:
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
except ModuleNotFoundError as ex:
|
||||
raise ModuleNotFoundError("`mcp` is required to use `MCPStreamableHTTPTool`. Please install `mcp`.") from ex
|
||||
from httpx import AsyncClient, Request, Timeout
|
||||
|
||||
http_client = self._httpx_client
|
||||
if self._header_provider is not None:
|
||||
if http_client is None:
|
||||
http_client = AsyncClient(
|
||||
follow_redirects=True,
|
||||
timeout=Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT),
|
||||
)
|
||||
self._httpx_client = http_client
|
||||
|
||||
if not hasattr(self, "_inject_headers_hook"):
|
||||
|
||||
async def _inject_headers(request: Request) -> None: # noqa: RUF029
|
||||
headers = _mcp_call_headers.get({})
|
||||
for key, value in headers.items():
|
||||
request.headers[key] = value
|
||||
|
||||
self._inject_headers_hook = _inject_headers # type: ignore[attr-defined]
|
||||
http_client.event_hooks["request"].append(self._inject_headers_hook) # type: ignore[attr-defined]
|
||||
|
||||
# Pass the http_client (which may be None) to streamable_http_client
|
||||
return streamable_http_client(
|
||||
url=self.url,
|
||||
http_client=self._httpx_client,
|
||||
http_client=http_client,
|
||||
terminate_on_close=self.terminate_on_close if self.terminate_on_close is not None else True,
|
||||
)
|
||||
|
||||
async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
|
||||
"""Call a tool, injecting headers from the header_provider if configured.
|
||||
|
||||
When a ``header_provider`` was supplied at construction time, the runtime
|
||||
*kwargs* (originating from ``FunctionInvocationContext.kwargs``) are passed
|
||||
to the provider. The returned headers are attached to every HTTP request
|
||||
made during this tool call via a ``contextvars.ContextVar``.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to call.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
A list of Content items representing the tool output.
|
||||
"""
|
||||
if self._header_provider is not None:
|
||||
headers = self._header_provider(kwargs)
|
||||
token = _mcp_call_headers.set(headers)
|
||||
try:
|
||||
return await super().call_tool(tool_name, **kwargs)
|
||||
finally:
|
||||
_mcp_call_headers.reset(token)
|
||||
return await super().call_tool(tool_name, **kwargs)
|
||||
|
||||
|
||||
class MCPWebsocketTool(MCPTool):
|
||||
"""MCP tool for connecting to WebSocket-based MCP servers.
|
||||
|
||||
@@ -3804,4 +3804,377 @@ async def test_mcp_tool_call_tool_otel_meta(use_span, expect_traceparent, span_e
|
||||
assert meta is None
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
|
||||
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=lambda kw: {"X-Token": kw.get("token", "")},
|
||||
)
|
||||
|
||||
try:
|
||||
with patch("agent_framework._mcp.streamable_http_client"):
|
||||
tool.get_mcp_client()
|
||||
tool.get_mcp_client()
|
||||
tool.get_mcp_client()
|
||||
|
||||
assert tool._httpx_client is not None
|
||||
hooks = tool._httpx_client.event_hooks.get("request", [])
|
||||
assert len(hooks) == 1, f"Expected exactly one hook, got {len(hooks)}"
|
||||
finally:
|
||||
if getattr(tool, "_httpx_client", None) is not None:
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region: MCPStreamableHTTPTool header_provider
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_injects_headers():
|
||||
"""Test that header_provider integrates with call_tool via runtime kwargs.
|
||||
|
||||
When header_provider is configured, runtime kwargs from FunctionInvocationContext
|
||||
are passed to the provider and the MCP session.call_tool is invoked successfully.
|
||||
"""
|
||||
|
||||
class _TestServer(MCPStreamableHTTPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="greet",
|
||||
description="Says hello",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")])
|
||||
)
|
||||
self.session.send_ping = AsyncMock()
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self):
|
||||
return None
|
||||
|
||||
def provider(kwargs):
|
||||
return {"X-Some-Token": kwargs.get("some_token", "")}
|
||||
|
||||
server = _TestServer(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=provider,
|
||||
)
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
|
||||
# Simulate the runtime kwargs that flow from FunctionInvocationContext.kwargs
|
||||
await server.call_tool("greet", name="Alice", some_token="my-secret")
|
||||
|
||||
# Verify the MCP session.call_tool was called
|
||||
server.session.call_tool.assert_called_once()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_sets_contextvar():
|
||||
"""Test that call_tool sets the contextvar with headers from header_provider."""
|
||||
from agent_framework._mcp import _mcp_call_headers
|
||||
|
||||
observed_headers: list[dict[str, str]] = []
|
||||
original_call_tool = MCPTool.call_tool
|
||||
|
||||
async def spy_call_tool(self, tool_name, **kwargs):
|
||||
# Capture the contextvar value during the super call
|
||||
try:
|
||||
observed_headers.append(_mcp_call_headers.get())
|
||||
except LookupError:
|
||||
observed_headers.append({})
|
||||
return await original_call_tool(self, tool_name, **kwargs)
|
||||
|
||||
class _TestServer(MCPStreamableHTTPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="greet",
|
||||
description="Says hello",
|
||||
inputSchema={"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")])
|
||||
)
|
||||
self.session.send_ping = AsyncMock()
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self):
|
||||
return None
|
||||
|
||||
server = _TestServer(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=lambda kw: {"X-Auth": kw.get("auth_token", "")},
|
||||
)
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
|
||||
with patch.object(MCPTool, "call_tool", spy_call_tool):
|
||||
await server.call_tool("greet", name="Alice", auth_token="bearer-xyz")
|
||||
|
||||
assert len(observed_headers) == 1
|
||||
assert observed_headers[0] == {"X-Auth": "bearer-xyz"}
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_contextvar_reset_after_call():
|
||||
"""Test that the contextvar is properly reset after call_tool completes."""
|
||||
from agent_framework._mcp import _mcp_call_headers
|
||||
|
||||
class _TestServer(MCPStreamableHTTPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="greet",
|
||||
description="Says hello",
|
||||
inputSchema={"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")])
|
||||
)
|
||||
self.session.send_ping = AsyncMock()
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self):
|
||||
return None
|
||||
|
||||
server = _TestServer(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=lambda kw: {"X-Token": kw.get("token", "")},
|
||||
)
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
await server.call_tool("greet", name="Alice", token="secret")
|
||||
|
||||
# After call_tool, the contextvar should be unset (reset to no value)
|
||||
with pytest.raises(LookupError):
|
||||
_mcp_call_headers.get()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_without_header_provider():
|
||||
"""Test that call_tool works normally when no header_provider is configured."""
|
||||
|
||||
class _TestServer(MCPStreamableHTTPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="greet",
|
||||
description="Says hello",
|
||||
inputSchema={"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")])
|
||||
)
|
||||
self.session.send_ping = AsyncMock()
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self):
|
||||
return None
|
||||
|
||||
server = _TestServer(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
)
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
await server.call_tool("greet", name="Alice")
|
||||
server.session.call_tool.assert_called_once()
|
||||
|
||||
# Without header_provider, call_tool should delegate directly to MCPTool
|
||||
assert server._header_provider is None
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
|
||||
"""Test that the httpx event hook injects headers from the contextvar."""
|
||||
import httpx
|
||||
|
||||
from agent_framework._mcp import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, _mcp_call_headers
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=lambda kw: {"X-Custom": kw.get("custom", "")},
|
||||
)
|
||||
|
||||
try:
|
||||
with patch("agent_framework._mcp.streamable_http_client"):
|
||||
# Trigger get_mcp_client to set up the event hook
|
||||
tool.get_mcp_client()
|
||||
|
||||
# The tool should have created an httpx client with the event hook
|
||||
assert tool._httpx_client is not None
|
||||
assert tool._httpx_client.follow_redirects is True
|
||||
assert tool._httpx_client.timeout.connect == MCP_DEFAULT_TIMEOUT
|
||||
assert tool._httpx_client.timeout.read == MCP_DEFAULT_SSE_READ_TIMEOUT
|
||||
hooks = tool._httpx_client.event_hooks.get("request", [])
|
||||
assert len(hooks) == 1, "Expected one request event hook"
|
||||
|
||||
# Simulate what happens during a call_tool: contextvar is set
|
||||
token = _mcp_call_headers.set({"X-Custom": "test-value"})
|
||||
try:
|
||||
request = httpx.Request("POST", "http://example.com/mcp")
|
||||
await hooks[0](request)
|
||||
assert request.headers.get("X-Custom") == "test-value"
|
||||
finally:
|
||||
_mcp_call_headers.reset(token)
|
||||
finally:
|
||||
# Ensure any created httpx client is properly closed
|
||||
if getattr(tool, "_httpx_client", None) is not None:
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client():
|
||||
"""Test that header_provider works when the user provides their own httpx client."""
|
||||
import httpx
|
||||
|
||||
from agent_framework._mcp import _mcp_call_headers
|
||||
|
||||
user_client = httpx.AsyncClient(headers={"X-Base": "static"})
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
http_client=user_client,
|
||||
header_provider=lambda kw: {"X-Dynamic": kw.get("dynamic", "")},
|
||||
)
|
||||
|
||||
with patch("agent_framework._mcp.streamable_http_client"):
|
||||
tool.get_mcp_client()
|
||||
|
||||
# The user's client should still be used
|
||||
assert tool._httpx_client is user_client
|
||||
hooks = user_client.event_hooks.get("request", [])
|
||||
assert len(hooks) == 1
|
||||
|
||||
# Verify the hook injects headers
|
||||
token = _mcp_call_headers.set({"X-Dynamic": "per-request"})
|
||||
try:
|
||||
request = httpx.Request("POST", "http://example.com/mcp")
|
||||
await hooks[0](request)
|
||||
assert request.headers.get("X-Dynamic") == "per-request"
|
||||
finally:
|
||||
_mcp_call_headers.reset(token)
|
||||
|
||||
await user_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_via_invoke_with_context():
|
||||
"""Test that header_provider receives kwargs via FunctionTool.invoke with FunctionInvocationContext.
|
||||
|
||||
This exercises the full pipeline: FunctionInvocationContext.kwargs -> FunctionTool.invoke
|
||||
-> MCPStreamableHTTPTool.call_tool -> header_provider.
|
||||
"""
|
||||
from agent_framework._mcp import _mcp_call_headers
|
||||
|
||||
observed_headers: list[dict[str, str]] = []
|
||||
original_call_tool = MCPStreamableHTTPTool.call_tool
|
||||
|
||||
async def spy_call_tool(self, tool_name, **kwargs):
|
||||
# Capture the contextvar value set by call_tool before delegating
|
||||
result = await original_call_tool(self, tool_name, **kwargs)
|
||||
try:
|
||||
observed_headers.append(_mcp_call_headers.get())
|
||||
except LookupError:
|
||||
observed_headers.append({})
|
||||
return result
|
||||
|
||||
class _TestServer(MCPStreamableHTTPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="greet",
|
||||
description="Says hello",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Hello!")])
|
||||
)
|
||||
self.session.send_ping = AsyncMock()
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self):
|
||||
return None
|
||||
|
||||
provider_received: list[dict] = []
|
||||
|
||||
def provider(kwargs):
|
||||
provider_received.append(dict(kwargs))
|
||||
return {"X-Some-Token": kwargs.get("some_token", "")}
|
||||
|
||||
server = _TestServer(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=provider,
|
||||
)
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
func = server.functions[0]
|
||||
|
||||
# Build a FunctionInvocationContext with runtime kwargs, as the agent framework would
|
||||
context = FunctionInvocationContext(
|
||||
function=func,
|
||||
arguments={"name": "Alice"},
|
||||
kwargs={"some_token": "my-secret"},
|
||||
)
|
||||
|
||||
with patch.object(MCPStreamableHTTPTool, "call_tool", spy_call_tool):
|
||||
result = await func.invoke(arguments={"name": "Alice"}, context=context)
|
||||
|
||||
# Verify the invoke produced a result
|
||||
assert isinstance(result, list)
|
||||
assert result[0].text == "Hello!"
|
||||
|
||||
# Verify header_provider was called with the runtime kwargs
|
||||
assert len(provider_received) == 1
|
||||
assert provider_received[0]["some_token"] == "my-secret"
|
||||
|
||||
# Verify session.call_tool was called with the tool arguments (not the runtime kwargs)
|
||||
server.session.call_tool.assert_called_once()
|
||||
call_args = server.session.call_tool.call_args
|
||||
assert call_args.kwargs.get("arguments", {}).get("name") == "Alice"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -11,7 +11,7 @@ The Model Context Protocol (MCP) is an open standard for connecting AI agents to
|
||||
| Sample | File | Description |
|
||||
|--------|------|-------------|
|
||||
| **Agent as MCP Server** | [`agent_as_mcp_server.py`](agent_as_mcp_server.py) | Shows how to expose an Agent Framework agent as an MCP server that other AI applications can connect to |
|
||||
| **API Key Authentication** | [`mcp_api_key_auth.py`](mcp_api_key_auth.py) | Demonstrates API key authentication with MCP servers |
|
||||
| **API Key Authentication** | [`mcp_api_key_auth.py`](mcp_api_key_auth.py) | Demonstrates API key authentication with MCP servers using `header_provider`, runtime invocation kwargs, and a command-line API key argument |
|
||||
| **GitHub Integration with PAT** | [`mcp_github_pat.py`](mcp_github_pat.py) | Demonstrates connecting to GitHub's MCP server using Personal Access Token (PAT) authentication |
|
||||
|
||||
## Prerequisites
|
||||
@@ -19,5 +19,7 @@ The Model Context Protocol (MCP) is an open standard for connecting AI agents to
|
||||
- `OPENAI_API_KEY` environment variable
|
||||
- `OPENAI_RESPONSES_MODEL` environment variable
|
||||
|
||||
Run `mcp_api_key_auth.py` with the MCP API key as the first command-line argument.
|
||||
|
||||
For `mcp_github_pat.py`:
|
||||
- `GITHUB_PAT` - Your GitHub Personal Access Token (create at https://github.com/settings/tokens)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Annotated, Any
|
||||
|
||||
import anyio
|
||||
from agent_framework import Agent, tool
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
@@ -57,7 +57,7 @@ async def run() -> None:
|
||||
# Define an agent
|
||||
# Agent's name and description provide better context for AI model
|
||||
agent = Agent(
|
||||
client=OpenAIResponsesClient(),
|
||||
client=OpenAIChatClient(),
|
||||
name="RestaurantAgent",
|
||||
description="Answer questions about the menu.",
|
||||
tools=[get_specials, get_item_price],
|
||||
|
||||
@@ -1,20 +1,31 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from agent_framework import Agent, MCPStreamableHTTPTool
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from dotenv import load_dotenv
|
||||
from httpx import AsyncClient
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
MCP Authentication Example
|
||||
MCP API Key Authentication Example
|
||||
|
||||
This example demonstrates how to authenticate with MCP servers using API key headers.
|
||||
This sample demonstrates the runtime ``header_provider`` pattern for
|
||||
``MCPStreamableHTTPTool``. The MCP tool derives authentication headers from
|
||||
``function_invocation_kwargs`` passed to ``Agent.run(...)`` so the API key stays
|
||||
in runtime context instead of being baked into a shared ``httpx.AsyncClient``.
|
||||
|
||||
Replace the ``url`` parameter in the ``MCPStreamableHTTPTool`` with your authenticated server URL and
|
||||
run the sample with your API key as a command-line argument:
|
||||
python mcp_api_key_auth.py <your_api_key>
|
||||
|
||||
The ``header_provider`` here is just a simple lambda, but it can be a more complex function that retrieves and
|
||||
formats headers as needed, allowing for flexible authentication schemes.
|
||||
For more complex scenarios, you could implement token refresh logic or support multiple authentication methods
|
||||
within the header provider function.
|
||||
|
||||
For more authentication examples including OAuth 2.0 flows, see:
|
||||
- https://github.com/modelcontextprotocol/python-sdk/tree/main/examples/clients/simple-auth-client
|
||||
@@ -22,44 +33,28 @@ For more authentication examples including OAuth 2.0 flows, see:
|
||||
"""
|
||||
|
||||
|
||||
async def api_key_auth_example() -> None:
|
||||
"""Example of using API key authentication with MCP server."""
|
||||
# Configuration
|
||||
mcp_server_url = os.getenv("MCP_SERVER_URL", "your-mcp-server-url")
|
||||
api_key = os.getenv("MCP_API_KEY")
|
||||
async def api_key_auth_example(api_key: str) -> None:
|
||||
"""Run an agent against an MCP server using runtime-provided API key headers."""
|
||||
|
||||
# Create authentication headers
|
||||
# Common patterns:
|
||||
# - Bearer token: "Authorization": f"Bearer {api_key}"
|
||||
# - API key header: "X-API-Key": api_key
|
||||
# - Custom header: "Authorization": f"ApiKey {api_key}"
|
||||
auth_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
# Create HTTP client with authentication headers
|
||||
http_client = AsyncClient(headers=auth_headers)
|
||||
|
||||
# Create MCP tool with the configured HTTP client
|
||||
async with (
|
||||
MCPStreamableHTTPTool(
|
||||
async with Agent(
|
||||
client=OpenAIChatClient(),
|
||||
name="Agent",
|
||||
instructions="You are a helpful assistant. Use your MCP tool when answering the user's question.",
|
||||
tools=MCPStreamableHTTPTool(
|
||||
name="MCP tool",
|
||||
description="MCP tool description",
|
||||
url=mcp_server_url,
|
||||
http_client=http_client, # Pass HTTP client with authentication headers
|
||||
) as mcp_tool,
|
||||
Agent(
|
||||
client=OpenAIResponsesClient(),
|
||||
name="Agent",
|
||||
instructions="You are a helpful assistant.",
|
||||
tools=mcp_tool,
|
||||
) as agent,
|
||||
):
|
||||
query = "What tools are available to you?"
|
||||
description="MCP tool description.",
|
||||
url="<your authenticated server url>",
|
||||
header_provider=lambda kwargs: {"Authorization": f"Bearer {kwargs['mcp_api_key']}"},
|
||||
),
|
||||
) as agent:
|
||||
query = "Use your MCP tool to tell me what tools are available to you."
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
result = await agent.run(
|
||||
query,
|
||||
function_invocation_kwargs={"mcp_api_key": api_key},
|
||||
)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(api_key_auth_example())
|
||||
asyncio.run(api_key_auth_example(sys.argv[1]))
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import Agent
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
"""
|
||||
@@ -45,7 +45,7 @@ async def github_mcp_example() -> None:
|
||||
# 4. Create agent with the GitHub MCP tool using instance method
|
||||
# The MCP tool manages the connection to the MCP server and makes its tools available
|
||||
# Set approval_mode="never_require" to allow the MCP tool to execute without approval
|
||||
client = OpenAIResponsesClient()
|
||||
client = OpenAIChatClient()
|
||||
github_mcp_tool = client.get_mcp_tool(
|
||||
name="GitHub",
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
|
||||
Reference in New Issue
Block a user