From 99860a5d0772e50f4e62fe8885a9fbfa1d17ee94 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:30:05 -0700 Subject: [PATCH] Python: Agent and Function middleware (#770) * Initial middleware implementation * Small fixes * Small updates * Small updates in samples * Moved middleware functionality to decorator * Removed obsolete file * Renamed AgentInvocationContext to AzureRunContext * Added unit tests * Small settings update for test discovery in VS Code * Added unit tests * Reverted changes in environment settings * Added context result override * Renaming and updates to logic * Added more samples * Updated DEV_SETUP.md * Addressed PR feedback * Addressed PR feedback * Removed unused parameter * Small fix * Small fix in telemetry logic * Revert "Small fix in telemetry logic" This reverts commit 6f82660d2d096dd110a92b2d51e0e98706e549b0. * Small fix --------- Co-authored-by: Chris <66376200+crickman@users.noreply.github.com> --- .../packages/main/agent_framework/__init__.py | 1 + .../packages/main/agent_framework/_agents.py | 7 + .../packages/main/agent_framework/_clients.py | 22 +- .../main/agent_framework/_middleware.py | 632 ++++++++++++ .../packages/main/agent_framework/_tools.py | 52 +- .../main/agent_framework/guard_rails.py | 27 - .../main/agent_framework/telemetry.py | 6 +- python/packages/main/pyproject.toml | 10 +- python/packages/main/tests/main/conftest.py | 15 +- .../main/tests/main/test_middleware.py | 939 ++++++++++++++++++ .../main/test_middleware_context_result.py | 463 +++++++++ .../tests/main/test_middleware_with_agent.py | 544 ++++++++++ .../middleware/class_based_middleware.py | 125 +++ .../exception_handling_with_middleware.py | 75 ++ .../middleware/function_based_middleware.py | 109 ++ .../override_result_with_middleware.py | 84 ++ 16 files changed, 3071 insertions(+), 40 deletions(-) create mode 100644 python/packages/main/agent_framework/_middleware.py delete mode 100644 python/packages/main/agent_framework/guard_rails.py create mode 100644 python/packages/main/tests/main/test_middleware.py create mode 100644 python/packages/main/tests/main/test_middleware_context_result.py create mode 100644 python/packages/main/tests/main/test_middleware_with_agent.py create mode 100644 python/samples/getting_started/middleware/class_based_middleware.py create mode 100644 python/samples/getting_started/middleware/exception_handling_with_middleware.py create mode 100644 python/samples/getting_started/middleware/function_based_middleware.py create mode 100644 python/samples/getting_started/middleware/override_result_with_middleware.py diff --git a/python/packages/main/agent_framework/__init__.py b/python/packages/main/agent_framework/__init__.py index dc6c9511e1..a0b1088ffb 100644 --- a/python/packages/main/agent_framework/__init__.py +++ b/python/packages/main/agent_framework/__init__.py @@ -13,6 +13,7 @@ from ._clients import * # noqa: F403 from ._logging import * # noqa: F403 from ._mcp import * # noqa: F403 from ._memory import * # noqa: F403 +from ._middleware import * # noqa: F403 from ._threads import * # noqa: F403 from ._tools import * # noqa: F403 from ._types import * # noqa: F403 diff --git a/python/packages/main/agent_framework/_agents.py b/python/packages/main/agent_framework/_agents.py index 9c15b5d3a2..6977388588 100644 --- a/python/packages/main/agent_framework/_agents.py +++ b/python/packages/main/agent_framework/_agents.py @@ -12,6 +12,7 @@ from ._clients import BaseChatClient, ChatClientProtocol from ._logging import get_logger from ._mcp import MCPTool from ._memory import AggregateContextProvider, Context, ContextProvider +from ._middleware import Middleware, use_agent_middleware from ._pydantic import AFBaseModel from ._threads import AgentThread, ChatMessageStore, deserialize_thread_state, thread_on_new_messages from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, ToolProtocol @@ -138,12 +139,14 @@ class BaseAgent(AFBaseModel): description: The description of the agent. display_name: The display name of the agent, which is either the name or id. context_providers: The collection of multiple context providers to include during agent invocation. + middleware: List of middleware to intercept agent and function invocations. """ id: str = Field(default_factory=lambda: str(uuid4())) name: str | None = None description: str | None = None context_providers: AggregateContextProvider | None = None + middleware: Middleware | list[Middleware] | None = None async def _notify_thread_of_new_messages( self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage] @@ -189,6 +192,7 @@ class BaseAgent(AFBaseModel): # region ChatAgent +@use_agent_middleware @use_agent_telemetry class ChatAgent(BaseAgent): """A Chat Client Agent.""" @@ -231,6 +235,7 @@ class ChatAgent(BaseAgent): additional_properties: dict[str, Any] | None = None, chat_message_store_factory: Callable[[], ChatMessageStore] | None = None, context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, + middleware: Middleware | list[Middleware] | None = None, **kwargs: Any, ) -> None: """Create a ChatAgent. @@ -266,6 +271,7 @@ class ChatAgent(BaseAgent): chat_message_store_factory: factory function to create an instance of ChatMessageStore. If not provided, the default in-memory store will be used. context_providers: The collection of multiple context providers to include during agent invocation. + middleware: List of middleware to intercept agent and function invocations. kwargs: any additional keyword arguments. Unused, can be used by subclasses of this Agent. """ @@ -287,6 +293,7 @@ class ChatAgent(BaseAgent): "chat_client": chat_client, "chat_message_store_factory": chat_message_store_factory, "context_providers": aggregate_context_providers, + "middleware": middleware, "chat_options": ChatOptions( ai_model_id=model, frequency_penalty=frequency_penalty, diff --git a/python/packages/main/agent_framework/_clients.py b/python/packages/main/agent_framework/_clients.py index 4d5775c96b..fede0de4d8 100644 --- a/python/packages/main/agent_framework/_clients.py +++ b/python/packages/main/agent_framework/_clients.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field from ._logging import get_logger from ._mcp import MCPTool from ._memory import AggregateContextProvider, ContextProvider +from ._middleware import Middleware from ._pydantic import AFBaseModel from ._threads import ChatMessageStore from ._tools import ToolProtocol @@ -344,7 +345,14 @@ class BaseChatClient(AFBaseModel, ABC): ) prepped_messages = self.prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) - return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs) + + # Remove middleware pipeline from kwargs as it's only used by function invocation wrappers + if "_function_middleware_pipeline" in kwargs: + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"} + else: + filtered_kwargs = kwargs + + return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs) async def get_streaming_response( self, @@ -424,8 +432,15 @@ class BaseChatClient(AFBaseModel, ABC): ) prepped_messages = self.prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) + + # Remove middleware pipeline from kwargs as it's only used by function invocation wrappers + if "_function_middleware_pipeline" in kwargs: + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"} + else: + filtered_kwargs = kwargs + async for update in self._inner_get_streaming_response( - messages=prepped_messages, chat_options=chat_options, **kwargs + messages=prepped_messages, chat_options=chat_options, **filtered_kwargs ): yield update @@ -465,6 +480,7 @@ class BaseChatClient(AFBaseModel, ABC): | None = None, chat_message_store_factory: Callable[[], ChatMessageStore] | None = None, context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, + middleware: Middleware | list[Middleware] | None = None, **kwargs: Any, ) -> "ChatAgent": """Create an agent with the given name and instructions. @@ -476,6 +492,7 @@ class BaseChatClient(AFBaseModel, ABC): chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided, the default in-memory store will be used. context_providers: Context providers to include during agent invocation. + middleware: List of middleware to intercept agent and function invocations. **kwargs: Additional keyword arguments to pass to the agent. See ChatAgent for all the available options. @@ -491,6 +508,7 @@ class BaseChatClient(AFBaseModel, ABC): tools=tools, chat_message_store_factory=chat_message_store_factory, context_providers=context_providers, + middleware=middleware, **kwargs, ) diff --git a/python/packages/main/agent_framework/_middleware.py b/python/packages/main/agent_framework/_middleware.py new file mode 100644 index 0000000000..7150881f6b --- /dev/null +++ b/python/packages/main/agent_framework/_middleware.py @@ -0,0 +1,632 @@ +# Copyright (c) Microsoft. All rights reserved. + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterable, Awaitable, Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar + +from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage + +if TYPE_CHECKING: + from pydantic import BaseModel + + from ._agents import AgentProtocol + from ._tools import AIFunction + +TAgent = TypeVar("TAgent", bound="AgentProtocol") + +__all__ = [ + "AgentMiddleware", + "AgentRunContext", + "FunctionInvocationContext", + "FunctionMiddleware", + "Middleware", + "use_agent_middleware", +] + + +@dataclass +class AgentRunContext: + """Context object for agent middleware invocations. + + Attributes: + agent: The agent being invoked. + messages: The messages being sent to the agent. + is_streaming: Whether this is a streaming invocation. + metadata: Metadata dictionary for sharing data between agent middleware. + result: Agent execution result. Can be observed after calling next() + to see the actual execution result or can be set to override the execution result. + For non-streaming: should be AgentRunResponse + For streaming: should be AsyncIterable[AgentRunResponseUpdate] + """ + + agent: "AgentProtocol" + messages: list[ChatMessage] + is_streaming: bool = False + metadata: dict[str, Any] = field(default_factory=lambda: {}) + result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None + + +@dataclass +class FunctionInvocationContext: + """Context object for function middleware invocations. + + Attributes: + function: The function being invoked. + arguments: The validated arguments for the function. + metadata: Metadata dictionary for sharing data between function middleware. + result: Function execution result. Can be observed after calling next() + to see the actual execution result or can be set to override the execution result. + """ + + function: "AIFunction[Any, Any]" + arguments: "BaseModel" + metadata: dict[str, Any] = field(default_factory=lambda: {}) + result: Any = None + + +class AgentMiddleware(ABC): + """Abstract base class for agent middleware that can intercept agent invocations.""" + + @abstractmethod + async def process( + self, + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], + ) -> None: + """Process an agent invocation. + + Args: + context: Agent invocation context containing agent, messages, and metadata. + Use context.is_streaming to determine if this is a streaming call. + Middleware can set context.result to override execution, or observe + the actual execution result after calling next(). + For non-streaming: AgentRunResponse + For streaming: AsyncIterable[AgentRunResponseUpdate] + next: Function to call the next middleware or final agent execution. + Does not return anything - all data flows through the context. + + Note: + Middleware should not return anything. All data manipulation should happen + within the context object. Set context.result to override execution, + or observe context.result after calling next() for actual results. + """ + ... + + +class FunctionMiddleware(ABC): + """Abstract base class for function middleware that can intercept function invocations.""" + + @abstractmethod + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + """Process a function invocation. + + Args: + context: Function invocation context containing function, arguments, and metadata. + Middleware can set context.result to override execution, or observe + the actual execution result after calling next(). + next: Function to call the next middleware or final function execution. + Does not return anything - all data flows through the context. + + Note: + Middleware should not return anything. All data manipulation should happen + within the context object. Set context.result to override execution, + or observe context.result after calling next() for actual results. + """ + ... + + +# Pure function type definitions for convenience +AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]] + +FunctionMiddlewareCallable = Callable[ + [FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None] +] + +# Type alias for all middleware types +Middleware: TypeAlias = AgentMiddleware | AgentMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable + + +class AgentMiddlewareWrapper(AgentMiddleware): + """Wrapper to convert pure functions into AgentMiddleware protocol objects.""" + + def __init__(self, func: AgentMiddlewareCallable): + self.func = func + + async def process( + self, + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], + ) -> None: + await self.func(context, next) + + +class FunctionMiddlewareWrapper(FunctionMiddleware): + """Wrapper to convert pure functions into FunctionMiddleware protocol objects.""" + + def __init__(self, func: FunctionMiddlewareCallable): + self.func = func + + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + await self.func(context, next) + + +class BaseMiddlewarePipeline(ABC): + """Base class for middleware pipeline execution.""" + + def __init__(self) -> None: + """Initialize the base middleware pipeline.""" + self._middlewares: list[Any] = [] + + @abstractmethod + def _register_middleware(self, middleware: Any) -> None: + """Register a middleware item. Must be implemented by subclasses.""" + ... + + @property + def has_middlewares(self) -> bool: + """Check if there are any middlewares registered.""" + return bool(self._middlewares) + + def _create_handler_chain( + self, + final_handler: Callable[[Any], Awaitable[Any]], + result_container: dict[str, Any], + result_key: str = "result", + ) -> Callable[[Any], Awaitable[None]]: + """Create a chain of middleware handlers. + + Args: + final_handler: The final handler to execute + result_container: Container to store the result + result_key: Key to use in the result container + + Returns: + The first handler in the chain + """ + + def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]: + if index >= len(self._middlewares): + + async def final_wrapper(c: Any) -> None: + # Execute actual handler and populate context for observability + result = await final_handler(c) + result_container[result_key] = result + c.result = result + + return final_wrapper + + middleware = self._middlewares[index] + next_handler = create_next_handler(index + 1) + + async def current_handler(c: Any) -> None: + await middleware.process(c, next_handler) + + return current_handler + + return create_next_handler(0) + + +class AgentMiddlewarePipeline(BaseMiddlewarePipeline): + """Executes agent middleware in a chain.""" + + def __init__(self, middlewares: list[AgentMiddleware | AgentMiddlewareCallable] | None = None): + """Initialize the agent middleware pipeline. + + Args: + middlewares: List of agent middleware to include in the pipeline. + """ + super().__init__() + self._middlewares: list[AgentMiddleware] = [] + + if middlewares: + for middleware in middlewares: + self._register_middleware(middleware) + + def _register_middleware(self, middleware: AgentMiddleware | AgentMiddlewareCallable) -> None: + """Register an agent middleware item.""" + if isinstance(middleware, AgentMiddleware): + self._middlewares.append(middleware) + elif callable(middleware): + self._middlewares.append(AgentMiddlewareWrapper(middleware)) + + async def execute( + self, + agent: "AgentProtocol", + messages: list[ChatMessage], + context: AgentRunContext, + final_handler: Callable[[AgentRunContext], Awaitable[AgentRunResponse]], + ) -> AgentRunResponse | None: + """Execute the agent middleware pipeline for non-streaming. + + Args: + agent: The agent being invoked. + messages: The messages to send to the agent. + context: The agent invocation context. + final_handler: The final handler that performs the actual agent execution. + + Returns: + The agent response after processing through all middleware. + """ + # Update context with agent and messages + context.agent = agent + context.messages = messages + context.is_streaming = False + + if not self._middlewares: + return await final_handler(context) + + # Store the final result + result_container: dict[str, AgentRunResponse | None] = {"response": None} + + def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]: + if index >= len(self._middlewares): + + async def final_wrapper(c: AgentRunContext) -> None: + # Execute actual handler and populate context for observability + result = await final_handler(c) + result_container["result"] = result + c.result = result + + return final_wrapper + + middleware = self._middlewares[index] + next_handler = create_next_handler(index + 1) + + async def current_handler(c: AgentRunContext) -> None: + await middleware.process(c, next_handler) + # After middleware execution, check if response was overridden + if c.result is not None and isinstance(c.result, AgentRunResponse): + result_container["result"] = c.result + + return current_handler + + first_handler = create_next_handler(0) + await first_handler(context) + + # Return the result from result container or overridden result + if context.result is not None and isinstance(context.result, AgentRunResponse): + return context.result + + # If no result was set (next() not called), return empty AgentRunResponse + response = result_container.get("result") + if response is None: + return AgentRunResponse() + return response + + async def execute_stream( + self, + agent: "AgentProtocol", + messages: list[ChatMessage], + context: AgentRunContext, + final_handler: Callable[[AgentRunContext], AsyncIterable[AgentRunResponseUpdate]], + ) -> AsyncIterable[AgentRunResponseUpdate]: + """Execute the agent middleware pipeline for streaming. + + Args: + agent: The agent being invoked. + messages: The messages to send to the agent. + context: The agent invocation context. + final_handler: The final handler that performs the actual agent streaming execution. + + Yields: + Agent response updates after processing through all middleware. + """ + # Update context with agent and messages + context.agent = agent + context.messages = messages + context.is_streaming = True + + if not self._middlewares: + async for update in final_handler(context): + yield update + return + + # Store the final result + result_container: dict[str, AsyncIterable[AgentRunResponseUpdate] | None] = {"result_stream": None} + + def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]: + if index >= len(self._middlewares): + + async def final_wrapper(c: AgentRunContext) -> None: # noqa: RUF029 + # Execute actual handler and populate context for observability + result = final_handler(c) + result_container["result_stream"] = result + c.result = result + + return final_wrapper + + middleware = self._middlewares[index] + next_handler = create_next_handler(index + 1) + + async def current_handler(c: AgentRunContext) -> None: + await middleware.process(c, next_handler) + + return current_handler + + first_handler = create_next_handler(0) + await first_handler(context) + + # Yield from the result stream in result container or overridden result + if context.result is not None and hasattr(context.result, "__aiter__"): + async for update in context.result: # type: ignore + yield update + return + + result_stream = result_container["result_stream"] + if result_stream is None: + # If no result stream was set (next() not called), yield nothing + return + + async for update in result_stream: + yield update + + +class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): + """Executes function middleware in a chain.""" + + def __init__(self, middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] | None = None): + """Initialize the function middleware pipeline. + + Args: + middlewares: List of function middleware to include in the pipeline. + """ + super().__init__() + self._middlewares: list[FunctionMiddleware] = [] + + if middlewares: + for middleware in middlewares: + self._register_middleware(middleware) + + def _register_middleware(self, middleware: FunctionMiddleware | FunctionMiddlewareCallable) -> None: + """Register a function middleware item.""" + # Check if it's a class instance inheriting from FunctionMiddleware + if isinstance(middleware, FunctionMiddleware): + self._middlewares.append(middleware) + elif callable(middleware): + self._middlewares.append(FunctionMiddlewareWrapper(middleware)) + + async def execute( + self, + function: Any, + arguments: "BaseModel", + context: FunctionInvocationContext, + final_handler: Callable[[FunctionInvocationContext], Awaitable[Any]], + ) -> Any: + """Execute the function middleware pipeline. + + Args: + function: The function being invoked. + arguments: The validated arguments for the function. + context: The function invocation context. + final_handler: The final handler that performs the actual function execution. + + Returns: + The function result after processing through all middleware. + """ + # Update context with function and arguments + context.function = function + context.arguments = arguments + + if not self._middlewares: + return await final_handler(context) + + # Store the final result + result_container: dict[str, Any] = {"result": None} + + # Custom final handler that handles pre-existing results + async def function_final_handler(c: FunctionInvocationContext) -> Any: + # If result was set before calling next(), skip execution + if c.result is not None: + return c.result + # Execute actual handler and populate context for observability + return await final_handler(c) + + first_handler = self._create_handler_chain(function_final_handler, result_container, "result") + await first_handler(context) + + # Return the result from result container or overridden result + if context.result is not None: + return context.result + return result_container["result"] + + +# Decorator for adding middleware support to agent classes +def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]: + """Class decorator that adds middleware support to an agent class. + + This decorator adds middleware functionality to any agent class. + It wraps the run() and run_stream() methods to provide middleware execution. + + Args: + agent_class: The agent class to add middleware support to. + + Returns: + The modified agent class with middleware support. + """ + import inspect + + # Store original methods + original_run = agent_class.run # type: ignore[attr-defined] + original_run_stream = agent_class.run_stream # type: ignore[attr-defined] + + def _initialize_middleware_pipelines(self: Any, middlewares: Middleware | list[Middleware] | None) -> None: + """Initialize agent and function middleware pipelines from the provided middleware list.""" + if not middlewares: + return + + middleware_list: list[Middleware] = middlewares if isinstance(middlewares, list) else [middlewares] # type: ignore + + # Separate agent and function middleware using isinstance checks + agent_middlewares: list[AgentMiddleware | AgentMiddlewareCallable] = [] + function_middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] = [] + + for middleware in middleware_list: + if isinstance(middleware, AgentMiddleware): + agent_middlewares.append(middleware) + elif isinstance(middleware, FunctionMiddleware): + function_middlewares.append(middleware) + elif callable(middleware): # type: ignore[arg-type] + # Check function signature to determine type + try: + sig = inspect.signature(middleware) + params = list(sig.parameters.values()) + if len(params) >= 1: + first_param = params[0] + # Check if first parameter is AgentRunContext or FunctionInvocationContext + if ( + hasattr(first_param.annotation, "__name__") + and first_param.annotation.__name__ == "AgentRunContext" + ): + agent_middlewares.append(middleware) # type: ignore + elif ( + hasattr(first_param.annotation, "__name__") + and first_param.annotation.__name__ == "FunctionInvocationContext" + ): + function_middlewares.append(middleware) # type: ignore + else: + # Default to agent middleware if uncertain + agent_middlewares.append(middleware) # type: ignore + else: + agent_middlewares.append(middleware) # type: ignore + except Exception: + # If signature inspection fails, assume it's an agent middleware + agent_middlewares.append(middleware) # type: ignore + else: + # Fallback + agent_middlewares.append(middleware) # type: ignore + + self._agent_middleware_pipeline = AgentMiddlewarePipeline(agent_middlewares) + self._function_middleware_pipeline = FunctionMiddlewarePipeline(function_middlewares) + + async def middleware_enabled_run( + self: Any, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: Any = None, + **kwargs: Any, + ) -> AgentRunResponse: + """Middleware-enabled run method.""" + # Initialize middleware pipelines if not already done + if ( + hasattr(self, "middleware") + and self.middleware + and not ( + hasattr(self, "_agent_middleware_pipeline") + and hasattr(self, "_function_middleware_pipeline") + and ( + self._agent_middleware_pipeline.has_middlewares + or self._function_middleware_pipeline.has_middlewares + ) + ) + ): + _initialize_middleware_pipelines(self, self.middleware) + + # Ensure pipelines exist even if empty + if not hasattr(self, "_agent_middleware_pipeline"): + self._agent_middleware_pipeline = AgentMiddlewarePipeline() + if not hasattr(self, "_function_middleware_pipeline"): + self._function_middleware_pipeline = FunctionMiddlewarePipeline() + + # Add function middleware pipeline to kwargs if available + if self._function_middleware_pipeline.has_middlewares: + kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline + + normalized_messages = self._normalize_messages(messages) + + # Execute with middleware if available + if self._agent_middleware_pipeline.has_middlewares: + context = AgentRunContext( + agent=self, # type: ignore[arg-type] + messages=normalized_messages, + is_streaming=False, + ) + + async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse: + return await original_run(self, ctx.messages, thread=thread, **kwargs) # type: ignore + + result = await self._agent_middleware_pipeline.execute( + self, # type: ignore[arg-type] + normalized_messages, + context, + _execute_handler, + ) + + return result if result else AgentRunResponse() + + # No middleware, execute directly + return await original_run(self, normalized_messages, thread=thread, **kwargs) # type: ignore[return-value] + + def middleware_enabled_run_stream( + self: Any, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: Any = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + """Middleware-enabled run_stream method.""" + # Initialize middleware pipelines if not already done + if ( + hasattr(self, "middleware") + and self.middleware + and not ( + hasattr(self, "_agent_middleware_pipeline") + and hasattr(self, "_function_middleware_pipeline") + and ( + self._agent_middleware_pipeline.has_middlewares + or self._function_middleware_pipeline.has_middlewares + ) + ) + ): + _initialize_middleware_pipelines(self, self.middleware) + + # Ensure pipelines exist even if empty + if not hasattr(self, "_agent_middleware_pipeline"): + self._agent_middleware_pipeline = AgentMiddlewarePipeline() + if not hasattr(self, "_function_middleware_pipeline"): + self._function_middleware_pipeline = FunctionMiddlewarePipeline() + + # Add function middleware pipeline to kwargs if available + if self._function_middleware_pipeline.has_middlewares: + kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline + + normalized_messages = self._normalize_messages(messages) + + # Execute with middleware if available + if self._agent_middleware_pipeline.has_middlewares: + context = AgentRunContext( + agent=self, # type: ignore[arg-type] + messages=normalized_messages, + is_streaming=True, + ) + + async def _execute_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + async for update in original_run_stream(self, ctx.messages, thread=thread, **kwargs): # type: ignore[misc] + yield update + + async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]: + async for update in self._agent_middleware_pipeline.execute_stream( + self, # type: ignore[arg-type] + normalized_messages, + context, + _execute_stream_handler, + ): + yield update + + return _stream_generator() + + # No middleware, execute directly + return original_run_stream(self, normalized_messages, thread=thread, **kwargs) # type: ignore + + agent_class.run = middleware_enabled_run # type: ignore + agent_class.run_stream = middleware_enabled_run_stream # type: ignore + + return agent_class diff --git a/python/packages/main/agent_framework/_tools.py b/python/packages/main/agent_framework/_tools.py index 3ab096999b..6a3d1eb615 100644 --- a/python/packages/main/agent_framework/_tools.py +++ b/python/packages/main/agent_framework/_tools.py @@ -576,6 +576,7 @@ async def _auto_invoke_function( tool_map: dict[str, AIFunction[BaseModel, Any]], sequence_index: int | None = None, request_index: int | None = None, + middleware_pipeline: Any = None, # Optional MiddlewarePipeline ) -> "Contents": """Invoke a function call requested by the agent, applying filters that are defined in the agent.""" from ._types import FunctionResultContent @@ -590,14 +591,43 @@ async def _auto_invoke_function( merged_args: dict[str, Any] = (custom_args or {}) | parsed_args args = tool.input_model.model_validate(merged_args) exception = None - try: - function_result = await tool.invoke( + + # Execute through middleware pipeline if available + if middleware_pipeline and hasattr(middleware_pipeline, "has_middlewares") and middleware_pipeline.has_middlewares: + from ._middleware import FunctionInvocationContext + + middleware_context = FunctionInvocationContext( + function=tool, arguments=args, - tool_call_id=function_call_content.call_id, - ) # type: ignore[arg-type] - except Exception as ex: - exception = ex - function_result = None + ) + + async def final_function_handler(context_obj: Any) -> Any: + return await tool.invoke( + arguments=context_obj.arguments, + tool_call_id=function_call_content.call_id, + ) + + try: + function_result = await middleware_pipeline.execute( + function=tool, + arguments=args, + context=middleware_context, + final_handler=final_function_handler, + ) + except Exception as ex: + exception = ex + function_result = None + else: + # No middleware - execute directly + try: + function_result = await tool.invoke( + arguments=args, + tool_call_id=function_call_content.call_id, + ) # type: ignore[arg-type] + except Exception as ex: + exception = ex + function_result = None + return FunctionResultContent( call_id=function_call_content.call_id, exception=exception, @@ -631,6 +661,7 @@ async def execute_function_calls( | Callable[..., Any] \ | MutableMapping[str, Any] \ | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]", + middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports ) -> list["Contents"]: tool_map = _get_tool_map(tools) # Run all function calls concurrently @@ -641,6 +672,7 @@ async def execute_function_calls( tool_map=tool_map, sequence_index=seq_idx, request_index=attempt_idx, + middleware_pipeline=middleware_pipeline, ) for seq_idx, function_call in enumerate(function_calls) ]) @@ -706,11 +738,14 @@ def _handle_function_calls_response( if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions): tools = chat_options.tools if function_calls and tools: + # Extract function middleware pipeline from kwargs if available + middleware_pipeline = kwargs.get("_function_middleware_pipeline") function_results = await execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, tools=tools, # type: ignore + middleware_pipeline=middleware_pipeline, ) # add a single ChatMessage to the response with the results result_message = ChatMessage(role="tool", contents=function_results) # type: ignore[call-overload] @@ -815,11 +850,14 @@ def _handle_function_calls_streaming_response( tools = chat_options.tools if function_calls and tools: + # Extract function middleware pipeline from kwargs if available + middleware_pipeline = kwargs.get("_function_middleware_pipeline") function_results = await execute_function_calls( custom_args=kwargs, attempt_idx=attempt_idx, function_calls=function_calls, tools=tools, # type: ignore[reportArgumentType] + middleware_pipeline=middleware_pipeline, ) function_result_msg = ChatMessage(role="tool", contents=function_results) yield ChatResponseUpdate(contents=function_results, role="tool") diff --git a/python/packages/main/agent_framework/guard_rails.py b/python/packages/main/agent_framework/guard_rails.py deleted file mode 100644 index 46f9d44352..0000000000 --- a/python/packages/main/agent_framework/guard_rails.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - - -from typing import Generic, Protocol, TypeVar, runtime_checkable - -TInput = TypeVar("TInput") -TResponse = TypeVar("TResponse") - -__all__ = ["InputGuardrail", "OutputGuardrail"] - - -@runtime_checkable -class InputGuardrail(Protocol, Generic[TInput]): - """A protocol for input guardrails that can validate and transform input messages.""" - - def __call__(self, message: TInput) -> TInput: - """Validate and possibly transform the input message.""" - ... - - -@runtime_checkable -class OutputGuardrail(Protocol, Generic[TResponse]): - """A protocol for output guardrails that can validate and transform output messages.""" - - def __call__(self, message: TResponse) -> TResponse: - """Validate and possibly transform the output message.""" - ... diff --git a/python/packages/main/agent_framework/telemetry.py b/python/packages/main/agent_framework/telemetry.py index 5382d413db..aabbfd9c06 100644 --- a/python/packages/main/agent_framework/telemetry.py +++ b/python/packages/main/agent_framework/telemetry.py @@ -1026,8 +1026,12 @@ def _capture_messages( prepped = prepare_messages(messages) for index, message in enumerate(prepped): + try: + message_data = message.model_dump(exclude_none=True) + except Exception: + message_data = {"role": message.role.value, "contents": message.contents} logger.info( - message.model_dump_json(exclude_none=True), + message_data, extra={ OtelAttr.EVENT_NAME: OtelAttr.CHOICE if output else ROLE_EVENT_MAP.get(message.role.value), OtelAttr.PROVIDER_NAME: provider_name, diff --git a/python/packages/main/pyproject.toml b/python/packages/main/pyproject.toml index 24eec9eee5..fe422ae882 100644 --- a/python/packages/main/pyproject.toml +++ b/python/packages/main/pyproject.toml @@ -72,7 +72,15 @@ environments = [ fallback-version = "0.0.0" [tool.pytest.ini_options] -testpaths = 'tests' +testpaths = [ + 'tests', + 'packages/main/tests', + 'packages/azure/tests', + 'packages/foundry/tests', + 'packages/copilotstudio/tests', + 'packages/mem0/tests', + 'packages/runtime/tests' +] addopts = "-ra -q -r fEX" asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" diff --git a/python/packages/main/tests/main/conftest.py b/python/packages/main/tests/main/conftest.py index 001510f877..9cb938f8ea 100644 --- a/python/packages/main/tests/main/conftest.py +++ b/python/packages/main/tests/main/conftest.py @@ -105,6 +105,9 @@ class MockChatClient: def __init__(self) -> None: self.additional_properties: dict[str, Any] = {} + self.call_count: int = 0 + self.responses: list[ChatResponse] = [] + self.streaming_responses: list[list[ChatResponseUpdate]] = [] async def get_response( self, @@ -112,6 +115,9 @@ class MockChatClient: **kwargs: Any, ) -> ChatResponse: logger.debug(f"Running custom chat client, with: {messages=}, {kwargs=}") + self.call_count += 1 + if self.responses: + return self.responses.pop(0) return ChatResponse(messages=ChatMessage(role="assistant", text="test response")) async def get_streaming_response( @@ -120,8 +126,13 @@ class MockChatClient: **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: logger.debug(f"Running custom chat client stream, with: {messages=}, {kwargs=}") - yield ChatResponseUpdate(text=TextContent(text="test streaming response "), role="assistant") - yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant") + self.call_count += 1 + if self.streaming_responses: + for update in self.streaming_responses.pop(0): + yield update + else: + yield ChatResponseUpdate(text=TextContent(text="test streaming response "), role="assistant") + yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant") class MockBaseChatClient(BaseChatClient): diff --git a/python/packages/main/tests/main/test_middleware.py b/python/packages/main/tests/main/test_middleware.py new file mode 100644 index 0000000000..f0f8b536c1 --- /dev/null +++ b/python/packages/main/tests/main/test_middleware.py @@ -0,0 +1,939 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable, Awaitable, Callable +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, Field + +from agent_framework import ( + AgentProtocol, + AgentRunResponse, + AgentRunResponseUpdate, + ChatMessage, + Role, + TextContent, +) +from agent_framework._middleware import ( + AgentMiddleware, + AgentMiddlewarePipeline, + AgentRunContext, + FunctionInvocationContext, + FunctionMiddleware, + FunctionMiddlewarePipeline, +) +from agent_framework._tools import AIFunction + + +class TestAgentRunContext: + """Test cases for AgentRunContext.""" + + def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None: + """Test AgentRunContext initialization with default values.""" + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + assert context.agent is mock_agent + assert context.messages == messages + assert context.is_streaming is False + assert context.metadata == {} + + def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None: + """Test AgentRunContext initialization with custom values.""" + messages = [ChatMessage(role=Role.USER, text="test")] + metadata = {"key": "value"} + context = AgentRunContext(agent=mock_agent, messages=messages, is_streaming=True, metadata=metadata) + + assert context.agent is mock_agent + assert context.messages == messages + assert context.is_streaming is True + assert context.metadata == metadata + + +class TestFunctionInvocationContext: + """Test cases for FunctionInvocationContext.""" + + def test_init_with_defaults(self, mock_function: AIFunction[Any, Any]) -> None: + """Test FunctionInvocationContext initialization with default values.""" + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + assert context.function is mock_function + assert context.arguments == arguments + assert context.metadata == {} + + def test_init_with_custom_metadata(self, mock_function: AIFunction[Any, Any]) -> None: + """Test FunctionInvocationContext initialization with custom metadata.""" + arguments = FunctionTestArgs(name="test") + metadata = {"key": "value"} + context = FunctionInvocationContext(function=mock_function, arguments=arguments, metadata=metadata) + + assert context.function is mock_function + assert context.arguments == arguments + assert context.metadata == metadata + + +class TestAgentMiddlewarePipeline: + """Test cases for AgentMiddlewarePipeline.""" + + def test_init_empty(self) -> None: + """Test AgentMiddlewarePipeline initialization with no middlewares.""" + pipeline = AgentMiddlewarePipeline() + assert not pipeline.has_middlewares + + def test_init_with_class_middleware(self) -> None: + """Test AgentMiddlewarePipeline initialization with class-based middleware.""" + middleware = TestAgentMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + assert pipeline.has_middlewares + + def test_init_with_function_middleware(self) -> None: + """Test AgentMiddlewarePipeline initialization with function-based middleware.""" + + async def test_middleware(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + await next(context) + + pipeline = AgentMiddlewarePipeline([test_middleware]) + assert pipeline.has_middlewares + + async def test_execute_no_middleware(self, mock_agent: AgentProtocol) -> None: + """Test pipeline execution with no middleware.""" + pipeline = AgentMiddlewarePipeline() + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + return expected_response + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + assert result == expected_response + + async def test_execute_with_middleware(self, mock_agent: AgentProtocol) -> None: + """Test pipeline execution with middleware.""" + execution_order: list[str] = [] + + class OrderTrackingMiddleware(AgentMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + middleware = OrderTrackingMiddleware("test") + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + execution_order.append("handler") + return expected_response + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + assert result == expected_response + assert execution_order == ["test_before", "handler", "test_after"] + + async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> None: + """Test pipeline streaming execution with no middleware.""" + pipeline = AgentMiddlewarePipeline() + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + + updates: list[AgentRunResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): + updates.append(update) + + assert len(updates) == 2 + assert updates[0].text == "chunk1" + assert updates[1].text == "chunk2" + + async def test_execute_stream_with_middleware(self, mock_agent: AgentProtocol) -> None: + """Test pipeline streaming execution with middleware.""" + execution_order: list[str] = [] + + class StreamOrderTrackingMiddleware(AgentMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + middleware = StreamOrderTrackingMiddleware("test") + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + execution_order.append("handler_start") + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + execution_order.append("handler_end") + + updates: list[AgentRunResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): + updates.append(update) + + assert len(updates) == 2 + assert updates[0].text == "chunk1" + assert updates[1].text == "chunk2" + assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"] + + +class TestFunctionMiddlewarePipeline: + """Test cases for FunctionMiddlewarePipeline.""" + + def test_init_empty(self) -> None: + """Test FunctionMiddlewarePipeline initialization with no middlewares.""" + pipeline = FunctionMiddlewarePipeline() + assert not pipeline.has_middlewares + + def test_init_with_class_middleware(self) -> None: + """Test FunctionMiddlewarePipeline initialization with class-based middleware.""" + middleware = TestFunctionMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + assert pipeline.has_middlewares + + def test_init_with_function_middleware(self) -> None: + """Test FunctionMiddlewarePipeline initialization with function-based middleware.""" + + async def test_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + await next(context) + + pipeline = FunctionMiddlewarePipeline([test_middleware]) + assert pipeline.has_middlewares + + async def test_execute_no_middleware(self, mock_function: AIFunction[Any, Any]) -> None: + """Test pipeline execution with no middleware.""" + pipeline = FunctionMiddlewarePipeline() + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + expected_result = "function_result" + + async def final_handler(ctx: FunctionInvocationContext) -> str: + return expected_result + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + assert result == expected_result + + async def test_execute_with_middleware(self, mock_function: AIFunction[Any, Any]) -> None: + """Test pipeline execution with middleware.""" + execution_order: list[str] = [] + + class OrderTrackingFunctionMiddleware(FunctionMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + middleware = OrderTrackingFunctionMiddleware("test") + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + expected_result = "function_result" + + async def final_handler(ctx: FunctionInvocationContext) -> str: + execution_order.append("handler") + return expected_result + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + assert result == expected_result + assert execution_order == ["test_before", "handler", "test_after"] + + +class TestClassBasedMiddleware: + """Test cases for class-based middleware implementations.""" + + async def test_agent_middleware_execution(self, mock_agent: AgentProtocol) -> None: + """Test class-based agent middleware execution.""" + metadata_updates: list[str] = [] + + class MetadataAgentMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + context.metadata["before"] = True + metadata_updates.append("before") + await next(context) + context.metadata["after"] = True + metadata_updates.append("after") + + middleware = MetadataAgentMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + metadata_updates.append("handler") + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + assert result is not None + assert context.metadata["before"] is True + assert context.metadata["after"] is True + assert metadata_updates == ["before", "handler", "after"] + + async def test_function_middleware_execution(self, mock_function: AIFunction[Any, Any]) -> None: + """Test class-based function middleware execution.""" + metadata_updates: list[str] = [] + + class MetadataFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + context.metadata["before"] = True + metadata_updates.append("before") + await next(context) + context.metadata["after"] = True + metadata_updates.append("after") + + middleware = MetadataFunctionMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + async def final_handler(ctx: FunctionInvocationContext) -> str: + metadata_updates.append("handler") + return "result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + assert result == "result" + assert context.metadata["before"] is True + assert context.metadata["after"] is True + assert metadata_updates == ["before", "handler", "after"] + + +class TestFunctionBasedMiddleware: + """Test cases for function-based middleware implementations.""" + + async def test_agent_function_middleware(self, mock_agent: AgentProtocol) -> None: + """Test function-based agent middleware.""" + execution_order: list[str] = [] + + async def test_agent_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("function_before") + context.metadata["function_middleware"] = True + await next(context) + execution_order.append("function_after") + + pipeline = AgentMiddlewarePipeline([test_agent_middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + execution_order.append("handler") + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + assert result is not None + assert context.metadata["function_middleware"] is True + assert execution_order == ["function_before", "handler", "function_after"] + + async def test_function_function_middleware(self, mock_function: AIFunction[Any, Any]) -> None: + """Test function-based function middleware.""" + execution_order: list[str] = [] + + async def test_function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("function_before") + context.metadata["function_middleware"] = True + await next(context) + execution_order.append("function_after") + + pipeline = FunctionMiddlewarePipeline([test_function_middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + async def final_handler(ctx: FunctionInvocationContext) -> str: + execution_order.append("handler") + return "result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + assert result == "result" + assert context.metadata["function_middleware"] is True + assert execution_order == ["function_before", "handler", "function_after"] + + +class TestMixedMiddleware: + """Test cases for mixed class and function-based middleware.""" + + async def test_mixed_agent_middleware(self, mock_agent: AgentProtocol) -> None: + """Test mixed class and function-based agent middleware.""" + execution_order: list[str] = [] + + class ClassMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("class_before") + await next(context) + execution_order.append("class_after") + + async def function_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("function_before") + await next(context) + execution_order.append("function_after") + + pipeline = AgentMiddlewarePipeline([ClassMiddleware(), function_middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + execution_order.append("handler") + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + assert result is not None + assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"] + + async def test_mixed_function_middleware(self, mock_function: AIFunction[Any, Any]) -> None: + """Test mixed class and function-based function middleware.""" + execution_order: list[str] = [] + + class ClassMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append("class_before") + await next(context) + execution_order.append("class_after") + + async def function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("function_before") + await next(context) + execution_order.append("function_after") + + pipeline = FunctionMiddlewarePipeline([ClassMiddleware(), function_middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + async def final_handler(ctx: FunctionInvocationContext) -> str: + execution_order.append("handler") + return "result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + assert result == "result" + assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"] + + +class TestMultipleMiddlewareOrdering: + """Test cases for multiple middleware execution order.""" + + async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) -> None: + """Test that multiple agent middlewares execute in registration order.""" + execution_order: list[str] = [] + + class FirstMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("first_before") + await next(context) + execution_order.append("first_after") + + class SecondMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("second_before") + await next(context) + execution_order.append("second_after") + + class ThirdMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("third_before") + await next(context) + execution_order.append("third_after") + + middlewares = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] + pipeline = AgentMiddlewarePipeline(middlewares) # type: ignore + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + execution_order.append("handler") + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + assert result is not None + expected_order = [ + "first_before", + "second_before", + "third_before", + "handler", + "third_after", + "second_after", + "first_after", + ] + assert execution_order == expected_order + + async def test_function_middleware_execution_order(self, mock_function: AIFunction[Any, Any]) -> None: + """Test that multiple function middlewares execute in registration order.""" + execution_order: list[str] = [] + + class FirstMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append("first_before") + await next(context) + execution_order.append("first_after") + + class SecondMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append("second_before") + await next(context) + execution_order.append("second_after") + + middlewares = [FirstMiddleware(), SecondMiddleware()] + pipeline = FunctionMiddlewarePipeline(middlewares) # type: ignore + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + async def final_handler(ctx: FunctionInvocationContext) -> str: + execution_order.append("handler") + return "result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + assert result == "result" + expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"] + assert execution_order == expected_order + + +class TestContextContentValidation: + """Test cases for validating middleware context content.""" + + async def test_agent_context_validation(self, mock_agent: AgentProtocol) -> None: + """Test that agent context contains expected data.""" + + class ContextValidationMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Verify context has all expected attributes + assert hasattr(context, "agent") + assert hasattr(context, "messages") + assert hasattr(context, "is_streaming") + assert hasattr(context, "metadata") + + # Verify context content + assert context.agent is mock_agent + assert len(context.messages) == 1 + assert context.messages[0].role == Role.USER + assert context.messages[0].text == "test" + assert context.is_streaming is False + assert isinstance(context.metadata, dict) + + # Add custom metadata + context.metadata["validated"] = True + + await next(context) + + middleware = ContextValidationMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + # Verify metadata was set by middleware + assert ctx.metadata.get("validated") is True + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + assert result is not None + + async def test_function_context_validation(self, mock_function: AIFunction[Any, Any]) -> None: + """Test that function context contains expected data.""" + + class ContextValidationMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + # Verify context has all expected attributes + assert hasattr(context, "function") + assert hasattr(context, "arguments") + assert hasattr(context, "metadata") + + # Verify context content + assert context.function is mock_function + assert isinstance(context.arguments, FunctionTestArgs) + assert context.arguments.name == "test" + assert isinstance(context.metadata, dict) + + # Add custom metadata + context.metadata["validated"] = True + + await next(context) + + middleware = ContextValidationMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + async def final_handler(ctx: FunctionInvocationContext) -> str: + # Verify metadata was set by middleware + assert ctx.metadata.get("validated") is True + return "result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + assert result == "result" + + +class TestStreamingScenarios: + """Test cases for streaming and non-streaming scenarios.""" + + async def test_streaming_flag_validation(self, mock_agent: AgentProtocol) -> None: + """Test that is_streaming flag is correctly set for streaming calls.""" + streaming_flags: list[bool] = [] + + class StreamingFlagMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + streaming_flags.append(context.is_streaming) + await next(context) + + middleware = StreamingFlagMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + + # Test non-streaming + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + streaming_flags.append(ctx.is_streaming) + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")]) + + await pipeline.execute(mock_agent, messages, context, final_handler) + + # Test streaming + context_stream = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + streaming_flags.append(ctx.is_streaming) + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk")]) + + updates: list[AgentRunResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context_stream, final_stream_handler): + updates.append(update) + + # Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler] + assert streaming_flags == [False, False, True, True] + + async def test_streaming_middleware_behavior(self, mock_agent: AgentProtocol) -> None: + """Test middleware behavior with streaming responses.""" + chunks_processed: list[str] = [] + + class StreamProcessingMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + chunks_processed.append("before_stream") + await next(context) + chunks_processed.append("after_stream") + + middleware = StreamProcessingMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + chunks_processed.append("stream_start") + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")]) + chunks_processed.append("chunk1_yielded") + yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")]) + chunks_processed.append("chunk2_yielded") + chunks_processed.append("stream_end") + + updates: list[str] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context, final_stream_handler): + updates.append(update.text) + + assert updates == ["chunk1", "chunk2"] + assert chunks_processed == [ + "before_stream", + "after_stream", + "stream_start", + "chunk1_yielded", + "chunk2_yielded", + "stream_end", + ] + + +# region Helper classes and fixtures + + +class FunctionTestArgs(BaseModel): + """Test arguments for function middleware tests.""" + + name: str = Field(description="Test name parameter") + + +class TestAgentMiddleware(AgentMiddleware): + """Test implementation of AgentMiddleware.""" + + async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + await next(context) + + +class TestFunctionMiddleware(FunctionMiddleware): + """Test implementation of FunctionMiddleware.""" + + async def process( + self, context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + await next(context) + + +class MockFunctionArgs(BaseModel): + """Test arguments for function middleware tests.""" + + name: str = Field(description="Test name parameter") + + +class TestMiddlewareExecutionControl: + """Test cases for middleware execution control (when next() is called vs not called).""" + + async def test_agent_middleware_no_next_no_execution(self, mock_agent: AgentProtocol) -> None: + """Test that when agent middleware doesn't call next(), no execution happens.""" + + class NoNextMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Don't call next() - this should prevent any execution + pass + + middleware = NoNextMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + handler_called = False + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + nonlocal handler_called + handler_called = True + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + # Verify no execution happened - should return empty AgentRunResponse + assert result is not None + assert isinstance(result, AgentRunResponse) + assert result.messages == [] # Empty response + assert not handler_called + assert context.result is None + + async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: AgentProtocol) -> None: + """Test that when agent middleware doesn't call next(), no streaming execution happens.""" + + class NoNextStreamingMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Don't call next() - this should prevent any execution + pass + + middleware = NoNextStreamingMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + handler_called = False + + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + nonlocal handler_called + handler_called = True + yield AgentRunResponseUpdate(contents=[TextContent(text="should not execute")]) + + # When middleware doesn't call next(), streaming should yield no updates + updates: list[AgentRunResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): + updates.append(update) + + # Verify no execution happened and no updates were yielded + assert len(updates) == 0 + assert not handler_called + assert context.result is None + + async def test_function_middleware_no_next_no_execution(self, mock_function: AIFunction[Any, Any]) -> None: + """Test that when function middleware doesn't call next(), no execution happens.""" + + class FunctionTestArgs(BaseModel): + name: str = Field(description="Test name parameter") + + class NoNextFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + # Don't call next() - this should prevent any execution + pass + + middleware = NoNextFunctionMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + handler_called = False + + async def final_handler(ctx: FunctionInvocationContext) -> str: + nonlocal handler_called + handler_called = True + return "should not execute" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + # Verify no execution happened + assert result is None + assert not handler_called + assert context.result is None + + async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) -> None: + """Test that when first middleware doesn't call next(), subsequent middlewares are not called.""" + execution_order: list[str] = [] + + class FirstMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("first") + # Don't call next() - this should stop the pipeline + + class SecondMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("second") + await next(context) + + pipeline = AgentMiddlewarePipeline([FirstMiddleware(), SecondMiddleware()]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + handler_called = False + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + nonlocal handler_called + handler_called = True + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + # Verify only first middleware was called and empty response returned + assert execution_order == ["first"] + assert result is not None + assert isinstance(result, AgentRunResponse) + assert result.messages == [] # Empty response + assert not handler_called + + async def test_function_middleware_pre_execution_override_with_next( + self, mock_function: AIFunction[Any, Any] + ) -> None: + """Test that function middleware can override result before calling next() - this skips handler execution.""" + + class FunctionTestArgs(BaseModel): + name: str = Field(description="Test name parameter") + + class PreOverrideFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + # Set override first + context.result = "pre-override result" + # Then call next() to continue middleware pipeline + await next(context) + + middleware = PreOverrideFunctionMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + handler_called = False + + async def final_handler(ctx: FunctionInvocationContext) -> str: + nonlocal handler_called + handler_called = True + # This should not be called when result is pre-set + return "original result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + # Verify pre-override worked and handler was NOT called (because result was already set) + assert result == "pre-override result" + assert not handler_called + + +@pytest.fixture +def mock_agent() -> AgentProtocol: + """Mock agent for testing.""" + agent = MagicMock(spec=AgentProtocol) + agent.name = "test_agent" + return agent + + +@pytest.fixture +def mock_function() -> AIFunction[Any, Any]: + """Mock function for testing.""" + function = MagicMock(spec=AIFunction[Any, Any]) + function.name = "test_function" + return function diff --git a/python/packages/main/tests/main/test_middleware_context_result.py b/python/packages/main/tests/main/test_middleware_context_result.py new file mode 100644 index 0000000000..447ba0d4b9 --- /dev/null +++ b/python/packages/main/tests/main/test_middleware_context_result.py @@ -0,0 +1,463 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable, Awaitable, Callable +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, Field + +from agent_framework import ( + AgentProtocol, + AgentRunResponse, + AgentRunResponseUpdate, + ChatAgent, + ChatMessage, + Role, + TextContent, +) +from agent_framework._middleware import ( + AgentMiddleware, + AgentMiddlewarePipeline, + AgentRunContext, + FunctionInvocationContext, + FunctionMiddleware, + FunctionMiddlewarePipeline, +) +from agent_framework._tools import AIFunction + +from .conftest import MockChatClient + + +class FunctionTestArgs(BaseModel): + """Test arguments for function middleware tests.""" + + name: str = Field(description="Test name parameter") + + +class TestResultOverrideMiddleware: + """Test cases for middleware result override functionality.""" + + async def test_agent_middleware_response_override_non_streaming(self, mock_agent: AgentProtocol) -> None: + """Test that agent middleware can override response for non-streaming execution.""" + override_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="overridden response")]) + + class ResponseOverrideMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Execute the pipeline first, then override the response + await next(context) + context.result = override_response + + middleware = ResponseOverrideMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + handler_called = False + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + nonlocal handler_called + handler_called = True + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="original response")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + # Verify the overridden response is returned + assert result is not None + assert result == override_response + assert result.messages[0].text == "overridden response" + # Verify original handler was called since middleware called next() + assert handler_called + + async def test_agent_middleware_response_override_streaming(self, mock_agent: AgentProtocol) -> None: + """Test that agent middleware can override response for streaming execution.""" + + async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]: + yield AgentRunResponseUpdate(contents=[TextContent(text="overridden")]) + yield AgentRunResponseUpdate(contents=[TextContent(text=" stream")]) + + class StreamResponseOverrideMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Execute the pipeline first, then override the response stream + await next(context) + context.result = override_stream() + + middleware = StreamResponseOverrideMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]: + yield AgentRunResponseUpdate(contents=[TextContent(text="original")]) + + updates: list[AgentRunResponseUpdate] = [] + async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): + updates.append(update) + + # Verify the overridden response stream is returned + assert len(updates) == 2 + assert updates[0].text == "overridden" + assert updates[1].text == " stream" + + async def test_function_middleware_result_override(self, mock_function: AIFunction[Any, Any]) -> None: + """Test that function middleware can override result.""" + override_result = "overridden function result" + + class ResultOverrideMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + # Execute the pipeline first, then override the result + await next(context) + context.result = override_result + + middleware = ResultOverrideMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + handler_called = False + + async def final_handler(ctx: FunctionInvocationContext) -> str: + nonlocal handler_called + handler_called = True + return "original function result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + # Verify the overridden result is returned + assert result == override_result + # Verify original handler was called since middleware called next() + assert handler_called + + async def test_chat_agent_middleware_response_override(self) -> None: + """Test result override functionality with ChatAgent integration.""" + mock_chat_client = MockChatClient() + + class ChatAgentResponseOverrideMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Always call next() first to allow execution + await next(context) + # Then conditionally override based on content + if any("special" in msg.text for msg in context.messages if msg.text): + context.result = AgentRunResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="Special response from middleware!")] + ) + + # Create ChatAgent with override middleware + middleware = ChatAgentResponseOverrideMiddleware() + agent = ChatAgent(chat_client=mock_chat_client, middleware=[middleware]) + + # Test override case + override_messages = [ChatMessage(role=Role.USER, text="Give me a special response")] + override_response = await agent.run(override_messages) + assert override_response.messages[0].text == "Special response from middleware!" + # Verify chat client was called since middleware called next() + assert mock_chat_client.call_count == 1 + + # Test normal case + normal_messages = [ChatMessage(role=Role.USER, text="Normal request")] + normal_response = await agent.run(normal_messages) + assert normal_response.messages[0].text == "test response" + # Verify chat client was called for normal case + assert mock_chat_client.call_count == 2 + + async def test_chat_agent_middleware_streaming_override(self) -> None: + """Test streaming result override functionality with ChatAgent integration.""" + mock_chat_client = MockChatClient() + + async def custom_stream() -> AsyncIterable[AgentRunResponseUpdate]: + yield AgentRunResponseUpdate(contents=[TextContent(text="Custom")]) + yield AgentRunResponseUpdate(contents=[TextContent(text=" streaming")]) + yield AgentRunResponseUpdate(contents=[TextContent(text=" response!")]) + + class ChatAgentStreamOverrideMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Always call next() first to allow execution + await next(context) + # Then conditionally override based on content + if any("custom stream" in msg.text for msg in context.messages if msg.text): + context.result = custom_stream() + + # Create ChatAgent with override middleware + middleware = ChatAgentStreamOverrideMiddleware() + agent = ChatAgent(chat_client=mock_chat_client, middleware=[middleware]) + + # Test streaming override case + override_messages = [ChatMessage(role=Role.USER, text="Give me a custom stream")] + override_updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream(override_messages): + override_updates.append(update) + + assert len(override_updates) == 3 + assert override_updates[0].text == "Custom" + assert override_updates[1].text == " streaming" + assert override_updates[2].text == " response!" + + # Test normal streaming case + normal_messages = [ChatMessage(role=Role.USER, text="Normal streaming request")] + normal_updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream(normal_messages): + normal_updates.append(update) + + assert len(normal_updates) == 2 + assert normal_updates[0].text == "test streaming response " + assert normal_updates[1].text == "another update" + + async def test_agent_middleware_conditional_no_next(self, mock_agent: AgentProtocol) -> None: + """Test that when agent middleware conditionally doesn't call next(), no execution happens.""" + + class ConditionalNoNextMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Only call next() if message contains "execute" + if any("execute" in msg.text for msg in context.messages if msg.text): + await next(context) + # Otherwise, don't call next() - no execution should happen + + middleware = ConditionalNoNextMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + + handler_called = False + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + nonlocal handler_called + handler_called = True + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")]) + + # Test case where next() is NOT called + no_execute_messages = [ChatMessage(role=Role.USER, text="Don't run this")] + no_execute_context = AgentRunContext(agent=mock_agent, messages=no_execute_messages) + no_execute_result = await pipeline.execute(mock_agent, no_execute_messages, no_execute_context, final_handler) + + # When middleware doesn't call next(), result should be empty AgentRunResponse + assert no_execute_result is not None + assert isinstance(no_execute_result, AgentRunResponse) + assert no_execute_result.messages == [] # Empty response + assert not handler_called + assert no_execute_context.result is None + + # Reset for next test + handler_called = False + + # Test case where next() IS called + execute_messages = [ChatMessage(role=Role.USER, text="Please execute this")] + execute_context = AgentRunContext(agent=mock_agent, messages=execute_messages) + execute_result = await pipeline.execute(mock_agent, execute_messages, execute_context, final_handler) + + assert execute_result is not None + assert execute_result.messages[0].text == "executed response" + assert handler_called + + async def test_function_middleware_conditional_no_next(self, mock_function: AIFunction[Any, Any]) -> None: + """Test that when function middleware conditionally doesn't call next(), no execution happens.""" + + class ConditionalNoNextFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + # Only call next() if argument name contains "execute" + args = context.arguments + assert isinstance(args, FunctionTestArgs) + if "execute" in args.name: + await next(context) + # Otherwise, don't call next() - no execution should happen + + middleware = ConditionalNoNextFunctionMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + + handler_called = False + + async def final_handler(ctx: FunctionInvocationContext) -> str: + nonlocal handler_called + handler_called = True + return "executed function result" + + # Test case where next() is NOT called + no_execute_args = FunctionTestArgs(name="test_no_action") + no_execute_context = FunctionInvocationContext(function=mock_function, arguments=no_execute_args) + no_execute_result = await pipeline.execute(mock_function, no_execute_args, no_execute_context, final_handler) + + # When middleware doesn't call next(), function result should be None (functions can return None) + assert no_execute_result is None + assert not handler_called + assert no_execute_context.result is None + + # Reset for next test + handler_called = False + + # Test case where next() IS called + execute_args = FunctionTestArgs(name="test_execute") + execute_context = FunctionInvocationContext(function=mock_function, arguments=execute_args) + execute_result = await pipeline.execute(mock_function, execute_args, execute_context, final_handler) + + assert execute_result == "executed function result" + assert handler_called + + +class TestResultObservability: + """Test cases for middleware result observability functionality.""" + + async def test_agent_middleware_response_observability(self, mock_agent: AgentProtocol) -> None: + """Test that middleware can observe response after execution.""" + observed_responses: list[AgentRunResponse] = [] + + class ObservabilityMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Context should be empty before next() + assert context.result is None + + # Call next to execute + await next(context) + + # Context should now contain the response for observability + assert context.result is not None + assert isinstance(context.result, AgentRunResponse) + observed_responses.append(context.result) + + middleware = ObservabilityMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + # Verify response was observed + assert len(observed_responses) == 1 + assert observed_responses[0].messages[0].text == "executed response" + assert result == observed_responses[0] + + async def test_function_middleware_result_observability(self, mock_function: AIFunction[Any, Any]) -> None: + """Test that middleware can observe function result after execution.""" + observed_results: list[str] = [] + + class ObservabilityMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + # Context should be empty before next() + assert context.result is None + + # Call next to execute + await next(context) + + # Context should now contain the result for observability + assert context.result is not None + observed_results.append(context.result) + + middleware = ObservabilityMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + async def final_handler(ctx: FunctionInvocationContext) -> str: + return "executed function result" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + # Verify result was observed + assert len(observed_results) == 1 + assert observed_results[0] == "executed function result" + assert result == observed_results[0] + + async def test_agent_middleware_post_execution_override(self, mock_agent: AgentProtocol) -> None: + """Test that middleware can override response after observing execution.""" + + class PostExecutionOverrideMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Call next to execute first + await next(context) + + # Now observe and conditionally override + assert context.result is not None + assert isinstance(context.result, AgentRunResponse) + + if "modify" in context.result.messages[0].text: + # Override after observing + context.result = AgentRunResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="modified after execution")] + ) + + middleware = PostExecutionOverrideMiddleware() + pipeline = AgentMiddlewarePipeline([middleware]) + messages = [ChatMessage(role=Role.USER, text="test")] + context = AgentRunContext(agent=mock_agent, messages=messages) + + async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response to modify")]) + + result = await pipeline.execute(mock_agent, messages, context, final_handler) + + # Verify response was modified after execution + assert result is not None + assert result.messages[0].text == "modified after execution" + + async def test_function_middleware_post_execution_override(self, mock_function: AIFunction[Any, Any]) -> None: + """Test that middleware can override function result after observing execution.""" + + class PostExecutionOverrideMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + # Call next to execute first + await next(context) + + # Now observe and conditionally override + assert context.result is not None + + if "modify" in context.result: + # Override after observing + context.result = "modified after execution" + + middleware = PostExecutionOverrideMiddleware() + pipeline = FunctionMiddlewarePipeline([middleware]) + arguments = FunctionTestArgs(name="test") + context = FunctionInvocationContext(function=mock_function, arguments=arguments) + + async def final_handler(ctx: FunctionInvocationContext) -> str: + return "result to modify" + + result = await pipeline.execute(mock_function, arguments, context, final_handler) + + # Verify result was modified after execution + assert result == "modified after execution" + + +@pytest.fixture +def mock_agent() -> AgentProtocol: + """Mock agent for testing.""" + agent = MagicMock(spec=AgentProtocol) + agent.name = "test_agent" + return agent + + +@pytest.fixture +def mock_function() -> AIFunction[Any, Any]: + """Mock function for testing.""" + function = MagicMock(spec=AIFunction[Any, Any]) + function.name = "test_function" + return function diff --git a/python/packages/main/tests/main/test_middleware_with_agent.py b/python/packages/main/tests/main/test_middleware_with_agent.py new file mode 100644 index 0000000000..7a2c030285 --- /dev/null +++ b/python/packages/main/tests/main/test_middleware_with_agent.py @@ -0,0 +1,544 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import Awaitable, Callable + +from agent_framework import ( + AgentRunResponseUpdate, + ChatAgent, + ChatMessage, + ChatResponse, + ChatResponseUpdate, + FunctionCallContent, + FunctionResultContent, + Role, + TextContent, +) +from agent_framework._middleware import ( + AgentMiddleware, + AgentRunContext, + FunctionInvocationContext, + FunctionMiddleware, +) + +from .conftest import MockChatClient + +# region ChatAgent Tests + + +class TestChatAgentClassBasedMiddleware: + """Test cases for class-based middleware integration with ChatAgent.""" + + async def test_class_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: + """Test class-based agent middleware with ChatAgent.""" + execution_order: list[str] = [] + + class TrackingAgentMiddleware(AgentMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + # Create ChatAgent with middleware + middleware = TrackingAgentMiddleware("agent_middleware") + agent = ChatAgent(chat_client=chat_client, middleware=[middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + # Note: conftest "MockChatClient" returns different text format + assert "test response" in response.messages[0].text + + # Verify middleware execution order + assert execution_order == ["agent_middleware_before", "agent_middleware_after"] + + async def test_class_based_function_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: + """Test class-based function middleware with ChatAgent.""" + execution_order: list[str] = [] + + class TrackingFunctionMiddleware(FunctionMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + # Create ChatAgent with function middleware (no tools, so function middleware won't be triggered) + middleware = TrackingFunctionMiddleware("function_middleware") + agent = ChatAgent(chat_client=chat_client, middleware=[middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 1 + + # Note: Function middleware won't execute since no function calls are made + assert execution_order == [] + + +class TestChatAgentFunctionBasedMiddleware: + """Test cases for function-based middleware integration with ChatAgent.""" + + async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: + """Test function-based agent middleware with ChatAgent.""" + execution_order: list[str] = [] + + async def tracking_agent_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("agent_function_before") + await next(context) + execution_order.append("agent_function_after") + + # Create ChatAgent with function middleware + agent = ChatAgent(chat_client=chat_client, middleware=[tracking_agent_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert response.messages[0].role == Role.ASSISTANT + assert response.messages[0].text == "test response" + assert chat_client.call_count == 1 + + # Verify middleware execution order + assert execution_order == ["agent_function_before", "agent_function_after"] + + async def test_function_based_function_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None: + """Test function-based function middleware with ChatAgent.""" + execution_order: list[str] = [] + + async def tracking_function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("function_function_before") + await next(context) + execution_order.append("function_function_after") + + # Create ChatAgent with function middleware (no tools, so function middleware won't be triggered) + agent = ChatAgent(chat_client=chat_client, middleware=[tracking_function_middleware]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 1 + + # Note: Function middleware won't execute since no function calls are made + assert execution_order == [] + + +class TestChatAgentStreamingMiddleware: + """Test cases for streaming middleware integration with ChatAgent.""" + + async def test_agent_middleware_with_streaming(self, chat_client: "MockChatClient") -> None: + """Test agent middleware with streaming ChatAgent responses.""" + execution_order: list[str] = [] + streaming_flags: list[bool] = [] + + class StreamingTrackingMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("middleware_before") + streaming_flags.append(context.is_streaming) + await next(context) + execution_order.append("middleware_after") + + # Create ChatAgent with middleware + middleware = StreamingTrackingMiddleware() + agent = ChatAgent(chat_client=chat_client, middleware=[middleware]) + + # Set up mock streaming responses + chat_client.streaming_responses = [ + [ + ChatResponseUpdate(contents=[TextContent(text="Streaming")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT), + ] + ] + + # Execute streaming + messages = [ChatMessage(role=Role.USER, text="test message")] + updates: list[AgentRunResponseUpdate] = [] + async for update in agent.run_stream(messages): + updates.append(update) + + # Verify streaming response + assert len(updates) == 2 + assert updates[0].text == "Streaming" + assert updates[1].text == " response" + assert chat_client.call_count == 1 + + # Verify middleware was called and streaming flag was set correctly + assert execution_order == ["middleware_before", "middleware_after"] + assert streaming_flags == [True] # Context should indicate streaming + + async def test_non_streaming_vs_streaming_flag_validation(self, chat_client: "MockChatClient") -> None: + """Test that is_streaming flag is correctly set for different execution modes.""" + streaming_flags: list[bool] = [] + + class FlagTrackingMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + streaming_flags.append(context.is_streaming) + await next(context) + + # Create ChatAgent with middleware + middleware = FlagTrackingMiddleware() + agent = ChatAgent(chat_client=chat_client, middleware=[middleware]) + messages = [ChatMessage(role=Role.USER, text="test message")] + + # Test non-streaming execution + response = await agent.run(messages) + assert response is not None + + # Test streaming execution + async for _ in agent.run_stream(messages): + pass + + # Verify flags: [non-streaming, streaming] + assert streaming_flags == [False, True] + + +class TestChatAgentMultipleMiddlewareOrdering: + """Test cases for multiple middleware execution order with ChatAgent.""" + + async def test_multiple_agent_middleware_execution_order(self, chat_client: "MockChatClient") -> None: + """Test that multiple agent middlewares execute in correct order with ChatAgent.""" + execution_order: list[str] = [] + + class OrderedMiddleware(AgentMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + # Create multiple middlewares + middleware1 = OrderedMiddleware("first") + middleware2 = OrderedMiddleware("second") + middleware3 = OrderedMiddleware("third") + + # Create ChatAgent with multiple middlewares + agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2, middleware3]) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert chat_client.call_count == 1 + + # Verify execution order (should be nested: first wraps second wraps third) + expected_order = ["first_before", "second_before", "third_before", "third_after", "second_after", "first_after"] + assert execution_order == expected_order + + async def test_mixed_middleware_types_with_chat_agent(self, chat_client: "MockChatClient") -> None: + """Test mixed class and function-based middlewares with ChatAgent.""" + execution_order: list[str] = [] + + class ClassAgentMiddleware(AgentMiddleware): + async def process( + self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("class_agent_before") + await next(context) + execution_order.append("class_agent_after") + + async def function_agent_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + execution_order.append("function_agent_before") + await next(context) + execution_order.append("function_agent_after") + + class ClassFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append("class_function_before") + await next(context) + execution_order.append("class_function_after") + + async def function_function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("function_function_before") + await next(context) + execution_order.append("function_function_after") + + # Create ChatAgent with mixed middleware types (no tools, focusing on agent middleware) + agent = ChatAgent( + chat_client=chat_client, + middleware=[ + ClassAgentMiddleware(), + function_agent_middleware, + ClassFunctionMiddleware(), # Won't execute without function calls + function_function_middleware, # Won't execute without function calls + ], + ) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="test message")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert chat_client.call_count == 1 + + # Verify that agent middlewares were executed in correct order + # (Function middlewares won't execute since no functions are called) + expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"] + assert execution_order == expected_order + + +# region Tool Functions for Testing + + +def sample_tool_function(location: str) -> str: + """A simple tool function for middleware testing.""" + return f"Weather in {location}: sunny" + + +# region ChatAgent Function Middleware Tests with Tools + + +class TestChatAgentFunctionMiddlewareWithTools: + """Test cases for function middleware integration with ChatAgent when tools are used.""" + + async def test_class_based_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None: + """Test class-based function middleware with ChatAgent when function calls are made.""" + execution_order: list[str] = [] + + class TrackingFunctionMiddleware(FunctionMiddleware): + def __init__(self, name: str): + self.name = name + + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append(f"{self.name}_before") + await next(context) + execution_order.append(f"{self.name}_after") + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="call_123", + name="sample_tool_function", + arguments='{"location": "Seattle"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + + chat_client.responses = [function_call_response, final_response] + + # Create ChatAgent with function middleware and tools + middleware = TrackingFunctionMiddleware("function_middleware") + agent = ChatAgent( + chat_client=chat_client, + middleware=[middleware], + tools=[sample_tool_function], + ) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="Get weather for Seattle")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 2 # Two calls: one for function call, one for final response + + # Verify function middleware was executed + assert execution_order == ["function_middleware_before", "function_middleware_after"] + + # Verify function call and result are in the response + all_contents = [content for message in response.messages for content in message.contents] + function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] + function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + + assert len(function_calls) == 1 + assert len(function_results) == 1 + assert function_calls[0].name == "sample_tool_function" + assert function_results[0].call_id == function_calls[0].call_id + + async def test_function_based_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None: + """Test function-based function middleware with ChatAgent when function calls are made.""" + execution_order: list[str] = [] + + async def tracking_function_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + execution_order.append("function_middleware_before") + await next(context) + execution_order.append("function_middleware_after") + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="call_456", + name="sample_tool_function", + arguments='{"location": "San Francisco"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + + chat_client.responses = [function_call_response, final_response] + + # Create ChatAgent with function middleware and tools + agent = ChatAgent( + chat_client=chat_client, + middleware=[tracking_function_middleware], + tools=[sample_tool_function], + ) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 2 # Two calls: one for function call, one for final response + + # Verify function middleware was executed + assert execution_order == ["function_middleware_before", "function_middleware_after"] + + # Verify function call and result are in the response + all_contents = [content for message in response.messages for content in message.contents] + function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] + function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + + assert len(function_calls) == 1 + assert len(function_results) == 1 + assert function_calls[0].name == "sample_tool_function" + assert function_results[0].call_id == function_calls[0].call_id + + async def test_mixed_agent_and_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None: + """Test both agent and function middleware with ChatAgent when function calls are made.""" + execution_order: list[str] = [] + + class TrackingAgentMiddleware(AgentMiddleware): + async def process( + self, + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], + ) -> None: + execution_order.append("agent_middleware_before") + await next(context) + execution_order.append("agent_middleware_after") + + class TrackingFunctionMiddleware(FunctionMiddleware): + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + execution_order.append("function_middleware_before") + await next(context) + execution_order.append("function_middleware_after") + + # Set up mock to return a function call first, then a regular response + function_call_response = ChatResponse( + messages=[ + ChatMessage( + role=Role.ASSISTANT, + contents=[ + FunctionCallContent( + call_id="call_789", + name="sample_tool_function", + arguments='{"location": "New York"}', + ) + ], + ) + ] + ) + final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")]) + + chat_client.responses = [function_call_response, final_response] + + # Create ChatAgent with both agent and function middleware and tools + agent = ChatAgent( + chat_client=chat_client, + middleware=[TrackingAgentMiddleware(), TrackingFunctionMiddleware()], + tools=[sample_tool_function], + ) + + # Execute the agent + messages = [ChatMessage(role=Role.USER, text="Get weather for New York")] + response = await agent.run(messages) + + # Verify response + assert response is not None + assert len(response.messages) > 0 + assert chat_client.call_count == 2 # Two calls: one for function call, one for final response + + # Verify middleware execution order: agent middleware wraps everything, + # function middleware only for function calls + expected_order = [ + "agent_middleware_before", + "function_middleware_before", + "function_middleware_after", + "agent_middleware_after", + ] + assert execution_order == expected_order + + # Verify function call and result are in the response + all_contents = [content for message in response.messages for content in message.contents] + function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] + function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + + assert len(function_calls) == 1 + assert len(function_results) == 1 + assert function_calls[0].name == "sample_tool_function" + assert function_results[0].call_id == function_calls[0].call_id diff --git a/python/samples/getting_started/middleware/class_based_middleware.py b/python/samples/getting_started/middleware/class_based_middleware.py new file mode 100644 index 0000000000..4958e3d094 --- /dev/null +++ b/python/samples/getting_started/middleware/class_based_middleware.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import time +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + AgentMiddleware, + AgentRunContext, + AgentRunResponse, + ChatMessage, + FunctionInvocationContext, + FunctionMiddleware, + Role, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Class-based Middleware Example + +This sample demonstrates how to implement middleware using class-based approach by inheriting +from AgentMiddleware and FunctionMiddleware base classes. The example includes: + +- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests + containing sensitive information like passwords or secrets +- LoggingFunctionMiddleware: Logs function execution details including timing and parameters + +This approach is useful when you need stateful middleware or complex logic that benefits +from object-oriented design patterns. +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +class SecurityAgentMiddleware(AgentMiddleware): + """Agent middleware that checks for security violations.""" + + async def process( + self, + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], + ) -> None: + # Check for potential security violations in the query + # Look at the last user message + last_message = context.messages[-1] if context.messages else None + if last_message and last_message.text: + query = last_message.text + if "password" in query.lower() or "secret" in query.lower(): + print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.") + # Override the result with warning message + context.result = AgentRunResponse( + messages=[ + ChatMessage(role=Role.ASSISTANT, text="Detected sensitive information, the request is blocked.") + ] + ) + # Simply don't call next() to prevent execution + return + + print("[SecurityAgentMiddleware] Security check passed.") + await next(context) + + +class LoggingFunctionMiddleware(FunctionMiddleware): + """Function middleware that logs function calls.""" + + async def process( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + function_name = context.function.name + print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.") + + start_time = time.time() + + await next(context) + + end_time = time.time() + duration = end_time - start_time + + print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.") + + +async def main() -> None: + """Example demonstrating class-based middleware.""" + print("=== Class-based Middleware Example ===") + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=get_weather, + middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()], + ) as agent, + ): + # Test with normal query + print("\n--- Normal Query ---") + query = "What's the weather like in Seattle?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text}\n") + + # Test with security-related query + print("--- Security Test ---") + query = "What's the password for the weather service?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/middleware/exception_handling_with_middleware.py b/python/samples/getting_started/middleware/exception_handling_with_middleware.py new file mode 100644 index 0000000000..ca37a75bd1 --- /dev/null +++ b/python/samples/getting_started/middleware/exception_handling_with_middleware.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Annotated + +from agent_framework import FunctionInvocationContext +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Exception Handling with Middleware + +This sample demonstrates how to use middleware for centralized exception handling in function calls. +The example shows: + +- How to catch exceptions thrown by functions and provide graceful error responses +- Overriding function results when errors occur to provide user-friendly messages +- Using middleware to implement retry logic, fallback mechanisms, or error reporting + +The middleware catches TimeoutError from an unstable data service and replaces it with +a helpful message for the user, preventing raw exceptions from reaching the end user. +""" + + +def unstable_data_service( + query: Annotated[str, Field(description="The data query to execute.")], +) -> str: + """A simulated data service that sometimes throws exceptions.""" + # Simulate failure + raise TimeoutError("Data service request timed out") + + +async def exception_handling_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] +) -> None: + function_name = context.function.name + + try: + print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}") + await next(context) + print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.") + except TimeoutError as e: + print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}") + # Override function result to provide custom message in response. + context.result = ( + "Request Timeout: The data service is taking longer than expected to respond.", + "Respond with message - 'Sorry for the inconvenience, please try again later.'", + ) + + +async def main() -> None: + """Example demonstrating exception handling with middleware.""" + print("=== Exception Handling Middleware Example ===") + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="DataAgent", + instructions="You are a helpful data assistant. Use the data service tool to fetch information for users.", + tools=unstable_data_service, + middleware=exception_handling_middleware, + ) as agent, + ): + query = "Get user statistics" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/middleware/function_based_middleware.py b/python/samples/getting_started/middleware/function_based_middleware.py new file mode 100644 index 0000000000..e6b92494b3 --- /dev/null +++ b/python/samples/getting_started/middleware/function_based_middleware.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import time +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import ( + AgentRunContext, + FunctionInvocationContext, +) +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Function-based Middleware Example + +This sample demonstrates how to implement middleware using simple async functions instead of classes. +The example includes: + +- Security middleware that validates agent requests for sensitive information +- Logging middleware that tracks function execution timing and parameters +- Performance monitoring to measure execution duration + +Function-based middleware is ideal for simple, stateless operations and provides a more +lightweight approach compared to class-based middleware. Both agent and function middleware +can be implemented as async functions that accept context and next parameters. +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +async def security_agent_middleware( + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], +) -> None: + """Agent middleware that checks for security violations.""" + # Check for potential security violations in the query + # For this example, we'll check the last user message + last_message = context.messages[-1] if context.messages else None + if last_message and last_message.text: + query = last_message.text + if "password" in query.lower() or "secret" in query.lower(): + print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.") + # Simply don't call next() to prevent execution + return + + print("[SecurityAgentMiddleware] Security check passed.") + await next(context) + + +async def logging_function_middleware( + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], +) -> None: + """Function middleware that logs function calls.""" + function_name = context.function.name + print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.") + + start_time = time.time() + + await next(context) + + end_time = time.time() + duration = end_time - start_time + + print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.") + + +async def main() -> None: + """Example demonstrating function-based middleware.""" + print("=== Function-based Middleware Example ===") + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant.", + tools=get_weather, + middleware=[security_agent_middleware, logging_function_middleware], + ) as agent, + ): + # Test with normal query + print("\n--- Normal Query ---") + query = "What's the weather like in Tokyo?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response'}\n") + + # Test with security violation + print("--- Security Test ---") + query = "What's the secret weather password?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text if result.text else 'No response'}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/middleware/override_result_with_middleware.py b/python/samples/getting_started/middleware/override_result_with_middleware.py new file mode 100644 index 0000000000..b56cb6c78b --- /dev/null +++ b/python/samples/getting_started/middleware/override_result_with_middleware.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from random import randint +from typing import Annotated + +from agent_framework import FunctionInvocationContext +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential +from pydantic import Field + +""" +Result Override with Middleware + +This sample demonstrates how to use middleware to intercept and modify function results +after execution. The example shows: + +- How to execute the original function first and then modify its result +- Replacing function outputs with custom messages or transformed data +- Using middleware for result filtering, formatting, or enhancement + +The weather override middleware lets the original weather function execute normally, +then replaces its result with a custom "perfect weather" message, demonstrating +how middleware can be used for content filtering, A/B testing, or result enhancement. +""" + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +async def weather_override_middleware( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] +) -> None: + function_name = context.function.name + + # Let the original function execute first + await next(context) + + # Override the result if it's a weather function + if function_name == "get_weather" and context.result is not None: + original_result = str(context.result) + print(f"[WeatherOverrideMiddleware] Original result: {original_result}") + + # Override with a custom message + # It's also possible to override the result before "next()" call if needed + custom_message = ( + "Weather Advisory - due to special atmospheric conditions, " + "all locations are experiencing perfect weather today! " + "Temperature is a comfortable 22°C with gentle breezes. " + "Perfect day for outdoor activities!" + ) + context.result = custom_message + print(f"[WeatherOverrideMiddleware] Overriding with custom message: {custom_message}") + + +async def main() -> None: + """Example demonstrating result override with middleware.""" + print("=== Result Override Middleware Example ===") + + # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred + # authentication option. + async with ( + AzureCliCredential() as credential, + FoundryChatClient(async_credential=credential).create_agent( + name="WeatherAgent", + instructions="You are a helpful weather assistant. Use the weather tool to get current conditions.", + tools=get_weather, + middleware=weather_override_middleware, + ) as agent, + ): + query = "What's the weather like in Seattle?" + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result}") + + +if __name__ == "__main__": + asyncio.run(main())