mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Added chat middleware and more examples (#883)
* Added example with stateful middleware * Added chat middleware * Updated middleware example with override scenario * Small revert * Small fixes * Added kwargs to context objects * Added README * Added function middleware to chat client * Small refactoring * Reverted example files * Made MiddlewareWrapper generic * Added Middleware exception * Small refactoring * Small fix
This commit is contained in:
committed by
GitHub
Unverified
parent
863c8d7471
commit
eec7f192eb
@@ -34,6 +34,7 @@ from agent_framework import (
|
||||
UsageContent,
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
@@ -123,6 +124,7 @@ TAzureAIAgentClient = TypeVar("TAzureAIAgentClient", bound="AzureAIAgentClient")
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class AzureAIAgentClient(BaseChatClient):
|
||||
"""Azure AI Agent Chat client."""
|
||||
|
||||
|
||||
@@ -10,7 +10,13 @@ from pydantic import BaseModel, Field
|
||||
from ._logging import get_logger
|
||||
from ._mcp import MCPTool
|
||||
from ._memory import AggregateContextProvider, ContextProvider
|
||||
from ._middleware import Middleware
|
||||
from ._middleware import (
|
||||
ChatMiddleware,
|
||||
ChatMiddlewareCallable,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewareCallable,
|
||||
Middleware,
|
||||
)
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._threads import ChatMessageStore
|
||||
from ._tools import ToolProtocol
|
||||
@@ -189,6 +195,14 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
"""Base class for chat clients."""
|
||||
|
||||
additional_properties: dict[str, Any] = Field(default_factory=dict)
|
||||
middleware: (
|
||||
ChatMiddleware
|
||||
| ChatMiddlewareCallable
|
||||
| FunctionMiddleware
|
||||
| FunctionMiddlewareCallable
|
||||
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
|
||||
| None
|
||||
) = None
|
||||
OTEL_PROVIDER_NAME: str = "unknown"
|
||||
# This is used for OTel setup, should be overridden in subclasses
|
||||
|
||||
@@ -346,13 +360,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
@@ -432,14 +440,8 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
async for update in self._inner_get_streaming_response(
|
||||
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
|
||||
messages=prepped_messages, chat_options=chat_options, **kwargs
|
||||
):
|
||||
yield update
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -607,6 +607,7 @@ async def _auto_invoke_function(
|
||||
middleware_context = FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
kwargs=custom_args or {},
|
||||
)
|
||||
|
||||
async def final_function_handler(context_obj: Any) -> Any:
|
||||
@@ -721,8 +722,16 @@ def _handle_function_calls_response(
|
||||
**kwargs: Any,
|
||||
) -> "ChatResponse":
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import ChatMessage, ChatOptions, FunctionCallContent, FunctionResultContent
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
prepped_messages = prepare_messages(messages)
|
||||
response: "ChatResponse | None" = None
|
||||
fcc_messages: "list[ChatMessage]" = []
|
||||
@@ -746,8 +755,9 @@ def _handle_function_calls_response(
|
||||
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
|
||||
tools = chat_options.tools
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
middleware_pipeline = stored_middleware_pipeline
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
@@ -820,8 +830,16 @@ def _handle_function_calls_streaming_response(
|
||||
) -> AsyncIterable["ChatResponseUpdate"]:
|
||||
"""Wrap the inner get streaming response method to handle tool calls."""
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
prepped_messages = prepare_messages(messages)
|
||||
for attempt_idx in range(max_iterations):
|
||||
all_updates: list["ChatResponseUpdate"] = []
|
||||
@@ -858,8 +876,9 @@ def _handle_function_calls_streaming_response(
|
||||
tools = chat_options.tools
|
||||
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
middleware_pipeline = stored_middleware_pipeline
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
|
||||
@@ -13,16 +13,18 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from pydantic import SecretStr, ValidationError
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from .._tools import use_function_invocation
|
||||
from .._types import (
|
||||
from agent_framework import (
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
CitationAnnotation,
|
||||
TextContent,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import use_observability
|
||||
from ..openai._chat_client import OpenAIBaseChatClient
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import use_observability
|
||||
from agent_framework.openai._chat_client import OpenAIBaseChatClient
|
||||
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
@@ -41,6 +43,7 @@ TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAICha
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
"""Azure OpenAI Chat completion class."""
|
||||
|
||||
|
||||
@@ -9,10 +9,11 @@ from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from pydantic import SecretStr, ValidationError
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from .._tools import use_function_invocation
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import use_observability
|
||||
from ..openai._responses_client import OpenAIBaseResponsesClient
|
||||
from agent_framework import use_chat_middleware, use_function_invocation
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import use_observability
|
||||
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
|
||||
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
@@ -23,6 +24,7 @@ TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="Azur
|
||||
|
||||
@use_observability
|
||||
@use_function_invocation
|
||||
@use_chat_middleware
|
||||
class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
"""Azure Responses completion class."""
|
||||
|
||||
|
||||
@@ -128,3 +128,9 @@ class AdditionItemMismatch(AgentFrameworkException):
|
||||
"""An error occurred while adding two types."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MiddlewareException(AgentFrameworkException):
|
||||
"""An error occurred during middleware execution."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -21,6 +21,7 @@ from openai.types.beta.threads.runs import RunStep
|
||||
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation
|
||||
from .._types import (
|
||||
ChatMessage,
|
||||
@@ -52,6 +53,7 @@ __all__ = ["OpenAIAssistantsClient"]
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
|
||||
"""OpenAI Assistants client."""
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import AIFunction, HostedWebSearchTool, ToolProtocol, use_function_invocation
|
||||
from .._types import (
|
||||
ChatMessage,
|
||||
@@ -452,6 +453,7 @@ TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient")
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
"""OpenAI Chat completion class."""
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import (
|
||||
AIFunction,
|
||||
HostedCodeInterpreterTool,
|
||||
@@ -933,6 +934,7 @@ TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponse
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
"""OpenAI Responses client class."""
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from agent_framework import (
|
||||
TextContent,
|
||||
ToolProtocol,
|
||||
ai_function,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
|
||||
@@ -111,11 +112,13 @@ class MockChatClient:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
|
||||
|
||||
|
||||
@use_chat_middleware
|
||||
class MockBaseChatClient(BaseChatClient):
|
||||
"""Mock implementation of the BaseChatClient."""
|
||||
|
||||
run_responses: list[ChatResponse] = Field(default_factory=list)
|
||||
streaming_responses: list[list[ChatResponseUpdate]] = Field(default_factory=list)
|
||||
call_count: int = Field(default=0)
|
||||
|
||||
@override
|
||||
async def _inner_get_response(
|
||||
@@ -136,6 +139,7 @@ class MockBaseChatClient(BaseChatClient):
|
||||
The chat response contents representing the response(s).
|
||||
"""
|
||||
logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}")
|
||||
self.call_count += 1
|
||||
if not self.run_responses:
|
||||
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from agent_framework import (
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
@@ -19,11 +21,15 @@ from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
AgentMiddlewarePipeline,
|
||||
AgentRunContext,
|
||||
ChatContext,
|
||||
ChatMiddleware,
|
||||
ChatMiddlewarePipeline,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewarePipeline,
|
||||
)
|
||||
from agent_framework._tools import AIFunction
|
||||
from agent_framework._types import ChatOptions
|
||||
|
||||
|
||||
class TestAgentRunContext:
|
||||
@@ -74,6 +80,46 @@ class TestFunctionInvocationContext:
|
||||
assert context.metadata == metadata
|
||||
|
||||
|
||||
class TestChatContext:
|
||||
"""Test cases for ChatContext."""
|
||||
|
||||
def test_init_with_defaults(self, mock_chat_client: Any) -> None:
|
||||
"""Test ChatContext initialization with default values."""
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
assert context.chat_client is mock_chat_client
|
||||
assert context.messages == messages
|
||||
assert context.chat_options is chat_options
|
||||
assert context.is_streaming is False
|
||||
assert context.metadata == {}
|
||||
assert context.result is None
|
||||
assert context.terminate is False
|
||||
|
||||
def test_init_with_custom_values(self, mock_chat_client: Any) -> None:
|
||||
"""Test ChatContext initialization with custom values."""
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions(temperature=0.5)
|
||||
metadata = {"key": "value"}
|
||||
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client,
|
||||
messages=messages,
|
||||
chat_options=chat_options,
|
||||
is_streaming=True,
|
||||
metadata=metadata,
|
||||
terminate=True,
|
||||
)
|
||||
|
||||
assert context.chat_client is mock_chat_client
|
||||
assert context.messages == messages
|
||||
assert context.chat_options is chat_options
|
||||
assert context.is_streaming is True
|
||||
assert context.metadata == metadata
|
||||
assert context.terminate is True
|
||||
|
||||
|
||||
class TestAgentMiddlewarePipeline:
|
||||
"""Test cases for AgentMiddlewarePipeline."""
|
||||
|
||||
@@ -410,6 +456,233 @@ class TestFunctionMiddlewarePipeline:
|
||||
assert execution_order == ["test_before", "handler", "test_after"]
|
||||
|
||||
|
||||
class TestChatMiddlewarePipeline:
|
||||
"""Test cases for ChatMiddlewarePipeline."""
|
||||
|
||||
class PreNextTerminateChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
context.terminate = True
|
||||
await next(context)
|
||||
|
||||
class PostNextTerminateChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
context.terminate = True
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
"""Test ChatMiddlewarePipeline initialization with no middlewares."""
|
||||
pipeline = ChatMiddlewarePipeline()
|
||||
assert not pipeline.has_middlewares
|
||||
|
||||
def test_init_with_class_middleware(self) -> None:
|
||||
"""Test ChatMiddlewarePipeline initialization with class-based middleware."""
|
||||
middleware = TestChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
def test_init_with_function_middleware(self) -> None:
|
||||
"""Test ChatMiddlewarePipeline initialization with function-based middleware."""
|
||||
|
||||
async def test_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
|
||||
pipeline = ChatMiddlewarePipeline([test_middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
async def test_execute_no_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with no middleware."""
|
||||
pipeline = ChatMiddlewarePipeline()
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
return expected_response
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert result == expected_response
|
||||
|
||||
async def test_execute_with_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class OrderTrackingChatMiddleware(ChatMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = OrderTrackingChatMiddleware("test")
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return expected_response
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert result == expected_response
|
||||
assert execution_order == ["test_before", "handler", "test_after"]
|
||||
|
||||
async def test_execute_stream_no_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with no middleware."""
|
||||
pipeline = ChatMiddlewarePipeline()
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
|
||||
async def test_execute_stream_with_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class StreamOrderTrackingChatMiddleware(ChatMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = StreamOrderTrackingChatMiddleware("test")
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
execution_order.append("handler_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
|
||||
|
||||
async def test_execute_with_pre_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
# Handler should not be executed when terminated before next()
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert response is None
|
||||
assert context.terminate
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
|
||||
async def test_execute_with_post_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert response is not None
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "response"
|
||||
assert context.terminate
|
||||
assert execution_order == ["handler"]
|
||||
|
||||
async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
# Handler should not be executed when terminated before next()
|
||||
execution_order.append("handler_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert context.terminate
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
assert not updates
|
||||
|
||||
async def test_execute_stream_with_post_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
execution_order.append("handler_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
assert context.terminate
|
||||
assert execution_order == ["handler_start", "handler_end"]
|
||||
|
||||
|
||||
class TestClassBasedMiddleware:
|
||||
"""Test cases for class-based middleware implementations."""
|
||||
|
||||
@@ -601,6 +874,37 @@ class TestMixedMiddleware:
|
||||
assert result == "result"
|
||||
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
|
||||
|
||||
async def test_mixed_chat_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test mixed class and function-based chat middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("class_before")
|
||||
await next(context)
|
||||
execution_order.append("class_after")
|
||||
|
||||
async def function_chat_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
await next(context)
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = ChatMiddlewarePipeline([ClassChatMiddleware(), function_chat_middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
|
||||
|
||||
|
||||
class TestMultipleMiddlewareOrdering:
|
||||
"""Test cases for multiple middleware execution order."""
|
||||
@@ -695,6 +999,52 @@ class TestMultipleMiddlewareOrdering:
|
||||
expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> None:
|
||||
"""Test that multiple chat middlewares execute in registration order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
class SecondChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
class ThirdChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("third_before")
|
||||
await next(context)
|
||||
execution_order.append("third_after")
|
||||
|
||||
middlewares = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()]
|
||||
pipeline = ChatMiddlewarePipeline(middlewares) # type: ignore
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
expected_order = [
|
||||
"first_before",
|
||||
"second_before",
|
||||
"third_before",
|
||||
"handler",
|
||||
"third_after",
|
||||
"second_after",
|
||||
"first_after",
|
||||
]
|
||||
assert execution_order == expected_order
|
||||
|
||||
|
||||
class TestContextContentValidation:
|
||||
"""Test cases for validating middleware context content."""
|
||||
@@ -776,6 +1126,49 @@ class TestContextContentValidation:
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
assert result == "result"
|
||||
|
||||
async def test_chat_context_validation(self, mock_chat_client: Any) -> None:
|
||||
"""Test that chat context contains expected data."""
|
||||
|
||||
class ChatContextValidationMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Verify context has all expected attributes
|
||||
assert hasattr(context, "chat_client")
|
||||
assert hasattr(context, "messages")
|
||||
assert hasattr(context, "chat_options")
|
||||
assert hasattr(context, "is_streaming")
|
||||
assert hasattr(context, "metadata")
|
||||
assert hasattr(context, "result")
|
||||
assert hasattr(context, "terminate")
|
||||
|
||||
# Verify context content
|
||||
assert context.chat_client is mock_chat_client
|
||||
assert len(context.messages) == 1
|
||||
assert context.messages[0].role == Role.USER
|
||||
assert context.messages[0].text == "test"
|
||||
assert context.is_streaming is False
|
||||
assert isinstance(context.metadata, dict)
|
||||
assert isinstance(context.chat_options, ChatOptions)
|
||||
assert context.chat_options.temperature == 0.5
|
||||
|
||||
# Add custom metadata
|
||||
context.metadata["validated"] = True
|
||||
|
||||
await next(context)
|
||||
|
||||
middleware = ChatContextValidationMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions(temperature=0.5)
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
# Verify metadata was set by middleware
|
||||
assert ctx.metadata.get("validated") is True
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestStreamingScenarios:
|
||||
"""Test cases for streaming and non-streaming scenarios."""
|
||||
@@ -857,6 +1250,89 @@ class TestStreamingScenarios:
|
||||
"stream_end",
|
||||
]
|
||||
|
||||
async def test_chat_streaming_flag_validation(self, mock_chat_client: Any) -> None:
|
||||
"""Test that is_streaming flag is correctly set for chat streaming calls."""
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class ChatStreamingFlagMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
|
||||
middleware = ChatStreamingFlagMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
# Test non-streaming
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
streaming_flags.append(ctx.is_streaming)
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
# Test streaming
|
||||
context_stream = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
streaming_flags.append(ctx.is_streaming)
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk")])
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(
|
||||
mock_chat_client, messages, chat_options, context_stream, final_stream_handler
|
||||
):
|
||||
updates.append(update)
|
||||
|
||||
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
|
||||
assert streaming_flags == [False, False, True, True]
|
||||
|
||||
async def test_chat_streaming_middleware_behavior(self, mock_chat_client: Any) -> None:
|
||||
"""Test chat middleware behavior with streaming responses."""
|
||||
chunks_processed: list[str] = []
|
||||
|
||||
class ChatStreamProcessingMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
chunks_processed.append("before_stream")
|
||||
await next(context)
|
||||
chunks_processed.append("after_stream")
|
||||
|
||||
middleware = ChatStreamProcessingMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
chunks_processed.append("stream_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
chunks_processed.append("chunk1_yielded")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
chunks_processed.append("chunk2_yielded")
|
||||
chunks_processed.append("stream_end")
|
||||
|
||||
updates: list[str] = []
|
||||
async for update in pipeline.execute_stream(
|
||||
mock_chat_client, messages, chat_options, context, final_stream_handler
|
||||
):
|
||||
updates.append(update.text)
|
||||
|
||||
assert updates == ["chunk1", "chunk2"]
|
||||
assert chunks_processed == [
|
||||
"before_stream",
|
||||
"after_stream",
|
||||
"stream_start",
|
||||
"chunk1_yielded",
|
||||
"chunk2_yielded",
|
||||
"stream_end",
|
||||
]
|
||||
|
||||
|
||||
# region Helper classes and fixtures
|
||||
|
||||
@@ -883,6 +1359,13 @@ class TestFunctionMiddleware(FunctionMiddleware):
|
||||
await next(context)
|
||||
|
||||
|
||||
class TestChatMiddleware(ChatMiddleware):
|
||||
"""Test implementation of ChatMiddleware."""
|
||||
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
|
||||
|
||||
class MockFunctionArgs(BaseModel):
|
||||
"""Test arguments for function middleware tests."""
|
||||
|
||||
@@ -1027,6 +1510,100 @@ class TestMiddlewareExecutionControl:
|
||||
assert result.messages == [] # Empty response
|
||||
assert not handler_called
|
||||
|
||||
async def test_chat_middleware_no_next_no_execution(self, mock_chat_client: Any) -> None:
|
||||
"""Test that when chat middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class NoNextChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
middleware = NoNextChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
# Verify no execution happened
|
||||
assert result is None
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_chat_middleware_no_next_no_streaming_execution(self, mock_chat_client: Any) -> None:
|
||||
"""Test that when chat middleware doesn't call next(), no streaming execution happens."""
|
||||
|
||||
class NoNextStreamingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
middleware = NoNextStreamingChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="should not execute")])
|
||||
|
||||
# When middleware doesn't call next(), streaming should yield no updates
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
# Verify no execution happened and no updates were yielded
|
||||
assert len(updates) == 0
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) -> None:
|
||||
"""Test that when first chat middleware doesn't call next(), subsequent middlewares are not called."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first")
|
||||
# Don't call next() - this should stop the pipeline
|
||||
|
||||
class SecondChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second")
|
||||
await next(context)
|
||||
|
||||
pipeline = ChatMiddlewarePipeline([FirstChatMiddleware(), SecondChatMiddleware()])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
# Verify only first middleware was called and no result returned
|
||||
assert execution_order == ["first"]
|
||||
assert result is None
|
||||
assert not handler_called
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> AgentProtocol:
|
||||
@@ -1042,3 +1619,13 @@ def mock_function() -> AIFunction[Any, Any]:
|
||||
function = MagicMock(spec=AIFunction[Any, Any])
|
||||
function.name = "test_function"
|
||||
return function
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_client() -> Any:
|
||||
"""Mock chat client for testing."""
|
||||
from agent_framework._clients import ChatClientProtocol
|
||||
|
||||
client = MagicMock(spec=ChatClientProtocol)
|
||||
client.service_url = MagicMock(return_value="mock://test")
|
||||
return client
|
||||
|
||||
@@ -8,7 +8,9 @@ import pytest
|
||||
from agent_framework import (
|
||||
AgentRunResponseUpdate,
|
||||
ChatAgent,
|
||||
ChatContext,
|
||||
ChatMessage,
|
||||
ChatMiddleware,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
FunctionCallContent,
|
||||
@@ -16,7 +18,9 @@ from agent_framework import (
|
||||
Role,
|
||||
TextContent,
|
||||
agent_middleware,
|
||||
chat_middleware,
|
||||
function_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
@@ -25,8 +29,9 @@ from agent_framework._middleware import (
|
||||
FunctionMiddleware,
|
||||
MiddlewareType,
|
||||
)
|
||||
from agent_framework.exceptions import MiddlewareException
|
||||
|
||||
from .conftest import MockChatClient
|
||||
from .conftest import MockBaseChatClient, MockChatClient
|
||||
|
||||
# region ChatAgent Tests
|
||||
|
||||
@@ -713,6 +718,101 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
|
||||
async def test_function_middleware_can_access_and_override_custom_kwargs(
|
||||
self, chat_client: "MockChatClient"
|
||||
) -> None:
|
||||
"""Test that function middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
middleware_called = False
|
||||
|
||||
@function_middleware
|
||||
async def kwargs_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
nonlocal middleware_called
|
||||
middleware_called = True
|
||||
|
||||
# Capture the original kwargs
|
||||
captured_kwargs["has_chat_options"] = "chat_options" in context.kwargs
|
||||
captured_kwargs["has_custom_param"] = "custom_param" in context.kwargs
|
||||
captured_kwargs["custom_param"] = context.kwargs.get("custom_param")
|
||||
|
||||
# Capture original chat_options values if present
|
||||
if "chat_options" in context.kwargs:
|
||||
chat_options = context.kwargs["chat_options"]
|
||||
captured_kwargs["original_temperature"] = getattr(chat_options, "temperature", None)
|
||||
captured_kwargs["original_max_tokens"] = getattr(chat_options, "max_tokens", None)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
|
||||
# Also modify chat_options if present
|
||||
if "chat_options" in context.kwargs:
|
||||
context.kwargs["chat_options"].temperature = 0.9
|
||||
context.kwargs["chat_options"].max_tokens = 500
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs["temperature"] = context.kwargs.get("temperature")
|
||||
modified_kwargs["max_tokens"] = context.kwargs.get("max_tokens")
|
||||
modified_kwargs["new_param"] = context.kwargs.get("new_param")
|
||||
modified_kwargs["custom_param"] = context.kwargs.get("custom_param")
|
||||
|
||||
# Capture modified chat_options values if present
|
||||
if "chat_options" in context.kwargs:
|
||||
chat_options = context.kwargs["chat_options"]
|
||||
modified_kwargs["chat_options_temperature"] = getattr(chat_options, "temperature", None)
|
||||
modified_kwargs["chat_options_max_tokens"] = getattr(chat_options, "max_tokens", None)
|
||||
|
||||
await next(context)
|
||||
|
||||
chat_client.responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="test_call", name="sample_tool_function", arguments={"location": "Seattle"}
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Function completed")])]),
|
||||
]
|
||||
|
||||
# Create ChatAgent with function middleware
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware], tools=[sample_tool_function])
|
||||
|
||||
# Execute the agent with custom parameters
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
# First check if middleware was called at all
|
||||
assert middleware_called, "Function middleware was not called"
|
||||
|
||||
# Verify middleware captured the original kwargs
|
||||
assert captured_kwargs["has_chat_options"] is True
|
||||
assert captured_kwargs["has_custom_param"] is True
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
assert captured_kwargs["original_temperature"] == 0.7
|
||||
assert captured_kwargs["original_max_tokens"] == 100
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value"
|
||||
assert modified_kwargs["chat_options_temperature"] == 0.9
|
||||
assert modified_kwargs["chat_options_max_tokens"] == 500
|
||||
|
||||
|
||||
class TestMiddlewareDynamicRebuild:
|
||||
"""Test cases for dynamic middleware pipeline rebuilding with ChatAgent."""
|
||||
@@ -1187,8 +1287,8 @@ class TestMiddlewareDecoratorLogic:
|
||||
"""Both decorator and parameter type specified but don't match."""
|
||||
|
||||
# This will cause a type error at decoration time, so we need to test differently
|
||||
# Should raise ValueError due to mismatch during agent creation
|
||||
with pytest.raises(ValueError, match="Middleware type mismatch"):
|
||||
# Should raise MiddlewareException due to mismatch during agent creation
|
||||
with pytest.raises(MiddlewareException, match="Middleware type mismatch"):
|
||||
|
||||
@agent_middleware # type: ignore[arg-type]
|
||||
async def mismatched_middleware(
|
||||
@@ -1304,8 +1404,8 @@ class TestMiddlewareDecoratorLogic:
|
||||
async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type
|
||||
await next(context)
|
||||
|
||||
# Should raise ValueError
|
||||
with pytest.raises(ValueError, match="Cannot determine middleware type"):
|
||||
# Should raise MiddlewareException
|
||||
with pytest.raises(MiddlewareException, match="Cannot determine middleware type"):
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware])
|
||||
await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
@@ -1313,8 +1413,8 @@ class TestMiddlewareDecoratorLogic:
|
||||
"""Test that middleware with insufficient parameters raises an error."""
|
||||
from agent_framework import ChatAgent, agent_middleware
|
||||
|
||||
# Should raise ValueError about insufficient parameters
|
||||
with pytest.raises(ValueError, match="must have at least 2 parameters"):
|
||||
# Should raise MiddlewareException about insufficient parameters
|
||||
with pytest.raises(MiddlewareException, match="must have at least 2 parameters"):
|
||||
|
||||
@agent_middleware # type: ignore[arg-type]
|
||||
async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter
|
||||
@@ -1340,3 +1440,359 @@ class TestMiddlewareDecoratorLogic:
|
||||
|
||||
assert hasattr(test_function_middleware, "_middleware_type")
|
||||
assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestChatAgentChatMiddleware:
|
||||
"""Test cases for chat middleware integration with ChatAgent."""
|
||||
|
||||
async def test_class_based_chat_middleware_with_chat_agent(self) -> None:
|
||||
"""Test class-based chat middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class TrackingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Create ChatAgent with chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
middleware = TrackingChatMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert "test response" in response.messages[0].text
|
||||
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
|
||||
|
||||
async def test_function_based_chat_middleware_with_chat_agent(self) -> None:
|
||||
"""Test function-based chat middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_chat_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Create ChatAgent with function-based chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_chat_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert "test response" in response.messages[0].text
|
||||
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
|
||||
|
||||
async def test_chat_middleware_can_modify_messages(self) -> None:
|
||||
"""Test that chat middleware can modify messages before sending to model."""
|
||||
|
||||
@chat_middleware
|
||||
async def message_modifier_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Modify the first message by adding a prefix
|
||||
if context.messages and len(context.messages) > 0:
|
||||
original_text = context.messages[0].text or ""
|
||||
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
|
||||
await next(context)
|
||||
|
||||
# Create ChatAgent with message-modifying middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[message_modifier_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify that the message was modified (MockBaseChatClient echoes back the input)
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert "MODIFIED: test message" in response.messages[0].text
|
||||
|
||||
async def test_chat_middleware_can_override_response(self) -> None:
|
||||
"""Test that chat middleware can override the response."""
|
||||
|
||||
@chat_middleware
|
||||
async def response_override_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Override the response without calling next()
|
||||
context.result = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
|
||||
response_id="middleware-response-123",
|
||||
)
|
||||
context.terminate = True
|
||||
|
||||
# Create ChatAgent with response-overriding middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[response_override_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify that the response was overridden
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].text == "Middleware overridden response"
|
||||
assert response.response_id == "middleware-response-123"
|
||||
|
||||
async def test_multiple_chat_middleware_execution_order(self) -> None:
|
||||
"""Test that multiple chat middleware execute in the correct order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Create ChatAgent with multiple chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[first_middleware, second_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert execution_order == ["first_before", "second_before", "second_after", "first_after"]
|
||||
|
||||
async def test_chat_middleware_with_streaming(self) -> None:
|
||||
"""Test chat middleware with streaming responses."""
|
||||
execution_order: list[str] = []
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingTrackingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("streaming_chat_before")
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
execution_order.append("streaming_chat_after")
|
||||
|
||||
# Create ChatAgent with chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[StreamingTrackingChatMiddleware()])
|
||||
|
||||
# Set up mock streaming responses
|
||||
chat_client.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT),
|
||||
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
|
||||
]
|
||||
]
|
||||
|
||||
# Execute streaming
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream(messages):
|
||||
updates.append(update)
|
||||
|
||||
# Verify streaming response
|
||||
assert len(updates) >= 1 # At least some updates
|
||||
assert execution_order == ["streaming_chat_before", "streaming_chat_after"]
|
||||
|
||||
# Verify streaming flag was set (at least one True)
|
||||
assert True in streaming_flags
|
||||
|
||||
async def test_chat_middleware_termination_before_execution(self) -> None:
|
||||
"""Test that chat middleware can terminate execution before calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PreTerminationChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
context.terminate = True
|
||||
# Set a custom response since we're terminating
|
||||
context.result = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Terminated by middleware")]
|
||||
)
|
||||
# We call next() but since terminate=True, execution should stop
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create ChatAgent with terminating middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[PreTerminationChatMiddleware()])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response was from middleware
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].text == "Terminated by middleware"
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
|
||||
async def test_chat_middleware_termination_after_execution(self) -> None:
|
||||
"""Test that chat middleware can terminate execution after calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PostTerminationChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
context.terminate = True
|
||||
|
||||
# Create ChatAgent with terminating middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[PostTerminationChatMiddleware()])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response is from actual execution
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert "test response" in response.messages[0].text
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
|
||||
async def test_combined_middleware(self) -> None:
|
||||
"""Test ChatAgent with combined middleware types."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def agent_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("agent_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("agent_middleware_after")
|
||||
|
||||
async def chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
async def function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_456",
|
||||
name="sample_tool_function",
|
||||
arguments='{"location": "San Francisco"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
|
||||
chat_client = use_function_invocation(MockBaseChatClient)()
|
||||
chat_client.run_responses = [function_call_response, final_response]
|
||||
|
||||
# Create ChatAgent with function middleware and tools
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[chat_middleware, function_middleware, agent_middleware],
|
||||
tools=[sample_tool_function],
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
|
||||
|
||||
# Verify function middleware was executed
|
||||
assert execution_order == [
|
||||
"agent_middleware_before",
|
||||
"chat_middleware_before",
|
||||
"chat_middleware_after",
|
||||
"function_middleware_before",
|
||||
"function_middleware_after",
|
||||
"chat_middleware_before",
|
||||
"chat_middleware_after",
|
||||
"agent_middleware_after",
|
||||
]
|
||||
|
||||
# Verify function call and result are in the response
|
||||
all_contents = [content for message in response.messages for content in message.contents]
|
||||
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
|
||||
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
|
||||
|
||||
assert len(function_calls) == 1
|
||||
assert len(function_results) == 1
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
|
||||
async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None:
|
||||
"""Test that agent middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def kwargs_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
|
||||
await next(context)
|
||||
|
||||
# Create ChatAgent with agent middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware])
|
||||
|
||||
# Execute the agent with custom parameters
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
# Verify middleware captured the original kwargs
|
||||
assert captured_kwargs["temperature"] == 0.7
|
||||
assert captured_kwargs["max_tokens"] == 100
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatContext,
|
||||
ChatMessage,
|
||||
ChatMiddleware,
|
||||
ChatResponse,
|
||||
FunctionCallContent,
|
||||
FunctionInvocationContext,
|
||||
Role,
|
||||
chat_middleware,
|
||||
function_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
|
||||
from .conftest import MockBaseChatClient
|
||||
|
||||
|
||||
class TestChatMiddleware:
|
||||
"""Test cases for chat middleware functionality."""
|
||||
|
||||
async def test_class_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test class-based chat middleware with ChatClient."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class LoggingChatMiddleware(ChatMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
next: Callable[[ChatContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [LoggingChatMiddleware()]
|
||||
|
||||
# Execute chat client directly
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
|
||||
|
||||
async def test_function_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test function-based chat middleware with ChatClient."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def logging_chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [logging_chat_middleware]
|
||||
|
||||
# Execute chat client directly
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["function_middleware_before", "function_middleware_after"]
|
||||
|
||||
async def test_chat_middleware_can_modify_messages(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that chat middleware can modify messages before sending to model."""
|
||||
|
||||
@chat_middleware
|
||||
async def message_modifier_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Modify the first message by adding a prefix
|
||||
if context.messages and len(context.messages) > 0:
|
||||
original_text = context.messages[0].text or ""
|
||||
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
|
||||
await next(context)
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [message_modifier_middleware]
|
||||
|
||||
# Execute chat client
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify that the message was modified (MockChatClient echoes back the input)
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
# The mock client should receive the modified message
|
||||
assert "MODIFIED: test message" in response.messages[0].text
|
||||
|
||||
async def test_chat_middleware_can_override_response(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that chat middleware can override the response."""
|
||||
|
||||
@chat_middleware
|
||||
async def response_override_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Override the response without calling next()
|
||||
context.result = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
|
||||
response_id="middleware-response-123",
|
||||
)
|
||||
context.terminate = True
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [response_override_middleware]
|
||||
|
||||
# Execute chat client
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify that the response was overridden
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].text == "Middleware overridden response"
|
||||
assert response.response_id == "middleware-response-123"
|
||||
|
||||
async def test_multiple_chat_middleware_execution_order(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that multiple chat middleware execute in the correct order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Add middleware to chat client (order should be preserved)
|
||||
chat_client_base.middleware = [first_middleware, second_middleware]
|
||||
|
||||
# Execute chat client
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
|
||||
# Verify middleware execution order (nested execution)
|
||||
expected_order = ["first_before", "second_before", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_chat_agent_with_chat_middleware(self) -> None:
|
||||
"""Test ChatAgent with chat middleware specified at agent level."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def agent_level_chat_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("agent_chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("agent_chat_middleware_after")
|
||||
|
||||
chat_client = MockBaseChatClient()
|
||||
|
||||
# Create ChatAgent with chat middleware
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[agent_level_chat_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["agent_chat_middleware_before", "agent_chat_middleware_after"]
|
||||
|
||||
async def test_chat_agent_with_multiple_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that ChatAgent can have multiple chat middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Create ChatAgent with multiple chat middleware
|
||||
agent = ChatAgent(chat_client=chat_client_base, middleware=[first_middleware, second_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
|
||||
# Verify both middleware executed (nested execution order)
|
||||
expected_order = ["first_before", "second_before", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_chat_middleware_with_streaming(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test chat middleware with streaming responses."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def streaming_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("streaming_before")
|
||||
# Verify it's a streaming context
|
||||
assert context.is_streaming is True
|
||||
await next(context)
|
||||
execution_order.append("streaming_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [streaming_middleware]
|
||||
|
||||
# Execute streaming response
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
updates: list[object] = []
|
||||
async for update in chat_client_base.get_streaming_response(messages):
|
||||
updates.append(update)
|
||||
|
||||
# Verify we got updates
|
||||
assert len(updates) > 0
|
||||
|
||||
# Verify middleware executed
|
||||
assert execution_order == ["streaming_before", "streaming_after"]
|
||||
|
||||
async def test_run_level_middleware_isolation(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that run-level middleware is isolated and doesn't persist across calls."""
|
||||
execution_count = {"count": 0}
|
||||
|
||||
@chat_middleware
|
||||
async def counting_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_count["count"] += 1
|
||||
await next(context)
|
||||
|
||||
# First call with run-level middleware
|
||||
messages = [ChatMessage(role=Role.USER, text="first message")]
|
||||
response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
|
||||
assert response1 is not None
|
||||
assert execution_count["count"] == 1
|
||||
|
||||
# Second call WITHOUT run-level middleware - should not execute the middleware
|
||||
messages = [ChatMessage(role=Role.USER, text="second message")]
|
||||
response2 = await chat_client_base.get_response(messages)
|
||||
assert response2 is not None
|
||||
assert execution_count["count"] == 1 # Should still be 1, not 2
|
||||
|
||||
# Third call with run-level middleware again - should execute
|
||||
messages = [ChatMessage(role=Role.USER, text="third message")]
|
||||
response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
|
||||
assert response3 is not None
|
||||
assert execution_count["count"] == 2 # Should be 2 now
|
||||
|
||||
async def test_chat_client_middleware_can_access_and_override_custom_kwargs(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that chat client middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
|
||||
@chat_middleware
|
||||
async def kwargs_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
|
||||
await next(context)
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [kwargs_middleware]
|
||||
|
||||
# Execute chat client with custom parameters
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(
|
||||
messages, temperature=0.7, max_tokens=100, custom_param="test_value"
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
assert captured_kwargs["temperature"] == 0.7
|
||||
assert captured_kwargs["max_tokens"] == 100
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
|
||||
|
||||
async def test_function_middleware_registration_on_chat_client(self) -> None:
|
||||
"""Test function middleware registered on ChatClient is executed during function calls."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@function_middleware
|
||||
async def test_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append(f"function_middleware_before_{context.function.name}")
|
||||
await next(context)
|
||||
execution_order.append(f"function_middleware_after_{context.function.name}")
|
||||
|
||||
# Define a simple tool function
|
||||
def sample_tool(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
# Create function-invocation enabled chat client
|
||||
chat_client = use_function_invocation(MockBaseChatClient)()
|
||||
|
||||
# Set function middleware directly on the chat client
|
||||
chat_client.middleware = [test_function_middleware]
|
||||
|
||||
# Prepare responses that will trigger function invocation
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_1",
|
||||
name="sample_tool",
|
||||
arguments={"location": "San Francisco"},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Based on the weather data, it's sunny!")]
|
||||
)
|
||||
|
||||
chat_client.run_responses = [function_call_response, final_response]
|
||||
|
||||
# Execute the chat client directly with tools - this should trigger function invocation and middleware
|
||||
messages = [ChatMessage(role=Role.USER, text="What's the weather in San Francisco?")]
|
||||
response = await chat_client.get_response(messages, tools=[sample_tool])
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: function call + final response
|
||||
|
||||
# Verify function middleware was executed
|
||||
assert execution_order == [
|
||||
"function_middleware_before_sample_tool",
|
||||
"function_middleware_after_sample_tool",
|
||||
]
|
||||
|
||||
async def test_run_level_function_middleware(self) -> None:
|
||||
"""Test that function middleware passed to get_response method is also invoked."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@function_middleware
|
||||
async def run_level_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("run_level_function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("run_level_function_middleware_after")
|
||||
|
||||
# Define a simple tool function
|
||||
def sample_tool(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
# Create function-invocation enabled chat client
|
||||
chat_client = use_function_invocation(MockBaseChatClient)()
|
||||
|
||||
# Prepare responses that will trigger function invocation
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_2",
|
||||
name="sample_tool",
|
||||
arguments={"location": "New York"},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="The weather information has been retrieved!")]
|
||||
)
|
||||
|
||||
chat_client.run_responses = [function_call_response, final_response]
|
||||
|
||||
# Execute the chat client directly with run-level middleware and tools
|
||||
messages = [ChatMessage(role=Role.USER, text="What's the weather in New York?")]
|
||||
response = await chat_client.get_response(
|
||||
messages, tools=[sample_tool], middleware=[run_level_function_middleware]
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: function call + final response
|
||||
|
||||
# Verify run-level function middleware was executed once (during function invocation)
|
||||
assert execution_order == [
|
||||
"run_level_function_middleware_before",
|
||||
"run_level_function_middleware_after",
|
||||
]
|
||||
Reference in New Issue
Block a user