mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Agent Name Sanitization (#1523)
* a2a agent name sanitization * fix * small fix --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
85921eda68
commit
76ae0a62ac
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
@@ -48,6 +49,44 @@ logger = get_logger("agent_framework")
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound="AgentThread")
|
||||
|
||||
|
||||
def _sanitize_agent_name(agent_name: str | None) -> str | None:
|
||||
"""Sanitize agent name for use as a function name.
|
||||
|
||||
Replaces spaces and special characters with underscores to create
|
||||
a valid Python identifier.
|
||||
|
||||
Args:
|
||||
agent_name: The agent name to sanitize.
|
||||
|
||||
Returns:
|
||||
The sanitized agent name with invalid characters replaced by underscores.
|
||||
If the input is None, returns None.
|
||||
If sanitization results in an empty string (e.g., agent_name="@@@"), returns "agent" as a default.
|
||||
"""
|
||||
if agent_name is None:
|
||||
return None
|
||||
|
||||
# Replace any character that is not alphanumeric or underscore with underscore
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", agent_name)
|
||||
|
||||
# Replace multiple consecutive underscores with a single underscore
|
||||
sanitized = re.sub(r"_+", "_", sanitized)
|
||||
|
||||
# Remove leading/trailing underscores
|
||||
sanitized = sanitized.strip("_")
|
||||
|
||||
# Handle empty string case
|
||||
if not sanitized:
|
||||
return "agent"
|
||||
|
||||
# Prefix with underscore if the sanitized name starts with a digit
|
||||
if sanitized and sanitized[0].isdigit():
|
||||
sanitized = f"_{sanitized}"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
__all__ = ["AgentProtocol", "BaseAgent", "ChatAgent"]
|
||||
|
||||
|
||||
@@ -396,7 +435,7 @@ class BaseAgent(SerializationMixin):
|
||||
if not isinstance(self, AgentProtocol):
|
||||
raise TypeError(f"Agent {self.__class__.__name__} must implement AgentProtocol to be used as a tool")
|
||||
|
||||
tool_name = name or self.name
|
||||
tool_name = name or _sanitize_agent_name(self.name)
|
||||
if tool_name is None:
|
||||
raise ValueError("Agent tool name cannot be None. Either provide a name parameter or set the agent's name.")
|
||||
tool_description = description or self.description or ""
|
||||
@@ -404,7 +443,8 @@ class BaseAgent(SerializationMixin):
|
||||
|
||||
# Create dynamic input model with the specified argument name
|
||||
field_info = Field(..., description=argument_description)
|
||||
input_model = create_model(f"{name or self.name or 'agent'}_task", **{arg_name: (str, field_info)}) # type: ignore[call-overload]
|
||||
model_name = f"{name or _sanitize_agent_name(self.name) or 'agent'}_task"
|
||||
input_model = create_model(model_name, **{arg_name: (str, field_info)}) # type: ignore[call-overload]
|
||||
|
||||
# Check if callback is async once, outside the wrapper
|
||||
is_async_callback = stream_callback is not None and inspect.iscoroutinefunction(stream_callback)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from pytest import raises
|
||||
@@ -23,6 +25,7 @@ from agent_framework import (
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework.exceptions import AgentExecutionException
|
||||
|
||||
|
||||
@@ -506,3 +509,69 @@ async def test_chat_agent_as_tool_with_async_stream_callback(chat_client: ChatCl
|
||||
# Result should be concatenation of all streaming updates
|
||||
expected_text = "".join(update.text for update in collected_updates)
|
||||
assert result == expected_text
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_name_sanitization(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test as_tool name sanitization."""
|
||||
test_cases = [
|
||||
("Invoice & Billing Agent", "Invoice_Billing_Agent"),
|
||||
("Travel & Logistics Agent", "Travel_Logistics_Agent"),
|
||||
("Agent@Company.com", "Agent_Company_com"),
|
||||
("Agent___Multiple___Underscores", "Agent_Multiple_Underscores"),
|
||||
("123Agent", "_123Agent"), # Test digit prefix handling
|
||||
("9to5Helper", "_9to5Helper"), # Another digit prefix case
|
||||
("@@@", "agent"), # Test empty sanitization fallback
|
||||
]
|
||||
|
||||
for agent_name, expected_tool_name in test_cases:
|
||||
agent = ChatAgent(chat_client=chat_client, name=agent_name, description="Test agent")
|
||||
tool = agent.as_tool()
|
||||
assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}"
|
||||
|
||||
|
||||
async def test_chat_agent_as_mcp_server_basic(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test basic as_mcp_server functionality."""
|
||||
agent = ChatAgent(chat_client=chat_client, name="TestAgent", description="Test agent for MCP")
|
||||
|
||||
# Create MCP server with default parameters
|
||||
server = agent.as_mcp_server()
|
||||
|
||||
# Verify server is created
|
||||
assert server is not None
|
||||
assert hasattr(server, "name")
|
||||
assert hasattr(server, "version")
|
||||
|
||||
|
||||
async def test_chat_agent_run_with_mcp_tools(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test run method with MCP tools to cover MCP tool handling code."""
|
||||
agent = ChatAgent(chat_client=chat_client, name="TestAgent", description="Test agent")
|
||||
|
||||
# Create a mock MCP tool
|
||||
mock_mcp_tool = MagicMock(spec=MCPTool)
|
||||
mock_mcp_tool.is_connected = False
|
||||
mock_mcp_tool.functions = [MagicMock()]
|
||||
|
||||
# Mock the async context manager entry
|
||||
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
|
||||
mock_mcp_tool.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Test run with MCP tools - this should hit the MCP tool handling code
|
||||
with contextlib.suppress(Exception):
|
||||
# We expect this to fail since we're using mocks, but we want to exercise the code path
|
||||
await agent.run(messages="Test message", tools=[mock_mcp_tool])
|
||||
|
||||
|
||||
async def test_chat_agent_with_local_mcp_tools(chat_client: ChatClientProtocol) -> None:
|
||||
"""Test agent initialization with local MCP tools."""
|
||||
# Create a mock MCP tool
|
||||
mock_mcp_tool = MagicMock(spec=MCPTool)
|
||||
mock_mcp_tool.is_connected = False
|
||||
mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool)
|
||||
mock_mcp_tool.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Test agent with MCP tools in constructor
|
||||
with contextlib.suppress(Exception):
|
||||
agent = ChatAgent(chat_client=chat_client, name="TestAgent", description="Test agent", tools=[mock_mcp_tool])
|
||||
# Test async context manager with MCP tools
|
||||
async with agent:
|
||||
pass
|
||||
|
||||
@@ -439,7 +439,7 @@ async def test_get_streaming_response_with_all_parameters() -> None:
|
||||
instructions="Stream response test",
|
||||
max_tokens=50,
|
||||
parallel_tool_calls=False,
|
||||
model="gpt-4",
|
||||
model_id="gpt-4",
|
||||
previous_response_id="stream-prev-123",
|
||||
reasoning={"mode": "stream"},
|
||||
service_tier="default",
|
||||
|
||||
Reference in New Issue
Block a user