diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 758263cefc..7ef3e73111 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -34,6 +34,7 @@ from agent_framework import ( UsageContent, UsageDetails, get_logger, + use_chat_middleware, use_function_invocation, ) from agent_framework._pydantic import AFBaseSettings @@ -123,6 +124,7 @@ TAzureAIAgentClient = TypeVar("TAzureAIAgentClient", bound="AzureAIAgentClient") @use_function_invocation @use_observability +@use_chat_middleware class AzureAIAgentClient(BaseChatClient): """Azure AI Agent Chat client.""" diff --git a/python/packages/main/agent_framework/_clients.py b/python/packages/main/agent_framework/_clients.py index f1a8d970ef..fcebe612ce 100644 --- a/python/packages/main/agent_framework/_clients.py +++ b/python/packages/main/agent_framework/_clients.py @@ -10,7 +10,13 @@ from pydantic import BaseModel, Field from ._logging import get_logger from ._mcp import MCPTool from ._memory import AggregateContextProvider, ContextProvider -from ._middleware import Middleware +from ._middleware import ( + ChatMiddleware, + ChatMiddlewareCallable, + FunctionMiddleware, + FunctionMiddlewareCallable, + Middleware, +) from ._pydantic import AFBaseModel from ._threads import ChatMessageStore from ._tools import ToolProtocol @@ -189,6 +195,14 @@ class BaseChatClient(AFBaseModel, ABC): """Base class for chat clients.""" additional_properties: dict[str, Any] = Field(default_factory=dict) + middleware: ( + ChatMiddleware + | ChatMiddlewareCallable + | FunctionMiddleware + | FunctionMiddlewareCallable + | list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable] + | None + ) = None OTEL_PROVIDER_NAME: str = "unknown" # This is used for OTel setup, should be overridden in subclasses @@ -346,13 +360,7 @@ class BaseChatClient(AFBaseModel, ABC): prepped_messages = self.prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) - # Remove middleware pipeline from kwargs as it's only used by function invocation wrappers - if "_function_middleware_pipeline" in kwargs: - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"} - else: - filtered_kwargs = kwargs - - return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs) + return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs) async def get_streaming_response( self, @@ -432,14 +440,8 @@ class BaseChatClient(AFBaseModel, ABC): prepped_messages = self.prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) - # Remove middleware pipeline from kwargs as it's only used by function invocation wrappers - if "_function_middleware_pipeline" in kwargs: - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"} - else: - filtered_kwargs = kwargs - async for update in self._inner_get_streaming_response( - messages=prepped_messages, chat_options=chat_options, **filtered_kwargs + messages=prepped_messages, chat_options=chat_options, **kwargs ): yield update diff --git a/python/packages/main/agent_framework/_middleware.py b/python/packages/main/agent_framework/_middleware.py index ba157dcc28..d557e45ed3 100644 --- a/python/packages/main/agent_framework/_middleware.py +++ b/python/packages/main/agent_framework/_middleware.py @@ -1,20 +1,28 @@ # Copyright (c) Microsoft. All rights reserved. +import inspect from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Awaitable, Callable from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage +from .exceptions import MiddlewareException if TYPE_CHECKING: + from collections.abc import AsyncIterable, MutableSequence + from pydantic import BaseModel from ._agents import AgentProtocol + from ._clients import ChatClientProtocol from ._tools import AIFunction + from ._types import ChatOptions, ChatResponse, ChatResponseUpdate TAgent = TypeVar("TAgent", bound="AgentProtocol") +TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol") +TContext = TypeVar("TContext") class MiddlewareType(Enum): @@ -22,17 +30,22 @@ class MiddlewareType(Enum): AGENT = "agent" FUNCTION = "function" + CHAT = "chat" __all__ = [ "AgentMiddleware", "AgentRunContext", + "ChatContext", + "ChatMiddleware", "FunctionInvocationContext", "FunctionMiddleware", "Middleware", "agent_middleware", + "chat_middleware", "function_middleware", "use_agent_middleware", + "use_chat_middleware", ] @@ -51,14 +64,16 @@ class AgentRunContext: For streaming: should be AsyncIterable[AgentRunResponseUpdate] terminate: A flag indicating whether to terminate execution after current middleware. When set to True, execution will stop as soon as control returns to framework. + kwargs: Additional keyword arguments passed to the agent run method. """ agent: "AgentProtocol" messages: list[ChatMessage] is_streaming: bool = False - metadata: dict[str, Any] = field(default_factory=lambda: {}) + metadata: dict[str, Any] = field(default_factory=dict) # type: ignore result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None terminate: bool = False + kwargs: dict[str, Any] = field(default_factory=dict) # type: ignore @dataclass @@ -73,13 +88,44 @@ class FunctionInvocationContext: to see the actual execution result or can be set to override the execution result. terminate: A flag indicating whether to terminate execution after current middleware. When set to True, execution will stop as soon as control returns to framework. + kwargs: Additional keyword arguments passed to the chat method that invoked this function. """ function: "AIFunction[Any, Any]" arguments: "BaseModel" - metadata: dict[str, Any] = field(default_factory=lambda: {}) + metadata: dict[str, Any] = field(default_factory=dict) # type: ignore result: Any = None terminate: bool = False + kwargs: dict[str, Any] = field(default_factory=dict) # type: ignore + + +@dataclass +class ChatContext: + """Context object for chat middleware invocations. + + Attributes: + chat_client: The chat client being invoked. + messages: The messages being sent to the chat client. + chat_options: The options for the chat request. + is_streaming: Whether this is a streaming invocation. + metadata: Metadata dictionary. + result: Chat execution result. Can be observed after calling next() + to see the actual execution result or can be set to override the execution result. + For non-streaming: should be ChatResponse + For streaming: should be AsyncIterable[ChatResponseUpdate] + terminate: A flag indicating whether to terminate execution after current middleware. + When set to True, execution will stop as soon as control returns to framework. + kwargs: Additional keyword arguments passed to the chat client. + """ + + chat_client: "ChatClientProtocol" + messages: "MutableSequence[ChatMessage]" + chat_options: "ChatOptions" + is_streaming: bool = False + metadata: dict[str, Any] = field(default_factory=dict) # type: ignore + result: "ChatResponse | AsyncIterable[ChatResponseUpdate] | None" = None + terminate: bool = False + kwargs: dict[str, Any] = field(default_factory=dict) # type: ignore class AgentMiddleware(ABC): @@ -137,6 +183,35 @@ class FunctionMiddleware(ABC): ... +class ChatMiddleware(ABC): + """Abstract base class for chat middleware that can intercept chat client requests.""" + + @abstractmethod + async def process( + self, + context: ChatContext, + next: Callable[[ChatContext], Awaitable[None]], + ) -> None: + """Process a chat client request. + + Args: + context: Chat invocation context containing chat client, messages, options, and metadata. + Use context.is_streaming to determine if this is a streaming call. + Middleware can set context.result to override execution, or observe + the actual execution result after calling next(). + For non-streaming: ChatResponse + For streaming: AsyncIterable[ChatResponseUpdate] + next: Function to call the next middleware or final chat execution. + Does not return anything - all data flows through the context. + + Note: + Middleware should not return anything. All data manipulation should happen + within the context object. Set context.result to override execution, + or observe context.result after calling next() for actual results. + """ + ... + + # Pure function type definitions for convenience AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]] @@ -144,8 +219,17 @@ FunctionMiddlewareCallable = Callable[ [FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None] ] +ChatMiddlewareCallable = Callable[[ChatContext, Callable[[ChatContext], Awaitable[None]]], Awaitable[None]] + # Type alias for all middleware types -Middleware: TypeAlias = AgentMiddleware | AgentMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable +Middleware: TypeAlias = ( + AgentMiddleware + | AgentMiddlewareCallable + | FunctionMiddleware + | FunctionMiddlewareCallable + | ChatMiddleware + | ChatMiddlewareCallable +) # Middleware type markers for decorators @@ -195,31 +279,40 @@ def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareC return func -class AgentMiddlewareWrapper(AgentMiddleware): - """Wrapper to convert pure functions into AgentMiddleware protocol objects.""" +def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable: + """Decorator to mark a function as chat middleware. - def __init__(self, func: AgentMiddlewareCallable): + This decorator explicitly identifies a function as chat middleware, + which processes ChatContext objects. + + Args: + func: The middleware function to mark as chat middleware. + + Returns: + The same function with chat middleware marker. + + Example: + @chat_middleware + async def my_middleware(context: ChatContext, next): + # Process chat invocation + await next(context) + """ + # Add marker attribute to identify this as chat middleware + func._middleware_type: MiddlewareType = MiddlewareType.CHAT # type: ignore + return func + + +class MiddlewareWrapper(Generic[TContext]): + """Generic wrapper to convert pure functions into middleware protocol objects. + + Type Parameters: + TContext: The type of context object this middleware operates on. + """ + + def __init__(self, func: Callable[[TContext, Callable[[TContext], Awaitable[None]]], Awaitable[None]]) -> None: self.func = func - async def process( - self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], - ) -> None: - await self.func(context, next) - - -class FunctionMiddlewareWrapper(FunctionMiddleware): - """Wrapper to convert pure functions into FunctionMiddleware protocol objects.""" - - def __init__(self, func: FunctionMiddlewareCallable): - self.func = func - - async def process( - self, - context: FunctionInvocationContext, - next: Callable[[FunctionInvocationContext], Awaitable[None]], - ) -> None: + async def process(self, context: TContext, next: Callable[[TContext], Awaitable[None]]) -> None: await self.func(context, next) @@ -240,6 +333,22 @@ class BaseMiddlewarePipeline(ABC): """Check if there are any middlewares registered.""" return bool(self._middlewares) + def _register_middleware_with_wrapper( + self, + middleware: Any, + expected_type: type, + ) -> None: + """Generic middleware registration with automatic wrapping. + + Args: + middleware: The middleware instance or callable to register. + expected_type: The expected middleware base class type. + """ + if isinstance(middleware, expected_type): + self._middlewares.append(middleware) + elif callable(middleware): + self._middlewares.append(MiddlewareWrapper(middleware)) # type: ignore[arg-type] + def _create_handler_chain( self, final_handler: Callable[[Any], Awaitable[Any]], @@ -278,6 +387,56 @@ class BaseMiddlewarePipeline(ABC): return create_next_handler(0) + def _create_streaming_handler_chain( + self, + final_handler: Callable[[Any], Any], + result_container: dict[str, Any], + result_key: str = "result_stream", + ) -> Callable[[Any], Awaitable[None]]: + """Create a chain of middleware handlers for streaming operations. + + Args: + final_handler: The final handler to execute + result_container: Container to store the result + result_key: Key to use in the result container + + Returns: + The first handler in the chain + """ + + def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]: + if index >= len(self._middlewares): + + async def final_wrapper(c: Any) -> None: + # If terminate was set, skip execution + if c.terminate: + return + + # Execute actual handler and populate context for observability + # Note: final_handler might not be awaitable for streaming cases + try: + result = await final_handler(c) + except TypeError: + # Handle non-awaitable case (e.g., generator functions) + result = final_handler(c) + result_container[result_key] = result + c.result = result + + return final_wrapper + + middleware = self._middlewares[index] + next_handler = create_next_handler(index + 1) + + async def current_handler(c: Any) -> None: + await middleware.process(c, next_handler) + # If terminate is set, don't continue the pipeline + if c.terminate: + return + + return current_handler + + return create_next_handler(0) + class AgentMiddlewarePipeline(BaseMiddlewarePipeline): """Executes agent middleware in a chain.""" @@ -297,10 +456,7 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): def _register_middleware(self, middleware: AgentMiddleware | AgentMiddlewareCallable) -> None: """Register an agent middleware item.""" - if isinstance(middleware, AgentMiddleware): - self._middlewares.append(middleware) - elif callable(middleware): - self._middlewares.append(AgentMiddlewareWrapper(middleware)) + self._register_middleware_with_wrapper(middleware, AgentMiddleware) async def execute( self, @@ -329,39 +485,19 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): return await final_handler(context) # Store the final result - result_container: dict[str, AgentRunResponse | None] = {"response": None} + result_container: dict[str, AgentRunResponse | None] = {"result": None} - def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]: - if index >= len(self._middlewares): - - async def final_wrapper(c: AgentRunContext) -> None: - # If terminate was set, skip execution - if c.terminate: - return - - # Execute actual handler and populate context for observability - result = await final_handler(c) - result_container["result"] = result - c.result = result - - return final_wrapper - - middleware = self._middlewares[index] - next_handler = create_next_handler(index + 1) - - async def current_handler(c: AgentRunContext) -> None: - # If terminate is set, don't continue the pipeline - if c.terminate: - return - - await middleware.process(c, next_handler) - # After middleware execution, check if response was overridden + # Custom final handler that handles termination and result override + async def agent_final_handler(c: AgentRunContext) -> AgentRunResponse: + # If terminate was set, return the result (which might be None) + if c.terminate: if c.result is not None and isinstance(c.result, AgentRunResponse): - result_container["result"] = c.result + return c.result + return AgentRunResponse() + # Execute actual handler and populate context for observability + return await final_handler(c) - return current_handler - - first_handler = create_next_handler(0) + first_handler = self._create_handler_chain(agent_final_handler, result_container, "result") await first_handler(context) # Return the result from result container or overridden result @@ -405,33 +541,7 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): # Store the final result result_container: dict[str, AsyncIterable[AgentRunResponseUpdate] | None] = {"result_stream": None} - def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]: - if index >= len(self._middlewares): - - async def final_wrapper(c: AgentRunContext) -> None: # noqa: RUF029 - # If terminate was set, skip execution - if c.terminate: - return - - # Execute actual handler and populate context for observability - result = final_handler(c) - result_container["result_stream"] = result - c.result = result - - return final_wrapper - - middleware = self._middlewares[index] - next_handler = create_next_handler(index + 1) - - async def current_handler(c: AgentRunContext) -> None: - await middleware.process(c, next_handler) - # If terminate is set, don't continue the pipeline - if c.terminate: - return - - return current_handler - - first_handler = create_next_handler(0) + first_handler = self._create_streaming_handler_chain(final_handler, result_container, "result_stream") await first_handler(context) # Yield from the result stream in result container or overridden result @@ -467,11 +577,7 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): def _register_middleware(self, middleware: FunctionMiddleware | FunctionMiddlewareCallable) -> None: """Register a function middleware item.""" - # Check if it's a class instance inheriting from FunctionMiddleware - if isinstance(middleware, FunctionMiddleware): - self._middlewares.append(middleware) - elif callable(middleware): - self._middlewares.append(FunctionMiddlewareWrapper(middleware)) + self._register_middleware_with_wrapper(middleware, FunctionMiddleware) async def execute( self, @@ -518,6 +624,198 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): return result_container["result"] +class ChatMiddlewarePipeline(BaseMiddlewarePipeline): + """Executes chat middleware in a chain.""" + + def __init__(self, middlewares: list[ChatMiddleware | ChatMiddlewareCallable] | None = None): + """Initialize the chat middleware pipeline. + + Args: + middlewares: List of chat middleware to include in the pipeline. + """ + super().__init__() + self._middlewares: list[ChatMiddleware] = [] + + if middlewares: + for middleware in middlewares: + self._register_middleware(middleware) + + def _register_middleware(self, middleware: ChatMiddleware | ChatMiddlewareCallable) -> None: + """Register a chat middleware item.""" + self._register_middleware_with_wrapper(middleware, ChatMiddleware) + + async def execute( + self, + chat_client: "ChatClientProtocol", + messages: "MutableSequence[ChatMessage]", + chat_options: "ChatOptions", + context: ChatContext, + final_handler: Callable[[ChatContext], Awaitable["ChatResponse"]], + **kwargs: Any, + ) -> "ChatResponse": + """Execute the chat middleware pipeline. + + Args: + chat_client: The chat client being invoked. + messages: The messages being sent to the chat client. + chat_options: The options for the chat request. + context: The chat invocation context. + final_handler: The final handler that performs the actual chat execution. + **kwargs: Additional keyword arguments. + + Returns: + The chat response after processing through all middleware. + """ + # Update context with chat client, messages, and options + context.chat_client = chat_client + context.messages = messages + context.chat_options = chat_options + + if not self._middlewares: + return await final_handler(context) + + # Store the final result + result_container: dict[str, Any] = {"result": None} + + # Custom final handler that handles pre-existing results + async def chat_final_handler(c: ChatContext) -> "ChatResponse": + # If terminate was set, skip execution and return the result (which might be None) + if c.terminate: + return c.result # type: ignore + # Execute actual handler and populate context for observability + return await final_handler(c) + + first_handler = self._create_handler_chain(chat_final_handler, result_container, "result") + await first_handler(context) + + # Return the result from result container or overridden result + if context.result is not None: + return context.result # type: ignore + return result_container["result"] # type: ignore + + async def execute_stream( + self, + chat_client: "ChatClientProtocol", + messages: "MutableSequence[ChatMessage]", + chat_options: "ChatOptions", + context: ChatContext, + final_handler: Callable[[ChatContext], AsyncIterable["ChatResponseUpdate"]], + **kwargs: Any, + ) -> AsyncIterable["ChatResponseUpdate"]: + """Execute the chat middleware pipeline for streaming. + + Args: + chat_client: The chat client being invoked. + messages: The messages being sent to the chat client. + chat_options: The options for the chat request. + context: The chat invocation context. + final_handler: The final handler that performs the actual streaming chat execution. + **kwargs: Additional keyword arguments. + + Yields: + Chat response updates after processing through all middleware. + """ + # Update context with chat client, messages, and options + context.chat_client = chat_client + context.messages = messages + context.chat_options = chat_options + context.is_streaming = True + + if not self._middlewares: + async for update in final_handler(context): + yield update + return + + # Store the final result stream + result_container: dict[str, Any] = {"result_stream": None} + + first_handler = self._create_streaming_handler_chain(final_handler, result_container, "result_stream") + await first_handler(context) + + # Yield from the result stream in result container or overridden result + if context.result is not None and hasattr(context.result, "__aiter__"): + async for update in context.result: # type: ignore + yield update + return + + result_stream = result_container["result_stream"] + if result_stream is None: + # If no result stream was set (next() not called), yield nothing + return + + async for update in result_stream: + yield update + + +def _determine_middleware_type(middleware: Any) -> MiddlewareType: + """Determine middleware type using decorator and/or parameter type annotation. + + Args: + middleware: The middleware function to analyze. + + Returns: + MiddlewareType.AGENT, MiddlewareType.FUNCTION, or MiddlewareType.CHAT indicating the middleware type. + + Raises: + MiddlewareException: When middleware type cannot be determined or there's a mismatch. + """ + # Check for decorator marker + decorator_type: MiddlewareType | None = getattr(middleware, "_middleware_type", None) + + # Check for parameter type annotation + param_type: MiddlewareType | None = None + try: + sig = inspect.signature(middleware) + params = list(sig.parameters.values()) + + # Must have at least 2 parameters (context and next) + if len(params) >= 2: + first_param = params[0] + if hasattr(first_param.annotation, "__name__"): + annotation_name = first_param.annotation.__name__ + if annotation_name == "AgentRunContext": + param_type = MiddlewareType.AGENT + elif annotation_name == "FunctionInvocationContext": + param_type = MiddlewareType.FUNCTION + elif annotation_name == "ChatContext": + param_type = MiddlewareType.CHAT + else: + # Not enough parameters - can't be valid middleware + raise MiddlewareException( + f"Middleware function must have at least 2 parameters (context, next), " + f"but {middleware.__name__} has {len(params)}" + ) + except Exception as e: + if isinstance(e, MiddlewareException): + raise + # Signature inspection failed - continue with other checks + pass + + if decorator_type and param_type: + # Both decorator and parameter type specified - they must match + if decorator_type != param_type: + raise MiddlewareException( + f"Middleware type mismatch: decorator indicates '{decorator_type.value}' " + f"but parameter type indicates '{param_type.value}' for function {middleware.__name__}" + ) + return decorator_type + + if decorator_type: + # Just decorator specified - rely on decorator + return decorator_type + + if param_type: + # Just parameter type specified - rely on types + return param_type + + # Neither decorator nor parameter type specified - throw exception + raise MiddlewareException( + f"Cannot determine middleware type for function {middleware.__name__}. " + f"Please either use @agent_middleware/@function_middleware/@chat_middleware decorators " + f"or specify parameter types (AgentRunContext, FunctionInvocationContext, or ChatContext)." + ) + + # Decorator for adding middleware support to agent classes def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: """Class decorator that adds middleware support to an agent class. @@ -535,132 +833,27 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: Returns: The modified agent class with middleware support. """ - import inspect - # Store original methods original_run = agent_class.run # type: ignore[attr-defined] original_run_stream = agent_class.run_stream # type: ignore[attr-defined] - def _determine_middleware_type(middleware: Any) -> MiddlewareType: - """Determine middleware type using decorator and/or parameter type annotation. - - Args: - middleware: The middleware function to analyze. - - Returns: - MiddlewareType.AGENT or MiddlewareType.FUNCTION indicating the middleware type. - - Raises: - ValueError: When middleware type cannot be determined or there's a mismatch. - """ - # Check for decorator marker - decorator_type: MiddlewareType | None = getattr(middleware, "_middleware_type", None) - - # Check for parameter type annotation - param_type: MiddlewareType | None = None - try: - sig = inspect.signature(middleware) - params = list(sig.parameters.values()) - - # Must have at least 2 parameters (context and next) - if len(params) >= 2: - first_param = params[0] - if hasattr(first_param.annotation, "__name__"): - annotation_name = first_param.annotation.__name__ - if annotation_name == "AgentRunContext": - param_type = MiddlewareType.AGENT - elif annotation_name == "FunctionInvocationContext": - param_type = MiddlewareType.FUNCTION - else: - # Not enough parameters - can't be valid middleware - raise ValueError( - f"Middleware function must have at least 2 parameters (context, next), " - f"but {middleware.__name__} has {len(params)}" - ) - except Exception as e: - if isinstance(e, ValueError): - raise # Re-raise our custom errors - # Signature inspection failed - continue with other checks - pass - - if decorator_type and param_type: - # Both decorator and parameter type specified - they must match - if decorator_type != param_type: - raise ValueError( - f"Middleware type mismatch: decorator indicates '{decorator_type.value}' " - f"but parameter type indicates '{param_type.value}' for function {middleware.__name__}" - ) - return decorator_type - - if decorator_type: - # Just decorator specified - rely on decorator - return decorator_type - - if param_type: - # Just parameter type specified - rely on types - return param_type - - # Neither decorator nor parameter type specified - throw exception - raise ValueError( - f"Cannot determine middleware type for function {middleware.__name__}. " - f"Please either use @agent_middleware/@function_middleware decorators " - f"or specify parameter types (AgentRunContext or FunctionInvocationContext)." - ) - def _build_middleware_pipelines( agent_level_middlewares: Middleware | list[Middleware] | None, run_level_middlewares: Middleware | list[Middleware] | None = None, - ) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline]: + ) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline, list[ChatMiddleware | ChatMiddlewareCallable]]: """Build fresh agent and function middleware pipelines from the provided middleware lists. Args: agent_level_middlewares: Agent-level middleware (executed first) run_level_middlewares: Run-level middleware (executed after agent middleware) """ - # Merge middleware lists: agent middleware first, then run middleware - combined_middlewares: list[Middleware] = [] + middleware = categorize_middleware(agent_level_middlewares, run_level_middlewares) - if agent_level_middlewares: - if isinstance(agent_level_middlewares, list): - combined_middlewares.extend(agent_level_middlewares) # type: ignore[arg-type] - else: - combined_middlewares.append(agent_level_middlewares) - - if run_level_middlewares: - if isinstance(run_level_middlewares, list): - combined_middlewares.extend(run_level_middlewares) # type: ignore[arg-type] - else: - combined_middlewares.append(run_level_middlewares) - - if not combined_middlewares: - return AgentMiddlewarePipeline(), FunctionMiddlewarePipeline() - - middleware_list = combined_middlewares - - # Separate agent and function middleware using isinstance checks - agent_middlewares: list[AgentMiddleware | AgentMiddlewareCallable] = [] - function_middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] = [] - - for middleware in middleware_list: - if isinstance(middleware, AgentMiddleware): - agent_middlewares.append(middleware) - elif isinstance(middleware, FunctionMiddleware): - function_middlewares.append(middleware) - elif callable(middleware): # type: ignore[arg-type] - # Determine middleware type using decorator and/or parameter type annotation - middleware_type = _determine_middleware_type(middleware) - if middleware_type == MiddlewareType.AGENT: - agent_middlewares.append(middleware) # type: ignore - elif middleware_type == MiddlewareType.FUNCTION: - function_middlewares.append(middleware) # type: ignore - else: - # This should not happen if _determine_middleware_type is implemented correctly - raise ValueError(f"Unknown middleware type: {middleware_type}") - else: - # Fallback - agent_middlewares.append(middleware) # type: ignore - - return AgentMiddlewarePipeline(agent_middlewares), FunctionMiddlewarePipeline(function_middlewares) + return ( + AgentMiddlewarePipeline(middleware["agent"]), # type: ignore[arg-type] + FunctionMiddlewarePipeline(middleware["function"]), # type: ignore[arg-type] + middleware["chat"], # type: ignore[return-value] + ) async def middleware_enabled_run( self: Any, @@ -673,12 +866,17 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: """Middleware-enabled run method.""" # Build fresh middleware pipelines from current middleware collection and run-level middleware agent_middleware = getattr(self, "middleware", None) - agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware) + + agent_pipeline, function_pipeline, chat_middlewares = _build_middleware_pipelines(agent_middleware, middleware) # Add function middleware pipeline to kwargs if available if function_pipeline.has_middlewares: kwargs["_function_middleware_pipeline"] = function_pipeline + # Pass chat middleware through kwargs for run-level application + if chat_middlewares: + kwargs["middleware"] = chat_middlewares + normalized_messages = self._normalize_messages(messages) # Execute with middleware if available @@ -687,10 +885,11 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: agent=self, # type: ignore[arg-type] messages=normalized_messages, is_streaming=False, + kwargs=kwargs, ) async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse: - return await original_run(self, ctx.messages, thread=thread, **kwargs) # type: ignore + return await original_run(self, ctx.messages, thread=thread, **ctx.kwargs) # type: ignore result = await agent_pipeline.execute( self, # type: ignore[arg-type] @@ -715,12 +914,16 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: """Middleware-enabled run_stream method.""" # Build fresh middleware pipelines from current middleware collection and run-level middleware agent_middleware = getattr(self, "middleware", None) - agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware) + agent_pipeline, function_pipeline, chat_middlewares = _build_middleware_pipelines(agent_middleware, middleware) # Add function middleware pipeline to kwargs if available if function_pipeline.has_middlewares: kwargs["_function_middleware_pipeline"] = function_pipeline + # Pass chat middleware through kwargs for run-level application + if chat_middlewares: + kwargs["middleware"] = chat_middlewares + normalized_messages = self._normalize_messages(messages) # Execute with middleware if available @@ -729,10 +932,11 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: agent=self, # type: ignore[arg-type] messages=normalized_messages, is_streaming=True, + kwargs=kwargs, ) async def _execute_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: - async for update in original_run_stream(self, ctx.messages, thread=thread, **kwargs): # type: ignore[misc] + async for update in original_run_stream(self, ctx.messages, thread=thread, **ctx.kwargs): # type: ignore[misc] yield update async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]: @@ -753,3 +957,229 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: agent_class.run_stream = middleware_enabled_run_stream # type: ignore return agent_class + + +def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClient]: + """Class decorator that adds middleware support to a chat client class. + + This decorator adds middleware functionality to any chat client class. + It wraps the get_response() and get_streaming_response() methods to provide middleware execution. + + Args: + chat_client_class: The chat client class to add middleware support to. + + Returns: + The modified chat client class with middleware support. + """ + # Store original methods + original_get_response = chat_client_class.get_response + original_get_streaming_response = chat_client_class.get_streaming_response + + async def middleware_enabled_get_response( + self: Any, + messages: Any, + **kwargs: Any, + ) -> Any: + """Middleware-enabled get_response method.""" + # Check if middleware is provided at call level or instance level + call_middleware = kwargs.pop("middleware", None) + instance_middleware = getattr(self, "middleware", None) + + # Merge all middleware and separate by type + middleware = categorize_middleware(instance_middleware, call_middleware) + chat_middleware_list = middleware["chat"] # type: ignore[assignment] + + # Extract function middleware for the function invocation pipeline + function_middleware_list = middleware["function"] + + # Pass function middleware to function invocation system if present + if function_middleware_list: + kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list) # type: ignore[arg-type] + + # If no chat middleware, use original method + if not chat_middleware_list: + return await original_get_response(self, messages, **kwargs) + + # Create pipeline and execute with middleware + from ._types import ChatOptions + + # Extract chat_options or create default + chat_options = kwargs.pop("chat_options", ChatOptions()) + + pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type] + context = ChatContext( + chat_client=self, + messages=self.prepare_messages(messages), + chat_options=chat_options, + is_streaming=False, + kwargs=kwargs, + ) + + async def final_handler(ctx: ChatContext) -> Any: + return await original_get_response(self, list(ctx.messages), chat_options=ctx.chat_options, **ctx.kwargs) + + return await pipeline.execute( + chat_client=self, + messages=context.messages, + chat_options=context.chat_options, + context=context, + final_handler=final_handler, + **kwargs, + ) + + def middleware_enabled_get_streaming_response( + self: Any, + messages: Any, + **kwargs: Any, + ) -> Any: + """Middleware-enabled get_streaming_response method.""" + + async def _stream_generator() -> Any: + # Check if middleware is provided at call level or instance level + call_middleware = kwargs.pop("middleware", None) + instance_middleware = getattr(self, "middleware", None) + + # Merge middleware from both sources, filtering for chat middleware only + all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware( + instance_middleware, call_middleware + ) + + # If no middleware, use original method + if not all_middleware: + async for update in original_get_streaming_response(self, messages, **kwargs): + yield update + return + + # Create pipeline and execute with middleware + from ._types import ChatOptions + + # Extract chat_options or create default + chat_options = kwargs.pop("chat_options", ChatOptions()) + + pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type] + context = ChatContext( + chat_client=self, + messages=self.prepare_messages(messages), + chat_options=chat_options, + is_streaming=True, + kwargs=kwargs, + ) + + def final_handler(ctx: ChatContext) -> Any: + return original_get_streaming_response( + self, list(ctx.messages), chat_options=ctx.chat_options, **ctx.kwargs + ) + + async for update in pipeline.execute_stream( + chat_client=self, + messages=context.messages, + chat_options=context.chat_options, + context=context, + final_handler=final_handler, + **kwargs, + ): + yield update + + return _stream_generator() + + # Replace methods + chat_client_class.get_response = middleware_enabled_get_response # type: ignore + chat_client_class.get_streaming_response = middleware_enabled_get_streaming_response # type: ignore + + return chat_client_class + + +def categorize_middleware( + *middleware_sources: Any | list[Any] | None, +) -> dict[str, list[Any]]: + """Categorize middleware from multiple sources into agent, function, and chat types. + + Args: + *middleware_sources: Variable number of middleware sources to categorize. + + Returns: + Dict with keys "agent", "function", "chat" containing lists of categorized middleware. + """ + result: dict[str, list[Any]] = {"agent": [], "function": [], "chat": []} + + # Merge all middleware sources into a single list + all_middleware: list[Any] = [] + for source in middleware_sources: + if source: + if isinstance(source, list): + all_middleware.extend(source) # type: ignore + else: + all_middleware.append(source) + + # Categorize each middleware item + for middleware in all_middleware: + if isinstance(middleware, AgentMiddleware): + result["agent"].append(middleware) + elif isinstance(middleware, FunctionMiddleware): + result["function"].append(middleware) + elif isinstance(middleware, ChatMiddleware): + result["chat"].append(middleware) + elif callable(middleware): + # Always call _determine_middleware_type to ensure proper validation + middleware_type = _determine_middleware_type(middleware) + if middleware_type == MiddlewareType.AGENT: + result["agent"].append(middleware) + elif middleware_type == MiddlewareType.FUNCTION: + result["function"].append(middleware) + elif middleware_type == MiddlewareType.CHAT: + result["chat"].append(middleware) + else: + # Fallback to agent middleware for unknown types + result["agent"].append(middleware) + + return result + + +def create_function_middleware_pipeline( + *middleware_sources: list[Middleware] | None, +) -> FunctionMiddlewarePipeline | None: + """Create a function middleware pipeline from multiple middleware sources.""" + middleware = categorize_middleware(*middleware_sources) + function_middlewares = middleware["function"] + return FunctionMiddlewarePipeline(function_middlewares) if function_middlewares else None # type: ignore[arg-type] + + +def _merge_and_filter_chat_middleware( + instance_middleware: Any | list[Any] | None, + call_middleware: Any | list[Any] | None, +) -> list[ChatMiddleware | ChatMiddlewareCallable]: + """Merge instance-level and call-level middleware, filtering for chat middleware only. + + Args: + instance_middleware: Middleware defined at the instance level. + call_middleware: Middleware provided at the call level. + + Returns: + A merged list of chat middleware only. + """ + middleware = categorize_middleware(instance_middleware, call_middleware) + return middleware["chat"] # type: ignore[return-value] + + +def extract_and_merge_function_middleware(chat_client: Any, kwargs: dict[str, Any]) -> None: + """Extract function middleware from chat client and merge with existing pipeline in kwargs. + + Args: + chat_client: The chat client instance to extract middleware from. + kwargs: Dictionary containing middleware and pipeline information. + """ + # Get middleware sources + client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None + run_level_middleware = kwargs.get("middleware") + existing_pipeline = kwargs.get("_function_middleware_pipeline") + + # Extract existing pipeline middlewares if present + existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None + + # Create combined pipeline from all sources using existing helper + combined_pipeline = create_function_middleware_pipeline( + client_middleware, run_level_middleware, existing_middlewares + ) + + if combined_pipeline: + kwargs["_function_middleware_pipeline"] = combined_pipeline diff --git a/python/packages/main/agent_framework/_tools.py b/python/packages/main/agent_framework/_tools.py index 0e1f08295f..43181702c7 100644 --- a/python/packages/main/agent_framework/_tools.py +++ b/python/packages/main/agent_framework/_tools.py @@ -607,6 +607,7 @@ async def _auto_invoke_function( middleware_context = FunctionInvocationContext( function=tool, arguments=args, + kwargs=custom_args or {}, ) async def final_function_handler(context_obj: Any) -> Any: @@ -721,8 +722,16 @@ def _handle_function_calls_response( **kwargs: Any, ) -> "ChatResponse": from ._clients import prepare_messages + from ._middleware import extract_and_merge_function_middleware from ._types import ChatMessage, ChatOptions, FunctionCallContent, FunctionResultContent + # Extract and merge function middleware from chat client with kwargs pipeline + extract_and_merge_function_middleware(self, kwargs) + + # Extract the middleware pipeline before calling the underlying function + # because the underlying function may not preserve it in kwargs + stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + prepped_messages = prepare_messages(messages) response: "ChatResponse | None" = None fcc_messages: "list[ChatMessage]" = [] @@ -746,8 +755,9 @@ def _handle_function_calls_response( if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions): tools = chat_options.tools if function_calls and tools: - # Extract function middleware pipeline from kwargs if available - middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Use the stored middleware pipeline instead of extracting from kwargs + # because kwargs may have been modified by the underlying function + middleware_pipeline = stored_middleware_pipeline function_results = await execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, @@ -820,8 +830,16 @@ def _handle_function_calls_streaming_response( ) -> AsyncIterable["ChatResponseUpdate"]: """Wrap the inner get streaming response method to handle tool calls.""" from ._clients import prepare_messages + from ._middleware import extract_and_merge_function_middleware from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent + # Extract and merge function middleware from chat client with kwargs pipeline + extract_and_merge_function_middleware(self, kwargs) + + # Extract the middleware pipeline before calling the underlying function + # because the underlying function may not preserve it in kwargs + stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline") + prepped_messages = prepare_messages(messages) for attempt_idx in range(max_iterations): all_updates: list["ChatResponseUpdate"] = [] @@ -858,8 +876,9 @@ def _handle_function_calls_streaming_response( tools = chat_options.tools if function_calls and tools: - # Extract function middleware pipeline from kwargs if available - middleware_pipeline = kwargs.get("_function_middleware_pipeline") + # Use the stored middleware pipeline instead of extracting from kwargs + # because kwargs may have been modified by the underlying function + middleware_pipeline = stored_middleware_pipeline function_results = await execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, diff --git a/python/packages/main/agent_framework/azure/_chat_client.py b/python/packages/main/agent_framework/azure/_chat_client.py index 045a22aba7..53500e9726 100644 --- a/python/packages/main/agent_framework/azure/_chat_client.py +++ b/python/packages/main/agent_framework/azure/_chat_client.py @@ -13,16 +13,18 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from pydantic import SecretStr, ValidationError from pydantic.networks import AnyUrl -from .._tools import use_function_invocation -from .._types import ( +from agent_framework import ( ChatResponse, ChatResponseUpdate, CitationAnnotation, TextContent, + use_chat_middleware, + use_function_invocation, ) -from ..exceptions import ServiceInitializationError -from ..observability import use_observability -from ..openai._chat_client import OpenAIBaseChatClient +from agent_framework.exceptions import ServiceInitializationError +from agent_framework.observability import use_observability +from agent_framework.openai._chat_client import OpenAIBaseChatClient + from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, @@ -41,6 +43,7 @@ TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAICha @use_function_invocation @use_observability +@use_chat_middleware class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient): """Azure OpenAI Chat completion class.""" diff --git a/python/packages/main/agent_framework/azure/_responses_client.py b/python/packages/main/agent_framework/azure/_responses_client.py index 5f61690a81..19dc4b3fbb 100644 --- a/python/packages/main/agent_framework/azure/_responses_client.py +++ b/python/packages/main/agent_framework/azure/_responses_client.py @@ -9,10 +9,11 @@ from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI from pydantic import SecretStr, ValidationError from pydantic.networks import AnyUrl -from .._tools import use_function_invocation -from ..exceptions import ServiceInitializationError -from ..observability import use_observability -from ..openai._responses_client import OpenAIBaseResponsesClient +from agent_framework import use_chat_middleware, use_function_invocation +from agent_framework.exceptions import ServiceInitializationError +from agent_framework.observability import use_observability +from agent_framework.openai._responses_client import OpenAIBaseResponsesClient + from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, @@ -23,6 +24,7 @@ TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="Azur @use_observability @use_function_invocation +@use_chat_middleware class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient): """Azure Responses completion class.""" diff --git a/python/packages/main/agent_framework/exceptions.py b/python/packages/main/agent_framework/exceptions.py index 242b83c995..128981022b 100644 --- a/python/packages/main/agent_framework/exceptions.py +++ b/python/packages/main/agent_framework/exceptions.py @@ -128,3 +128,9 @@ class AdditionItemMismatch(AgentFrameworkException): """An error occurred while adding two types.""" pass + + +class MiddlewareException(AgentFrameworkException): + """An error occurred during middleware execution.""" + + pass diff --git a/python/packages/main/agent_framework/openai/_assistants_client.py b/python/packages/main/agent_framework/openai/_assistants_client.py index b9dd574f65..70733f7287 100644 --- a/python/packages/main/agent_framework/openai/_assistants_client.py +++ b/python/packages/main/agent_framework/openai/_assistants_client.py @@ -21,6 +21,7 @@ from openai.types.beta.threads.runs import RunStep from pydantic import Field, PrivateAttr, SecretStr, ValidationError from .._clients import BaseChatClient +from .._middleware import use_chat_middleware from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation from .._types import ( ChatMessage, @@ -52,6 +53,7 @@ __all__ = ["OpenAIAssistantsClient"] @use_function_invocation @use_observability +@use_chat_middleware class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient): """OpenAI Assistants client.""" diff --git a/python/packages/main/agent_framework/openai/_chat_client.py b/python/packages/main/agent_framework/openai/_chat_client.py index 8696bdc737..abc12bcc70 100644 --- a/python/packages/main/agent_framework/openai/_chat_client.py +++ b/python/packages/main/agent_framework/openai/_chat_client.py @@ -18,6 +18,7 @@ from pydantic import BaseModel, SecretStr, ValidationError from .._clients import BaseChatClient from .._logging import get_logger +from .._middleware import use_chat_middleware from .._tools import AIFunction, HostedWebSearchTool, ToolProtocol, use_function_invocation from .._types import ( ChatMessage, @@ -452,6 +453,7 @@ TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient") @use_function_invocation @use_observability +@use_chat_middleware class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient): """OpenAI Chat completion class.""" diff --git a/python/packages/main/agent_framework/openai/_responses_client.py b/python/packages/main/agent_framework/openai/_responses_client.py index db4db96a43..9c17e35cc6 100644 --- a/python/packages/main/agent_framework/openai/_responses_client.py +++ b/python/packages/main/agent_framework/openai/_responses_client.py @@ -26,6 +26,7 @@ from pydantic import BaseModel, SecretStr, ValidationError from .._clients import BaseChatClient from .._logging import get_logger +from .._middleware import use_chat_middleware from .._tools import ( AIFunction, HostedCodeInterpreterTool, @@ -933,6 +934,7 @@ TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponse @use_function_invocation @use_observability +@use_chat_middleware class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient): """OpenAI Responses client class.""" diff --git a/python/packages/main/tests/main/conftest.py b/python/packages/main/tests/main/conftest.py index 579d6f0eee..f7813cdd0f 100644 --- a/python/packages/main/tests/main/conftest.py +++ b/python/packages/main/tests/main/conftest.py @@ -25,6 +25,7 @@ from agent_framework import ( TextContent, ToolProtocol, ai_function, + use_chat_middleware, use_function_invocation, ) @@ -111,11 +112,13 @@ class MockChatClient: yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant") +@use_chat_middleware class MockBaseChatClient(BaseChatClient): """Mock implementation of the BaseChatClient.""" run_responses: list[ChatResponse] = Field(default_factory=list) streaming_responses: list[list[ChatResponseUpdate]] = Field(default_factory=list) + call_count: int = Field(default=0) @override async def _inner_get_response( @@ -136,6 +139,7 @@ class MockBaseChatClient(BaseChatClient): The chat response contents representing the response(s). """ logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}") + self.call_count += 1 if not self.run_responses: return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}")) diff --git a/python/packages/main/tests/main/test_middleware.py b/python/packages/main/tests/main/test_middleware.py index a530c4c5e8..4051fdad9d 100644 --- a/python/packages/main/tests/main/test_middleware.py +++ b/python/packages/main/tests/main/test_middleware.py @@ -12,6 +12,8 @@ from agent_framework import ( AgentRunResponse, AgentRunResponseUpdate, ChatMessage, + ChatResponse, + ChatResponseUpdate, Role, TextContent, ) @@ -19,11 +21,15 @@ from agent_framework._middleware import ( AgentMiddleware, AgentMiddlewarePipeline, AgentRunContext, + ChatContext, + ChatMiddleware, + ChatMiddlewarePipeline, FunctionInvocationContext, FunctionMiddleware, FunctionMiddlewarePipeline, ) from agent_framework._tools import AIFunction +from agent_framework._types import ChatOptions class TestAgentRunContext: @@ -74,6 +80,46 @@ class TestFunctionInvocationContext: assert context.metadata == metadata +class TestChatContext: + """Test cases for ChatContext.""" + + def test_init_with_defaults(self, mock_chat_client: Any) -> None: + """Test ChatContext initialization with default values.""" + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + assert context.chat_client is mock_chat_client + assert context.messages == messages + assert context.chat_options is chat_options + assert context.is_streaming is False + assert context.metadata == {} + assert context.result is None + assert context.terminate is False + + def test_init_with_custom_values(self, mock_chat_client: Any) -> None: + """Test ChatContext initialization with custom values.""" + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions(temperature=0.5) + metadata = {"key": "value"} + + context = ChatContext( + chat_client=mock_chat_client, + messages=messages, + chat_options=chat_options, + is_streaming=True, + metadata=metadata, + terminate=True, + ) + + assert context.chat_client is mock_chat_client + assert context.messages == messages + assert context.chat_options is chat_options + assert context.is_streaming is True + assert context.metadata == metadata + assert context.terminate is True + + class TestAgentMiddlewarePipeline: """Test cases for AgentMiddlewarePipeline.""" @@ -410,6 +456,233 @@ class TestFunctionMiddlewarePipeline: assert execution_order == ["test_before", "handler", "test_after"] +class TestChatMiddlewarePipeline: + """Test cases for ChatMiddlewarePipeline.""" + + class PreNextTerminateChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + context.terminate = True + await next(context) + + class PostNextTerminateChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + await next(context) + context.terminate = True + + def test_init_empty(self) -> None: + """Test ChatMiddlewarePipeline initialization with no middlewares.""" + pipeline = ChatMiddlewarePipeline() + assert not pipeline.has_middlewares + + def test_init_with_class_middleware(self) -> None: + """Test ChatMiddlewarePipeline initialization with class-based middleware.""" + middleware = TestChatMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + assert pipeline.has_middlewares + + def test_init_with_function_middleware(self) -> None: + """Test ChatMiddlewarePipeline initialization with function-based middleware.""" + + async def test_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + await next(context) + + pipeline = ChatMiddlewarePipeline([test_middleware]) + assert pipeline.has_middlewares + + async def test_execute_no_middleware(self, mock_chat_client: Any) -> None: + """Test pipeline execution with no middleware.""" + pipeline = ChatMiddlewarePipeline() + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + async def final_handler(ctx: ChatContext) -> ChatResponse: + return expected_response + + result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + assert result == expected_response + + async def test_execute_with_middleware(self, mock_chat_client: Any) -> None: + """Test pipeline execution with middleware.""" + execution_order: list[str] = [] + + class OrderTrackingChatMiddleware(ChatMiddleware): + def __init__(self, name: str): + self.name = name + + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + middleware = OrderTrackingChatMiddleware("test") + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + async def final_handler(ctx: ChatContext) -> ChatResponse: + execution_order.append("handler") + return expected_response + + result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + assert result == expected_response + assert execution_order == ["test_before", "handler", "test_after"] + + async def test_execute_stream_no_middleware(self, mock_chat_client: Any) -> None: + """Test pipeline streaming execution with no middleware.""" + pipeline = ChatMiddlewarePipeline() + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) + yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + + updates: list[ChatResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler): + updates.append(update) + + assert len(updates) == 2 + assert updates[0].text == "chunk1" + assert updates[1].text == "chunk2" + + async def test_execute_stream_with_middleware(self, mock_chat_client: Any) -> None: + """Test pipeline streaming execution with middleware.""" + execution_order: list[str] = [] + + class StreamOrderTrackingChatMiddleware(ChatMiddleware): + def __init__(self, name: str): + self.name = name + + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + middleware = StreamOrderTrackingChatMiddleware("test") + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext( + chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True + ) + + async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: + execution_order.append("handler_start") + yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) + yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + execution_order.append("handler_end") + + updates: list[ChatResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler): + updates.append(update) + + assert len(updates) == 2 + assert updates[0].text == "chunk1" + assert updates[1].text == "chunk2" + assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"] + + async def test_execute_with_pre_next_termination(self, mock_chat_client: Any) -> None: + """Test pipeline execution with termination before next().""" + middleware = self.PreNextTerminateChatMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + execution_order: list[str] = [] + + async def final_handler(ctx: ChatContext) -> ChatResponse: + # Handler should not be executed when terminated before next() + execution_order.append("handler") + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + assert response is None + assert context.terminate + # Handler should not be called when terminated before next() + assert execution_order == [] + + async def test_execute_with_post_next_termination(self, mock_chat_client: Any) -> None: + """Test pipeline execution with termination after next().""" + middleware = self.PostNextTerminateChatMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + execution_order: list[str] = [] + + async def final_handler(ctx: ChatContext) -> ChatResponse: + execution_order.append("handler") + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + assert response is not None + assert len(response.messages) == 1 + assert response.messages[0].text == "response" + assert context.terminate + assert execution_order == ["handler"] + + async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: Any) -> None: + """Test pipeline streaming execution with termination before next().""" + middleware = self.PreNextTerminateChatMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext( + chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True + ) + execution_order: list[str] = [] + + async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: + # Handler should not be executed when terminated before next() + execution_order.append("handler_start") + yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) + yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + execution_order.append("handler_end") + + updates: list[ChatResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler): + updates.append(update) + + assert context.terminate + # Handler should not be called when terminated before next() + assert execution_order == [] + assert not updates + + async def test_execute_stream_with_post_next_termination(self, mock_chat_client: Any) -> None: + """Test pipeline streaming execution with termination after next().""" + middleware = self.PostNextTerminateChatMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext( + chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True + ) + execution_order: list[str] = [] + + async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: + execution_order.append("handler_start") + yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) + yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + execution_order.append("handler_end") + + updates: list[ChatResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler): + updates.append(update) + + assert len(updates) == 2 + assert updates[0].text == "chunk1" + assert updates[1].text == "chunk2" + assert context.terminate + assert execution_order == ["handler_start", "handler_end"] + + class TestClassBasedMiddleware: """Test cases for class-based middleware implementations.""" @@ -601,6 +874,37 @@ class TestMixedMiddleware: assert result == "result" assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"] + async def test_mixed_chat_middleware(self, mock_chat_client: Any) -> None: + """Test mixed class and function-based chat middleware.""" + execution_order: list[str] = [] + + class ClassChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("class_before") + await next(context) + execution_order.append("class_after") + + async def function_chat_middleware( + context: ChatContext, next: Callable[[ChatContext], Awaitable[None]] + ) -> None: + execution_order.append("function_before") + await next(context) + execution_order.append("function_after") + + pipeline = ChatMiddlewarePipeline([ClassChatMiddleware(), function_chat_middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + async def final_handler(ctx: ChatContext) -> ChatResponse: + execution_order.append("handler") + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + + assert result is not None + assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"] + class TestMultipleMiddlewareOrdering: """Test cases for multiple middleware execution order.""" @@ -695,6 +999,52 @@ class TestMultipleMiddlewareOrdering: expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"] assert execution_order == expected_order + async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> None: + """Test that multiple chat middlewares execute in registration order.""" + execution_order: list[str] = [] + + class FirstChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("first_before") + await next(context) + execution_order.append("first_after") + + class SecondChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("second_before") + await next(context) + execution_order.append("second_after") + + class ThirdChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("third_before") + await next(context) + execution_order.append("third_after") + + middlewares = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()] + pipeline = ChatMiddlewarePipeline(middlewares) # type: ignore + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + async def final_handler(ctx: ChatContext) -> ChatResponse: + execution_order.append("handler") + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + + assert result is not None + expected_order = [ + "first_before", + "second_before", + "third_before", + "handler", + "third_after", + "second_after", + "first_after", + ] + assert execution_order == expected_order + class TestContextContentValidation: """Test cases for validating middleware context content.""" @@ -776,6 +1126,49 @@ class TestContextContentValidation: result = await pipeline.execute(mock_function, arguments, context, final_handler) assert result == "result" + async def test_chat_context_validation(self, mock_chat_client: Any) -> None: + """Test that chat context contains expected data.""" + + class ChatContextValidationMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + # Verify context has all expected attributes + assert hasattr(context, "chat_client") + assert hasattr(context, "messages") + assert hasattr(context, "chat_options") + assert hasattr(context, "is_streaming") + assert hasattr(context, "metadata") + assert hasattr(context, "result") + assert hasattr(context, "terminate") + + # Verify context content + assert context.chat_client is mock_chat_client + assert len(context.messages) == 1 + assert context.messages[0].role == Role.USER + assert context.messages[0].text == "test" + assert context.is_streaming is False + assert isinstance(context.metadata, dict) + assert isinstance(context.chat_options, ChatOptions) + assert context.chat_options.temperature == 0.5 + + # Add custom metadata + context.metadata["validated"] = True + + await next(context) + + middleware = ChatContextValidationMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions(temperature=0.5) + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + async def final_handler(ctx: ChatContext) -> ChatResponse: + # Verify metadata was set by middleware + assert ctx.metadata.get("validated") is True + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + assert result is not None + class TestStreamingScenarios: """Test cases for streaming and non-streaming scenarios.""" @@ -857,6 +1250,89 @@ class TestStreamingScenarios: "stream_end", ] + async def test_chat_streaming_flag_validation(self, mock_chat_client: Any) -> None: + """Test that is_streaming flag is correctly set for chat streaming calls.""" + streaming_flags: list[bool] = [] + + class ChatStreamingFlagMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + streaming_flags.append(context.is_streaming) + await next(context) + + middleware = ChatStreamingFlagMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + + # Test non-streaming + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + async def final_handler(ctx: ChatContext) -> ChatResponse: + streaming_flags.append(ctx.is_streaming) + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + + # Test streaming + context_stream = ChatContext( + chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True + ) + + async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: + streaming_flags.append(ctx.is_streaming) + yield ChatResponseUpdate(contents=[TextContent(text="chunk")]) + + updates: list[ChatResponseUpdate] = [] + async for update in pipeline.execute_stream( + mock_chat_client, messages, chat_options, context_stream, final_stream_handler + ): + updates.append(update) + + # Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler] + assert streaming_flags == [False, False, True, True] + + async def test_chat_streaming_middleware_behavior(self, mock_chat_client: Any) -> None: + """Test chat middleware behavior with streaming responses.""" + chunks_processed: list[str] = [] + + class ChatStreamProcessingMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + chunks_processed.append("before_stream") + await next(context) + chunks_processed.append("after_stream") + + middleware = ChatStreamProcessingMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext( + chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True + ) + + async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: + chunks_processed.append("stream_start") + yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) + chunks_processed.append("chunk1_yielded") + yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + chunks_processed.append("chunk2_yielded") + chunks_processed.append("stream_end") + + updates: list[str] = [] + async for update in pipeline.execute_stream( + mock_chat_client, messages, chat_options, context, final_stream_handler + ): + updates.append(update.text) + + assert updates == ["chunk1", "chunk2"] + assert chunks_processed == [ + "before_stream", + "after_stream", + "stream_start", + "chunk1_yielded", + "chunk2_yielded", + "stream_end", + ] + # region Helper classes and fixtures @@ -883,6 +1359,13 @@ class TestFunctionMiddleware(FunctionMiddleware): await next(context) +class TestChatMiddleware(ChatMiddleware): + """Test implementation of ChatMiddleware.""" + + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + await next(context) + + class MockFunctionArgs(BaseModel): """Test arguments for function middleware tests.""" @@ -1027,6 +1510,100 @@ class TestMiddlewareExecutionControl: assert result.messages == [] # Empty response assert not handler_called + async def test_chat_middleware_no_next_no_execution(self, mock_chat_client: Any) -> None: + """Test that when chat middleware doesn't call next(), no execution happens.""" + + class NoNextChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + # Don't call next() - this should prevent any execution + pass + + middleware = NoNextChatMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + handler_called = False + + async def final_handler(ctx: ChatContext) -> ChatResponse: + nonlocal handler_called + handler_called = True + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) + + result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + + # Verify no execution happened + assert result is None + assert not handler_called + assert context.result is None + + async def test_chat_middleware_no_next_no_streaming_execution(self, mock_chat_client: Any) -> None: + """Test that when chat middleware doesn't call next(), no streaming execution happens.""" + + class NoNextStreamingChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + # Don't call next() - this should prevent any execution + pass + + middleware = NoNextStreamingChatMiddleware() + pipeline = ChatMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext( + chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True + ) + + handler_called = False + + async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: + nonlocal handler_called + handler_called = True + yield ChatResponseUpdate(contents=[TextContent(text="should not execute")]) + + # When middleware doesn't call next(), streaming should yield no updates + updates: list[ChatResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler): + updates.append(update) + + # Verify no execution happened and no updates were yielded + assert len(updates) == 0 + assert not handler_called + assert context.result is None + + async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) -> None: + """Test that when first chat middleware doesn't call next(), subsequent middlewares are not called.""" + execution_order: list[str] = [] + + class FirstChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("first") + # Don't call next() - this should stop the pipeline + + class SecondChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("second") + await next(context) + + pipeline = ChatMiddlewarePipeline([FirstChatMiddleware(), SecondChatMiddleware()]) + messages = [ChatMessage(role=Role.USER, text="test")] + chat_options = ChatOptions() + context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) + + handler_called = False + + async def final_handler(ctx: ChatContext) -> ChatResponse: + nonlocal handler_called + handler_called = True + return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) + + result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler) + + # Verify only first middleware was called and no result returned + assert execution_order == ["first"] + assert result is None + assert not handler_called + @pytest.fixture def mock_agent() -> AgentProtocol: @@ -1042,3 +1619,13 @@ def mock_function() -> AIFunction[Any, Any]: function = MagicMock(spec=AIFunction[Any, Any]) function.name = "test_function" return function + + +@pytest.fixture +def mock_chat_client() -> Any: + """Mock chat client for testing.""" + from agent_framework._clients import ChatClientProtocol + + client = MagicMock(spec=ChatClientProtocol) + client.service_url = MagicMock(return_value="mock://test") + return client diff --git a/python/packages/main/tests/main/test_middleware_with_agent.py b/python/packages/main/tests/main/test_middleware_with_agent.py index 6ef89e5c97..025ce9b3ae 100644 --- a/python/packages/main/tests/main/test_middleware_with_agent.py +++ b/python/packages/main/tests/main/test_middleware_with_agent.py @@ -8,7 +8,9 @@ import pytest from agent_framework import ( AgentRunResponseUpdate, ChatAgent, + ChatContext, ChatMessage, + ChatMiddleware, ChatResponse, ChatResponseUpdate, FunctionCallContent, @@ -16,7 +18,9 @@ from agent_framework import ( Role, TextContent, agent_middleware, + chat_middleware, function_middleware, + use_function_invocation, ) from agent_framework._middleware import ( AgentMiddleware, @@ -25,8 +29,9 @@ from agent_framework._middleware import ( FunctionMiddleware, MiddlewareType, ) +from agent_framework.exceptions import MiddlewareException -from .conftest import MockChatClient +from .conftest import MockBaseChatClient, MockChatClient # region ChatAgent Tests @@ -713,6 +718,101 @@ class TestChatAgentFunctionMiddlewareWithTools: assert function_calls[0].name == "sample_tool_function" assert function_results[0].call_id == function_calls[0].call_id + async def test_function_middleware_can_access_and_override_custom_kwargs( + self, chat_client: "MockChatClient" + ) -> None: + """Test that function middleware can access and override custom parameters like temperature.""" + captured_kwargs: dict[str, Any] = {} + modified_kwargs: dict[str, Any] = {} + middleware_called = False + + @function_middleware + async def kwargs_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + nonlocal middleware_called + middleware_called = True + + # Capture the original kwargs + captured_kwargs["has_chat_options"] = "chat_options" in context.kwargs + captured_kwargs["has_custom_param"] = "custom_param" in context.kwargs + captured_kwargs["custom_param"] = context.kwargs.get("custom_param") + + # Capture original chat_options values if present + if "chat_options" in context.kwargs: + chat_options = context.kwargs["chat_options"] + captured_kwargs["original_temperature"] = getattr(chat_options, "temperature", None) + captured_kwargs["original_max_tokens"] = getattr(chat_options, "max_tokens", None) + + # Modify some kwargs + context.kwargs["temperature"] = 0.9 + context.kwargs["max_tokens"] = 500 + context.kwargs["new_param"] = "added_by_middleware" + + # Also modify chat_options if present + if "chat_options" in context.kwargs: + context.kwargs["chat_options"].temperature = 0.9 + context.kwargs["chat_options"].max_tokens = 500 + + # Store modified kwargs for verification + modified_kwargs["temperature"] = context.kwargs.get("temperature") + modified_kwargs["max_tokens"] = context.kwargs.get("max_tokens") + modified_kwargs["new_param"] = context.kwargs.get("new_param") + modified_kwargs["custom_param"] = context.kwargs.get("custom_param") + + # Capture modified chat_options values if present + if "chat_options" in context.kwargs: + chat_options = context.kwargs["chat_options"] + modified_kwargs["chat_options_temperature"] = getattr(chat_options, "temperature", None) + modified_kwargs["chat_options_max_tokens"] = getattr(chat_options, "max_tokens", None) + + await next(context) + + chat_client.responses = [ + ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="test_call", name="sample_tool_function", arguments={"location": "Seattle"} + ) + ], + ) + ] + ), + ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Function completed")])]), + ] + + # Create ChatAgent with function middleware + agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware], tools=[sample_tool_function]) + + # Execute the agent with custom parameters + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value") + + # Verify response + assert response is not None + assert len(response.messages) > 0 + + # First check if middleware was called at all + assert middleware_called, "Function middleware was not called" + + # Verify middleware captured the original kwargs + assert captured_kwargs["has_chat_options"] is True + assert captured_kwargs["has_custom_param"] is True + assert captured_kwargs["custom_param"] == "test_value" + assert captured_kwargs["original_temperature"] == 0.7 + assert captured_kwargs["original_max_tokens"] == 100 + + # Verify middleware could modify the kwargs + assert modified_kwargs["temperature"] == 0.9 + assert modified_kwargs["max_tokens"] == 500 + assert modified_kwargs["new_param"] == "added_by_middleware" + assert modified_kwargs["custom_param"] == "test_value" + assert modified_kwargs["chat_options_temperature"] == 0.9 + assert modified_kwargs["chat_options_max_tokens"] == 500 + class TestMiddlewareDynamicRebuild: """Test cases for dynamic middleware pipeline rebuilding with ChatAgent.""" @@ -1187,8 +1287,8 @@ class TestMiddlewareDecoratorLogic: """Both decorator and parameter type specified but don't match.""" # This will cause a type error at decoration time, so we need to test differently - # Should raise ValueError due to mismatch during agent creation - with pytest.raises(ValueError, match="Middleware type mismatch"): + # Should raise MiddlewareException due to mismatch during agent creation + with pytest.raises(MiddlewareException, match="Middleware type mismatch"): @agent_middleware # type: ignore[arg-type] async def mismatched_middleware( @@ -1304,8 +1404,8 @@ class TestMiddlewareDecoratorLogic: async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type await next(context) - # Should raise ValueError - with pytest.raises(ValueError, match="Cannot determine middleware type"): + # Should raise MiddlewareException + with pytest.raises(MiddlewareException, match="Cannot determine middleware type"): agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware]) await agent.run([ChatMessage(role=Role.USER, text="test")]) @@ -1313,8 +1413,8 @@ class TestMiddlewareDecoratorLogic: """Test that middleware with insufficient parameters raises an error.""" from agent_framework import ChatAgent, agent_middleware - # Should raise ValueError about insufficient parameters - with pytest.raises(ValueError, match="must have at least 2 parameters"): + # Should raise MiddlewareException about insufficient parameters + with pytest.raises(MiddlewareException, match="must have at least 2 parameters"): @agent_middleware # type: ignore[arg-type] async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter @@ -1340,3 +1440,359 @@ class TestMiddlewareDecoratorLogic: assert hasattr(test_function_middleware, "_middleware_type") assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined] + + +class TestChatAgentChatMiddleware: + """Test cases for chat middleware integration with ChatAgent.""" + + async def test_class_based_chat_middleware_with_chat_agent(self) -> None: + """Test class-based chat middleware with ChatAgent.""" + execution_order: list[str] = [] + + class TrackingChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("chat_middleware_before") + await next(context) + execution_order.append("chat_middleware_after") + + # Create ChatAgent with chat middleware + chat_client = MockBaseChatClient() + middleware = TrackingChatMiddleware() + agent = ChatAgent(chat_client=chat_client, middleware=[middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + assert "test response" in response.messages[0].text + assert execution_order == ["chat_middleware_before", "chat_middleware_after"] + + async def test_function_based_chat_middleware_with_chat_agent(self) -> None: + """Test function-based chat middleware with ChatAgent.""" + execution_order: list[str] = [] + + async def tracking_chat_middleware( + context: ChatContext, next: Callable[[ChatContext], Awaitable[None]] + ) -> None: + execution_order.append("chat_middleware_before") + await next(context) + execution_order.append("chat_middleware_after") + + # Create ChatAgent with function-based chat middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[tracking_chat_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + assert "test response" in response.messages[0].text + assert execution_order == ["chat_middleware_before", "chat_middleware_after"] + + async def test_chat_middleware_can_modify_messages(self) -> None: + """Test that chat middleware can modify messages before sending to model.""" + + @chat_middleware + async def message_modifier_middleware( + context: ChatContext, next: Callable[[ChatContext], Awaitable[None]] + ) -> None: + # Modify the first message by adding a prefix + if context.messages and len(context.messages) > 0: + original_text = context.messages[0].text or "" + context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}") + await next(context) + + # Create ChatAgent with message-modifying middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[message_modifier_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify that the message was modified (MockBaseChatClient echoes back the input) + assert response is not None + assert len(response.messages) > 0 + assert "MODIFIED: test message" in response.messages[0].text + + async def test_chat_middleware_can_override_response(self) -> None: + """Test that chat middleware can override the response.""" + + @chat_middleware + async def response_override_middleware( + context: ChatContext, next: Callable[[ChatContext], Awaitable[None]] + ) -> None: + # Override the response without calling next() + context.result = ChatResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")], + response_id="middleware-response-123", + ) + context.terminate = True + + # Create ChatAgent with response-overriding middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[response_override_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify that the response was overridden + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].text == "Middleware overridden response" + assert response.response_id == "middleware-response-123" + + async def test_multiple_chat_middleware_execution_order(self) -> None: + """Test that multiple chat middleware execute in the correct order.""" + execution_order: list[str] = [] + + @chat_middleware + async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("first_before") + await next(context) + execution_order.append("first_after") + + @chat_middleware + async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("second_before") + await next(context) + execution_order.append("second_after") + + # Create ChatAgent with multiple chat middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[first_middleware, second_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert execution_order == ["first_before", "second_before", "second_after", "first_after"] + + async def test_chat_middleware_with_streaming(self) -> None: + """Test chat middleware with streaming responses.""" + execution_order: list[str] = [] + streaming_flags: list[bool] = [] + + class StreamingTrackingChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("streaming_chat_before") + streaming_flags.append(context.is_streaming) + await next(context) + execution_order.append("streaming_chat_after") + + # Create ChatAgent with chat middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[StreamingTrackingChatMiddleware()]) + + # Set up mock streaming responses + chat_client.streaming_responses = [ + [ + ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT), + ] + ] + + # Execute streaming + messages = [ChatMessage(role=Role.USER, text="test message")] + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream(messages): + updates.append(update) + + # Verify streaming response + assert len(updates) >= 1 # At least some updates + assert execution_order == ["streaming_chat_before", "streaming_chat_after"] + + # Verify streaming flag was set (at least one True) + assert True in streaming_flags + + async def test_chat_middleware_termination_before_execution(self) -> None: + """Test that chat middleware can terminate execution before calling next().""" + execution_order: list[str] = [] + + class PreTerminationChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("middleware_before") + context.terminate = True + # Set a custom response since we're terminating + context.result = ChatResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="Terminated by middleware")] + ) + # We call next() but since terminate=True, execution should stop + await next(context) + execution_order.append("middleware_after") + + # Create ChatAgent with terminating middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[PreTerminationChatMiddleware()]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response was from middleware + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].text == "Terminated by middleware" + assert execution_order == ["middleware_before", "middleware_after"] + + async def test_chat_middleware_termination_after_execution(self) -> None: + """Test that chat middleware can terminate execution after calling next().""" + execution_order: list[str] = [] + + class PostTerminationChatMiddleware(ChatMiddleware): + async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("middleware_before") + await next(context) + execution_order.append("middleware_after") + context.terminate = True + + # Create ChatAgent with terminating middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[PostTerminationChatMiddleware()]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response is from actual execution + assert response is not None + assert len(response.messages) > 0 + assert "test response" in response.messages[0].text + assert execution_order == ["middleware_before", "middleware_after"] + + async def test_combined_middleware(self) -> None: + """Test ChatAgent with combined middleware types.""" + execution_order: list[str] = [] + + async def agent_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("agent_middleware_before") + await next(context) + execution_order.append("agent_middleware_after") + + async def chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("chat_middleware_before") + await next(context) + execution_order.append("chat_middleware_after") + + async def function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("function_middleware_before") + await next(context) + execution_order.append("function_middleware_after") + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="call_456", + name="sample_tool_function", + arguments='{"location": "San Francisco"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + + chat_client = use_function_invocation(MockBaseChatClient)() + chat_client.run_responses = [function_call_response, final_response] + + # Create ChatAgent with function middleware and tools + agent = ChatAgent( + chat_client=chat_client, + middleware=[chat_middleware, function_middleware, agent_middleware], + tools=[sample_tool_function], + ) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 2 # Two calls: one for function call, one for final response + + # Verify function middleware was executed + assert execution_order == [ + "agent_middleware_before", + "chat_middleware_before", + "chat_middleware_after", + "function_middleware_before", + "function_middleware_after", + "chat_middleware_before", + "chat_middleware_after", + "agent_middleware_after", + ] + + # Verify function call and result are in the response + all_contents = [content for message in response.messages for content in message.contents] + function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] + function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + + assert len(function_calls) == 1 + assert len(function_results) == 1 + assert function_calls[0].name == "sample_tool_function" + assert function_results[0].call_id == function_calls[0].call_id + + async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None: + """Test that agent middleware can access and override custom parameters like temperature.""" + captured_kwargs: dict[str, Any] = {} + modified_kwargs: dict[str, Any] = {} + + @agent_middleware + async def kwargs_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Capture the original kwargs + captured_kwargs.update(context.kwargs) + + # Modify some kwargs + context.kwargs["temperature"] = 0.9 + context.kwargs["max_tokens"] = 500 + context.kwargs["new_param"] = "added_by_middleware" + + # Store modified kwargs for verification + modified_kwargs.update(context.kwargs) + + await next(context) + + # Create ChatAgent with agent middleware + chat_client = MockBaseChatClient() + agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware]) + + # Execute the agent with custom parameters + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value") + + # Verify response + assert response is not None + assert len(response.messages) > 0 + + # Verify middleware captured the original kwargs + assert captured_kwargs["temperature"] == 0.7 + assert captured_kwargs["max_tokens"] == 100 + assert captured_kwargs["custom_param"] == "test_value" + + # Verify middleware could modify the kwargs + assert modified_kwargs["temperature"] == 0.9 + assert modified_kwargs["max_tokens"] == 500 + assert modified_kwargs["new_param"] == "added_by_middleware" + assert modified_kwargs["custom_param"] == "test_value" # Should still be there diff --git a/python/packages/main/tests/main/test_middleware_with_chat.py b/python/packages/main/tests/main/test_middleware_with_chat.py new file mode 100644 index 0000000000..e850710e49 --- /dev/null +++ b/python/packages/main/tests/main/test_middleware_with_chat.py @@ -0,0 +1,436 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import Awaitable, Callable +from typing import Any + +from agent_framework import ( + ChatAgent, + ChatContext, + ChatMessage, + ChatMiddleware, + ChatResponse, + FunctionCallContent, + FunctionInvocationContext, + Role, + chat_middleware, + function_middleware, + use_function_invocation, +) + +from .conftest import MockBaseChatClient + + +class TestChatMiddleware: + """Test cases for chat middleware functionality.""" + + async def test_class_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None: + """Test class-based chat middleware with ChatClient.""" + execution_order: list[str] = [] + + class LoggingChatMiddleware(ChatMiddleware): + async def process( + self, + context: ChatContext, + next: Callable[[ChatContext], Awaitable[None]], + ) -> None: + execution_order.append("chat_middleware_before") + await next(context) + execution_order.append("chat_middleware_after") + + # Add middleware to chat client + chat_client_base.middleware = [LoggingChatMiddleware()] + + # Execute chat client directly + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await chat_client_base.get_response(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + + # Verify middleware execution order + assert execution_order == ["chat_middleware_before", "chat_middleware_after"] + + async def test_function_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None: + """Test function-based chat middleware with ChatClient.""" + execution_order: list[str] = [] + + @chat_middleware + async def logging_chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("function_middleware_before") + await next(context) + execution_order.append("function_middleware_after") + + # Add middleware to chat client + chat_client_base.middleware = [logging_chat_middleware] + + # Execute chat client directly + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await chat_client_base.get_response(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + + # Verify middleware execution order + assert execution_order == ["function_middleware_before", "function_middleware_after"] + + async def test_chat_middleware_can_modify_messages(self, chat_client_base: "MockBaseChatClient") -> None: + """Test that chat middleware can modify messages before sending to model.""" + + @chat_middleware + async def message_modifier_middleware( + context: ChatContext, next: Callable[[ChatContext], Awaitable[None]] + ) -> None: + # Modify the first message by adding a prefix + if context.messages and len(context.messages) > 0: + original_text = context.messages[0].text or "" + context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}") + await next(context) + + # Add middleware to chat client + chat_client_base.middleware = [message_modifier_middleware] + + # Execute chat client + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await chat_client_base.get_response(messages) + + # Verify that the message was modified (MockChatClient echoes back the input) + assert response is not None + assert len(response.messages) > 0 + # The mock client should receive the modified message + assert "MODIFIED: test message" in response.messages[0].text + + async def test_chat_middleware_can_override_response(self, chat_client_base: "MockBaseChatClient") -> None: + """Test that chat middleware can override the response.""" + + @chat_middleware + async def response_override_middleware( + context: ChatContext, next: Callable[[ChatContext], Awaitable[None]] + ) -> None: + # Override the response without calling next() + context.result = ChatResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")], + response_id="middleware-response-123", + ) + context.terminate = True + + # Add middleware to chat client + chat_client_base.middleware = [response_override_middleware] + + # Execute chat client + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await chat_client_base.get_response(messages) + + # Verify that the response was overridden + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].text == "Middleware overridden response" + assert response.response_id == "middleware-response-123" + + async def test_multiple_chat_middleware_execution_order(self, chat_client_base: "MockBaseChatClient") -> None: + """Test that multiple chat middleware execute in the correct order.""" + execution_order: list[str] = [] + + @chat_middleware + async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("first_before") + await next(context) + execution_order.append("first_after") + + @chat_middleware + async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("second_before") + await next(context) + execution_order.append("second_after") + + # Add middleware to chat client (order should be preserved) + chat_client_base.middleware = [first_middleware, second_middleware] + + # Execute chat client + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await chat_client_base.get_response(messages) + + # Verify response + assert response is not None + + # Verify middleware execution order (nested execution) + expected_order = ["first_before", "second_before", "second_after", "first_after"] + assert execution_order == expected_order + + async def test_chat_agent_with_chat_middleware(self) -> None: + """Test ChatAgent with chat middleware specified at agent level.""" + execution_order: list[str] = [] + + @chat_middleware + async def agent_level_chat_middleware( + context: ChatContext, next: Callable[[ChatContext], Awaitable[None]] + ) -> None: + execution_order.append("agent_chat_middleware_before") + await next(context) + execution_order.append("agent_chat_middleware_after") + + chat_client = MockBaseChatClient() + + # Create ChatAgent with chat middleware + agent = ChatAgent(chat_client=chat_client, middleware=[agent_level_chat_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + + # Verify middleware execution order + assert execution_order == ["agent_chat_middleware_before", "agent_chat_middleware_after"] + + async def test_chat_agent_with_multiple_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None: + """Test that ChatAgent can have multiple chat middleware.""" + execution_order: list[str] = [] + + @chat_middleware + async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("first_before") + await next(context) + execution_order.append("first_after") + + @chat_middleware + async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("second_before") + await next(context) + execution_order.append("second_after") + + # Create ChatAgent with multiple chat middleware + agent = ChatAgent(chat_client=chat_client_base, middleware=[first_middleware, second_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + + # Verify both middleware executed (nested execution order) + expected_order = ["first_before", "second_before", "second_after", "first_after"] + assert execution_order == expected_order + + async def test_chat_middleware_with_streaming(self, chat_client_base: "MockBaseChatClient") -> None: + """Test chat middleware with streaming responses.""" + execution_order: list[str] = [] + + @chat_middleware + async def streaming_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_order.append("streaming_before") + # Verify it's a streaming context + assert context.is_streaming is True + await next(context) + execution_order.append("streaming_after") + + # Add middleware to chat client + chat_client_base.middleware = [streaming_middleware] + + # Execute streaming response + messages = [ChatMessage(role=Role.USER, text="test message")] + updates: list[object] = [] + async for update in chat_client_base.get_streaming_response(messages): + updates.append(update) + + # Verify we got updates + assert len(updates) > 0 + + # Verify middleware executed + assert execution_order == ["streaming_before", "streaming_after"] + + async def test_run_level_middleware_isolation(self, chat_client_base: "MockBaseChatClient") -> None: + """Test that run-level middleware is isolated and doesn't persist across calls.""" + execution_count = {"count": 0} + + @chat_middleware + async def counting_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + execution_count["count"] += 1 + await next(context) + + # First call with run-level middleware + messages = [ChatMessage(role=Role.USER, text="first message")] + response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware]) + assert response1 is not None + assert execution_count["count"] == 1 + + # Second call WITHOUT run-level middleware - should not execute the middleware + messages = [ChatMessage(role=Role.USER, text="second message")] + response2 = await chat_client_base.get_response(messages) + assert response2 is not None + assert execution_count["count"] == 1 # Should still be 1, not 2 + + # Third call with run-level middleware again - should execute + messages = [ChatMessage(role=Role.USER, text="third message")] + response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware]) + assert response3 is not None + assert execution_count["count"] == 2 # Should be 2 now + + async def test_chat_client_middleware_can_access_and_override_custom_kwargs( + self, chat_client_base: "MockBaseChatClient" + ) -> None: + """Test that chat client middleware can access and override custom parameters like temperature.""" + captured_kwargs: dict[str, Any] = {} + modified_kwargs: dict[str, Any] = {} + + @chat_middleware + async def kwargs_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None: + # Capture the original kwargs + captured_kwargs.update(context.kwargs) + + # Modify some kwargs + context.kwargs["temperature"] = 0.9 + context.kwargs["max_tokens"] = 500 + context.kwargs["new_param"] = "added_by_middleware" + + # Store modified kwargs for verification + modified_kwargs.update(context.kwargs) + + await next(context) + + # Add middleware to chat client + chat_client_base.middleware = [kwargs_middleware] + + # Execute chat client with custom parameters + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await chat_client_base.get_response( + messages, temperature=0.7, max_tokens=100, custom_param="test_value" + ) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + + assert captured_kwargs["temperature"] == 0.7 + assert captured_kwargs["max_tokens"] == 100 + assert captured_kwargs["custom_param"] == "test_value" + + # Verify middleware could modify the kwargs + assert modified_kwargs["temperature"] == 0.9 + assert modified_kwargs["max_tokens"] == 500 + assert modified_kwargs["new_param"] == "added_by_middleware" + assert modified_kwargs["custom_param"] == "test_value" # Should still be there + + async def test_function_middleware_registration_on_chat_client(self) -> None: + """Test function middleware registered on ChatClient is executed during function calls.""" + execution_order: list[str] = [] + + @function_middleware + async def test_function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append(f"function_middleware_before_{context.function.name}") + await next(context) + execution_order.append(f"function_middleware_after_{context.function.name}") + + # Define a simple tool function + def sample_tool(location: str) -> str: + """Get weather for a location.""" + return f"Weather in {location}: sunny" + + # Create function-invocation enabled chat client + chat_client = use_function_invocation(MockBaseChatClient)() + + # Set function middleware directly on the chat client + chat_client.middleware = [test_function_middleware] + + # Prepare responses that will trigger function invocation + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="call_1", + name="sample_tool", + arguments={"location": "San Francisco"}, + ) + ], + ) + ] + ) + final_response = ChatResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="Based on the weather data, it's sunny!")] + ) + + chat_client.run_responses = [function_call_response, final_response] + + # Execute the chat client directly with tools - this should trigger function invocation and middleware + messages = [ChatMessage(role=Role.USER, text="What's the weather in San Francisco?")] + response = await chat_client.get_response(messages, tools=[sample_tool]) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 2 # Two calls: function call + final response + + # Verify function middleware was executed + assert execution_order == [ + "function_middleware_before_sample_tool", + "function_middleware_after_sample_tool", + ] + + async def test_run_level_function_middleware(self) -> None: + """Test that function middleware passed to get_response method is also invoked.""" + execution_order: list[str] = [] + + @function_middleware + async def run_level_function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("run_level_function_middleware_before") + await next(context) + execution_order.append("run_level_function_middleware_after") + + # Define a simple tool function + def sample_tool(location: str) -> str: + """Get weather for a location.""" + return f"Weather in {location}: sunny" + + # Create function-invocation enabled chat client + chat_client = use_function_invocation(MockBaseChatClient)() + + # Prepare responses that will trigger function invocation + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="call_2", + name="sample_tool", + arguments={"location": "New York"}, + ) + ], + ) + ] + ) + final_response = ChatResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="The weather information has been retrieved!")] + ) + + chat_client.run_responses = [function_call_response, final_response] + + # Execute the chat client directly with run-level middleware and tools + messages = [ChatMessage(role=Role.USER, text="What's the weather in New York?")] + response = await chat_client.get_response( + messages, tools=[sample_tool], middleware=[run_level_function_middleware] + ) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 2 # Two calls: function call + final response + + # Verify run-level function middleware was executed once (during function invocation) + assert execution_order == [ + "run_level_function_middleware_before", + "run_level_function_middleware_after", + ] diff --git a/python/samples/getting_started/agents/custom/custom_chat_client.py b/python/samples/getting_started/agents/custom/custom_chat_client.py index b138413194..f7a2ff9eae 100644 --- a/python/samples/getting_started/agents/custom/custom_chat_client.py +++ b/python/samples/getting_started/agents/custom/custom_chat_client.py @@ -13,6 +13,7 @@ from agent_framework import ( ChatResponseUpdate, Role, TextContent, + use_chat_middleware, use_function_invocation, ) @@ -37,6 +38,7 @@ custom client with ChatAgent through the create_agent() method. @use_function_invocation +@use_chat_middleware class EchoingChatClient(BaseChatClient): """A custom chat client that echoes messages back with modifications. diff --git a/python/samples/getting_started/middleware/README.md b/python/samples/getting_started/middleware/README.md new file mode 100644 index 0000000000..c833855c86 --- /dev/null +++ b/python/samples/getting_started/middleware/README.md @@ -0,0 +1,45 @@ +# Middleware Examples + +This folder contains examples demonstrating various middleware patterns with the Agent Framework. Middleware allows you to intercept and modify behavior at different execution stages, including agent runs, function calls, and chat interactions. + +## Examples + +| File | Description | +|------|-------------| +| [`function_based_middleware.py`](function_based_middleware.py) | Demonstrates how to implement middleware using simple async functions instead of classes. Shows security validation, logging, and performance monitoring middleware. Function-based middleware is ideal for simple, stateless operations and provides a lightweight approach. | +| [`class_based_middleware.py`](class_based_middleware.py) | Shows how to implement middleware using class-based approach by inheriting from `AgentMiddleware` and `FunctionMiddleware` base classes. Includes security checks for sensitive information and detailed function execution logging with timing. | +| [`decorator_middleware.py`](decorator_middleware.py) | Demonstrates how to use `@agent_middleware` and `@function_middleware` decorators to explicitly mark middleware functions without requiring type annotations. Shows different middleware detection scenarios and explicit decorator usage. | +| [`middleware_termination.py`](middleware_termination.py) | Shows how middleware can terminate execution using the `context.terminate` flag. Includes examples of pre-termination (prevents agent processing) and post-termination (allows processing but stops further execution). Useful for security checks, rate limiting, or early exit conditions. | +| [`exception_handling_with_middleware.py`](exception_handling_with_middleware.py) | Demonstrates how to use middleware for centralized exception handling in function calls. Shows how to catch exceptions from functions, provide graceful error responses, and override function results when errors occur to provide user-friendly messages. | +| [`override_result_with_middleware.py`](override_result_with_middleware.py) | Shows how to use middleware to intercept and modify function results after execution, supporting both regular and streaming agent responses. Demonstrates result filtering, formatting, enhancement, and custom streaming response generation. | +| [`shared_state_middleware.py`](shared_state_middleware.py) | Demonstrates how to implement function-based middleware within a class to share state between multiple middleware functions. Shows how middleware can work together by sharing state, including call counting and result enhancement. | +| [`agent_and_run_level_middleware.py`](agent_and_run_level_middleware.py) | Explains the difference between agent-level middleware (applied to ALL runs of the agent) and run-level middleware (applied to specific runs only). Shows security validation, performance monitoring, and context-specific middleware patterns. | +| [`chat_middleware.py`](chat_middleware.py) | Demonstrates how to use chat middleware to observe and override inputs sent to AI models. Shows how to intercept chat requests, log and modify input messages, and override entire responses before they reach the underlying AI service. | + +## Key Concepts + +### Middleware Types + +- **Agent Middleware**: Intercepts agent run execution, allowing you to modify requests and responses +- **Function Middleware**: Intercepts function calls within agents, enabling logging, validation, and result modification +- **Chat Middleware**: Intercepts chat requests sent to AI models, allowing input/output transformation + +### Implementation Approaches + +- **Function-based**: Simple async functions for lightweight, stateless operations +- **Class-based**: Inherit from base middleware classes for complex, stateful operations +- **Decorator-based**: Use decorators for explicit middleware marking + +### Common Use Cases + +- **Security**: Validate requests, block sensitive information, implement access controls +- **Logging**: Track execution timing, log parameters and results, monitor performance +- **Error Handling**: Catch exceptions, provide graceful fallbacks, implement retry logic +- **Result Transformation**: Filter, format, or enhance function outputs +- **State Management**: Share data between middleware functions, maintain execution context + +### Execution Control + +- **Termination**: Use `context.terminate` to stop execution early +- **Result Override**: Modify or replace function/agent results +- **Streaming Support**: Handle both regular and streaming responses \ No newline at end of file diff --git a/python/samples/getting_started/middleware/chat_middleware.py b/python/samples/getting_started/middleware/chat_middleware.py new file mode 100644 index 0000000000..37674a86d9 --- /dev/null +++ b/python/samples/getting_started/middleware/chat_middleware.py @@ -0,0 +1,245 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + ChatContext, + ChatMessage, + ChatMiddleware, + ChatResponse, + Role, + chat_middleware, +) +from agent_framework.azure import AzureAIAgentClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Chat Middleware Example + +This sample demonstrates how to use chat middleware to observe and override +inputs sent to AI models. Chat middleware intercepts chat requests before they reach +the underlying AI service, allowing you to: + +1. Observe and log input messages +2. Modify input messages before sending to AI +3. Override the entire response + +The example covers: +- Class-based chat middleware inheriting from ChatMiddleware +- Function-based chat middleware with @chat_middleware decorator +- Middleware registration at agent level (applies to all runs) +- Middleware registration at run level (applies to specific run only) +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +class InputObserverMiddleware(ChatMiddleware): + """Class-based middleware that observes and modifies input messages.""" + + def __init__(self, replacement: str | None = None): + """Initialize with a replacement for user messages.""" + self.replacement = replacement + + async def process( + self, + context: ChatContext, + next: Callable[[ChatContext], Awaitable[None]], + ) -> None: + """Observe and modify input messages before they are sent to AI.""" + print("[InputObserverMiddleware] Observing input messages:") + + for i, message in enumerate(context.messages): + content = message.text if message.text else str(message.contents) + print(f" Message {i + 1} ({message.role.value}): {content}") + + print(f"[InputObserverMiddleware] Total messages: {len(context.messages)}") + + # Modify user messages by creating new messages with enhanced text + modified_messages: list[ChatMessage] = [] + modified_count = 0 + + for message in context.messages: + if message.role == Role.USER and message.text: + original_text = message.text + updated_text = original_text + + if self.replacement: + updated_text = self.replacement + print(f"[InputObserverMiddleware] Updated: '{original_text}' -> '{updated_text}'") + + modified_message = ChatMessage(role=message.role, text=updated_text) + modified_messages.append(modified_message) + modified_count += 1 + else: + modified_messages.append(message) + + # Replace messages in context + context.messages[:] = modified_messages + + # Continue to next middleware or AI execution + await next(context) + + # Observe that processing is complete + print("[InputObserverMiddleware] Processing completed") + + +@chat_middleware +async def security_and_override_middleware( + context: ChatContext, + next: Callable[[ChatContext], Awaitable[None]], +) -> None: + """Function-based middleware that implements security filtering and response override.""" + print("[SecurityMiddleware] Processing input...") + + # Security check - block sensitive information + blocked_terms = ["password", "secret", "api_key", "token"] + + for message in context.messages: + if message.text: + message_lower = message.text.lower() + for term in blocked_terms: + if term in message_lower: + print(f"[SecurityMiddleware] BLOCKED: Found '{term}' in message") + + # Override the response instead of calling AI + context.result = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + text="I cannot process requests containing sensitive information. " + "Please rephrase your question without including passwords, secrets, or other " + "sensitive data.", + ) + ] + ) + + # Set terminate flag to stop execution + context.terminate = True + return + + # Continue to next middleware or AI execution + await next(context) + + +async def class_based_chat_middleware() -> None: + """Demonstrate class-based middleware at agent level.""" + print("\n" + "=" * 60) + print("Class-based Chat Middleware (Agent Level)") + print("=" * 60) + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(async_credential=credential).create_agent( + name="EnhancedChatAgent", + instructions="You are a helpful AI assistant.", + # Register class-based middleware at agent level (applies to all runs) + middleware=InputObserverMiddleware(), + tools=get_weather, + ) as agent, + ): + query = "What's the weather in Seattle?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Final Response: {result.text if result.text else 'No response'}") + + +async def function_based_chat_middleware() -> None: + """Demonstrate function-based middleware at agent level.""" + print("\n" + "=" * 60) + print("Function-based Chat Middleware (Agent Level)") + print("=" * 60) + + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(async_credential=credential).create_agent( + name="FunctionMiddlewareAgent", + instructions="You are a helpful AI assistant.", + # Register function-based middleware at agent level + middleware=security_and_override_middleware, + ) as agent, + ): + # Scenario with normal query + print("\n--- Scenario 1: Normal Query ---") + query = "Hello, how are you?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Final Response: {result.text if result.text else 'No response'}") + + # Scenario with security violation + print("\n--- Scenario 2: Security Violation ---") + query = "What is my password for this account?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Final Response: {result.text if result.text else 'No response'}") + + +async def run_level_middleware() -> None: + """Demonstrate middleware registration at run level.""" + print("\n" + "=" * 60) + print("Run-level Chat Middleware") + print("=" * 60) + + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(async_credential=credential).create_agent( + name="RunLevelAgent", + instructions="You are a helpful AI assistant.", + tools=get_weather, + # No middleware at agent level + ) as agent, + ): + # Scenario 1: Run without any middleware + print("\n--- Scenario 1: No Middleware ---") + query = "What's the weather in Tokyo?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Response: {result.text if result.text else 'No response'}") + + # Scenario 2: Run with specific middleware for this call only (both enhancement and security) + print("\n--- Scenario 2: With Run-level Middleware ---") + print(f"User: {query}") + result = await agent.run( + query, + middleware=[ + InputObserverMiddleware(replacement="What's the weather in Madrid?"), + security_and_override_middleware, + ], + ) + print(f"Response: {result.text if result.text else 'No response'}") + + # Scenario 3: Security test with run-level middleware + print("\n--- Scenario 3: Security Test with Run-level Middleware ---") + query = "Can you help me with my secret API key?" + print(f"User: {query}") + result = await agent.run( + query, + middleware=security_and_override_middleware, + ) + print(f"Response: {result.text if result.text else 'No response'}") + + +async def main() -> None: + """Run all chat middleware examples.""" + print("Chat Middleware Examples") + print("========================") + + await class_based_chat_middleware() + await function_based_chat_middleware() + await run_level_middleware() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/middleware/override_result_with_middleware.py b/python/samples/getting_started/middleware/override_result_with_middleware.py index 1b030193fa..cc7a349e0b 100644 --- a/python/samples/getting_started/middleware/override_result_with_middleware.py +++ b/python/samples/getting_started/middleware/override_result_with_middleware.py @@ -1,28 +1,37 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import AsyncIterable, Awaitable, Callable from random import randint from typing import Annotated -from agent_framework import FunctionInvocationContext +from agent_framework import ( + AgentRunContext, + AgentRunResponse, + AgentRunResponseUpdate, + ChatMessage, + Role, + TextContent, +) from agent_framework.azure import AzureAIAgentClient from azure.identity.aio import AzureCliCredential from pydantic import Field """ -Result Override with Middleware +Result Override with Middleware (Regular and Streaming) This sample demonstrates how to use middleware to intercept and modify function results -after execution. The example shows: +after execution, supporting both regular and streaming agent responses. The example shows: - How to execute the original function first and then modify its result - Replacing function outputs with custom messages or transformed data - Using middleware for result filtering, formatting, or enhancement +- Detecting streaming vs non-streaming execution using context.is_streaming +- Overriding streaming results with custom async generators The weather override middleware lets the original weather function execute normally, -then replaces its result with a custom "perfect weather" message, demonstrating -how middleware can be used for content filtering, A/B testing, or result enhancement. +then replaces its result with a custom "perfect weather" message. For streaming responses, +it creates a custom async generator that yields the override message in chunks. """ @@ -35,32 +44,39 @@ def get_weather( async def weather_override_middleware( - context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] ) -> None: - function_name = context.function.name + """Middleware that overrides weather results for both streaming and non-streaming cases.""" - # Let the original function execute first + # Let the original agent execution complete first await next(context) - # Override the result if it's a weather function - if function_name == "get_weather" and context.result is not None: - original_result = str(context.result) - print(f"[WeatherOverrideMiddleware] Original result: {original_result}") + # Check if there's a result to override (agent called weather function) + if context.result is not None: + # Create custom weather message + chunks = [ + "Weather Advisory - ", + "due to special atmospheric conditions, ", + "all locations are experiencing perfect weather today! ", + "Temperature is a comfortable 22°C with gentle breezes. ", + "Perfect day for outdoor activities!", + ] - # Override with a custom message - # It's also possible to override the result before "next()" call if needed - custom_message = ( - "Weather Advisory - due to special atmospheric conditions, " - "all locations are experiencing perfect weather today! " - "Temperature is a comfortable 22°C with gentle breezes. " - "Perfect day for outdoor activities!" - ) - context.result = custom_message - print(f"[WeatherOverrideMiddleware] Overriding with custom message: {custom_message}") + if context.is_streaming: + # For streaming: create an async generator that yields chunks + async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]: + for chunk in chunks: + yield AgentRunResponseUpdate(contents=[TextContent(text=chunk)]) + + context.result = override_stream() + else: + # For non-streaming: just replace with the string message + custom_message = "".join(chunks) + context.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=custom_message)]) async def main() -> None: - """Example demonstrating result override with middleware.""" + """Example demonstrating result override with middleware for both streaming and non-streaming.""" print("=== Result Override Middleware Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred @@ -74,11 +90,22 @@ async def main() -> None: middleware=weather_override_middleware, ) as agent, ): + # Non-streaming example + print("\n--- Non-streaming Example ---") query = "What's the weather like in Seattle?" print(f"User: {query}") result = await agent.run(query) print(f"Agent: {result}") + # Streaming example + print("\n--- Streaming Example ---") + query = "What's the weather like in Portland?" + print(f"User: {query}") + print("Agent: ", end="", flush=True) + async for chunk in agent.run_stream(query): + if chunk.text: + print(chunk.text, end="", flush=True) + if __name__ == "__main__": asyncio.run(main()) diff --git a/python/samples/getting_started/middleware/shared_state_middleware.py b/python/samples/getting_started/middleware/shared_state_middleware.py new file mode 100644 index 0000000000..e4ce077038 --- /dev/null +++ b/python/samples/getting_started/middleware/shared_state_middleware.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + FunctionInvocationContext, +) +from agent_framework.azure import AzureAIAgentClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Shared State Function-based Middleware Example + +This sample demonstrates how to implement function-based middleware within a class to share state. +The example includes: + +- A MiddlewareContainer class with two simple function middleware methods +- First middleware: Counts function calls and stores the count in shared state +- Second middleware: Uses the shared count to add call numbers to function results + +This approach shows how middleware can work together by sharing state within the same class instance. +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +def get_time( + timezone: Annotated[str, Field(description="The timezone to get the time for.")] = "UTC", +) -> str: + """Get the current time for a given timezone.""" + import datetime + + return f"The current time in {timezone} is {datetime.datetime.now().strftime('%H:%M:%S')}" + + +class MiddlewareContainer: + """Container class that holds middleware functions with shared state.""" + + def __init__(self) -> None: + # Simple shared state: count function calls + self.call_count: int = 0 + + async def call_counter_middleware( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + """First middleware: increments call count in shared state.""" + # Increment the shared call count + self.call_count += 1 + + print(f"[CallCounter] This is function call #{self.call_count}") + + # Call the next middleware/function + await next(context) + + async def result_enhancer_middleware( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + """Second middleware: uses shared call count to enhance function results.""" + print(f"[ResultEnhancer] Current total calls so far: {self.call_count}") + + # Call the next middleware/function + await next(context) + + # After function execution, enhance the result using shared state + if context.result: + enhanced_result = f"[Call #{self.call_count}] {context.result}" + context.result = enhanced_result + print("[ResultEnhancer] Enhanced result with call number") + + +async def main() -> None: + """Example demonstrating shared state function-based middleware.""" + print("=== Shared State Function-based Middleware Example ===") + + # Create middleware container with shared state + middleware_container = MiddlewareContainer() + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + AzureAIAgentClient(async_credential=credential).create_agent( + name="UtilityAgent", + instructions="You are a helpful assistant that can provide weather information and current time.", + tools=[get_weather, get_time], + # Pass both middleware functions from the same container instance + # Order matters: counter runs first to increment count, + # then result enhancer uses the updated count + middleware=[ + middleware_container.call_counter_middleware, + middleware_container.result_enhancer_middleware, + ], + ) as agent, + ): + # Test multiple requests to see shared state in action + queries = [ + "What's the weather like in New York?", + "What time is it in London?", + "What's the weather in Tokyo?", + ] + + for i, query in enumerate(queries, 1): + print(f"\n--- Query {i} ---") + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response'}") + + # Display final statistics + print("\n=== Final Statistics ===") + print(f"Total function calls made: {middleware_container.call_count}") + + +if __name__ == "__main__": + asyncio.run(main())