Python: fix middleware and cleanup confusing function (#1865)

* fix middleware and cleanup redundant function

* added test to validate
This commit is contained in:
Eduard van Valkenburg
2025-11-03 19:42:59 +01:00
committed by GitHub
Unverified
parent 0b843d2b3e
commit e462d209fd
13 changed files with 61 additions and 149 deletions
@@ -559,19 +559,6 @@ async def test_azure_ai_chat_client_create_run_options_with_messages(mock_ai_pro
assert len(run_options["additional_messages"]) == 1 # Only user message
async def test_azure_ai_chat_client_instructions_sent_once(mock_ai_project_client: MagicMock) -> None:
"""Ensure instructions are only sent once for AzureAIAgentClient."""
chat_client = create_test_azure_ai_chat_client(mock_ai_project_client)
instructions = "You are a helpful assistant."
chat_options = ChatOptions(instructions=instructions)
messages = chat_client.prepare_messages([ChatMessage(role=Role.USER, text="Hello")], chat_options)
run_options, _ = await chat_client._create_run_options(messages, chat_options) # type: ignore
assert run_options.get("instructions") == instructions
async def test_azure_ai_chat_client_inner_get_response(mock_ai_project_client: MagicMock) -> None:
"""Test _inner_get_response method."""
chat_client = create_test_azure_ai_chat_client(mock_ai_project_client, agent_id="test-agent")
@@ -20,13 +20,7 @@ from ._middleware import (
from ._serialization import SerializationMixin
from ._threads import ChatMessageStoreProtocol
from ._tools import ToolProtocol
from ._types import (
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
ToolMode,
)
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, ToolMode, prepare_messages
if TYPE_CHECKING:
from ._agents import ChatAgent
@@ -216,28 +210,7 @@ class ChatClientProtocol(Protocol):
# region ChatClientBase
def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]:
"""Convert various message input formats into a list of ChatMessage objects.
Args:
messages: The input messages in various supported formats.
Returns:
A list of ChatMessage objects.
"""
if isinstance(messages, str):
return [ChatMessage(role="user", text=messages)]
if isinstance(messages, ChatMessage):
return [messages]
return_messages: list[ChatMessage] = []
for msg in messages:
if isinstance(msg, str):
msg = ChatMessage(role="user", text=msg)
return_messages.append(msg)
return return_messages
def merge_chat_options(
def _merge_chat_options(
*,
base_chat_options: ChatOptions | Any | None,
model_id: str | None = None,
@@ -405,25 +378,6 @@ class BaseChatClient(SerializationMixin, ABC):
return result
def prepare_messages(
self, messages: str | ChatMessage | list[str] | list[ChatMessage], chat_options: ChatOptions
) -> MutableSequence[ChatMessage]:
"""Convert various message input formats into a list of ChatMessage objects.
Prepends system instructions if present in chat_options.
Args:
messages: The input messages in various supported formats.
chat_options: The chat options containing instructions and other settings.
Returns:
A mutable sequence of ChatMessage objects.
"""
if chat_options.instructions:
system_msg = ChatMessage(role="system", text=chat_options.instructions)
return [system_msg, *prepare_messages(messages)]
return prepare_messages(messages)
def _filter_internal_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Filter out internal framework parameters that shouldn't be passed to chat client implementations.
@@ -584,7 +538,7 @@ class BaseChatClient(SerializationMixin, ABC):
"""
# Normalize tools and merge with base chat_options
normalized_tools = await self._normalize_tools(tools)
chat_options = merge_chat_options(
chat_options = _merge_chat_options(
base_chat_options=kwargs.pop("chat_options", None),
model_id=model_id,
frequency_penalty=frequency_penalty,
@@ -612,7 +566,11 @@ class BaseChatClient(SerializationMixin, ABC):
)
chat_options.store = True
prepped_messages = self.prepare_messages(messages, chat_options)
if chat_options.instructions:
system_msg = ChatMessage(role="system", text=chat_options.instructions)
prepped_messages = [system_msg, *prepare_messages(messages)]
else:
prepped_messages = prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
filtered_kwargs = self._filter_internal_kwargs(kwargs)
@@ -679,7 +637,7 @@ class BaseChatClient(SerializationMixin, ABC):
"""
# Normalize tools and merge with base chat_options
normalized_tools = await self._normalize_tools(tools)
chat_options = merge_chat_options(
chat_options = _merge_chat_options(
base_chat_options=kwargs.pop("chat_options", None),
model_id=model_id,
frequency_penalty=frequency_penalty,
@@ -707,7 +665,11 @@ class BaseChatClient(SerializationMixin, ABC):
)
chat_options.store = True
prepped_messages = self.prepare_messages(messages, chat_options)
if chat_options.instructions:
system_msg = ChatMessage(role="system", text=chat_options.instructions)
prepped_messages = [system_msg, *prepare_messages(messages)]
else:
prepped_messages = prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
filtered_kwargs = self._filter_internal_kwargs(kwargs)
@@ -8,7 +8,7 @@ from functools import update_wrapper
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar
from ._serialization import SerializationMixin
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, prepare_messages
from .exceptions import MiddlewareException
if TYPE_CHECKING:
@@ -1375,7 +1375,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=self.prepare_messages(messages, chat_options),
messages=prepare_messages(messages),
chat_options=chat_options,
is_streaming=False,
kwargs=kwargs,
@@ -1425,7 +1425,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=self.prepare_messages(messages, chat_options),
messages=prepare_messages(messages),
chat_options=chat_options,
is_streaming=True,
kwargs=kwargs,
@@ -1318,13 +1318,13 @@ def _handle_function_calls_response(
messages: "str | ChatMessage | list[str] | list[ChatMessage]",
**kwargs: Any,
) -> "ChatResponse":
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import (
ChatMessage,
FunctionApprovalRequestContent,
FunctionCallContent,
FunctionResultContent,
prepare_messages,
)
# Extract and merge function middleware from chat client with kwargs pipeline
@@ -1465,7 +1465,6 @@ def _handle_function_calls_streaming_response(
**kwargs: Any,
) -> AsyncIterable["ChatResponseUpdate"]:
"""Wrap the inner get streaming response method to handle tool calls."""
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import (
ChatMessage,
@@ -1473,6 +1472,7 @@ def _handle_function_calls_streaming_response(
ChatResponseUpdate,
FunctionCallContent,
FunctionResultContent,
prepare_messages,
)
# Extract and merge function middleware from chat client with kwargs pipeline
@@ -2052,6 +2052,27 @@ class ChatMessage(SerializationMixin):
return " ".join(content.text for content in self.contents if isinstance(content, TextContent))
def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]:
"""Convert various message input formats into a list of ChatMessage objects.
Args:
messages: The input messages in various supported formats.
Returns:
A list of ChatMessage objects.
"""
if isinstance(messages, str):
return [ChatMessage(role="user", text=messages)]
if isinstance(messages, ChatMessage):
return [messages]
return_messages: list[ChatMessage] = []
for msg in messages:
if isinstance(msg, str):
msg = ChatMessage(role="user", text=msg)
return_messages.append(msg)
return return_messages
# region ChatResponse
@@ -1413,7 +1413,7 @@ def _capture_messages(
finish_reason: "FinishReason | None" = None,
) -> None:
"""Log messages with extra information."""
from ._clients import prepare_messages
from ._types import prepare_messages
prepped = prepare_messages(messages)
otel_messages: list[dict[str, Any]] = []
@@ -15,7 +15,6 @@ from agent_framework import (
ChatAgent,
ChatClientProtocol,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
HostedCodeInterpreterTool,
@@ -155,18 +154,6 @@ def test_azure_assistants_client_init_with_default_headers(azure_openai_unit_tes
assert chat_client.client.default_headers[key] == value
def test_azure_assistants_client_instructions_sent_once(mock_async_azure_openai: MagicMock) -> None:
"""Ensure instructions are only included once for Azure OpenAI Assistants requests."""
chat_client = create_test_azure_assistants_client(mock_async_azure_openai)
instructions = "You are a helpful assistant."
chat_options = ChatOptions(instructions=instructions)
prepared_messages = chat_client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
run_options, _ = chat_client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
assert run_options.get("instructions") == instructions
async def test_azure_assistants_client_get_assistant_id_or_create_existing_assistant(
mock_async_azure_openai: MagicMock,
) -> None:
@@ -23,7 +23,6 @@ from agent_framework import (
ChatAgent,
ChatClientProtocol,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
TextContent,
@@ -84,18 +83,6 @@ def test_init_base_url(azure_openai_unit_test_env: dict[str, str]) -> None:
assert azure_chat_client.client.default_headers[key] == value
def test_azure_openai_chat_client_instructions_sent_once(azure_openai_unit_test_env: dict[str, str]) -> None:
"""Ensure instructions are only included once when preparing Azure OpenAI chat requests."""
client = AzureOpenAIChatClient()
instructions = "You are a helpful assistant."
chat_options = ChatOptions(instructions=instructions)
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
assert json.dumps(request_options).count(instructions) == 1
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_BASE_URL"]], indirect=True)
def test_init_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None:
azure_chat_client = AzureOpenAIChatClient()
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
import json
import os
from typing import Annotated
@@ -15,7 +14,6 @@ from agent_framework import (
ChatAgent,
ChatClientProtocol,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
HostedCodeInterpreterTool,
@@ -114,18 +112,6 @@ def test_init_with_default_header(azure_openai_unit_test_env: dict[str, str]) ->
assert azure_responses_client.client.default_headers[key] == value
def test_azure_responses_client_instructions_sent_once(azure_openai_unit_test_env: dict[str, str]) -> None:
"""Ensure instructions are only included once for Azure OpenAI Responses requests."""
client = AzureOpenAIResponsesClient()
instructions = "You are a helpful assistant."
chat_options = ChatOptions(instructions=instructions)
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
assert json.dumps(request_options).count(instructions) == 1
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"]], indirect=True)
def test_init_with_empty_model_id(azure_openai_unit_test_env: dict[str, str]) -> None:
with pytest.raises(ServiceInitializationError):
@@ -1,10 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.
from unittest.mock import patch
from agent_framework import (
BaseChatClient,
ChatClientProtocol,
ChatMessage,
ChatOptions,
Role,
)
@@ -39,3 +42,20 @@ async def test_base_client_get_response(chat_client_base: ChatClientProtocol):
async def test_base_client_get_streaming_response(chat_client_base: ChatClientProtocol):
async for update in chat_client_base.get_streaming_response(ChatMessage(role="user", text="Hello")):
assert update.text == "update - Hello" or update.text == "another update"
async def test_chat_client_instructions_handling(chat_client_base: ChatClientProtocol):
instructions = "You are a helpful assistant."
with patch.object(
chat_client_base,
"_inner_get_response",
) as mock_inner_get_response:
await chat_client_base.get_response("hello", chat_options=ChatOptions(instructions=instructions))
mock_inner_get_response.assert_called_once()
_, kwargs = mock_inner_get_response.call_args
messages = kwargs.get("messages", [])
assert len(messages) == 2
assert messages[0].role == Role.SYSTEM
assert messages[0].text == instructions
assert messages[1].role == Role.USER
assert messages[1].text == "hello"
@@ -193,18 +193,6 @@ def test_openai_assistants_client_init_with_default_headers(openai_unit_test_env
assert chat_client.client.default_headers[key] == value
def test_openai_assistants_client_instructions_sent_once(mock_async_openai: MagicMock) -> None:
"""Ensure instructions are only included once for OpenAI Assistants requests."""
chat_client = create_test_openai_assistants_client(mock_async_openai)
instructions = "You are a helpful assistant."
chat_options = ChatOptions(instructions=instructions)
prepared_messages = chat_client.prepare_messages([ChatMessage(role=Role.USER, text="Hello")], chat_options)
run_options, _ = chat_client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
assert run_options.get("instructions") == instructions
async def test_openai_assistants_client_get_assistant_id_or_create_existing_assistant(
mock_async_openai: MagicMock,
) -> None:
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
import json
import os
from typing import Annotated
from unittest.mock import MagicMock, patch
@@ -100,18 +99,6 @@ def test_init_base_url_from_settings_env() -> None:
assert str(client.client.base_url) == "https://custom-openai-endpoint.com/v1/"
def test_openai_chat_client_instructions_sent_once(openai_unit_test_env: dict[str, str]) -> None:
"""Ensure instructions are only included once for OpenAI chat requests."""
client = OpenAIChatClient()
instructions = "You are a helpful assistant."
chat_options = ChatOptions(instructions=instructions)
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
assert json.dumps(request_options).count(instructions) == 1
@pytest.mark.parametrize("exclude_list", [["OPENAI_CHAT_MODEL_ID"]], indirect=True)
def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
with pytest.raises(ServiceInitializationError):
@@ -2,7 +2,6 @@
import asyncio
import base64
import json
import os
from typing import Annotated
from unittest.mock import MagicMock, patch
@@ -138,18 +137,6 @@ def test_init_with_default_header(openai_unit_test_env: dict[str, str]) -> None:
assert openai_responses_client.client.default_headers[key] == value
def test_openai_responses_client_instructions_sent_once(openai_unit_test_env: dict[str, str]) -> None:
"""Ensure instructions are only included once for OpenAI Responses requests."""
client = OpenAIResponsesClient()
instructions = "You are a helpful assistant."
chat_options = ChatOptions(instructions=instructions)
prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options)
request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage]
assert json.dumps(request_options).count(instructions) == 1
@pytest.mark.parametrize("exclude_list", [["OPENAI_RESPONSES_MODEL_ID"]], indirect=True)
def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
with pytest.raises(ServiceInitializationError):