Python: [BREAKING] Refactor middleware layering and split Anthropic raw client (#4746)

* [BREAKING] Refactor middleware layering and raw clients

Reorder chat client layers so function invocation wraps chat middleware, and chat middleware stays outside telemetry while still running for each inner model call. Add middleware pipeline caching, refresh docs and samples, and split Anthropic into raw and public clients to match the standard layering model.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Tighten typing ignores in ancillary modules

Add targeted typing ignores in workflow visualization and lab modules so pyright stays clean alongside the middleware refactor work.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix categorize_middleware to unpack tuple/Sequence and use relative MRO assertions

- Broaden isinstance check in categorize_middleware from list to Sequence
  so tuples and other Sequence types are properly unpacked instead of
  being appended as a single item.
- Replace fragile hardcoded MRO index assertions in anthropic test with
  relative ordering via mro.index().
- Add regression tests for categorize_middleware with tuple, list, and
  None inputs.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix middleware string decomposition, add middleware param to FunctionInvocationLayer, and add tests (#4710)

- Guard categorize_middleware Sequence check against str/bytes to prevent
  character-by-character decomposition of accidentally passed strings
- Add explicit middleware parameter to FunctionInvocationLayer.get_response
  and merge it into client_kwargs before categorization, fixing the
  inconsistency where only OpenAIChatClient supported this parameter
- Add assertions that RawAnthropicClient does not inherit convenience layers
- Add chat middleware cache test with non-empty base middleware
- Add tests for single unwrapped middleware item and string input

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Apply pre-commit auto-fixes

* Apply pre-commit auto-fixes

* Address review feedback for #4710: review comment fixes

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
Eduard van Valkenburg
2026-03-20 01:43:37 +01:00
committed by GitHub
Unverified
parent cefda44283
commit 0cd40f8354
41 changed files with 936 additions and 155 deletions
@@ -111,8 +111,8 @@ def _apply_server_function_call_unwrap(client: BaseChatClientT) -> BaseChatClien
@_apply_server_function_call_unwrap
class AGUIChatClient(
ChatMiddlewareLayer[AGUIChatOptionsT],
FunctionInvocationLayer[AGUIChatOptionsT],
ChatMiddlewareLayer[AGUIChatOptionsT],
ChatTelemetryLayer[AGUIChatOptionsT],
BaseChatClient[AGUIChatOptionsT],
Generic[AGUIChatOptionsT],
@@ -45,8 +45,8 @@ def pytest_configure() -> None:
class StreamingChatClientStub(
ChatMiddlewareLayer[OptionsCoT],
FunctionInvocationLayer[OptionsCoT],
ChatMiddlewareLayer[OptionsCoT],
ChatTelemetryLayer[OptionsCoT],
BaseChatClient[OptionsCoT],
Generic[OptionsCoT],
@@ -54,7 +54,7 @@ class StreamingChatClientStub(
"""Typed streaming stub that satisfies SupportsChatGetResponse."""
def __init__(self, stream_fn: StreamFn, response_fn: ResponseFn | None = None) -> None:
super().__init__(function_middleware=[])
super().__init__(middleware=[])
self._stream_fn = stream_fn
self._response_fn = response_fn
self.last_session: AgentSession | None = None
@@ -2,7 +2,7 @@
import importlib.metadata
from ._chat_client import AnthropicChatOptions, AnthropicClient
from ._chat_client import AnthropicChatOptions, AnthropicClient, RawAnthropicClient
try:
__version__ = importlib.metadata.version(__name__)
@@ -12,5 +12,6 @@ except importlib.metadata.PackageNotFoundError:
__all__ = [
"AnthropicChatOptions",
"AnthropicClient",
"RawAnthropicClient",
"__version__",
]
@@ -68,6 +68,7 @@ else:
__all__ = [
"AnthropicChatOptions",
"AnthropicClient",
"RawAnthropicClient",
"ThinkingConfig",
]
@@ -210,14 +211,24 @@ class AnthropicSettings(TypedDict, total=False):
chat_model_id: str | None
class AnthropicClient(
ChatMiddlewareLayer[AnthropicOptionsT],
FunctionInvocationLayer[AnthropicOptionsT],
ChatTelemetryLayer[AnthropicOptionsT],
class RawAnthropicClient(
BaseChatClient[AnthropicOptionsT],
Generic[AnthropicOptionsT],
):
"""Anthropic Chat client with middleware, telemetry, and function invocation support."""
"""Raw Anthropic chat client without middleware, telemetry, or function invocation support.
Warning:
**This class should not normally be used directly.** It does not include middleware,
telemetry, or function invocation support that you most likely need. If you do use it,
you should consider which additional layers to apply. There is a defined ordering that
you should follow:
1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware
2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry
3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry
Use ``AnthropicClient`` instead for a fully-featured client with all layers applied.
"""
OTEL_PROVIDER_NAME: ClassVar[str] = "anthropic" # type: ignore[reportIncompatibleVariableOverride, misc]
@@ -229,12 +240,10 @@ class AnthropicClient(
anthropic_client: AsyncAnthropic | None = None,
additional_beta_flags: list[str] | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize an Anthropic Agent client.
"""Initialize a raw Anthropic client.
Keyword Args:
api_key: The Anthropic API key to use for authentication.
@@ -245,15 +254,13 @@ class AnthropicClient(
additional_beta_flags: Additional beta flags to enable on the client.
Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25".
additional_properties: Additional properties stored on the client instance.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
Examples:
.. code-block:: python
from agent_framework.anthropic import AnthropicClient
from agent_framework.anthropic import RawAnthropicClient
from azure.identity.aio import DefaultAzureCredential
# Using environment variables
@@ -261,13 +268,13 @@ class AnthropicClient(
# ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929
# Or passing parameters directly
client = AnthropicClient(
client = RawAnthropicClient(
model_id="claude-sonnet-4-5-20250929",
api_key="your_anthropic_api_key",
)
# Or loading from a .env file
client = AnthropicClient(env_file_path="path/to/.env")
client = RawAnthropicClient(env_file_path="path/to/.env")
# Or passing in an existing client
from anthropic import AsyncAnthropic
@@ -275,7 +282,7 @@ class AnthropicClient(
anthropic_client = AsyncAnthropic(
api_key="your_anthropic_api_key", base_url="https://custom-anthropic-endpoint.com"
)
client = AnthropicClient(
client = RawAnthropicClient(
model_id="claude-sonnet-4-5-20250929",
anthropic_client=anthropic_client,
)
@@ -289,7 +296,7 @@ class AnthropicClient(
my_custom_option: str
client: AnthropicClient[MyOptions] = AnthropicClient(model_id="claude-sonnet-4-5-20250929")
client: RawAnthropicClient[MyOptions] = RawAnthropicClient(model_id="claude-sonnet-4-5-20250929")
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
@@ -320,8 +327,6 @@ class AnthropicClient(
# Initialize parent
super().__init__(
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
)
# Initialize instance variables
@@ -1376,3 +1381,95 @@ class AnthropicClient(
The service URL for the chat client, or None if not set.
"""
return str(self.anthropic_client.base_url)
class AnthropicClient(
FunctionInvocationLayer[AnthropicOptionsT],
ChatMiddlewareLayer[AnthropicOptionsT],
ChatTelemetryLayer[AnthropicOptionsT],
RawAnthropicClient[AnthropicOptionsT],
Generic[AnthropicOptionsT],
):
"""Anthropic chat client with middleware, telemetry, and function invocation support."""
def __init__(
self,
*,
api_key: str | None = None,
model_id: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
additional_beta_flags: list[str] | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize an Anthropic client.
Keyword Args:
api_key: The Anthropic API key to use for authentication.
model_id: The ID of the model to use.
anthropic_client: An existing Anthropic client to use. If not provided, one will be created.
This can be used to further configure the client before passing it in.
For instance if you need to set a different base_url for testing or private deployments.
additional_beta_flags: Additional beta flags to enable on the client.
Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25".
additional_properties: Additional properties stored on the client instance.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
Examples:
.. code-block:: python
from agent_framework.anthropic import AnthropicClient
# Using environment variables
# Set ANTHROPIC_API_KEY=your_anthropic_api_key
# ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929
# Or passing parameters directly
client = AnthropicClient(
model_id="claude-sonnet-4-5-20250929",
api_key="your_anthropic_api_key",
)
# Or loading from a .env file
client = AnthropicClient(env_file_path="path/to/.env")
# Or passing in an existing client
from anthropic import AsyncAnthropic
anthropic_client = AsyncAnthropic(
api_key="your_anthropic_api_key", base_url="https://custom-anthropic-endpoint.com"
)
client = AnthropicClient(
model_id="claude-sonnet-4-5-20250929",
anthropic_client=anthropic_client,
)
# Using custom ChatOptions with type safety:
from typing import TypedDict
from agent_framework.anthropic import AnthropicChatOptions
class MyOptions(AnthropicChatOptions, total=False):
my_custom_option: str
client: AnthropicClient[MyOptions] = AnthropicClient(model_id="claude-sonnet-4-5-20250929")
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
super().__init__(
api_key=api_key,
model_id=model_id,
anthropic_client=anthropic_client,
additional_beta_flags=additional_beta_flags,
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
@@ -6,15 +6,18 @@ from unittest.mock import MagicMock, patch
import pytest
from agent_framework import (
ChatMiddlewareLayer,
ChatOptions,
ChatResponseUpdate,
Content,
FunctionInvocationLayer,
Message,
SupportsChatGetResponse,
tool,
)
from agent_framework._settings import load_settings
from agent_framework._tools import SHELL_TOOL_KIND_VALUE
from agent_framework.observability import ChatTelemetryLayer
from anthropic.types.beta import (
BetaMessage,
BetaTextBlock,
@@ -23,7 +26,7 @@ from anthropic.types.beta import (
)
from pydantic import BaseModel, Field
from agent_framework_anthropic import AnthropicClient
from agent_framework_anthropic import AnthropicClient, RawAnthropicClient
from agent_framework_anthropic._chat_client import AnthropicSettings
# Test constants
@@ -64,6 +67,8 @@ def create_test_anthropic_client(
client.additional_beta_flags = []
client.chat_middleware = []
client.function_middleware = []
client._cached_chat_middleware_pipeline = None
client._cached_function_middleware_pipeline = None
client.function_invocation_configuration = normalize_function_invocation_configuration(None)
return client
@@ -117,6 +122,19 @@ def test_anthropic_client_init_with_client(mock_anthropic_client: MagicMock) ->
assert isinstance(client, SupportsChatGetResponse)
def test_anthropic_client_wraps_raw_client_with_standard_layer_order() -> None:
"""Test AnthropicClient composes the standard public layer stack around the raw client."""
assert issubclass(AnthropicClient, RawAnthropicClient)
mro = AnthropicClient.__mro__
assert mro.index(FunctionInvocationLayer) < mro.index(ChatMiddlewareLayer)
assert mro.index(ChatMiddlewareLayer) < mro.index(ChatTelemetryLayer)
assert mro.index(ChatTelemetryLayer) < mro.index(RawAnthropicClient)
# RawAnthropicClient must not include the convenience layers
assert not issubclass(RawAnthropicClient, FunctionInvocationLayer)
assert not issubclass(RawAnthropicClient, ChatMiddlewareLayer)
assert not issubclass(RawAnthropicClient, ChatTelemetryLayer)
def test_anthropic_client_init_auto_create_client(
anthropic_unit_test_env: dict[str, str],
) -> None:
@@ -206,8 +206,8 @@ AzureAIAgentOptionsT = TypeVar(
class AzureAIAgentClient(
ChatMiddlewareLayer[AzureAIAgentOptionsT],
FunctionInvocationLayer[AzureAIAgentOptionsT],
ChatMiddlewareLayer[AzureAIAgentOptionsT],
ChatTelemetryLayer[AzureAIAgentOptionsT],
BaseChatClient[AzureAIAgentOptionsT],
Generic[AzureAIAgentOptionsT],
@@ -97,9 +97,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
you should consider which additional layers to apply. There is a defined ordering that
you should follow:
1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware
2. **FunctionInvocationLayer** - Handles tool/function calling loop
3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry
1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware
2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry
3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry
Use ``AzureAIClient`` instead for a fully-featured client with all layers applied.
"""
@@ -1214,8 +1214,8 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
class AzureAIClient(
ChatMiddlewareLayer[AzureAIClientOptionsT],
FunctionInvocationLayer[AzureAIClientOptionsT],
ChatMiddlewareLayer[AzureAIClientOptionsT],
ChatTelemetryLayer[AzureAIClientOptionsT],
RawAzureAIClient[AzureAIClientOptionsT],
Generic[AzureAIClientOptionsT],
@@ -87,6 +87,8 @@ def create_test_azure_ai_chat_client(
client.middleware = None
client.chat_middleware = []
client.function_middleware = []
client._cached_chat_middleware_pipeline = None
client._cached_function_middleware_pipeline = None
client.otel_provider_name = "azure.ai"
client.function_invocation_configuration = {
"enabled": True,
@@ -151,6 +153,10 @@ def test_azure_ai_chat_client_init_auto_create_client(
chat_client.agent_name = None
chat_client.additional_properties = {}
chat_client.middleware = None
chat_client.chat_middleware = []
chat_client.function_middleware = []
chat_client._cached_chat_middleware_pipeline = None
chat_client._cached_function_middleware_pipeline = None
assert chat_client.agents_client is mock_agents_client
assert chat_client.agent_id is None
@@ -216,8 +216,8 @@ class BedrockSettings(TypedDict, total=False):
class BedrockChatClient(
ChatMiddlewareLayer[BedrockChatOptionsT],
FunctionInvocationLayer[BedrockChatOptionsT],
ChatMiddlewareLayer[BedrockChatOptionsT],
ChatTelemetryLayer[BedrockChatOptionsT],
BaseChatClient[BedrockChatOptionsT],
Generic[BedrockChatOptionsT],
@@ -966,16 +966,7 @@ def _apply_get_response_docstrings() -> None:
from .observability import ChatTelemetryLayer
apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response)
apply_layered_docstring(
FunctionInvocationLayer.get_response,
ChatTelemetryLayer.get_response,
extra_keyword_args={
"function_middleware": """
Optional per-call function middleware.
When omitted, middleware configured on the client or forwarded from higher layers is used.
""",
},
)
apply_layered_docstring(FunctionInvocationLayer.get_response, ChatTelemetryLayer.get_response)
apply_layered_docstring(
ChatMiddlewareLayer.get_response,
FunctionInvocationLayer.get_response,
@@ -742,12 +742,17 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
middleware: The list of agent middleware to include in the pipeline.
"""
super().__init__()
self._source_middleware: tuple[AgentMiddlewareTypes, ...] = tuple(middleware)
self._middleware: list[AgentMiddleware] = []
if middleware:
for mdlware in middleware:
self._register_middleware(mdlware)
def matches(self, middleware: Sequence[AgentMiddlewareTypes]) -> bool:
"""Return whether this pipeline was built from the provided middleware sequence."""
return self._source_middleware == tuple(middleware)
def _register_middleware(self, middleware: AgentMiddlewareTypes) -> None:
"""Register an agent middleware item.
@@ -824,12 +829,17 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
middleware: The list of function middleware to include in the pipeline.
"""
super().__init__()
self._source_middleware: tuple[FunctionMiddlewareTypes, ...] = tuple(middleware)
self._middleware: list[FunctionMiddleware] = []
if middleware:
for mdlware in middleware:
self._register_middleware(mdlware)
def matches(self, middleware: Sequence[FunctionMiddlewareTypes]) -> bool:
"""Return whether this pipeline was built from the provided middleware sequence."""
return self._source_middleware == tuple(middleware)
def _register_middleware(self, middleware: FunctionMiddlewareTypes) -> None:
"""Register a function middleware item.
@@ -892,12 +902,17 @@ class ChatMiddlewarePipeline(BaseMiddlewarePipeline):
middleware: The list of chat middleware to include in the pipeline.
"""
super().__init__()
self._source_middleware: tuple[ChatMiddlewareTypes, ...] = tuple(middleware)
self._middleware: list[ChatMiddleware] = []
if middleware:
for mdlware in middleware:
self._register_middleware(mdlware)
def matches(self, middleware: Sequence[ChatMiddlewareTypes]) -> bool:
"""Return whether this pipeline was built from the provided middleware sequence."""
return self._source_middleware == tuple(middleware)
def _register_middleware(self, middleware: ChatMiddlewareTypes) -> None:
"""Register a chat middleware item.
@@ -980,16 +995,26 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
def __init__(
self,
*,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
middleware: Sequence[ChatMiddlewareTypes] | None = None,
**kwargs: Any,
) -> None:
middleware_list = categorize_middleware(*(middleware or []))
self.chat_middleware = middleware_list["chat"]
if "function_middleware" in kwargs and middleware_list["function"]:
raise ValueError("Cannot specify 'function_middleware' and 'middleware' at the same time.")
kwargs["function_middleware"] = middleware_list["function"]
self.chat_middleware = list(middleware) if middleware else []
self._cached_chat_middleware_pipeline: ChatMiddlewarePipeline | None = None
super().__init__(**kwargs)
def _get_chat_middleware_pipeline(
self,
middleware: Sequence[ChatMiddlewareTypes],
) -> ChatMiddlewarePipeline:
effective_middleware = [*self.chat_middleware, *middleware]
if self._cached_chat_middleware_pipeline is not None and self._cached_chat_middleware_pipeline.matches(
effective_middleware
):
return self._cached_chat_middleware_pipeline
self._cached_chat_middleware_pipeline = ChatMiddlewarePipeline(*effective_middleware)
return self._cached_chat_middleware_pipeline
@overload
def get_response(
self,
@@ -1052,14 +1077,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
kwargs["tokenizer"] = tokenizer
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", []))
middleware = categorize_middleware(call_middleware)
effective_client_kwargs["function_middleware"] = middleware["function"]
pipeline = ChatMiddlewarePipeline(
*self.chat_middleware,
*middleware["chat"],
)
call_middleware = effective_client_kwargs.pop("middleware", [])
pipeline = self._get_chat_middleware_pipeline(call_middleware) # type: ignore[reportUnknownArgumentType]
if not pipeline.has_middlewares:
return super_get_response( # type: ignore[no-any-return]
messages=messages,
@@ -1134,12 +1153,25 @@ class AgentMiddlewareLayer:
) -> None:
middleware_list = categorize_middleware(middleware)
self.agent_middleware = middleware_list["agent"]
self._cached_agent_middleware_pipeline: AgentMiddlewarePipeline | None = None
# Pass middleware to super so BaseAgent can store it for dynamic rebuild
super().__init__(*args, middleware=middleware, **kwargs) # type: ignore[call-arg]
# Note: We intentionally don't extend client's middleware lists here.
# Chat and function middleware is passed to the chat client at runtime via kwargs
# in AgentMiddlewareLayer.run(), where it's properly combined with run-level middleware.
def _get_agent_middleware_pipeline(
self,
middleware: Sequence[AgentMiddlewareTypes],
) -> AgentMiddlewarePipeline:
if self._cached_agent_middleware_pipeline is not None and self._cached_agent_middleware_pipeline.matches(
middleware
):
return self._cached_agent_middleware_pipeline
self._cached_agent_middleware_pipeline = AgentMiddlewarePipeline(*middleware)
return self._cached_agent_middleware_pipeline
@overload
def run(
self,
@@ -1210,7 +1242,7 @@ class AgentMiddlewareLayer:
)
base_middleware_list = categorize_middleware(base_middleware)
run_middleware_list = categorize_middleware(middleware)
pipeline = AgentMiddlewarePipeline(*base_middleware_list["agent"], *run_middleware_list["agent"])
pipeline = self._get_agent_middleware_pipeline([*base_middleware_list["agent"], *run_middleware_list["agent"]])
# Combine base and run-level function/chat middleware for forwarding to chat client
combined_function_chat_middleware = (
@@ -1392,7 +1424,7 @@ def categorize_middleware(
all_middleware: list[Any] = []
for source in middleware_sources:
if source:
if isinstance(source, list):
if isinstance(source, Sequence) and not isinstance(source, (str, bytes)):
all_middleware.extend(source) # type: ignore
else:
all_middleware.append(source)
+48 -16
View File
@@ -63,7 +63,12 @@ if TYPE_CHECKING:
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._mcp import MCPTool
from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes
from ._middleware import (
ChatAndFunctionMiddlewareTypes,
FunctionInvocationContext,
FunctionMiddlewarePipeline,
FunctionMiddlewareTypes,
)
from ._sessions import AgentSession
from ._types import (
ChatOptions,
@@ -2024,18 +2029,37 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
def __init__(
self,
*,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
**kwargs: Any,
) -> None:
self.function_middleware: list[FunctionMiddlewareTypes] = (
list(function_middleware) if function_middleware else []
)
from ._middleware import categorize_middleware
middleware_list = categorize_middleware(middleware)
self.function_middleware: list[FunctionMiddlewareTypes] = list(middleware_list["function"])
self._cached_function_middleware_pipeline: FunctionMiddlewarePipeline | None = None
self.function_invocation_configuration = normalize_function_invocation_configuration(
function_invocation_configuration
)
if (chat_middleware := (middleware_list["chat"] or None)) is not None:
kwargs["middleware"] = chat_middleware
super().__init__(**kwargs)
def _get_function_middleware_pipeline(
self,
middleware: Sequence[FunctionMiddlewareTypes],
) -> FunctionMiddlewarePipeline:
from ._middleware import FunctionMiddlewarePipeline
effective_middleware = [*self.function_middleware, *middleware]
if self._cached_function_middleware_pipeline is not None and self._cached_function_middleware_pipeline.matches(
effective_middleware
):
return self._cached_function_middleware_pipeline
self._cached_function_middleware_pipeline = FunctionMiddlewarePipeline(*effective_middleware)
return self._cached_function_middleware_pipeline
@overload
def get_response(
self,
@@ -2043,6 +2067,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
@@ -2057,6 +2082,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
@@ -2071,6 +2097,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
@@ -2084,14 +2111,14 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
from ._middleware import FunctionMiddlewarePipeline
from ._middleware import categorize_middleware
from ._types import (
ChatResponse,
ChatResponseUpdate,
@@ -2109,16 +2136,21 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
)
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
effective_function_middleware = function_middleware
if effective_function_middleware is None:
middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None)
if middleware_from_client_kwargs is not None:
effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs)
if middleware is not None:
existing = effective_client_kwargs.get("middleware", [])
effective_client_kwargs["middleware"] = [
*(
existing
if isinstance(existing, Sequence) and not isinstance(existing, (str, bytes))
else [existing]
),
*middleware,
]
runtime_middleware = categorize_middleware(effective_client_kwargs.pop("middleware", []))
# ChatMiddleware adds this kwarg
function_middleware_pipeline = FunctionMiddlewarePipeline(
*(self.function_middleware), *(effective_function_middleware or [])
)
function_middleware_pipeline = self._get_function_middleware_pipeline(runtime_middleware["function"])
if runtime_middleware["chat"]:
effective_client_kwargs["middleware"] = runtime_middleware["chat"]
max_errors = self.function_invocation_configuration.get(
"max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST
)
@@ -109,7 +109,7 @@ class WorkflowViz:
# Create a temporary graphviz Source object
dot_content = self.to_digraph(include_internal_executors=include_internal_executors)
source = graphviz.Source(dot_content)
source = graphviz.Source(dot_content) # type: ignore[reportUnknownVariableType]
try:
if filename:
@@ -131,7 +131,7 @@ class WorkflowViz:
source.render(base_name, format=format, cleanup=True) # type: ignore
return f"{base_name}.{format}"
except graphviz.backend.execute.ExecutableNotFound as e:
except graphviz.backend.execute.ExecutableNotFound as e: # type: ignore
raise ImportError(
"The graphviz executables are not found. The graphviz Python package is installed, but the "
"graphviz executables (dot, neato, etc.) are not available on your system's PATH. "
@@ -152,8 +152,8 @@ AzureOpenAIChatClientT = TypeVar("AzureOpenAIChatClientT", bound="AzureOpenAICha
class AzureOpenAIChatClient( # type: ignore[misc]
AzureOpenAIConfigMixin,
ChatMiddlewareLayer[AzureOpenAIChatOptionsT],
FunctionInvocationLayer[AzureOpenAIChatOptionsT],
ChatMiddlewareLayer[AzureOpenAIChatOptionsT],
ChatTelemetryLayer[AzureOpenAIChatOptionsT],
RawOpenAIChatClient[AzureOpenAIChatOptionsT],
Generic[AzureOpenAIChatOptionsT],
@@ -51,8 +51,8 @@ AzureOpenAIResponsesOptionsT = TypeVar(
class AzureOpenAIResponsesClient( # type: ignore[misc]
AzureOpenAIConfigMixin,
ChatMiddlewareLayer[AzureOpenAIResponsesOptionsT],
FunctionInvocationLayer[AzureOpenAIResponsesOptionsT],
ChatMiddlewareLayer[AzureOpenAIResponsesOptionsT],
ChatTelemetryLayer[AzureOpenAIResponsesOptionsT],
RawOpenAIResponsesClient[AzureOpenAIResponsesOptionsT],
Generic[AzureOpenAIResponsesOptionsT],
@@ -362,11 +362,15 @@ def _create_otlp_exporters(
if protocol == "grpc":
# Import all gRPC exporters
try:
from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter as GRPCLogExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
OTLPMetricExporter as GRPCMetricExporter,
from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( # type: ignore[reportMissingImports]
OTLPLogExporter as GRPCLogExporter, # type: ignore[reportUnknownVariableType]
)
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( # type: ignore[reportMissingImports]
OTLPMetricExporter as GRPCMetricExporter, # type: ignore[reportUnknownVariableType]
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( # type: ignore[reportMissingImports]
OTLPSpanExporter as GRPCSpanExporter, # type: ignore[reportUnknownVariableType]
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
except ImportError as exc:
raise ImportError(
"opentelemetry-exporter-otlp-proto-grpc is required for OTLP gRPC exporters. "
@@ -375,21 +379,21 @@ def _create_otlp_exporters(
if actual_logs_endpoint:
exporters.append(
GRPCLogExporter(
GRPCLogExporter( # type: ignore[reportUnknownArgumentType]
endpoint=actual_logs_endpoint,
headers=actual_logs_headers if actual_logs_headers else None,
)
)
if actual_traces_endpoint:
exporters.append(
GRPCSpanExporter(
GRPCSpanExporter( # type: ignore[reportUnknownArgumentType]
endpoint=actual_traces_endpoint,
headers=actual_traces_headers if actual_traces_headers else None,
)
)
if actual_metrics_endpoint:
exporters.append(
GRPCMetricExporter(
GRPCMetricExporter( # type: ignore[reportUnknownArgumentType]
endpoint=actual_metrics_endpoint,
headers=actual_metrics_headers if actual_metrics_headers else None,
)
@@ -210,8 +210,8 @@ OpenAIAssistantsOptionsT = TypeVar(
class OpenAIAssistantsClient( # type: ignore[misc]
OpenAIConfigMixin,
ChatMiddlewareLayer[OpenAIAssistantsOptionsT],
FunctionInvocationLayer[OpenAIAssistantsOptionsT],
ChatMiddlewareLayer[OpenAIAssistantsOptionsT],
ChatTelemetryLayer[OpenAIAssistantsOptionsT],
BaseChatClient[OpenAIAssistantsOptionsT],
Generic[OpenAIAssistantsOptionsT],
@@ -31,7 +31,7 @@ from pydantic import BaseModel
from .._clients import BaseChatClient
from .._docstrings import apply_layered_docstring
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
FunctionInvocationConfiguration,
@@ -156,9 +156,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
you should consider which additional layers to apply. There is a defined ordering that
you should follow:
1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware
2. **FunctionInvocationLayer** - Handles tool/function calling loop
3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry
1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware
2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry
3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry
Use ``OpenAIChatClient`` instead for a fully-featured client with all layers applied.
"""
@@ -776,8 +776,8 @@ class RawOpenAIChatClient( # type: ignore[misc]
class OpenAIChatClient( # type: ignore[misc]
OpenAIConfigMixin,
ChatMiddlewareLayer[OpenAIChatOptionsT],
FunctionInvocationLayer[OpenAIChatOptionsT],
ChatMiddlewareLayer[OpenAIChatOptionsT],
ChatTelemetryLayer[OpenAIChatOptionsT],
RawOpenAIChatClient[OpenAIChatOptionsT],
Generic[OpenAIChatOptionsT],
@@ -791,7 +791,6 @@ class OpenAIChatClient( # type: ignore[misc]
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
@@ -805,7 +804,6 @@ class OpenAIChatClient( # type: ignore[misc]
*,
stream: Literal[False] = ...,
options: OpenAIChatOptionsT | ChatOptions[None] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
@@ -819,7 +817,6 @@ class OpenAIChatClient( # type: ignore[misc]
*,
stream: Literal[True],
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
@@ -833,7 +830,6 @@ class OpenAIChatClient( # type: ignore[misc]
*,
stream: bool = False,
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
@@ -844,14 +840,15 @@ class OpenAIChatClient( # type: ignore[misc]
"Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]",
super().get_response, # type: ignore[misc]
)
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if middleware is not None:
effective_client_kwargs["middleware"] = middleware
return super_get_response( # type: ignore[no-any-return]
messages=messages,
stream=stream,
options=options,
function_middleware=function_middleware,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
middleware=middleware,
client_kwargs=effective_client_kwargs,
**kwargs,
)
@@ -967,10 +964,6 @@ def _apply_openai_chat_client_docstrings() -> None:
OpenAIChatClient.get_response,
RawOpenAIChatClient.get_response,
extra_keyword_args={
"function_middleware": """
Optional per-call function middleware.
When omitted, middleware configured on the client or forwarded from higher layers is used.
""",
"middleware": """
Optional per-call chat and function middleware.
This is merged with any middleware configured on the client for the current request.
@@ -249,9 +249,9 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
you should consider which additional layers to apply. There is a defined ordering that
you should follow:
1. **ChatMiddlewareLayer** - Should be applied first as it also prepares function middleware
2. **FunctionInvocationLayer** - Handles tool/function calling loop
3. **ChatTelemetryLayer** - Must be inside the function calling loop for correct per-call telemetry
1. **FunctionInvocationLayer** - Owns the tool/function calling loop and routes function middleware
2. **ChatMiddlewareLayer** - Applies chat middleware per model call and stays outside telemetry
3. **ChatTelemetryLayer** - Must stay inside chat middleware for correct per-call telemetry
Use ``OpenAIResponsesClient`` instead for a fully-featured client with all layers applied.
"""
@@ -2259,8 +2259,8 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
class OpenAIResponsesClient( # type: ignore[misc]
OpenAIConfigMixin,
ChatMiddlewareLayer[OpenAIResponsesOptionsT],
FunctionInvocationLayer[OpenAIResponsesOptionsT],
ChatMiddlewareLayer[OpenAIResponsesOptionsT],
ChatTelemetryLayer[OpenAIResponsesOptionsT],
RawOpenAIResponsesClient[OpenAIResponsesOptionsT],
Generic[OpenAIResponsesOptionsT],
+2 -2
View File
@@ -128,8 +128,8 @@ class MockChatClient:
class MockBaseChatClient(
ChatMiddlewareLayer[OptionsCoT],
FunctionInvocationLayer[OptionsCoT],
ChatMiddlewareLayer[OptionsCoT],
ChatTelemetryLayer[OptionsCoT],
BaseChatClient[OptionsCoT],
Generic[OptionsCoT],
@@ -137,7 +137,7 @@ class MockBaseChatClient(
"""Mock implementation of a full-featured ChatClient."""
def __init__(self, **kwargs: Any):
super().__init__(function_middleware=[], **kwargs)
super().__init__(middleware=[], **kwargs)
self.run_responses: list[ChatResponse] = []
self.streaming_responses: list[list[ChatResponseUpdate]] = []
self.call_count: int = 0
@@ -74,8 +74,8 @@ def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs
assert docstring is not None
assert "Get a response from a chat client." in docstring
assert "function_invocation_kwargs" in docstring
assert "function_middleware: Optional per-call function middleware." in docstring
assert "middleware: Optional per-call chat and function middleware." in docstring
assert "function_middleware: Optional per-call function middleware." not in docstring
def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None:
@@ -84,7 +84,6 @@ def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None:
signature = inspect.signature(OpenAIChatClient.get_response)
assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response"
assert "function_middleware" in signature.parameters
assert "middleware" in signature.parameters
@@ -3226,7 +3226,7 @@ async def test_terminate_loop_single_function_call(chat_client_base: SupportsCha
response = await chat_client_base.get_response(
"hello",
options={"tool_choice": "auto", "tools": [ai_func]},
middleware=[TerminateLoopMiddleware()],
client_kwargs={"middleware": [TerminateLoopMiddleware()]},
)
# Function should NOT have been executed - middleware intercepted it
@@ -3292,7 +3292,7 @@ async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client
response = await chat_client_base.get_response(
"hello",
options={"tool_choice": "auto", "tools": [normal_func, terminating_func]},
middleware=[SelectiveTerminateMiddleware()],
client_kwargs={"middleware": [SelectiveTerminateMiddleware()]},
)
# normal_function should have executed (middleware calls next_handler)
@@ -3345,7 +3345,7 @@ async def test_terminate_loop_streaming_single_function_call(chat_client_base: S
async for update in chat_client_base.get_response(
"hello",
options={"tool_choice": "auto", "tools": [ai_func]},
middleware=[TerminateLoopMiddleware()],
client_kwargs={"middleware": [TerminateLoopMiddleware()]},
stream=True,
):
updates.append(update)
@@ -3389,12 +3389,12 @@ async def test_conversation_id_updated_in_options_between_tool_iterations():
conversation_ids_received: list[str | None] = []
class TrackingChatClient(
ChatMiddlewareLayer,
FunctionInvocationLayer,
ChatMiddlewareLayer,
BaseChatClient,
):
def __init__(self) -> None:
super().__init__(function_middleware=[])
super().__init__(middleware=[])
self.run_responses: list[ChatResponse] = []
self.streaming_responses: list[list[ChatResponseUpdate]] = []
self.call_count: int = 0
@@ -84,8 +84,8 @@ class _MockBaseChatClient(BaseChatClient[Any]):
class FunctionInvokingMockClient(
ChatMiddlewareLayer[Any],
FunctionInvocationLayer[Any],
ChatMiddlewareLayer[Any],
ChatTelemetryLayer[Any],
_MockBaseChatClient,
):
@@ -28,6 +28,7 @@ from agent_framework._middleware import (
FunctionMiddleware,
FunctionMiddlewarePipeline,
MiddlewareTermination,
categorize_middleware,
)
from agent_framework._tools import FunctionTool
@@ -1681,3 +1682,49 @@ def mock_chat_client() -> Any:
client = MagicMock(spec=SupportsChatGetResponse)
client.service_url = MagicMock(return_value="mock://test")
return client
class TestCategorizeMiddleware:
"""Test cases for categorize_middleware."""
def test_categorize_middleware_with_tuple(self) -> None:
"""Test that tuple middleware sources are unpacked, not appended as a single item."""
chat_mw = TestChatMiddleware()
function_mw = TestFunctionMiddleware()
agent_mw = TestAgentMiddleware()
result = categorize_middleware((chat_mw, function_mw, agent_mw))
assert result["chat"] == [chat_mw]
assert result["function"] == [function_mw]
assert result["agent"] == [agent_mw]
def test_categorize_middleware_with_list(self) -> None:
"""Test that list middleware sources are unpacked correctly."""
chat_mw = TestChatMiddleware()
function_mw = TestFunctionMiddleware()
result = categorize_middleware([chat_mw, function_mw])
assert result["chat"] == [chat_mw]
assert result["function"] == [function_mw]
assert result["agent"] == []
def test_categorize_middleware_with_none(self) -> None:
"""Test that None middleware sources are handled."""
result = categorize_middleware(None)
assert result["chat"] == []
assert result["function"] == []
assert result["agent"] == []
def test_categorize_middleware_with_single_item(self) -> None:
"""Test that a single unwrapped middleware item is appended correctly."""
chat_mw = TestChatMiddleware()
result = categorize_middleware(chat_mw)
assert result["chat"] == [chat_mw]
assert result["function"] == []
assert result["agent"] == []
def test_categorize_middleware_with_string_does_not_decompose(self) -> None:
"""Test that a string is not decomposed character-by-character."""
result = categorize_middleware("not_a_middleware")
# String should be treated as a single item, not decomposed into characters
total_items = len(result["chat"]) + len(result["function"]) + len(result["agent"])
assert total_items == 1
assert result["agent"] == ["not_a_middleware"]
@@ -697,6 +697,26 @@ class TestChatAgentFunctionMiddlewareWithTools:
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
def test_agent_middleware_pipeline_cache_reuses_matching_middleware(self) -> None:
"""Test that identical agent middleware sets reuse the cached pipeline."""
@agent_middleware
async def first_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
@agent_middleware
async def second_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
agent = Agent(client=MockBaseChatClient())
first_pipeline = agent._get_agent_middleware_pipeline([first_middleware])
second_pipeline = agent._get_agent_middleware_pipeline([first_middleware])
third_pipeline = agent._get_agent_middleware_pipeline([second_middleware])
assert first_pipeline is second_pipeline
assert third_pipeline is not first_pipeline
async def test_function_middleware_can_access_and_override_custom_kwargs(
self, chat_client_base: "MockBaseChatClient"
) -> None:
@@ -1969,6 +1989,77 @@ class TestChatAgentChatMiddleware:
"agent_middleware_after",
]
async def test_combined_middleware_with_tool_loop(self) -> None:
"""Test Agent middleware ordering when tool calls trigger multiple chat rounds."""
execution_order: list[str] = []
chat_round = 0
client = MockBaseChatClient()
client.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_123",
name="sample_tool_function",
arguments='{"location": "Seattle"}',
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Final response")]),
]
async def tracking_agent_middleware(
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("agent_middleware_before")
await call_next()
execution_order.append("agent_middleware_after")
async def tracking_chat_middleware(
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
nonlocal chat_round
chat_round += 1
execution_order.append(f"chat_middleware_before_{chat_round}")
await call_next()
execution_order.append(f"chat_middleware_after_{chat_round}")
async def tracking_function_middleware(
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("function_middleware_before")
await call_next()
execution_order.append("function_middleware_after")
agent = Agent(
client=client,
middleware=[tracking_chat_middleware, tracking_function_middleware, tracking_agent_middleware],
tools=[sample_tool_function],
)
response = await agent.run([Message(role="user", text="test")])
assert response is not None
assert client.call_count == 2
assert response.messages[-1].text == "Final response"
assert execution_order == [
"agent_middleware_before",
"chat_middleware_before_1",
"chat_middleware_after_1",
"function_middleware_before",
"function_middleware_after",
"chat_middleware_before_2",
"chat_middleware_after_2",
"agent_middleware_after",
]
async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None:
"""Test that agent middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
@@ -274,7 +274,10 @@ class TestChatMiddleware:
# First call with run-level middleware
messages = [Message(role="user", text="first message")]
response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
response1 = await chat_client_base.get_response(
messages,
client_kwargs={"middleware": [counting_middleware]},
)
assert response1 is not None
assert execution_count["count"] == 1
@@ -286,7 +289,10 @@ class TestChatMiddleware:
# Third call with run-level middleware again - should execute
messages = [Message(role="user", text="third message")]
response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
response3 = await chat_client_base.get_response(
messages,
client_kwargs={"middleware": [counting_middleware]},
)
assert response3 is not None
assert execution_count["count"] == 2 # Should be 2 now
@@ -335,6 +341,81 @@ class TestChatMiddleware:
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
def test_chat_middleware_pipeline_cache_reuses_matching_middleware(
self,
chat_client_base: "MockBaseChatClient",
) -> None:
"""Test that identical chat middleware sets reuse the cached pipeline."""
@chat_middleware
async def first_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
@chat_middleware
async def second_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
first_pipeline = chat_client_base._get_chat_middleware_pipeline([first_middleware])
second_pipeline = chat_client_base._get_chat_middleware_pipeline([first_middleware])
third_pipeline = chat_client_base._get_chat_middleware_pipeline([second_middleware])
assert first_pipeline is second_pipeline
assert third_pipeline is not first_pipeline
def test_chat_middleware_pipeline_cache_includes_base_middleware(
self,
chat_client_base: "MockBaseChatClient",
) -> None:
"""Test that chat middleware cache key includes base middleware to prevent incorrect reuse."""
@chat_middleware
async def base_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
@chat_middleware
async def runtime_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
# Without base middleware
pipeline_no_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware])
# With base middleware
chat_client_base.chat_middleware = [base_middleware]
pipeline_with_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware])
assert pipeline_with_base is not pipeline_no_base
def test_function_middleware_pipeline_cache_reuses_matching_middleware(
self,
chat_client_base: "MockBaseChatClient",
) -> None:
"""Test that identical function middleware sets reuse the cached pipeline."""
@function_middleware
async def base_middleware(context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]) -> None:
await call_next()
@function_middleware
async def first_runtime_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
await call_next()
@function_middleware
async def second_runtime_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
await call_next()
chat_client_base.function_middleware = [base_middleware]
first_pipeline = chat_client_base._get_function_middleware_pipeline([first_runtime_middleware])
second_pipeline = chat_client_base._get_function_middleware_pipeline([first_runtime_middleware])
third_pipeline = chat_client_base._get_function_middleware_pipeline([second_runtime_middleware])
assert first_pipeline is second_pipeline
assert third_pipeline is not first_pipeline
async def test_function_middleware_registration_on_chat_client(
self, chat_client_base: "MockBaseChatClient"
) -> None:
@@ -450,7 +531,9 @@ class TestChatMiddleware:
# Execute the chat client directly with run-level middleware and tools
messages = [Message(role="user", text="What's the weather in New York?")]
response = await client.get_response(
messages, options={"tools": [sample_tool_wrapped]}, middleware=[run_level_function_middleware]
messages,
options={"tools": [sample_tool_wrapped]},
client_kwargs={"middleware": [run_level_function_middleware]},
)
# Verify response
@@ -463,3 +546,156 @@ class TestChatMiddleware:
"run_level_function_middleware_before",
"run_level_function_middleware_after",
]
async def test_run_level_chat_and_function_middleware_split_per_function_loop_round(self) -> None:
"""Test mixed run-level middleware is split so chat middleware runs per model call."""
execution_order: list[str] = []
chat_round = 0
@chat_middleware
async def run_level_chat_middleware(
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
nonlocal chat_round
chat_round += 1
execution_order.append(f"chat_middleware_before_{chat_round}")
await call_next()
execution_order.append(f"chat_middleware_after_{chat_round}")
@function_middleware
async def run_level_function_middleware(
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("function_middleware_before")
await call_next()
execution_order.append("function_middleware_after")
def sample_tool(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny"
sample_tool_wrapped = FunctionTool(
func=sample_tool,
name="sample_tool",
description="Get weather for a location",
approval_mode="never_require",
)
client = MockBaseChatClient()
client.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_3",
name="sample_tool",
arguments={"location": "Seattle"},
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Based on the weather data, it's sunny!")]),
]
response = await client.get_response(
[Message(role="user", text="What's the weather in Seattle?")],
options={"tools": [sample_tool_wrapped]},
client_kwargs={"middleware": [run_level_chat_middleware, run_level_function_middleware]},
)
assert response is not None
assert client.call_count == 2
assert response.messages[-1].text == "Based on the weather data, it's sunny!"
assert execution_order == [
"chat_middleware_before_1",
"chat_middleware_after_1",
"function_middleware_before",
"function_middleware_after",
"chat_middleware_before_2",
"chat_middleware_after_2",
]
async def test_run_level_chat_and_function_middleware_split_per_function_loop_round_streaming(self) -> None:
"""Test mixed run-level middleware is split so chat middleware runs per model call in streaming mode."""
execution_order: list[str] = []
chat_round = 0
@chat_middleware
async def run_level_chat_middleware(
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
nonlocal chat_round
chat_round += 1
execution_order.append(f"chat_middleware_before_{chat_round}")
await call_next()
execution_order.append(f"chat_middleware_after_{chat_round}")
@function_middleware
async def run_level_function_middleware(
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
execution_order.append("function_middleware_before")
await call_next()
execution_order.append("function_middleware_after")
def sample_tool(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny"
sample_tool_wrapped = FunctionTool(
func=sample_tool,
name="sample_tool",
description="Get weather for a location",
approval_mode="never_require",
)
client = MockBaseChatClient()
client.streaming_responses = [
[
ChatResponseUpdate(
contents=[
Content.from_function_call(
call_id="call_3",
name="sample_tool",
arguments='{"location": "Seattle"}',
)
],
role="assistant",
finish_reason="tool_calls",
),
],
[
ChatResponseUpdate(
contents=[Content.from_text("Based on the weather data, it's sunny!")],
role="assistant",
finish_reason="stop",
),
],
]
updates: list[ChatResponseUpdate] = []
async for update in client.get_response(
[Message(role="user", text="What's the weather in Seattle?")],
options={"tools": [sample_tool_wrapped]},
client_kwargs={"middleware": [run_level_chat_middleware, run_level_function_middleware]},
stream=True,
):
updates.append(update)
assert client.call_count == 2
assert len(updates) > 0
assert execution_order == [
"chat_middleware_before_1",
"chat_middleware_after_1",
"function_middleware_before",
"function_middleware_after",
"chat_middleware_before_2",
"chat_middleware_after_2",
]
@@ -2437,7 +2437,7 @@ def test_capture_response(span_exporter: InMemorySpanExporter):
async def test_layer_ordering_span_sequence_with_function_calling(span_exporter: InMemorySpanExporter):
"""Test that with correct layer ordering, spans appear in the expected sequence.
When using the correct layer ordering (ChatMiddlewareLayer, FunctionInvocationLayer,
When using the correct layer ordering (FunctionInvocationLayer, ChatMiddlewareLayer,
ChatTelemetryLayer, BaseChatClient), the spans should appear in this order:
1. First 'chat' span (initial LLM call that returns function call)
2. 'execute_tool' span (function invocation)
@@ -2454,11 +2454,11 @@ async def test_layer_ordering_span_sequence_with_function_calling(span_exporter:
def get_weather(location: str) -> str:
return f"The weather in {location} is sunny."
# Correct layer ordering: FunctionInvocationLayer BEFORE ChatTelemetryLayer
# This ensures each inner LLM call gets its own telemetry span
# Correct layer ordering: FunctionInvocationLayer BEFORE ChatMiddlewareLayer BEFORE ChatTelemetryLayer
# This ensures each inner LLM call traverses chat middleware and still gets its own telemetry span
class MockChatClientWithLayers(
ChatMiddlewareLayer,
FunctionInvocationLayer,
ChatMiddlewareLayer,
ChatTelemetryLayer,
BaseChatClient,
):
@@ -130,8 +130,8 @@ class FoundryLocalSettings(TypedDict, total=False):
class FoundryLocalClient(
ChatMiddlewareLayer[FoundryLocalChatOptionsT],
FunctionInvocationLayer[FoundryLocalChatOptionsT],
ChatMiddlewareLayer[FoundryLocalChatOptionsT],
ChatTelemetryLayer[FoundryLocalChatOptionsT],
RawOpenAIChatClient[FoundryLocalChatOptionsT],
Generic[FoundryLocalChatOptionsT],
@@ -273,7 +273,7 @@ def _load_gaia_local(repo_dir: Path, wanted_levels: list[int] | None = None, max
for p in parquet_files:
try:
import pyarrow.parquet as pq
import pyarrow.parquet as pq # type: ignore[reportMissingImports]
pq_any = cast(Any, pq)
table: Any = pq_any.read_table(p)
@@ -7,8 +7,8 @@ from __future__ import annotations
import importlib.metadata
from agent_framework.observability import enable_instrumentation
from agentlightning.tracer import (
AgentOpsTracer, # pyright: ignore[reportMissingImports] # type: ignore[import-not-found]
from agentlightning.tracer import ( # type: ignore[reportMissingImports]
AgentOpsTracer, # type: ignore[reportMissingImports, import-not-found]
)
try:
@@ -285,8 +285,8 @@ logger = logging.getLogger("agent_framework.ollama")
class OllamaChatClient(
ChatMiddlewareLayer[OllamaChatOptionsT],
FunctionInvocationLayer[OllamaChatOptionsT],
ChatMiddlewareLayer[OllamaChatOptionsT],
ChatTelemetryLayer[OllamaChatOptionsT],
BaseChatClient[OllamaChatOptionsT],
):
@@ -33,7 +33,7 @@ from agent_framework_orchestrations._handoff import (
from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff
class MockChatClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
"""Mock chat client for testing handoff workflows."""
def __init__(
@@ -134,7 +134,7 @@ class MockHandoffAgent(Agent):
super().__init__(client=MockChatClient(name=name, handoff_to=handoff_to), name=name, id=name)
class ContextAwareRefundClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
"""Mock client that expects prior user context to remain available on resume."""
def __init__(self) -> None:
@@ -298,7 +298,7 @@ async def test_tool_approval_responses_are_not_replayed_from_history() -> None:
execution_count += 1
return "ok"
class ApprovalReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class ApprovalReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
def __init__(self) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
@@ -383,7 +383,7 @@ async def test_handoff_resume_preserves_approval_function_call_for_stateless_run
def submit_refund() -> str:
return "ok"
class StrictStatelessApprovalClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class StrictStatelessApprovalClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
def __init__(self) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
@@ -475,7 +475,7 @@ async def test_handoff_resume_preserves_approval_function_call_for_stateless_run
async def test_handoff_replay_serializes_handoff_function_results() -> None:
"""Returning to the same agent must not replay dict tool outputs."""
class ReplaySafeHandoffClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class ReplaySafeHandoffClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
def __init__(self, name: str, handoff_sequence: list[str | None]) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
@@ -550,7 +550,7 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs(
def submit_refund() -> str:
return "submitted"
class RefundReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class RefundReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
def __init__(self) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
@@ -608,7 +608,7 @@ async def test_handoff_resume_preserves_approved_tool_output_for_stateless_runs(
return _get()
class OrderReplayClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class OrderReplayClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
def __init__(self) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
@@ -907,7 +907,7 @@ async def test_handoff_async_termination_condition() -> None:
async def test_handoff_terminates_without_request_info_when_latest_response_meets_condition() -> None:
"""Termination triggered by the latest assistant response should not emit request_info."""
class FinalizingClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
class FinalizingClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]):
def __init__(self) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
+6 -4
View File
@@ -114,10 +114,11 @@ class RetryingAzureOpenAIChatClient(AzureOpenAIChatClient):
class RateLimitRetryMiddleware(ChatMiddleware):
"""Chat middleware that retries the full request pipeline on rate limit errors.
"""Chat middleware that retries a single model-call pipeline on rate limit errors.
Register this middleware on an agent (or at the run level) to automatically
retry any call_next() invocation that raises RateLimitError.
retry any chat-model call that raises RateLimitError. In tool-loop scenarios,
the middleware applies independently to each inner model call.
"""
def __init__(self, *, max_attempts: int = RETRY_ATTEMPTS) -> None:
@@ -154,8 +155,9 @@ async def rate_limit_retry_middleware(
"""Function-based chat middleware that retries on rate limit errors.
Wrap call_next() with a tenacity @retry decorator so any RateLimitError
raised during model inference triggers an automatic retry with exponential
back-off.
raised during a single model call triggers an automatic retry with exponential
back-off. In tool-loop scenarios, the middleware applies independently to
each inner model call.
"""
@retry(
@@ -29,7 +29,10 @@ else:
Custom Chat Client Implementation Example
This sample demonstrates implementing a custom chat client and optionally composing
middleware, telemetry, and function invocation layers explicitly.
middleware, telemetry, and function invocation layers explicitly. The recommended
layer order is `FunctionInvocationLayer -> ChatMiddlewareLayer -> ChatTelemetryLayer`
so chat middleware runs within each tool-loop iteration while telemetry records
per-call spans without middleware latency.
"""
@@ -124,9 +127,9 @@ class EchoingChatClient(BaseChatClient[OptionsT]):
class EchoingChatClientWithLayers( # type: ignore[misc]
FunctionInvocationLayer[OptionsT],
ChatMiddlewareLayer[OptionsT],
ChatTelemetryLayer[OptionsT],
FunctionInvocationLayer[OptionsT],
EchoingChatClient,
):
"""Echoing chat client that explicitly composes middleware, telemetry, and function layers."""
@@ -0,0 +1,37 @@
# Middleware samples
This folder contains focused middleware samples for `Agent`, chat clients, tools, sessions, and runtime context behavior.
## Files
| File | Description |
|------|-------------|
| [`agent_and_run_level_middleware.py`](./agent_and_run_level_middleware.py) | Demonstrates combining agent-level and run-level middleware. |
| [`chat_middleware.py`](./chat_middleware.py) | Shows class-based and function-based chat middleware that can observe, modify, and override model calls. |
| [`class_based_middleware.py`](./class_based_middleware.py) | Shows class-based agent and function middleware. |
| [`decorator_middleware.py`](./decorator_middleware.py) | Demonstrates middleware registration with decorators. |
| [`exception_handling_with_middleware.py`](./exception_handling_with_middleware.py) | Shows how middleware can handle failures and recover cleanly. |
| [`function_based_middleware.py`](./function_based_middleware.py) | Shows function-based agent and function middleware. |
| [`middleware_termination.py`](./middleware_termination.py) | Demonstrates stopping a middleware pipeline early. |
| [`override_result_with_middleware.py`](./override_result_with_middleware.py) | Shows how middleware can replace the normal result. |
| [`runtime_context_delegation.py`](./runtime_context_delegation.py) | Demonstrates delegating work with runtime context data. |
| [`session_behavior_middleware.py`](./session_behavior_middleware.py) | Shows how middleware interacts with session-backed runs. |
| [`shared_state_middleware.py`](./shared_state_middleware.py) | Demonstrates sharing mutable state across middleware invocations. |
| [`usage_tracking_middleware.py`](./usage_tracking_middleware.py) | Demonstrates one chat middleware function that tracks per-call usage in non-streaming and streaming tool-loop runs. |
## Running the usage tracking sample
The new usage tracking sample uses `OpenAIResponsesClient`, so set the usual OpenAI responses environment variables first:
```bash
export OPENAI_API_KEY="your-openai-api-key"
export OPENAI_RESPONSES_MODEL_ID="gpt-4.1-mini"
```
Then run:
```bash
uv run samples/02-agents/middleware/usage_tracking_middleware.py
```
The sample forces a tool call so you can see middleware output for each inner model call in both non-streaming and streaming modes.
@@ -51,10 +51,10 @@ Agent Middleware Execution Order:
- Run middleware wraps only the agent for that specific run
- Each middleware can modify the context before AND after calling next()
Note: Function and chat middleware (e.g., ``function_logging_middleware``) execute
during tool invocation *inside* the agent execution, not in the outer agent-middleware
chain shown above. They follow the same ordering principle: agent-level function/chat
middleware runs before run-level function/chat middleware.
Note: Function middleware executes during tool invocation, and chat middleware
executes around each model call inside the agent execution, not in the outer
agent-middleware chain shown above. They follow the same ordering principle:
agent-level function/chat middleware runs before run-level function/chat middleware.
"""
@@ -0,0 +1,185 @@
# Copyright (c) Microsoft. All rights reserved.
"""
This sample demonstrates a single chat middleware that tracks per-model-call usage
for both non-streaming and streaming tool-loop runs.
"""
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
Agent,
ChatContext,
ChatResponse,
ChatResponseUpdate,
ResponseStream,
chat_middleware,
tool,
)
from agent_framework.openai import OpenAIResponsesClient
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
NON_STREAMING_CALL_COUNT = 0
STREAMING_CALL_COUNT = 0
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
def _reset_usage_counters() -> None:
"""Reset call counters between sample runs."""
global NON_STREAMING_CALL_COUNT, STREAMING_CALL_COUNT
NON_STREAMING_CALL_COUNT = 0
STREAMING_CALL_COUNT = 0
def _create_agent(
) -> Agent:
"""Create the shared agent used by both demonstrations."""
return Agent(
client=OpenAIResponsesClient(),
instructions=(
"You are a weather assistant. Always call the weather tool before answering weather questions, "
"then summarize the tool result in one short paragraph."
),
tools=[get_weather],
middleware=[print_usage],
)
@chat_middleware
async def print_usage(
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Print usage for each inner model call in both non-streaming and streaming runs."""
global NON_STREAMING_CALL_COUNT, STREAMING_CALL_COUNT
if context.stream:
STREAMING_CALL_COUNT += 1
call_number = STREAMING_CALL_COUNT
usage_seen_in_updates = False
def capture_usage_update(update: ChatResponseUpdate) -> ChatResponseUpdate:
nonlocal usage_seen_in_updates
for content in update.contents:
if content.type == "usage":
usage_seen_in_updates = True
print(f"\n[Streaming model call #{call_number}] Usage update: {content.usage_details}")
return update
def capture_final_usage(result: ChatResponse) -> ChatResponse:
if not usage_seen_in_updates and result.usage_details:
print(f"\n[Streaming model call #{call_number}] Final usage: {result.usage_details}")
return result
context.stream_transform_hooks.append(capture_usage_update)
context.stream_result_hooks.append(capture_final_usage)
await call_next()
return
NON_STREAMING_CALL_COUNT += 1
call_number = NON_STREAMING_CALL_COUNT
await call_next()
response = context.result
if isinstance(response, ChatResponse) and response.usage_details:
print(f"[Non-streaming model call #{call_number}] Usage: {response.usage_details}")
async def non_streaming_usage_example() -> None:
"""Run the non-streaming usage tracking example."""
_reset_usage_counters()
print("\n=== Non-streaming per-call usage tracking ===")
# 1. Create an agent with middleware that prints usage after each inner model call.
agent = _create_agent()
# 2. Run a weather question and require a tool call so the function loop performs multiple model calls.
query = "What is the weather in Seattle, and should I bring an umbrella?"
print(f"User: {query}")
result = await agent.run(
query,
options={"tool_choice": "required"},
)
# 3. Print the final user-visible answer after the middleware already logged per-call usage.
print(f"Assistant: {result.text}")
async def streaming_usage_example() -> None:
"""Run the streaming usage tracking example."""
_reset_usage_counters()
print("\n=== Streaming per-call usage tracking ===")
# 1. Create an agent with middleware that watches streaming usage for each inner model call.
agent = _create_agent()
# 2. Start a streaming run and force tool usage so the function loop performs multiple model calls.
query = "What is the weather in Portland, and should I bring a jacket?"
print(f"User: {query}")
print("Assistant: ", end="", flush=True)
stream: ResponseStream = agent.run(
query,
stream=True,
options={"tool_choice": "required"},
)
# 3. Consume the stream normally while the middleware reports usage in the background.
async for update in stream:
if update.text:
print(update.text, end="", flush=True)
print()
# 4. Finalize the stream so you can inspect the final response if needed.
final_response = await stream.get_final_response()
print(f"Final assistant message: {final_response.text}")
async def main() -> None:
"""Run both usage tracking demonstrations."""
print("=== Usage Tracking Middleware Example ===")
await non_streaming_usage_example()
await streaming_usage_example()
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Usage Tracking Middleware Example ===
=== Non-streaming per-call usage tracking ===
User: What is the weather in Seattle, and should I bring an umbrella?
[Non-streaming model call #1] Usage: {'input_tokens': ..., 'output_tokens': ..., ...}
[Non-streaming model call #2] Usage: {'input_tokens': ..., 'output_tokens': ..., ...}
Assistant: Based on the weather in Seattle, ...
=== Streaming per-call usage tracking ===
User: What is the weather in Portland, and should I bring a jacket?
Assistant: Based on the weather in Portland, ...
[Streaming model call #1] Usage update: {'input_tokens': ..., 'output_tokens': ..., ...}
[Streaming model call #2] Usage update: {'input_tokens': ..., 'output_tokens': ..., ...}
Final assistant message: Based on the weather in Portland, ...
"""
@@ -96,10 +96,16 @@ async def run_chat_client() -> None:
stream: Whether to use streaming for the plugin
Remarks:
When function calling is outside the open telemetry loop
each of the call to the model is handled as a seperate span,
while when the open telemetry is put last, a single span
is shown, which might include one or more rounds of function calling.
By default, the built-in non-`Raw...Client` chat clients already compose
the layers in this order:
`FunctionInvocationLayer -> ChatMiddlewareLayer -> ChatTelemetryLayer -> Raw/Base client`.
When `FunctionInvocationLayer` is outside `ChatTelemetryLayer`,
each call to the model is handled as a separate span.
Keep `ChatMiddlewareLayer` outside telemetry
so middleware latency does not skew those timings.
By contrast, when telemetry is placed outside the function loop,
a single span can cover one or more rounds of function calling.
So for the scenario below, you should see the following:
@@ -71,10 +71,12 @@ async def run_chat_client(client: "SupportsChatGetResponse", stream: bool = Fals
stream: Whether to use streaming for the plugin
Remarks:
When function calling is outside the open telemetry loop
each of the call to the model is handled as a separate span,
while when the open telemetry is put last, a single span
is shown, which might include one or more rounds of function calling.
When `FunctionInvocationLayer` is outside `ChatTelemetryLayer`,
each call to the model is handled as a separate span.
If `ChatMiddlewareLayer` is present, keep it outside telemetry
so middleware latency does not skew those timings.
By contrast, when telemetry is placed outside the function loop,
a single span can cover one or more rounds of function calling.
So for the scenario below, you should see the following:
@@ -37,17 +37,17 @@ The framework provides `Raw...Client` classes (e.g., `RawOpenAIChatClient`, `Raw
There is a defined ordering for applying layers that you should follow:
1. **ChatMiddlewareLayer** - Should be applied **first** because it also prepares function middleware
2. **FunctionInvocationLayer** - Handles tool/function calling loop
3. **ChatTelemetryLayer** - Must be **inside** the function calling loop for correct per-call telemetry
1. **FunctionInvocationLayer** - Handles the tool/function calling loop and should stay outermost
2. **ChatMiddlewareLayer** - Wraps each model call in the loop and stays outside telemetry
3. **ChatTelemetryLayer** - Must be inside the function calling loop so each model call gets its own telemetry span
4. **Raw...Client** - The base implementation (e.g., `RawOpenAIChatClient`)
Example of correct layer composition:
```python
class MyCustomClient(
ChatMiddlewareLayer[TOptions],
FunctionInvocationLayer[TOptions],
ChatMiddlewareLayer[TOptions],
ChatTelemetryLayer[TOptions],
RawOpenAIChatClient[TOptions], # or BaseChatClient for custom implementations
Generic[TOptions],
@@ -16,7 +16,6 @@ from azure.ai.agentserver.agentframework import from_agent_framework
from azure.identity.aio import AzureCliCredential, ManagedIdentityCredential
from dotenv import load_dotenv
load_dotenv(override=True)
# Configure these for your Foundry project