Python: fix(ag-ui): add MCP tool support for AG-UI approval flows (#3212)

* add MCP tool support for AG-UI approval flows

* use attribute in place of property
This commit is contained in:
Evan Mattson
2026-01-15 11:34:11 +09:00
committed by GitHub
Unverified
parent 80b25a782b
commit 620da7a829
7 changed files with 234 additions and 111 deletions
@@ -3,59 +3,85 @@
"""Tool handling helpers."""
import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from agent_framework import BaseChatClient, ChatAgent
from agent_framework import BaseChatClient
if TYPE_CHECKING:
from agent_framework import AgentProtocol
logger = logging.getLogger(__name__)
def collect_server_tools(agent: Any) -> list[Any]:
"""Collect server tools from ChatAgent or duck-typed agent."""
if isinstance(agent, ChatAgent):
tools_from_agent = agent.default_options.get("tools")
server_tools = list(tools_from_agent) if tools_from_agent else []
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
for tool in server_tools:
tool_name = getattr(tool, "name", "unknown")
approval_mode = getattr(tool, "approval_mode", None)
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
return server_tools
def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]:
"""Extract functions from connected MCP tools.
try:
default_options_attr = getattr(agent, "default_options", None)
if default_options_attr is not None:
if isinstance(default_options_attr, dict):
return default_options_attr.get("tools") or []
return getattr(default_options_attr, "tools", None) or []
except AttributeError:
Args:
mcp_tools: List of MCP tool instances.
Returns:
List of functions from connected MCP tools.
"""
functions: list[Any] = []
for mcp_tool in mcp_tools:
if getattr(mcp_tool, "is_connected", False) and hasattr(mcp_tool, "functions"):
functions.extend(mcp_tool.functions)
return functions
def collect_server_tools(agent: "AgentProtocol") -> list[Any]:
"""Collect server tools from an agent.
This includes both regular tools from default_options and MCP tools.
MCP tools are stored separately for lifecycle management but their
functions need to be included for tool execution during approval flows.
Args:
agent: Agent instance to collect tools from. Works with ChatAgent
or any agent with default_options and optional mcp_tools attributes.
Returns:
List of tools including both regular tools and connected MCP tool functions.
"""
# Get tools from default_options
default_options = getattr(agent, "default_options", None)
if default_options is None:
return []
return []
tools_from_agent = default_options.get("tools") if isinstance(default_options, dict) else None
server_tools = list(tools_from_agent) if tools_from_agent else []
# Include functions from connected MCP tools (only available on ChatAgent)
mcp_tools = getattr(agent, "mcp_tools", None)
if mcp_tools:
server_tools.extend(_collect_mcp_tool_functions(mcp_tools))
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
for tool in server_tools:
tool_name = getattr(tool, "name", "unknown")
approval_mode = getattr(tool, "approval_mode", None)
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
return server_tools
def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None:
"""Register client tools as additional declaration-only tools to avoid server execution."""
def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[Any] | None) -> None:
"""Register client tools as additional declaration-only tools to avoid server execution.
Args:
agent: Agent instance to register tools on. Works with ChatAgent
or any agent with a chat_client attribute.
client_tools: List of client tools to register.
"""
if not client_tools:
return
if isinstance(agent, ChatAgent):
chat_client = agent.chat_client
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
chat_client.function_invocation_configuration.additional_tools = client_tools
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")
chat_client = getattr(agent, "chat_client", None)
if chat_client is None:
return
try:
chat_client_attr = getattr(agent, "chat_client", None)
if chat_client_attr is not None:
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
if fic is not None:
fic.additional_tools = client_tools # type: ignore[attr-defined]
logger.debug(
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
)
except AttributeError:
return
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
chat_client.function_invocation_configuration.additional_tools = client_tools
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")
def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None:
@@ -3,10 +3,17 @@
"""Tests for AG-UI orchestrators."""
from collections.abc import AsyncGenerator
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
from agent_framework import AgentResponseUpdate, FunctionInvocationConfiguration, TextContent, ai_function
from agent_framework import (
AgentResponseUpdate,
BaseChatClient,
ChatAgent,
FunctionInvocationConfiguration,
TextContent,
ai_function,
)
from agent_framework_ag_ui._agent import AgentConfig
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext
@@ -18,56 +25,53 @@ def server_tool() -> str:
return "server"
class DummyAgent:
"""Minimal agent stub to capture run_stream parameters."""
def _create_mock_chat_agent(
tools: list[Any] | None = None,
response_format: Any = None,
capture_tools: list[Any] | None = None,
capture_messages: list[Any] | None = None,
) -> ChatAgent:
"""Create a ChatAgent with mocked chat client for testing.
def __init__(self) -> None:
self.default_options: dict[str, Any] = {"tools": [server_tool], "response_format": None}
self.tools = [server_tool]
self.chat_client = SimpleNamespace(
function_invocation_configuration=FunctionInvocationConfiguration(),
)
self.seen_tools: list[Any] | None = None
Args:
tools: Tools to configure on the agent.
response_format: Response format to configure.
capture_tools: If provided, tools passed to run_stream will be appended here.
capture_messages: If provided, messages passed to run_stream will be appended here.
"""
mock_chat_client = MagicMock(spec=BaseChatClient)
mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration()
async def run_stream(
self,
agent = ChatAgent(
chat_client=mock_chat_client,
tools=tools or [server_tool],
response_format=response_format,
)
# Create a mock run_stream that captures parameters and yields a simple response
async def mock_run_stream(
messages: list[Any],
*,
thread: Any,
thread: Any = None,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentResponseUpdate, None]:
self.seen_tools = tools
if capture_tools is not None and tools is not None:
capture_tools.extend(tools)
if capture_messages is not None:
capture_messages.extend(messages)
yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
# Patch the run_stream method
agent.run_stream = mock_run_stream # type: ignore[method-assign]
class RecordingAgent:
"""Agent stub that captures messages passed to run_stream."""
def __init__(self) -> None:
self.chat_options = SimpleNamespace(tools=[], response_format=None)
self.tools: list[Any] = []
self.chat_client = SimpleNamespace(
function_invocation_configuration=FunctionInvocationConfiguration(),
)
self.seen_messages: list[Any] | None = None
async def run_stream(
self,
messages: list[Any],
*,
thread: Any,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentResponseUpdate, None]:
self.seen_messages = messages
yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
return agent
async def test_default_orchestrator_merges_client_tools() -> None:
"""Client tool declarations are merged with server tools before running agent."""
agent = DummyAgent()
captured_tools: list[Any] = []
agent = _create_mock_chat_agent(tools=[server_tool], capture_tools=captured_tools)
orchestrator = DefaultOrchestrator()
input_data = {
@@ -100,8 +104,8 @@ async def test_default_orchestrator_merges_client_tools() -> None:
async for event in orchestrator.run(context):
events.append(event)
assert agent.seen_tools is not None
tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools]
assert len(captured_tools) > 0
tool_names = [getattr(tool, "name", "?") for tool in captured_tools]
assert "server_tool" in tool_names
assert "get_weather" in tool_names
assert agent.chat_client.function_invocation_configuration.additional_tools
@@ -109,8 +113,7 @@ async def test_default_orchestrator_merges_client_tools() -> None:
async def test_default_orchestrator_with_camel_case_ids() -> None:
"""Client tool is able to extract camelCase IDs."""
agent = DummyAgent()
agent = _create_mock_chat_agent()
orchestrator = DefaultOrchestrator()
input_data = {
@@ -143,8 +146,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None:
async def test_default_orchestrator_with_snake_case_ids() -> None:
"""Client tool is able to extract snake_case IDs."""
agent = DummyAgent()
agent = _create_mock_chat_agent()
orchestrator = DefaultOrchestrator()
input_data = {
@@ -177,8 +179,8 @@ async def test_default_orchestrator_with_snake_case_ids() -> None:
async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
"""State context should be injected when current state differs from tool call args."""
agent = RecordingAgent()
captured_messages: list[Any] = []
agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages)
orchestrator = DefaultOrchestrator()
tool_recipe = {"title": "Salad", "special_preferences": []}
@@ -215,9 +217,9 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
async for _event in orchestrator.run(context):
pass
assert agent.seen_messages is not None
assert len(captured_messages) > 0
state_messages = []
for msg in agent.seen_messages:
for msg in captured_messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value != "system":
continue
@@ -230,8 +232,8 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
async def test_state_context_not_injected_when_tool_call_matches_state() -> None:
"""State context should be skipped when tool call args match current state."""
agent = RecordingAgent()
captured_messages: list[Any] = []
agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages)
orchestrator = DefaultOrchestrator()
input_data = {
@@ -264,9 +266,9 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None
async for _event in orchestrator.run(context):
pass
assert agent.seen_messages is not None
assert len(captured_messages) > 0
state_messages = []
for msg in agent.seen_messages:
for msg in captured_messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value != "system":
continue
+103 -8
View File
@@ -1,8 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.
from types import SimpleNamespace
from unittest.mock import MagicMock
from agent_framework_ag_ui._orchestration._tooling import merge_tools, register_additional_client_tools
from agent_framework import ChatAgent, ai_function
from agent_framework_ag_ui._orchestration._tooling import (
collect_server_tools,
merge_tools,
register_additional_client_tools,
)
class DummyTool:
@@ -11,6 +17,30 @@ class DummyTool:
self.declaration_only = True
class MockMCPTool:
"""Mock MCP tool that simulates connected MCP tool with functions."""
def __init__(self, functions: list[DummyTool], is_connected: bool = True) -> None:
self.functions = functions
self.is_connected = is_connected
@ai_function
def regular_tool() -> str:
"""Regular tool for testing."""
return "result"
def _create_chat_agent_with_tool(tool_name: str = "regular_tool") -> ChatAgent:
"""Create a ChatAgent with a mocked chat client and a simple tool.
Note: tool_name parameter is kept for API compatibility but the tool
will always be named 'regular_tool' since ai_function uses the function name.
"""
mock_chat_client = MagicMock()
return ChatAgent(chat_client=mock_chat_client, tools=[regular_tool])
def test_merge_tools_filters_duplicates() -> None:
server = [DummyTool("a"), DummyTool("b")]
client = [DummyTool("b"), DummyTool("c")]
@@ -23,14 +53,79 @@ def test_merge_tools_filters_duplicates() -> None:
def test_register_additional_client_tools_assigns_when_configured() -> None:
class Fic:
def __init__(self) -> None:
self.additional_tools = None
"""register_additional_client_tools should set additional_tools on the chat client."""
from agent_framework import BaseChatClient, FunctionInvocationConfiguration
holder = SimpleNamespace(function_invocation_configuration=Fic())
agent = SimpleNamespace(chat_client=holder)
mock_chat_client = MagicMock(spec=BaseChatClient)
mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration()
agent = ChatAgent(chat_client=mock_chat_client)
tools = [DummyTool("x")]
register_additional_client_tools(agent, tools)
assert holder.function_invocation_configuration.additional_tools == tools
assert mock_chat_client.function_invocation_configuration.additional_tools == tools
def test_collect_server_tools_includes_mcp_tools_when_connected() -> None:
"""MCP tool functions should be included when the MCP tool is connected."""
mcp_function1 = DummyTool("mcp_function_1")
mcp_function2 = DummyTool("mcp_function_2")
mock_mcp = MockMCPTool([mcp_function1, mcp_function2], is_connected=True)
agent = _create_chat_agent_with_tool("regular_tool")
agent.mcp_tools = [mock_mcp]
tools = collect_server_tools(agent)
names = [getattr(t, "name", None) for t in tools]
assert "regular_tool" in names
assert "mcp_function_1" in names
assert "mcp_function_2" in names
assert len(tools) == 3
def test_collect_server_tools_excludes_mcp_tools_when_not_connected() -> None:
"""MCP tool functions should be excluded when the MCP tool is not connected."""
mcp_function = DummyTool("mcp_function")
mock_mcp = MockMCPTool([mcp_function], is_connected=False)
agent = _create_chat_agent_with_tool("regular_tool")
agent.mcp_tools = [mock_mcp]
tools = collect_server_tools(agent)
names = [getattr(t, "name", None) for t in tools]
assert "regular_tool" in names
assert "mcp_function" not in names
assert len(tools) == 1
def test_collect_server_tools_works_with_no_mcp_tools() -> None:
"""collect_server_tools should work when there are no MCP tools."""
agent = _create_chat_agent_with_tool("regular_tool")
tools = collect_server_tools(agent)
names = [getattr(t, "name", None) for t in tools]
assert "regular_tool" in names
assert len(tools) == 1
def test_collect_server_tools_with_mcp_tools_via_public_property() -> None:
"""collect_server_tools should access MCP tools via the public mcp_tools property."""
mcp_function = DummyTool("mcp_function")
mock_mcp = MockMCPTool([mcp_function], is_connected=True)
agent = _create_chat_agent_with_tool("regular_tool")
agent.mcp_tools = [mock_mcp]
# Verify the public property works
assert agent.mcp_tools == [mock_mcp]
tools = collect_server_tools(agent)
names = [getattr(t, "name", None) for t in tools]
assert "regular_tool" in names
assert "mcp_function" in names
assert len(tools) == 2
@@ -678,7 +678,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type:ignore[reportUnknownVariableType]
[] if tools_ is None else tools_ if isinstance(tools_, list) else [tools_] # type: ignore[list-item]
)
self._local_mcp_tools = [tool for tool in normalized_tools if isinstance(tool, MCPTool)]
self.mcp_tools: list[MCPTool] = [tool for tool in normalized_tools if isinstance(tool, MCPTool)]
agent_tools = [tool for tool in normalized_tools if not isinstance(tool, MCPTool)]
# Build chat options dict
@@ -720,7 +720,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
Returns:
The ChatAgent instance.
"""
for context_manager in chain([self.chat_client], self._local_mcp_tools):
for context_manager in chain([self.chat_client], self.mcp_tools):
if isinstance(context_manager, AbstractAsyncContextManager):
await self._async_exit_stack.enter_async_context(context_manager)
return self
@@ -817,7 +817,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
else:
final_tools.append(tool) # type: ignore
for mcp_server in self._local_mcp_tools:
for mcp_server in self.mcp_tools:
if not mcp_server.is_connected:
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(mcp_server.functions)
@@ -944,7 +944,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
else:
final_tools.append(tool)
for mcp_server in self._local_mcp_tools:
for mcp_server in self.mcp_tools:
if not mcp_server.is_connected:
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(mcp_server.functions)
@@ -275,13 +275,13 @@ class HandoffAgentExecutor(AgentExecutor):
middleware = list(agent.middleware or [])
# Reconstruct the original tools list by combining regular tools with MCP tools.
# ChatAgent.__init__ separates MCP tools into _local_mcp_tools during initialization,
# ChatAgent.__init__ separates MCP tools during initialization,
# so we need to recombine them here to pass the complete tools list to the constructor.
# This makes sure MCP tools are preserved when cloning agents for handoff workflows.
tools_from_options = options.get("tools")
all_tools = list(tools_from_options) if tools_from_options else []
if agent._local_mcp_tools: # type: ignore
all_tools.extend(agent._local_mcp_tools) # type: ignore
if agent.mcp_tools:
all_tools.extend(agent.mcp_tools)
logit_bias = options.get("logit_bias")
metadata = options.get("metadata")
@@ -114,10 +114,10 @@ class AgentFrameworkExecutor:
Args:
agent: Agent object that may have MCP tools
"""
if not hasattr(agent, "_local_mcp_tools"):
if not hasattr(agent, "mcp_tools"):
return
for mcp_tool in agent._local_mcp_tools:
for mcp_tool in agent.mcp_tools:
if not getattr(mcp_tool, "is_connected", False):
continue
@@ -248,9 +248,9 @@ class DevServer:
except Exception as e:
logger.warning(f"Error closing credential for {entity_info.id}: {e}")
# Close MCP tools (framework tracks them in _local_mcp_tools)
if entity_obj and hasattr(entity_obj, "_local_mcp_tools"):
for mcp_tool in entity_obj._local_mcp_tools:
# Close MCP tools (framework tracks them in mcp_tools)
if entity_obj and hasattr(entity_obj, "mcp_tools"):
for mcp_tool in entity_obj.mcp_tools:
if hasattr(mcp_tool, "close") and callable(mcp_tool.close):
try:
if inspect.iscoroutinefunction(mcp_tool.close):