From e462d209fdbe87e2e3e7c3905b2ab061ff5bb15e Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Mon, 3 Nov 2025 19:42:59 +0100 Subject: [PATCH] Python: fix middleware and cleanup confusing function (#1865) * fix middleware and cleanup redundant function * added test to validate --- .../tests/test_azure_ai_agent_client.py | 13 ---- .../packages/core/agent_framework/_clients.py | 66 ++++--------------- .../core/agent_framework/_middleware.py | 6 +- .../packages/core/agent_framework/_tools.py | 4 +- .../packages/core/agent_framework/_types.py | 21 ++++++ .../core/agent_framework/observability.py | 2 +- .../azure/test_azure_assistants_client.py | 13 ---- .../tests/azure/test_azure_chat_client.py | 13 ---- .../azure/test_azure_responses_client.py | 14 ---- .../packages/core/tests/core/test_clients.py | 20 ++++++ .../openai/test_openai_assistants_client.py | 12 ---- .../tests/openai/test_openai_chat_client.py | 13 ---- .../openai/test_openai_responses_client.py | 13 ---- 13 files changed, 61 insertions(+), 149 deletions(-) diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index 3658f549ce..9f3a06ebee 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -559,19 +559,6 @@ async def test_azure_ai_chat_client_create_run_options_with_messages(mock_ai_pro assert len(run_options["additional_messages"]) == 1 # Only user message -async def test_azure_ai_chat_client_instructions_sent_once(mock_ai_project_client: MagicMock) -> None: - """Ensure instructions are only sent once for AzureAIAgentClient.""" - chat_client = create_test_azure_ai_chat_client(mock_ai_project_client) - - instructions = "You are a helpful assistant." - chat_options = ChatOptions(instructions=instructions) - messages = chat_client.prepare_messages([ChatMessage(role=Role.USER, text="Hello")], chat_options) - - run_options, _ = await chat_client._create_run_options(messages, chat_options) # type: ignore - - assert run_options.get("instructions") == instructions - - async def test_azure_ai_chat_client_inner_get_response(mock_ai_project_client: MagicMock) -> None: """Test _inner_get_response method.""" chat_client = create_test_azure_ai_chat_client(mock_ai_project_client, agent_id="test-agent") diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 4089aea9e3..0b36b486c8 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -20,13 +20,7 @@ from ._middleware import ( from ._serialization import SerializationMixin from ._threads import ChatMessageStoreProtocol from ._tools import ToolProtocol -from ._types import ( - ChatMessage, - ChatOptions, - ChatResponse, - ChatResponseUpdate, - ToolMode, -) +from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, ToolMode, prepare_messages if TYPE_CHECKING: from ._agents import ChatAgent @@ -216,28 +210,7 @@ class ChatClientProtocol(Protocol): # region ChatClientBase -def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]: - """Convert various message input formats into a list of ChatMessage objects. - - Args: - messages: The input messages in various supported formats. - - Returns: - A list of ChatMessage objects. - """ - if isinstance(messages, str): - return [ChatMessage(role="user", text=messages)] - if isinstance(messages, ChatMessage): - return [messages] - return_messages: list[ChatMessage] = [] - for msg in messages: - if isinstance(msg, str): - msg = ChatMessage(role="user", text=msg) - return_messages.append(msg) - return return_messages - - -def merge_chat_options( +def _merge_chat_options( *, base_chat_options: ChatOptions | Any | None, model_id: str | None = None, @@ -405,25 +378,6 @@ class BaseChatClient(SerializationMixin, ABC): return result - def prepare_messages( - self, messages: str | ChatMessage | list[str] | list[ChatMessage], chat_options: ChatOptions - ) -> MutableSequence[ChatMessage]: - """Convert various message input formats into a list of ChatMessage objects. - - Prepends system instructions if present in chat_options. - - Args: - messages: The input messages in various supported formats. - chat_options: The chat options containing instructions and other settings. - - Returns: - A mutable sequence of ChatMessage objects. - """ - if chat_options.instructions: - system_msg = ChatMessage(role="system", text=chat_options.instructions) - return [system_msg, *prepare_messages(messages)] - return prepare_messages(messages) - def _filter_internal_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: """Filter out internal framework parameters that shouldn't be passed to chat client implementations. @@ -584,7 +538,7 @@ class BaseChatClient(SerializationMixin, ABC): """ # Normalize tools and merge with base chat_options normalized_tools = await self._normalize_tools(tools) - chat_options = merge_chat_options( + chat_options = _merge_chat_options( base_chat_options=kwargs.pop("chat_options", None), model_id=model_id, frequency_penalty=frequency_penalty, @@ -612,7 +566,11 @@ class BaseChatClient(SerializationMixin, ABC): ) chat_options.store = True - prepped_messages = self.prepare_messages(messages, chat_options) + if chat_options.instructions: + system_msg = ChatMessage(role="system", text=chat_options.instructions) + prepped_messages = [system_msg, *prepare_messages(messages)] + else: + prepped_messages = prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) filtered_kwargs = self._filter_internal_kwargs(kwargs) @@ -679,7 +637,7 @@ class BaseChatClient(SerializationMixin, ABC): """ # Normalize tools and merge with base chat_options normalized_tools = await self._normalize_tools(tools) - chat_options = merge_chat_options( + chat_options = _merge_chat_options( base_chat_options=kwargs.pop("chat_options", None), model_id=model_id, frequency_penalty=frequency_penalty, @@ -707,7 +665,11 @@ class BaseChatClient(SerializationMixin, ABC): ) chat_options.store = True - prepped_messages = self.prepare_messages(messages, chat_options) + if chat_options.instructions: + system_msg = ChatMessage(role="system", text=chat_options.instructions) + prepped_messages = [system_msg, *prepare_messages(messages)] + else: + prepped_messages = prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) filtered_kwargs = self._filter_internal_kwargs(kwargs) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index d61f9e9ee8..9bb730ba62 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -8,7 +8,7 @@ from functools import update_wrapper from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar from ._serialization import SerializationMixin -from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage +from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, prepare_messages from .exceptions import MiddlewareException if TYPE_CHECKING: @@ -1375,7 +1375,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type] context = ChatContext( chat_client=self, - messages=self.prepare_messages(messages, chat_options), + messages=prepare_messages(messages), chat_options=chat_options, is_streaming=False, kwargs=kwargs, @@ -1425,7 +1425,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type] context = ChatContext( chat_client=self, - messages=self.prepare_messages(messages, chat_options), + messages=prepare_messages(messages), chat_options=chat_options, is_streaming=True, kwargs=kwargs, diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 22f3dee18a..22b9921e49 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1318,13 +1318,13 @@ def _handle_function_calls_response( messages: "str | ChatMessage | list[str] | list[ChatMessage]", **kwargs: Any, ) -> "ChatResponse": - from ._clients import prepare_messages from ._middleware import extract_and_merge_function_middleware from ._types import ( ChatMessage, FunctionApprovalRequestContent, FunctionCallContent, FunctionResultContent, + prepare_messages, ) # Extract and merge function middleware from chat client with kwargs pipeline @@ -1465,7 +1465,6 @@ def _handle_function_calls_streaming_response( **kwargs: Any, ) -> AsyncIterable["ChatResponseUpdate"]: """Wrap the inner get streaming response method to handle tool calls.""" - from ._clients import prepare_messages from ._middleware import extract_and_merge_function_middleware from ._types import ( ChatMessage, @@ -1473,6 +1472,7 @@ def _handle_function_calls_streaming_response( ChatResponseUpdate, FunctionCallContent, FunctionResultContent, + prepare_messages, ) # Extract and merge function middleware from chat client with kwargs pipeline diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 3294f78a5d..f1a12f813e 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -2052,6 +2052,27 @@ class ChatMessage(SerializationMixin): return " ".join(content.text for content in self.contents if isinstance(content, TextContent)) +def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]: + """Convert various message input formats into a list of ChatMessage objects. + + Args: + messages: The input messages in various supported formats. + + Returns: + A list of ChatMessage objects. + """ + if isinstance(messages, str): + return [ChatMessage(role="user", text=messages)] + if isinstance(messages, ChatMessage): + return [messages] + return_messages: list[ChatMessage] = [] + for msg in messages: + if isinstance(msg, str): + msg = ChatMessage(role="user", text=msg) + return_messages.append(msg) + return return_messages + + # region ChatResponse diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 2bdaa90206..c17ce12666 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1413,7 +1413,7 @@ def _capture_messages( finish_reason: "FinishReason | None" = None, ) -> None: """Log messages with extra information.""" - from ._clients import prepare_messages + from ._types import prepare_messages prepped = prepare_messages(messages) otel_messages: list[dict[str, Any]] = [] diff --git a/python/packages/core/tests/azure/test_azure_assistants_client.py b/python/packages/core/tests/azure/test_azure_assistants_client.py index 307e6c7ac1..758be68d3b 100644 --- a/python/packages/core/tests/azure/test_azure_assistants_client.py +++ b/python/packages/core/tests/azure/test_azure_assistants_client.py @@ -15,7 +15,6 @@ from agent_framework import ( ChatAgent, ChatClientProtocol, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, HostedCodeInterpreterTool, @@ -155,18 +154,6 @@ def test_azure_assistants_client_init_with_default_headers(azure_openai_unit_tes assert chat_client.client.default_headers[key] == value -def test_azure_assistants_client_instructions_sent_once(mock_async_azure_openai: MagicMock) -> None: - """Ensure instructions are only included once for Azure OpenAI Assistants requests.""" - chat_client = create_test_azure_assistants_client(mock_async_azure_openai) - instructions = "You are a helpful assistant." - chat_options = ChatOptions(instructions=instructions) - - prepared_messages = chat_client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options) - run_options, _ = chat_client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage] - - assert run_options.get("instructions") == instructions - - async def test_azure_assistants_client_get_assistant_id_or_create_existing_assistant( mock_async_azure_openai: MagicMock, ) -> None: diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index d43302d472..1b7dbb904b 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -23,7 +23,6 @@ from agent_framework import ( ChatAgent, ChatClientProtocol, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, TextContent, @@ -84,18 +83,6 @@ def test_init_base_url(azure_openai_unit_test_env: dict[str, str]) -> None: assert azure_chat_client.client.default_headers[key] == value -def test_azure_openai_chat_client_instructions_sent_once(azure_openai_unit_test_env: dict[str, str]) -> None: - """Ensure instructions are only included once when preparing Azure OpenAI chat requests.""" - client = AzureOpenAIChatClient() - instructions = "You are a helpful assistant." - chat_options = ChatOptions(instructions=instructions) - - prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options) - request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage] - - assert json.dumps(request_options).count(instructions) == 1 - - @pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_BASE_URL"]], indirect=True) def test_init_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None: azure_chat_client = AzureOpenAIChatClient() diff --git a/python/packages/core/tests/azure/test_azure_responses_client.py b/python/packages/core/tests/azure/test_azure_responses_client.py index 658aa21457..a495d05837 100644 --- a/python/packages/core/tests/azure/test_azure_responses_client.py +++ b/python/packages/core/tests/azure/test_azure_responses_client.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import json import os from typing import Annotated @@ -15,7 +14,6 @@ from agent_framework import ( ChatAgent, ChatClientProtocol, ChatMessage, - ChatOptions, ChatResponse, ChatResponseUpdate, HostedCodeInterpreterTool, @@ -114,18 +112,6 @@ def test_init_with_default_header(azure_openai_unit_test_env: dict[str, str]) -> assert azure_responses_client.client.default_headers[key] == value -def test_azure_responses_client_instructions_sent_once(azure_openai_unit_test_env: dict[str, str]) -> None: - """Ensure instructions are only included once for Azure OpenAI Responses requests.""" - client = AzureOpenAIResponsesClient() - instructions = "You are a helpful assistant." - chat_options = ChatOptions(instructions=instructions) - - prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options) - request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage] - - assert json.dumps(request_options).count(instructions) == 1 - - @pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"]], indirect=True) def test_init_with_empty_model_id(azure_openai_unit_test_env: dict[str, str]) -> None: with pytest.raises(ServiceInitializationError): diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index c0e319e34b..423a7e42b5 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. +from unittest.mock import patch + from agent_framework import ( BaseChatClient, ChatClientProtocol, ChatMessage, + ChatOptions, Role, ) @@ -39,3 +42,20 @@ async def test_base_client_get_response(chat_client_base: ChatClientProtocol): async def test_base_client_get_streaming_response(chat_client_base: ChatClientProtocol): async for update in chat_client_base.get_streaming_response(ChatMessage(role="user", text="Hello")): assert update.text == "update - Hello" or update.text == "another update" + + +async def test_chat_client_instructions_handling(chat_client_base: ChatClientProtocol): + instructions = "You are a helpful assistant." + with patch.object( + chat_client_base, + "_inner_get_response", + ) as mock_inner_get_response: + await chat_client_base.get_response("hello", chat_options=ChatOptions(instructions=instructions)) + mock_inner_get_response.assert_called_once() + _, kwargs = mock_inner_get_response.call_args + messages = kwargs.get("messages", []) + assert len(messages) == 2 + assert messages[0].role == Role.SYSTEM + assert messages[0].text == instructions + assert messages[1].role == Role.USER + assert messages[1].text == "hello" diff --git a/python/packages/core/tests/openai/test_openai_assistants_client.py b/python/packages/core/tests/openai/test_openai_assistants_client.py index be1a059b58..90947dd437 100644 --- a/python/packages/core/tests/openai/test_openai_assistants_client.py +++ b/python/packages/core/tests/openai/test_openai_assistants_client.py @@ -193,18 +193,6 @@ def test_openai_assistants_client_init_with_default_headers(openai_unit_test_env assert chat_client.client.default_headers[key] == value -def test_openai_assistants_client_instructions_sent_once(mock_async_openai: MagicMock) -> None: - """Ensure instructions are only included once for OpenAI Assistants requests.""" - chat_client = create_test_openai_assistants_client(mock_async_openai) - instructions = "You are a helpful assistant." - chat_options = ChatOptions(instructions=instructions) - - prepared_messages = chat_client.prepare_messages([ChatMessage(role=Role.USER, text="Hello")], chat_options) - run_options, _ = chat_client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage] - - assert run_options.get("instructions") == instructions - - async def test_openai_assistants_client_get_assistant_id_or_create_existing_assistant( mock_async_openai: MagicMock, ) -> None: diff --git a/python/packages/core/tests/openai/test_openai_chat_client.py b/python/packages/core/tests/openai/test_openai_chat_client.py index 63db8c071e..d159091311 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client.py +++ b/python/packages/core/tests/openai/test_openai_chat_client.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import json import os from typing import Annotated from unittest.mock import MagicMock, patch @@ -100,18 +99,6 @@ def test_init_base_url_from_settings_env() -> None: assert str(client.client.base_url) == "https://custom-openai-endpoint.com/v1/" -def test_openai_chat_client_instructions_sent_once(openai_unit_test_env: dict[str, str]) -> None: - """Ensure instructions are only included once for OpenAI chat requests.""" - client = OpenAIChatClient() - instructions = "You are a helpful assistant." - chat_options = ChatOptions(instructions=instructions) - - prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options) - request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage] - - assert json.dumps(request_options).count(instructions) == 1 - - @pytest.mark.parametrize("exclude_list", [["OPENAI_CHAT_MODEL_ID"]], indirect=True) def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None: with pytest.raises(ServiceInitializationError): diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index c6ce05c2a3..cb4f0dc0d3 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -2,7 +2,6 @@ import asyncio import base64 -import json import os from typing import Annotated from unittest.mock import MagicMock, patch @@ -138,18 +137,6 @@ def test_init_with_default_header(openai_unit_test_env: dict[str, str]) -> None: assert openai_responses_client.client.default_headers[key] == value -def test_openai_responses_client_instructions_sent_once(openai_unit_test_env: dict[str, str]) -> None: - """Ensure instructions are only included once for OpenAI Responses requests.""" - client = OpenAIResponsesClient() - instructions = "You are a helpful assistant." - chat_options = ChatOptions(instructions=instructions) - - prepared_messages = client.prepare_messages([ChatMessage(role="user", text="Hello")], chat_options) - request_options = client._prepare_options(prepared_messages, chat_options) # type: ignore[reportPrivateUsage] - - assert json.dumps(request_options).count(instructions) == 1 - - @pytest.mark.parametrize("exclude_list", [["OPENAI_RESPONSES_MODEL_ID"]], indirect=True) def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None: with pytest.raises(ServiceInitializationError):