Python: Agent and Function middleware (#770)

* Initial middleware implementation

* Small fixes

* Small updates

* Small updates in samples

* Moved middleware functionality to decorator

* Removed obsolete file

* Renamed AgentInvocationContext to AzureRunContext

* Added unit tests

* Small settings update for test discovery in VS Code

* Added unit tests

* Reverted changes in environment settings

* Added context result override

* Renaming and updates to logic

* Added more samples

* Updated DEV_SETUP.md

* Addressed PR feedback

* Addressed PR feedback

* Removed unused parameter

* Small fix

* Small fix in telemetry logic

* Revert "Small fix in telemetry logic"

This reverts commit 6f82660d2d.

* Small fix

---------

Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2025-09-18 16:30:05 -07:00
committed by GitHub
Unverified
parent 538be4c149
commit 99860a5d07
16 changed files with 3071 additions and 40 deletions
@@ -13,6 +13,7 @@ from ._clients import * # noqa: F403
from ._logging import * # noqa: F403
from ._mcp import * # noqa: F403
from ._memory import * # noqa: F403
from ._middleware import * # noqa: F403
from ._threads import * # noqa: F403
from ._tools import * # noqa: F403
from ._types import * # noqa: F403
@@ -12,6 +12,7 @@ from ._clients import BaseChatClient, ChatClientProtocol
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, Context, ContextProvider
from ._middleware import Middleware, use_agent_middleware
from ._pydantic import AFBaseModel
from ._threads import AgentThread, ChatMessageStore, deserialize_thread_state, thread_on_new_messages
from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, ToolProtocol
@@ -138,12 +139,14 @@ class BaseAgent(AFBaseModel):
description: The description of the agent.
display_name: The display name of the agent, which is either the name or id.
context_providers: The collection of multiple context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
name: str | None = None
description: str | None = None
context_providers: AggregateContextProvider | None = None
middleware: Middleware | list[Middleware] | None = None
async def _notify_thread_of_new_messages(
self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]
@@ -189,6 +192,7 @@ class BaseAgent(AFBaseModel):
# region ChatAgent
@use_agent_middleware
@use_agent_telemetry
class ChatAgent(BaseAgent):
"""A Chat Client Agent."""
@@ -231,6 +235,7 @@ class ChatAgent(BaseAgent):
additional_properties: dict[str, Any] | None = None,
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
) -> None:
"""Create a ChatAgent.
@@ -266,6 +271,7 @@ class ChatAgent(BaseAgent):
chat_message_store_factory: factory function to create an instance of ChatMessageStore. If not provided,
the default in-memory store will be used.
context_providers: The collection of multiple context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
kwargs: any additional keyword arguments.
Unused, can be used by subclasses of this Agent.
"""
@@ -287,6 +293,7 @@ class ChatAgent(BaseAgent):
"chat_client": chat_client,
"chat_message_store_factory": chat_message_store_factory,
"context_providers": aggregate_context_providers,
"middleware": middleware,
"chat_options": ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, ContextProvider
from ._middleware import Middleware
from ._pydantic import AFBaseModel
from ._threads import ChatMessageStore
from ._tools import ToolProtocol
@@ -344,7 +345,14 @@ class BaseChatClient(AFBaseModel, ABC):
)
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
async def get_streaming_response(
self,
@@ -424,8 +432,15 @@ class BaseChatClient(AFBaseModel, ABC):
)
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
async for update in self._inner_get_streaming_response(
messages=prepped_messages, chat_options=chat_options, **kwargs
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
):
yield update
@@ -465,6 +480,7 @@ class BaseChatClient(AFBaseModel, ABC):
| None = None,
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
) -> "ChatAgent":
"""Create an agent with the given name and instructions.
@@ -476,6 +492,7 @@ class BaseChatClient(AFBaseModel, ABC):
chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided,
the default in-memory store will be used.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
**kwargs: Additional keyword arguments to pass to the agent.
See ChatAgent for all the available options.
@@ -491,6 +508,7 @@ class BaseChatClient(AFBaseModel, ABC):
tools=tools,
chat_message_store_factory=chat_message_store_factory,
context_providers=context_providers,
middleware=middleware,
**kwargs,
)
@@ -0,0 +1,632 @@
# Copyright (c) Microsoft. All rights reserved.
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
if TYPE_CHECKING:
from pydantic import BaseModel
from ._agents import AgentProtocol
from ._tools import AIFunction
TAgent = TypeVar("TAgent", bound="AgentProtocol")
__all__ = [
"AgentMiddleware",
"AgentRunContext",
"FunctionInvocationContext",
"FunctionMiddleware",
"Middleware",
"use_agent_middleware",
]
@dataclass
class AgentRunContext:
"""Context object for agent middleware invocations.
Attributes:
agent: The agent being invoked.
messages: The messages being sent to the agent.
is_streaming: Whether this is a streaming invocation.
metadata: Metadata dictionary for sharing data between agent middleware.
result: Agent execution result. Can be observed after calling next()
to see the actual execution result or can be set to override the execution result.
For non-streaming: should be AgentRunResponse
For streaming: should be AsyncIterable[AgentRunResponseUpdate]
"""
agent: "AgentProtocol"
messages: list[ChatMessage]
is_streaming: bool = False
metadata: dict[str, Any] = field(default_factory=lambda: {})
result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None
@dataclass
class FunctionInvocationContext:
"""Context object for function middleware invocations.
Attributes:
function: The function being invoked.
arguments: The validated arguments for the function.
metadata: Metadata dictionary for sharing data between function middleware.
result: Function execution result. Can be observed after calling next()
to see the actual execution result or can be set to override the execution result.
"""
function: "AIFunction[Any, Any]"
arguments: "BaseModel"
metadata: dict[str, Any] = field(default_factory=lambda: {})
result: Any = None
class AgentMiddleware(ABC):
"""Abstract base class for agent middleware that can intercept agent invocations."""
@abstractmethod
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Process an agent invocation.
Args:
context: Agent invocation context containing agent, messages, and metadata.
Use context.is_streaming to determine if this is a streaming call.
Middleware can set context.result to override execution, or observe
the actual execution result after calling next().
For non-streaming: AgentRunResponse
For streaming: AsyncIterable[AgentRunResponseUpdate]
next: Function to call the next middleware or final agent execution.
Does not return anything - all data flows through the context.
Note:
Middleware should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
"""
...
class FunctionMiddleware(ABC):
"""Abstract base class for function middleware that can intercept function invocations."""
@abstractmethod
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Process a function invocation.
Args:
context: Function invocation context containing function, arguments, and metadata.
Middleware can set context.result to override execution, or observe
the actual execution result after calling next().
next: Function to call the next middleware or final function execution.
Does not return anything - all data flows through the context.
Note:
Middleware should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
"""
...
# Pure function type definitions for convenience
AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]]
FunctionMiddlewareCallable = Callable[
[FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None]
]
# Type alias for all middleware types
Middleware: TypeAlias = AgentMiddleware | AgentMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable
class AgentMiddlewareWrapper(AgentMiddleware):
"""Wrapper to convert pure functions into AgentMiddleware protocol objects."""
def __init__(self, func: AgentMiddlewareCallable):
self.func = func
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
await self.func(context, next)
class FunctionMiddlewareWrapper(FunctionMiddleware):
"""Wrapper to convert pure functions into FunctionMiddleware protocol objects."""
def __init__(self, func: FunctionMiddlewareCallable):
self.func = func
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
await self.func(context, next)
class BaseMiddlewarePipeline(ABC):
"""Base class for middleware pipeline execution."""
def __init__(self) -> None:
"""Initialize the base middleware pipeline."""
self._middlewares: list[Any] = []
@abstractmethod
def _register_middleware(self, middleware: Any) -> None:
"""Register a middleware item. Must be implemented by subclasses."""
...
@property
def has_middlewares(self) -> bool:
"""Check if there are any middlewares registered."""
return bool(self._middlewares)
def _create_handler_chain(
self,
final_handler: Callable[[Any], Awaitable[Any]],
result_container: dict[str, Any],
result_key: str = "result",
) -> Callable[[Any], Awaitable[None]]:
"""Create a chain of middleware handlers.
Args:
final_handler: The final handler to execute
result_container: Container to store the result
result_key: Key to use in the result container
Returns:
The first handler in the chain
"""
def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]:
if index >= len(self._middlewares):
async def final_wrapper(c: Any) -> None:
# Execute actual handler and populate context for observability
result = await final_handler(c)
result_container[result_key] = result
c.result = result
return final_wrapper
middleware = self._middlewares[index]
next_handler = create_next_handler(index + 1)
async def current_handler(c: Any) -> None:
await middleware.process(c, next_handler)
return current_handler
return create_next_handler(0)
class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
"""Executes agent middleware in a chain."""
def __init__(self, middlewares: list[AgentMiddleware | AgentMiddlewareCallable] | None = None):
"""Initialize the agent middleware pipeline.
Args:
middlewares: List of agent middleware to include in the pipeline.
"""
super().__init__()
self._middlewares: list[AgentMiddleware] = []
if middlewares:
for middleware in middlewares:
self._register_middleware(middleware)
def _register_middleware(self, middleware: AgentMiddleware | AgentMiddlewareCallable) -> None:
"""Register an agent middleware item."""
if isinstance(middleware, AgentMiddleware):
self._middlewares.append(middleware)
elif callable(middleware):
self._middlewares.append(AgentMiddlewareWrapper(middleware))
async def execute(
self,
agent: "AgentProtocol",
messages: list[ChatMessage],
context: AgentRunContext,
final_handler: Callable[[AgentRunContext], Awaitable[AgentRunResponse]],
) -> AgentRunResponse | None:
"""Execute the agent middleware pipeline for non-streaming.
Args:
agent: The agent being invoked.
messages: The messages to send to the agent.
context: The agent invocation context.
final_handler: The final handler that performs the actual agent execution.
Returns:
The agent response after processing through all middleware.
"""
# Update context with agent and messages
context.agent = agent
context.messages = messages
context.is_streaming = False
if not self._middlewares:
return await final_handler(context)
# Store the final result
result_container: dict[str, AgentRunResponse | None] = {"response": None}
def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]:
if index >= len(self._middlewares):
async def final_wrapper(c: AgentRunContext) -> None:
# Execute actual handler and populate context for observability
result = await final_handler(c)
result_container["result"] = result
c.result = result
return final_wrapper
middleware = self._middlewares[index]
next_handler = create_next_handler(index + 1)
async def current_handler(c: AgentRunContext) -> None:
await middleware.process(c, next_handler)
# After middleware execution, check if response was overridden
if c.result is not None and isinstance(c.result, AgentRunResponse):
result_container["result"] = c.result
return current_handler
first_handler = create_next_handler(0)
await first_handler(context)
# Return the result from result container or overridden result
if context.result is not None and isinstance(context.result, AgentRunResponse):
return context.result
# If no result was set (next() not called), return empty AgentRunResponse
response = result_container.get("result")
if response is None:
return AgentRunResponse()
return response
async def execute_stream(
self,
agent: "AgentProtocol",
messages: list[ChatMessage],
context: AgentRunContext,
final_handler: Callable[[AgentRunContext], AsyncIterable[AgentRunResponseUpdate]],
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Execute the agent middleware pipeline for streaming.
Args:
agent: The agent being invoked.
messages: The messages to send to the agent.
context: The agent invocation context.
final_handler: The final handler that performs the actual agent streaming execution.
Yields:
Agent response updates after processing through all middleware.
"""
# Update context with agent and messages
context.agent = agent
context.messages = messages
context.is_streaming = True
if not self._middlewares:
async for update in final_handler(context):
yield update
return
# Store the final result
result_container: dict[str, AsyncIterable[AgentRunResponseUpdate] | None] = {"result_stream": None}
def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]:
if index >= len(self._middlewares):
async def final_wrapper(c: AgentRunContext) -> None: # noqa: RUF029
# Execute actual handler and populate context for observability
result = final_handler(c)
result_container["result_stream"] = result
c.result = result
return final_wrapper
middleware = self._middlewares[index]
next_handler = create_next_handler(index + 1)
async def current_handler(c: AgentRunContext) -> None:
await middleware.process(c, next_handler)
return current_handler
first_handler = create_next_handler(0)
await first_handler(context)
# Yield from the result stream in result container or overridden result
if context.result is not None and hasattr(context.result, "__aiter__"):
async for update in context.result: # type: ignore
yield update
return
result_stream = result_container["result_stream"]
if result_stream is None:
# If no result stream was set (next() not called), yield nothing
return
async for update in result_stream:
yield update
class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
"""Executes function middleware in a chain."""
def __init__(self, middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] | None = None):
"""Initialize the function middleware pipeline.
Args:
middlewares: List of function middleware to include in the pipeline.
"""
super().__init__()
self._middlewares: list[FunctionMiddleware] = []
if middlewares:
for middleware in middlewares:
self._register_middleware(middleware)
def _register_middleware(self, middleware: FunctionMiddleware | FunctionMiddlewareCallable) -> None:
"""Register a function middleware item."""
# Check if it's a class instance inheriting from FunctionMiddleware
if isinstance(middleware, FunctionMiddleware):
self._middlewares.append(middleware)
elif callable(middleware):
self._middlewares.append(FunctionMiddlewareWrapper(middleware))
async def execute(
self,
function: Any,
arguments: "BaseModel",
context: FunctionInvocationContext,
final_handler: Callable[[FunctionInvocationContext], Awaitable[Any]],
) -> Any:
"""Execute the function middleware pipeline.
Args:
function: The function being invoked.
arguments: The validated arguments for the function.
context: The function invocation context.
final_handler: The final handler that performs the actual function execution.
Returns:
The function result after processing through all middleware.
"""
# Update context with function and arguments
context.function = function
context.arguments = arguments
if not self._middlewares:
return await final_handler(context)
# Store the final result
result_container: dict[str, Any] = {"result": None}
# Custom final handler that handles pre-existing results
async def function_final_handler(c: FunctionInvocationContext) -> Any:
# If result was set before calling next(), skip execution
if c.result is not None:
return c.result
# Execute actual handler and populate context for observability
return await final_handler(c)
first_handler = self._create_handler_chain(function_final_handler, result_container, "result")
await first_handler(context)
# Return the result from result container or overridden result
if context.result is not None:
return context.result
return result_container["result"]
# Decorator for adding middleware support to agent classes
def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
"""Class decorator that adds middleware support to an agent class.
This decorator adds middleware functionality to any agent class.
It wraps the run() and run_stream() methods to provide middleware execution.
Args:
agent_class: The agent class to add middleware support to.
Returns:
The modified agent class with middleware support.
"""
import inspect
# Store original methods
original_run = agent_class.run # type: ignore[attr-defined]
original_run_stream = agent_class.run_stream # type: ignore[attr-defined]
def _initialize_middleware_pipelines(self: Any, middlewares: Middleware | list[Middleware] | None) -> None:
"""Initialize agent and function middleware pipelines from the provided middleware list."""
if not middlewares:
return
middleware_list: list[Middleware] = middlewares if isinstance(middlewares, list) else [middlewares] # type: ignore
# Separate agent and function middleware using isinstance checks
agent_middlewares: list[AgentMiddleware | AgentMiddlewareCallable] = []
function_middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] = []
for middleware in middleware_list:
if isinstance(middleware, AgentMiddleware):
agent_middlewares.append(middleware)
elif isinstance(middleware, FunctionMiddleware):
function_middlewares.append(middleware)
elif callable(middleware): # type: ignore[arg-type]
# Check function signature to determine type
try:
sig = inspect.signature(middleware)
params = list(sig.parameters.values())
if len(params) >= 1:
first_param = params[0]
# Check if first parameter is AgentRunContext or FunctionInvocationContext
if (
hasattr(first_param.annotation, "__name__")
and first_param.annotation.__name__ == "AgentRunContext"
):
agent_middlewares.append(middleware) # type: ignore
elif (
hasattr(first_param.annotation, "__name__")
and first_param.annotation.__name__ == "FunctionInvocationContext"
):
function_middlewares.append(middleware) # type: ignore
else:
# Default to agent middleware if uncertain
agent_middlewares.append(middleware) # type: ignore
else:
agent_middlewares.append(middleware) # type: ignore
except Exception:
# If signature inspection fails, assume it's an agent middleware
agent_middlewares.append(middleware) # type: ignore
else:
# Fallback
agent_middlewares.append(middleware) # type: ignore
self._agent_middleware_pipeline = AgentMiddlewarePipeline(agent_middlewares)
self._function_middleware_pipeline = FunctionMiddlewarePipeline(function_middlewares)
async def middleware_enabled_run(
self: Any,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: Any = None,
**kwargs: Any,
) -> AgentRunResponse:
"""Middleware-enabled run method."""
# Initialize middleware pipelines if not already done
if (
hasattr(self, "middleware")
and self.middleware
and not (
hasattr(self, "_agent_middleware_pipeline")
and hasattr(self, "_function_middleware_pipeline")
and (
self._agent_middleware_pipeline.has_middlewares
or self._function_middleware_pipeline.has_middlewares
)
)
):
_initialize_middleware_pipelines(self, self.middleware)
# Ensure pipelines exist even if empty
if not hasattr(self, "_agent_middleware_pipeline"):
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
if not hasattr(self, "_function_middleware_pipeline"):
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
# Add function middleware pipeline to kwargs if available
if self._function_middleware_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
normalized_messages = self._normalize_messages(messages)
# Execute with middleware if available
if self._agent_middleware_pipeline.has_middlewares:
context = AgentRunContext(
agent=self, # type: ignore[arg-type]
messages=normalized_messages,
is_streaming=False,
)
async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse:
return await original_run(self, ctx.messages, thread=thread, **kwargs) # type: ignore
result = await self._agent_middleware_pipeline.execute(
self, # type: ignore[arg-type]
normalized_messages,
context,
_execute_handler,
)
return result if result else AgentRunResponse()
# No middleware, execute directly
return await original_run(self, normalized_messages, thread=thread, **kwargs) # type: ignore[return-value]
def middleware_enabled_run_stream(
self: Any,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: Any = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Middleware-enabled run_stream method."""
# Initialize middleware pipelines if not already done
if (
hasattr(self, "middleware")
and self.middleware
and not (
hasattr(self, "_agent_middleware_pipeline")
and hasattr(self, "_function_middleware_pipeline")
and (
self._agent_middleware_pipeline.has_middlewares
or self._function_middleware_pipeline.has_middlewares
)
)
):
_initialize_middleware_pipelines(self, self.middleware)
# Ensure pipelines exist even if empty
if not hasattr(self, "_agent_middleware_pipeline"):
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
if not hasattr(self, "_function_middleware_pipeline"):
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
# Add function middleware pipeline to kwargs if available
if self._function_middleware_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
normalized_messages = self._normalize_messages(messages)
# Execute with middleware if available
if self._agent_middleware_pipeline.has_middlewares:
context = AgentRunContext(
agent=self, # type: ignore[arg-type]
messages=normalized_messages,
is_streaming=True,
)
async def _execute_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
async for update in original_run_stream(self, ctx.messages, thread=thread, **kwargs): # type: ignore[misc]
yield update
async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]:
async for update in self._agent_middleware_pipeline.execute_stream(
self, # type: ignore[arg-type]
normalized_messages,
context,
_execute_stream_handler,
):
yield update
return _stream_generator()
# No middleware, execute directly
return original_run_stream(self, normalized_messages, thread=thread, **kwargs) # type: ignore
agent_class.run = middleware_enabled_run # type: ignore
agent_class.run_stream = middleware_enabled_run_stream # type: ignore
return agent_class
+45 -7
View File
@@ -576,6 +576,7 @@ async def _auto_invoke_function(
tool_map: dict[str, AIFunction[BaseModel, Any]],
sequence_index: int | None = None,
request_index: int | None = None,
middleware_pipeline: Any = None, # Optional MiddlewarePipeline
) -> "Contents":
"""Invoke a function call requested by the agent, applying filters that are defined in the agent."""
from ._types import FunctionResultContent
@@ -590,14 +591,43 @@ async def _auto_invoke_function(
merged_args: dict[str, Any] = (custom_args or {}) | parsed_args
args = tool.input_model.model_validate(merged_args)
exception = None
try:
function_result = await tool.invoke(
# Execute through middleware pipeline if available
if middleware_pipeline and hasattr(middleware_pipeline, "has_middlewares") and middleware_pipeline.has_middlewares:
from ._middleware import FunctionInvocationContext
middleware_context = FunctionInvocationContext(
function=tool,
arguments=args,
tool_call_id=function_call_content.call_id,
) # type: ignore[arg-type]
except Exception as ex:
exception = ex
function_result = None
)
async def final_function_handler(context_obj: Any) -> Any:
return await tool.invoke(
arguments=context_obj.arguments,
tool_call_id=function_call_content.call_id,
)
try:
function_result = await middleware_pipeline.execute(
function=tool,
arguments=args,
context=middleware_context,
final_handler=final_function_handler,
)
except Exception as ex:
exception = ex
function_result = None
else:
# No middleware - execute directly
try:
function_result = await tool.invoke(
arguments=args,
tool_call_id=function_call_content.call_id,
) # type: ignore[arg-type]
except Exception as ex:
exception = ex
function_result = None
return FunctionResultContent(
call_id=function_call_content.call_id,
exception=exception,
@@ -631,6 +661,7 @@ async def execute_function_calls(
| Callable[..., Any] \
| MutableMapping[str, Any] \
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]",
middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports
) -> list["Contents"]:
tool_map = _get_tool_map(tools)
# Run all function calls concurrently
@@ -641,6 +672,7 @@ async def execute_function_calls(
tool_map=tool_map,
sequence_index=seq_idx,
request_index=attempt_idx,
middleware_pipeline=middleware_pipeline,
)
for seq_idx, function_call in enumerate(function_calls)
])
@@ -706,11 +738,14 @@ def _handle_function_calls_response(
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
function_calls=function_calls,
tools=tools, # type: ignore
middleware_pipeline=middleware_pipeline,
)
# add a single ChatMessage to the response with the results
result_message = ChatMessage(role="tool", contents=function_results) # type: ignore[call-overload]
@@ -815,11 +850,14 @@ def _handle_function_calls_streaming_response(
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
function_calls=function_calls,
tools=tools, # type: ignore[reportArgumentType]
middleware_pipeline=middleware_pipeline,
)
function_result_msg = ChatMessage(role="tool", contents=function_results)
yield ChatResponseUpdate(contents=function_results, role="tool")
@@ -1,27 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Generic, Protocol, TypeVar, runtime_checkable
TInput = TypeVar("TInput")
TResponse = TypeVar("TResponse")
__all__ = ["InputGuardrail", "OutputGuardrail"]
@runtime_checkable
class InputGuardrail(Protocol, Generic[TInput]):
"""A protocol for input guardrails that can validate and transform input messages."""
def __call__(self, message: TInput) -> TInput:
"""Validate and possibly transform the input message."""
...
@runtime_checkable
class OutputGuardrail(Protocol, Generic[TResponse]):
"""A protocol for output guardrails that can validate and transform output messages."""
def __call__(self, message: TResponse) -> TResponse:
"""Validate and possibly transform the output message."""
...
@@ -1026,8 +1026,12 @@ def _capture_messages(
prepped = prepare_messages(messages)
for index, message in enumerate(prepped):
try:
message_data = message.model_dump(exclude_none=True)
except Exception:
message_data = {"role": message.role.value, "contents": message.contents}
logger.info(
message.model_dump_json(exclude_none=True),
message_data,
extra={
OtelAttr.EVENT_NAME: OtelAttr.CHOICE if output else ROLE_EVENT_MAP.get(message.role.value),
OtelAttr.PROVIDER_NAME: provider_name,
+9 -1
View File
@@ -72,7 +72,15 @@ environments = [
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
testpaths = [
'tests',
'packages/main/tests',
'packages/azure/tests',
'packages/foundry/tests',
'packages/copilotstudio/tests',
'packages/mem0/tests',
'packages/runtime/tests'
]
addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
+13 -2
View File
@@ -105,6 +105,9 @@ class MockChatClient:
def __init__(self) -> None:
self.additional_properties: dict[str, Any] = {}
self.call_count: int = 0
self.responses: list[ChatResponse] = []
self.streaming_responses: list[list[ChatResponseUpdate]] = []
async def get_response(
self,
@@ -112,6 +115,9 @@ class MockChatClient:
**kwargs: Any,
) -> ChatResponse:
logger.debug(f"Running custom chat client, with: {messages=}, {kwargs=}")
self.call_count += 1
if self.responses:
return self.responses.pop(0)
return ChatResponse(messages=ChatMessage(role="assistant", text="test response"))
async def get_streaming_response(
@@ -120,8 +126,13 @@ class MockChatClient:
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
logger.debug(f"Running custom chat client stream, with: {messages=}, {kwargs=}")
yield ChatResponseUpdate(text=TextContent(text="test streaming response "), role="assistant")
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
self.call_count += 1
if self.streaming_responses:
for update in self.streaming_responses.pop(0):
yield update
else:
yield ChatResponseUpdate(text=TextContent(text="test streaming response "), role="assistant")
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
class MockBaseChatClient(BaseChatClient):
@@ -0,0 +1,939 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel, Field
from agent_framework import (
AgentProtocol,
AgentRunResponse,
AgentRunResponseUpdate,
ChatMessage,
Role,
TextContent,
)
from agent_framework._middleware import (
AgentMiddleware,
AgentMiddlewarePipeline,
AgentRunContext,
FunctionInvocationContext,
FunctionMiddleware,
FunctionMiddlewarePipeline,
)
from agent_framework._tools import AIFunction
class TestAgentRunContext:
"""Test cases for AgentRunContext."""
def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None:
"""Test AgentRunContext initialization with default values."""
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
assert context.agent is mock_agent
assert context.messages == messages
assert context.is_streaming is False
assert context.metadata == {}
def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None:
"""Test AgentRunContext initialization with custom values."""
messages = [ChatMessage(role=Role.USER, text="test")]
metadata = {"key": "value"}
context = AgentRunContext(agent=mock_agent, messages=messages, is_streaming=True, metadata=metadata)
assert context.agent is mock_agent
assert context.messages == messages
assert context.is_streaming is True
assert context.metadata == metadata
class TestFunctionInvocationContext:
"""Test cases for FunctionInvocationContext."""
def test_init_with_defaults(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test FunctionInvocationContext initialization with default values."""
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
assert context.function is mock_function
assert context.arguments == arguments
assert context.metadata == {}
def test_init_with_custom_metadata(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test FunctionInvocationContext initialization with custom metadata."""
arguments = FunctionTestArgs(name="test")
metadata = {"key": "value"}
context = FunctionInvocationContext(function=mock_function, arguments=arguments, metadata=metadata)
assert context.function is mock_function
assert context.arguments == arguments
assert context.metadata == metadata
class TestAgentMiddlewarePipeline:
"""Test cases for AgentMiddlewarePipeline."""
def test_init_empty(self) -> None:
"""Test AgentMiddlewarePipeline initialization with no middlewares."""
pipeline = AgentMiddlewarePipeline()
assert not pipeline.has_middlewares
def test_init_with_class_middleware(self) -> None:
"""Test AgentMiddlewarePipeline initialization with class-based middleware."""
middleware = TestAgentMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
assert pipeline.has_middlewares
def test_init_with_function_middleware(self) -> None:
"""Test AgentMiddlewarePipeline initialization with function-based middleware."""
async def test_middleware(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
await next(context)
pipeline = AgentMiddlewarePipeline([test_middleware])
assert pipeline.has_middlewares
async def test_execute_no_middleware(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline execution with no middleware."""
pipeline = AgentMiddlewarePipeline()
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
return expected_response
result = await pipeline.execute(mock_agent, messages, context, final_handler)
assert result == expected_response
async def test_execute_with_middleware(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline execution with middleware."""
execution_order: list[str] = []
class OrderTrackingMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
middleware = OrderTrackingMiddleware("test")
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
execution_order.append("handler")
return expected_response
result = await pipeline.execute(mock_agent, messages, context, final_handler)
assert result == expected_response
assert execution_order == ["test_before", "handler", "test_after"]
async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline streaming execution with no middleware."""
pipeline = AgentMiddlewarePipeline()
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
updates: list[AgentRunResponseUpdate] = []
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
async def test_execute_stream_with_middleware(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline streaming execution with middleware."""
execution_order: list[str] = []
class StreamOrderTrackingMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
middleware = StreamOrderTrackingMiddleware("test")
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
execution_order.append("handler_start")
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[AgentRunResponseUpdate] = []
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
class TestFunctionMiddlewarePipeline:
"""Test cases for FunctionMiddlewarePipeline."""
def test_init_empty(self) -> None:
"""Test FunctionMiddlewarePipeline initialization with no middlewares."""
pipeline = FunctionMiddlewarePipeline()
assert not pipeline.has_middlewares
def test_init_with_class_middleware(self) -> None:
"""Test FunctionMiddlewarePipeline initialization with class-based middleware."""
middleware = TestFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
assert pipeline.has_middlewares
def test_init_with_function_middleware(self) -> None:
"""Test FunctionMiddlewarePipeline initialization with function-based middleware."""
async def test_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
await next(context)
pipeline = FunctionMiddlewarePipeline([test_middleware])
assert pipeline.has_middlewares
async def test_execute_no_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test pipeline execution with no middleware."""
pipeline = FunctionMiddlewarePipeline()
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
expected_result = "function_result"
async def final_handler(ctx: FunctionInvocationContext) -> str:
return expected_result
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == expected_result
async def test_execute_with_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test pipeline execution with middleware."""
execution_order: list[str] = []
class OrderTrackingFunctionMiddleware(FunctionMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
middleware = OrderTrackingFunctionMiddleware("test")
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
expected_result = "function_result"
async def final_handler(ctx: FunctionInvocationContext) -> str:
execution_order.append("handler")
return expected_result
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == expected_result
assert execution_order == ["test_before", "handler", "test_after"]
class TestClassBasedMiddleware:
"""Test cases for class-based middleware implementations."""
async def test_agent_middleware_execution(self, mock_agent: AgentProtocol) -> None:
"""Test class-based agent middleware execution."""
metadata_updates: list[str] = []
class MetadataAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
context.metadata["before"] = True
metadata_updates.append("before")
await next(context)
context.metadata["after"] = True
metadata_updates.append("after")
middleware = MetadataAgentMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
metadata_updates.append("handler")
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
assert result is not None
assert context.metadata["before"] is True
assert context.metadata["after"] is True
assert metadata_updates == ["before", "handler", "after"]
async def test_function_middleware_execution(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test class-based function middleware execution."""
metadata_updates: list[str] = []
class MetadataFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
context.metadata["before"] = True
metadata_updates.append("before")
await next(context)
context.metadata["after"] = True
metadata_updates.append("after")
middleware = MetadataFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
async def final_handler(ctx: FunctionInvocationContext) -> str:
metadata_updates.append("handler")
return "result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "result"
assert context.metadata["before"] is True
assert context.metadata["after"] is True
assert metadata_updates == ["before", "handler", "after"]
class TestFunctionBasedMiddleware:
"""Test cases for function-based middleware implementations."""
async def test_agent_function_middleware(self, mock_agent: AgentProtocol) -> None:
"""Test function-based agent middleware."""
execution_order: list[str] = []
async def test_agent_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("function_before")
context.metadata["function_middleware"] = True
await next(context)
execution_order.append("function_after")
pipeline = AgentMiddlewarePipeline([test_agent_middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
execution_order.append("handler")
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
assert result is not None
assert context.metadata["function_middleware"] is True
assert execution_order == ["function_before", "handler", "function_after"]
async def test_function_function_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test function-based function middleware."""
execution_order: list[str] = []
async def test_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("function_before")
context.metadata["function_middleware"] = True
await next(context)
execution_order.append("function_after")
pipeline = FunctionMiddlewarePipeline([test_function_middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
async def final_handler(ctx: FunctionInvocationContext) -> str:
execution_order.append("handler")
return "result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "result"
assert context.metadata["function_middleware"] is True
assert execution_order == ["function_before", "handler", "function_after"]
class TestMixedMiddleware:
"""Test cases for mixed class and function-based middleware."""
async def test_mixed_agent_middleware(self, mock_agent: AgentProtocol) -> None:
"""Test mixed class and function-based agent middleware."""
execution_order: list[str] = []
class ClassMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("class_before")
await next(context)
execution_order.append("class_after")
async def function_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("function_before")
await next(context)
execution_order.append("function_after")
pipeline = AgentMiddlewarePipeline([ClassMiddleware(), function_middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
execution_order.append("handler")
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
assert result is not None
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
async def test_mixed_function_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test mixed class and function-based function middleware."""
execution_order: list[str] = []
class ClassMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append("class_before")
await next(context)
execution_order.append("class_after")
async def function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("function_before")
await next(context)
execution_order.append("function_after")
pipeline = FunctionMiddlewarePipeline([ClassMiddleware(), function_middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
async def final_handler(ctx: FunctionInvocationContext) -> str:
execution_order.append("handler")
return "result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "result"
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
class TestMultipleMiddlewareOrdering:
"""Test cases for multiple middleware execution order."""
async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) -> None:
"""Test that multiple agent middlewares execute in registration order."""
execution_order: list[str] = []
class FirstMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
class SecondMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
class ThirdMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("third_before")
await next(context)
execution_order.append("third_after")
middlewares = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()]
pipeline = AgentMiddlewarePipeline(middlewares) # type: ignore
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
execution_order.append("handler")
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
assert result is not None
expected_order = [
"first_before",
"second_before",
"third_before",
"handler",
"third_after",
"second_after",
"first_after",
]
assert execution_order == expected_order
async def test_function_middleware_execution_order(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test that multiple function middlewares execute in registration order."""
execution_order: list[str] = []
class FirstMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
class SecondMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
middlewares = [FirstMiddleware(), SecondMiddleware()]
pipeline = FunctionMiddlewarePipeline(middlewares) # type: ignore
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
async def final_handler(ctx: FunctionInvocationContext) -> str:
execution_order.append("handler")
return "result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "result"
expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"]
assert execution_order == expected_order
class TestContextContentValidation:
"""Test cases for validating middleware context content."""
async def test_agent_context_validation(self, mock_agent: AgentProtocol) -> None:
"""Test that agent context contains expected data."""
class ContextValidationMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Verify context has all expected attributes
assert hasattr(context, "agent")
assert hasattr(context, "messages")
assert hasattr(context, "is_streaming")
assert hasattr(context, "metadata")
# Verify context content
assert context.agent is mock_agent
assert len(context.messages) == 1
assert context.messages[0].role == Role.USER
assert context.messages[0].text == "test"
assert context.is_streaming is False
assert isinstance(context.metadata, dict)
# Add custom metadata
context.metadata["validated"] = True
await next(context)
middleware = ContextValidationMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
# Verify metadata was set by middleware
assert ctx.metadata.get("validated") is True
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
assert result is not None
async def test_function_context_validation(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test that function context contains expected data."""
class ContextValidationMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Verify context has all expected attributes
assert hasattr(context, "function")
assert hasattr(context, "arguments")
assert hasattr(context, "metadata")
# Verify context content
assert context.function is mock_function
assert isinstance(context.arguments, FunctionTestArgs)
assert context.arguments.name == "test"
assert isinstance(context.metadata, dict)
# Add custom metadata
context.metadata["validated"] = True
await next(context)
middleware = ContextValidationMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
async def final_handler(ctx: FunctionInvocationContext) -> str:
# Verify metadata was set by middleware
assert ctx.metadata.get("validated") is True
return "result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "result"
class TestStreamingScenarios:
"""Test cases for streaming and non-streaming scenarios."""
async def test_streaming_flag_validation(self, mock_agent: AgentProtocol) -> None:
"""Test that is_streaming flag is correctly set for streaming calls."""
streaming_flags: list[bool] = []
class StreamingFlagMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
streaming_flags.append(context.is_streaming)
await next(context)
middleware = StreamingFlagMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
# Test non-streaming
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
streaming_flags.append(ctx.is_streaming)
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
await pipeline.execute(mock_agent, messages, context, final_handler)
# Test streaming
context_stream = AgentRunContext(agent=mock_agent, messages=messages)
async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
streaming_flags.append(ctx.is_streaming)
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk")])
updates: list[AgentRunResponseUpdate] = []
async for update in pipeline.execute_stream(mock_agent, messages, context_stream, final_stream_handler):
updates.append(update)
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
assert streaming_flags == [False, False, True, True]
async def test_streaming_middleware_behavior(self, mock_agent: AgentProtocol) -> None:
"""Test middleware behavior with streaming responses."""
chunks_processed: list[str] = []
class StreamProcessingMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
chunks_processed.append("before_stream")
await next(context)
chunks_processed.append("after_stream")
middleware = StreamProcessingMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
chunks_processed.append("stream_start")
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
chunks_processed.append("chunk1_yielded")
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
chunks_processed.append("chunk2_yielded")
chunks_processed.append("stream_end")
updates: list[str] = []
async for update in pipeline.execute_stream(mock_agent, messages, context, final_stream_handler):
updates.append(update.text)
assert updates == ["chunk1", "chunk2"]
assert chunks_processed == [
"before_stream",
"after_stream",
"stream_start",
"chunk1_yielded",
"chunk2_yielded",
"stream_end",
]
# region Helper classes and fixtures
class FunctionTestArgs(BaseModel):
"""Test arguments for function middleware tests."""
name: str = Field(description="Test name parameter")
class TestAgentMiddleware(AgentMiddleware):
"""Test implementation of AgentMiddleware."""
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
await next(context)
class TestFunctionMiddleware(FunctionMiddleware):
"""Test implementation of FunctionMiddleware."""
async def process(
self, context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
await next(context)
class MockFunctionArgs(BaseModel):
"""Test arguments for function middleware tests."""
name: str = Field(description="Test name parameter")
class TestMiddlewareExecutionControl:
"""Test cases for middleware execution control (when next() is called vs not called)."""
async def test_agent_middleware_no_next_no_execution(self, mock_agent: AgentProtocol) -> None:
"""Test that when agent middleware doesn't call next(), no execution happens."""
class NoNextMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Don't call next() - this should prevent any execution
pass
middleware = NoNextMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
handler_called = False
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
nonlocal handler_called
handler_called = True
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
# Verify no execution happened - should return empty AgentRunResponse
assert result is not None
assert isinstance(result, AgentRunResponse)
assert result.messages == [] # Empty response
assert not handler_called
assert context.result is None
async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: AgentProtocol) -> None:
"""Test that when agent middleware doesn't call next(), no streaming execution happens."""
class NoNextStreamingMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Don't call next() - this should prevent any execution
pass
middleware = NoNextStreamingMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
handler_called = False
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
nonlocal handler_called
handler_called = True
yield AgentRunResponseUpdate(contents=[TextContent(text="should not execute")])
# When middleware doesn't call next(), streaming should yield no updates
updates: list[AgentRunResponseUpdate] = []
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
updates.append(update)
# Verify no execution happened and no updates were yielded
assert len(updates) == 0
assert not handler_called
assert context.result is None
async def test_function_middleware_no_next_no_execution(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test that when function middleware doesn't call next(), no execution happens."""
class FunctionTestArgs(BaseModel):
name: str = Field(description="Test name parameter")
class NoNextFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Don't call next() - this should prevent any execution
pass
middleware = NoNextFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
handler_called = False
async def final_handler(ctx: FunctionInvocationContext) -> str:
nonlocal handler_called
handler_called = True
return "should not execute"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
# Verify no execution happened
assert result is None
assert not handler_called
assert context.result is None
async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) -> None:
"""Test that when first middleware doesn't call next(), subsequent middlewares are not called."""
execution_order: list[str] = []
class FirstMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("first")
# Don't call next() - this should stop the pipeline
class SecondMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("second")
await next(context)
pipeline = AgentMiddlewarePipeline([FirstMiddleware(), SecondMiddleware()])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
handler_called = False
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
nonlocal handler_called
handler_called = True
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
# Verify only first middleware was called and empty response returned
assert execution_order == ["first"]
assert result is not None
assert isinstance(result, AgentRunResponse)
assert result.messages == [] # Empty response
assert not handler_called
async def test_function_middleware_pre_execution_override_with_next(
self, mock_function: AIFunction[Any, Any]
) -> None:
"""Test that function middleware can override result before calling next() - this skips handler execution."""
class FunctionTestArgs(BaseModel):
name: str = Field(description="Test name parameter")
class PreOverrideFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Set override first
context.result = "pre-override result"
# Then call next() to continue middleware pipeline
await next(context)
middleware = PreOverrideFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
handler_called = False
async def final_handler(ctx: FunctionInvocationContext) -> str:
nonlocal handler_called
handler_called = True
# This should not be called when result is pre-set
return "original result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
# Verify pre-override worked and handler was NOT called (because result was already set)
assert result == "pre-override result"
assert not handler_called
@pytest.fixture
def mock_agent() -> AgentProtocol:
"""Mock agent for testing."""
agent = MagicMock(spec=AgentProtocol)
agent.name = "test_agent"
return agent
@pytest.fixture
def mock_function() -> AIFunction[Any, Any]:
"""Mock function for testing."""
function = MagicMock(spec=AIFunction[Any, Any])
function.name = "test_function"
return function
@@ -0,0 +1,463 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Any
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel, Field
from agent_framework import (
AgentProtocol,
AgentRunResponse,
AgentRunResponseUpdate,
ChatAgent,
ChatMessage,
Role,
TextContent,
)
from agent_framework._middleware import (
AgentMiddleware,
AgentMiddlewarePipeline,
AgentRunContext,
FunctionInvocationContext,
FunctionMiddleware,
FunctionMiddlewarePipeline,
)
from agent_framework._tools import AIFunction
from .conftest import MockChatClient
class FunctionTestArgs(BaseModel):
"""Test arguments for function middleware tests."""
name: str = Field(description="Test name parameter")
class TestResultOverrideMiddleware:
"""Test cases for middleware result override functionality."""
async def test_agent_middleware_response_override_non_streaming(self, mock_agent: AgentProtocol) -> None:
"""Test that agent middleware can override response for non-streaming execution."""
override_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="overridden response")])
class ResponseOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Execute the pipeline first, then override the response
await next(context)
context.result = override_response
middleware = ResponseOverrideMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
handler_called = False
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
nonlocal handler_called
handler_called = True
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="original response")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
# Verify the overridden response is returned
assert result is not None
assert result == override_response
assert result.messages[0].text == "overridden response"
# Verify original handler was called since middleware called next()
assert handler_called
async def test_agent_middleware_response_override_streaming(self, mock_agent: AgentProtocol) -> None:
"""Test that agent middleware can override response for streaming execution."""
async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]:
yield AgentRunResponseUpdate(contents=[TextContent(text="overridden")])
yield AgentRunResponseUpdate(contents=[TextContent(text=" stream")])
class StreamResponseOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Execute the pipeline first, then override the response stream
await next(context)
context.result = override_stream()
middleware = StreamResponseOverrideMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
yield AgentRunResponseUpdate(contents=[TextContent(text="original")])
updates: list[AgentRunResponseUpdate] = []
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
updates.append(update)
# Verify the overridden response stream is returned
assert len(updates) == 2
assert updates[0].text == "overridden"
assert updates[1].text == " stream"
async def test_function_middleware_result_override(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test that function middleware can override result."""
override_result = "overridden function result"
class ResultOverrideMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Execute the pipeline first, then override the result
await next(context)
context.result = override_result
middleware = ResultOverrideMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
handler_called = False
async def final_handler(ctx: FunctionInvocationContext) -> str:
nonlocal handler_called
handler_called = True
return "original function result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
# Verify the overridden result is returned
assert result == override_result
# Verify original handler was called since middleware called next()
assert handler_called
async def test_chat_agent_middleware_response_override(self) -> None:
"""Test result override functionality with ChatAgent integration."""
mock_chat_client = MockChatClient()
class ChatAgentResponseOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Always call next() first to allow execution
await next(context)
# Then conditionally override based on content
if any("special" in msg.text for msg in context.messages if msg.text):
context.result = AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Special response from middleware!")]
)
# Create ChatAgent with override middleware
middleware = ChatAgentResponseOverrideMiddleware()
agent = ChatAgent(chat_client=mock_chat_client, middleware=[middleware])
# Test override case
override_messages = [ChatMessage(role=Role.USER, text="Give me a special response")]
override_response = await agent.run(override_messages)
assert override_response.messages[0].text == "Special response from middleware!"
# Verify chat client was called since middleware called next()
assert mock_chat_client.call_count == 1
# Test normal case
normal_messages = [ChatMessage(role=Role.USER, text="Normal request")]
normal_response = await agent.run(normal_messages)
assert normal_response.messages[0].text == "test response"
# Verify chat client was called for normal case
assert mock_chat_client.call_count == 2
async def test_chat_agent_middleware_streaming_override(self) -> None:
"""Test streaming result override functionality with ChatAgent integration."""
mock_chat_client = MockChatClient()
async def custom_stream() -> AsyncIterable[AgentRunResponseUpdate]:
yield AgentRunResponseUpdate(contents=[TextContent(text="Custom")])
yield AgentRunResponseUpdate(contents=[TextContent(text=" streaming")])
yield AgentRunResponseUpdate(contents=[TextContent(text=" response!")])
class ChatAgentStreamOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Always call next() first to allow execution
await next(context)
# Then conditionally override based on content
if any("custom stream" in msg.text for msg in context.messages if msg.text):
context.result = custom_stream()
# Create ChatAgent with override middleware
middleware = ChatAgentStreamOverrideMiddleware()
agent = ChatAgent(chat_client=mock_chat_client, middleware=[middleware])
# Test streaming override case
override_messages = [ChatMessage(role=Role.USER, text="Give me a custom stream")]
override_updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream(override_messages):
override_updates.append(update)
assert len(override_updates) == 3
assert override_updates[0].text == "Custom"
assert override_updates[1].text == " streaming"
assert override_updates[2].text == " response!"
# Test normal streaming case
normal_messages = [ChatMessage(role=Role.USER, text="Normal streaming request")]
normal_updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream(normal_messages):
normal_updates.append(update)
assert len(normal_updates) == 2
assert normal_updates[0].text == "test streaming response "
assert normal_updates[1].text == "another update"
async def test_agent_middleware_conditional_no_next(self, mock_agent: AgentProtocol) -> None:
"""Test that when agent middleware conditionally doesn't call next(), no execution happens."""
class ConditionalNoNextMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Only call next() if message contains "execute"
if any("execute" in msg.text for msg in context.messages if msg.text):
await next(context)
# Otherwise, don't call next() - no execution should happen
middleware = ConditionalNoNextMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
handler_called = False
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
nonlocal handler_called
handler_called = True
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")])
# Test case where next() is NOT called
no_execute_messages = [ChatMessage(role=Role.USER, text="Don't run this")]
no_execute_context = AgentRunContext(agent=mock_agent, messages=no_execute_messages)
no_execute_result = await pipeline.execute(mock_agent, no_execute_messages, no_execute_context, final_handler)
# When middleware doesn't call next(), result should be empty AgentRunResponse
assert no_execute_result is not None
assert isinstance(no_execute_result, AgentRunResponse)
assert no_execute_result.messages == [] # Empty response
assert not handler_called
assert no_execute_context.result is None
# Reset for next test
handler_called = False
# Test case where next() IS called
execute_messages = [ChatMessage(role=Role.USER, text="Please execute this")]
execute_context = AgentRunContext(agent=mock_agent, messages=execute_messages)
execute_result = await pipeline.execute(mock_agent, execute_messages, execute_context, final_handler)
assert execute_result is not None
assert execute_result.messages[0].text == "executed response"
assert handler_called
async def test_function_middleware_conditional_no_next(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test that when function middleware conditionally doesn't call next(), no execution happens."""
class ConditionalNoNextFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Only call next() if argument name contains "execute"
args = context.arguments
assert isinstance(args, FunctionTestArgs)
if "execute" in args.name:
await next(context)
# Otherwise, don't call next() - no execution should happen
middleware = ConditionalNoNextFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
handler_called = False
async def final_handler(ctx: FunctionInvocationContext) -> str:
nonlocal handler_called
handler_called = True
return "executed function result"
# Test case where next() is NOT called
no_execute_args = FunctionTestArgs(name="test_no_action")
no_execute_context = FunctionInvocationContext(function=mock_function, arguments=no_execute_args)
no_execute_result = await pipeline.execute(mock_function, no_execute_args, no_execute_context, final_handler)
# When middleware doesn't call next(), function result should be None (functions can return None)
assert no_execute_result is None
assert not handler_called
assert no_execute_context.result is None
# Reset for next test
handler_called = False
# Test case where next() IS called
execute_args = FunctionTestArgs(name="test_execute")
execute_context = FunctionInvocationContext(function=mock_function, arguments=execute_args)
execute_result = await pipeline.execute(mock_function, execute_args, execute_context, final_handler)
assert execute_result == "executed function result"
assert handler_called
class TestResultObservability:
"""Test cases for middleware result observability functionality."""
async def test_agent_middleware_response_observability(self, mock_agent: AgentProtocol) -> None:
"""Test that middleware can observe response after execution."""
observed_responses: list[AgentRunResponse] = []
class ObservabilityMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Context should be empty before next()
assert context.result is None
# Call next to execute
await next(context)
# Context should now contain the response for observability
assert context.result is not None
assert isinstance(context.result, AgentRunResponse)
observed_responses.append(context.result)
middleware = ObservabilityMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
# Verify response was observed
assert len(observed_responses) == 1
assert observed_responses[0].messages[0].text == "executed response"
assert result == observed_responses[0]
async def test_function_middleware_result_observability(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test that middleware can observe function result after execution."""
observed_results: list[str] = []
class ObservabilityMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Context should be empty before next()
assert context.result is None
# Call next to execute
await next(context)
# Context should now contain the result for observability
assert context.result is not None
observed_results.append(context.result)
middleware = ObservabilityMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
async def final_handler(ctx: FunctionInvocationContext) -> str:
return "executed function result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
# Verify result was observed
assert len(observed_results) == 1
assert observed_results[0] == "executed function result"
assert result == observed_results[0]
async def test_agent_middleware_post_execution_override(self, mock_agent: AgentProtocol) -> None:
"""Test that middleware can override response after observing execution."""
class PostExecutionOverrideMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Call next to execute first
await next(context)
# Now observe and conditionally override
assert context.result is not None
assert isinstance(context.result, AgentRunResponse)
if "modify" in context.result.messages[0].text:
# Override after observing
context.result = AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="modified after execution")]
)
middleware = PostExecutionOverrideMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response to modify")])
result = await pipeline.execute(mock_agent, messages, context, final_handler)
# Verify response was modified after execution
assert result is not None
assert result.messages[0].text == "modified after execution"
async def test_function_middleware_post_execution_override(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test that middleware can override function result after observing execution."""
class PostExecutionOverrideMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Call next to execute first
await next(context)
# Now observe and conditionally override
assert context.result is not None
if "modify" in context.result:
# Override after observing
context.result = "modified after execution"
middleware = PostExecutionOverrideMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
async def final_handler(ctx: FunctionInvocationContext) -> str:
return "result to modify"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
# Verify result was modified after execution
assert result == "modified after execution"
@pytest.fixture
def mock_agent() -> AgentProtocol:
"""Mock agent for testing."""
agent = MagicMock(spec=AgentProtocol)
agent.name = "test_agent"
return agent
@pytest.fixture
def mock_function() -> AIFunction[Any, Any]:
"""Mock function for testing."""
function = MagicMock(spec=AIFunction[Any, Any])
function.name = "test_function"
return function
@@ -0,0 +1,544 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Awaitable, Callable
from agent_framework import (
AgentRunResponseUpdate,
ChatAgent,
ChatMessage,
ChatResponse,
ChatResponseUpdate,
FunctionCallContent,
FunctionResultContent,
Role,
TextContent,
)
from agent_framework._middleware import (
AgentMiddleware,
AgentRunContext,
FunctionInvocationContext,
FunctionMiddleware,
)
from .conftest import MockChatClient
# region ChatAgent Tests
class TestChatAgentClassBasedMiddleware:
"""Test cases for class-based middleware integration with ChatAgent."""
async def test_class_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
"""Test class-based agent middleware with ChatAgent."""
execution_order: list[str] = []
class TrackingAgentMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
# Create ChatAgent with middleware
middleware = TrackingAgentMiddleware("agent_middleware")
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
# Note: conftest "MockChatClient" returns different text format
assert "test response" in response.messages[0].text
# Verify middleware execution order
assert execution_order == ["agent_middleware_before", "agent_middleware_after"]
async def test_class_based_function_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
"""Test class-based function middleware with ChatAgent."""
execution_order: list[str] = []
class TrackingFunctionMiddleware(FunctionMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
# Create ChatAgent with function middleware (no tools, so function middleware won't be triggered)
middleware = TrackingFunctionMiddleware("function_middleware")
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 1
# Note: Function middleware won't execute since no function calls are made
assert execution_order == []
class TestChatAgentFunctionBasedMiddleware:
"""Test cases for function-based middleware integration with ChatAgent."""
async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
"""Test function-based agent middleware with ChatAgent."""
execution_order: list[str] = []
async def tracking_agent_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("agent_function_before")
await next(context)
execution_order.append("agent_function_after")
# Create ChatAgent with function middleware
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_agent_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
assert response.messages[0].text == "test response"
assert chat_client.call_count == 1
# Verify middleware execution order
assert execution_order == ["agent_function_before", "agent_function_after"]
async def test_function_based_function_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
"""Test function-based function middleware with ChatAgent."""
execution_order: list[str] = []
async def tracking_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("function_function_before")
await next(context)
execution_order.append("function_function_after")
# Create ChatAgent with function middleware (no tools, so function middleware won't be triggered)
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_function_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 1
# Note: Function middleware won't execute since no function calls are made
assert execution_order == []
class TestChatAgentStreamingMiddleware:
"""Test cases for streaming middleware integration with ChatAgent."""
async def test_agent_middleware_with_streaming(self, chat_client: "MockChatClient") -> None:
"""Test agent middleware with streaming ChatAgent responses."""
execution_order: list[str] = []
streaming_flags: list[bool] = []
class StreamingTrackingMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("middleware_before")
streaming_flags.append(context.is_streaming)
await next(context)
execution_order.append("middleware_after")
# Create ChatAgent with middleware
middleware = StreamingTrackingMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
# Set up mock streaming responses
chat_client.streaming_responses = [
[
ChatResponseUpdate(contents=[TextContent(text="Streaming")], role=Role.ASSISTANT),
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
]
]
# Execute streaming
messages = [ChatMessage(role=Role.USER, text="test message")]
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream(messages):
updates.append(update)
# Verify streaming response
assert len(updates) == 2
assert updates[0].text == "Streaming"
assert updates[1].text == " response"
assert chat_client.call_count == 1
# Verify middleware was called and streaming flag was set correctly
assert execution_order == ["middleware_before", "middleware_after"]
assert streaming_flags == [True] # Context should indicate streaming
async def test_non_streaming_vs_streaming_flag_validation(self, chat_client: "MockChatClient") -> None:
"""Test that is_streaming flag is correctly set for different execution modes."""
streaming_flags: list[bool] = []
class FlagTrackingMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
streaming_flags.append(context.is_streaming)
await next(context)
# Create ChatAgent with middleware
middleware = FlagTrackingMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
messages = [ChatMessage(role=Role.USER, text="test message")]
# Test non-streaming execution
response = await agent.run(messages)
assert response is not None
# Test streaming execution
async for _ in agent.run_stream(messages):
pass
# Verify flags: [non-streaming, streaming]
assert streaming_flags == [False, True]
class TestChatAgentMultipleMiddlewareOrdering:
"""Test cases for multiple middleware execution order with ChatAgent."""
async def test_multiple_agent_middleware_execution_order(self, chat_client: "MockChatClient") -> None:
"""Test that multiple agent middlewares execute in correct order with ChatAgent."""
execution_order: list[str] = []
class OrderedMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
# Create multiple middlewares
middleware1 = OrderedMiddleware("first")
middleware2 = OrderedMiddleware("second")
middleware3 = OrderedMiddleware("third")
# Create ChatAgent with multiple middlewares
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2, middleware3])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert chat_client.call_count == 1
# Verify execution order (should be nested: first wraps second wraps third)
expected_order = ["first_before", "second_before", "third_before", "third_after", "second_after", "first_after"]
assert execution_order == expected_order
async def test_mixed_middleware_types_with_chat_agent(self, chat_client: "MockChatClient") -> None:
"""Test mixed class and function-based middlewares with ChatAgent."""
execution_order: list[str] = []
class ClassAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("class_agent_before")
await next(context)
execution_order.append("class_agent_after")
async def function_agent_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("function_agent_before")
await next(context)
execution_order.append("function_agent_after")
class ClassFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append("class_function_before")
await next(context)
execution_order.append("class_function_after")
async def function_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("function_function_before")
await next(context)
execution_order.append("function_function_after")
# Create ChatAgent with mixed middleware types (no tools, focusing on agent middleware)
agent = ChatAgent(
chat_client=chat_client,
middleware=[
ClassAgentMiddleware(),
function_agent_middleware,
ClassFunctionMiddleware(), # Won't execute without function calls
function_function_middleware, # Won't execute without function calls
],
)
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert chat_client.call_count == 1
# Verify that agent middlewares were executed in correct order
# (Function middlewares won't execute since no functions are called)
expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"]
assert execution_order == expected_order
# region Tool Functions for Testing
def sample_tool_function(location: str) -> str:
"""A simple tool function for middleware testing."""
return f"Weather in {location}: sunny"
# region ChatAgent Function Middleware Tests with Tools
class TestChatAgentFunctionMiddlewareWithTools:
"""Test cases for function middleware integration with ChatAgent when tools are used."""
async def test_class_based_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None:
"""Test class-based function middleware with ChatAgent when function calls are made."""
execution_order: list[str] = []
class TrackingFunctionMiddleware(FunctionMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_123",
name="sample_tool_function",
arguments='{"location": "Seattle"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client.responses = [function_call_response, final_response]
# Create ChatAgent with function middleware and tools
middleware = TrackingFunctionMiddleware("function_middleware")
agent = ChatAgent(
chat_client=chat_client,
middleware=[middleware],
tools=[sample_tool_function],
)
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="Get weather for Seattle")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
# Verify function middleware was executed
assert execution_order == ["function_middleware_before", "function_middleware_after"]
# Verify function call and result are in the response
all_contents = [content for message in response.messages for content in message.contents]
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
assert len(function_calls) == 1
assert len(function_results) == 1
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
async def test_function_based_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None:
"""Test function-based function middleware with ChatAgent when function calls are made."""
execution_order: list[str] = []
async def tracking_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("function_middleware_before")
await next(context)
execution_order.append("function_middleware_after")
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_456",
name="sample_tool_function",
arguments='{"location": "San Francisco"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client.responses = [function_call_response, final_response]
# Create ChatAgent with function middleware and tools
agent = ChatAgent(
chat_client=chat_client,
middleware=[tracking_function_middleware],
tools=[sample_tool_function],
)
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
# Verify function middleware was executed
assert execution_order == ["function_middleware_before", "function_middleware_after"]
# Verify function call and result are in the response
all_contents = [content for message in response.messages for content in message.contents]
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
assert len(function_calls) == 1
assert len(function_results) == 1
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
async def test_mixed_agent_and_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None:
"""Test both agent and function middleware with ChatAgent when function calls are made."""
execution_order: list[str] = []
class TrackingAgentMiddleware(AgentMiddleware):
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
execution_order.append("agent_middleware_before")
await next(context)
execution_order.append("agent_middleware_after")
class TrackingFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append("function_middleware_before")
await next(context)
execution_order.append("function_middleware_after")
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_789",
name="sample_tool_function",
arguments='{"location": "New York"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client.responses = [function_call_response, final_response]
# Create ChatAgent with both agent and function middleware and tools
agent = ChatAgent(
chat_client=chat_client,
middleware=[TrackingAgentMiddleware(), TrackingFunctionMiddleware()],
tools=[sample_tool_function],
)
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="Get weather for New York")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
# Verify middleware execution order: agent middleware wraps everything,
# function middleware only for function calls
expected_order = [
"agent_middleware_before",
"function_middleware_before",
"function_middleware_after",
"agent_middleware_after",
]
assert execution_order == expected_order
# Verify function call and result are in the response
all_contents = [content for message in response.messages for content in message.contents]
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
assert len(function_calls) == 1
assert len(function_results) == 1
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentMiddleware,
AgentRunContext,
AgentRunResponse,
ChatMessage,
FunctionInvocationContext,
FunctionMiddleware,
Role,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Class-based Middleware Example
This sample demonstrates how to implement middleware using class-based approach by inheriting
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
containing sensitive information like passwords or secrets
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
This approach is useful when you need stateful middleware or complex logic that benefits
from object-oriented design patterns.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent middleware that checks for security violations."""
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Override the result with warning message
context.result = AgentRunResponse(
messages=[
ChatMessage(role=Role.ASSISTANT, text="Detected sensitive information, the request is blocked.")
]
)
# Simply don't call next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await next(context)
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function calls."""
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await next(context)
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating class-based middleware."""
print("=== Class-based Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
# Test with security-related query
print("--- Security Test ---")
query = "What's the password for the weather service?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from typing import Annotated
from agent_framework import FunctionInvocationContext
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Exception Handling with Middleware
This sample demonstrates how to use middleware for centralized exception handling in function calls.
The example shows:
- How to catch exceptions thrown by functions and provide graceful error responses
- Overriding function results when errors occur to provide user-friendly messages
- Using middleware to implement retry logic, fallback mechanisms, or error reporting
The middleware catches TimeoutError from an unstable data service and replaces it with
a helpful message for the user, preventing raw exceptions from reaching the end user.
"""
def unstable_data_service(
query: Annotated[str, Field(description="The data query to execute.")],
) -> str:
"""A simulated data service that sometimes throws exceptions."""
# Simulate failure
raise TimeoutError("Data service request timed out")
async def exception_handling_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
function_name = context.function.name
try:
print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
await next(context)
print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
except TimeoutError as e:
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
# Override function result to provide custom message in response.
context.result = (
"Request Timeout: The data service is taking longer than expected to respond.",
"Respond with message - 'Sorry for the inconvenience, please try again later.'",
)
async def main() -> None:
"""Example demonstrating exception handling with middleware."""
print("=== Exception Handling Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="DataAgent",
instructions="You are a helpful data assistant. Use the data service tool to fetch information for users.",
tools=unstable_data_service,
middleware=exception_handling_middleware,
) as agent,
):
query = "Get user statistics"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,109 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentRunContext,
FunctionInvocationContext,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Function-based Middleware Example
This sample demonstrates how to implement middleware using simple async functions instead of classes.
The example includes:
- Security middleware that validates agent requests for sensitive information
- Logging middleware that tracks function execution timing and parameters
- Performance monitoring to measure execution duration
Function-based middleware is ideal for simple, stateless operations and provides a more
lightweight approach compared to class-based middleware. Both agent and function middleware
can be implemented as async functions that accept context and next parameters.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def security_agent_middleware(
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Agent middleware that checks for security violations."""
# Check for potential security violations in the query
# For this example, we'll check the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Simply don't call next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await next(context)
async def logging_function_middleware(
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Function middleware that logs function calls."""
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await next(context)
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating function-based middleware."""
print("=== Function-based Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[security_agent_middleware, logging_function_middleware],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Tokyo?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}\n")
# Test with security violation
print("--- Security Test ---")
query = "What's the secret weather password?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}\n")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import FunctionInvocationContext
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Result Override with Middleware
This sample demonstrates how to use middleware to intercept and modify function results
after execution. The example shows:
- How to execute the original function first and then modify its result
- Replacing function outputs with custom messages or transformed data
- Using middleware for result filtering, formatting, or enhancement
The weather override middleware lets the original weather function execute normally,
then replaces its result with a custom "perfect weather" message, demonstrating
how middleware can be used for content filtering, A/B testing, or result enhancement.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def weather_override_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
function_name = context.function.name
# Let the original function execute first
await next(context)
# Override the result if it's a weather function
if function_name == "get_weather" and context.result is not None:
original_result = str(context.result)
print(f"[WeatherOverrideMiddleware] Original result: {original_result}")
# Override with a custom message
# It's also possible to override the result before "next()" call if needed
custom_message = (
"Weather Advisory - due to special atmospheric conditions, "
"all locations are experiencing perfect weather today! "
"Temperature is a comfortable 22°C with gentle breezes. "
"Perfect day for outdoor activities!"
)
context.result = custom_message
print(f"[WeatherOverrideMiddleware] Overriding with custom message: {custom_message}")
async def main() -> None:
"""Example demonstrating result override with middleware."""
print("=== Result Override Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant. Use the weather tool to get current conditions.",
tools=get_weather,
middleware=weather_override_middleware,
) as agent,
):
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}")
if __name__ == "__main__":
asyncio.run(main())