From 0cd40f8354e91f3e89cbc96e2d8740cc40a6af2b Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Fri, 20 Mar 2026 01:43:37 +0100 Subject: [PATCH] Python: [BREAKING] Refactor middleware layering and split Anthropic raw client (#4746) * [BREAKING] Refactor middleware layering and raw clients Reorder chat client layers so function invocation wraps chat middleware, and chat middleware stays outside telemetry while still running for each inner model call. Add middleware pipeline caching, refresh docs and samples, and split Anthropic into raw and public clients to match the standard layering model. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Tighten typing ignores in ancillary modules Add targeted typing ignores in workflow visualization and lab modules so pyright stays clean alongside the middleware refactor work. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix categorize_middleware to unpack tuple/Sequence and use relative MRO assertions - Broaden isinstance check in categorize_middleware from list to Sequence so tuples and other Sequence types are properly unpacked instead of being appended as a single item. - Replace fragile hardcoded MRO index assertions in anthropic test with relative ordering via mro.index(). - Add regression tests for categorize_middleware with tuple, list, and None inputs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix middleware string decomposition, add middleware param to FunctionInvocationLayer, and add tests (#4710) - Guard categorize_middleware Sequence check against str/bytes to prevent character-by-character decomposition of accidentally passed strings - Add explicit middleware parameter to FunctionInvocationLayer.get_response and merge it into client_kwargs before categorization, fixing the inconsistency where only OpenAIChatClient supported this parameter - Add assertions that RawAnthropicClient does not inherit convenience layers - Add chat middleware cache test with non-empty base middleware - Add tests for single unwrapped middleware item and string input Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Apply pre-commit auto-fixes * Address review feedback for #4710: review comment fixes --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot --- .../ag-ui/agent_framework_ag_ui/_client.py | 2 +- python/packages/ag-ui/tests/ag_ui/conftest.py | 4 +- .../agent_framework_anthropic/__init__.py | 3 +- .../agent_framework_anthropic/_chat_client.py | 131 ++++++++-- .../anthropic/tests/test_anthropic_client.py | 20 +- .../agent_framework_azure_ai/_chat_client.py | 2 +- .../agent_framework_azure_ai/_client.py | 8 +- .../tests/test_azure_ai_agent_client.py | 6 + .../agent_framework_bedrock/_chat_client.py | 2 +- .../packages/core/agent_framework/_clients.py | 11 +- .../core/agent_framework/_middleware.py | 64 +++-- .../packages/core/agent_framework/_tools.py | 64 +++-- .../core/agent_framework/_workflows/_viz.py | 4 +- .../agent_framework/azure/_chat_client.py | 2 +- .../azure/_responses_client.py | 2 +- .../core/agent_framework/observability.py | 18 +- .../openai/_assistants_client.py | 2 +- .../agent_framework/openai/_chat_client.py | 25 +- .../openai/_responses_client.py | 8 +- python/packages/core/tests/core/conftest.py | 4 +- .../packages/core/tests/core/test_clients.py | 3 +- .../core/test_function_invocation_logic.py | 10 +- .../test_kwargs_propagation_to_ai_function.py | 2 +- .../core/tests/core/test_middleware.py | 47 ++++ .../tests/core/test_middleware_with_agent.py | 91 +++++++ .../tests/core/test_middleware_with_chat.py | 242 +++++++++++++++++- .../core/tests/core/test_observability.py | 8 +- .../_foundry_local_client.py | 2 +- .../lab/gaia/agent_framework_lab_gaia/gaia.py | 2 +- .../agent_framework_lab_lightning/__init__.py | 4 +- .../agent_framework_ollama/_chat_client.py | 2 +- .../orchestrations/tests/test_handoff.py | 16 +- python/samples/02-agents/auto_retry.py | 10 +- .../chat_client/custom_chat_client.py | 7 +- python/samples/02-agents/middleware/README.md | 37 +++ .../agent_and_run_level_middleware.py | 8 +- .../middleware/usage_tracking_middleware.py | 185 +++++++++++++ .../advanced_manual_setup_console_output.py | 14 +- .../observability/advanced_zero_code.py | 10 +- .../02-agents/providers/custom/README.md | 8 +- .../agent_with_local_tools/main.py | 1 - 41 files changed, 936 insertions(+), 155 deletions(-) create mode 100644 python/samples/02-agents/middleware/README.md create mode 100644 python/samples/02-agents/middleware/usage_tracking_middleware.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index d2fb59bbb6..7a1b974a38 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -111,8 +111,8 @@ def _apply_server_function_call_unwrap(client: BaseChatClientT) -> BaseChatClien @_apply_server_function_call_unwrap class AGUIChatClient( - ChatMiddlewareLayer[AGUIChatOptionsT], FunctionInvocationLayer[AGUIChatOptionsT], + ChatMiddlewareLayer[AGUIChatOptionsT], ChatTelemetryLayer[AGUIChatOptionsT], BaseChatClient[AGUIChatOptionsT], Generic[AGUIChatOptionsT], diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index 42a6967371..744196dbdf 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -45,8 +45,8 @@ def pytest_configure() -> None: class StreamingChatClientStub( - ChatMiddlewareLayer[OptionsCoT], FunctionInvocationLayer[OptionsCoT], + ChatMiddlewareLayer[OptionsCoT], ChatTelemetryLayer[OptionsCoT], BaseChatClient[OptionsCoT], Generic[OptionsCoT], @@ -54,7 +54,7 @@ class StreamingChatClientStub( """Typed streaming stub that satisfies SupportsChatGetResponse.""" def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None: - super().__init__(function_middleware=[]) + super().__init__(middleware=[]) self._stream_fn = stream_fn self._response_fn = response_fn self.last_session: AgentSession | None = None diff --git a/python/packages/anthropic/agent_framework_anthropic/__init__.py b/python/packages/anthropic/agent_framework_anthropic/__init__.py index 706740a127..ad0cff9648 100644 --- a/python/packages/anthropic/agent_framework_anthropic/__init__.py +++ b/python/packages/anthropic/agent_framework_anthropic/__init__.py @@ -2,7 +2,7 @@ import importlib.metadata -from ._chat_client import AnthropicChatOptions, AnthropicClient +from ._chat_client import AnthropicChatOptions, AnthropicClient, RawAnthropicClient try: __version__ = importlib.metadata.version(__name__) @@ -12,5 +12,6 @@ except importlib.metadata.PackageNotFoundError: __all__ = [ "AnthropicChatOptions", "AnthropicClient", + "RawAnthropicClient", "__version__", ] diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index a1915a69fb..b3b61a4640 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -68,6 +68,7 @@ else: __all__ = [ "AnthropicChatOptions", "AnthropicClient", + "RawAnthropicClient", "ThinkingConfig", ] @@ -210,14 +211,24 @@ class AnthropicSettings(TypedDict, total=False): chat_model_id: str | None -class AnthropicClient( - ChatMiddlewareLayer[AnthropicOptionsT], - FunctionInvocationLayer[AnthropicOptionsT], - ChatTelemetryLayer[AnthropicOptionsT], +class RawAnthropicClient( BaseChatClient[AnthropicOptionsT], Generic[AnthropicOptionsT], ): - """Anthropic Chat client with middleware, telemetry, and function invocation support.""" + """Raw Anthropic chat client without middleware, telemetry, or function invocation support. + + Warning: + **This class should not normally be used directly.** It does not include middleware, + telemetry, or function invocation support that you most likely need. If you do use it, + you should consider which additional layers to apply. There is a defined ordering that + you should follow: + + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry + + Use ``AnthropicClient`` instead for a fully-featured client with all layers applied. + """ OTEL_PROVIDER_NAME: ClassVar[str] = "anthropic" # type: ignore[reportIncompatibleVariableOverride, misc] @@ -229,12 +240,10 @@ class AnthropicClient( anthropic_client: AsyncAnthropic | None = None, additional_beta_flags: list[str] | None = None, additional_properties: dict[str, Any] | None = None, - middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, - function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, ) -> None: - """Initialize an Anthropic Agent client. + """Initialize a raw Anthropic client. Keyword Args: api_key: The Anthropic API key to use for authentication. @@ -245,15 +254,13 @@ class AnthropicClient( additional_beta_flags: Additional beta flags to enable on the client. Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25". additional_properties: Additional properties stored on the client instance. - middleware: Optional middleware to apply to the client. - function_invocation_configuration: Optional function invocation configuration override. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. Examples: .. code-block:: python - from agent_framework.anthropic import AnthropicClient + from agent_framework.anthropic import RawAnthropicClient from azure.identity.aio import DefaultAzureCredential # Using environment variables @@ -261,13 +268,13 @@ class AnthropicClient( # ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929 # Or passing parameters directly - client = AnthropicClient( + client = RawAnthropicClient( model_id="claude-sonnet-4-5-20250929", api_key="your_anthropic_api_key", ) # Or loading from a .env file - client = AnthropicClient(env_file_path="path/to/.env") + client = RawAnthropicClient(env_file_path="path/to/.env") # Or passing in an existing client from anthropic import AsyncAnthropic @@ -275,7 +282,7 @@ class AnthropicClient( anthropic_client = AsyncAnthropic( api_key="your_anthropic_api_key", base_url="https://custom-anthropic-endpoint.com" ) - client = AnthropicClient( + client = RawAnthropicClient( model_id="claude-sonnet-4-5-20250929", anthropic_client=anthropic_client, ) @@ -289,7 +296,7 @@ class AnthropicClient( my_custom_option: str - client: AnthropicClient[MyOptions] = AnthropicClient(model_id="claude-sonnet-4-5-20250929") + client: RawAnthropicClient[MyOptions] = RawAnthropicClient(model_id="claude-sonnet-4-5-20250929") response = await client.get_response("Hello", options={"my_custom_option": "value"}) """ @@ -320,8 +327,6 @@ class AnthropicClient( # Initialize parent super().__init__( additional_properties=additional_properties, - middleware=middleware, - function_invocation_configuration=function_invocation_configuration, ) # Initialize instance variables @@ -1376,3 +1381,95 @@ class AnthropicClient( The service URL for the chat client, or None if not set. """ return str(self.anthropic_client.base_url) + + +class AnthropicClient( + FunctionInvocationLayer[AnthropicOptionsT], + ChatMiddlewareLayer[AnthropicOptionsT], + ChatTelemetryLayer[AnthropicOptionsT], + RawAnthropicClient[AnthropicOptionsT], + Generic[AnthropicOptionsT], +): + """Anthropic chat client with middleware, telemetry, and function invocation support.""" + + def __init__( + self, + *, + api_key: str | None = None, + model_id: str | None = None, + anthropic_client: AsyncAnthropic | None = None, + additional_beta_flags: list[str] | None = None, + additional_properties: dict[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + function_invocation_configuration: FunctionInvocationConfiguration | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize an Anthropic client. + + Keyword Args: + api_key: The Anthropic API key to use for authentication. + model_id: The ID of the model to use. + anthropic_client: An existing Anthropic client to use. If not provided, one will be created. + This can be used to further configure the client before passing it in. + For instance if you need to set a different base_url for testing or private deployments. + additional_beta_flags: Additional beta flags to enable on the client. + Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25". + additional_properties: Additional properties stored on the client instance. + middleware: Optional middleware to apply to the client. + function_invocation_configuration: Optional function invocation configuration override. + env_file_path: Path to environment file for loading settings. + env_file_encoding: Encoding of the environment file. + + Examples: + .. code-block:: python + + from agent_framework.anthropic import AnthropicClient + + # Using environment variables + # Set ANTHROPIC_API_KEY=your_anthropic_api_key + # ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929 + + # Or passing parameters directly + client = AnthropicClient( + model_id="claude-sonnet-4-5-20250929", + api_key="your_anthropic_api_key", + ) + + # Or loading from a .env file + client = AnthropicClient(env_file_path="path/to/.env") + + # Or passing in an existing client + from anthropic import AsyncAnthropic + + anthropic_client = AsyncAnthropic( + api_key="your_anthropic_api_key", base_url="https://custom-anthropic-endpoint.com" + ) + client = AnthropicClient( + model_id="claude-sonnet-4-5-20250929", + anthropic_client=anthropic_client, + ) + + # Using custom ChatOptions with type safety: + from typing import TypedDict + from agent_framework.anthropic import AnthropicChatOptions + + + class MyOptions(AnthropicChatOptions, total=False): + my_custom_option: str + + + client: AnthropicClient[MyOptions] = AnthropicClient(model_id="claude-sonnet-4-5-20250929") + response = await client.get_response("Hello", options={"my_custom_option": "value"}) + """ + super().__init__( + api_key=api_key, + model_id=model_id, + anthropic_client=anthropic_client, + additional_beta_flags=additional_beta_flags, + additional_properties=additional_properties, + middleware=middleware, + function_invocation_configuration=function_invocation_configuration, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py index 272239b1d7..258cc275ca 100644 --- a/python/packages/anthropic/tests/test_anthropic_client.py +++ b/python/packages/anthropic/tests/test_anthropic_client.py @@ -6,15 +6,18 @@ from unittest.mock import MagicMock, patch import pytest from agent_framework import ( + ChatMiddlewareLayer, ChatOptions, ChatResponseUpdate, Content, + FunctionInvocationLayer, Message, SupportsChatGetResponse, tool, ) from agent_framework._settings import load_settings from agent_framework._tools import SHELL_TOOL_KIND_VALUE +from agent_framework.observability import ChatTelemetryLayer from anthropic.types.beta import ( BetaMessage, BetaTextBlock, @@ -23,7 +26,7 @@ from anthropic.types.beta import ( ) from pydantic import BaseModel, Field -from agent_framework_anthropic import AnthropicClient +from agent_framework_anthropic import AnthropicClient, RawAnthropicClient from agent_framework_anthropic._chat_client import AnthropicSettings # Test constants @@ -64,6 +67,8 @@ def create_test_anthropic_client( client.additional_beta_flags = [] client.chat_middleware = [] client.function_middleware = [] + client._cached_chat_middleware_pipeline = None + client._cached_function_middleware_pipeline = None client.function_invocation_configuration = normalize_function_invocation_configuration(None) return client @@ -117,6 +122,19 @@ def test_anthropic_client_init_with_client(mock_anthropic_client: MagicMock) -> assert isinstance(client, SupportsChatGetResponse) +def test_anthropic_client_wraps_raw_client_with_standard_layer_order() -> None: + """Test AnthropicClient composes the standard public layer stack around the raw client.""" + assert issubclass(AnthropicClient, RawAnthropicClient) + mro = AnthropicClient.__mro__ + assert mro.index(FunctionInvocationLayer) < mro.index(ChatMiddlewareLayer) + assert mro.index(ChatMiddlewareLayer) < mro.index(ChatTelemetryLayer) + assert mro.index(ChatTelemetryLayer) < mro.index(RawAnthropicClient) + # RawAnthropicClient must not include the convenience layers + assert not issubclass(RawAnthropicClient, FunctionInvocationLayer) + assert not issubclass(RawAnthropicClient, ChatMiddlewareLayer) + assert not issubclass(RawAnthropicClient, ChatTelemetryLayer) + + def test_anthropic_client_init_auto_create_client( anthropic_unit_test_env: dict[str, str], ) -> None: diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index d349ef3247..63db1663d8 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -206,8 +206,8 @@ AzureAIAgentOptionsT = TypeVar( class AzureAIAgentClient( - ChatMiddlewareLayer[AzureAIAgentOptionsT], FunctionInvocationLayer[AzureAIAgentOptionsT], + ChatMiddlewareLayer[AzureAIAgentOptionsT], ChatTelemetryLayer[AzureAIAgentOptionsT], BaseChatClient[AzureAIAgentOptionsT], Generic[AzureAIAgentOptionsT], diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 1fc6c7c1c9..34ac6f29a5 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -97,9 +97,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[ you should consider which additional layers to apply. There is a defined ordering that you should follow: - 1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware - 2. **FunctionInvocationLayer** - Handles tool/function calling loop - 3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry Use ``AzureAIClient`` instead for a fully-featured client with all layers applied. """ @@ -1214,8 +1214,8 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[ class AzureAIClient( - ChatMiddlewareLayer[AzureAIClientOptionsT], FunctionInvocationLayer[AzureAIClientOptionsT], + ChatMiddlewareLayer[AzureAIClientOptionsT], ChatTelemetryLayer[AzureAIClientOptionsT], RawAzureAIClient[AzureAIClientOptionsT], Generic[AzureAIClientOptionsT], diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index afa073c6ab..65922e76b2 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -87,6 +87,8 @@ def create_test_azure_ai_chat_client( client.middleware = None client.chat_middleware = [] client.function_middleware = [] + client._cached_chat_middleware_pipeline = None + client._cached_function_middleware_pipeline = None client.otel_provider_name = "azure.ai" client.function_invocation_configuration = { "enabled": True, @@ -151,6 +153,10 @@ def test_azure_ai_chat_client_init_auto_create_client( chat_client.agent_name = None chat_client.additional_properties = {} chat_client.middleware = None + chat_client.chat_middleware = [] + chat_client.function_middleware = [] + chat_client._cached_chat_middleware_pipeline = None + chat_client._cached_function_middleware_pipeline = None assert chat_client.agents_client is mock_agents_client assert chat_client.agent_id is None diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index c546ef5535..0aefbe12f3 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -216,8 +216,8 @@ class BedrockSettings(TypedDict, total=False): class BedrockChatClient( - ChatMiddlewareLayer[BedrockChatOptionsT], FunctionInvocationLayer[BedrockChatOptionsT], + ChatMiddlewareLayer[BedrockChatOptionsT], ChatTelemetryLayer[BedrockChatOptionsT], BaseChatClient[BedrockChatOptionsT], Generic[BedrockChatOptionsT], diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 4fd563d3e0..66740f5bf8 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -966,16 +966,7 @@ def _apply_get_response_docstrings() -> None: from .observability import ChatTelemetryLayer apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response) - apply_layered_docstring( - FunctionInvocationLayer.get_response, - ChatTelemetryLayer.get_response, - extra_keyword_args={ - "function_middleware": """ - Optional per-call function middleware. - When omitted, middleware configured on the client or forwarded from higher layers is used. - """, - }, - ) + apply_layered_docstring(FunctionInvocationLayer.get_response, ChatTelemetryLayer.get_response) apply_layered_docstring( ChatMiddlewareLayer.get_response, FunctionInvocationLayer.get_response, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 66845a2e9d..381482b91a 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -742,12 +742,17 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): middleware: The list of agent middleware to include in the pipeline. """ super().__init__() + self._source_middleware: tuple[AgentMiddlewareTypes, ...] = tuple(middleware) self._middleware: list[AgentMiddleware] = [] if middleware: for mdlware in middleware: self._register_middleware(mdlware) + def matches(self, middleware: Sequence[AgentMiddlewareTypes]) -> bool: + """Return whether this pipeline was built from the provided middleware sequence.""" + return self._source_middleware == tuple(middleware) + def _register_middleware(self, middleware: AgentMiddlewareTypes) -> None: """Register an agent middleware item. @@ -824,12 +829,17 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): middleware: The list of function middleware to include in the pipeline. """ super().__init__() + self._source_middleware: tuple[FunctionMiddlewareTypes, ...] = tuple(middleware) self._middleware: list[FunctionMiddleware] = [] if middleware: for mdlware in middleware: self._register_middleware(mdlware) + def matches(self, middleware: Sequence[FunctionMiddlewareTypes]) -> bool: + """Return whether this pipeline was built from the provided middleware sequence.""" + return self._source_middleware == tuple(middleware) + def _register_middleware(self, middleware: FunctionMiddlewareTypes) -> None: """Register a function middleware item. @@ -892,12 +902,17 @@ class ChatMiddlewarePipeline(BaseMiddlewarePipeline): middleware: The list of chat middleware to include in the pipeline. """ super().__init__() + self._source_middleware: tuple[ChatMiddlewareTypes, ...] = tuple(middleware) self._middleware: list[ChatMiddleware] = [] if middleware: for mdlware in middleware: self._register_middleware(mdlware) + def matches(self, middleware: Sequence[ChatMiddlewareTypes]) -> bool: + """Return whether this pipeline was built from the provided middleware sequence.""" + return self._source_middleware == tuple(middleware) + def _register_middleware(self, middleware: ChatMiddlewareTypes) -> None: """Register a chat middleware item. @@ -980,16 +995,26 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]): def __init__( self, *, - middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + middleware: Sequence[ChatMiddlewareTypes] | None = None, **kwargs: Any, ) -> None: - middleware_list = categorize_middleware(*(middleware or [])) - self.chat_middleware = middleware_list["chat"] - if "function_middleware" in kwargs and middleware_list["function"]: - raise ValueError("Cannot specify 'function_middleware' and 'middleware' at the same time.") - kwargs["function_middleware"] = middleware_list["function"] + self.chat_middleware = list(middleware) if middleware else [] + self._cached_chat_middleware_pipeline: ChatMiddlewarePipeline | None = None super().__init__(**kwargs) + def _get_chat_middleware_pipeline( + self, + middleware: Sequence[ChatMiddlewareTypes], + ) -> ChatMiddlewarePipeline: + effective_middleware = [*self.chat_middleware, *middleware] + if self._cached_chat_middleware_pipeline is not None and self._cached_chat_middleware_pipeline.matches( + effective_middleware + ): + return self._cached_chat_middleware_pipeline + + self._cached_chat_middleware_pipeline = ChatMiddlewarePipeline(*effective_middleware) + return self._cached_chat_middleware_pipeline + @overload def get_response( self, @@ -1052,14 +1077,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]): kwargs["tokenizer"] = tokenizer effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} - call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", [])) - middleware = categorize_middleware(call_middleware) - effective_client_kwargs["function_middleware"] = middleware["function"] - - pipeline = ChatMiddlewarePipeline( - *self.chat_middleware, - *middleware["chat"], - ) + call_middleware = effective_client_kwargs.pop("middleware", []) + pipeline = self._get_chat_middleware_pipeline(call_middleware) # type: ignore[reportUnknownArgumentType] if not pipeline.has_middlewares: return super_get_response( # type: ignore[no-any-return] messages=messages, @@ -1134,12 +1153,25 @@ class AgentMiddlewareLayer: ) -> None: middleware_list = categorize_middleware(middleware) self.agent_middleware = middleware_list["agent"] + self._cached_agent_middleware_pipeline: AgentMiddlewarePipeline | None = None # Pass middleware to super so BaseAgent can store it for dynamic rebuild super().__init__(*args, middleware=middleware, **kwargs) # type: ignore[call-arg] # Note: We intentionally don't extend client's middleware lists here. # Chat and function middleware is passed to the chat client at runtime via kwargs # in AgentMiddlewareLayer.run(), where it's properly combined with run-level middleware. + def _get_agent_middleware_pipeline( + self, + middleware: Sequence[AgentMiddlewareTypes], + ) -> AgentMiddlewarePipeline: + if self._cached_agent_middleware_pipeline is not None and self._cached_agent_middleware_pipeline.matches( + middleware + ): + return self._cached_agent_middleware_pipeline + + self._cached_agent_middleware_pipeline = AgentMiddlewarePipeline(*middleware) + return self._cached_agent_middleware_pipeline + @overload def run( self, @@ -1210,7 +1242,7 @@ class AgentMiddlewareLayer: ) base_middleware_list = categorize_middleware(base_middleware) run_middleware_list = categorize_middleware(middleware) - pipeline = AgentMiddlewarePipeline(*base_middleware_list["agent"], *run_middleware_list["agent"]) + pipeline = self._get_agent_middleware_pipeline([*base_middleware_list["agent"], *run_middleware_list["agent"]]) # Combine base and run-level function/chat middleware for forwarding to chat client combined_function_chat_middleware = ( @@ -1392,7 +1424,7 @@ def categorize_middleware( all_middleware: list[Any] = [] for source in middleware_sources: if source: - if isinstance(source, list): + if isinstance(source, Sequence) and not isinstance(source, (str, bytes)): all_middleware.extend(source) # type: ignore else: all_middleware.append(source) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index c9810771de..cf7384588f 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -63,7 +63,12 @@ if TYPE_CHECKING: from ._clients import SupportsChatGetResponse from ._compaction import CompactionStrategy, TokenizerProtocol from ._mcp import MCPTool - from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._middleware import ( + ChatAndFunctionMiddlewareTypes, + FunctionInvocationContext, + FunctionMiddlewarePipeline, + FunctionMiddlewareTypes, + ) from ._sessions import AgentSession from ._types import ( ChatOptions, @@ -2024,18 +2029,37 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): def __init__( self, *, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, **kwargs: Any, ) -> None: - self.function_middleware: list[FunctionMiddlewareTypes] = ( - list(function_middleware) if function_middleware else [] - ) + from ._middleware import categorize_middleware + + middleware_list = categorize_middleware(middleware) + self.function_middleware: list[FunctionMiddlewareTypes] = list(middleware_list["function"]) + self._cached_function_middleware_pipeline: FunctionMiddlewarePipeline | None = None self.function_invocation_configuration = normalize_function_invocation_configuration( function_invocation_configuration ) + if (chat_middleware := (middleware_list["chat"] or None)) is not None: + kwargs["middleware"] = chat_middleware super().__init__(**kwargs) + def _get_function_middleware_pipeline( + self, + middleware: Sequence[FunctionMiddlewareTypes], + ) -> FunctionMiddlewarePipeline: + from ._middleware import FunctionMiddlewarePipeline + + effective_middleware = [*self.function_middleware, *middleware] + if self._cached_function_middleware_pipeline is not None and self._cached_function_middleware_pipeline.matches( + effective_middleware + ): + return self._cached_function_middleware_pipeline + + self._cached_function_middleware_pipeline = FunctionMiddlewarePipeline(*effective_middleware) + return self._cached_function_middleware_pipeline + @overload def get_response( self, @@ -2043,6 +2067,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, @@ -2057,6 +2082,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): *, stream: Literal[False] = ..., options: OptionsCoT | ChatOptions[None] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, @@ -2071,6 +2097,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): *, stream: Literal[True], options: OptionsCoT | ChatOptions[Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, @@ -2084,14 +2111,14 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): *, stream: bool = False, options: OptionsCoT | ChatOptions[Any] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - from ._middleware import FunctionMiddlewarePipeline + from ._middleware import categorize_middleware from ._types import ( ChatResponse, ChatResponseUpdate, @@ -2109,16 +2136,21 @@ class FunctionInvocationLayer(Generic[OptionsCoT]): ) effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} - effective_function_middleware = function_middleware - if effective_function_middleware is None: - middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None) - if middleware_from_client_kwargs is not None: - effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs) + if middleware is not None: + existing = effective_client_kwargs.get("middleware", []) + effective_client_kwargs["middleware"] = [ + *( + existing + if isinstance(existing, Sequence) and not isinstance(existing, (str, bytes)) + else [existing] + ), + *middleware, + ] + runtime_middleware = categorize_middleware(effective_client_kwargs.pop("middleware", [])) - # ChatMiddleware adds this kwarg - function_middleware_pipeline = FunctionMiddlewarePipeline( - *(self.function_middleware), *(effective_function_middleware or []) - ) + function_middleware_pipeline = self._get_function_middleware_pipeline(runtime_middleware["function"]) + if runtime_middleware["chat"]: + effective_client_kwargs["middleware"] = runtime_middleware["chat"] max_errors = self.function_invocation_configuration.get( "max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST ) diff --git a/python/packages/core/agent_framework/_workflows/_viz.py b/python/packages/core/agent_framework/_workflows/_viz.py index 0fcf8af32d..54015b066c 100644 --- a/python/packages/core/agent_framework/_workflows/_viz.py +++ b/python/packages/core/agent_framework/_workflows/_viz.py @@ -109,7 +109,7 @@ class WorkflowViz: # Create a temporary graphviz Source object dot_content = self.to_digraph(include_internal_executors=include_internal_executors) - source = graphviz.Source(dot_content) + source = graphviz.Source(dot_content) # type: ignore[reportUnknownVariableType] try: if filename: @@ -131,7 +131,7 @@ class WorkflowViz: source.render(base_name, format=format, cleanup=True) # type: ignore return f"{base_name}.{format}" - except graphviz.backend.execute.ExecutableNotFound as e: + except graphviz.backend.execute.ExecutableNotFound as e: # type: ignore raise ImportError( "The graphviz executables are not found. The graphviz Python package is installed, but the " "graphviz executables (dot, neato, etc.) are not available on your system's PATH. " diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index 2ae21d124c..ef598ebe21 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -152,8 +152,8 @@ AzureOpenAIChatClientT = TypeVar("AzureOpenAIChatClientT", bound="AzureOpenAICha class AzureOpenAIChatClient( # type: ignore[misc] AzureOpenAIConfigMixin, - ChatMiddlewareLayer[AzureOpenAIChatOptionsT], FunctionInvocationLayer[AzureOpenAIChatOptionsT], + ChatMiddlewareLayer[AzureOpenAIChatOptionsT], ChatTelemetryLayer[AzureOpenAIChatOptionsT], RawOpenAIChatClient[AzureOpenAIChatOptionsT], Generic[AzureOpenAIChatOptionsT], diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index 192576bd04..8387e49591 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -51,8 +51,8 @@ AzureOpenAIResponsesOptionsT = TypeVar( class AzureOpenAIResponsesClient( # type: ignore[misc] AzureOpenAIConfigMixin, - ChatMiddlewareLayer[AzureOpenAIResponsesOptionsT], FunctionInvocationLayer[AzureOpenAIResponsesOptionsT], + ChatMiddlewareLayer[AzureOpenAIResponsesOptionsT], ChatTelemetryLayer[AzureOpenAIResponsesOptionsT], RawOpenAIResponsesClient[AzureOpenAIResponsesOptionsT], Generic[AzureOpenAIResponsesOptionsT], diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index dcabaae8fc..a0cbd6a1a0 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -362,11 +362,15 @@ def _create_otlp_exporters( if protocol == "grpc": # Import all gRPC exporters try: - from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter as GRPCLogExporter - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( - OTLPMetricExporter as GRPCMetricExporter, + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( # type: ignore[reportMissingImports] + OTLPLogExporter as GRPCLogExporter, # type: ignore[reportUnknownVariableType] + ) + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( # type: ignore[reportMissingImports] + OTLPMetricExporter as GRPCMetricExporter, # type: ignore[reportUnknownVariableType] + ) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( # type: ignore[reportMissingImports] + OTLPSpanExporter as GRPCSpanExporter, # type: ignore[reportUnknownVariableType] ) - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter except ImportError as exc: raise ImportError( "opentelemetry-exporter-otlp-proto-grpc is required for OTLP gRPC exporters. " @@ -375,21 +379,21 @@ def _create_otlp_exporters( if actual_logs_endpoint: exporters.append( - GRPCLogExporter( + GRPCLogExporter( # type: ignore[reportUnknownArgumentType] endpoint=actual_logs_endpoint, headers=actual_logs_headers if actual_logs_headers else None, ) ) if actual_traces_endpoint: exporters.append( - GRPCSpanExporter( + GRPCSpanExporter( # type: ignore[reportUnknownArgumentType] endpoint=actual_traces_endpoint, headers=actual_traces_headers if actual_traces_headers else None, ) ) if actual_metrics_endpoint: exporters.append( - GRPCMetricExporter( + GRPCMetricExporter( # type: ignore[reportUnknownArgumentType] endpoint=actual_metrics_endpoint, headers=actual_metrics_headers if actual_metrics_headers else None, ) diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index b1d5e8795c..9179fb4a8c 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -210,8 +210,8 @@ OpenAIAssistantsOptionsT = TypeVar( class OpenAIAssistantsClient( # type: ignore[misc] OpenAIConfigMixin, - ChatMiddlewareLayer[OpenAIAssistantsOptionsT], FunctionInvocationLayer[OpenAIAssistantsOptionsT], + ChatMiddlewareLayer[OpenAIAssistantsOptionsT], ChatTelemetryLayer[OpenAIAssistantsOptionsT], BaseChatClient[OpenAIAssistantsOptionsT], Generic[OpenAIAssistantsOptionsT], diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index bb7362a9bb..a77d44d933 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -31,7 +31,7 @@ from pydantic import BaseModel from .._clients import BaseChatClient from .._docstrings import apply_layered_docstring -from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes +from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer from .._settings import load_settings from .._tools import ( FunctionInvocationConfiguration, @@ -156,9 +156,9 @@ class RawOpenAIChatClient( # type: ignore[misc] you should consider which additional layers to apply. There is a defined ordering that you should follow: - 1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware - 2. **FunctionInvocationLayer** - Handles tool/function calling loop - 3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry Use ``OpenAIChatClient`` instead for a fully-featured client with all layers applied. """ @@ -776,8 +776,8 @@ class RawOpenAIChatClient( # type: ignore[misc] class OpenAIChatClient( # type: ignore[misc] OpenAIConfigMixin, - ChatMiddlewareLayer[OpenAIChatOptionsT], FunctionInvocationLayer[OpenAIChatOptionsT], + ChatMiddlewareLayer[OpenAIChatOptionsT], ChatTelemetryLayer[OpenAIChatOptionsT], RawOpenAIChatClient[OpenAIChatOptionsT], Generic[OpenAIChatOptionsT], @@ -791,7 +791,6 @@ class OpenAIChatClient( # type: ignore[misc] *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -805,7 +804,6 @@ class OpenAIChatClient( # type: ignore[misc] *, stream: Literal[False] = ..., options: OpenAIChatOptionsT | ChatOptions[None] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -819,7 +817,6 @@ class OpenAIChatClient( # type: ignore[misc] *, stream: Literal[True], options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -833,7 +830,6 @@ class OpenAIChatClient( # type: ignore[misc] *, stream: bool = False, options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, - function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, function_invocation_kwargs: Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, @@ -844,14 +840,15 @@ class OpenAIChatClient( # type: ignore[misc] "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", super().get_response, # type: ignore[misc] ) + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + if middleware is not None: + effective_client_kwargs["middleware"] = middleware return super_get_response( # type: ignore[no-any-return] messages=messages, stream=stream, options=options, - function_middleware=function_middleware, function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=client_kwargs, - middleware=middleware, + client_kwargs=effective_client_kwargs, **kwargs, ) @@ -967,10 +964,6 @@ def _apply_openai_chat_client_docstrings() -> None: OpenAIChatClient.get_response, RawOpenAIChatClient.get_response, extra_keyword_args={ - "function_middleware": """ - Optional per-call function middleware. - When omitted, middleware configured on the client or forwarded from higher layers is used. - """, "middleware": """ Optional per-call chat and function middleware. This is merged with any middleware configured on the client for the current request. diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 0769c3f1f9..0c57dffb39 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -249,9 +249,9 @@ class RawOpenAIResponsesClient( # type: ignore[misc] you should consider which additional layers to apply. There is a defined ordering that you should follow: - 1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware - 2. **FunctionInvocationLayer** - Handles tool/function calling loop - 3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry + 1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware + 2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry + 3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry Use ``OpenAIResponsesClient`` instead for a fully-featured client with all layers applied. """ @@ -2259,8 +2259,8 @@ class RawOpenAIResponsesClient( # type: ignore[misc] class OpenAIResponsesClient( # type: ignore[misc] OpenAIConfigMixin, - ChatMiddlewareLayer[OpenAIResponsesOptionsT], FunctionInvocationLayer[OpenAIResponsesOptionsT], + ChatMiddlewareLayer[OpenAIResponsesOptionsT], ChatTelemetryLayer[OpenAIResponsesOptionsT], RawOpenAIResponsesClient[OpenAIResponsesOptionsT], Generic[OpenAIResponsesOptionsT], diff --git a/python/packages/core/tests/core/conftest.py b/python/packages/core/tests/core/conftest.py index 2d1eec2d9a..57c0cf5217 100644 --- a/python/packages/core/tests/core/conftest.py +++ b/python/packages/core/tests/core/conftest.py @@ -128,8 +128,8 @@ class MockChatClient: class MockBaseChatClient( - ChatMiddlewareLayer[OptionsCoT], FunctionInvocationLayer[OptionsCoT], + ChatMiddlewareLayer[OptionsCoT], ChatTelemetryLayer[OptionsCoT], BaseChatClient[OptionsCoT], Generic[OptionsCoT], @@ -137,7 +137,7 @@ class MockBaseChatClient( """Mock implementation of a full-featured ChatClient.""" def __init__(self, **kwargs: Any): - super().__init__(function_middleware=[], **kwargs) + super().__init__(middleware=[], **kwargs) self.run_responses: list[ChatResponse] = [] self.streaming_responses: list[list[ChatResponseUpdate]] = [] self.call_count: int = 0 diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index 7e150c47c6..258a31d73b 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -74,8 +74,8 @@ def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs assert docstring is not None assert "Get a response from a chat client." in docstring assert "function_invocation_kwargs" in docstring - assert "function_middleware: Optional per-call function middleware." in docstring assert "middleware: Optional per-call chat and function middleware." in docstring + assert "function_middleware: Optional per-call function middleware." not in docstring def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None: @@ -84,7 +84,6 @@ def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None: signature = inspect.signature(OpenAIChatClient.get_response) assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response" - assert "function_middleware" in signature.parameters assert "middleware" in signature.parameters diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 3c61040289..d9659837a8 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3226,7 +3226,7 @@ async def test_terminate_loop_single_function_call(chat_client_base: SupportsCha response = await chat_client_base.get_response( "hello", options={"tool_choice": "auto", "tools": [ai_func]}, - middleware=[TerminateLoopMiddleware()], + client_kwargs={"middleware": [TerminateLoopMiddleware()]}, ) # Function should NOT have been executed - middleware intercepted it @@ -3292,7 +3292,7 @@ async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client response = await chat_client_base.get_response( "hello", options={"tool_choice": "auto", "tools": [normal_func, terminating_func]}, - middleware=[SelectiveTerminateMiddleware()], + client_kwargs={"middleware": [SelectiveTerminateMiddleware()]}, ) # normal_function should have executed (middleware calls next_handler) @@ -3345,7 +3345,7 @@ async def test_terminate_loop_streaming_single_function_call(chat_client_base: S async for update in chat_client_base.get_response( "hello", options={"tool_choice": "auto", "tools": [ai_func]}, - middleware=[TerminateLoopMiddleware()], + client_kwargs={"middleware": [TerminateLoopMiddleware()]}, stream=True, ): updates.append(update) @@ -3389,12 +3389,12 @@ async def test_conversation_id_updated_in_options_between_tool_iterations(): conversation_ids_received: list[str | None] = [] class TrackingChatClient( - ChatMiddlewareLayer, FunctionInvocationLayer, + ChatMiddlewareLayer, BaseChatClient, ): def __init__(self) -> None: - super().__init__(function_middleware=[]) + super().__init__(middleware=[]) self.run_responses: list[ChatResponse] = [] self.streaming_responses: list[list[ChatResponseUpdate]] = [] self.call_count: int = 0 diff --git a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py index 160ea0fcc4..11a738a0b9 100644 --- a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py +++ b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py @@ -84,8 +84,8 @@ class _MockBaseChatClient(BaseChatClient[Any]): class FunctionInvokingMockClient( - ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], + ChatMiddlewareLayer[Any], ChatTelemetryLayer[Any], _MockBaseChatClient, ): diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index 6c559c40d4..0026cbf98f 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -28,6 +28,7 @@ from agent_framework._middleware import ( FunctionMiddleware, FunctionMiddlewarePipeline, MiddlewareTermination, + categorize_middleware, ) from agent_framework._tools import FunctionTool @@ -1681,3 +1682,49 @@ def mock_chat_client() -> Any: client = MagicMock(spec=SupportsChatGetResponse) client.service_url = MagicMock(return_value="mock://test") return client + + +class TestCategorizeMiddleware: + """Test cases for categorize_middleware.""" + + def test_categorize_middleware_with_tuple(self) -> None: + """Test that tuple middleware sources are unpacked, not appended as a single item.""" + chat_mw = TestChatMiddleware() + function_mw = TestFunctionMiddleware() + agent_mw = TestAgentMiddleware() + result = categorize_middleware((chat_mw, function_mw, agent_mw)) + assert result["chat"] == [chat_mw] + assert result["function"] == [function_mw] + assert result["agent"] == [agent_mw] + + def test_categorize_middleware_with_list(self) -> None: + """Test that list middleware sources are unpacked correctly.""" + chat_mw = TestChatMiddleware() + function_mw = TestFunctionMiddleware() + result = categorize_middleware([chat_mw, function_mw]) + assert result["chat"] == [chat_mw] + assert result["function"] == [function_mw] + assert result["agent"] == [] + + def test_categorize_middleware_with_none(self) -> None: + """Test that None middleware sources are handled.""" + result = categorize_middleware(None) + assert result["chat"] == [] + assert result["function"] == [] + assert result["agent"] == [] + + def test_categorize_middleware_with_single_item(self) -> None: + """Test that a single unwrapped middleware item is appended correctly.""" + chat_mw = TestChatMiddleware() + result = categorize_middleware(chat_mw) + assert result["chat"] == [chat_mw] + assert result["function"] == [] + assert result["agent"] == [] + + def test_categorize_middleware_with_string_does_not_decompose(self) -> None: + """Test that a string is not decomposed character-by-character.""" + result = categorize_middleware("not_a_middleware") + # String should be treated as a single item, not decomposed into characters + total_items = len(result["chat"]) + len(result["function"]) + len(result["agent"]) + assert total_items == 1 + assert result["agent"] == ["not_a_middleware"] diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index bfe5ec1293..6470a8202e 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -697,6 +697,26 @@ class TestChatAgentFunctionMiddlewareWithTools: assert function_calls[0].name == "sample_tool_function" assert function_results[0].call_id == function_calls[0].call_id + def test_agent_middleware_pipeline_cache_reuses_matching_middleware(self) -> None: + """Test that identical agent middleware sets reuse the cached pipeline.""" + + @agent_middleware + async def first_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @agent_middleware + async def second_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + agent = Agent(client=MockBaseChatClient()) + + first_pipeline = agent._get_agent_middleware_pipeline([first_middleware]) + second_pipeline = agent._get_agent_middleware_pipeline([first_middleware]) + third_pipeline = agent._get_agent_middleware_pipeline([second_middleware]) + + assert first_pipeline is second_pipeline + assert third_pipeline is not first_pipeline + async def test_function_middleware_can_access_and_override_custom_kwargs( self, chat_client_base: "MockBaseChatClient" ) -> None: @@ -1969,6 +1989,77 @@ class TestChatAgentChatMiddleware: "agent_middleware_after", ] + async def test_combined_middleware_with_tool_loop(self) -> None: + """Test Agent middleware ordering when tool calls trigger multiple chat rounds.""" + execution_order: list[str] = [] + chat_round = 0 + client = MockBaseChatClient() + client.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_123", + name="sample_tool_function", + arguments='{"location": "Seattle"}', + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Final response")]), + ] + + async def tracking_agent_middleware( + context: AgentContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("agent_middleware_before") + await call_next() + execution_order.append("agent_middleware_after") + + async def tracking_chat_middleware( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + nonlocal chat_round + chat_round += 1 + execution_order.append(f"chat_middleware_before_{chat_round}") + await call_next() + execution_order.append(f"chat_middleware_after_{chat_round}") + + async def tracking_function_middleware( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("function_middleware_before") + await call_next() + execution_order.append("function_middleware_after") + + agent = Agent( + client=client, + middleware=[tracking_chat_middleware, tracking_function_middleware, tracking_agent_middleware], + tools=[sample_tool_function], + ) + + response = await agent.run([Message(role="user", text="test")]) + + assert response is not None + assert client.call_count == 2 + assert response.messages[-1].text == "Final response" + assert execution_order == [ + "agent_middleware_before", + "chat_middleware_before_1", + "chat_middleware_after_1", + "function_middleware_before", + "function_middleware_after", + "chat_middleware_before_2", + "chat_middleware_after_2", + "agent_middleware_after", + ] + async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None: """Test that agent middleware can access and override custom parameters like temperature.""" captured_kwargs: dict[str, Any] = {} diff --git a/python/packages/core/tests/core/test_middleware_with_chat.py b/python/packages/core/tests/core/test_middleware_with_chat.py index 62a168ccb0..5fa9d64031 100644 --- a/python/packages/core/tests/core/test_middleware_with_chat.py +++ b/python/packages/core/tests/core/test_middleware_with_chat.py @@ -274,7 +274,10 @@ class TestChatMiddleware: # First call with run-level middleware messages = [Message(role="user", text="first message")] - response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware]) + response1 = await chat_client_base.get_response( + messages, + client_kwargs={"middleware": [counting_middleware]}, + ) assert response1 is not None assert execution_count["count"] == 1 @@ -286,7 +289,10 @@ class TestChatMiddleware: # Third call with run-level middleware again - should execute messages = [Message(role="user", text="third message")] - response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware]) + response3 = await chat_client_base.get_response( + messages, + client_kwargs={"middleware": [counting_middleware]}, + ) assert response3 is not None assert execution_count["count"] == 2 # Should be 2 now @@ -335,6 +341,81 @@ class TestChatMiddleware: assert modified_kwargs["new_param"] == "added_by_middleware" assert modified_kwargs["custom_param"] == "test_value" # Should still be there + def test_chat_middleware_pipeline_cache_reuses_matching_middleware( + self, + chat_client_base: "MockBaseChatClient", + ) -> None: + """Test that identical chat middleware sets reuse the cached pipeline.""" + + @chat_middleware + async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @chat_middleware + async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + first_pipeline = chat_client_base._get_chat_middleware_pipeline([first_middleware]) + second_pipeline = chat_client_base._get_chat_middleware_pipeline([first_middleware]) + third_pipeline = chat_client_base._get_chat_middleware_pipeline([second_middleware]) + + assert first_pipeline is second_pipeline + assert third_pipeline is not first_pipeline + + def test_chat_middleware_pipeline_cache_includes_base_middleware( + self, + chat_client_base: "MockBaseChatClient", + ) -> None: + """Test that chat middleware cache key includes base middleware to prevent incorrect reuse.""" + + @chat_middleware + async def base_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @chat_middleware + async def runtime_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + # Without base middleware + pipeline_no_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware]) + + # With base middleware + chat_client_base.chat_middleware = [base_middleware] + pipeline_with_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware]) + + assert pipeline_with_base is not pipeline_no_base + + def test_function_middleware_pipeline_cache_reuses_matching_middleware( + self, + chat_client_base: "MockBaseChatClient", + ) -> None: + """Test that identical function middleware sets reuse the cached pipeline.""" + + @function_middleware + async def base_middleware(context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None: + await call_next() + + @function_middleware + async def first_runtime_middleware( + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] + ) -> None: + await call_next() + + @function_middleware + async def second_runtime_middleware( + context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]] + ) -> None: + await call_next() + + chat_client_base.function_middleware = [base_middleware] + + first_pipeline = chat_client_base._get_function_middleware_pipeline([first_runtime_middleware]) + second_pipeline = chat_client_base._get_function_middleware_pipeline([first_runtime_middleware]) + third_pipeline = chat_client_base._get_function_middleware_pipeline([second_runtime_middleware]) + + assert first_pipeline is second_pipeline + assert third_pipeline is not first_pipeline + async def test_function_middleware_registration_on_chat_client( self, chat_client_base: "MockBaseChatClient" ) -> None: @@ -450,7 +531,9 @@ class TestChatMiddleware: # Execute the chat client directly with run-level middleware and tools messages = [Message(role="user", text="What's the weather in New York?")] response = await client.get_response( - messages, options={"tools": [sample_tool_wrapped]}, middleware=[run_level_function_middleware] + messages, + options={"tools": [sample_tool_wrapped]}, + client_kwargs={"middleware": [run_level_function_middleware]}, ) # Verify response @@ -463,3 +546,156 @@ class TestChatMiddleware: "run_level_function_middleware_before", "run_level_function_middleware_after", ] + + async def test_run_level_chat_and_function_middleware_split_per_function_loop_round(self) -> None: + """Test mixed run-level middleware is split so chat middleware runs per model call.""" + execution_order: list[str] = [] + chat_round = 0 + + @chat_middleware + async def run_level_chat_middleware( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + nonlocal chat_round + chat_round += 1 + execution_order.append(f"chat_middleware_before_{chat_round}") + await call_next() + execution_order.append(f"chat_middleware_after_{chat_round}") + + @function_middleware + async def run_level_function_middleware( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("function_middleware_before") + await call_next() + execution_order.append("function_middleware_after") + + def sample_tool(location: str) -> str: + """Get weather for a location.""" + return f"Weather in {location}: sunny" + + sample_tool_wrapped = FunctionTool( + func=sample_tool, + name="sample_tool", + description="Get weather for a location", + approval_mode="never_require", + ) + + client = MockBaseChatClient() + client.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_3", + name="sample_tool", + arguments={"location": "Seattle"}, + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Based on the weather data, it's sunny!")]), + ] + + response = await client.get_response( + [Message(role="user", text="What's the weather in Seattle?")], + options={"tools": [sample_tool_wrapped]}, + client_kwargs={"middleware": [run_level_chat_middleware, run_level_function_middleware]}, + ) + + assert response is not None + assert client.call_count == 2 + assert response.messages[-1].text == "Based on the weather data, it's sunny!" + assert execution_order == [ + "chat_middleware_before_1", + "chat_middleware_after_1", + "function_middleware_before", + "function_middleware_after", + "chat_middleware_before_2", + "chat_middleware_after_2", + ] + + async def test_run_level_chat_and_function_middleware_split_per_function_loop_round_streaming(self) -> None: + """Test mixed run-level middleware is split so chat middleware runs per model call in streaming mode.""" + execution_order: list[str] = [] + chat_round = 0 + + @chat_middleware + async def run_level_chat_middleware( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + nonlocal chat_round + chat_round += 1 + execution_order.append(f"chat_middleware_before_{chat_round}") + await call_next() + execution_order.append(f"chat_middleware_after_{chat_round}") + + @function_middleware + async def run_level_function_middleware( + context: FunctionInvocationContext, + call_next: Callable[[], Awaitable[None]], + ) -> None: + execution_order.append("function_middleware_before") + await call_next() + execution_order.append("function_middleware_after") + + def sample_tool(location: str) -> str: + """Get weather for a location.""" + return f"Weather in {location}: sunny" + + sample_tool_wrapped = FunctionTool( + func=sample_tool, + name="sample_tool", + description="Get weather for a location", + approval_mode="never_require", + ) + + client = MockBaseChatClient() + client.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[ + Content.from_function_call( + call_id="call_3", + name="sample_tool", + arguments='{"location": "Seattle"}', + ) + ], + role="assistant", + finish_reason="tool_calls", + ), + ], + [ + ChatResponseUpdate( + contents=[Content.from_text("Based on the weather data, it's sunny!")], + role="assistant", + finish_reason="stop", + ), + ], + ] + + updates: list[ChatResponseUpdate] = [] + async for update in client.get_response( + [Message(role="user", text="What's the weather in Seattle?")], + options={"tools": [sample_tool_wrapped]}, + client_kwargs={"middleware": [run_level_chat_middleware, run_level_function_middleware]}, + stream=True, + ): + updates.append(update) + + assert client.call_count == 2 + assert len(updates) > 0 + assert execution_order == [ + "chat_middleware_before_1", + "chat_middleware_after_1", + "function_middleware_before", + "function_middleware_after", + "chat_middleware_before_2", + "chat_middleware_after_2", + ] diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 367e32bf92..7982985b94 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -2437,7 +2437,7 @@ def test_capture_response(span_exporter: InMemorySpanExporter): async def test_layer_ordering_span_sequence_with_function_calling(span_exporter: InMemorySpanExporter): """Test that with correct layer ordering, spans appear in the expected sequence. - When using the correct layer ordering (ChatMiddlewareLayer, FunctionInvocationLayer, + When using the correct layer ordering (FunctionInvocationLayer, ChatMiddlewareLayer, ChatTelemetryLayer, BaseChatClient), the spans should appear in this order: 1. First 'chat' span (initial LLM call that returns function call) 2. 'execute_tool' span (function invocation) @@ -2454,11 +2454,11 @@ async def test_layer_ordering_span_sequence_with_function_calling(span_exporter: def get_weather(location: str) -> str: return f"The weather in {location} is sunny." - # Correct layer ordering: FunctionInvocationLayer BEFORE ChatTelemetryLayer - # This ensures each inner LLM call gets its own telemetry span + # Correct layer ordering: FunctionInvocationLayer BEFORE ChatMiddlewareLayer BEFORE ChatTelemetryLayer + # This ensures each inner LLM call traverses chat middleware and still gets its own telemetry span class MockChatClientWithLayers( - ChatMiddlewareLayer, FunctionInvocationLayer, + ChatMiddlewareLayer, ChatTelemetryLayer, BaseChatClient, ): diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 4c1e64cd7c..2566d031aa 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -130,8 +130,8 @@ class FoundryLocalSettings(TypedDict, total=False): class FoundryLocalClient( - ChatMiddlewareLayer[FoundryLocalChatOptionsT], FunctionInvocationLayer[FoundryLocalChatOptionsT], + ChatMiddlewareLayer[FoundryLocalChatOptionsT], ChatTelemetryLayer[FoundryLocalChatOptionsT], RawOpenAIChatClient[FoundryLocalChatOptionsT], Generic[FoundryLocalChatOptionsT], diff --git a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py index 07b1945882..266ae8a107 100644 --- a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py +++ b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py @@ -273,7 +273,7 @@ def _load_gaia_local(repo_dir: Path, wanted_levels: list[int] | None = None, max for p in parquet_files: try: - import pyarrow.parquet as pq + import pyarrow.parquet as pq # type: ignore[reportMissingImports] pq_any = cast(Any, pq) table: Any = pq_any.read_table(p) diff --git a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py index 9526498cc2..3da1121910 100644 --- a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py +++ b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py @@ -7,8 +7,8 @@ from __future__ import annotations import importlib.metadata from agent_framework.observability import enable_instrumentation -from agentlightning.tracer import ( - AgentOpsTracer, # pyright: ignore[reportMissingImports] # type: ignore[import-not-found] +from agentlightning.tracer import ( # type: ignore[reportMissingImports] + AgentOpsTracer, # type: ignore[reportMissingImports, import-not-found] ) try: diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index b931c89499..0c7f232797 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -285,8 +285,8 @@ logger = logging.getLogger("agent_framework.ollama") class OllamaChatClient( - ChatMiddlewareLayer[OllamaChatOptionsT], FunctionInvocationLayer[OllamaChatOptionsT], + ChatMiddlewareLayer[OllamaChatOptionsT], ChatTelemetryLayer[OllamaChatOptionsT], BaseChatClient[OllamaChatOptionsT], ): diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index 43c2f9153a..5c594ed537 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -33,7 +33,7 @@ from agent_framework_orchestrations._handoff import ( from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff -class MockChatClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): +class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): """Mock chat client for testing handoff workflows.""" def __init__( @@ -134,7 +134,7 @@ class MockHandoffAgent(Agent): super().__init__(client=MockChatClient(name=name, handoff_to=handoff_to), name=name, id=name) -class ContextAwareRefundClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): +class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): """Mock client that expects prior user context to remain available on resume.""" def __init__(self) -> None: @@ -298,7 +298,7 @@ async def test_tool_approval_responses_are_not_replayed_from_history() -> None: execution_count += 1 return "ok" - class ApprovalReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class ApprovalReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -383,7 +383,7 @@ async def test_handoff_resume_preserves_approval_function_call_for_stateless_run def submit_refund() -> str: return "ok" - class StrictStatelessApprovalClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class StrictStatelessApprovalClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -475,7 +475,7 @@ async def test_handoff_resume_preserves_approval_function_call_for_stateless_run async def test_handoff_replay_serializes_handoff_function_results() -> None: """Returning to the same agent must not replay dict tool outputs.""" - class ReplaySafeHandoffClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class ReplaySafeHandoffClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self, name: str, handoff_sequence: list[str | None]) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -550,7 +550,7 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs( def submit_refund() -> str: return "submitted" - class RefundReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class RefundReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -608,7 +608,7 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs( return _get() - class OrderReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class OrderReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) @@ -907,7 +907,7 @@ async def test_handoff_async_termination_condition() -> None: async def test_handoff_terminates_without_request_info_when_latest_response_meets_condition() -> None: """Termination triggered by the latest assistant response should not emit request_info.""" - class FinalizingClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + class FinalizingClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): def __init__(self) -> None: ChatMiddlewareLayer.__init__(self) FunctionInvocationLayer.__init__(self) diff --git a/python/samples/02-agents/auto_retry.py b/python/samples/02-agents/auto_retry.py index 7c985bd0c1..0a5169ad3d 100644 --- a/python/samples/02-agents/auto_retry.py +++ b/python/samples/02-agents/auto_retry.py @@ -114,10 +114,11 @@ class RetryingAzureOpenAIChatClient(AzureOpenAIChatClient): class RateLimitRetryMiddleware(ChatMiddleware): - """Chat middleware that retries the full request pipeline on rate limit errors. + """Chat middleware that retries a single model-call pipeline on rate limit errors. Register this middleware on an agent (or at the run level) to automatically - retry any call_next() invocation that raises RateLimitError. + retry any chat-model call that raises RateLimitError. In tool-loop scenarios, + the middleware applies independently to each inner model call. """ def __init__(self, *, max_attempts: int = RETRY_ATTEMPTS) -> None: @@ -154,8 +155,9 @@ async def rate_limit_retry_middleware( """Function-based chat middleware that retries on rate limit errors. Wrap call_next() with a tenacity @retry decorator so any RateLimitError - raised during model inference triggers an automatic retry with exponential - back-off. + raised during a single model call triggers an automatic retry with exponential + back-off. In tool-loop scenarios, the middleware applies independently to + each inner model call. """ @retry( diff --git a/python/samples/02-agents/chat_client/custom_chat_client.py b/python/samples/02-agents/chat_client/custom_chat_client.py index 5adcf50d15..7a9aaa95f6 100644 --- a/python/samples/02-agents/chat_client/custom_chat_client.py +++ b/python/samples/02-agents/chat_client/custom_chat_client.py @@ -29,7 +29,10 @@ else: Custom Chat Client Implementation Example This sample demonstrates implementing a custom chat client and optionally composing -middleware, telemetry, and function invocation layers explicitly. +middleware, telemetry, and function invocation layers explicitly. The recommended +layer order is `FunctionInvocationLayer -> ChatMiddlewareLayer -> ChatTelemetryLayer` +so chat middleware runs within each tool-loop iteration while telemetry records +per-call spans without middleware latency. """ @@ -124,9 +127,9 @@ class EchoingChatClient(BaseChatClient[OptionsT]): class EchoingChatClientWithLayers( # type: ignore[misc] + FunctionInvocationLayer[OptionsT], ChatMiddlewareLayer[OptionsT], ChatTelemetryLayer[OptionsT], - FunctionInvocationLayer[OptionsT], EchoingChatClient, ): """Echoing chat client that explicitly composes middleware, telemetry, and function layers.""" diff --git a/python/samples/02-agents/middleware/README.md b/python/samples/02-agents/middleware/README.md new file mode 100644 index 0000000000..754f96e815 --- /dev/null +++ b/python/samples/02-agents/middleware/README.md @@ -0,0 +1,37 @@ +# Middleware samples + +This folder contains focused middleware samples for `Agent`, chat clients, tools, sessions, and runtime context behavior. + +## Files + +| File | Description | +|------|-------------| +| [`agent_and_run_level_middleware.py`](./agent_and_run_level_middleware.py) | Demonstrates combining agent-level and run-level middleware. | +| [`chat_middleware.py`](./chat_middleware.py) | Shows class-based and function-based chat middleware that can observe, modify, and override model calls. | +| [`class_based_middleware.py`](./class_based_middleware.py) | Shows class-based agent and function middleware. | +| [`decorator_middleware.py`](./decorator_middleware.py) | Demonstrates middleware registration with decorators. | +| [`exception_handling_with_middleware.py`](./exception_handling_with_middleware.py) | Shows how middleware can handle failures and recover cleanly. | +| [`function_based_middleware.py`](./function_based_middleware.py) | Shows function-based agent and function middleware. | +| [`middleware_termination.py`](./middleware_termination.py) | Demonstrates stopping a middleware pipeline early. | +| [`override_result_with_middleware.py`](./override_result_with_middleware.py) | Shows how middleware can replace the normal result. | +| [`runtime_context_delegation.py`](./runtime_context_delegation.py) | Demonstrates delegating work with runtime context data. | +| [`session_behavior_middleware.py`](./session_behavior_middleware.py) | Shows how middleware interacts with session-backed runs. | +| [`shared_state_middleware.py`](./shared_state_middleware.py) | Demonstrates sharing mutable state across middleware invocations. | +| [`usage_tracking_middleware.py`](./usage_tracking_middleware.py) | Demonstrates one chat middleware function that tracks per-call usage in non-streaming and streaming tool-loop runs. | + +## Running the usage tracking sample + +The new usage tracking sample uses `OpenAIResponsesClient`, so set the usual OpenAI responses environment variables first: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +export OPENAI_RESPONSES_MODEL_ID="gpt-4.1-mini" +``` + +Then run: + +```bash +uv run samples/02-agents/middleware/usage_tracking_middleware.py +``` + +The sample forces a tool call so you can see middleware output for each inner model call in both non-streaming and streaming modes. diff --git a/python/samples/02-agents/middleware/agent_and_run_level_middleware.py b/python/samples/02-agents/middleware/agent_and_run_level_middleware.py index 55ccce3507..158d90daee 100644 --- a/python/samples/02-agents/middleware/agent_and_run_level_middleware.py +++ b/python/samples/02-agents/middleware/agent_and_run_level_middleware.py @@ -51,10 +51,10 @@ Agent Middleware Execution Order: - Run middleware wraps only the agent for that specific run - Each middleware can modify the context before AND after calling next() - Note: Function and chat middleware (e.g., ``function_logging_middleware``) execute - during tool invocation *inside* the agent execution, not in the outer agent-middleware - chain shown above. They follow the same ordering principle: agent-level function/chat - middleware runs before run-level function/chat middleware. + Note: Function middleware executes during tool invocation, and chat middleware + executes around each model call inside the agent execution, not in the outer + agent-middleware chain shown above. They follow the same ordering principle: + agent-level function/chat middleware runs before run-level function/chat middleware. """ diff --git a/python/samples/02-agents/middleware/usage_tracking_middleware.py b/python/samples/02-agents/middleware/usage_tracking_middleware.py new file mode 100644 index 0000000000..877d2a8a82 --- /dev/null +++ b/python/samples/02-agents/middleware/usage_tracking_middleware.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +This sample demonstrates a single chat middleware that tracks per-model-call usage +for both non-streaming and streaming tool-loop runs. +""" + +import asyncio +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + Agent, + ChatContext, + ChatResponse, + ChatResponseUpdate, + ResponseStream, + chat_middleware, + tool, +) +from agent_framework.openai import OpenAIResponsesClient +from dotenv import load_dotenv +from pydantic import Field + +# Load environment variables from .env file +load_dotenv() + + +NON_STREAMING_CALL_COUNT = 0 +STREAMING_CALL_COUNT = 0 + + +# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; +# see samples/02-agents/tools/function_tool_with_approval.py +# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +@tool(approval_mode="never_require") +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +def _reset_usage_counters() -> None: + """Reset call counters between sample runs.""" + global NON_STREAMING_CALL_COUNT, STREAMING_CALL_COUNT + NON_STREAMING_CALL_COUNT = 0 + STREAMING_CALL_COUNT = 0 + + +def _create_agent( +) -> Agent: + """Create the shared agent used by both demonstrations.""" + return Agent( + client=OpenAIResponsesClient(), + instructions=( + "You are a weather assistant. Always call the weather tool before answering weather questions, " + "then summarize the tool result in one short paragraph." + ), + tools=[get_weather], + middleware=[print_usage], + ) + + +@chat_middleware +async def print_usage( + context: ChatContext, + call_next: Callable[[], Awaitable[None]], +) -> None: + """Print usage for each inner model call in both non-streaming and streaming runs.""" + global NON_STREAMING_CALL_COUNT, STREAMING_CALL_COUNT + + if context.stream: + STREAMING_CALL_COUNT += 1 + call_number = STREAMING_CALL_COUNT + usage_seen_in_updates = False + + def capture_usage_update(update: ChatResponseUpdate) -> ChatResponseUpdate: + nonlocal usage_seen_in_updates + + for content in update.contents: + if content.type == "usage": + usage_seen_in_updates = True + print(f"\n[Streaming model call #{call_number}] Usage update: {content.usage_details}") + return update + + def capture_final_usage(result: ChatResponse) -> ChatResponse: + if not usage_seen_in_updates and result.usage_details: + print(f"\n[Streaming model call #{call_number}] Final usage: {result.usage_details}") + return result + + context.stream_transform_hooks.append(capture_usage_update) + context.stream_result_hooks.append(capture_final_usage) + await call_next() + return + + NON_STREAMING_CALL_COUNT += 1 + call_number = NON_STREAMING_CALL_COUNT + + await call_next() + + response = context.result + if isinstance(response, ChatResponse) and response.usage_details: + print(f"[Non-streaming model call #{call_number}] Usage: {response.usage_details}") + + +async def non_streaming_usage_example() -> None: + """Run the non-streaming usage tracking example.""" + _reset_usage_counters() + print("\n=== Non-streaming per-call usage tracking ===") + + # 1. Create an agent with middleware that prints usage after each inner model call. + agent = _create_agent() + + # 2. Run a weather question and require a tool call so the function loop performs multiple model calls. + query = "What is the weather in Seattle, and should I bring an umbrella?" + print(f"User: {query}") + result = await agent.run( + query, + options={"tool_choice": "required"}, + ) + + # 3. Print the final user-visible answer after the middleware already logged per-call usage. + print(f"Assistant: {result.text}") + + +async def streaming_usage_example() -> None: + """Run the streaming usage tracking example.""" + _reset_usage_counters() + print("\n=== Streaming per-call usage tracking ===") + + # 1. Create an agent with middleware that watches streaming usage for each inner model call. + agent = _create_agent() + + # 2. Start a streaming run and force tool usage so the function loop performs multiple model calls. + query = "What is the weather in Portland, and should I bring a jacket?" + print(f"User: {query}") + print("Assistant: ", end="", flush=True) + stream: ResponseStream = agent.run( + query, + stream=True, + options={"tool_choice": "required"}, + ) + + # 3. Consume the stream normally while the middleware reports usage in the background. + async for update in stream: + if update.text: + print(update.text, end="", flush=True) + print() + + # 4. Finalize the stream so you can inspect the final response if needed. + final_response = await stream.get_final_response() + print(f"Final assistant message: {final_response.text}") + + +async def main() -> None: + """Run both usage tracking demonstrations.""" + print("=== Usage Tracking Middleware Example ===") + + await non_streaming_usage_example() + await streaming_usage_example() + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Usage Tracking Middleware Example === + +=== Non-streaming per-call usage tracking === +User: What is the weather in Seattle, and should I bring an umbrella? +[Non-streaming model call #1] Usage: {'input_tokens': ..., 'output_tokens': ..., ...} +[Non-streaming model call #2] Usage: {'input_tokens': ..., 'output_tokens': ..., ...} +Assistant: Based on the weather in Seattle, ... + +=== Streaming per-call usage tracking === +User: What is the weather in Portland, and should I bring a jacket? +Assistant: Based on the weather in Portland, ... +[Streaming model call #1] Usage update: {'input_tokens': ..., 'output_tokens': ..., ...} +[Streaming model call #2] Usage update: {'input_tokens': ..., 'output_tokens': ..., ...} +Final assistant message: Based on the weather in Portland, ... +""" diff --git a/python/samples/02-agents/observability/advanced_manual_setup_console_output.py b/python/samples/02-agents/observability/advanced_manual_setup_console_output.py index 7afd359264..af7fcc6287 100644 --- a/python/samples/02-agents/observability/advanced_manual_setup_console_output.py +++ b/python/samples/02-agents/observability/advanced_manual_setup_console_output.py @@ -96,10 +96,16 @@ async def run_chat_client() -> None: stream: Whether to use streaming for the plugin Remarks: - When function calling is outside the open telemetry loop - each of the call to the model is handled as a seperate span, - while when the open telemetry is put last, a single span - is shown, which might include one or more rounds of function calling. + By default, the built-in non-`Raw...Client` chat clients already compose + the layers in this order: + `FunctionInvocationLayer -> ChatMiddlewareLayer -> ChatTelemetryLayer -> Raw/Base client`. + + When `FunctionInvocationLayer` is outside `ChatTelemetryLayer`, + each call to the model is handled as a separate span. + Keep `ChatMiddlewareLayer` outside telemetry + so middleware latency does not skew those timings. + By contrast, when telemetry is placed outside the function loop, + a single span can cover one or more rounds of function calling. So for the scenario below, you should see the following: diff --git a/python/samples/02-agents/observability/advanced_zero_code.py b/python/samples/02-agents/observability/advanced_zero_code.py index 477a5b4d9b..981b14a0e6 100644 --- a/python/samples/02-agents/observability/advanced_zero_code.py +++ b/python/samples/02-agents/observability/advanced_zero_code.py @@ -71,10 +71,12 @@ async def run_chat_client(client: "SupportsChatGetResponse", stream: bool = Fals stream: Whether to use streaming for the plugin Remarks: - When function calling is outside the open telemetry loop - each of the call to the model is handled as a separate span, - while when the open telemetry is put last, a single span - is shown, which might include one or more rounds of function calling. + When `FunctionInvocationLayer` is outside `ChatTelemetryLayer`, + each call to the model is handled as a separate span. + If `ChatMiddlewareLayer` is present, keep it outside telemetry + so middleware latency does not skew those timings. + By contrast, when telemetry is placed outside the function loop, + a single span can cover one or more rounds of function calling. So for the scenario below, you should see the following: diff --git a/python/samples/02-agents/providers/custom/README.md b/python/samples/02-agents/providers/custom/README.md index f2d67e0315..ac58a77e69 100644 --- a/python/samples/02-agents/providers/custom/README.md +++ b/python/samples/02-agents/providers/custom/README.md @@ -37,17 +37,17 @@ The framework provides `Raw...Client` classes (e.g., `RawOpenAIChatClient`, `Raw There is a defined ordering for applying layers that you should follow: -1. **ChatMiddlewareLayer** - Should be applied **first** because it also prepares function middleware -2. **FunctionInvocationLayer** - Handles tool/function calling loop -3. **ChatTelemetryLayer** - Must be **inside** the function calling loop for correct per-call telemetry +1. **FunctionInvocationLayer** - Handles the tool/function calling loop and should stay outermost +2. **ChatMiddlewareLayer** - Wraps each model call in the loop and stays outside telemetry +3. **ChatTelemetryLayer** - Must be inside the function calling loop so each model call gets its own telemetry span 4. **Raw...Client** - The base implementation (e.g., `RawOpenAIChatClient`) Example of correct layer composition: ```python class MyCustomClient( - ChatMiddlewareLayer[TOptions], FunctionInvocationLayer[TOptions], + ChatMiddlewareLayer[TOptions], ChatTelemetryLayer[TOptions], RawOpenAIChatClient[TOptions], # or BaseChatClient for custom implementations Generic[TOptions], diff --git a/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py b/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py index 1faad6a2e9..4c60902dc2 100644 --- a/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py +++ b/python/samples/05-end-to-end/hosted_agents/agent_with_local_tools/main.py @@ -16,7 +16,6 @@ from azure.ai.agentserver.agentframework import from_agent_framework from azure.identity.aio import AzureCliCredential, ManagedIdentityCredential from dotenv import load_dotenv - load_dotenv(override=True) # Configure these for your Foundry project