Python: Agent Name Sanitization (#1523)

* a2a agent name sanitization

* fix

* small fix

---------

Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
Giles Odigwe
2025-10-16 13:33:25 -07:00
committed by GitHub
Unverified
parent 85921eda68
commit 76ae0a62ac
3 changed files with 112 additions and 3 deletions
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import inspect
import re
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack
@@ -48,6 +49,44 @@ logger = get_logger("agent_framework")
TThreadType = TypeVar("TThreadType", bound="AgentThread")
def _sanitize_agent_name(agent_name: str | None) -> str | None:
"""Sanitize agent name for use as a function name.
Replaces spaces and special characters with underscores to create
a valid Python identifier.
Args:
agent_name: The agent name to sanitize.
Returns:
The sanitized agent name with invalid characters replaced by underscores.
If the input is None, returns None.
If sanitization results in an empty string (e.g., agent_name="@@@"), returns "agent" as a default.
"""
if agent_name is None:
return None
# Replace any character that is not alphanumeric or underscore with underscore
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", agent_name)
# Replace multiple consecutive underscores with a single underscore
sanitized = re.sub(r"_+", "_", sanitized)
# Remove leading/trailing underscores
sanitized = sanitized.strip("_")
# Handle empty string case
if not sanitized:
return "agent"
# Prefix with underscore if the sanitized name starts with a digit
if sanitized and sanitized[0].isdigit():
sanitized = f"_{sanitized}"
return sanitized
__all__ = ["AgentProtocol", "BaseAgent", "ChatAgent"]
@@ -396,7 +435,7 @@ class BaseAgent(SerializationMixin):
if not isinstance(self, AgentProtocol):
raise TypeError(f"Agent {self.__class__.__name__} must implement AgentProtocol to be used as a tool")
tool_name = name or self.name
tool_name = name or _sanitize_agent_name(self.name)
if tool_name is None:
raise ValueError("Agent tool name cannot be None. Either provide a name parameter or set the agent's name.")
tool_description = description or self.description or ""
@@ -404,7 +443,8 @@ class BaseAgent(SerializationMixin):
# Create dynamic input model with the specified argument name
field_info = Field(..., description=argument_description)
input_model = create_model(f"{name or self.name or 'agent'}_task", **{arg_name: (str, field_info)}) # type: ignore[call-overload]
model_name = f"{name or _sanitize_agent_name(self.name) or 'agent'}_task"
input_model = create_model(model_name, **{arg_name: (str, field_info)}) # type: ignore[call-overload]
# Check if callback is async once, outside the wrapper
is_async_callback = stream_callback is not None and inspect.iscoroutinefunction(stream_callback)
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import contextlib
from collections.abc import AsyncIterable, MutableSequence, Sequence
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
from pytest import raises
@@ -23,6 +25,7 @@ from agent_framework import (
Role,
TextContent,
)
from agent_framework._mcp import MCPTool
from agent_framework.exceptions import AgentExecutionException
@@ -506,3 +509,69 @@ async def test_chat_agent_as_tool_with_async_stream_callback(chat_client: ChatCl
# Result should be concatenation of all streaming updates
expected_text = "".join(update.text for update in collected_updates)
assert result == expected_text
async def test_chat_agent_as_tool_name_sanitization(chat_client: ChatClientProtocol) -> None:
"""Test as_tool name sanitization."""
test_cases = [
("Invoice & Billing Agent", "Invoice_Billing_Agent"),
("Travel & Logistics Agent", "Travel_Logistics_Agent"),
("Agent@Company.com", "Agent_Company_com"),
("Agent___Multiple___Underscores", "Agent_Multiple_Underscores"),
("123Agent", "_123Agent"), # Test digit prefix handling
("9to5Helper", "_9to5Helper"), # Another digit prefix case
("@@@", "agent"), # Test empty sanitization fallback
]
for agent_name, expected_tool_name in test_cases:
agent = ChatAgent(chat_client=chat_client, name=agent_name, description="Test agent")
tool = agent.as_tool()
assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}"
async def test_chat_agent_as_mcp_server_basic(chat_client: ChatClientProtocol) -> None:
"""Test basic as_mcp_server functionality."""
agent = ChatAgent(chat_client=chat_client, name="TestAgent", description="Test agent for MCP")
# Create MCP server with default parameters
server = agent.as_mcp_server()
# Verify server is created
assert server is not None
assert hasattr(server, "name")
assert hasattr(server, "version")
async def test_chat_agent_run_with_mcp_tools(chat_client: ChatClientProtocol) -> None:
"""Test run method with MCP tools to cover MCP tool handling code."""
agent = ChatAgent(chat_client=chat_client, name="TestAgent", description="Test agent")
# Create a mock MCP tool
mock_mcp_tool = MagicMock(spec=MCPTool)
mock_mcp_tool.is_connected = False
mock_mcp_tool.functions = [MagicMock()]
# Mock the async context manager entry
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
mock_mcp_tool.__aexit__ = AsyncMock(return_value=None)
# Test run with MCP tools - this should hit the MCP tool handling code
with contextlib.suppress(Exception):
# We expect this to fail since we're using mocks, but we want to exercise the code path
await agent.run(messages="Test message", tools=[mock_mcp_tool])
async def test_chat_agent_with_local_mcp_tools(chat_client: ChatClientProtocol) -> None:
"""Test agent initialization with local MCP tools."""
# Create a mock MCP tool
mock_mcp_tool = MagicMock(spec=MCPTool)
mock_mcp_tool.is_connected = False
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
mock_mcp_tool.__aexit__ = AsyncMock(return_value=None)
# Test agent with MCP tools in constructor
with contextlib.suppress(Exception):
agent = ChatAgent(chat_client=chat_client, name="TestAgent", description="Test agent", tools=[mock_mcp_tool])
# Test async context manager with MCP tools
async with agent:
pass
@@ -439,7 +439,7 @@ async def test_get_streaming_response_with_all_parameters() -> None:
instructions="Stream response test",
max_tokens=50,
parallel_tool_calls=False,
model="gpt-4",
model_id="gpt-4",
previous_response_id="stream-prev-123",
reasoning={"mode": "stream"},
service_tier="default",