Python: Added chat middleware and more examples (#883)

* Added example with stateful middleware

* Added chat middleware

* Updated middleware example with override scenario

* Small revert

* Small fixes

* Added kwargs to context objects

* Added README

* Added function middleware to chat client

* Small refactoring

* Reverted example files

* Made MiddlewareWrapper generic

* Added Middleware exception

* Small refactoring

* Small fix
This commit is contained in:
Dmytro Struk
2025-09-26 08:10:56 -07:00
committed by GitHub
Unverified
parent 863c8d7471
commit eec7f192eb
19 changed files with 2667 additions and 267 deletions
@@ -10,7 +10,13 @@ from pydantic import BaseModel, Field
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, ContextProvider
from ._middleware import Middleware
from ._middleware import (
ChatMiddleware,
ChatMiddlewareCallable,
FunctionMiddleware,
FunctionMiddlewareCallable,
Middleware,
)
from ._pydantic import AFBaseModel
from ._threads import ChatMessageStore
from ._tools import ToolProtocol
@@ -189,6 +195,14 @@ class BaseChatClient(AFBaseModel, ABC):
"""Base class for chat clients."""
additional_properties: dict[str, Any] = Field(default_factory=dict)
middleware: (
ChatMiddleware
| ChatMiddlewareCallable
| FunctionMiddleware
| FunctionMiddlewareCallable
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
| None
) = None
OTEL_PROVIDER_NAME: str = "unknown"
# This is used for OTel setup, should be overridden in subclasses
@@ -346,13 +360,7 @@ class BaseChatClient(AFBaseModel, ABC):
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
async def get_streaming_response(
self,
@@ -432,14 +440,8 @@ class BaseChatClient(AFBaseModel, ABC):
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
async for update in self._inner_get_streaming_response(
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
messages=prepped_messages, chat_options=chat_options, **kwargs
):
yield update
File diff suppressed because it is too large Load Diff
+23 -4
View File
@@ -607,6 +607,7 @@ async def _auto_invoke_function(
middleware_context = FunctionInvocationContext(
function=tool,
arguments=args,
kwargs=custom_args or {},
)
async def final_function_handler(context_obj: Any) -> Any:
@@ -721,8 +722,16 @@ def _handle_function_calls_response(
**kwargs: Any,
) -> "ChatResponse":
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import ChatMessage, ChatOptions, FunctionCallContent, FunctionResultContent
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
prepped_messages = prepare_messages(messages)
response: "ChatResponse | None" = None
fcc_messages: "list[ChatMessage]" = []
@@ -746,8 +755,9 @@ def _handle_function_calls_response(
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
middleware_pipeline = stored_middleware_pipeline
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
@@ -820,8 +830,16 @@ def _handle_function_calls_streaming_response(
) -> AsyncIterable["ChatResponseUpdate"]:
"""Wrap the inner get streaming response method to handle tool calls."""
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
prepped_messages = prepare_messages(messages)
for attempt_idx in range(max_iterations):
all_updates: list["ChatResponseUpdate"] = []
@@ -858,8 +876,9 @@ def _handle_function_calls_streaming_response(
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
middleware_pipeline = stored_middleware_pipeline
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
@@ -13,16 +13,18 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from pydantic import SecretStr, ValidationError
from pydantic.networks import AnyUrl
from .._tools import use_function_invocation
from .._types import (
from agent_framework import (
ChatResponse,
ChatResponseUpdate,
CitationAnnotation,
TextContent,
use_chat_middleware,
use_function_invocation,
)
from ..exceptions import ServiceInitializationError
from ..observability import use_observability
from ..openai._chat_client import OpenAIBaseChatClient
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import use_observability
from agent_framework.openai._chat_client import OpenAIBaseChatClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
@@ -41,6 +43,7 @@ TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAICha
@use_function_invocation
@use_observability
@use_chat_middleware
class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient):
"""Azure OpenAI Chat completion class."""
@@ -9,10 +9,11 @@ from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from pydantic import SecretStr, ValidationError
from pydantic.networks import AnyUrl
from .._tools import use_function_invocation
from ..exceptions import ServiceInitializationError
from ..observability import use_observability
from ..openai._responses_client import OpenAIBaseResponsesClient
from agent_framework import use_chat_middleware, use_function_invocation
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import use_observability
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
@@ -23,6 +24,7 @@ TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="Azur
@use_observability
@use_function_invocation
@use_chat_middleware
class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient):
"""Azure Responses completion class."""
@@ -128,3 +128,9 @@ class AdditionItemMismatch(AgentFrameworkException):
"""An error occurred while adding two types."""
pass
class MiddlewareException(AgentFrameworkException):
"""An error occurred during middleware execution."""
pass
@@ -21,6 +21,7 @@ from openai.types.beta.threads.runs import RunStep
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._middleware import use_chat_middleware
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation
from .._types import (
ChatMessage,
@@ -52,6 +53,7 @@ __all__ = ["OpenAIAssistantsClient"]
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
"""OpenAI Assistants client."""
@@ -18,6 +18,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import use_chat_middleware
from .._tools import AIFunction, HostedWebSearchTool, ToolProtocol, use_function_invocation
from .._types import (
ChatMessage,
@@ -452,6 +453,7 @@ TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient")
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
"""OpenAI Chat completion class."""
@@ -26,6 +26,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import use_chat_middleware
from .._tools import (
AIFunction,
HostedCodeInterpreterTool,
@@ -933,6 +934,7 @@ TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponse
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
"""OpenAI Responses client class."""