Python: Added chat middleware and more examples (#883)

* Added example with stateful middleware

* Added chat middleware

* Updated middleware example with override scenario

* Small revert

* Small fixes

* Added kwargs to context objects

* Added README

* Added function middleware to chat client

* Small refactoring

* Reverted example files

* Made MiddlewareWrapper generic

* Added Middleware exception

* Small refactoring

* Small fix
This commit is contained in:
Dmytro Struk
2025-09-26 08:10:56 -07:00
committed by GitHub
Unverified
parent 863c8d7471
commit eec7f192eb
19 changed files with 2667 additions and 267 deletions
@@ -34,6 +34,7 @@ from agent_framework import (
UsageContent,
UsageDetails,
get_logger,
use_chat_middleware,
use_function_invocation,
)
from agent_framework._pydantic import AFBaseSettings
@@ -123,6 +124,7 @@ TAzureAIAgentClient = TypeVar("TAzureAIAgentClient", bound="AzureAIAgentClient")
@use_function_invocation
@use_observability
@use_chat_middleware
class AzureAIAgentClient(BaseChatClient):
"""Azure AI Agent Chat client."""
@@ -10,7 +10,13 @@ from pydantic import BaseModel, Field
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, ContextProvider
from ._middleware import Middleware
from ._middleware import (
ChatMiddleware,
ChatMiddlewareCallable,
FunctionMiddleware,
FunctionMiddlewareCallable,
Middleware,
)
from ._pydantic import AFBaseModel
from ._threads import ChatMessageStore
from ._tools import ToolProtocol
@@ -189,6 +195,14 @@ class BaseChatClient(AFBaseModel, ABC):
"""Base class for chat clients."""
additional_properties: dict[str, Any] = Field(default_factory=dict)
middleware: (
ChatMiddleware
| ChatMiddlewareCallable
| FunctionMiddleware
| FunctionMiddlewareCallable
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
| None
) = None
OTEL_PROVIDER_NAME: str = "unknown"
# This is used for OTel setup, should be overridden in subclasses
@@ -346,13 +360,7 @@ class BaseChatClient(AFBaseModel, ABC):
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
async def get_streaming_response(
self,
@@ -432,14 +440,8 @@ class BaseChatClient(AFBaseModel, ABC):
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
if "_function_middleware_pipeline" in kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
else:
filtered_kwargs = kwargs
async for update in self._inner_get_streaming_response(
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
messages=prepped_messages, chat_options=chat_options, **kwargs
):
yield update
File diff suppressed because it is too large Load Diff
+23 -4
View File
@@ -607,6 +607,7 @@ async def _auto_invoke_function(
middleware_context = FunctionInvocationContext(
function=tool,
arguments=args,
kwargs=custom_args or {},
)
async def final_function_handler(context_obj: Any) -> Any:
@@ -721,8 +722,16 @@ def _handle_function_calls_response(
**kwargs: Any,
) -> "ChatResponse":
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import ChatMessage, ChatOptions, FunctionCallContent, FunctionResultContent
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
prepped_messages = prepare_messages(messages)
response: "ChatResponse | None" = None
fcc_messages: "list[ChatMessage]" = []
@@ -746,8 +755,9 @@ def _handle_function_calls_response(
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
middleware_pipeline = stored_middleware_pipeline
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
@@ -820,8 +830,16 @@ def _handle_function_calls_streaming_response(
) -> AsyncIterable["ChatResponseUpdate"]:
"""Wrap the inner get streaming response method to handle tool calls."""
from ._clients import prepare_messages
from ._middleware import extract_and_merge_function_middleware
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
prepped_messages = prepare_messages(messages)
for attempt_idx in range(max_iterations):
all_updates: list["ChatResponseUpdate"] = []
@@ -858,8 +876,9 @@ def _handle_function_calls_streaming_response(
tools = chat_options.tools
if function_calls and tools:
# Extract function middleware pipeline from kwargs if available
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
middleware_pipeline = stored_middleware_pipeline
function_results = await execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
@@ -13,16 +13,18 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from pydantic import SecretStr, ValidationError
from pydantic.networks import AnyUrl
from .._tools import use_function_invocation
from .._types import (
from agent_framework import (
ChatResponse,
ChatResponseUpdate,
CitationAnnotation,
TextContent,
use_chat_middleware,
use_function_invocation,
)
from ..exceptions import ServiceInitializationError
from ..observability import use_observability
from ..openai._chat_client import OpenAIBaseChatClient
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import use_observability
from agent_framework.openai._chat_client import OpenAIBaseChatClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
@@ -41,6 +43,7 @@ TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAICha
@use_function_invocation
@use_observability
@use_chat_middleware
class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient):
"""Azure OpenAI Chat completion class."""
@@ -9,10 +9,11 @@ from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from pydantic import SecretStr, ValidationError
from pydantic.networks import AnyUrl
from .._tools import use_function_invocation
from ..exceptions import ServiceInitializationError
from ..observability import use_observability
from ..openai._responses_client import OpenAIBaseResponsesClient
from agent_framework import use_chat_middleware, use_function_invocation
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import use_observability
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
@@ -23,6 +24,7 @@ TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="Azur
@use_observability
@use_function_invocation
@use_chat_middleware
class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient):
"""Azure Responses completion class."""
@@ -128,3 +128,9 @@ class AdditionItemMismatch(AgentFrameworkException):
"""An error occurred while adding two types."""
pass
class MiddlewareException(AgentFrameworkException):
"""An error occurred during middleware execution."""
pass
@@ -21,6 +21,7 @@ from openai.types.beta.threads.runs import RunStep
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._middleware import use_chat_middleware
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation
from .._types import (
ChatMessage,
@@ -52,6 +53,7 @@ __all__ = ["OpenAIAssistantsClient"]
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
"""OpenAI Assistants client."""
@@ -18,6 +18,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import use_chat_middleware
from .._tools import AIFunction, HostedWebSearchTool, ToolProtocol, use_function_invocation
from .._types import (
ChatMessage,
@@ -452,6 +453,7 @@ TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient")
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
"""OpenAI Chat completion class."""
@@ -26,6 +26,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import use_chat_middleware
from .._tools import (
AIFunction,
HostedCodeInterpreterTool,
@@ -933,6 +934,7 @@ TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponse
@use_function_invocation
@use_observability
@use_chat_middleware
class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
"""OpenAI Responses client class."""
@@ -25,6 +25,7 @@ from agent_framework import (
TextContent,
ToolProtocol,
ai_function,
use_chat_middleware,
use_function_invocation,
)
@@ -111,11 +112,13 @@ class MockChatClient:
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
@use_chat_middleware
class MockBaseChatClient(BaseChatClient):
"""Mock implementation of the BaseChatClient."""
run_responses: list[ChatResponse] = Field(default_factory=list)
streaming_responses: list[list[ChatResponseUpdate]] = Field(default_factory=list)
call_count: int = Field(default=0)
@override
async def _inner_get_response(
@@ -136,6 +139,7 @@ class MockBaseChatClient(BaseChatClient):
The chat response contents representing the response(s).
"""
logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}")
self.call_count += 1
if not self.run_responses:
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
@@ -12,6 +12,8 @@ from agent_framework import (
AgentRunResponse,
AgentRunResponseUpdate,
ChatMessage,
ChatResponse,
ChatResponseUpdate,
Role,
TextContent,
)
@@ -19,11 +21,15 @@ from agent_framework._middleware import (
AgentMiddleware,
AgentMiddlewarePipeline,
AgentRunContext,
ChatContext,
ChatMiddleware,
ChatMiddlewarePipeline,
FunctionInvocationContext,
FunctionMiddleware,
FunctionMiddlewarePipeline,
)
from agent_framework._tools import AIFunction
from agent_framework._types import ChatOptions
class TestAgentRunContext:
@@ -74,6 +80,46 @@ class TestFunctionInvocationContext:
assert context.metadata == metadata
class TestChatContext:
"""Test cases for ChatContext."""
def test_init_with_defaults(self, mock_chat_client: Any) -> None:
"""Test ChatContext initialization with default values."""
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
assert context.chat_client is mock_chat_client
assert context.messages == messages
assert context.chat_options is chat_options
assert context.is_streaming is False
assert context.metadata == {}
assert context.result is None
assert context.terminate is False
def test_init_with_custom_values(self, mock_chat_client: Any) -> None:
"""Test ChatContext initialization with custom values."""
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions(temperature=0.5)
metadata = {"key": "value"}
context = ChatContext(
chat_client=mock_chat_client,
messages=messages,
chat_options=chat_options,
is_streaming=True,
metadata=metadata,
terminate=True,
)
assert context.chat_client is mock_chat_client
assert context.messages == messages
assert context.chat_options is chat_options
assert context.is_streaming is True
assert context.metadata == metadata
assert context.terminate is True
class TestAgentMiddlewarePipeline:
"""Test cases for AgentMiddlewarePipeline."""
@@ -410,6 +456,233 @@ class TestFunctionMiddlewarePipeline:
assert execution_order == ["test_before", "handler", "test_after"]
class TestChatMiddlewarePipeline:
"""Test cases for ChatMiddlewarePipeline."""
class PreNextTerminateChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
context.terminate = True
await next(context)
class PostNextTerminateChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
await next(context)
context.terminate = True
def test_init_empty(self) -> None:
"""Test ChatMiddlewarePipeline initialization with no middlewares."""
pipeline = ChatMiddlewarePipeline()
assert not pipeline.has_middlewares
def test_init_with_class_middleware(self) -> None:
"""Test ChatMiddlewarePipeline initialization with class-based middleware."""
middleware = TestChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
assert pipeline.has_middlewares
def test_init_with_function_middleware(self) -> None:
"""Test ChatMiddlewarePipeline initialization with function-based middleware."""
async def test_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
await next(context)
pipeline = ChatMiddlewarePipeline([test_middleware])
assert pipeline.has_middlewares
async def test_execute_no_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with no middleware."""
pipeline = ChatMiddlewarePipeline()
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
async def final_handler(ctx: ChatContext) -> ChatResponse:
return expected_response
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result == expected_response
async def test_execute_with_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with middleware."""
execution_order: list[str] = []
class OrderTrackingChatMiddleware(ChatMiddleware):
def __init__(self, name: str):
self.name = name
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
middleware = OrderTrackingChatMiddleware("test")
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return expected_response
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result == expected_response
assert execution_order == ["test_before", "handler", "test_after"]
async def test_execute_stream_no_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with no middleware."""
pipeline = ChatMiddlewarePipeline()
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
async def test_execute_stream_with_middleware(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with middleware."""
execution_order: list[str] = []
class StreamOrderTrackingChatMiddleware(ChatMiddleware):
def __init__(self, name: str):
self.name = name
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append(f"{self.name}_before")
await next(context)
execution_order.append(f"{self.name}_after")
middleware = StreamOrderTrackingChatMiddleware("test")
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
execution_order.append("handler_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
async def test_execute_with_pre_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with termination before next()."""
middleware = self.PreNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> ChatResponse:
# Handler should not be executed when terminated before next()
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert response is None
assert context.terminate
# Handler should not be called when terminated before next()
assert execution_order == []
async def test_execute_with_post_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline execution with termination after next()."""
middleware = self.PostNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert response is not None
assert len(response.messages) == 1
assert response.messages[0].text == "response"
assert context.terminate
assert execution_order == ["handler"]
async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with termination before next()."""
middleware = self.PreNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
# Handler should not be executed when terminated before next()
execution_order.append("handler_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert context.terminate
# Handler should not be called when terminated before next()
assert execution_order == []
assert not updates
async def test_execute_stream_with_post_next_termination(self, mock_chat_client: Any) -> None:
"""Test pipeline streaming execution with termination after next()."""
middleware = self.PostNextTerminateChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
execution_order: list[str] = []
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
execution_order.append("handler_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
assert context.terminate
assert execution_order == ["handler_start", "handler_end"]
class TestClassBasedMiddleware:
"""Test cases for class-based middleware implementations."""
@@ -601,6 +874,37 @@ class TestMixedMiddleware:
assert result == "result"
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
async def test_mixed_chat_middleware(self, mock_chat_client: Any) -> None:
"""Test mixed class and function-based chat middleware."""
execution_order: list[str] = []
class ClassChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("class_before")
await next(context)
execution_order.append("class_after")
async def function_chat_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
execution_order.append("function_before")
await next(context)
execution_order.append("function_after")
pipeline = ChatMiddlewarePipeline([ClassChatMiddleware(), function_chat_middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result is not None
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
class TestMultipleMiddlewareOrdering:
"""Test cases for multiple middleware execution order."""
@@ -695,6 +999,52 @@ class TestMultipleMiddlewareOrdering:
expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"]
assert execution_order == expected_order
async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> None:
"""Test that multiple chat middlewares execute in registration order."""
execution_order: list[str] = []
class FirstChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
class SecondChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
class ThirdChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("third_before")
await next(context)
execution_order.append("third_after")
middlewares = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()]
pipeline = ChatMiddlewarePipeline(middlewares) # type: ignore
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
execution_order.append("handler")
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result is not None
expected_order = [
"first_before",
"second_before",
"third_before",
"handler",
"third_after",
"second_after",
"first_after",
]
assert execution_order == expected_order
class TestContextContentValidation:
"""Test cases for validating middleware context content."""
@@ -776,6 +1126,49 @@ class TestContextContentValidation:
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "result"
async def test_chat_context_validation(self, mock_chat_client: Any) -> None:
"""Test that chat context contains expected data."""
class ChatContextValidationMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Verify context has all expected attributes
assert hasattr(context, "chat_client")
assert hasattr(context, "messages")
assert hasattr(context, "chat_options")
assert hasattr(context, "is_streaming")
assert hasattr(context, "metadata")
assert hasattr(context, "result")
assert hasattr(context, "terminate")
# Verify context content
assert context.chat_client is mock_chat_client
assert len(context.messages) == 1
assert context.messages[0].role == Role.USER
assert context.messages[0].text == "test"
assert context.is_streaming is False
assert isinstance(context.metadata, dict)
assert isinstance(context.chat_options, ChatOptions)
assert context.chat_options.temperature == 0.5
# Add custom metadata
context.metadata["validated"] = True
await next(context)
middleware = ChatContextValidationMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions(temperature=0.5)
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
# Verify metadata was set by middleware
assert ctx.metadata.get("validated") is True
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
assert result is not None
class TestStreamingScenarios:
"""Test cases for streaming and non-streaming scenarios."""
@@ -857,6 +1250,89 @@ class TestStreamingScenarios:
"stream_end",
]
async def test_chat_streaming_flag_validation(self, mock_chat_client: Any) -> None:
"""Test that is_streaming flag is correctly set for chat streaming calls."""
streaming_flags: list[bool] = []
class ChatStreamingFlagMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
streaming_flags.append(context.is_streaming)
await next(context)
middleware = ChatStreamingFlagMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
# Test non-streaming
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
async def final_handler(ctx: ChatContext) -> ChatResponse:
streaming_flags.append(ctx.is_streaming)
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
# Test streaming
context_stream = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
streaming_flags.append(ctx.is_streaming)
yield ChatResponseUpdate(contents=[TextContent(text="chunk")])
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(
mock_chat_client, messages, chat_options, context_stream, final_stream_handler
):
updates.append(update)
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
assert streaming_flags == [False, False, True, True]
async def test_chat_streaming_middleware_behavior(self, mock_chat_client: Any) -> None:
"""Test chat middleware behavior with streaming responses."""
chunks_processed: list[str] = []
class ChatStreamProcessingMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
chunks_processed.append("before_stream")
await next(context)
chunks_processed.append("after_stream")
middleware = ChatStreamProcessingMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
chunks_processed.append("stream_start")
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
chunks_processed.append("chunk1_yielded")
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
chunks_processed.append("chunk2_yielded")
chunks_processed.append("stream_end")
updates: list[str] = []
async for update in pipeline.execute_stream(
mock_chat_client, messages, chat_options, context, final_stream_handler
):
updates.append(update.text)
assert updates == ["chunk1", "chunk2"]
assert chunks_processed == [
"before_stream",
"after_stream",
"stream_start",
"chunk1_yielded",
"chunk2_yielded",
"stream_end",
]
# region Helper classes and fixtures
@@ -883,6 +1359,13 @@ class TestFunctionMiddleware(FunctionMiddleware):
await next(context)
class TestChatMiddleware(ChatMiddleware):
"""Test implementation of ChatMiddleware."""
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
await next(context)
class MockFunctionArgs(BaseModel):
"""Test arguments for function middleware tests."""
@@ -1027,6 +1510,100 @@ class TestMiddlewareExecutionControl:
assert result.messages == [] # Empty response
assert not handler_called
async def test_chat_middleware_no_next_no_execution(self, mock_chat_client: Any) -> None:
"""Test that when chat middleware doesn't call next(), no execution happens."""
class NoNextChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
middleware = NoNextChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
handler_called = False
async def final_handler(ctx: ChatContext) -> ChatResponse:
nonlocal handler_called
handler_called = True
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
# Verify no execution happened
assert result is None
assert not handler_called
assert context.result is None
async def test_chat_middleware_no_next_no_streaming_execution(self, mock_chat_client: Any) -> None:
"""Test that when chat middleware doesn't call next(), no streaming execution happens."""
class NoNextStreamingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Don't call next() - this should prevent any execution
pass
middleware = NoNextStreamingChatMiddleware()
pipeline = ChatMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
)
handler_called = False
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
nonlocal handler_called
handler_called = True
yield ChatResponseUpdate(contents=[TextContent(text="should not execute")])
# When middleware doesn't call next(), streaming should yield no updates
updates: list[ChatResponseUpdate] = []
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
updates.append(update)
# Verify no execution happened and no updates were yielded
assert len(updates) == 0
assert not handler_called
assert context.result is None
async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) -> None:
"""Test that when first chat middleware doesn't call next(), subsequent middlewares are not called."""
execution_order: list[str] = []
class FirstChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first")
# Don't call next() - this should stop the pipeline
class SecondChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second")
await next(context)
pipeline = ChatMiddlewarePipeline([FirstChatMiddleware(), SecondChatMiddleware()])
messages = [ChatMessage(role=Role.USER, text="test")]
chat_options = ChatOptions()
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
handler_called = False
async def final_handler(ctx: ChatContext) -> ChatResponse:
nonlocal handler_called
handler_called = True
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
# Verify only first middleware was called and no result returned
assert execution_order == ["first"]
assert result is None
assert not handler_called
@pytest.fixture
def mock_agent() -> AgentProtocol:
@@ -1042,3 +1619,13 @@ def mock_function() -> AIFunction[Any, Any]:
function = MagicMock(spec=AIFunction[Any, Any])
function.name = "test_function"
return function
@pytest.fixture
def mock_chat_client() -> Any:
"""Mock chat client for testing."""
from agent_framework._clients import ChatClientProtocol
client = MagicMock(spec=ChatClientProtocol)
client.service_url = MagicMock(return_value="mock://test")
return client
@@ -8,7 +8,9 @@ import pytest
from agent_framework import (
AgentRunResponseUpdate,
ChatAgent,
ChatContext,
ChatMessage,
ChatMiddleware,
ChatResponse,
ChatResponseUpdate,
FunctionCallContent,
@@ -16,7 +18,9 @@ from agent_framework import (
Role,
TextContent,
agent_middleware,
chat_middleware,
function_middleware,
use_function_invocation,
)
from agent_framework._middleware import (
AgentMiddleware,
@@ -25,8 +29,9 @@ from agent_framework._middleware import (
FunctionMiddleware,
MiddlewareType,
)
from agent_framework.exceptions import MiddlewareException
from .conftest import MockChatClient
from .conftest import MockBaseChatClient, MockChatClient
# region ChatAgent Tests
@@ -713,6 +718,101 @@ class TestChatAgentFunctionMiddlewareWithTools:
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
async def test_function_middleware_can_access_and_override_custom_kwargs(
self, chat_client: "MockChatClient"
) -> None:
"""Test that function middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
middleware_called = False
@function_middleware
async def kwargs_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
nonlocal middleware_called
middleware_called = True
# Capture the original kwargs
captured_kwargs["has_chat_options"] = "chat_options" in context.kwargs
captured_kwargs["has_custom_param"] = "custom_param" in context.kwargs
captured_kwargs["custom_param"] = context.kwargs.get("custom_param")
# Capture original chat_options values if present
if "chat_options" in context.kwargs:
chat_options = context.kwargs["chat_options"]
captured_kwargs["original_temperature"] = getattr(chat_options, "temperature", None)
captured_kwargs["original_max_tokens"] = getattr(chat_options, "max_tokens", None)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
# Also modify chat_options if present
if "chat_options" in context.kwargs:
context.kwargs["chat_options"].temperature = 0.9
context.kwargs["chat_options"].max_tokens = 500
# Store modified kwargs for verification
modified_kwargs["temperature"] = context.kwargs.get("temperature")
modified_kwargs["max_tokens"] = context.kwargs.get("max_tokens")
modified_kwargs["new_param"] = context.kwargs.get("new_param")
modified_kwargs["custom_param"] = context.kwargs.get("custom_param")
# Capture modified chat_options values if present
if "chat_options" in context.kwargs:
chat_options = context.kwargs["chat_options"]
modified_kwargs["chat_options_temperature"] = getattr(chat_options, "temperature", None)
modified_kwargs["chat_options_max_tokens"] = getattr(chat_options, "max_tokens", None)
await next(context)
chat_client.responses = [
ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="test_call", name="sample_tool_function", arguments={"location": "Seattle"}
)
],
)
]
),
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Function completed")])]),
]
# Create ChatAgent with function middleware
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware], tools=[sample_tool_function])
# Execute the agent with custom parameters
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
# Verify response
assert response is not None
assert len(response.messages) > 0
# First check if middleware was called at all
assert middleware_called, "Function middleware was not called"
# Verify middleware captured the original kwargs
assert captured_kwargs["has_chat_options"] is True
assert captured_kwargs["has_custom_param"] is True
assert captured_kwargs["custom_param"] == "test_value"
assert captured_kwargs["original_temperature"] == 0.7
assert captured_kwargs["original_max_tokens"] == 100
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value"
assert modified_kwargs["chat_options_temperature"] == 0.9
assert modified_kwargs["chat_options_max_tokens"] == 500
class TestMiddlewareDynamicRebuild:
"""Test cases for dynamic middleware pipeline rebuilding with ChatAgent."""
@@ -1187,8 +1287,8 @@ class TestMiddlewareDecoratorLogic:
"""Both decorator and parameter type specified but don't match."""
# This will cause a type error at decoration time, so we need to test differently
# Should raise ValueError due to mismatch during agent creation
with pytest.raises(ValueError, match="Middleware type mismatch"):
# Should raise MiddlewareException due to mismatch during agent creation
with pytest.raises(MiddlewareException, match="Middleware type mismatch"):
@agent_middleware # type: ignore[arg-type]
async def mismatched_middleware(
@@ -1304,8 +1404,8 @@ class TestMiddlewareDecoratorLogic:
async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type
await next(context)
# Should raise ValueError
with pytest.raises(ValueError, match="Cannot determine middleware type"):
# Should raise MiddlewareException
with pytest.raises(MiddlewareException, match="Cannot determine middleware type"):
agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware])
await agent.run([ChatMessage(role=Role.USER, text="test")])
@@ -1313,8 +1413,8 @@ class TestMiddlewareDecoratorLogic:
"""Test that middleware with insufficient parameters raises an error."""
from agent_framework import ChatAgent, agent_middleware
# Should raise ValueError about insufficient parameters
with pytest.raises(ValueError, match="must have at least 2 parameters"):
# Should raise MiddlewareException about insufficient parameters
with pytest.raises(MiddlewareException, match="must have at least 2 parameters"):
@agent_middleware # type: ignore[arg-type]
async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter
@@ -1340,3 +1440,359 @@ class TestMiddlewareDecoratorLogic:
assert hasattr(test_function_middleware, "_middleware_type")
assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined]
class TestChatAgentChatMiddleware:
"""Test cases for chat middleware integration with ChatAgent."""
async def test_class_based_chat_middleware_with_chat_agent(self) -> None:
"""Test class-based chat middleware with ChatAgent."""
execution_order: list[str] = []
class TrackingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
# Create ChatAgent with chat middleware
chat_client = MockBaseChatClient()
middleware = TrackingChatMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
assert "test response" in response.messages[0].text
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
async def test_function_based_chat_middleware_with_chat_agent(self) -> None:
"""Test function-based chat middleware with ChatAgent."""
execution_order: list[str] = []
async def tracking_chat_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
# Create ChatAgent with function-based chat middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_chat_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
assert "test response" in response.messages[0].text
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
async def test_chat_middleware_can_modify_messages(self) -> None:
"""Test that chat middleware can modify messages before sending to model."""
@chat_middleware
async def message_modifier_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Modify the first message by adding a prefix
if context.messages and len(context.messages) > 0:
original_text = context.messages[0].text or ""
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
await next(context)
# Create ChatAgent with message-modifying middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[message_modifier_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify that the message was modified (MockBaseChatClient echoes back the input)
assert response is not None
assert len(response.messages) > 0
assert "MODIFIED: test message" in response.messages[0].text
async def test_chat_middleware_can_override_response(self) -> None:
"""Test that chat middleware can override the response."""
@chat_middleware
async def response_override_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Override the response without calling next()
context.result = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
response_id="middleware-response-123",
)
context.terminate = True
# Create ChatAgent with response-overriding middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[response_override_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify that the response was overridden
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].text == "Middleware overridden response"
assert response.response_id == "middleware-response-123"
async def test_multiple_chat_middleware_execution_order(self) -> None:
"""Test that multiple chat middleware execute in the correct order."""
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
# Create ChatAgent with multiple chat middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[first_middleware, second_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert execution_order == ["first_before", "second_before", "second_after", "first_after"]
async def test_chat_middleware_with_streaming(self) -> None:
"""Test chat middleware with streaming responses."""
execution_order: list[str] = []
streaming_flags: list[bool] = []
class StreamingTrackingChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("streaming_chat_before")
streaming_flags.append(context.is_streaming)
await next(context)
execution_order.append("streaming_chat_after")
# Create ChatAgent with chat middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[StreamingTrackingChatMiddleware()])
# Set up mock streaming responses
chat_client.streaming_responses = [
[
ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT),
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
]
]
# Execute streaming
messages = [ChatMessage(role=Role.USER, text="test message")]
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream(messages):
updates.append(update)
# Verify streaming response
assert len(updates) >= 1 # At least some updates
assert execution_order == ["streaming_chat_before", "streaming_chat_after"]
# Verify streaming flag was set (at least one True)
assert True in streaming_flags
async def test_chat_middleware_termination_before_execution(self) -> None:
"""Test that chat middleware can terminate execution before calling next()."""
execution_order: list[str] = []
class PreTerminationChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
context.terminate = True
# Set a custom response since we're terminating
context.result = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Terminated by middleware")]
)
# We call next() but since terminate=True, execution should stop
await next(context)
execution_order.append("middleware_after")
# Create ChatAgent with terminating middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[PreTerminationChatMiddleware()])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response was from middleware
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].text == "Terminated by middleware"
assert execution_order == ["middleware_before", "middleware_after"]
async def test_chat_middleware_termination_after_execution(self) -> None:
"""Test that chat middleware can terminate execution after calling next()."""
execution_order: list[str] = []
class PostTerminationChatMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("middleware_before")
await next(context)
execution_order.append("middleware_after")
context.terminate = True
# Create ChatAgent with terminating middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[PostTerminationChatMiddleware()])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response is from actual execution
assert response is not None
assert len(response.messages) > 0
assert "test response" in response.messages[0].text
assert execution_order == ["middleware_before", "middleware_after"]
async def test_combined_middleware(self) -> None:
"""Test ChatAgent with combined middleware types."""
execution_order: list[str] = []
async def agent_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("agent_middleware_before")
await next(context)
execution_order.append("agent_middleware_after")
async def chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
async def function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("function_middleware_before")
await next(context)
execution_order.append("function_middleware_after")
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_456",
name="sample_tool_function",
arguments='{"location": "San Francisco"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client = use_function_invocation(MockBaseChatClient)()
chat_client.run_responses = [function_call_response, final_response]
# Create ChatAgent with function middleware and tools
agent = ChatAgent(
chat_client=chat_client,
middleware=[chat_middleware, function_middleware, agent_middleware],
tools=[sample_tool_function],
)
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
# Verify function middleware was executed
assert execution_order == [
"agent_middleware_before",
"chat_middleware_before",
"chat_middleware_after",
"function_middleware_before",
"function_middleware_after",
"chat_middleware_before",
"chat_middleware_after",
"agent_middleware_after",
]
# Verify function call and result are in the response
all_contents = [content for message in response.messages for content in message.contents]
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
assert len(function_calls) == 1
assert len(function_results) == 1
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None:
"""Test that agent middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
@agent_middleware
async def kwargs_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
await next(context)
# Create ChatAgent with agent middleware
chat_client = MockBaseChatClient()
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware])
# Execute the agent with custom parameters
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
# Verify response
assert response is not None
assert len(response.messages) > 0
# Verify middleware captured the original kwargs
assert captured_kwargs["temperature"] == 0.7
assert captured_kwargs["max_tokens"] == 100
assert captured_kwargs["custom_param"] == "test_value"
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
@@ -0,0 +1,436 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Awaitable, Callable
from typing import Any
from agent_framework import (
ChatAgent,
ChatContext,
ChatMessage,
ChatMiddleware,
ChatResponse,
FunctionCallContent,
FunctionInvocationContext,
Role,
chat_middleware,
function_middleware,
use_function_invocation,
)
from .conftest import MockBaseChatClient
class TestChatMiddleware:
"""Test cases for chat middleware functionality."""
async def test_class_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test class-based chat middleware with ChatClient."""
execution_order: list[str] = []
class LoggingChatMiddleware(ChatMiddleware):
async def process(
self,
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
execution_order.append("chat_middleware_before")
await next(context)
execution_order.append("chat_middleware_after")
# Add middleware to chat client
chat_client_base.middleware = [LoggingChatMiddleware()]
# Execute chat client directly
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
# Verify middleware execution order
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
async def test_function_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test function-based chat middleware with ChatClient."""
execution_order: list[str] = []
@chat_middleware
async def logging_chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("function_middleware_before")
await next(context)
execution_order.append("function_middleware_after")
# Add middleware to chat client
chat_client_base.middleware = [logging_chat_middleware]
# Execute chat client directly
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
# Verify middleware execution order
assert execution_order == ["function_middleware_before", "function_middleware_after"]
async def test_chat_middleware_can_modify_messages(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that chat middleware can modify messages before sending to model."""
@chat_middleware
async def message_modifier_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Modify the first message by adding a prefix
if context.messages and len(context.messages) > 0:
original_text = context.messages[0].text or ""
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
await next(context)
# Add middleware to chat client
chat_client_base.middleware = [message_modifier_middleware]
# Execute chat client
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify that the message was modified (MockChatClient echoes back the input)
assert response is not None
assert len(response.messages) > 0
# The mock client should receive the modified message
assert "MODIFIED: test message" in response.messages[0].text
async def test_chat_middleware_can_override_response(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that chat middleware can override the response."""
@chat_middleware
async def response_override_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
# Override the response without calling next()
context.result = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
response_id="middleware-response-123",
)
context.terminate = True
# Add middleware to chat client
chat_client_base.middleware = [response_override_middleware]
# Execute chat client
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify that the response was overridden
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].text == "Middleware overridden response"
assert response.response_id == "middleware-response-123"
async def test_multiple_chat_middleware_execution_order(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that multiple chat middleware execute in the correct order."""
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
# Add middleware to chat client (order should be preserved)
chat_client_base.middleware = [first_middleware, second_middleware]
# Execute chat client
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(messages)
# Verify response
assert response is not None
# Verify middleware execution order (nested execution)
expected_order = ["first_before", "second_before", "second_after", "first_after"]
assert execution_order == expected_order
async def test_chat_agent_with_chat_middleware(self) -> None:
"""Test ChatAgent with chat middleware specified at agent level."""
execution_order: list[str] = []
@chat_middleware
async def agent_level_chat_middleware(
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
execution_order.append("agent_chat_middleware_before")
await next(context)
execution_order.append("agent_chat_middleware_after")
chat_client = MockBaseChatClient()
# Create ChatAgent with chat middleware
agent = ChatAgent(chat_client=chat_client, middleware=[agent_level_chat_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
# Verify middleware execution order
assert execution_order == ["agent_chat_middleware_before", "agent_chat_middleware_after"]
async def test_chat_agent_with_multiple_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that ChatAgent can have multiple chat middleware."""
execution_order: list[str] = []
@chat_middleware
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("first_before")
await next(context)
execution_order.append("first_after")
@chat_middleware
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("second_before")
await next(context)
execution_order.append("second_after")
# Create ChatAgent with multiple chat middleware
agent = ChatAgent(chat_client=chat_client_base, middleware=[first_middleware, second_middleware])
# Execute the agent
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await agent.run(messages)
# Verify response
assert response is not None
# Verify both middleware executed (nested execution order)
expected_order = ["first_before", "second_before", "second_after", "first_after"]
assert execution_order == expected_order
async def test_chat_middleware_with_streaming(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test chat middleware with streaming responses."""
execution_order: list[str] = []
@chat_middleware
async def streaming_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_order.append("streaming_before")
# Verify it's a streaming context
assert context.is_streaming is True
await next(context)
execution_order.append("streaming_after")
# Add middleware to chat client
chat_client_base.middleware = [streaming_middleware]
# Execute streaming response
messages = [ChatMessage(role=Role.USER, text="test message")]
updates: list[object] = []
async for update in chat_client_base.get_streaming_response(messages):
updates.append(update)
# Verify we got updates
assert len(updates) > 0
# Verify middleware executed
assert execution_order == ["streaming_before", "streaming_after"]
async def test_run_level_middleware_isolation(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that run-level middleware is isolated and doesn't persist across calls."""
execution_count = {"count": 0}
@chat_middleware
async def counting_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
execution_count["count"] += 1
await next(context)
# First call with run-level middleware
messages = [ChatMessage(role=Role.USER, text="first message")]
response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
assert response1 is not None
assert execution_count["count"] == 1
# Second call WITHOUT run-level middleware - should not execute the middleware
messages = [ChatMessage(role=Role.USER, text="second message")]
response2 = await chat_client_base.get_response(messages)
assert response2 is not None
assert execution_count["count"] == 1 # Should still be 1, not 2
# Third call with run-level middleware again - should execute
messages = [ChatMessage(role=Role.USER, text="third message")]
response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
assert response3 is not None
assert execution_count["count"] == 2 # Should be 2 now
async def test_chat_client_middleware_can_access_and_override_custom_kwargs(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that chat client middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
@chat_middleware
async def kwargs_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
await next(context)
# Add middleware to chat client
chat_client_base.middleware = [kwargs_middleware]
# Execute chat client with custom parameters
messages = [ChatMessage(role=Role.USER, text="test message")]
response = await chat_client_base.get_response(
messages, temperature=0.7, max_tokens=100, custom_param="test_value"
)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert captured_kwargs["temperature"] == 0.7
assert captured_kwargs["max_tokens"] == 100
assert captured_kwargs["custom_param"] == "test_value"
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
async def test_function_middleware_registration_on_chat_client(self) -> None:
"""Test function middleware registered on ChatClient is executed during function calls."""
execution_order: list[str] = []
@function_middleware
async def test_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append(f"function_middleware_before_{context.function.name}")
await next(context)
execution_order.append(f"function_middleware_after_{context.function.name}")
# Define a simple tool function
def sample_tool(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny"
# Create function-invocation enabled chat client
chat_client = use_function_invocation(MockBaseChatClient)()
# Set function middleware directly on the chat client
chat_client.middleware = [test_function_middleware]
# Prepare responses that will trigger function invocation
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_1",
name="sample_tool",
arguments={"location": "San Francisco"},
)
],
)
]
)
final_response = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Based on the weather data, it's sunny!")]
)
chat_client.run_responses = [function_call_response, final_response]
# Execute the chat client directly with tools - this should trigger function invocation and middleware
messages = [ChatMessage(role=Role.USER, text="What's the weather in San Francisco?")]
response = await chat_client.get_response(messages, tools=[sample_tool])
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: function call + final response
# Verify function middleware was executed
assert execution_order == [
"function_middleware_before_sample_tool",
"function_middleware_after_sample_tool",
]
async def test_run_level_function_middleware(self) -> None:
"""Test that function middleware passed to get_response method is also invoked."""
execution_order: list[str] = []
@function_middleware
async def run_level_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("run_level_function_middleware_before")
await next(context)
execution_order.append("run_level_function_middleware_after")
# Define a simple tool function
def sample_tool(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}: sunny"
# Create function-invocation enabled chat client
chat_client = use_function_invocation(MockBaseChatClient)()
# Prepare responses that will trigger function invocation
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="call_2",
name="sample_tool",
arguments={"location": "New York"},
)
],
)
]
)
final_response = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="The weather information has been retrieved!")]
)
chat_client.run_responses = [function_call_response, final_response]
# Execute the chat client directly with run-level middleware and tools
messages = [ChatMessage(role=Role.USER, text="What's the weather in New York?")]
response = await chat_client.get_response(
messages, tools=[sample_tool], middleware=[run_level_function_middleware]
)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Two calls: function call + final response
# Verify run-level function middleware was executed once (during function invocation)
assert execution_order == [
"run_level_function_middleware_before",
"run_level_function_middleware_after",
]