mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Python parity sample for invoking Foundry Toolbox tools from declarative workflows (#5933)
* Add Python parity sample for invoking Foundry Toolbox tools from declarative workflows * Python: address PR review on declarative toolbox sample Two security fixes for PR #5933: 1. Add safe_mode flag to WorkflowFactory (default True) mirroring AgentFactory. Gates =Env.* exposure inside DeclarativeWorkflowState PowerFx symbols via _safe_mode_context, so workflow YAML loaded from untrusted sources no longer leaks the host's full os.environ snapshot into PowerFx evaluation. The flag is also forwarded to the internally-constructed AgentFactory so inline agent definitions follow the same policy. 2. Pin the invoke_foundry_toolbox_mcp sample's _client_provider to the resolved toolbox endpoint. The bearer-authenticated httpx client is now only returned when MCPToolInvocation.server_url matches the toolbox URL case-insensitively; any other URL gets None (the default unauthenticated path), preventing the Foundry AAD bearer token from being attached to a mis-configured or injected server URL. Mirrors the .NET sample's httpClientProvider guard. The sample is updated to opt in to safe_mode=False because its YAML intentionally uses =Env.FOUNDRY_TOOLBOX_* to keep configuration in env vars under the developer's control. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix pyright issues. * Addressed PR comments. * Fix CI pipelines. * Resolve PR comments * Revamped sample to address PR comments. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
bd4fc64b4d
commit
200488cb08
+128
-3
@@ -27,12 +27,15 @@ from __future__ import annotations
|
||||
|
||||
import locale
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from decimal import Decimal as _Decimal
|
||||
from enum import Enum
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from agent_framework import (
|
||||
@@ -58,6 +61,100 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_ENV_REFERENCE_RE = re.compile(r"\bEnv\.([A-Za-z_][A-Za-z0-9_]*)")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeclarativeEnvConfig:
|
||||
"""Configuration that populates the PowerFx ``Env`` symbol for a workflow.
|
||||
|
||||
Configuration values are always exposed under ``Env.<name>``;
|
||||
``os.environ`` is consulted only when ``restrict_to_configuration``
|
||||
is ``False`` AND the YAML literally references the name in a PowerFx
|
||||
expression (the allowlist enforced via ``referenced_names``).
|
||||
|
||||
Attributes:
|
||||
values: Caller-supplied configuration resolved by name when the
|
||||
workflow YAML references ``=Env.NAME``. Always exposed in
|
||||
the ``Env`` symbol regardless of ``restrict_to_configuration``.
|
||||
restrict_to_configuration: When ``True`` (default), the ``Env``
|
||||
symbol is populated exclusively from ``values``; ``os.environ``
|
||||
is never consulted. Set to ``False`` to additionally fall back
|
||||
to ``os.environ`` for names absent from ``values`` that the
|
||||
workflow YAML explicitly references.
|
||||
referenced_names: The set of ``Env.NAME`` symbols discovered in
|
||||
PowerFx expressions inside the workflow definition. The
|
||||
``os.environ`` fallback is constrained to this allowlist so
|
||||
unrelated environment variables never enter the PowerFx scope.
|
||||
"""
|
||||
|
||||
values: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({}))
|
||||
restrict_to_configuration: bool = True
|
||||
referenced_names: frozenset[str] = field(default_factory=lambda: frozenset[str]())
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Defensive snapshots so the frozen guarantee extends to the
|
||||
# contents of ``values`` / ``referenced_names``: caller mutations
|
||||
# to the original objects after construction cannot leak into
|
||||
# ``resolve()``.
|
||||
object.__setattr__(self, "values", MappingProxyType(dict(self.values)))
|
||||
object.__setattr__(self, "referenced_names", frozenset(self.referenced_names))
|
||||
|
||||
def resolve(self) -> dict[str, str]:
|
||||
"""Return the resolved ``Env`` symbol mapping for the workflow.
|
||||
|
||||
Configuration values are always included (stringified).
|
||||
``os.environ`` is consulted only when ``restrict_to_configuration``
|
||||
is ``False`` and the name appears in ``referenced_names``, so
|
||||
unrelated environment variables never enter the PowerFx scope.
|
||||
Configuration values always win over the environment fallback.
|
||||
"""
|
||||
resolved = {name: str(value) for name, value in self.values.items()}
|
||||
if self.restrict_to_configuration:
|
||||
return resolved
|
||||
for name in self.referenced_names.difference(resolved):
|
||||
env_value = os.environ.get(name)
|
||||
if env_value is not None:
|
||||
resolved[name] = env_value
|
||||
return resolved
|
||||
|
||||
|
||||
def discover_env_references(node: Any) -> set[str]:
|
||||
"""Discover ``Env.NAME`` references in PowerFx expressions inside ``node``.
|
||||
|
||||
Walks any nested ``Mapping``/``list``/scalar structure and inspects every
|
||||
string value. To avoid false positives from doc/description fields that
|
||||
happen to mention ``Env.SOMETHING`` as plain text, the scan only inspects
|
||||
strings that begin with ``=`` (PowerFx expression marker, matching the
|
||||
convention enforced by :meth:`DeclarativeWorkflowState.eval`).
|
||||
|
||||
Args:
|
||||
node: A parsed workflow definition (typically the dict produced by
|
||||
``yaml.safe_load``).
|
||||
|
||||
Returns:
|
||||
The set of ``Env`` identifier names referenced in PowerFx
|
||||
expressions inside ``node``.
|
||||
"""
|
||||
names: set[str] = set()
|
||||
|
||||
def visit(value: Any) -> None:
|
||||
if isinstance(value, str):
|
||||
if value.startswith("="):
|
||||
names.update(_ENV_REFERENCE_RE.findall(value))
|
||||
return
|
||||
if isinstance(value, Mapping):
|
||||
for inner in cast(Mapping[Any, Any], value).values(): # type: ignore[redundant-cast]
|
||||
visit(inner)
|
||||
return
|
||||
if isinstance(value, list):
|
||||
for item in cast(list[Any], value): # type: ignore[redundant-cast]
|
||||
visit(item)
|
||||
|
||||
visit(node)
|
||||
return names
|
||||
|
||||
|
||||
class ConversationData(TypedDict):
|
||||
"""Structure for conversation-related state data.
|
||||
|
||||
@@ -169,13 +266,18 @@ class DeclarativeWorkflowState:
|
||||
- Conversation: Conversation history
|
||||
"""
|
||||
|
||||
def __init__(self, state: State):
|
||||
def __init__(self, state: State, env_config: DeclarativeEnvConfig | None = None):
|
||||
"""Initialize with a State instance.
|
||||
|
||||
Args:
|
||||
state: The workflow's state for persistence
|
||||
env_config: Configuration that populates the PowerFx ``Env``
|
||||
symbol when ``_to_powerfx_symbols`` is called. Defaults to
|
||||
an empty configuration which results in no ``Env`` binding,
|
||||
matching the safe default of the :class:`WorkflowFactory`.
|
||||
"""
|
||||
self._state = state
|
||||
self._env_config = env_config if env_config is not None else DeclarativeEnvConfig()
|
||||
|
||||
def initialize(self, inputs: Mapping[str, Any] | None = None) -> None:
|
||||
"""Initialize the declarative state with inputs.
|
||||
@@ -714,6 +816,14 @@ class DeclarativeWorkflowState:
|
||||
# Custom namespaces
|
||||
**state_data.get("Custom", {}),
|
||||
}
|
||||
# Resolve the ``Env`` symbol from the workflow-level
|
||||
# :class:`DeclarativeEnvConfig`. When both ``values`` and the
|
||||
# ``os.environ`` allowlist produce no entries the symbol is
|
||||
# omitted so ``=Env.X`` falls back to the literal expression
|
||||
# string (preserving the legacy "unbound identifier" behaviour).
|
||||
env_bound = self._env_config.resolve()
|
||||
if env_bound:
|
||||
symbols["Env"] = env_bound
|
||||
# Debug log the Local symbols to help diagnose type issues
|
||||
if local_data:
|
||||
for key, value in local_data.items():
|
||||
@@ -867,6 +977,11 @@ class DeclarativeActionExecutor(Executor):
|
||||
action_id = id or action_def.get("id") or f"{action_def.get('kind', 'action')}_{hash(str(action_def)) % 10000}"
|
||||
super().__init__(id=action_id, defer_discovery=True)
|
||||
self._action_def = action_def
|
||||
# The active :class:`DeclarativeEnvConfig` is stamped onto the
|
||||
# executor by :class:`DeclarativeWorkflowBuilder` after construction.
|
||||
# Defaults to an empty configuration so direct ``DeclarativeActionExecutor``
|
||||
# construction (e.g. in unit tests) doesn't expose ``os.environ``.
|
||||
self._declarative_env_config: DeclarativeEnvConfig = DeclarativeEnvConfig()
|
||||
|
||||
# Manually register handlers after initialization
|
||||
self._handlers = {}
|
||||
@@ -874,6 +989,16 @@ class DeclarativeActionExecutor(Executor):
|
||||
self._discover_handlers()
|
||||
self._discover_response_handlers()
|
||||
|
||||
def set_declarative_env_config(self, env_config: DeclarativeEnvConfig) -> None:
|
||||
"""Set the workflow-level :class:`DeclarativeEnvConfig` for this executor.
|
||||
|
||||
Called by :class:`DeclarativeWorkflowBuilder` after each executor is
|
||||
created so that ``_to_powerfx_symbols`` populates the ``Env`` symbol
|
||||
according to the caller-supplied configuration on the
|
||||
:class:`WorkflowFactory`.
|
||||
"""
|
||||
self._declarative_env_config = env_config
|
||||
|
||||
@property
|
||||
def action_def(self) -> dict[str, Any]:
|
||||
"""Get the action definition."""
|
||||
@@ -886,7 +1011,7 @@ class DeclarativeActionExecutor(Executor):
|
||||
|
||||
def _get_state(self, state: State) -> DeclarativeWorkflowState:
|
||||
"""Get the declarative workflow state wrapper."""
|
||||
return DeclarativeWorkflowState(state)
|
||||
return DeclarativeWorkflowState(state, env_config=self._declarative_env_config)
|
||||
|
||||
async def _ensure_state_initialized(
|
||||
self,
|
||||
|
||||
+16
@@ -24,6 +24,7 @@ from agent_framework import (
|
||||
from ._declarative_base import (
|
||||
ConditionResult,
|
||||
DeclarativeActionExecutor,
|
||||
DeclarativeEnvConfig,
|
||||
LoopIterationResult,
|
||||
)
|
||||
from ._errors import DeclarativeWorkflowError
|
||||
@@ -140,6 +141,7 @@ class DeclarativeWorkflowBuilder:
|
||||
max_iterations: int | None = None,
|
||||
http_request_handler: HttpRequestHandler | None = None,
|
||||
mcp_tool_handler: MCPToolHandler | None = None,
|
||||
env_config: DeclarativeEnvConfig | None = None,
|
||||
):
|
||||
"""Initialize the builder.
|
||||
|
||||
@@ -158,6 +160,10 @@ class DeclarativeWorkflowBuilder:
|
||||
mcp_tool_handler: Handler used to dispatch InvokeMcpTool calls.
|
||||
Must be supplied when the workflow contains any InvokeMcpTool;
|
||||
otherwise build raises ``DeclarativeWorkflowError``.
|
||||
env_config: Optional :class:`DeclarativeEnvConfig` controlling
|
||||
how the ``Env`` PowerFx symbol is populated for every
|
||||
executor built by this builder. Defaults to an empty
|
||||
configuration (``Env`` not exposed).
|
||||
"""
|
||||
self._yaml_def = yaml_definition
|
||||
self._workflow_id = workflow_id or yaml_definition.get("name", "declarative_workflow")
|
||||
@@ -171,6 +177,7 @@ class DeclarativeWorkflowBuilder:
|
||||
self._seen_explicit_ids: set[str] = set() # Track explicit IDs for duplicate detection
|
||||
self._http_request_handler = http_request_handler
|
||||
self._mcp_tool_handler = mcp_tool_handler
|
||||
self._env_config: DeclarativeEnvConfig = env_config if env_config is not None else DeclarativeEnvConfig()
|
||||
# Resolve max_iterations: explicit arg > YAML maxTurns > core default
|
||||
resolved = max_iterations if max_iterations is not None else yaml_definition.get("maxTurns")
|
||||
if resolved is not None and (not isinstance(resolved, int) or resolved <= 0):
|
||||
@@ -221,6 +228,15 @@ class DeclarativeWorkflowBuilder:
|
||||
# Resolve pending gotos (back-edges for loops, forward-edges for jumps)
|
||||
self._resolve_pending_gotos(builder)
|
||||
|
||||
# Stamp the resolved DeclarativeEnvConfig onto every executor so they
|
||||
# expose the configured Env binding through their _get_state(). This
|
||||
# happens after _create_executors_for_actions and _resolve_pending_gotos
|
||||
# so it covers the entry node, join nodes, evaluators, foreach
|
||||
# init/next/exit nodes, and goto placeholders.
|
||||
for executor in self._executors.values():
|
||||
if isinstance(executor, DeclarativeActionExecutor):
|
||||
executor.set_declarative_env_config(self._env_config)
|
||||
|
||||
return builder.build()
|
||||
|
||||
def _validate_workflow(self, actions: list[dict[str, Any]]) -> None:
|
||||
|
||||
@@ -26,6 +26,7 @@ from agent_framework import (
|
||||
)
|
||||
|
||||
from .._loader import AgentFactory
|
||||
from ._declarative_base import DeclarativeEnvConfig, discover_env_references
|
||||
from ._declarative_builder import DeclarativeWorkflowBuilder
|
||||
from ._errors import DeclarativeWorkflowError
|
||||
from ._http_handler import HttpRequestHandler
|
||||
@@ -93,6 +94,8 @@ class WorkflowFactory:
|
||||
max_iterations: int | None = None,
|
||||
http_request_handler: HttpRequestHandler | None = None,
|
||||
mcp_tool_handler: MCPToolHandler | None = None,
|
||||
configuration: Mapping[str, str] | None = None,
|
||||
restrict_env_to_configuration: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the workflow factory.
|
||||
|
||||
@@ -119,6 +122,23 @@ class WorkflowFactory:
|
||||
for a default backed by :class:`agent_framework.MCPStreamableHTTPTool`,
|
||||
or supply your own implementation to enforce SSRF guards, allowlisting,
|
||||
or auth/connection resolution.
|
||||
configuration: Optional mapping that populates the PowerFx ``Env``
|
||||
symbol referenced from workflow YAML expressions (e.g.
|
||||
``=Env.MY_KEY``). Keys supplied here are always exposed
|
||||
under ``Env.<key>``; the process ``os.environ`` is consulted
|
||||
only when ``restrict_env_to_configuration`` is ``False``.
|
||||
When neither source produces a value the ``Env`` symbol is
|
||||
omitted so ``=Env.X`` evaluates to the literal expression
|
||||
string.
|
||||
restrict_env_to_configuration: When ``True`` (default), the
|
||||
``Env`` PowerFx symbol is populated exclusively from
|
||||
``configuration``; ``os.environ`` is never consulted. Set to
|
||||
``False`` to additionally fall back to ``os.environ`` for
|
||||
names absent from ``configuration`` that the workflow YAML
|
||||
explicitly references. The fallback is constrained to names
|
||||
discovered in PowerFx expressions inside the workflow
|
||||
definition so unrelated environment variables never enter
|
||||
the PowerFx scope.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -151,6 +171,18 @@ class WorkflowFactory:
|
||||
checkpoint_storage=FileCheckpointStorage("./checkpoints"),
|
||||
env_file=".env",
|
||||
)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework.declarative import WorkflowFactory
|
||||
|
||||
# Inject named values for =Env.* references in the workflow YAML
|
||||
factory = WorkflowFactory(
|
||||
configuration={
|
||||
"MY_SERVER_URL": "https://example.com",
|
||||
"MY_TOOL_NAME": "search",
|
||||
},
|
||||
)
|
||||
"""
|
||||
self._agent_factory = agent_factory or AgentFactory(env_file_path=env_file)
|
||||
self._agents: dict[str, SupportsAgentRun | AgentExecutor] = dict(agents) if agents else {}
|
||||
@@ -160,6 +192,8 @@ class WorkflowFactory:
|
||||
self._max_iterations = max_iterations
|
||||
self._http_request_handler = http_request_handler
|
||||
self._mcp_tool_handler = mcp_tool_handler
|
||||
self._configuration: dict[str, str] = dict(configuration) if configuration else {}
|
||||
self._restrict_env_to_configuration = restrict_env_to_configuration
|
||||
|
||||
def create_workflow_from_yaml_path(
|
||||
self,
|
||||
@@ -394,6 +428,16 @@ class WorkflowFactory:
|
||||
if description:
|
||||
normalized_def["description"] = description
|
||||
|
||||
# Build the DeclarativeEnvConfig from the factory's configuration and the
|
||||
# set of Env references actually used in the workflow PowerFx expressions.
|
||||
# The referenced-name allowlist constrains ``os.environ`` fallback (when
|
||||
# enabled) so unrelated variables never enter the PowerFx scope.
|
||||
env_config = DeclarativeEnvConfig(
|
||||
values=dict(self._configuration),
|
||||
restrict_to_configuration=self._restrict_env_to_configuration,
|
||||
referenced_names=frozenset(discover_env_references(normalized_def)),
|
||||
)
|
||||
|
||||
# Build the graph-based workflow, passing agents and tools for specialized executors
|
||||
try:
|
||||
graph_builder = DeclarativeWorkflowBuilder(
|
||||
@@ -405,6 +449,7 @@ class WorkflowFactory:
|
||||
max_iterations=self._max_iterations,
|
||||
http_request_handler=self._http_request_handler,
|
||||
mcp_tool_handler=self._mcp_tool_handler,
|
||||
env_config=env_config,
|
||||
)
|
||||
workflow = graph_builder.build()
|
||||
except ValueError as e:
|
||||
|
||||
@@ -33,7 +33,7 @@ import logging
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast, runtime_checkable
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -194,6 +194,21 @@ class DefaultMCPToolHandler:
|
||||
Defaults to ``32``.
|
||||
"""
|
||||
|
||||
LIST_TOOLS_TOOL_NAME: ClassVar[str] = "tools/list"
|
||||
"""Reserved ``tool_name`` that maps an :class:`MCPToolHandler` invocation
|
||||
to the MCP protocol ``tools/list`` discovery operation.
|
||||
|
||||
The constant matches the underlying MCP method name so a single
|
||||
string travels unchanged through host code, YAML, and the protocol
|
||||
wire. When this handler receives an invocation with this name it
|
||||
pages through ``session.list_tools()`` and returns the catalog as a
|
||||
single ``TextContent`` containing JSON of shape
|
||||
``{"tools": [{name, description, inputSchema, outputSchema}, ...]}``.
|
||||
Workflows can reference this name from an ``InvokeMcpTool`` declarative
|
||||
action to introspect a server's tool surface without an extra round-trip
|
||||
from host code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -217,10 +232,27 @@ class DefaultMCPToolHandler:
|
||||
self._closed = False
|
||||
|
||||
async def invoke_tool(self, invocation: MCPToolInvocation) -> MCPToolResult:
|
||||
"""Invoke ``invocation.tool_name`` on the cached MCP client for the server."""
|
||||
"""Invoke ``invocation.tool_name`` on the cached MCP client for the server.
|
||||
|
||||
The reserved name :attr:`LIST_TOOLS_TOOL_NAME` (``"tools/list"``) is
|
||||
intercepted client-side: instead of being forwarded as a tool call,
|
||||
it is translated to an MCP ``session.list_tools()`` discovery
|
||||
operation (paginated automatically) and returned as a single
|
||||
``TextContent`` containing a JSON tool catalog.
|
||||
"""
|
||||
from agent_framework import Content
|
||||
from agent_framework.exceptions import ToolExecutionException
|
||||
|
||||
# Reserved-name args validation runs before connect: rejecting bad
|
||||
# input shouldn't require establishing an MCP session.
|
||||
if invocation.tool_name == self.LIST_TOOLS_TOOL_NAME and invocation.arguments:
|
||||
message = f"The reserved MCP '{self.LIST_TOOLS_TOOL_NAME}' operation does not accept tool arguments."
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
|
||||
try:
|
||||
entry = await self._get_or_create_entry(invocation)
|
||||
except Exception as exc:
|
||||
@@ -240,6 +272,8 @@ class DefaultMCPToolHandler:
|
||||
)
|
||||
|
||||
try:
|
||||
if invocation.tool_name == self.LIST_TOOLS_TOOL_NAME:
|
||||
return await self._invoke_list_tools(entry)
|
||||
raw = await entry.tool.call_tool(invocation.tool_name, **invocation.arguments)
|
||||
except ToolExecutionException as exc:
|
||||
logger.info(
|
||||
@@ -284,6 +318,59 @@ class DefaultMCPToolHandler:
|
||||
outputs = list(raw)
|
||||
return MCPToolResult(outputs=outputs)
|
||||
|
||||
@staticmethod
|
||||
async def _invoke_list_tools(entry: _CacheEntry) -> MCPToolResult:
|
||||
"""Handle the reserved :attr:`LIST_TOOLS_TOOL_NAME` invocation.
|
||||
|
||||
Pages through ``session.list_tools()`` (mirroring the pagination loop
|
||||
in :meth:`agent_framework.MCPTool.load_tools`) and serialises the
|
||||
full catalog as a single ``TextContent`` containing JSON of shape
|
||||
``{"tools": [{name, description, inputSchema, outputSchema}, ...]}``.
|
||||
|
||||
The output shape, property names, and property order are stable so
|
||||
downstream PowerFx expressions can rely on the schema. ``indent=2``
|
||||
produces human-readable JSON for the conversation log;
|
||||
``allow_nan=False`` guards against producing non-conformant JSON
|
||||
``NaN``/``Infinity`` tokens if a misbehaving server returns such
|
||||
values in a schema.
|
||||
"""
|
||||
from agent_framework import Content
|
||||
|
||||
session = getattr(entry.tool, "session", None)
|
||||
if session is None:
|
||||
message = "MCP session is not connected; cannot list tools."
|
||||
return MCPToolResult(
|
||||
outputs=[Content.from_text(f"Error: {message}")],
|
||||
is_error=True,
|
||||
error_message=message,
|
||||
)
|
||||
|
||||
# Lazy import keeps ``mcp`` types out of module import time.
|
||||
from mcp import types as mcp_types
|
||||
|
||||
collected: list[Any] = []
|
||||
params: mcp_types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
tool_list = await session.list_tools(params=params)
|
||||
collected.extend(tool_list.tools)
|
||||
next_cursor = getattr(tool_list, "nextCursor", None)
|
||||
if not next_cursor:
|
||||
break
|
||||
params = mcp_types.PaginatedRequestParams(cursor=next_cursor)
|
||||
|
||||
payload = {
|
||||
"tools": [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema,
|
||||
"outputSchema": tool.outputSchema,
|
||||
}
|
||||
for tool in collected
|
||||
],
|
||||
}
|
||||
return MCPToolResult(outputs=[Content.from_text(json.dumps(payload, indent=2, allow_nan=False))])
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close all cached MCP clients and the owned httpx clients.
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ owned-vs-caller httpx close semantics.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
@@ -33,6 +34,55 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
class FakeListToolsResult: # noqa: B903 - mimics ``mcp.types.ListToolsResult`` shape, not a value type
|
||||
"""Stand-in for ``mcp.types.ListToolsResult`` returned by ``session.list_tools()``."""
|
||||
|
||||
def __init__(self, tools: list[Any], next_cursor: str | None = None) -> None:
|
||||
self.tools = tools
|
||||
self.nextCursor = next_cursor
|
||||
|
||||
|
||||
class FakeMcpTool:
|
||||
"""Stand-in for an MCP ``Tool`` (subset used by ``_invoke_list_tools``)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
inputSchema: dict[str, Any] | None = None,
|
||||
outputSchema: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.inputSchema = inputSchema if inputSchema is not None else {"type": "object", "properties": {}}
|
||||
self.outputSchema = outputSchema
|
||||
|
||||
|
||||
class FakeMcpSession:
|
||||
"""Stand-in for ``mcp.ClientSession``.
|
||||
|
||||
``list_tools_pages`` lets a test enqueue multiple paginated responses;
|
||||
when None (default), an empty single-page result is returned. ``list_tools_error``
|
||||
raises a synthetic error on the next call when set.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.list_tools_pages: list[FakeListToolsResult] | None = None
|
||||
self.list_tools_calls: list[Any] = []
|
||||
self.list_tools_error: BaseException | None = None
|
||||
|
||||
async def list_tools(self, params: Any = None) -> FakeListToolsResult:
|
||||
self.list_tools_calls.append(params)
|
||||
if self.list_tools_error is not None:
|
||||
raise self.list_tools_error
|
||||
if self.list_tools_pages is None:
|
||||
return FakeListToolsResult(tools=[])
|
||||
index = len(self.list_tools_calls) - 1
|
||||
if index >= len(self.list_tools_pages):
|
||||
return FakeListToolsResult(tools=[])
|
||||
return self.list_tools_pages[index]
|
||||
|
||||
|
||||
class FakeTool:
|
||||
"""Stand-in for ``MCPStreamableHTTPTool``.
|
||||
|
||||
@@ -50,6 +100,7 @@ class FakeTool:
|
||||
self.connect_error: BaseException | None = None
|
||||
self.call_handler: Any = lambda **_a: [Content.from_text("ok")]
|
||||
self._httpx_client: httpx.AsyncClient | None = None
|
||||
self.session: FakeMcpSession | None = None
|
||||
# Mimic MCPStreamableHTTPTool: when no caller client AND header_provider
|
||||
# is set, lazily allocate an owned httpx client during connect.
|
||||
FakeTool.instances.append(self)
|
||||
@@ -63,6 +114,9 @@ class FakeTool:
|
||||
# Mimic lazy httpx allocation when no client provided AND header_provider set.
|
||||
if self.kwargs.get("http_client") is None and self.kwargs.get("header_provider") is not None:
|
||||
self._httpx_client = httpx.AsyncClient()
|
||||
# Mimic MCPStreamableHTTPTool: a live session becomes available after connect.
|
||||
if self.session is None:
|
||||
self.session = FakeMcpSession()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_count += 1
|
||||
@@ -541,3 +595,185 @@ class TestCacheKey:
|
||||
k1 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "Bearer-A"})
|
||||
k2 = DefaultMCPToolHandler._cache_key("https://x/", None, None, {"X": "bearer-a"})
|
||||
assert k1 != k2
|
||||
|
||||
|
||||
# ---------- tools/list reserved name --------------------------------------
|
||||
|
||||
|
||||
class TestListTools:
|
||||
"""Exercise the reserved :attr:`DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME` interception path."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_returns_json_catalog(self) -> None:
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
# Prime the cache so the FakeTool session exists.
|
||||
await handler.invoke_tool(_invocation())
|
||||
FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr]
|
||||
FakeListToolsResult(
|
||||
tools=[
|
||||
FakeMcpTool(
|
||||
name="search",
|
||||
description="Search docs",
|
||||
inputSchema={"type": "object", "properties": {"q": {"type": "string"}}},
|
||||
outputSchema={"type": "object"},
|
||||
),
|
||||
FakeMcpTool(name="echo", description=None, outputSchema=None),
|
||||
],
|
||||
),
|
||||
]
|
||||
result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
assert result.is_error is False
|
||||
assert len(result.outputs) == 1
|
||||
payload = json.loads(result.outputs[0].text) # type: ignore[reportAttributeAccessIssue]
|
||||
assert payload == {
|
||||
"tools": [
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search docs",
|
||||
"inputSchema": {"type": "object", "properties": {"q": {"type": "string"}}},
|
||||
"outputSchema": {"type": "object"},
|
||||
},
|
||||
{
|
||||
"name": "echo",
|
||||
"description": None,
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
"outputSchema": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_property_order_is_stable(self) -> None:
|
||||
"""JSON property order is stable: name, description, inputSchema, outputSchema."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr]
|
||||
FakeListToolsResult(tools=[FakeMcpTool(name="t1", description="d")]),
|
||||
]
|
||||
result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue]
|
||||
name_idx = text.find('"name"')
|
||||
desc_idx = text.find('"description"')
|
||||
input_idx = text.find('"inputSchema"')
|
||||
output_idx = text.find('"outputSchema"')
|
||||
assert 0 <= name_idx < desc_idx < input_idx < output_idx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_indented_output(self) -> None:
|
||||
"""Output is JSON with a 2-space indent so the conversation log is human-readable."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr]
|
||||
FakeListToolsResult(tools=[FakeMcpTool(name="t1")]),
|
||||
]
|
||||
result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue]
|
||||
# Indented output contains newlines and a 2-space indented key.
|
||||
assert "\n " in text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_rejects_arguments(self) -> None:
|
||||
"""Reserved name does NOT accept tool arguments. Fails fast before connect."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
result = await handler.invoke_tool(
|
||||
_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME, arguments={"q": "test"}),
|
||||
)
|
||||
assert result.is_error is True
|
||||
assert "does not accept tool arguments" in (result.error_message or "")
|
||||
# Args validation runs before connect, so no tool was instantiated.
|
||||
assert FakeTool.instances == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_empty_args_dict_is_accepted(self) -> None:
|
||||
"""An empty arguments dict is equivalent to no arguments."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
result = await handler.invoke_tool(
|
||||
_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME, arguments={}),
|
||||
)
|
||||
assert result.is_error is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_paginates(self) -> None:
|
||||
"""Pagination loop calls list_tools repeatedly until nextCursor is empty."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr]
|
||||
FakeListToolsResult(tools=[FakeMcpTool(name="a")], next_cursor="cursor1"),
|
||||
FakeListToolsResult(tools=[FakeMcpTool(name="b")], next_cursor="cursor2"),
|
||||
FakeListToolsResult(tools=[FakeMcpTool(name="c")], next_cursor=None),
|
||||
]
|
||||
result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
payload = json.loads(result.outputs[0].text) # type: ignore[reportAttributeAccessIssue]
|
||||
assert [t["name"] for t in payload["tools"]] == ["a", "b", "c"]
|
||||
session = FakeTool.instances[0].session
|
||||
assert session is not None
|
||||
assert len(session.list_tools_calls) == 3
|
||||
# First call has no cursor; second/third use the cursor from the prior page.
|
||||
assert session.list_tools_calls[0] is None
|
||||
assert getattr(session.list_tools_calls[1], "cursor", None) == "cursor1"
|
||||
assert getattr(session.list_tools_calls[2], "cursor", None) == "cursor2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_shares_cache_with_call_tool(self) -> None:
|
||||
"""tools/list reuses the same cached MCP session as a regular call_tool."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation(tool_name="search"))
|
||||
await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
assert len(FakeTool.instances) == 1
|
||||
assert FakeTool.instances[0].connect_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_propagates_session_errors_as_error_result(self) -> None:
|
||||
"""Errors raised by session.list_tools become MCPToolResult(is_error=True), not crashes."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
FakeTool.instances[0].session.list_tools_error = httpx.ReadTimeout("read timed out") # type: ignore[union-attr]
|
||||
result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
assert result.is_error is True
|
||||
assert "ReadTimeout" in (result.error_message or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_returns_error_when_session_is_none(self) -> None:
|
||||
"""If somehow the cached tool has no session, return a clear error rather than crashing."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
FakeTool.instances[0].session = None
|
||||
result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
assert result.is_error is True
|
||||
assert "not connected" in (result.error_message or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_does_not_call_call_tool(self) -> None:
|
||||
"""The reserved name is intercepted; the inner call_tool path is bypassed."""
|
||||
handler = DefaultMCPToolHandler()
|
||||
call_tool_invoked = False
|
||||
|
||||
def fail(**_a: Any) -> Any:
|
||||
nonlocal call_tool_invoked
|
||||
call_tool_invoked = True
|
||||
raise AssertionError("call_tool should not run for tools/list")
|
||||
|
||||
with _patch_tool():
|
||||
await handler.invoke_tool(_invocation())
|
||||
FakeTool.instances[0].call_handler = fail
|
||||
FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr]
|
||||
FakeListToolsResult(tools=[]),
|
||||
]
|
||||
result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME))
|
||||
assert call_tool_invoked is False
|
||||
assert result.is_error is False
|
||||
|
||||
def test_class_attribute_value(self) -> None:
|
||||
# Constant must equal the MCP protocol method name so a single
|
||||
# string travels unchanged through host code, YAML, and the wire.
|
||||
assert DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME == "tools/list"
|
||||
|
||||
Reference in New Issue
Block a user