Python: Added chat middleware and more examples (#883)

* Added example with stateful middleware

* Added chat middleware

* Updated middleware example with override scenario

* Small revert

* Small fixes

* Added kwargs to context objects

* Added README

* Added function middleware to chat client

* Small refactoring

* Reverted example files

* Made MiddlewareWrapper generic

* Added Middleware exception

* Small refactoring

* Small fix
This commit is contained in:
Dmytro Struk
2025-09-26 08:10:56 -07:00
committed by GitHub
Unverified
parent 863c8d7471
commit eec7f192eb
19 changed files with 2667 additions and 267 deletions
@@ -34,6 +34,7 @@ from agent_framework import (
UsageContent,
UsageDetails,
get_logger,
use_chat_middleware,
use_function_invocation,
)
from agent_framework._pydantic import AFBaseSettings
@@ -123,6 +124,7 @@ TAzureAIAgentClient = TypeVar("TAzureAIAgentClient", bound="AzureAIAgentClient")
@use_function_invocation
@use_observability
@use_chat_middleware
class AzureAIAgentClient(BaseChatClient):
"""Azure AI Agent Chat client."""
@@ -10,7 +10,13 @@ from pydantic import BaseModel, Field
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, ContextProvider
from ._middleware import Middleware
from ._middleware import (
ChatMiddleware,
ChatMiddlewareCallable,
FunctionMiddleware,
FunctionMiddlewareCallable,
Middleware,
)
from ._pydantic import AFBaseModel
from ._threads import ChatMessageStore
from ._tools import ToolProtocol
@@ -189,6 +195,14 @@ class BaseChatClient(AFBaseModel, ABC):
"""Base class for chat clients."""
additional_properties: dict[str, Any] = Field(default_factory=dict)
middleware: (
ChatMiddleware
| ChatMiddlewareCallable
| FunctionMiddleware
| FunctionMiddlewareCallable
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
| None
) = None
OTEL_PROVIDER_NAME: str = "unknown"
# This is used for OTel setup, should be overridden in subclasses
@@ -346,13 +360,7 @@ class BaseChatClient(AFBaseModel, ABC):
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
async def get_streaming_response(
self,
@@ -432,14 +440,8 @@ class BaseChatClient(AFBaseModel, ABC):
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
async for update in self._inner_get_streaming_response(
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
messages=prepped_messages, chat_options=chat_options, **kwargs
):
yield update
File diff suppressed because it is too large Load Diff
+23 -4
View File
@@ -607,6 +607,7 @@ async def _auto_invoke_function(
middleware_context = FunctionInvocationContext(
function=tool,
arguments=args,
kwargs=custom_args or {},
)
async def final_function_handler(context_obj: Any) -> Any:
@@ -721,8 +722,16 @@ def _handle_function_calls_response(
**kwargs: Any,
) -> "ChatResponse":
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import ChatMessage, ChatOptions, FunctionCallContent, FunctionResultContent
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
prepped_messages = prepare_messages(messages)
response: "ChatResponse | None" = None
fcc_messages: "list[ChatMessage]" = []
@@ -746,8 +755,9 @@ def _handle_function_calls_response(
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
middleware_pipeline = stored_middleware_pipeline
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
@@ -820,8 +830,16 @@ def _handle_function_calls_streaming_response(
) -> AsyncIterable["ChatResponseUpdate"]:
"""Wrap the inner get streaming response method to handle tool calls."""
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
prepped_messages = prepare_messages(messages)
for attempt_idx in range(max_iterations):
all_updates: list["ChatResponseUpdate"] = []
@@ -858,8 +876,9 @@ def _handle_function_calls_streaming_response(
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
middleware_pipeline = stored_middleware_pipeline
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
@@ -13,16 +13,18 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from pydantic import SecretStr, ValidationError
from pydantic.networks import AnyUrl
from .._tools import use_function_invocation
from .._types import (
from agent_framework import (
ChatResponse,
ChatResponseUpdate,
CitationAnnotation,
TextContent,
use_chat_middleware,
use_function_invocation,
)
from ..exceptions import ServiceInitializationError
from ..observability import use_observability
from ..openai._chat_client import OpenAIBaseChatClient
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import use_observability
from agent_framework.openai._chat_client import OpenAIBaseChatClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
@@ -41,6 +43,7 @@ TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAICha
@use_function_invocation
@use_observability
@use_chat_middleware
class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient):
"""Azure OpenAI Chat completion class."""
@@ -9,10 +9,11 @@ from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from pydantic import SecretStr, ValidationError
from pydantic.networks import AnyUrl
from .._tools import use_function_invocation
from ..exceptions import ServiceInitializationError
from ..observability import use_observability
from ..openai._responses_client import OpenAIBaseResponsesClient
from agent_framework import use_chat_middleware, use_function_invocation
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import use_observability
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
@@ -23,6 +24,7 @@ TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="Azur
@use_observability
@use_function_invocation
@use_chat_middleware
class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient):
"""Azure Responses completion class."""
@@ -128,3 +128,9 @@ class AdditionItemMismatch(AgentFrameworkException):
"""An error occurred while adding two types."""
pass
class MiddlewareException(AgentFrameworkException):
"""An error occurred during middleware execution."""
pass
@@ -21,6 +21,7 @@ from openai.types.beta.threads.runs import RunStep
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._middleware import use_chat_middleware
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation
from .._types import (
ChatMessage,
@@ -52,6 +53,7 @@ __all__ = ["OpenAIAssistantsClient"]
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
"""OpenAI Assistants client."""
@@ -18,6 +18,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import use_chat_middleware
from .._tools import AIFunction, HostedWebSearchTool, ToolProtocol, use_function_invocation
from .._types import (
ChatMessage,
@@ -452,6 +453,7 @@ TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient")
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
"""OpenAI Chat completion class."""
@@ -26,6 +26,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import use_chat_middleware
from .._tools import (
AIFunction,
HostedCodeInterpreterTool,
@@ -933,6 +934,7 @@ TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponse
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
"""OpenAI Responses client class."""
@@ -25,6 +25,7 @@ from agent_framework import (
TextContent,
ToolProtocol,
ai_function,
use_chat_middleware,
use_function_invocation,
)
@@ -111,11 +112,13 @@ class MockChatClient:
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
@use_chat_middleware
class MockBaseChatClient(BaseChatClient):
"""Mock implementation of the BaseChatClient."""
run_responses: list[ChatResponse] = Field(default_factory=list)
streaming_responses: list[list[ChatResponseUpdate]] = Field(default_factory=list)
call_count: int = Field(default=0)
@override
async def _inner_get_response(
@@ -136,6 +139,7 @@ class MockBaseChatClient(BaseChatClient):
The chat response contents representing the response(s).
"""
logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}")
self.call_count += 1
if not self.run_responses:
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
@@ -12,6 +12,8 @@ from agent_framework import (
AgentRunResponse,
AgentRunResponseUpdate,
ChatMessage,
ChatResponse,
ChatResponseUpdate,
Role,
TextContent,
)
@@ -19,11 +21,15 @@ from agent_framework._middleware import (
AgentMiddleware,
AgentMiddlewarePipeline,
AgentRunContext,
ChatContext,
ChatMiddleware,
ChatMiddlewarePipeline,
FunctionInvocationContext,
FunctionMiddleware,
FunctionMiddlewarePipeline,
)
from agent_framework._tools import AIFunction
from agent_framework._types import ChatOptions
class TestAgentRunContext:
@@ -74,6 +80,46 @@ class TestFunctionInvocationContext:
assert context.metadata == metadata
class TestChatContext:
"""Test cases for ChatContext."""
def test_init_with_defaults(self, mock_chat_client: Any) -> None:
"""Test ChatContext initialization with default values."""
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
assert context.chat_client is mock_chat_client
assert context.messages == messages
assert context.chat_options is chat_options
assert context.is_streaming is False
assert context.metadata == {}
assert context.result is None
assert context.terminate is False
def test_init_with_custom_values(self, mock_chat_client: Any) -> None:
"""Test ChatContext initialization with custom values."""
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions(temperature=0.5)
metadata = {"key": "value"}
context = ChatContext(
chat_client=mock_chat_client,
messages=messages,
chat_options=chat_options,
is_streaming=True,
metadata=metadata,
terminate=True,
)
assert context.chat_client is mock_chat_client
assert context.messages == messages
assert context.chat_options is chat_options
assert context.is_streaming is True
assert context.metadata == metadata
assert context.terminate is True
class TestAgentMiddlewarePipeline:
"""Test cases for AgentMiddlewarePipeline."""
@@ -410,6 +456,233 @@ class TestFunctionMiddlewarePipeline:
assert execution_order == ["test_before", "handler", "test_after"]
class TestChatMiddlewarePipeline:
"""Test cases for ChatMiddlewarePipeline."""
class PreNextTerminateChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
context.terminate = True
await next(context)
class PostNextTerminateChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
await next(context)
context.terminate = True
def test_init_empty(self) -> None:
"""Test ChatMiddlewarePipeline initialization with no middlewares."""
pipeline = ChatMiddlewarePipeline()
assert not pipeline.has_middlewares
def test_init_with_class_middleware(self) -> None:
"""Test ChatMiddlewarePipeline initialization with class-based middleware."""
middleware = TestChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
assert pipeline.has_middlewares
def test_init_with_function_middleware(self) -> None:
"""Test ChatMiddlewarePipeline initialization with function-based middleware."""
async def test_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
await next(context)
pipeline = ChatMiddlewarePipeline([test_middleware])
assert pipeline.has_middlewares
async def test_execute_no_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with no middleware."""
pipeline = ChatMiddlewarePipeline()
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
async def final_handler(ctx: ChatContext) -> ChatResponse:
return expected_response
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result == expected_response
async def test_execute_with_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with middleware."""
execution_order: list[str] = []
class OrderTrackingChatMiddleware(ChatMiddleware):
def __init__(self, name: str):
self.name = name
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
middleware = OrderTrackingChatMiddleware("test")
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return expected_response
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result == expected_response
assert execution_order == ["test_before", "handler", "test_after"]
async def test_execute_stream_no_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with no middleware."""
pipeline = ChatMiddlewarePipeline()
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
async def test_execute_stream_with_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with middleware."""
execution_order: list[str] = []
class StreamOrderTrackingChatMiddleware(ChatMiddleware):
def __init__(self, name: str):
self.name = name
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
middleware = StreamOrderTrackingChatMiddleware("test")
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
execution_order.append("handler_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
async def test_execute_with_pre_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with termination before next()."""
middleware = self.PreNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> ChatResponse:
# Handler should not be executed when terminated before next()
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert response is None
assert context.terminate
# Handler should not be called when terminated before next()
assert execution_order == []
async def test_execute_with_post_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with termination after next()."""
middleware = self.PostNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert response is not None
assert len(response.messages) == 1
assert response.messages[0].text == "response"
assert context.terminate
assert execution_order == ["handler"]
async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with termination before next()."""
middleware = self.PreNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
# Handler should not be executed when terminated before next()
execution_order.append("handler_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert context.terminate
# Handler should not be called when terminated before next()
assert execution_order == []
assert not updates
async def test_execute_stream_with_post_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with termination after next()."""
middleware = self.PostNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
execution_order.append("handler_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
assert context.terminate
assert execution_order == ["handler_start", "handler_end"]
class TestClassBasedMiddleware:
"""Test cases for class-based middleware implementations."""
@@ -601,6 +874,37 @@ class TestMixedMiddleware:
assert result == "result"
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
async def test_mixed_chat_middleware(self, mock_chat_client: Any) -> None:
"""Test mixed class and function-based chat middleware."""
execution_order: list[str] = []
class ClassChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("class_before")
await next(context)
execution_order.append("class_after")
async def function_chat_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
execution_order.append("function_before")
await next(context)
execution_order.append("function_after")
pipeline = ChatMiddlewarePipeline([ClassChatMiddleware(), function_chat_middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result is not None
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
class TestMultipleMiddlewareOrdering:
"""Test cases for multiple middleware execution order."""
@@ -695,6 +999,52 @@ class TestMultipleMiddlewareOrdering:
expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"]
assert execution_order == expected_order
async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> None:
"""Test that multiple chat middlewares execute in registration order."""
execution_order: list[str] = []
class FirstChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
class SecondChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
class ThirdChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("third_before")
await next(context)
execution_order.append("third_after")
middlewares = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()]
pipeline = ChatMiddlewarePipeline(middlewares) # type: ignore
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result is not None
expected_order = [
"first_before",
"second_before",
"third_before",
"handler",
"third_after",
"second_after",
"first_after",
]
assert execution_order == expected_order
class TestContextContentValidation:
"""Test cases for validating middleware context content."""
@@ -776,6 +1126,49 @@ class TestContextContentValidation:
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "result"
async def test_chat_context_validation(self, mock_chat_client: Any) -> None:
"""Test that chat context contains expected data."""
class ChatContextValidationMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Verify context has all expected attributes
assert hasattr(context, "chat_client")
assert hasattr(context, "messages")
assert hasattr(context, "chat_options")
assert hasattr(context, "is_streaming")
assert hasattr(context, "metadata")
assert hasattr(context, "result")
assert hasattr(context, "terminate")
# Verify context content
assert context.chat_client is mock_chat_client
assert len(context.messages) == 1
assert context.messages[0].role == Role.USER
assert context.messages[0].text == "test"
assert context.is_streaming is False
assert isinstance(context.metadata, dict)
assert isinstance(context.chat_options, ChatOptions)
assert context.chat_options.temperature == 0.5
# Add custom metadata
context.metadata["validated"] = True
await next(context)
middleware = ChatContextValidationMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions(temperature=0.5)
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
# Verify metadata was set by middleware
assert ctx.metadata.get("validated") is True
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result is not None
class TestStreamingScenarios:
"""Test cases for streaming and non-streaming scenarios."""
@@ -857,6 +1250,89 @@ class TestStreamingScenarios:
"stream_end",
]
async def test_chat_streaming_flag_validation(self, mock_chat_client: Any) -> None:
"""Test that is_streaming flag is correctly set for chat streaming calls."""
streaming_flags: list[bool] = []
class ChatStreamingFlagMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
streaming_flags.append(context.is_streaming)
await next(context)
middleware = ChatStreamingFlagMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
# Test non-streaming
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
streaming_flags.append(ctx.is_streaming)
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
# Test streaming
context_stream = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
streaming_flags.append(ctx.is_streaming)
yield ChatResponseUpdate(contents=[TextContent(text="chunk")])
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(
mock_chat_client, messages, chat_options, context_stream, final_stream_handler
):
updates.append(update)
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
assert streaming_flags == [False, False, True, True]
async def test_chat_streaming_middleware_behavior(self, mock_chat_client: Any) -> None:
"""Test chat middleware behavior with streaming responses."""
chunks_processed: list[str] = []
class ChatStreamProcessingMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
chunks_processed.append("before_stream")
await next(context)
chunks_processed.append("after_stream")
middleware = ChatStreamProcessingMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
chunks_processed.append("stream_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
chunks_processed.append("chunk1_yielded")
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
chunks_processed.append("chunk2_yielded")
chunks_processed.append("stream_end")
updates: list[str] = []
async for update in pipeline.execute_stream(
mock_chat_client, messages, chat_options, context, final_stream_handler
):
updates.append(update.text)
assert updates == ["chunk1", "chunk2"]
assert chunks_processed == [
"before_stream",
"after_stream",
"stream_start",
"chunk1_yielded",
"chunk2_yielded",
"stream_end",
]
# region Helper classes and fixtures
@@ -883,6 +1359,13 @@ class TestFunctionMiddleware(FunctionMiddleware):
await next(context)
class TestChatMiddleware(ChatMiddleware):
"""Test implementation of ChatMiddleware."""
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
await next(context)
class MockFunctionArgs(BaseModel):
"""Test arguments for function middleware tests."""
@@ -1027,6 +1510,100 @@ class TestMiddlewareExecutionControl:
assert result.messages == [] # Empty response
assert not handler_called
async def test_chat_middleware_no_next_no_execution(self, mock_chat_client: Any) -> None:
"""Test that when chat middleware doesn't call next(), no execution happens."""
class NoNextChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
middleware = NoNextChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
handler_called = False
async def final_handler(ctx: ChatContext) -> ChatResponse:
nonlocal handler_called
handler_called = True
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
# Verify no execution happened
assert result is None
assert not handler_called
assert context.result is None
async def test_chat_middleware_no_next_no_streaming_execution(self, mock_chat_client: Any) -> None:
"""Test that when chat middleware doesn't call next(), no streaming execution happens."""
class NoNextStreamingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
middleware = NoNextStreamingChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
handler_called = False
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
nonlocal handler_called
handler_called = True
yield ChatResponseUpdate(contents=[TextContent(text="should not execute")])
# When middleware doesn't call next(), streaming should yield no updates
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
# Verify no execution happened and no updates were yielded
assert len(updates) == 0
assert not handler_called
assert context.result is None
async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) -> None:
"""Test that when first chat middleware doesn't call next(), subsequent middlewares are not called."""
execution_order: list[str] = []
class FirstChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first")
# Don't call next() - this should stop the pipeline
class SecondChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second")
await next(context)
pipeline = ChatMiddlewarePipeline([FirstChatMiddleware(), SecondChatMiddleware()])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
handler_called = False
async def final_handler(ctx: ChatContext) -> ChatResponse:
nonlocal handler_called
handler_called = True
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
# Verify only first middleware was called and no result returned
assert execution_order == ["first"]
assert result is None
assert not handler_called
@pytest.fixture
def mock_agent() -> AgentProtocol:
@@ -1042,3 +1619,13 @@ def mock_function() -> AIFunction[Any, Any]:
function = MagicMock(spec=AIFunction[Any, Any])
function.name = "test_function"
return function
@pytest.fixture
def mock_chat_client() -> Any:
"""Mock chat client for testing."""
from agent_framework._clients import ChatClientProtocol
client = MagicMock(spec=ChatClientProtocol)
client.service_url = MagicMock(return_value="mock://test")
return client
@@ -8,7 +8,9 @@ import pytest
from agent_framework import (
AgentRunResponseUpdate,
ChatAgent,
ChatContext,
ChatMessage,
ChatMiddleware,
ChatResponse,
ChatResponseUpdate,
FunctionCallContent,
@@ -16,7 +18,9 @@ from agent_framework import (
Role,
TextContent,
agent_middleware,
chat_middleware,
function_middleware,
use_function_invocation,
)
from agent_framework._middleware import (
AgentMiddleware,
@@ -25,8 +29,9 @@ from agent_framework._middleware import (
FunctionMiddleware,
MiddlewareType,
)
from agent_framework.exceptions import MiddlewareException
from .conftest import MockChatClient
from .conftest import MockBaseChatClient, MockChatClient
# region ChatAgent Tests
@@ -713,6 +718,101 @@ class TestChatAgentFunctionMiddlewareWithTools:
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
async def test_function_middleware_can_access_and_override_custom_kwargs(
self, chat_client: "MockChatClient"
) -> None:
"""Test that function middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
middleware_called = False
@function_middleware
async def kwargs_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
nonlocal middleware_called
middleware_called = True
# Capture the original kwargs
captured_kwargs["has_chat_options"] = "chat_options" in context.kwargs
captured_kwargs["has_custom_param"] = "custom_param" in context.kwargs
captured_kwargs["custom_param"] = context.kwargs.get("custom_param")
# Capture original chat_options values if present
if "chat_options" in context.kwargs:
chat_options = context.kwargs["chat_options"]
captured_kwargs["original_temperature"] = getattr(chat_options, "temperature", None)
captured_kwargs["original_max_tokens"] = getattr(chat_options, "max_tokens", None)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
# Also modify chat_options if present
if "chat_options" in context.kwargs:
context.kwargs["chat_options"].temperature = 0.9
context.kwargs["chat_options"].max_tokens = 500
# Store modified kwargs for verification
modified_kwargs["temperature"] = context.kwargs.get("temperature")
modified_kwargs["max_tokens"] = context.kwargs.get("max_tokens")
modified_kwargs["new_param"] = context.kwargs.get("new_param")
modified_kwargs["custom_param"] = context.kwargs.get("custom_param")
# Capture modified chat_options values if present
if "chat_options" in context.kwargs:
chat_options = context.kwargs["chat_options"]
modified_kwargs["chat_options_temperature"] = getattr(chat_options, "temperature", None)
modified_kwargs["chat_options_max_tokens"] = getattr(chat_options, "max_tokens", None)
await next(context)
chat_client.responses = [
ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="test_call", name="sample_tool_function", arguments={"location": "Seattle"}
)
],
)
]
),
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Function completed")])]),
]
# Create ChatAgent with function middleware
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware], tools=[sample_tool_function])
# Execute the agent with custom parameters
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
# Verify response
assert response is not None
assert len(response.messages) > 0
# First check if middleware was called at all
assert middleware_called, "Function middleware was not called"
# Verify middleware captured the original kwargs
assert captured_kwargs["has_chat_options"] is True
assert captured_kwargs["has_custom_param"] is True
assert captured_kwargs["custom_param"] == "test_value"
assert captured_kwargs["original_temperature"] == 0.7
assert captured_kwargs["original_max_tokens"] == 100
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value"
assert modified_kwargs["chat_options_temperature"] == 0.9
assert modified_kwargs["chat_options_max_tokens"] == 500
class TestMiddlewareDynamicRebuild:
"""Test cases for dynamic middleware pipeline rebuilding with ChatAgent."""
@@ -1187,8 +1287,8 @@ class TestMiddlewareDecoratorLogic:
"""Both decorator and parameter type specified but don't match."""
# This will cause a type error at decoration time, so we need to test differently
# Should raise ValueError due to mismatch during agent creation
with pytest.raises(ValueError, match="Middleware type mismatch"):
# Should raise MiddlewareException due to mismatch during agent creation
with pytest.raises(MiddlewareException, match="Middleware type mismatch"):
@agent_middleware # type: ignore[arg-type]
async def mismatched_middleware(
@@ -1304,8 +1404,8 @@ class TestMiddlewareDecoratorLogic:
async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type
await next(context)
# Should raise ValueError
with pytest.raises(ValueError, match="Cannot determine middleware type"):
# Should raise MiddlewareException
with pytest.raises(MiddlewareException, match="Cannot determine middleware type"):
agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware])
await agent.run([ChatMessage(role=Role.USER, text="test")])
@@ -1313,8 +1413,8 @@ class TestMiddlewareDecoratorLogic:
"""Test that middleware with insufficient parameters raises an error."""
from agent_framework import ChatAgent, agent_middleware
# Should raise ValueError about insufficient parameters
with pytest.raises(ValueError, match="must have at least 2 parameters"):
# Should raise MiddlewareException about insufficient parameters
with pytest.raises(MiddlewareException, match="must have at least 2 parameters"):
@agent_middleware # type: ignore[arg-type]
async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter
@@ -1340,3 +1440,359 @@ class TestMiddlewareDecoratorLogic:
assert hasattr(test_function_middleware, "_middleware_type")
assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined]
class TestChatAgentChatMiddleware:
"""Test cases for chat middleware integration with ChatAgent."""
async def test_class_based_chat_middleware_with_chat_agent(self) -> None:
"""Test class-based chat middleware with ChatAgent."""
execution_order: list[str] = []
class TrackingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
# Create ChatAgent with chat middleware
chat_client = MockBaseChatClient()
middleware = TrackingChatMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
assert "test response" in response.messages[0].text
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
async def test_function_based_chat_middleware_with_chat_agent(self) -> None:
"""Test function-based chat middleware with ChatAgent."""
execution_order: list[str] = []
async def tracking_chat_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
# Create ChatAgent with function-based chat middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_chat_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
assert "test response" in response.messages[0].text
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
async def test_chat_middleware_can_modify_messages(self) -> None:
"""Test that chat middleware can modify messages before sending to model."""
@chat_middleware
async def message_modifier_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Modify the first message by adding a prefix
if context.messages and len(context.messages) > 0:
original_text = context.messages[0].text or ""
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
await next(context)
# Create ChatAgent with message-modifying middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[message_modifier_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify that the message was modified (MockBaseChatClient echoes back the input)
assert response is not None
assert len(response.messages) > 0
assert "MODIFIED: test message" in response.messages[0].text
async def test_chat_middleware_can_override_response(self) -> None:
"""Test that chat middleware can override the response."""
@chat_middleware
async def response_override_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Override the response without calling next()
context.result = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
response_id="middleware-response-123",
)
context.terminate = True
# Create ChatAgent with response-overriding middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[response_override_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify that the response was overridden
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].text == "Middleware overridden response"
assert response.response_id == "middleware-response-123"
async def test_multiple_chat_middleware_execution_order(self) -> None:
"""Test that multiple chat middleware execute in the correct order."""
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
# Create ChatAgent with multiple chat middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[first_middleware, second_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert execution_order == ["first_before", "second_before", "second_after", "first_after"]
async def test_chat_middleware_with_streaming(self) -> None:
"""Test chat middleware with streaming responses."""
execution_order: list[str] = []
streaming_flags: list[bool] = []
class StreamingTrackingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("streaming_chat_before")
streaming_flags.append(context.is_streaming)
await next(context)
execution_order.append("streaming_chat_after")
# Create ChatAgent with chat middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[StreamingTrackingChatMiddleware()])
# Set up mock streaming responses
chat_client.streaming_responses = [
[
ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT),
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
]
]
# Execute streaming
messages = [ChatMessage(role=Role.USER, text="test message")]
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream(messages):
updates.append(update)
# Verify streaming response
assert len(updates) >= 1 # At least some updates
assert execution_order == ["streaming_chat_before", "streaming_chat_after"]
# Verify streaming flag was set (at least one True)
assert True in streaming_flags
async def test_chat_middleware_termination_before_execution(self) -> None:
"""Test that chat middleware can terminate execution before calling next()."""
execution_order: list[str] = []
class PreTerminationChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
context.terminate = True
# Set a custom response since we're terminating
context.result = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Terminated by middleware")]
)
# We call next() but since terminate=True, execution should stop
await next(context)
execution_order.append("middleware_after")
# Create ChatAgent with terminating middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[PreTerminationChatMiddleware()])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response was from middleware
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].text == "Terminated by middleware"
assert execution_order == ["middleware_before", "middleware_after"]
async def test_chat_middleware_termination_after_execution(self) -> None:
"""Test that chat middleware can terminate execution after calling next()."""
execution_order: list[str] = []
class PostTerminationChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
await next(context)
execution_order.append("middleware_after")
context.terminate = True
# Create ChatAgent with terminating middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[PostTerminationChatMiddleware()])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response is from actual execution
assert response is not None
assert len(response.messages) > 0
assert "test response" in response.messages[0].text
assert execution_order == ["middleware_before", "middleware_after"]
async def test_combined_middleware(self) -> None:
"""Test ChatAgent with combined middleware types."""
execution_order: list[str] = []
async def agent_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("agent_middleware_before")
await next(context)
execution_order.append("agent_middleware_after")
async def chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
async def function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("function_middleware_before")
await next(context)
execution_order.append("function_middleware_after")
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_456",
name="sample_tool_function",
arguments='{"location": "San Francisco"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client = use_function_invocation(MockBaseChatClient)()
chat_client.run_responses = [function_call_response, final_response]
# Create ChatAgent with function middleware and tools
agent = ChatAgent(
chat_client=chat_client,
middleware=[chat_middleware, function_middleware, agent_middleware],
tools=[sample_tool_function],
)
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
# Verify function middleware was executed
assert execution_order == [
"agent_middleware_before",
"chat_middleware_before",
"chat_middleware_after",
"function_middleware_before",
"function_middleware_after",
"chat_middleware_before",
"chat_middleware_after",
"agent_middleware_after",
]
# Verify function call and result are in the response
all_contents = [content for message in response.messages for content in message.contents]
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
assert len(function_calls) == 1
assert len(function_results) == 1
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None:
"""Test that agent middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
@agent_middleware
async def kwargs_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
await next(context)
# Create ChatAgent with agent middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware])
# Execute the agent with custom parameters
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
# Verify response
assert response is not None
assert len(response.messages) > 0
# Verify middleware captured the original kwargs
assert captured_kwargs["temperature"] == 0.7
assert captured_kwargs["max_tokens"] == 100
assert captured_kwargs["custom_param"] == "test_value"
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
@@ -0,0 +1,436 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Awaitable, Callable
from typing import Any
from agent_framework import (
ChatAgent,
ChatContext,
ChatMessage,
ChatMiddleware,
ChatResponse,
FunctionCallContent,
FunctionInvocationContext,
Role,
chat_middleware,
function_middleware,
use_function_invocation,
)
from .conftest import MockBaseChatClient
class TestChatMiddleware:
"""Test cases for chat middleware functionality."""
async def test_class_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test class-based chat middleware with ChatClient."""
execution_order: list[str] = []
class LoggingChatMiddleware(ChatMiddleware):
async def process(
self,
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
# Add middleware to chat client
chat_client_base.middleware = [LoggingChatMiddleware()]
# Execute chat client directly
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
# Verify middleware execution order
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
async def test_function_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test function-based chat middleware with ChatClient."""
execution_order: list[str] = []
@chat_middleware
async def logging_chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("function_middleware_before")
await next(context)
execution_order.append("function_middleware_after")
# Add middleware to chat client
chat_client_base.middleware = [logging_chat_middleware]
# Execute chat client directly
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
# Verify middleware execution order
assert execution_order == ["function_middleware_before", "function_middleware_after"]
async def test_chat_middleware_can_modify_messages(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that chat middleware can modify messages before sending to model."""
@chat_middleware
async def message_modifier_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Modify the first message by adding a prefix
if context.messages and len(context.messages) > 0:
original_text = context.messages[0].text or ""
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
await next(context)
# Add middleware to chat client
chat_client_base.middleware = [message_modifier_middleware]
# Execute chat client
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify that the message was modified (MockChatClient echoes back the input)
assert response is not None
assert len(response.messages) > 0
# The mock client should receive the modified message
assert "MODIFIED: test message" in response.messages[0].text
async def test_chat_middleware_can_override_response(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that chat middleware can override the response."""
@chat_middleware
async def response_override_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Override the response without calling next()
context.result = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
response_id="middleware-response-123",
)
context.terminate = True
# Add middleware to chat client
chat_client_base.middleware = [response_override_middleware]
# Execute chat client
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify that the response was overridden
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].text == "Middleware overridden response"
assert response.response_id == "middleware-response-123"
async def test_multiple_chat_middleware_execution_order(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that multiple chat middleware execute in the correct order."""
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
# Add middleware to chat client (order should be preserved)
chat_client_base.middleware = [first_middleware, second_middleware]
# Execute chat client
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify response
assert response is not None
# Verify middleware execution order (nested execution)
expected_order = ["first_before", "second_before", "second_after", "first_after"]
assert execution_order == expected_order
async def test_chat_agent_with_chat_middleware(self) -> None:
"""Test ChatAgent with chat middleware specified at agent level."""
execution_order: list[str] = []
@chat_middleware
async def agent_level_chat_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
execution_order.append("agent_chat_middleware_before")
await next(context)
execution_order.append("agent_chat_middleware_after")
chat_client = MockBaseChatClient()
# Create ChatAgent with chat middleware
agent = ChatAgent(chat_client=chat_client, middleware=[agent_level_chat_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
# Verify middleware execution order
assert execution_order == ["agent_chat_middleware_before", "agent_chat_middleware_after"]
async def test_chat_agent_with_multiple_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that ChatAgent can have multiple chat middleware."""
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
# Create ChatAgent with multiple chat middleware
agent = ChatAgent(chat_client=chat_client_base, middleware=[first_middleware, second_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
# Verify both middleware executed (nested execution order)
expected_order = ["first_before", "second_before", "second_after", "first_after"]
assert execution_order == expected_order
async def test_chat_middleware_with_streaming(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test chat middleware with streaming responses."""
execution_order: list[str] = []
@chat_middleware
async def streaming_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("streaming_before")
# Verify it's a streaming context
assert context.is_streaming is True
await next(context)
execution_order.append("streaming_after")
# Add middleware to chat client
chat_client_base.middleware = [streaming_middleware]
# Execute streaming response
messages = [ChatMessage(role=Role.USER, text="test message")]
updates: list[object] = []
async for update in chat_client_base.get_streaming_response(messages):
updates.append(update)
# Verify we got updates
assert len(updates) > 0
# Verify middleware executed
assert execution_order == ["streaming_before", "streaming_after"]
async def test_run_level_middleware_isolation(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that run-level middleware is isolated and doesn't persist across calls."""
execution_count = {"count": 0}
@chat_middleware
async def counting_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_count["count"] += 1
await next(context)
# First call with run-level middleware
messages = [ChatMessage(role=Role.USER, text="first message")]
response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
assert response1 is not None
assert execution_count["count"] == 1
# Second call WITHOUT run-level middleware - should not execute the middleware
messages = [ChatMessage(role=Role.USER, text="second message")]
response2 = await chat_client_base.get_response(messages)
assert response2 is not None
assert execution_count["count"] == 1 # Should still be 1, not 2
# Third call with run-level middleware again - should execute
messages = [ChatMessage(role=Role.USER, text="third message")]
response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
assert response3 is not None
assert execution_count["count"] == 2 # Should be 2 now
async def test_chat_client_middleware_can_access_and_override_custom_kwargs(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that chat client middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
@chat_middleware
async def kwargs_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
await next(context)
# Add middleware to chat client
chat_client_base.middleware = [kwargs_middleware]
# Execute chat client with custom parameters
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(
messages, temperature=0.7, max_tokens=100, custom_param="test_value"
)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert captured_kwargs["temperature"] == 0.7
assert captured_kwargs["max_tokens"] == 100
assert captured_kwargs["custom_param"] == "test_value"
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
async def test_function_middleware_registration_on_chat_client(self) -> None:
"""Test function middleware registered on ChatClient is executed during function calls."""
execution_order: list[str] = []
@function_middleware
async def test_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append(f"function_middleware_before_{context.function.name}")
await next(context)
execution_order.append(f"function_middleware_after_{context.function.name}")
# Define a simple tool function
def sample_tool(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny"
# Create function-invocation enabled chat client
chat_client = use_function_invocation(MockBaseChatClient)()
# Set function middleware directly on the chat client
chat_client.middleware = [test_function_middleware]
# Prepare responses that will trigger function invocation
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_1",
name="sample_tool",
arguments={"location": "San Francisco"},
)
],
)
]
)
final_response = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Based on the weather data, it's sunny!")]
)
chat_client.run_responses = [function_call_response, final_response]
# Execute the chat client directly with tools - this should trigger function invocation and middleware
messages = [ChatMessage(role=Role.USER, text="What's the weather in San Francisco?")]
response = await chat_client.get_response(messages, tools=[sample_tool])
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: function call + final response
# Verify function middleware was executed
assert execution_order == [
"function_middleware_before_sample_tool",
"function_middleware_after_sample_tool",
]
async def test_run_level_function_middleware(self) -> None:
"""Test that function middleware passed to get_response method is also invoked."""
execution_order: list[str] = []
@function_middleware
async def run_level_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("run_level_function_middleware_before")
await next(context)
execution_order.append("run_level_function_middleware_after")
# Define a simple tool function
def sample_tool(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny"
# Create function-invocation enabled chat client
chat_client = use_function_invocation(MockBaseChatClient)()
# Prepare responses that will trigger function invocation
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_2",
name="sample_tool",
arguments={"location": "New York"},
)
],
)
]
)
final_response = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="The weather information has been retrieved!")]
)
chat_client.run_responses = [function_call_response, final_response]
# Execute the chat client directly with run-level middleware and tools
messages = [ChatMessage(role=Role.USER, text="What's the weather in New York?")]
response = await chat_client.get_response(
messages, tools=[sample_tool], middleware=[run_level_function_middleware]
)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: function call + final response
# Verify run-level function middleware was executed once (during function invocation)
assert execution_order == [
"run_level_function_middleware_before",
"run_level_function_middleware_after",
]
@@ -13,6 +13,7 @@ from agent_framework import (
ChatResponseUpdate,
Role,
TextContent,
use_chat_middleware,
use_function_invocation,
)
@@ -37,6 +38,7 @@ custom client with ChatAgent through the create_agent() method.
@use_function_invocation
@use_chat_middleware
class EchoingChatClient(BaseChatClient):
"""A custom chat client that echoes messages back with modifications.
@@ -0,0 +1,45 @@
# Middleware Examples
This folder contains examples demonstrating various middleware patterns with the Agent Framework. Middleware allows you to intercept and modify behavior at different execution stages, including agent runs, function calls, and chat interactions.
## Examples
| File | Description |
|------|-------------|
| [`function_based_middleware.py`](function_based_middleware.py) | Demonstrates how to implement middleware using simple async functions instead of classes. Shows security validation, logging, and performance monitoring middleware. Function-based middleware is ideal for simple, stateless operations and provides a lightweight approach. |
| [`class_based_middleware.py`](class_based_middleware.py) | Shows how to implement middleware using class-based approach by inheriting from `AgentMiddleware` and `FunctionMiddleware` base classes. Includes security checks for sensitive information and detailed function execution logging with timing. |
| [`decorator_middleware.py`](decorator_middleware.py) | Demonstrates how to use `@agent_middleware` and `@function_middleware` decorators to explicitly mark middleware functions without requiring type annotations. Shows different middleware detection scenarios and explicit decorator usage. |
| [`middleware_termination.py`](middleware_termination.py) | Shows how middleware can terminate execution using the `context.terminate` flag. Includes examples of pre-termination (prevents agent processing) and post-termination (allows processing but stops further execution). Useful for security checks, rate limiting, or early exit conditions. |
| [`exception_handling_with_middleware.py`](exception_handling_with_middleware.py) | Demonstrates how to use middleware for centralized exception handling in function calls. Shows how to catch exceptions from functions, provide graceful error responses, and override function results when errors occur to provide user-friendly messages. |
| [`override_result_with_middleware.py`](override_result_with_middleware.py) | Shows how to use middleware to intercept and modify function results after execution, supporting both regular and streaming agent responses. Demonstrates result filtering, formatting, enhancement, and custom streaming response generation. |
| [`shared_state_middleware.py`](shared_state_middleware.py) | Demonstrates how to implement function-based middleware within a class to share state between multiple middleware functions. Shows how middleware can work together by sharing state, including call counting and result enhancement. |
| [`agent_and_run_level_middleware.py`](agent_and_run_level_middleware.py) | Explains the difference between agent-level middleware (applied to ALL runs of the agent) and run-level middleware (applied to specific runs only). Shows security validation, performance monitoring, and context-specific middleware patterns. |
| [`chat_middleware.py`](chat_middleware.py) | Demonstrates how to use chat middleware to observe and override inputs sent to AI models. Shows how to intercept chat requests, log and modify input messages, and override entire responses before they reach the underlying AI service. |
## Key Concepts
### Middleware Types
- **Agent Middleware**: Intercepts agent run execution, allowing you to modify requests and responses
- **Function Middleware**: Intercepts function calls within agents, enabling logging, validation, and result modification
- **Chat Middleware**: Intercepts chat requests sent to AI models, allowing input/output transformation
### Implementation Approaches
- **Function-based**: Simple async functions for lightweight, stateless operations
- **Class-based**: Inherit from base middleware classes for complex, stateful operations
- **Decorator-based**: Use decorators for explicit middleware marking
### Common Use Cases
- **Security**: Validate requests, block sensitive information, implement access controls
- **Logging**: Track execution timing, log parameters and results, monitor performance
- **Error Handling**: Catch exceptions, provide graceful fallbacks, implement retry logic
- **Result Transformation**: Filter, format, or enhance function outputs
- **State Management**: Share data between middleware functions, maintain execution context
### Execution Control
- **Termination**: Use `context.terminate` to stop execution early
- **Result Override**: Modify or replace function/agent results
- **Streaming Support**: Handle both regular and streaming responses
@@ -0,0 +1,245 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
ChatContext,
ChatMessage,
ChatMiddleware,
ChatResponse,
Role,
chat_middleware,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Chat Middleware Example
This sample demonstrates how to use chat middleware to observe and override
inputs sent to AI models. Chat middleware intercepts chat requests before they reach
the underlying AI service, allowing you to:
1. Observe and log input messages
2. Modify input messages before sending to AI
3. Override the entire response
The example covers:
- Class-based chat middleware inheriting from ChatMiddleware
- Function-based chat middleware with @chat_middleware decorator
- Middleware registration at agent level (applies to all runs)
- Middleware registration at run level (applies to specific run only)
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class InputObserverMiddleware(ChatMiddleware):
"""Class-based middleware that observes and modifies input messages."""
def __init__(self, replacement: str | None = None):
"""Initialize with a replacement for user messages."""
self.replacement = replacement
async def process(
self,
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
"""Observe and modify input messages before they are sent to AI."""
print("[InputObserverMiddleware] Observing input messages:")
for i, message in enumerate(context.messages):
content = message.text if message.text else str(message.contents)
print(f" Message {i + 1} ({message.role.value}): {content}")
print(f"[InputObserverMiddleware] Total messages: {len(context.messages)}")
# Modify user messages by creating new messages with enhanced text
modified_messages: list[ChatMessage] = []
modified_count = 0
for message in context.messages:
if message.role == Role.USER and message.text:
original_text = message.text
updated_text = original_text
if self.replacement:
updated_text = self.replacement
print(f"[InputObserverMiddleware] Updated: '{original_text}' -> '{updated_text}'")
modified_message = ChatMessage(role=message.role, text=updated_text)
modified_messages.append(modified_message)
modified_count += 1
else:
modified_messages.append(message)
# Replace messages in context
context.messages[:] = modified_messages
# Continue to next middleware or AI execution
await next(context)
# Observe that processing is complete
print("[InputObserverMiddleware] Processing completed")
@chat_middleware
async def security_and_override_middleware(
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
"""Function-based middleware that implements security filtering and response override."""
print("[SecurityMiddleware] Processing input...")
# Security check - block sensitive information
blocked_terms = ["password", "secret", "api_key", "token"]
for message in context.messages:
if message.text:
message_lower = message.text.lower()
for term in blocked_terms:
if term in message_lower:
print(f"[SecurityMiddleware] BLOCKED: Found '{term}' in message")
# Override the response instead of calling AI
context.result = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
text="I cannot process requests containing sensitive information. "
"Please rephrase your question without including passwords, secrets, or other "
"sensitive data.",
)
]
)
# Set terminate flag to stop execution
context.terminate = True
return
# Continue to next middleware or AI execution
await next(context)
async def class_based_chat_middleware() -> None:
"""Demonstrate class-based middleware at agent level."""
print("\n" + "=" * 60)
print("Class-based Chat Middleware (Agent Level)")
print("=" * 60)
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(async_credential=credential).create_agent(
name="EnhancedChatAgent",
instructions="You are a helpful AI assistant.",
# Register class-based middleware at agent level (applies to all runs)
middleware=InputObserverMiddleware(),
tools=get_weather,
) as agent,
):
query = "What's the weather in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Final Response: {result.text if result.text else 'No response'}")
async def function_based_chat_middleware() -> None:
"""Demonstrate function-based middleware at agent level."""
print("\n" + "=" * 60)
print("Function-based Chat Middleware (Agent Level)")
print("=" * 60)
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(async_credential=credential).create_agent(
name="FunctionMiddlewareAgent",
instructions="You are a helpful AI assistant.",
# Register function-based middleware at agent level
middleware=security_and_override_middleware,
) as agent,
):
# Scenario with normal query
print("\n--- Scenario 1: Normal Query ---")
query = "Hello, how are you?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Final Response: {result.text if result.text else 'No response'}")
# Scenario with security violation
print("\n--- Scenario 2: Security Violation ---")
query = "What is my password for this account?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Final Response: {result.text if result.text else 'No response'}")
async def run_level_middleware() -> None:
"""Demonstrate middleware registration at run level."""
print("\n" + "=" * 60)
print("Run-level Chat Middleware")
print("=" * 60)
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(async_credential=credential).create_agent(
name="RunLevelAgent",
instructions="You are a helpful AI assistant.",
tools=get_weather,
# No middleware at agent level
) as agent,
):
# Scenario 1: Run without any middleware
print("\n--- Scenario 1: No Middleware ---")
query = "What's the weather in Tokyo?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Response: {result.text if result.text else 'No response'}")
# Scenario 2: Run with specific middleware for this call only (both enhancement and security)
print("\n--- Scenario 2: With Run-level Middleware ---")
print(f"User: {query}")
result = await agent.run(
query,
middleware=[
InputObserverMiddleware(replacement="What's the weather in Madrid?"),
security_and_override_middleware,
],
)
print(f"Response: {result.text if result.text else 'No response'}")
# Scenario 3: Security test with run-level middleware
print("\n--- Scenario 3: Security Test with Run-level Middleware ---")
query = "Can you help me with my secret API key?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=security_and_override_middleware,
)
print(f"Response: {result.text if result.text else 'No response'}")
async def main() -> None:
"""Run all chat middleware examples."""
print("Chat Middleware Examples")
print("========================")
await class_based_chat_middleware()
await function_based_chat_middleware()
await run_level_middleware()
if __name__ == "__main__":
asyncio.run(main())
@@ -1,28 +1,37 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterable, Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import FunctionInvocationContext
from agent_framework import (
AgentRunContext,
AgentRunResponse,
AgentRunResponseUpdate,
ChatMessage,
Role,
TextContent,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Result Override with Middleware
Result Override with Middleware (Regular and Streaming)
This sample demonstrates how to use middleware to intercept and modify function results
after execution. The example shows:
after execution, supporting both regular and streaming agent responses. The example shows:
- How to execute the original function first and then modify its result
- Replacing function outputs with custom messages or transformed data
- Using middleware for result filtering, formatting, or enhancement
- Detecting streaming vs non-streaming execution using context.is_streaming
- Overriding streaming results with custom async generators
The weather override middleware lets the original weather function execute normally,
then replaces its result with a custom "perfect weather" message, demonstrating
how middleware can be used for content filtering, A/B testing, or result enhancement.
then replaces its result with a custom "perfect weather" message. For streaming responses,
it creates a custom async generator that yields the override message in chunks.
"""
@@ -35,32 +44,39 @@ def get_weather(
async def weather_override_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
function_name = context.function.name
"""Middleware that overrides weather results for both streaming and non-streaming cases."""
# Let the original function execute first
# Let the original agent execution complete first
await next(context)
# Override the result if it's a weather function
if function_name == "get_weather" and context.result is not None:
original_result = str(context.result)
print(f"[WeatherOverrideMiddleware] Original result: {original_result}")
# Check if there's a result to override (agent called weather function)
if context.result is not None:
# Create custom weather message
chunks = [
"Weather Advisory - ",
"due to special atmospheric conditions, ",
"all locations are experiencing perfect weather today! ",
"Temperature is a comfortable 22°C with gentle breezes. ",
"Perfect day for outdoor activities!",
]
# Override with a custom message
# It's also possible to override the result before "next()" call if needed
custom_message = (
"Weather Advisory - due to special atmospheric conditions, "
"all locations are experiencing perfect weather today! "
"Temperature is a comfortable 22°C with gentle breezes. "
"Perfect day for outdoor activities!"
)
context.result = custom_message
print(f"[WeatherOverrideMiddleware] Overriding with custom message: {custom_message}")
if context.is_streaming:
# For streaming: create an async generator that yields chunks
async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]:
for chunk in chunks:
yield AgentRunResponseUpdate(contents=[TextContent(text=chunk)])
context.result = override_stream()
else:
# For non-streaming: just replace with the string message
custom_message = "".join(chunks)
context.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=custom_message)])
async def main() -> None:
"""Example demonstrating result override with middleware."""
"""Example demonstrating result override with middleware for both streaming and non-streaming."""
print("=== Result Override Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
@@ -74,11 +90,22 @@ async def main() -> None:
middleware=weather_override_middleware,
) as agent,
):
# Non-streaming example
print("\n--- Non-streaming Example ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}")
# Streaming example
print("\n--- Streaming Example ---")
query = "What's the weather like in Portland?"
print(f"User: {query}")
print("Agent: ", end="", flush=True)
async for chunk in agent.run_stream(query):
if chunk.text:
print(chunk.text, end="", flush=True)
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,128 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
FunctionInvocationContext,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Shared State Function-based Middleware Example
This sample demonstrates how to implement function-based middleware within a class to share state.
The example includes:
- A MiddlewareContainer class with two simple function middleware methods
- First middleware: Counts function calls and stores the count in shared state
- Second middleware: Uses the shared count to add call numbers to function results
This approach shows how middleware can work together by sharing state within the same class instance.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
def get_time(
timezone: Annotated[str, Field(description="The timezone to get the time for.")] = "UTC",
) -> str:
"""Get the current time for a given timezone."""
import datetime
return f"The current time in {timezone} is {datetime.datetime.now().strftime('%H:%M:%S')}"
class MiddlewareContainer:
"""Container class that holds middleware functions with shared state."""
def __init__(self) -> None:
# Simple shared state: count function calls
self.call_count: int = 0
async def call_counter_middleware(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""First middleware: increments call count in shared state."""
# Increment the shared call count
self.call_count += 1
print(f"[CallCounter] This is function call #{self.call_count}")
# Call the next middleware/function
await next(context)
async def result_enhancer_middleware(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Second middleware: uses shared call count to enhance function results."""
print(f"[ResultEnhancer] Current total calls so far: {self.call_count}")
# Call the next middleware/function
await next(context)
# After function execution, enhance the result using shared state
if context.result:
enhanced_result = f"[Call #{self.call_count}] {context.result}"
context.result = enhanced_result
print("[ResultEnhancer] Enhanced result with call number")
async def main() -> None:
"""Example demonstrating shared state function-based middleware."""
print("=== Shared State Function-based Middleware Example ===")
# Create middleware container with shared state
middleware_container = MiddlewareContainer()
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(async_credential=credential).create_agent(
name="UtilityAgent",
instructions="You are a helpful assistant that can provide weather information and current time.",
tools=[get_weather, get_time],
# Pass both middleware functions from the same container instance
# Order matters: counter runs first to increment count,
# then result enhancer uses the updated count
middleware=[
middleware_container.call_counter_middleware,
middleware_container.result_enhancer_middleware,
],
) as agent,
):
# Test multiple requests to see shared state in action
queries = [
"What's the weather like in New York?",
"What time is it in London?",
"What's the weather in Tokyo?",
]
for i, query in enumerate(queries, 1):
print(f"\n--- Query {i} ---")
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
# Display final statistics
print("\n=== Final Statistics ===")
print(f"Total function calls made: {middleware_container.call_count}")
if __name__ == "__main__":
asyncio.run(main())