Files
agent-framework/python/packages/main/agent_framework/_middleware.py
T
Dmytro Struk eec7f192eb Python: Added chat middleware and more examples (#883)
* Added example with stateful middleware

* Added chat middleware

* Updated middleware example with override scenario

* Small revert

* Small fixes

* Added kwargs to context objects

* Added README

* Added function middleware to chat client

* Small refactoring

* Reverted example files

* Made MiddlewareWrapper generic

* Added Middleware exception

* Small refactoring

* Small fix
2025-09-26 15:10:56 +00:00

1186 lines
46 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
import inspect
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable, Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from .exceptions import MiddlewareException
if TYPE_CHECKING:
from collections.abc import AsyncIterable, MutableSequence
from pydantic import BaseModel
from ._agents import AgentProtocol
from ._clients import ChatClientProtocol
from ._tools import AIFunction
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
TAgent = TypeVar("TAgent", bound="AgentProtocol")
TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol")
TContext = TypeVar("TContext")
class MiddlewareType(Enum):
"""Enum representing the type of middleware."""
AGENT = "agent"
FUNCTION = "function"
CHAT = "chat"
__all__ = [
"AgentMiddleware",
"AgentRunContext",
"ChatContext",
"ChatMiddleware",
"FunctionInvocationContext",
"FunctionMiddleware",
"Middleware",
"agent_middleware",
"chat_middleware",
"function_middleware",
"use_agent_middleware",
"use_chat_middleware",
]
@dataclass
class AgentRunContext:
"""Context object for agent middleware invocations.
Attributes:
agent: The agent being invoked.
messages: The messages being sent to the agent.
is_streaming: Whether this is a streaming invocation.
metadata: Metadata dictionary for sharing data between agent middleware.
result: Agent execution result. Can be observed after calling next()
to see the actual execution result or can be set to override the execution result.
For non-streaming: should be AgentRunResponse
For streaming: should be AsyncIterable[AgentRunResponseUpdate]
terminate: A flag indicating whether to terminate execution after current middleware.
When set to True, execution will stop as soon as control returns to framework.
kwargs: Additional keyword arguments passed to the agent run method.
"""
agent: "AgentProtocol"
messages: list[ChatMessage]
is_streaming: bool = False
metadata: dict[str, Any] = field(default_factory=dict) # type: ignore
result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None
terminate: bool = False
kwargs: dict[str, Any] = field(default_factory=dict) # type: ignore
@dataclass
class FunctionInvocationContext:
"""Context object for function middleware invocations.
Attributes:
function: The function being invoked.
arguments: The validated arguments for the function.
metadata: Metadata dictionary for sharing data between function middleware.
result: Function execution result. Can be observed after calling next()
to see the actual execution result or can be set to override the execution result.
terminate: A flag indicating whether to terminate execution after current middleware.
When set to True, execution will stop as soon as control returns to framework.
kwargs: Additional keyword arguments passed to the chat method that invoked this function.
"""
function: "AIFunction[Any, Any]"
arguments: "BaseModel"
metadata: dict[str, Any] = field(default_factory=dict) # type: ignore
result: Any = None
terminate: bool = False
kwargs: dict[str, Any] = field(default_factory=dict) # type: ignore
@dataclass
class ChatContext:
"""Context object for chat middleware invocations.
Attributes:
chat_client: The chat client being invoked.
messages: The messages being sent to the chat client.
chat_options: The options for the chat request.
is_streaming: Whether this is a streaming invocation.
metadata: Metadata dictionary.
result: Chat execution result. Can be observed after calling next()
to see the actual execution result or can be set to override the execution result.
For non-streaming: should be ChatResponse
For streaming: should be AsyncIterable[ChatResponseUpdate]
terminate: A flag indicating whether to terminate execution after current middleware.
When set to True, execution will stop as soon as control returns to framework.
kwargs: Additional keyword arguments passed to the chat client.
"""
chat_client: "ChatClientProtocol"
messages: "MutableSequence[ChatMessage]"
chat_options: "ChatOptions"
is_streaming: bool = False
metadata: dict[str, Any] = field(default_factory=dict) # type: ignore
result: "ChatResponse | AsyncIterable[ChatResponseUpdate] | None" = None
terminate: bool = False
kwargs: dict[str, Any] = field(default_factory=dict) # type: ignore
class AgentMiddleware(ABC):
"""Abstract base class for agent middleware that can intercept agent invocations."""
@abstractmethod
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Process an agent invocation.
Args:
context: Agent invocation context containing agent, messages, and metadata.
Use context.is_streaming to determine if this is a streaming call.
Middleware can set context.result to override execution, or observe
the actual execution result after calling next().
For non-streaming: AgentRunResponse
For streaming: AsyncIterable[AgentRunResponseUpdate]
next: Function to call the next middleware or final agent execution.
Does not return anything - all data flows through the context.
Note:
Middleware should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
"""
...
class FunctionMiddleware(ABC):
"""Abstract base class for function middleware that can intercept function invocations."""
@abstractmethod
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Process a function invocation.
Args:
context: Function invocation context containing function, arguments, and metadata.
Middleware can set context.result to override execution, or observe
the actual execution result after calling next().
next: Function to call the next middleware or final function execution.
Does not return anything - all data flows through the context.
Note:
Middleware should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
"""
...
class ChatMiddleware(ABC):
"""Abstract base class for chat middleware that can intercept chat client requests."""
@abstractmethod
async def process(
self,
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
"""Process a chat client request.
Args:
context: Chat invocation context containing chat client, messages, options, and metadata.
Use context.is_streaming to determine if this is a streaming call.
Middleware can set context.result to override execution, or observe
the actual execution result after calling next().
For non-streaming: ChatResponse
For streaming: AsyncIterable[ChatResponseUpdate]
next: Function to call the next middleware or final chat execution.
Does not return anything - all data flows through the context.
Note:
Middleware should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
"""
...
# Pure function type definitions for convenience
AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]]
FunctionMiddlewareCallable = Callable[
[FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None]
]
ChatMiddlewareCallable = Callable[[ChatContext, Callable[[ChatContext], Awaitable[None]]], Awaitable[None]]
# Type alias for all middleware types
Middleware: TypeAlias = (
AgentMiddleware
| AgentMiddlewareCallable
| FunctionMiddleware
| FunctionMiddlewareCallable
| ChatMiddleware
| ChatMiddlewareCallable
)
# Middleware type markers for decorators
def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
"""Decorator to mark a function as agent middleware.
This decorator explicitly identifies a function as agent middleware,
which processes AgentRunContext objects.
Args:
func: The middleware function to mark as agent middleware.
Returns:
The same function with agent middleware marker.
Example:
@agent_middleware
async def my_middleware(context: AgentRunContext, next):
# Process agent invocation
await next(context)
"""
# Add marker attribute to identify this as agent middleware
func._middleware_type: MiddlewareType = MiddlewareType.AGENT # type: ignore
return func
def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareCallable:
"""Decorator to mark a function as function middleware.
This decorator explicitly identifies a function as function middleware,
which processes FunctionInvocationContext objects.
Args:
func: The middleware function to mark as function middleware.
Returns:
The same function with function middleware marker.
Example:
@function_middleware
async def my_middleware(context: FunctionInvocationContext, next):
# Process function invocation
await next(context)
"""
# Add marker attribute to identify this as function middleware
func._middleware_type: MiddlewareType = MiddlewareType.FUNCTION # type: ignore
return func
def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable:
"""Decorator to mark a function as chat middleware.
This decorator explicitly identifies a function as chat middleware,
which processes ChatContext objects.
Args:
func: The middleware function to mark as chat middleware.
Returns:
The same function with chat middleware marker.
Example:
@chat_middleware
async def my_middleware(context: ChatContext, next):
# Process chat invocation
await next(context)
"""
# Add marker attribute to identify this as chat middleware
func._middleware_type: MiddlewareType = MiddlewareType.CHAT # type: ignore
return func
class MiddlewareWrapper(Generic[TContext]):
"""Generic wrapper to convert pure functions into middleware protocol objects.
Type Parameters:
TContext: The type of context object this middleware operates on.
"""
def __init__(self, func: Callable[[TContext, Callable[[TContext], Awaitable[None]]], Awaitable[None]]) -> None:
self.func = func
async def process(self, context: TContext, next: Callable[[TContext], Awaitable[None]]) -> None:
await self.func(context, next)
class BaseMiddlewarePipeline(ABC):
"""Base class for middleware pipeline execution."""
def __init__(self) -> None:
"""Initialize the base middleware pipeline."""
self._middlewares: list[Any] = []
@abstractmethod
def _register_middleware(self, middleware: Any) -> None:
"""Register a middleware item. Must be implemented by subclasses."""
...
@property
def has_middlewares(self) -> bool:
"""Check if there are any middlewares registered."""
return bool(self._middlewares)
def _register_middleware_with_wrapper(
self,
middleware: Any,
expected_type: type,
) -> None:
"""Generic middleware registration with automatic wrapping.
Args:
middleware: The middleware instance or callable to register.
expected_type: The expected middleware base class type.
"""
if isinstance(middleware, expected_type):
self._middlewares.append(middleware)
elif callable(middleware):
self._middlewares.append(MiddlewareWrapper(middleware)) # type: ignore[arg-type]
def _create_handler_chain(
self,
final_handler: Callable[[Any], Awaitable[Any]],
result_container: dict[str, Any],
result_key: str = "result",
) -> Callable[[Any], Awaitable[None]]:
"""Create a chain of middleware handlers.
Args:
final_handler: The final handler to execute
result_container: Container to store the result
result_key: Key to use in the result container
Returns:
The first handler in the chain
"""
def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]:
if index >= len(self._middlewares):
async def final_wrapper(c: Any) -> None:
# Execute actual handler and populate context for observability
result = await final_handler(c)
result_container[result_key] = result
c.result = result
return final_wrapper
middleware = self._middlewares[index]
next_handler = create_next_handler(index + 1)
async def current_handler(c: Any) -> None:
await middleware.process(c, next_handler)
return current_handler
return create_next_handler(0)
def _create_streaming_handler_chain(
self,
final_handler: Callable[[Any], Any],
result_container: dict[str, Any],
result_key: str = "result_stream",
) -> Callable[[Any], Awaitable[None]]:
"""Create a chain of middleware handlers for streaming operations.
Args:
final_handler: The final handler to execute
result_container: Container to store the result
result_key: Key to use in the result container
Returns:
The first handler in the chain
"""
def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]:
if index >= len(self._middlewares):
async def final_wrapper(c: Any) -> None:
# If terminate was set, skip execution
if c.terminate:
return
# Execute actual handler and populate context for observability
# Note: final_handler might not be awaitable for streaming cases
try:
result = await final_handler(c)
except TypeError:
# Handle non-awaitable case (e.g., generator functions)
result = final_handler(c)
result_container[result_key] = result
c.result = result
return final_wrapper
middleware = self._middlewares[index]
next_handler = create_next_handler(index + 1)
async def current_handler(c: Any) -> None:
await middleware.process(c, next_handler)
# If terminate is set, don't continue the pipeline
if c.terminate:
return
return current_handler
return create_next_handler(0)
class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
"""Executes agent middleware in a chain."""
def __init__(self, middlewares: list[AgentMiddleware | AgentMiddlewareCallable] | None = None):
"""Initialize the agent middleware pipeline.
Args:
middlewares: List of agent middleware to include in the pipeline.
"""
super().__init__()
self._middlewares: list[AgentMiddleware] = []
if middlewares:
for middleware in middlewares:
self._register_middleware(middleware)
def _register_middleware(self, middleware: AgentMiddleware | AgentMiddlewareCallable) -> None:
"""Register an agent middleware item."""
self._register_middleware_with_wrapper(middleware, AgentMiddleware)
async def execute(
self,
agent: "AgentProtocol",
messages: list[ChatMessage],
context: AgentRunContext,
final_handler: Callable[[AgentRunContext], Awaitable[AgentRunResponse]],
) -> AgentRunResponse | None:
"""Execute the agent middleware pipeline for non-streaming.
Args:
agent: The agent being invoked.
messages: The messages to send to the agent.
context: The agent invocation context.
final_handler: The final handler that performs the actual agent execution.
Returns:
The agent response after processing through all middleware.
"""
# Update context with agent and messages
context.agent = agent
context.messages = messages
context.is_streaming = False
if not self._middlewares:
return await final_handler(context)
# Store the final result
result_container: dict[str, AgentRunResponse | None] = {"result": None}
# Custom final handler that handles termination and result override
async def agent_final_handler(c: AgentRunContext) -> AgentRunResponse:
# If terminate was set, return the result (which might be None)
if c.terminate:
if c.result is not None and isinstance(c.result, AgentRunResponse):
return c.result
return AgentRunResponse()
# Execute actual handler and populate context for observability
return await final_handler(c)
first_handler = self._create_handler_chain(agent_final_handler, result_container, "result")
await first_handler(context)
# Return the result from result container or overridden result
if context.result is not None and isinstance(context.result, AgentRunResponse):
return context.result
# If no result was set (next() not called), return empty AgentRunResponse
response = result_container.get("result")
if response is None:
return AgentRunResponse()
return response
async def execute_stream(
self,
agent: "AgentProtocol",
messages: list[ChatMessage],
context: AgentRunContext,
final_handler: Callable[[AgentRunContext], AsyncIterable[AgentRunResponseUpdate]],
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Execute the agent middleware pipeline for streaming.
Args:
agent: The agent being invoked.
messages: The messages to send to the agent.
context: The agent invocation context.
final_handler: The final handler that performs the actual agent streaming execution.
Yields:
Agent response updates after processing through all middleware.
"""
# Update context with agent and messages
context.agent = agent
context.messages = messages
context.is_streaming = True
if not self._middlewares:
async for update in final_handler(context):
yield update
return
# Store the final result
result_container: dict[str, AsyncIterable[AgentRunResponseUpdate] | None] = {"result_stream": None}
first_handler = self._create_streaming_handler_chain(final_handler, result_container, "result_stream")
await first_handler(context)
# Yield from the result stream in result container or overridden result
if context.result is not None and hasattr(context.result, "__aiter__"):
async for update in context.result: # type: ignore
yield update
return
result_stream = result_container["result_stream"]
if result_stream is None:
# If no result stream was set (next() not called), yield nothing
return
async for update in result_stream:
yield update
class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
"""Executes function middleware in a chain."""
def __init__(self, middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] | None = None):
"""Initialize the function middleware pipeline.
Args:
middlewares: List of function middleware to include in the pipeline.
"""
super().__init__()
self._middlewares: list[FunctionMiddleware] = []
if middlewares:
for middleware in middlewares:
self._register_middleware(middleware)
def _register_middleware(self, middleware: FunctionMiddleware | FunctionMiddlewareCallable) -> None:
"""Register a function middleware item."""
self._register_middleware_with_wrapper(middleware, FunctionMiddleware)
async def execute(
self,
function: Any,
arguments: "BaseModel",
context: FunctionInvocationContext,
final_handler: Callable[[FunctionInvocationContext], Awaitable[Any]],
) -> Any:
"""Execute the function middleware pipeline.
Args:
function: The function being invoked.
arguments: The validated arguments for the function.
context: The function invocation context.
final_handler: The final handler that performs the actual function execution.
Returns:
The function result after processing through all middleware.
"""
# Update context with function and arguments
context.function = function
context.arguments = arguments
if not self._middlewares:
return await final_handler(context)
# Store the final result
result_container: dict[str, Any] = {"result": None}
# Custom final handler that handles pre-existing results
async def function_final_handler(c: FunctionInvocationContext) -> Any:
# If terminate was set, skip execution and return the result (which might be None)
if c.terminate:
return c.result
# Execute actual handler and populate context for observability
return await final_handler(c)
first_handler = self._create_handler_chain(function_final_handler, result_container, "result")
await first_handler(context)
# Return the result from result container or overridden result
if context.result is not None:
return context.result
return result_container["result"]
class ChatMiddlewarePipeline(BaseMiddlewarePipeline):
"""Executes chat middleware in a chain."""
def __init__(self, middlewares: list[ChatMiddleware | ChatMiddlewareCallable] | None = None):
"""Initialize the chat middleware pipeline.
Args:
middlewares: List of chat middleware to include in the pipeline.
"""
super().__init__()
self._middlewares: list[ChatMiddleware] = []
if middlewares:
for middleware in middlewares:
self._register_middleware(middleware)
def _register_middleware(self, middleware: ChatMiddleware | ChatMiddlewareCallable) -> None:
"""Register a chat middleware item."""
self._register_middleware_with_wrapper(middleware, ChatMiddleware)
async def execute(
self,
chat_client: "ChatClientProtocol",
messages: "MutableSequence[ChatMessage]",
chat_options: "ChatOptions",
context: ChatContext,
final_handler: Callable[[ChatContext], Awaitable["ChatResponse"]],
**kwargs: Any,
) -> "ChatResponse":
"""Execute the chat middleware pipeline.
Args:
chat_client: The chat client being invoked.
messages: The messages being sent to the chat client.
chat_options: The options for the chat request.
context: The chat invocation context.
final_handler: The final handler that performs the actual chat execution.
**kwargs: Additional keyword arguments.
Returns:
The chat response after processing through all middleware.
"""
# Update context with chat client, messages, and options
context.chat_client = chat_client
context.messages = messages
context.chat_options = chat_options
if not self._middlewares:
return await final_handler(context)
# Store the final result
result_container: dict[str, Any] = {"result": None}
# Custom final handler that handles pre-existing results
async def chat_final_handler(c: ChatContext) -> "ChatResponse":
# If terminate was set, skip execution and return the result (which might be None)
if c.terminate:
return c.result # type: ignore
# Execute actual handler and populate context for observability
return await final_handler(c)
first_handler = self._create_handler_chain(chat_final_handler, result_container, "result")
await first_handler(context)
# Return the result from result container or overridden result
if context.result is not None:
return context.result # type: ignore
return result_container["result"] # type: ignore
async def execute_stream(
self,
chat_client: "ChatClientProtocol",
messages: "MutableSequence[ChatMessage]",
chat_options: "ChatOptions",
context: ChatContext,
final_handler: Callable[[ChatContext], AsyncIterable["ChatResponseUpdate"]],
**kwargs: Any,
) -> AsyncIterable["ChatResponseUpdate"]:
"""Execute the chat middleware pipeline for streaming.
Args:
chat_client: The chat client being invoked.
messages: The messages being sent to the chat client.
chat_options: The options for the chat request.
context: The chat invocation context.
final_handler: The final handler that performs the actual streaming chat execution.
**kwargs: Additional keyword arguments.
Yields:
Chat response updates after processing through all middleware.
"""
# Update context with chat client, messages, and options
context.chat_client = chat_client
context.messages = messages
context.chat_options = chat_options
context.is_streaming = True
if not self._middlewares:
async for update in final_handler(context):
yield update
return
# Store the final result stream
result_container: dict[str, Any] = {"result_stream": None}
first_handler = self._create_streaming_handler_chain(final_handler, result_container, "result_stream")
await first_handler(context)
# Yield from the result stream in result container or overridden result
if context.result is not None and hasattr(context.result, "__aiter__"):
async for update in context.result: # type: ignore
yield update
return
result_stream = result_container["result_stream"]
if result_stream is None:
# If no result stream was set (next() not called), yield nothing
return
async for update in result_stream:
yield update
def _determine_middleware_type(middleware: Any) -> MiddlewareType:
"""Determine middleware type using decorator and/or parameter type annotation.
Args:
middleware: The middleware function to analyze.
Returns:
MiddlewareType.AGENT, MiddlewareType.FUNCTION, or MiddlewareType.CHAT indicating the middleware type.
Raises:
MiddlewareException: When middleware type cannot be determined or there's a mismatch.
"""
# Check for decorator marker
decorator_type: MiddlewareType | None = getattr(middleware, "_middleware_type", None)
# Check for parameter type annotation
param_type: MiddlewareType | None = None
try:
sig = inspect.signature(middleware)
params = list(sig.parameters.values())
# Must have at least 2 parameters (context and next)
if len(params) >= 2:
first_param = params[0]
if hasattr(first_param.annotation, "__name__"):
annotation_name = first_param.annotation.__name__
if annotation_name == "AgentRunContext":
param_type = MiddlewareType.AGENT
elif annotation_name == "FunctionInvocationContext":
param_type = MiddlewareType.FUNCTION
elif annotation_name == "ChatContext":
param_type = MiddlewareType.CHAT
else:
# Not enough parameters - can't be valid middleware
raise MiddlewareException(
f"Middleware function must have at least 2 parameters (context, next), "
f"but {middleware.__name__} has {len(params)}"
)
except Exception as e:
if isinstance(e, MiddlewareException):
raise
# Signature inspection failed - continue with other checks
pass
if decorator_type and param_type:
# Both decorator and parameter type specified - they must match
if decorator_type != param_type:
raise MiddlewareException(
f"Middleware type mismatch: decorator indicates '{decorator_type.value}' "
f"but parameter type indicates '{param_type.value}' for function {middleware.__name__}"
)
return decorator_type
if decorator_type:
# Just decorator specified - rely on decorator
return decorator_type
if param_type:
# Just parameter type specified - rely on types
return param_type
# Neither decorator nor parameter type specified - throw exception
raise MiddlewareException(
f"Cannot determine middleware type for function {middleware.__name__}. "
f"Please either use @agent_middleware/@function_middleware/@chat_middleware decorators "
f"or specify parameter types (AgentRunContext, FunctionInvocationContext, or ChatContext)."
)
# Decorator for adding middleware support to agent classes
def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
"""Class decorator that adds middleware support to an agent class.
This decorator adds middleware functionality to any agent class.
It wraps the run() and run_stream() methods to provide middleware execution.
The middleware execution can be terminated at any point by setting the
context.terminate property to True. Once set, the pipeline will stop executing
further middleware as soon as control returns to the pipeline.
Args:
agent_class: The agent class to add middleware support to.
Returns:
The modified agent class with middleware support.
"""
# Store original methods
original_run = agent_class.run # type: ignore[attr-defined]
original_run_stream = agent_class.run_stream # type: ignore[attr-defined]
def _build_middleware_pipelines(
agent_level_middlewares: Middleware | list[Middleware] | None,
run_level_middlewares: Middleware | list[Middleware] | None = None,
) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline, list[ChatMiddleware | ChatMiddlewareCallable]]:
"""Build fresh agent and function middleware pipelines from the provided middleware lists.
Args:
agent_level_middlewares: Agent-level middleware (executed first)
run_level_middlewares: Run-level middleware (executed after agent middleware)
"""
middleware = categorize_middleware(agent_level_middlewares, run_level_middlewares)
return (
AgentMiddlewarePipeline(middleware["agent"]), # type: ignore[arg-type]
FunctionMiddlewarePipeline(middleware["function"]), # type: ignore[arg-type]
middleware["chat"], # type: ignore[return-value]
)
async def middleware_enabled_run(
self: Any,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: Any = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
) -> AgentRunResponse:
"""Middleware-enabled run method."""
# Build fresh middleware pipelines from current middleware collection and run-level middleware
agent_middleware = getattr(self, "middleware", None)
agent_pipeline, function_pipeline, chat_middlewares = _build_middleware_pipelines(agent_middleware, middleware)
# Add function middleware pipeline to kwargs if available
if function_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = function_pipeline
# Pass chat middleware through kwargs for run-level application
if chat_middlewares:
kwargs["middleware"] = chat_middlewares
normalized_messages = self._normalize_messages(messages)
# Execute with middleware if available
if agent_pipeline.has_middlewares:
context = AgentRunContext(
agent=self, # type: ignore[arg-type]
messages=normalized_messages,
is_streaming=False,
kwargs=kwargs,
)
async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse:
return await original_run(self, ctx.messages, thread=thread, **ctx.kwargs) # type: ignore
result = await agent_pipeline.execute(
self, # type: ignore[arg-type]
normalized_messages,
context,
_execute_handler,
)
return result if result else AgentRunResponse()
# No middleware, execute directly
return await original_run(self, normalized_messages, thread=thread, **kwargs) # type: ignore[return-value]
def middleware_enabled_run_stream(
self: Any,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: Any = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Middleware-enabled run_stream method."""
# Build fresh middleware pipelines from current middleware collection and run-level middleware
agent_middleware = getattr(self, "middleware", None)
agent_pipeline, function_pipeline, chat_middlewares = _build_middleware_pipelines(agent_middleware, middleware)
# Add function middleware pipeline to kwargs if available
if function_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = function_pipeline
# Pass chat middleware through kwargs for run-level application
if chat_middlewares:
kwargs["middleware"] = chat_middlewares
normalized_messages = self._normalize_messages(messages)
# Execute with middleware if available
if agent_pipeline.has_middlewares:
context = AgentRunContext(
agent=self, # type: ignore[arg-type]
messages=normalized_messages,
is_streaming=True,
kwargs=kwargs,
)
async def _execute_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
async for update in original_run_stream(self, ctx.messages, thread=thread, **ctx.kwargs): # type: ignore[misc]
yield update
async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]:
async for update in agent_pipeline.execute_stream(
self, # type: ignore[arg-type]
normalized_messages,
context,
_execute_stream_handler,
):
yield update
return _stream_generator()
# No middleware, execute directly
return original_run_stream(self, normalized_messages, thread=thread, **kwargs) # type: ignore
agent_class.run = middleware_enabled_run # type: ignore
agent_class.run_stream = middleware_enabled_run_stream # type: ignore
return agent_class
def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClient]:
"""Class decorator that adds middleware support to a chat client class.
This decorator adds middleware functionality to any chat client class.
It wraps the get_response() and get_streaming_response() methods to provide middleware execution.
Args:
chat_client_class: The chat client class to add middleware support to.
Returns:
The modified chat client class with middleware support.
"""
# Store original methods
original_get_response = chat_client_class.get_response
original_get_streaming_response = chat_client_class.get_streaming_response
async def middleware_enabled_get_response(
self: Any,
messages: Any,
**kwargs: Any,
) -> Any:
"""Middleware-enabled get_response method."""
# Check if middleware is provided at call level or instance level
call_middleware = kwargs.pop("middleware", None)
instance_middleware = getattr(self, "middleware", None)
# Merge all middleware and separate by type
middleware = categorize_middleware(instance_middleware, call_middleware)
chat_middleware_list = middleware["chat"] # type: ignore[assignment]
# Extract function middleware for the function invocation pipeline
function_middleware_list = middleware["function"]
# Pass function middleware to function invocation system if present
if function_middleware_list:
kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list) # type: ignore[arg-type]
# If no chat middleware, use original method
if not chat_middleware_list:
return await original_get_response(self, messages, **kwargs)
# Create pipeline and execute with middleware
from ._types import ChatOptions
# Extract chat_options or create default
chat_options = kwargs.pop("chat_options", ChatOptions())
pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=self.prepare_messages(messages),
chat_options=chat_options,
is_streaming=False,
kwargs=kwargs,
)
async def final_handler(ctx: ChatContext) -> Any:
return await original_get_response(self, list(ctx.messages), chat_options=ctx.chat_options, **ctx.kwargs)
return await pipeline.execute(
chat_client=self,
messages=context.messages,
chat_options=context.chat_options,
context=context,
final_handler=final_handler,
**kwargs,
)
def middleware_enabled_get_streaming_response(
self: Any,
messages: Any,
**kwargs: Any,
) -> Any:
"""Middleware-enabled get_streaming_response method."""
async def _stream_generator() -> Any:
# Check if middleware is provided at call level or instance level
call_middleware = kwargs.pop("middleware", None)
instance_middleware = getattr(self, "middleware", None)
# Merge middleware from both sources, filtering for chat middleware only
all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware(
instance_middleware, call_middleware
)
# If no middleware, use original method
if not all_middleware:
async for update in original_get_streaming_response(self, messages, **kwargs):
yield update
return
# Create pipeline and execute with middleware
from ._types import ChatOptions
# Extract chat_options or create default
chat_options = kwargs.pop("chat_options", ChatOptions())
pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=self.prepare_messages(messages),
chat_options=chat_options,
is_streaming=True,
kwargs=kwargs,
)
def final_handler(ctx: ChatContext) -> Any:
return original_get_streaming_response(
self, list(ctx.messages), chat_options=ctx.chat_options, **ctx.kwargs
)
async for update in pipeline.execute_stream(
chat_client=self,
messages=context.messages,
chat_options=context.chat_options,
context=context,
final_handler=final_handler,
**kwargs,
):
yield update
return _stream_generator()
# Replace methods
chat_client_class.get_response = middleware_enabled_get_response # type: ignore
chat_client_class.get_streaming_response = middleware_enabled_get_streaming_response # type: ignore
return chat_client_class
def categorize_middleware(
*middleware_sources: Any | list[Any] | None,
) -> dict[str, list[Any]]:
"""Categorize middleware from multiple sources into agent, function, and chat types.
Args:
*middleware_sources: Variable number of middleware sources to categorize.
Returns:
Dict with keys "agent", "function", "chat" containing lists of categorized middleware.
"""
result: dict[str, list[Any]] = {"agent": [], "function": [], "chat": []}
# Merge all middleware sources into a single list
all_middleware: list[Any] = []
for source in middleware_sources:
if source:
if isinstance(source, list):
all_middleware.extend(source) # type: ignore
else:
all_middleware.append(source)
# Categorize each middleware item
for middleware in all_middleware:
if isinstance(middleware, AgentMiddleware):
result["agent"].append(middleware)
elif isinstance(middleware, FunctionMiddleware):
result["function"].append(middleware)
elif isinstance(middleware, ChatMiddleware):
result["chat"].append(middleware)
elif callable(middleware):
# Always call _determine_middleware_type to ensure proper validation
middleware_type = _determine_middleware_type(middleware)
if middleware_type == MiddlewareType.AGENT:
result["agent"].append(middleware)
elif middleware_type == MiddlewareType.FUNCTION:
result["function"].append(middleware)
elif middleware_type == MiddlewareType.CHAT:
result["chat"].append(middleware)
else:
# Fallback to agent middleware for unknown types
result["agent"].append(middleware)
return result
def create_function_middleware_pipeline(
*middleware_sources: list[Middleware] | None,
) -> FunctionMiddlewarePipeline | None:
"""Create a function middleware pipeline from multiple middleware sources."""
middleware = categorize_middleware(*middleware_sources)
function_middlewares = middleware["function"]
return FunctionMiddlewarePipeline(function_middlewares) if function_middlewares else None # type: ignore[arg-type]
def _merge_and_filter_chat_middleware(
instance_middleware: Any | list[Any] | None,
call_middleware: Any | list[Any] | None,
) -> list[ChatMiddleware | ChatMiddlewareCallable]:
"""Merge instance-level and call-level middleware, filtering for chat middleware only.
Args:
instance_middleware: Middleware defined at the instance level.
call_middleware: Middleware provided at the call level.
Returns:
A merged list of chat middleware only.
"""
middleware = categorize_middleware(instance_middleware, call_middleware)
return middleware["chat"] # type: ignore[return-value]
def extract_and_merge_function_middleware(chat_client: Any, kwargs: dict[str, Any]) -> None:
"""Extract function middleware from chat client and merge with existing pipeline in kwargs.
Args:
chat_client: The chat client instance to extract middleware from.
kwargs: Dictionary containing middleware and pipeline information.
"""
# Get middleware sources
client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None
run_level_middleware = kwargs.get("middleware")
existing_pipeline = kwargs.get("_function_middleware_pipeline")
# Extract existing pipeline middlewares if present
existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None
# Create combined pipeline from all sources using existing helper
combined_pipeline = create_function_middleware_pipeline(
client_middleware, run_level_middleware, existing_middlewares
)
if combined_pipeline:
kwargs["_function_middleware_pipeline"] = combined_pipeline