mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: fix middleware and cleanup confusing function (#1865)
* fix middleware and cleanup redundant function * added test to validate
This commit is contained in:
committed by
GitHub
Unverified
parent
0b843d2b3e
commit
e462d209fd
@@ -559,19 +559,6 @@ async def test_azure_ai_chat_client_create_run_options_with_messages(mock_ai_pro
|
||||
assert len(run_options["additional_messages"]) == 1 # Only user message
|
||||
|
||||
|
||||
async def test_azure_ai_chat_client_instructions_sent_once(mock_ai_project_client: MagicMock) -> None:
|
||||
"""Ensure instructions are only sent once for AzureAIAgentClient."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_ai_project_client)
|
||||
|
||||
instructions = "You are a helpful assistant."
|
||||
chat_options = ChatOptions(instructions=instructions)
|
||||
messages = chat_client.prepare_messages([ChatMessage(role=Role.USER, text="Hello")], chat_options)
|
||||
|
||||
run_options, _ = await chat_client._create_run_options(messages, chat_options) # type: ignore
|
||||
|
||||
assert run_options.get("instructions") == instructions
|
||||
|
||||
|
||||
async def test_azure_ai_chat_client_inner_get_response(mock_ai_project_client: MagicMock) -> None:
|
||||
"""Test _inner_get_response method."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_ai_project_client, agent_id="test-agent")
|
||||
|
||||
@@ -20,13 +20,7 @@ from ._middleware import (
|
||||
from ._serialization import SerializationMixin
|
||||
from ._threads import ChatMessageStoreProtocol
|
||||
from ._tools import ToolProtocol
|
||||
from ._types import (
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ToolMode,
|
||||
)
|
||||
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, ToolMode, prepare_messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agents import ChatAgent
|
||||
@@ -216,28 +210,7 @@ class ChatClientProtocol(Protocol):
|
||||
# region ChatClientBase
|
||||
|
||||
|
||||
def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]:
|
||||
"""Convert various message input formats into a list of ChatMessage objects.
|
||||
|
||||
Args:
|
||||
messages: The input messages in various supported formats.
|
||||
|
||||
Returns:
|
||||
A list of ChatMessage objects.
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return [ChatMessage(role="user", text=messages)]
|
||||
if isinstance(messages, ChatMessage):
|
||||
return [messages]
|
||||
return_messages: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
msg = ChatMessage(role="user", text=msg)
|
||||
return_messages.append(msg)
|
||||
return return_messages
|
||||
|
||||
|
||||
def merge_chat_options(
|
||||
def _merge_chat_options(
|
||||
*,
|
||||
base_chat_options: ChatOptions | Any | None,
|
||||
model_id: str | None = None,
|
||||
@@ -405,25 +378,6 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def prepare_messages(
|
||||
self, messages: str | ChatMessage | list[str] | list[ChatMessage], chat_options: ChatOptions
|
||||
) -> MutableSequence[ChatMessage]:
|
||||
"""Convert various message input formats into a list of ChatMessage objects.
|
||||
|
||||
Prepends system instructions if present in chat_options.
|
||||
|
||||
Args:
|
||||
messages: The input messages in various supported formats.
|
||||
chat_options: The chat options containing instructions and other settings.
|
||||
|
||||
Returns:
|
||||
A mutable sequence of ChatMessage objects.
|
||||
"""
|
||||
if chat_options.instructions:
|
||||
system_msg = ChatMessage(role="system", text=chat_options.instructions)
|
||||
return [system_msg, *prepare_messages(messages)]
|
||||
return prepare_messages(messages)
|
||||
|
||||
def _filter_internal_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Filter out internal framework parameters that shouldn't be passed to chat client implementations.
|
||||
|
||||
@@ -584,7 +538,7 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
"""
|
||||
# Normalize tools and merge with base chat_options
|
||||
normalized_tools = await self._normalize_tools(tools)
|
||||
chat_options = merge_chat_options(
|
||||
chat_options = _merge_chat_options(
|
||||
base_chat_options=kwargs.pop("chat_options", None),
|
||||
model_id=model_id,
|
||||
frequency_penalty=frequency_penalty,
|
||||
@@ -612,7 +566,11 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
)
|
||||
chat_options.store = True
|
||||
|
||||
prepped_messages = self.prepare_messages(messages, chat_options)
|
||||
if chat_options.instructions:
|
||||
system_msg = ChatMessage(role="system", text=chat_options.instructions)
|
||||
prepped_messages = [system_msg, *prepare_messages(messages)]
|
||||
else:
|
||||
prepped_messages = prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
filtered_kwargs = self._filter_internal_kwargs(kwargs)
|
||||
@@ -679,7 +637,7 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
"""
|
||||
# Normalize tools and merge with base chat_options
|
||||
normalized_tools = await self._normalize_tools(tools)
|
||||
chat_options = merge_chat_options(
|
||||
chat_options = _merge_chat_options(
|
||||
base_chat_options=kwargs.pop("chat_options", None),
|
||||
model_id=model_id,
|
||||
frequency_penalty=frequency_penalty,
|
||||
@@ -707,7 +665,11 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
)
|
||||
chat_options.store = True
|
||||
|
||||
prepped_messages = self.prepare_messages(messages, chat_options)
|
||||
if chat_options.instructions:
|
||||
system_msg = ChatMessage(role="system", text=chat_options.instructions)
|
||||
prepped_messages = [system_msg, *prepare_messages(messages)]
|
||||
else:
|
||||
prepped_messages = prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
filtered_kwargs = self._filter_internal_kwargs(kwargs)
|
||||
|
||||
@@ -8,7 +8,7 @@ from functools import update_wrapper
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar
|
||||
|
||||
from ._serialization import SerializationMixin
|
||||
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, prepare_messages
|
||||
from .exceptions import MiddlewareException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1375,7 +1375,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
|
||||
pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type]
|
||||
context = ChatContext(
|
||||
chat_client=self,
|
||||
messages=self.prepare_messages(messages, chat_options),
|
||||
messages=prepare_messages(messages),
|
||||
chat_options=chat_options,
|
||||
is_streaming=False,
|
||||
kwargs=kwargs,
|
||||
@@ -1425,7 +1425,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
|
||||
pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type]
|
||||
context = ChatContext(
|
||||
chat_client=self,
|
||||
messages=self.prepare_messages(messages, chat_options),
|
||||
messages=prepare_messages(messages),
|
||||
chat_options=chat_options,
|
||||
is_streaming=True,
|
||||
kwargs=kwargs,
|
||||
|
||||
@@ -1318,13 +1318,13 @@ def _handle_function_calls_response(
|
||||
messages: "str | ChatMessage | list[str] | list[ChatMessage]",
|
||||
**kwargs: Any,
|
||||
) -> "ChatResponse":
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import (
|
||||
ChatMessage,
|
||||
FunctionApprovalRequestContent,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
prepare_messages,
|
||||
)
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
@@ -1465,7 +1465,6 @@ def _handle_function_calls_streaming_response(
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable["ChatResponseUpdate"]:
|
||||
"""Wrap the inner get streaming response method to handle tool calls."""
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import (
|
||||
ChatMessage,
|
||||
@@ -1473,6 +1472,7 @@ def _handle_function_calls_streaming_response(
|
||||
ChatResponseUpdate,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
prepare_messages,
|
||||
)
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
|
||||
@@ -2052,6 +2052,27 @@ class ChatMessage(SerializationMixin):
|
||||
return " ".join(content.text for content in self.contents if isinstance(content, TextContent))
|
||||
|
||||
|
||||
def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]:
|
||||
"""Convert various message input formats into a list of ChatMessage objects.
|
||||
|
||||
Args:
|
||||
messages: The input messages in various supported formats.
|
||||
|
||||
Returns:
|
||||
A list of ChatMessage objects.
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return [ChatMessage(role="user", text=messages)]
|
||||
if isinstance(messages, ChatMessage):
|
||||
return [messages]
|
||||
return_messages: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
msg = ChatMessage(role="user", text=msg)
|
||||
return_messages.append(msg)
|
||||
return return_messages
|
||||
|
||||
|
||||
# region ChatResponse
|
||||
|
||||
|
||||
|
||||
@@ -1413,7 +1413,7 @@ def _capture_messages(
|
||||
finish_reason: "FinishReason | None" = None,
|
||||
) -> None:
|
||||
"""Log messages with extra information."""
|
||||
from ._clients import prepare_messages
|
||||
from ._types import prepare_messages
|
||||
|
||||
prepped = prepare_messages(messages)
|
||||
otel_messages: list[dict[str, Any]] = []
|
||||
|
||||
@@ -15,7 +15,6 @@ from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatClientProtocol,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
HostedCodeInterpreterTool,
|
||||
@@ -155,18 +154,6 @@ def test_azure_assistants_client_init_with_default_headers(azure_openai_unit_tes
|
||||
assert chat_client.client.default_headers[key] == value
|
||||
|
||||
|
||||
def test_azure_assistants_client_instructions_sent_once(mock_async_azure_openai: MagicMock) -> None:
|
||||
"""Ensure instructions are only included once for Azure OpenAI Assistants requests."""
|
||||
chat_client = create_test_azure_assistants_client(mock_async_azure_openai)
|
||||
instructions = "You are a helpful assistant."
|
||||
chat_options = ChatOptions(instructions=instructions)
|
||||
|
||||
prepared_messages = chat_client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
|
||||
run_options, _ = chat_client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert run_options.get("instructions") == instructions
|
||||
|
||||
|
||||
async def test_azure_assistants_client_get_assistant_id_or_create_existing_assistant(
|
||||
mock_async_azure_openai: MagicMock,
|
||||
) -> None:
|
||||
|
||||
@@ -23,7 +23,6 @@ from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatClientProtocol,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
TextContent,
|
||||
@@ -84,18 +83,6 @@ def test_init_base_url(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
assert azure_chat_client.client.default_headers[key] == value
|
||||
|
||||
|
||||
def test_azure_openai_chat_client_instructions_sent_once(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Ensure instructions are only included once when preparing Azure OpenAI chat requests."""
|
||||
client = AzureOpenAIChatClient()
|
||||
instructions = "You are a helpful assistant."
|
||||
chat_options = ChatOptions(instructions=instructions)
|
||||
|
||||
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
|
||||
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert json.dumps(request_options).count(instructions) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_BASE_URL"]], indirect=True)
|
||||
def test_init_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
azure_chat_client = AzureOpenAIChatClient()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
@@ -15,7 +14,6 @@ from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatClientProtocol,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
HostedCodeInterpreterTool,
|
||||
@@ -114,18 +112,6 @@ def test_init_with_default_header(azure_openai_unit_test_env: dict[str, str]) ->
|
||||
assert azure_responses_client.client.default_headers[key] == value
|
||||
|
||||
|
||||
def test_azure_responses_client_instructions_sent_once(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Ensure instructions are only included once for Azure OpenAI Responses requests."""
|
||||
client = AzureOpenAIResponsesClient()
|
||||
instructions = "You are a helpful assistant."
|
||||
chat_options = ChatOptions(instructions=instructions)
|
||||
|
||||
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
|
||||
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert json.dumps(request_options).count(instructions) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"]], indirect=True)
|
||||
def test_init_with_empty_model_id(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework import (
|
||||
BaseChatClient,
|
||||
ChatClientProtocol,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
Role,
|
||||
)
|
||||
|
||||
@@ -39,3 +42,20 @@ async def test_base_client_get_response(chat_client_base: ChatClientProtocol):
|
||||
async def test_base_client_get_streaming_response(chat_client_base: ChatClientProtocol):
|
||||
async for update in chat_client_base.get_streaming_response(ChatMessage(role="user", text="Hello")):
|
||||
assert update.text == "update - Hello" or update.text == "another update"
|
||||
|
||||
|
||||
async def test_chat_client_instructions_handling(chat_client_base: ChatClientProtocol):
|
||||
instructions = "You are a helpful assistant."
|
||||
with patch.object(
|
||||
chat_client_base,
|
||||
"_inner_get_response",
|
||||
) as mock_inner_get_response:
|
||||
await chat_client_base.get_response("hello", chat_options=ChatOptions(instructions=instructions))
|
||||
mock_inner_get_response.assert_called_once()
|
||||
_, kwargs = mock_inner_get_response.call_args
|
||||
messages = kwargs.get("messages", [])
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == Role.SYSTEM
|
||||
assert messages[0].text == instructions
|
||||
assert messages[1].role == Role.USER
|
||||
assert messages[1].text == "hello"
|
||||
|
||||
@@ -193,18 +193,6 @@ def test_openai_assistants_client_init_with_default_headers(openai_unit_test_env
|
||||
assert chat_client.client.default_headers[key] == value
|
||||
|
||||
|
||||
def test_openai_assistants_client_instructions_sent_once(mock_async_openai: MagicMock) -> None:
|
||||
"""Ensure instructions are only included once for OpenAI Assistants requests."""
|
||||
chat_client = create_test_openai_assistants_client(mock_async_openai)
|
||||
instructions = "You are a helpful assistant."
|
||||
chat_options = ChatOptions(instructions=instructions)
|
||||
|
||||
prepared_messages = chat_client.prepare_messages([ChatMessage(role=Role.USER, text="Hello")], chat_options)
|
||||
run_options, _ = chat_client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert run_options.get("instructions") == instructions
|
||||
|
||||
|
||||
async def test_openai_assistants_client_get_assistant_id_or_create_existing_assistant(
|
||||
mock_async_openai: MagicMock,
|
||||
) -> None:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -100,18 +99,6 @@ def test_init_base_url_from_settings_env() -> None:
|
||||
assert str(client.client.base_url) == "https://custom-openai-endpoint.com/v1/"
|
||||
|
||||
|
||||
def test_openai_chat_client_instructions_sent_once(openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Ensure instructions are only included once for OpenAI chat requests."""
|
||||
client = OpenAIChatClient()
|
||||
instructions = "You are a helpful assistant."
|
||||
chat_options = ChatOptions(instructions=instructions)
|
||||
|
||||
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
|
||||
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert json.dumps(request_options).count(instructions) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["OPENAI_CHAT_MODEL_ID"]], indirect=True)
|
||||
def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -138,18 +137,6 @@ def test_init_with_default_header(openai_unit_test_env: dict[str, str]) -> None:
|
||||
assert openai_responses_client.client.default_headers[key] == value
|
||||
|
||||
|
||||
def test_openai_responses_client_instructions_sent_once(openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Ensure instructions are only included once for OpenAI Responses requests."""
|
||||
client = OpenAIResponsesClient()
|
||||
instructions = "You are a helpful assistant."
|
||||
chat_options = ChatOptions(instructions=instructions)
|
||||
|
||||
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
|
||||
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
|
||||
|
||||
assert json.dumps(request_options).count(instructions) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["OPENAI_RESPONSES_MODEL_ID"]], indirect=True)
|
||||
def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
|
||||
Reference in New Issue
Block a user