mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Warn on unsupported AzureAIClient runtime tool/structured_output overrides (#3919)
* Guard AzureAIClient runtime tool and structured output overrides * Simplify AzureAI runtime option pruning logic * small fix * slight update * fix error message in test * fix test var * Move Azure AI runtime override checks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
fc9c81b0b1
commit
b68d0f93e3
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
||||
from contextlib import suppress
|
||||
from typing import Any, ClassVar, Generic, Literal, TypedDict, TypeVar, cast
|
||||
|
||||
from agent_framework import (
|
||||
@@ -218,6 +220,10 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
self._is_application_endpoint = "/applications/" in project_client._config.endpoint # type: ignore
|
||||
# Track whether we should close client connection
|
||||
self._should_close_client = should_close_client
|
||||
# Track creation-time agent configuration for runtime mismatch warnings.
|
||||
self.warn_runtime_tools_and_structure_changed = False
|
||||
self._created_agent_tool_names: set[str] = set()
|
||||
self._created_agent_structured_output_signature: str | None = None
|
||||
|
||||
async def configure_azure_monitor(
|
||||
self,
|
||||
@@ -341,18 +347,18 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
"Agent name is required. Provide 'agent_name' when initializing AzureAIClient "
|
||||
"or 'name' when initializing Agent."
|
||||
)
|
||||
# If the agent exists and we do not want to track agent configuration, return early
|
||||
if self.agent_version is not None and not self.warn_runtime_tools_and_structure_changed:
|
||||
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}
|
||||
|
||||
# If no agent_version is provided, either use latest version or create a new agent:
|
||||
if self.agent_version is None:
|
||||
# Try to use latest version if requested and agent exists
|
||||
if self.use_latest_version:
|
||||
try:
|
||||
with suppress(ResourceNotFoundError):
|
||||
existing_agent = await self.project_client.agents.get(self.agent_name)
|
||||
self.agent_version = existing_agent.versions.latest.version
|
||||
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}
|
||||
except ResourceNotFoundError:
|
||||
# Agent doesn't exist, fall through to creation logic
|
||||
pass
|
||||
|
||||
if "model" not in run_options or not run_options["model"]:
|
||||
raise ServiceInitializationError(
|
||||
@@ -395,7 +401,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
)
|
||||
|
||||
self.agent_version = created_agent.version
|
||||
|
||||
self.warn_runtime_tools_and_structure_changed = True
|
||||
self._created_agent_tool_names = self._extract_tool_names(run_options.get("tools"))
|
||||
self._created_agent_structured_output_signature = self._get_structured_output_signature(chat_options)
|
||||
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}
|
||||
|
||||
async def _close_client_if_needed(self) -> None:
|
||||
@@ -403,6 +411,91 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
if self._should_close_client:
|
||||
await self.project_client.close()
|
||||
|
||||
def _extract_tool_names(self, tools: Any) -> set[str]:
|
||||
"""Extract comparable tool names from runtime tool payloads."""
|
||||
if not isinstance(tools, Sequence) or isinstance(tools, str | bytes):
|
||||
return set()
|
||||
return {self._get_tool_name(tool) for tool in tools}
|
||||
|
||||
def _get_tool_name(self, tool: Any) -> str:
|
||||
"""Get a stable name for a tool for runtime comparison."""
|
||||
if isinstance(tool, FunctionTool):
|
||||
return tool.name
|
||||
if isinstance(tool, Mapping):
|
||||
tool_type = tool.get("type")
|
||||
if tool_type == "function":
|
||||
if isinstance(function_data := tool.get("function"), Mapping) and function_data.get("name"):
|
||||
return str(function_data["name"])
|
||||
if tool.get("name"):
|
||||
return str(tool["name"])
|
||||
if tool.get("name"):
|
||||
return str(tool["name"])
|
||||
if tool.get("server_label"):
|
||||
return f"mcp:{tool['server_label']}"
|
||||
if tool_type:
|
||||
return str(tool_type)
|
||||
if getattr(tool, "name", None):
|
||||
return str(tool.name)
|
||||
if getattr(tool, "server_label", None):
|
||||
return f"mcp:{tool.server_label}"
|
||||
if getattr(tool, "type", None):
|
||||
return str(tool.type)
|
||||
return type(tool).__name__
|
||||
|
||||
def _get_structured_output_signature(self, chat_options: Mapping[str, Any] | None) -> str | None:
|
||||
"""Build a stable signature for structured_output/response_format values."""
|
||||
if not chat_options:
|
||||
return None
|
||||
response_format = chat_options.get("response_format")
|
||||
if response_format is None:
|
||||
return None
|
||||
if isinstance(response_format, type):
|
||||
return f"{response_format.__module__}.{response_format.__qualname__}"
|
||||
if isinstance(response_format, Mapping):
|
||||
return json.dumps(response_format, sort_keys=True, default=str)
|
||||
return str(response_format)
|
||||
|
||||
def _remove_agent_level_run_options(
|
||||
self,
|
||||
run_options: dict[str, Any],
|
||||
chat_options: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Remove request-level options that Azure AI only supports at agent creation time."""
|
||||
runtime_tools = run_options.get("tools")
|
||||
runtime_structured_output = self._get_structured_output_signature(chat_options)
|
||||
|
||||
if runtime_tools is not None or runtime_structured_output is not None:
|
||||
tools_changed = runtime_tools is not None
|
||||
structured_output_changed = runtime_structured_output is not None
|
||||
|
||||
if self.warn_runtime_tools_and_structure_changed:
|
||||
if runtime_tools is not None:
|
||||
tools_changed = self._extract_tool_names(runtime_tools) != self._created_agent_tool_names
|
||||
if runtime_structured_output is not None:
|
||||
structured_output_changed = (
|
||||
runtime_structured_output != self._created_agent_structured_output_signature
|
||||
)
|
||||
|
||||
if tools_changed or structured_output_changed:
|
||||
logger.warning(
|
||||
"AzureAIClient does not support runtime tools or structured_output overrides after agent creation. "
|
||||
"Use AzureOpenAIResponsesClient instead."
|
||||
)
|
||||
|
||||
agent_level_option_to_run_keys = {
|
||||
"model_id": ("model",),
|
||||
"tools": ("tools",),
|
||||
"response_format": ("response_format", "text", "text_format"),
|
||||
"rai_config": ("rai_config",),
|
||||
"temperature": ("temperature",),
|
||||
"top_p": ("top_p",),
|
||||
"reasoning": ("reasoning",),
|
||||
}
|
||||
|
||||
for run_keys in agent_level_option_to_run_keys.values():
|
||||
for run_key in run_keys:
|
||||
run_options.pop(run_key, None)
|
||||
|
||||
@override
|
||||
async def _prepare_options(
|
||||
self,
|
||||
@@ -427,22 +520,8 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
agent_reference = await self._get_agent_reference_or_create(run_options, instructions, options)
|
||||
run_options["extra_body"] = {"agent": agent_reference}
|
||||
|
||||
# Remove properties that are not supported on request level
|
||||
# but were configured on agent level
|
||||
exclude = [
|
||||
"model",
|
||||
"tools",
|
||||
"response_format",
|
||||
"rai_config",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"text",
|
||||
"text_format",
|
||||
"reasoning",
|
||||
]
|
||||
|
||||
for property in exclude:
|
||||
run_options.pop(property, None)
|
||||
# Remove only keys that map to this client's declared options TypedDict.
|
||||
self._remove_agent_level_run_options(run_options, options)
|
||||
|
||||
return run_options
|
||||
|
||||
|
||||
@@ -130,6 +130,9 @@ def create_test_azure_ai_client(
|
||||
client.conversation_id = conversation_id
|
||||
client._is_application_endpoint = False # type: ignore
|
||||
client._should_close_client = should_close_client # type: ignore
|
||||
client.warn_runtime_tools_and_structure_changed = False # type: ignore
|
||||
client._created_agent_tool_names = set() # type: ignore
|
||||
client._created_agent_structured_output_signature = None # type: ignore
|
||||
client.additional_properties = {}
|
||||
client.middleware = None
|
||||
|
||||
@@ -773,6 +776,82 @@ async def test_agent_creation_with_tools(
|
||||
assert call_args[1]["definition"].tools == test_tools
|
||||
|
||||
|
||||
async def test_runtime_tools_override_logs_warning(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test warning is logged when runtime tools differ from creation-time tools."""
|
||||
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent")
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.name = "test-agent"
|
||||
mock_agent.version = "1.0"
|
||||
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
|
||||
with patch(
|
||||
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
|
||||
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_one"}]},
|
||||
):
|
||||
await client._prepare_options(messages, {})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
|
||||
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_two"}]},
|
||||
),
|
||||
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
|
||||
):
|
||||
await client._prepare_options(messages, {})
|
||||
mock_warning.assert_called_once()
|
||||
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]
|
||||
|
||||
|
||||
async def test_prepare_options_logs_warning_for_tools_with_existing_agent_version(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test warning is logged when tools are supplied against an existing agent version."""
|
||||
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
|
||||
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_one"}]},
|
||||
),
|
||||
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
|
||||
):
|
||||
run_options = await client._prepare_options(messages, {})
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]
|
||||
assert "tools" not in run_options
|
||||
|
||||
|
||||
async def test_prepare_options_logs_warning_for_tools_on_application_endpoint(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test warning is logged when runtime tools are removed for application endpoints."""
|
||||
client = create_test_azure_ai_client(mock_project_client)
|
||||
client._is_application_endpoint = True # type: ignore
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
|
||||
return_value={"model": "test-model", "tools": [{"type": "function", "name": "tool_one"}]},
|
||||
),
|
||||
patch.object(client, "_get_agent_reference_or_create", new_callable=AsyncMock) as mock_get_agent_reference,
|
||||
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
|
||||
):
|
||||
run_options = await client._prepare_options(messages, {})
|
||||
|
||||
mock_get_agent_reference.assert_not_called()
|
||||
mock_warning.assert_called_once()
|
||||
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]
|
||||
assert "tools" not in run_options
|
||||
assert "extra_body" not in run_options
|
||||
|
||||
|
||||
async def test_use_latest_version_existing_agent(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
@@ -872,6 +951,13 @@ class ResponseFormatModel(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AlternateResponseFormatModel(BaseModel):
|
||||
"""Alternate model for structured output warning checks."""
|
||||
|
||||
summary: str
|
||||
confidence: float
|
||||
|
||||
|
||||
async def test_agent_creation_with_response_format(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
@@ -964,6 +1050,36 @@ async def test_agent_creation_with_mapping_response_format(
|
||||
assert format_config.strict is True
|
||||
|
||||
|
||||
async def test_runtime_structured_output_override_logs_warning(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test warning is logged when runtime structured_output differs from creation-time configuration."""
|
||||
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent")
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.name = "test-agent"
|
||||
mock_agent.version = "1.0"
|
||||
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
|
||||
with patch(
|
||||
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
|
||||
return_value={"model": "test-model"},
|
||||
):
|
||||
await client._prepare_options(messages, {"response_format": ResponseFormatModel})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
|
||||
return_value={"model": "test-model"},
|
||||
),
|
||||
patch("agent_framework_azure_ai._client.logger.warning") as mock_warning,
|
||||
):
|
||||
await client._prepare_options(messages, {"response_format": AlternateResponseFormatModel})
|
||||
mock_warning.assert_called_once()
|
||||
assert "Use AzureOpenAIResponsesClient instead." in mock_warning.call_args[0][0]
|
||||
|
||||
|
||||
async def test_prepare_options_excludes_response_format(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
@@ -1001,6 +1117,39 @@ async def test_prepare_options_excludes_response_format(
|
||||
assert run_options["extra_body"]["agent"]["name"] == "test-agent"
|
||||
|
||||
|
||||
async def test_prepare_options_keeps_values_for_unsupported_option_keys(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that run_options removal only applies to known AzureAI agent-level option mappings."""
|
||||
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"agent_framework.openai._responses_client.RawOpenAIResponsesClient._prepare_options",
|
||||
return_value={
|
||||
"model": "test-model",
|
||||
"tools": [{"type": "function", "name": "weather"}],
|
||||
"text": {"format": {"type": "json_schema", "name": "schema"}},
|
||||
"text_format": ResponseFormatModel,
|
||||
"custom_option": "keep-me",
|
||||
},
|
||||
),
|
||||
patch.object(
|
||||
client,
|
||||
"_get_agent_reference_or_create",
|
||||
return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"},
|
||||
),
|
||||
):
|
||||
run_options = await client._prepare_options(messages, {})
|
||||
|
||||
assert "model" not in run_options
|
||||
assert "tools" not in run_options
|
||||
assert "text" not in run_options
|
||||
assert "text_format" not in run_options
|
||||
assert run_options["custom_option"] == "keep-me"
|
||||
|
||||
|
||||
def test_get_conversation_id_with_store_true_and_conversation_id() -> None:
|
||||
"""Test _get_conversation_id returns conversation ID when store is True and conversation exists."""
|
||||
client = create_test_azure_ai_client(MagicMock())
|
||||
|
||||
Reference in New Issue
Block a user