mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Added chat middleware and more examples (#883)
* Added example with stateful middleware * Added chat middleware * Updated middleware example with override scenario * Small revert * Small fixes * Added kwargs to context objects * Added README * Added function middleware to chat client * Small refactoring * Reverted example files * Made MiddlewareWrapper generic * Added Middleware exception * Small refactoring * Small fix
This commit is contained in:
committed by
GitHub
Unverified
parent
863c8d7471
commit
eec7f192eb
@@ -34,6 +34,7 @@ from agent_framework import (
|
||||
UsageContent,
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
@@ -123,6 +124,7 @@ TAzureAIAgentClient = TypeVar("TAzureAIAgentClient", bound="AzureAIAgentClient")
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class AzureAIAgentClient(BaseChatClient):
|
||||
"""Azure AI Agent Chat client."""
|
||||
|
||||
|
||||
@@ -10,7 +10,13 @@ from pydantic import BaseModel, Field
|
||||
from ._logging import get_logger
|
||||
from ._mcp import MCPTool
|
||||
from ._memory import AggregateContextProvider, ContextProvider
|
||||
from ._middleware import Middleware
|
||||
from ._middleware import (
|
||||
ChatMiddleware,
|
||||
ChatMiddlewareCallable,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewareCallable,
|
||||
Middleware,
|
||||
)
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._threads import ChatMessageStore
|
||||
from ._tools import ToolProtocol
|
||||
@@ -189,6 +195,14 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
"""Base class for chat clients."""
|
||||
|
||||
additional_properties: dict[str, Any] = Field(default_factory=dict)
|
||||
middleware: (
|
||||
ChatMiddleware
|
||||
| ChatMiddlewareCallable
|
||||
| FunctionMiddleware
|
||||
| FunctionMiddlewareCallable
|
||||
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
|
||||
| None
|
||||
) = None
|
||||
OTEL_PROVIDER_NAME: str = "unknown"
|
||||
# This is used for OTel setup, should be overridden in subclasses
|
||||
|
||||
@@ -346,13 +360,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
@@ -432,14 +440,8 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
async for update in self._inner_get_streaming_response(
|
||||
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
|
||||
messages=prepped_messages, chat_options=chat_options, **kwargs
|
||||
):
|
||||
yield update
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -607,6 +607,7 @@ async def _auto_invoke_function(
|
||||
middleware_context = FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
kwargs=custom_args or {},
|
||||
)
|
||||
|
||||
async def final_function_handler(context_obj: Any) -> Any:
|
||||
@@ -721,8 +722,16 @@ def _handle_function_calls_response(
|
||||
**kwargs: Any,
|
||||
) -> "ChatResponse":
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import ChatMessage, ChatOptions, FunctionCallContent, FunctionResultContent
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
prepped_messages = prepare_messages(messages)
|
||||
response: "ChatResponse | None" = None
|
||||
fcc_messages: "list[ChatMessage]" = []
|
||||
@@ -746,8 +755,9 @@ def _handle_function_calls_response(
|
||||
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
|
||||
tools = chat_options.tools
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
middleware_pipeline = stored_middleware_pipeline
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
@@ -820,8 +830,16 @@ def _handle_function_calls_streaming_response(
|
||||
) -> AsyncIterable["ChatResponseUpdate"]:
|
||||
"""Wrap the inner get streaming response method to handle tool calls."""
|
||||
from ._clients import prepare_messages
|
||||
from ._middleware import extract_and_merge_function_middleware
|
||||
from ._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, FunctionCallContent
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
prepped_messages = prepare_messages(messages)
|
||||
for attempt_idx in range(max_iterations):
|
||||
all_updates: list["ChatResponseUpdate"] = []
|
||||
@@ -858,8 +876,9 @@ def _handle_function_calls_streaming_response(
|
||||
tools = chat_options.tools
|
||||
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
middleware_pipeline = stored_middleware_pipeline
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
|
||||
@@ -13,16 +13,18 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from pydantic import SecretStr, ValidationError
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from .._tools import use_function_invocation
|
||||
from .._types import (
|
||||
from agent_framework import (
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
CitationAnnotation,
|
||||
TextContent,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import use_observability
|
||||
from ..openai._chat_client import OpenAIBaseChatClient
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import use_observability
|
||||
from agent_framework.openai._chat_client import OpenAIBaseChatClient
|
||||
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
@@ -41,6 +43,7 @@ TAzureOpenAIChatClient = TypeVar("TAzureOpenAIChatClient", bound="AzureOpenAICha
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class AzureOpenAIChatClient(AzureOpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
"""Azure OpenAI Chat completion class."""
|
||||
|
||||
|
||||
@@ -9,10 +9,11 @@ from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from pydantic import SecretStr, ValidationError
|
||||
from pydantic.networks import AnyUrl
|
||||
|
||||
from .._tools import use_function_invocation
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import use_observability
|
||||
from ..openai._responses_client import OpenAIBaseResponsesClient
|
||||
from agent_framework import use_chat_middleware, use_function_invocation
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import use_observability
|
||||
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
|
||||
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
@@ -23,6 +24,7 @@ TAzureOpenAIResponsesClient = TypeVar("TAzureOpenAIResponsesClient", bound="Azur
|
||||
|
||||
@use_observability
|
||||
@use_function_invocation
|
||||
@use_chat_middleware
|
||||
class AzureOpenAIResponsesClient(AzureOpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
"""Azure Responses completion class."""
|
||||
|
||||
|
||||
@@ -128,3 +128,9 @@ class AdditionItemMismatch(AgentFrameworkException):
|
||||
"""An error occurred while adding two types."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MiddlewareException(AgentFrameworkException):
|
||||
"""An error occurred during middleware execution."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -21,6 +21,7 @@ from openai.types.beta.threads.runs import RunStep
|
||||
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import AIFunction, HostedCodeInterpreterTool, HostedFileSearchTool, use_function_invocation
|
||||
from .._types import (
|
||||
ChatMessage,
|
||||
@@ -52,6 +53,7 @@ __all__ = ["OpenAIAssistantsClient"]
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
|
||||
"""OpenAI Assistants client."""
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import AIFunction, HostedWebSearchTool, ToolProtocol, use_function_invocation
|
||||
from .._types import (
|
||||
ChatMessage,
|
||||
@@ -452,6 +453,7 @@ TOpenAIChatClient = TypeVar("TOpenAIChatClient", bound="OpenAIChatClient")
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
"""OpenAI Chat completion class."""
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import (
|
||||
AIFunction,
|
||||
HostedCodeInterpreterTool,
|
||||
@@ -933,6 +934,7 @@ TOpenAIResponsesClient = TypeVar("TOpenAIResponsesClient", bound="OpenAIResponse
|
||||
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
"""OpenAI Responses client class."""
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from agent_framework import (
|
||||
TextContent,
|
||||
ToolProtocol,
|
||||
ai_function,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
|
||||
@@ -111,11 +112,13 @@ class MockChatClient:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
|
||||
|
||||
|
||||
@use_chat_middleware
|
||||
class MockBaseChatClient(BaseChatClient):
|
||||
"""Mock implementation of the BaseChatClient."""
|
||||
|
||||
run_responses: list[ChatResponse] = Field(default_factory=list)
|
||||
streaming_responses: list[list[ChatResponseUpdate]] = Field(default_factory=list)
|
||||
call_count: int = Field(default=0)
|
||||
|
||||
@override
|
||||
async def _inner_get_response(
|
||||
@@ -136,6 +139,7 @@ class MockBaseChatClient(BaseChatClient):
|
||||
The chat response contents representing the response(s).
|
||||
"""
|
||||
logger.debug(f"Running base chat client inner, with: {messages=}, {chat_options=}, {kwargs=}")
|
||||
self.call_count += 1
|
||||
if not self.run_responses:
|
||||
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from agent_framework import (
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
@@ -19,11 +21,15 @@ from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
AgentMiddlewarePipeline,
|
||||
AgentRunContext,
|
||||
ChatContext,
|
||||
ChatMiddleware,
|
||||
ChatMiddlewarePipeline,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewarePipeline,
|
||||
)
|
||||
from agent_framework._tools import AIFunction
|
||||
from agent_framework._types import ChatOptions
|
||||
|
||||
|
||||
class TestAgentRunContext:
|
||||
@@ -74,6 +80,46 @@ class TestFunctionInvocationContext:
|
||||
assert context.metadata == metadata
|
||||
|
||||
|
||||
class TestChatContext:
|
||||
"""Test cases for ChatContext."""
|
||||
|
||||
def test_init_with_defaults(self, mock_chat_client: Any) -> None:
|
||||
"""Test ChatContext initialization with default values."""
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
assert context.chat_client is mock_chat_client
|
||||
assert context.messages == messages
|
||||
assert context.chat_options is chat_options
|
||||
assert context.is_streaming is False
|
||||
assert context.metadata == {}
|
||||
assert context.result is None
|
||||
assert context.terminate is False
|
||||
|
||||
def test_init_with_custom_values(self, mock_chat_client: Any) -> None:
|
||||
"""Test ChatContext initialization with custom values."""
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions(temperature=0.5)
|
||||
metadata = {"key": "value"}
|
||||
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client,
|
||||
messages=messages,
|
||||
chat_options=chat_options,
|
||||
is_streaming=True,
|
||||
metadata=metadata,
|
||||
terminate=True,
|
||||
)
|
||||
|
||||
assert context.chat_client is mock_chat_client
|
||||
assert context.messages == messages
|
||||
assert context.chat_options is chat_options
|
||||
assert context.is_streaming is True
|
||||
assert context.metadata == metadata
|
||||
assert context.terminate is True
|
||||
|
||||
|
||||
class TestAgentMiddlewarePipeline:
|
||||
"""Test cases for AgentMiddlewarePipeline."""
|
||||
|
||||
@@ -410,6 +456,233 @@ class TestFunctionMiddlewarePipeline:
|
||||
assert execution_order == ["test_before", "handler", "test_after"]
|
||||
|
||||
|
||||
class TestChatMiddlewarePipeline:
|
||||
"""Test cases for ChatMiddlewarePipeline."""
|
||||
|
||||
class PreNextTerminateChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
context.terminate = True
|
||||
await next(context)
|
||||
|
||||
class PostNextTerminateChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
context.terminate = True
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
"""Test ChatMiddlewarePipeline initialization with no middlewares."""
|
||||
pipeline = ChatMiddlewarePipeline()
|
||||
assert not pipeline.has_middlewares
|
||||
|
||||
def test_init_with_class_middleware(self) -> None:
|
||||
"""Test ChatMiddlewarePipeline initialization with class-based middleware."""
|
||||
middleware = TestChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
def test_init_with_function_middleware(self) -> None:
|
||||
"""Test ChatMiddlewarePipeline initialization with function-based middleware."""
|
||||
|
||||
async def test_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
|
||||
pipeline = ChatMiddlewarePipeline([test_middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
async def test_execute_no_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with no middleware."""
|
||||
pipeline = ChatMiddlewarePipeline()
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
return expected_response
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert result == expected_response
|
||||
|
||||
async def test_execute_with_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class OrderTrackingChatMiddleware(ChatMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = OrderTrackingChatMiddleware("test")
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
expected_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return expected_response
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert result == expected_response
|
||||
assert execution_order == ["test_before", "handler", "test_after"]
|
||||
|
||||
async def test_execute_stream_no_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with no middleware."""
|
||||
pipeline = ChatMiddlewarePipeline()
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
|
||||
async def test_execute_stream_with_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class StreamOrderTrackingChatMiddleware(ChatMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = StreamOrderTrackingChatMiddleware("test")
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
execution_order.append("handler_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
|
||||
|
||||
async def test_execute_with_pre_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
# Handler should not be executed when terminated before next()
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert response is None
|
||||
assert context.terminate
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
|
||||
async def test_execute_with_post_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
response = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert response is not None
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "response"
|
||||
assert context.terminate
|
||||
assert execution_order == ["handler"]
|
||||
|
||||
async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
# Handler should not be executed when terminated before next()
|
||||
execution_order.append("handler_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert context.terminate
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
assert not updates
|
||||
|
||||
async def test_execute_stream_with_post_next_termination(self, mock_chat_client: Any) -> None:
|
||||
"""Test pipeline streaming execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
execution_order.append("handler_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
assert context.terminate
|
||||
assert execution_order == ["handler_start", "handler_end"]
|
||||
|
||||
|
||||
class TestClassBasedMiddleware:
|
||||
"""Test cases for class-based middleware implementations."""
|
||||
|
||||
@@ -601,6 +874,37 @@ class TestMixedMiddleware:
|
||||
assert result == "result"
|
||||
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
|
||||
|
||||
async def test_mixed_chat_middleware(self, mock_chat_client: Any) -> None:
|
||||
"""Test mixed class and function-based chat middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("class_before")
|
||||
await next(context)
|
||||
execution_order.append("class_after")
|
||||
|
||||
async def function_chat_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
await next(context)
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = ChatMiddlewarePipeline([ClassChatMiddleware(), function_chat_middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
|
||||
|
||||
|
||||
class TestMultipleMiddlewareOrdering:
|
||||
"""Test cases for multiple middleware execution order."""
|
||||
@@ -695,6 +999,52 @@ class TestMultipleMiddlewareOrdering:
|
||||
expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> None:
|
||||
"""Test that multiple chat middlewares execute in registration order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
class SecondChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
class ThirdChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("third_before")
|
||||
await next(context)
|
||||
execution_order.append("third_after")
|
||||
|
||||
middlewares = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()]
|
||||
pipeline = ChatMiddlewarePipeline(middlewares) # type: ignore
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
execution_order.append("handler")
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
expected_order = [
|
||||
"first_before",
|
||||
"second_before",
|
||||
"third_before",
|
||||
"handler",
|
||||
"third_after",
|
||||
"second_after",
|
||||
"first_after",
|
||||
]
|
||||
assert execution_order == expected_order
|
||||
|
||||
|
||||
class TestContextContentValidation:
|
||||
"""Test cases for validating middleware context content."""
|
||||
@@ -776,6 +1126,49 @@ class TestContextContentValidation:
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
assert result == "result"
|
||||
|
||||
async def test_chat_context_validation(self, mock_chat_client: Any) -> None:
|
||||
"""Test that chat context contains expected data."""
|
||||
|
||||
class ChatContextValidationMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Verify context has all expected attributes
|
||||
assert hasattr(context, "chat_client")
|
||||
assert hasattr(context, "messages")
|
||||
assert hasattr(context, "chat_options")
|
||||
assert hasattr(context, "is_streaming")
|
||||
assert hasattr(context, "metadata")
|
||||
assert hasattr(context, "result")
|
||||
assert hasattr(context, "terminate")
|
||||
|
||||
# Verify context content
|
||||
assert context.chat_client is mock_chat_client
|
||||
assert len(context.messages) == 1
|
||||
assert context.messages[0].role == Role.USER
|
||||
assert context.messages[0].text == "test"
|
||||
assert context.is_streaming is False
|
||||
assert isinstance(context.metadata, dict)
|
||||
assert isinstance(context.chat_options, ChatOptions)
|
||||
assert context.chat_options.temperature == 0.5
|
||||
|
||||
# Add custom metadata
|
||||
context.metadata["validated"] = True
|
||||
|
||||
await next(context)
|
||||
|
||||
middleware = ChatContextValidationMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions(temperature=0.5)
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
# Verify metadata was set by middleware
|
||||
assert ctx.metadata.get("validated") is True
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestStreamingScenarios:
|
||||
"""Test cases for streaming and non-streaming scenarios."""
|
||||
@@ -857,6 +1250,89 @@ class TestStreamingScenarios:
|
||||
"stream_end",
|
||||
]
|
||||
|
||||
async def test_chat_streaming_flag_validation(self, mock_chat_client: Any) -> None:
|
||||
"""Test that is_streaming flag is correctly set for chat streaming calls."""
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class ChatStreamingFlagMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
|
||||
middleware = ChatStreamingFlagMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
# Test non-streaming
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
streaming_flags.append(ctx.is_streaming)
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
# Test streaming
|
||||
context_stream = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
streaming_flags.append(ctx.is_streaming)
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk")])
|
||||
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(
|
||||
mock_chat_client, messages, chat_options, context_stream, final_stream_handler
|
||||
):
|
||||
updates.append(update)
|
||||
|
||||
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
|
||||
assert streaming_flags == [False, False, True, True]
|
||||
|
||||
async def test_chat_streaming_middleware_behavior(self, mock_chat_client: Any) -> None:
|
||||
"""Test chat middleware behavior with streaming responses."""
|
||||
chunks_processed: list[str] = []
|
||||
|
||||
class ChatStreamProcessingMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
chunks_processed.append("before_stream")
|
||||
await next(context)
|
||||
chunks_processed.append("after_stream")
|
||||
|
||||
middleware = ChatStreamProcessingMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
chunks_processed.append("stream_start")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
chunks_processed.append("chunk1_yielded")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
chunks_processed.append("chunk2_yielded")
|
||||
chunks_processed.append("stream_end")
|
||||
|
||||
updates: list[str] = []
|
||||
async for update in pipeline.execute_stream(
|
||||
mock_chat_client, messages, chat_options, context, final_stream_handler
|
||||
):
|
||||
updates.append(update.text)
|
||||
|
||||
assert updates == ["chunk1", "chunk2"]
|
||||
assert chunks_processed == [
|
||||
"before_stream",
|
||||
"after_stream",
|
||||
"stream_start",
|
||||
"chunk1_yielded",
|
||||
"chunk2_yielded",
|
||||
"stream_end",
|
||||
]
|
||||
|
||||
|
||||
# region Helper classes and fixtures
|
||||
|
||||
@@ -883,6 +1359,13 @@ class TestFunctionMiddleware(FunctionMiddleware):
|
||||
await next(context)
|
||||
|
||||
|
||||
class TestChatMiddleware(ChatMiddleware):
|
||||
"""Test implementation of ChatMiddleware."""
|
||||
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
|
||||
|
||||
class MockFunctionArgs(BaseModel):
|
||||
"""Test arguments for function middleware tests."""
|
||||
|
||||
@@ -1027,6 +1510,100 @@ class TestMiddlewareExecutionControl:
|
||||
assert result.messages == [] # Empty response
|
||||
assert not handler_called
|
||||
|
||||
async def test_chat_middleware_no_next_no_execution(self, mock_chat_client: Any) -> None:
|
||||
"""Test that when chat middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class NoNextChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
middleware = NoNextChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
# Verify no execution happened
|
||||
assert result is None
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_chat_middleware_no_next_no_streaming_execution(self, mock_chat_client: Any) -> None:
|
||||
"""Test that when chat middleware doesn't call next(), no streaming execution happens."""
|
||||
|
||||
class NoNextStreamingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
middleware = NoNextStreamingChatMiddleware()
|
||||
pipeline = ChatMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(
|
||||
chat_client=mock_chat_client, messages=messages, chat_options=chat_options, is_streaming=True
|
||||
)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="should not execute")])
|
||||
|
||||
# When middleware doesn't call next(), streaming should yield no updates
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
# Verify no execution happened and no updates were yielded
|
||||
assert len(updates) == 0
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) -> None:
|
||||
"""Test that when first chat middleware doesn't call next(), subsequent middlewares are not called."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first")
|
||||
# Don't call next() - this should stop the pipeline
|
||||
|
||||
class SecondChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second")
|
||||
await next(context)
|
||||
|
||||
pipeline = ChatMiddlewarePipeline([FirstChatMiddleware(), SecondChatMiddleware()])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
chat_options = ChatOptions()
|
||||
context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: ChatContext) -> ChatResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
|
||||
|
||||
result = await pipeline.execute(mock_chat_client, messages, chat_options, context, final_handler)
|
||||
|
||||
# Verify only first middleware was called and no result returned
|
||||
assert execution_order == ["first"]
|
||||
assert result is None
|
||||
assert not handler_called
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> AgentProtocol:
|
||||
@@ -1042,3 +1619,13 @@ def mock_function() -> AIFunction[Any, Any]:
|
||||
function = MagicMock(spec=AIFunction[Any, Any])
|
||||
function.name = "test_function"
|
||||
return function
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_client() -> Any:
|
||||
"""Mock chat client for testing."""
|
||||
from agent_framework._clients import ChatClientProtocol
|
||||
|
||||
client = MagicMock(spec=ChatClientProtocol)
|
||||
client.service_url = MagicMock(return_value="mock://test")
|
||||
return client
|
||||
|
||||
@@ -8,7 +8,9 @@ import pytest
|
||||
from agent_framework import (
|
||||
AgentRunResponseUpdate,
|
||||
ChatAgent,
|
||||
ChatContext,
|
||||
ChatMessage,
|
||||
ChatMiddleware,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
FunctionCallContent,
|
||||
@@ -16,7 +18,9 @@ from agent_framework import (
|
||||
Role,
|
||||
TextContent,
|
||||
agent_middleware,
|
||||
chat_middleware,
|
||||
function_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
@@ -25,8 +29,9 @@ from agent_framework._middleware import (
|
||||
FunctionMiddleware,
|
||||
MiddlewareType,
|
||||
)
|
||||
from agent_framework.exceptions import MiddlewareException
|
||||
|
||||
from .conftest import MockChatClient
|
||||
from .conftest import MockBaseChatClient, MockChatClient
|
||||
|
||||
# region ChatAgent Tests
|
||||
|
||||
@@ -713,6 +718,101 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
|
||||
async def test_function_middleware_can_access_and_override_custom_kwargs(
|
||||
self, chat_client: "MockChatClient"
|
||||
) -> None:
|
||||
"""Test that function middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
middleware_called = False
|
||||
|
||||
@function_middleware
|
||||
async def kwargs_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
nonlocal middleware_called
|
||||
middleware_called = True
|
||||
|
||||
# Capture the original kwargs
|
||||
captured_kwargs["has_chat_options"] = "chat_options" in context.kwargs
|
||||
captured_kwargs["has_custom_param"] = "custom_param" in context.kwargs
|
||||
captured_kwargs["custom_param"] = context.kwargs.get("custom_param")
|
||||
|
||||
# Capture original chat_options values if present
|
||||
if "chat_options" in context.kwargs:
|
||||
chat_options = context.kwargs["chat_options"]
|
||||
captured_kwargs["original_temperature"] = getattr(chat_options, "temperature", None)
|
||||
captured_kwargs["original_max_tokens"] = getattr(chat_options, "max_tokens", None)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
|
||||
# Also modify chat_options if present
|
||||
if "chat_options" in context.kwargs:
|
||||
context.kwargs["chat_options"].temperature = 0.9
|
||||
context.kwargs["chat_options"].max_tokens = 500
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs["temperature"] = context.kwargs.get("temperature")
|
||||
modified_kwargs["max_tokens"] = context.kwargs.get("max_tokens")
|
||||
modified_kwargs["new_param"] = context.kwargs.get("new_param")
|
||||
modified_kwargs["custom_param"] = context.kwargs.get("custom_param")
|
||||
|
||||
# Capture modified chat_options values if present
|
||||
if "chat_options" in context.kwargs:
|
||||
chat_options = context.kwargs["chat_options"]
|
||||
modified_kwargs["chat_options_temperature"] = getattr(chat_options, "temperature", None)
|
||||
modified_kwargs["chat_options_max_tokens"] = getattr(chat_options, "max_tokens", None)
|
||||
|
||||
await next(context)
|
||||
|
||||
chat_client.responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="test_call", name="sample_tool_function", arguments={"location": "Seattle"}
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Function completed")])]),
|
||||
]
|
||||
|
||||
# Create ChatAgent with function middleware
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware], tools=[sample_tool_function])
|
||||
|
||||
# Execute the agent with custom parameters
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
# First check if middleware was called at all
|
||||
assert middleware_called, "Function middleware was not called"
|
||||
|
||||
# Verify middleware captured the original kwargs
|
||||
assert captured_kwargs["has_chat_options"] is True
|
||||
assert captured_kwargs["has_custom_param"] is True
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
assert captured_kwargs["original_temperature"] == 0.7
|
||||
assert captured_kwargs["original_max_tokens"] == 100
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value"
|
||||
assert modified_kwargs["chat_options_temperature"] == 0.9
|
||||
assert modified_kwargs["chat_options_max_tokens"] == 500
|
||||
|
||||
|
||||
class TestMiddlewareDynamicRebuild:
|
||||
"""Test cases for dynamic middleware pipeline rebuilding with ChatAgent."""
|
||||
@@ -1187,8 +1287,8 @@ class TestMiddlewareDecoratorLogic:
|
||||
"""Both decorator and parameter type specified but don't match."""
|
||||
|
||||
# This will cause a type error at decoration time, so we need to test differently
|
||||
# Should raise ValueError due to mismatch during agent creation
|
||||
with pytest.raises(ValueError, match="Middleware type mismatch"):
|
||||
# Should raise MiddlewareException due to mismatch during agent creation
|
||||
with pytest.raises(MiddlewareException, match="Middleware type mismatch"):
|
||||
|
||||
@agent_middleware # type: ignore[arg-type]
|
||||
async def mismatched_middleware(
|
||||
@@ -1304,8 +1404,8 @@ class TestMiddlewareDecoratorLogic:
|
||||
async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type
|
||||
await next(context)
|
||||
|
||||
# Should raise ValueError
|
||||
with pytest.raises(ValueError, match="Cannot determine middleware type"):
|
||||
# Should raise MiddlewareException
|
||||
with pytest.raises(MiddlewareException, match="Cannot determine middleware type"):
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware])
|
||||
await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
@@ -1313,8 +1413,8 @@ class TestMiddlewareDecoratorLogic:
|
||||
"""Test that middleware with insufficient parameters raises an error."""
|
||||
from agent_framework import ChatAgent, agent_middleware
|
||||
|
||||
# Should raise ValueError about insufficient parameters
|
||||
with pytest.raises(ValueError, match="must have at least 2 parameters"):
|
||||
# Should raise MiddlewareException about insufficient parameters
|
||||
with pytest.raises(MiddlewareException, match="must have at least 2 parameters"):
|
||||
|
||||
@agent_middleware # type: ignore[arg-type]
|
||||
async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter
|
||||
@@ -1340,3 +1440,359 @@ class TestMiddlewareDecoratorLogic:
|
||||
|
||||
assert hasattr(test_function_middleware, "_middleware_type")
|
||||
assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestChatAgentChatMiddleware:
|
||||
"""Test cases for chat middleware integration with ChatAgent."""
|
||||
|
||||
async def test_class_based_chat_middleware_with_chat_agent(self) -> None:
|
||||
"""Test class-based chat middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class TrackingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Create ChatAgent with chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
middleware = TrackingChatMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert "test response" in response.messages[0].text
|
||||
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
|
||||
|
||||
async def test_function_based_chat_middleware_with_chat_agent(self) -> None:
|
||||
"""Test function-based chat middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_chat_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Create ChatAgent with function-based chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_chat_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert "test response" in response.messages[0].text
|
||||
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
|
||||
|
||||
async def test_chat_middleware_can_modify_messages(self) -> None:
|
||||
"""Test that chat middleware can modify messages before sending to model."""
|
||||
|
||||
@chat_middleware
|
||||
async def message_modifier_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Modify the first message by adding a prefix
|
||||
if context.messages and len(context.messages) > 0:
|
||||
original_text = context.messages[0].text or ""
|
||||
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
|
||||
await next(context)
|
||||
|
||||
# Create ChatAgent with message-modifying middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[message_modifier_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify that the message was modified (MockBaseChatClient echoes back the input)
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert "MODIFIED: test message" in response.messages[0].text
|
||||
|
||||
async def test_chat_middleware_can_override_response(self) -> None:
|
||||
"""Test that chat middleware can override the response."""
|
||||
|
||||
@chat_middleware
|
||||
async def response_override_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Override the response without calling next()
|
||||
context.result = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
|
||||
response_id="middleware-response-123",
|
||||
)
|
||||
context.terminate = True
|
||||
|
||||
# Create ChatAgent with response-overriding middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[response_override_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify that the response was overridden
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].text == "Middleware overridden response"
|
||||
assert response.response_id == "middleware-response-123"
|
||||
|
||||
async def test_multiple_chat_middleware_execution_order(self) -> None:
|
||||
"""Test that multiple chat middleware execute in the correct order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Create ChatAgent with multiple chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[first_middleware, second_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert execution_order == ["first_before", "second_before", "second_after", "first_after"]
|
||||
|
||||
async def test_chat_middleware_with_streaming(self) -> None:
|
||||
"""Test chat middleware with streaming responses."""
|
||||
execution_order: list[str] = []
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingTrackingChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("streaming_chat_before")
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
execution_order.append("streaming_chat_after")
|
||||
|
||||
# Create ChatAgent with chat middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[StreamingTrackingChatMiddleware()])
|
||||
|
||||
# Set up mock streaming responses
|
||||
chat_client.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT),
|
||||
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
|
||||
]
|
||||
]
|
||||
|
||||
# Execute streaming
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream(messages):
|
||||
updates.append(update)
|
||||
|
||||
# Verify streaming response
|
||||
assert len(updates) >= 1 # At least some updates
|
||||
assert execution_order == ["streaming_chat_before", "streaming_chat_after"]
|
||||
|
||||
# Verify streaming flag was set (at least one True)
|
||||
assert True in streaming_flags
|
||||
|
||||
async def test_chat_middleware_termination_before_execution(self) -> None:
|
||||
"""Test that chat middleware can terminate execution before calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PreTerminationChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
context.terminate = True
|
||||
# Set a custom response since we're terminating
|
||||
context.result = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Terminated by middleware")]
|
||||
)
|
||||
# We call next() but since terminate=True, execution should stop
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create ChatAgent with terminating middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[PreTerminationChatMiddleware()])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response was from middleware
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].text == "Terminated by middleware"
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
|
||||
async def test_chat_middleware_termination_after_execution(self) -> None:
|
||||
"""Test that chat middleware can terminate execution after calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PostTerminationChatMiddleware(ChatMiddleware):
|
||||
async def process(self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
context.terminate = True
|
||||
|
||||
# Create ChatAgent with terminating middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[PostTerminationChatMiddleware()])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response is from actual execution
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert "test response" in response.messages[0].text
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
|
||||
async def test_combined_middleware(self) -> None:
|
||||
"""Test ChatAgent with combined middleware types."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def agent_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("agent_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("agent_middleware_after")
|
||||
|
||||
async def chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
async def function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_456",
|
||||
name="sample_tool_function",
|
||||
arguments='{"location": "San Francisco"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
|
||||
chat_client = use_function_invocation(MockBaseChatClient)()
|
||||
chat_client.run_responses = [function_call_response, final_response]
|
||||
|
||||
# Create ChatAgent with function middleware and tools
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[chat_middleware, function_middleware, agent_middleware],
|
||||
tools=[sample_tool_function],
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
|
||||
|
||||
# Verify function middleware was executed
|
||||
assert execution_order == [
|
||||
"agent_middleware_before",
|
||||
"chat_middleware_before",
|
||||
"chat_middleware_after",
|
||||
"function_middleware_before",
|
||||
"function_middleware_after",
|
||||
"chat_middleware_before",
|
||||
"chat_middleware_after",
|
||||
"agent_middleware_after",
|
||||
]
|
||||
|
||||
# Verify function call and result are in the response
|
||||
all_contents = [content for message in response.messages for content in message.contents]
|
||||
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
|
||||
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
|
||||
|
||||
assert len(function_calls) == 1
|
||||
assert len(function_results) == 1
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
|
||||
async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None:
|
||||
"""Test that agent middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def kwargs_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
|
||||
await next(context)
|
||||
|
||||
# Create ChatAgent with agent middleware
|
||||
chat_client = MockBaseChatClient()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[kwargs_middleware])
|
||||
|
||||
# Execute the agent with custom parameters
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
# Verify middleware captured the original kwargs
|
||||
assert captured_kwargs["temperature"] == 0.7
|
||||
assert captured_kwargs["max_tokens"] == 100
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
ChatAgent,
|
||||
ChatContext,
|
||||
ChatMessage,
|
||||
ChatMiddleware,
|
||||
ChatResponse,
|
||||
FunctionCallContent,
|
||||
FunctionInvocationContext,
|
||||
Role,
|
||||
chat_middleware,
|
||||
function_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
|
||||
from .conftest import MockBaseChatClient
|
||||
|
||||
|
||||
class TestChatMiddleware:
|
||||
"""Test cases for chat middleware functionality."""
|
||||
|
||||
async def test_class_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test class-based chat middleware with ChatClient."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class LoggingChatMiddleware(ChatMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
next: Callable[[ChatContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("chat_middleware_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [LoggingChatMiddleware()]
|
||||
|
||||
# Execute chat client directly
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["chat_middleware_before", "chat_middleware_after"]
|
||||
|
||||
async def test_function_based_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test function-based chat middleware with ChatClient."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def logging_chat_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [logging_chat_middleware]
|
||||
|
||||
# Execute chat client directly
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["function_middleware_before", "function_middleware_after"]
|
||||
|
||||
async def test_chat_middleware_can_modify_messages(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that chat middleware can modify messages before sending to model."""
|
||||
|
||||
@chat_middleware
|
||||
async def message_modifier_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Modify the first message by adding a prefix
|
||||
if context.messages and len(context.messages) > 0:
|
||||
original_text = context.messages[0].text or ""
|
||||
context.messages[0] = ChatMessage(role=context.messages[0].role, text=f"MODIFIED: {original_text}")
|
||||
await next(context)
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [message_modifier_middleware]
|
||||
|
||||
# Execute chat client
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify that the message was modified (MockChatClient echoes back the input)
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
# The mock client should receive the modified message
|
||||
assert "MODIFIED: test message" in response.messages[0].text
|
||||
|
||||
async def test_chat_middleware_can_override_response(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that chat middleware can override the response."""
|
||||
|
||||
@chat_middleware
|
||||
async def response_override_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Override the response without calling next()
|
||||
context.result = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Middleware overridden response")],
|
||||
response_id="middleware-response-123",
|
||||
)
|
||||
context.terminate = True
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [response_override_middleware]
|
||||
|
||||
# Execute chat client
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify that the response was overridden
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].text == "Middleware overridden response"
|
||||
assert response.response_id == "middleware-response-123"
|
||||
|
||||
async def test_multiple_chat_middleware_execution_order(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that multiple chat middleware execute in the correct order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Add middleware to chat client (order should be preserved)
|
||||
chat_client_base.middleware = [first_middleware, second_middleware]
|
||||
|
||||
# Execute chat client
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
|
||||
# Verify middleware execution order (nested execution)
|
||||
expected_order = ["first_before", "second_before", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_chat_agent_with_chat_middleware(self) -> None:
|
||||
"""Test ChatAgent with chat middleware specified at agent level."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def agent_level_chat_middleware(
|
||||
context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("agent_chat_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("agent_chat_middleware_after")
|
||||
|
||||
chat_client = MockBaseChatClient()
|
||||
|
||||
# Create ChatAgent with chat middleware
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[agent_level_chat_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["agent_chat_middleware_before", "agent_chat_middleware_after"]
|
||||
|
||||
async def test_chat_agent_with_multiple_chat_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that ChatAgent can have multiple chat middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def first_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
@chat_middleware
|
||||
async def second_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
# Create ChatAgent with multiple chat middleware
|
||||
agent = ChatAgent(chat_client=chat_client_base, middleware=[first_middleware, second_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
|
||||
# Verify both middleware executed (nested execution order)
|
||||
expected_order = ["first_before", "second_before", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_chat_middleware_with_streaming(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test chat middleware with streaming responses."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@chat_middleware
|
||||
async def streaming_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("streaming_before")
|
||||
# Verify it's a streaming context
|
||||
assert context.is_streaming is True
|
||||
await next(context)
|
||||
execution_order.append("streaming_after")
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [streaming_middleware]
|
||||
|
||||
# Execute streaming response
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
updates: list[object] = []
|
||||
async for update in chat_client_base.get_streaming_response(messages):
|
||||
updates.append(update)
|
||||
|
||||
# Verify we got updates
|
||||
assert len(updates) > 0
|
||||
|
||||
# Verify middleware executed
|
||||
assert execution_order == ["streaming_before", "streaming_after"]
|
||||
|
||||
async def test_run_level_middleware_isolation(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that run-level middleware is isolated and doesn't persist across calls."""
|
||||
execution_count = {"count": 0}
|
||||
|
||||
@chat_middleware
|
||||
async def counting_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
execution_count["count"] += 1
|
||||
await next(context)
|
||||
|
||||
# First call with run-level middleware
|
||||
messages = [ChatMessage(role=Role.USER, text="first message")]
|
||||
response1 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
|
||||
assert response1 is not None
|
||||
assert execution_count["count"] == 1
|
||||
|
||||
# Second call WITHOUT run-level middleware - should not execute the middleware
|
||||
messages = [ChatMessage(role=Role.USER, text="second message")]
|
||||
response2 = await chat_client_base.get_response(messages)
|
||||
assert response2 is not None
|
||||
assert execution_count["count"] == 1 # Should still be 1, not 2
|
||||
|
||||
# Third call with run-level middleware again - should execute
|
||||
messages = [ChatMessage(role=Role.USER, text="third message")]
|
||||
response3 = await chat_client_base.get_response(messages, middleware=[counting_middleware])
|
||||
assert response3 is not None
|
||||
assert execution_count["count"] == 2 # Should be 2 now
|
||||
|
||||
async def test_chat_client_middleware_can_access_and_override_custom_kwargs(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that chat client middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
|
||||
@chat_middleware
|
||||
async def kwargs_middleware(context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
|
||||
await next(context)
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.middleware = [kwargs_middleware]
|
||||
|
||||
# Execute chat client with custom parameters
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await chat_client_base.get_response(
|
||||
messages, temperature=0.7, max_tokens=100, custom_param="test_value"
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
assert captured_kwargs["temperature"] == 0.7
|
||||
assert captured_kwargs["max_tokens"] == 100
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
|
||||
|
||||
async def test_function_middleware_registration_on_chat_client(self) -> None:
|
||||
"""Test function middleware registered on ChatClient is executed during function calls."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@function_middleware
|
||||
async def test_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append(f"function_middleware_before_{context.function.name}")
|
||||
await next(context)
|
||||
execution_order.append(f"function_middleware_after_{context.function.name}")
|
||||
|
||||
# Define a simple tool function
|
||||
def sample_tool(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
# Create function-invocation enabled chat client
|
||||
chat_client = use_function_invocation(MockBaseChatClient)()
|
||||
|
||||
# Set function middleware directly on the chat client
|
||||
chat_client.middleware = [test_function_middleware]
|
||||
|
||||
# Prepare responses that will trigger function invocation
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_1",
|
||||
name="sample_tool",
|
||||
arguments={"location": "San Francisco"},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Based on the weather data, it's sunny!")]
|
||||
)
|
||||
|
||||
chat_client.run_responses = [function_call_response, final_response]
|
||||
|
||||
# Execute the chat client directly with tools - this should trigger function invocation and middleware
|
||||
messages = [ChatMessage(role=Role.USER, text="What's the weather in San Francisco?")]
|
||||
response = await chat_client.get_response(messages, tools=[sample_tool])
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: function call + final response
|
||||
|
||||
# Verify function middleware was executed
|
||||
assert execution_order == [
|
||||
"function_middleware_before_sample_tool",
|
||||
"function_middleware_after_sample_tool",
|
||||
]
|
||||
|
||||
async def test_run_level_function_middleware(self) -> None:
|
||||
"""Test that function middleware passed to get_response method is also invoked."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@function_middleware
|
||||
async def run_level_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("run_level_function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("run_level_function_middleware_after")
|
||||
|
||||
# Define a simple tool function
|
||||
def sample_tool(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
# Create function-invocation enabled chat client
|
||||
chat_client = use_function_invocation(MockBaseChatClient)()
|
||||
|
||||
# Prepare responses that will trigger function invocation
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_2",
|
||||
name="sample_tool",
|
||||
arguments={"location": "New York"},
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="The weather information has been retrieved!")]
|
||||
)
|
||||
|
||||
chat_client.run_responses = [function_call_response, final_response]
|
||||
|
||||
# Execute the chat client directly with run-level middleware and tools
|
||||
messages = [ChatMessage(role=Role.USER, text="What's the weather in New York?")]
|
||||
response = await chat_client.get_response(
|
||||
messages, tools=[sample_tool], middleware=[run_level_function_middleware]
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: function call + final response
|
||||
|
||||
# Verify run-level function middleware was executed once (during function invocation)
|
||||
assert execution_order == [
|
||||
"run_level_function_middleware_before",
|
||||
"run_level_function_middleware_after",
|
||||
]
|
||||
@@ -13,6 +13,7 @@ from agent_framework import (
|
||||
ChatResponseUpdate,
|
||||
Role,
|
||||
TextContent,
|
||||
use_chat_middleware,
|
||||
use_function_invocation,
|
||||
)
|
||||
|
||||
@@ -37,6 +38,7 @@ custom client with ChatAgent through the create_agent() method.
|
||||
|
||||
|
||||
@use_function_invocation
|
||||
@use_chat_middleware
|
||||
class EchoingChatClient(BaseChatClient):
|
||||
"""A custom chat client that echoes messages back with modifications.
|
||||
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
# Middleware Examples
|
||||
|
||||
This folder contains examples demonstrating various middleware patterns with the Agent Framework. Middleware allows you to intercept and modify behavior at different execution stages, including agent runs, function calls, and chat interactions.
|
||||
|
||||
## Examples
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| [`function_based_middleware.py`](function_based_middleware.py) | Demonstrates how to implement middleware using simple async functions instead of classes. Shows security validation, logging, and performance monitoring middleware. Function-based middleware is ideal for simple, stateless operations and provides a lightweight approach. |
|
||||
| [`class_based_middleware.py`](class_based_middleware.py) | Shows how to implement middleware using class-based approach by inheriting from `AgentMiddleware` and `FunctionMiddleware` base classes. Includes security checks for sensitive information and detailed function execution logging with timing. |
|
||||
| [`decorator_middleware.py`](decorator_middleware.py) | Demonstrates how to use `@agent_middleware` and `@function_middleware` decorators to explicitly mark middleware functions without requiring type annotations. Shows different middleware detection scenarios and explicit decorator usage. |
|
||||
| [`middleware_termination.py`](middleware_termination.py) | Shows how middleware can terminate execution using the `context.terminate` flag. Includes examples of pre-termination (prevents agent processing) and post-termination (allows processing but stops further execution). Useful for security checks, rate limiting, or early exit conditions. |
|
||||
| [`exception_handling_with_middleware.py`](exception_handling_with_middleware.py) | Demonstrates how to use middleware for centralized exception handling in function calls. Shows how to catch exceptions from functions, provide graceful error responses, and override function results when errors occur to provide user-friendly messages. |
|
||||
| [`override_result_with_middleware.py`](override_result_with_middleware.py) | Shows how to use middleware to intercept and modify function results after execution, supporting both regular and streaming agent responses. Demonstrates result filtering, formatting, enhancement, and custom streaming response generation. |
|
||||
| [`shared_state_middleware.py`](shared_state_middleware.py) | Demonstrates how to implement function-based middleware within a class to share state between multiple middleware functions. Shows how middleware can work together by sharing state, including call counting and result enhancement. |
|
||||
| [`agent_and_run_level_middleware.py`](agent_and_run_level_middleware.py) | Explains the difference between agent-level middleware (applied to ALL runs of the agent) and run-level middleware (applied to specific runs only). Shows security validation, performance monitoring, and context-specific middleware patterns. |
|
||||
| [`chat_middleware.py`](chat_middleware.py) | Demonstrates how to use chat middleware to observe and override inputs sent to AI models. Shows how to intercept chat requests, log and modify input messages, and override entire responses before they reach the underlying AI service. |
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Middleware Types
|
||||
|
||||
- **Agent Middleware**: Intercepts agent run execution, allowing you to modify requests and responses
|
||||
- **Function Middleware**: Intercepts function calls within agents, enabling logging, validation, and result modification
|
||||
- **Chat Middleware**: Intercepts chat requests sent to AI models, allowing input/output transformation
|
||||
|
||||
### Implementation Approaches
|
||||
|
||||
- **Function-based**: Simple async functions for lightweight, stateless operations
|
||||
- **Class-based**: Inherit from base middleware classes for complex, stateful operations
|
||||
- **Decorator-based**: Use decorators for explicit middleware marking
|
||||
|
||||
### Common Use Cases
|
||||
|
||||
- **Security**: Validate requests, block sensitive information, implement access controls
|
||||
- **Logging**: Track execution timing, log parameters and results, monitor performance
|
||||
- **Error Handling**: Catch exceptions, provide graceful fallbacks, implement retry logic
|
||||
- **Result Transformation**: Filter, format, or enhance function outputs
|
||||
- **State Management**: Share data between middleware functions, maintain execution context
|
||||
|
||||
### Execution Control
|
||||
|
||||
- **Termination**: Use `context.terminate` to stop execution early
|
||||
- **Result Override**: Modify or replace function/agent results
|
||||
- **Streaming Support**: Handle both regular and streaming responses
|
||||
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
ChatContext,
|
||||
ChatMessage,
|
||||
ChatMiddleware,
|
||||
ChatResponse,
|
||||
Role,
|
||||
chat_middleware,
|
||||
)
|
||||
from agent_framework.azure import AzureAIAgentClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Chat Middleware Example
|
||||
|
||||
This sample demonstrates how to use chat middleware to observe and override
|
||||
inputs sent to AI models. Chat middleware intercepts chat requests before they reach
|
||||
the underlying AI service, allowing you to:
|
||||
|
||||
1. Observe and log input messages
|
||||
2. Modify input messages before sending to AI
|
||||
3. Override the entire response
|
||||
|
||||
The example covers:
|
||||
- Class-based chat middleware inheriting from ChatMiddleware
|
||||
- Function-based chat middleware with @chat_middleware decorator
|
||||
- Middleware registration at agent level (applies to all runs)
|
||||
- Middleware registration at run level (applies to specific run only)
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
class InputObserverMiddleware(ChatMiddleware):
|
||||
"""Class-based middleware that observes and modifies input messages."""
|
||||
|
||||
def __init__(self, replacement: str | None = None):
|
||||
"""Initialize with a replacement for user messages."""
|
||||
self.replacement = replacement
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
next: Callable[[ChatContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Observe and modify input messages before they are sent to AI."""
|
||||
print("[InputObserverMiddleware] Observing input messages:")
|
||||
|
||||
for i, message in enumerate(context.messages):
|
||||
content = message.text if message.text else str(message.contents)
|
||||
print(f" Message {i + 1} ({message.role.value}): {content}")
|
||||
|
||||
print(f"[InputObserverMiddleware] Total messages: {len(context.messages)}")
|
||||
|
||||
# Modify user messages by creating new messages with enhanced text
|
||||
modified_messages: list[ChatMessage] = []
|
||||
modified_count = 0
|
||||
|
||||
for message in context.messages:
|
||||
if message.role == Role.USER and message.text:
|
||||
original_text = message.text
|
||||
updated_text = original_text
|
||||
|
||||
if self.replacement:
|
||||
updated_text = self.replacement
|
||||
print(f"[InputObserverMiddleware] Updated: '{original_text}' -> '{updated_text}'")
|
||||
|
||||
modified_message = ChatMessage(role=message.role, text=updated_text)
|
||||
modified_messages.append(modified_message)
|
||||
modified_count += 1
|
||||
else:
|
||||
modified_messages.append(message)
|
||||
|
||||
# Replace messages in context
|
||||
context.messages[:] = modified_messages
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await next(context)
|
||||
|
||||
# Observe that processing is complete
|
||||
print("[InputObserverMiddleware] Processing completed")
|
||||
|
||||
|
||||
@chat_middleware
|
||||
async def security_and_override_middleware(
|
||||
context: ChatContext,
|
||||
next: Callable[[ChatContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function-based middleware that implements security filtering and response override."""
|
||||
print("[SecurityMiddleware] Processing input...")
|
||||
|
||||
# Security check - block sensitive information
|
||||
blocked_terms = ["password", "secret", "api_key", "token"]
|
||||
|
||||
for message in context.messages:
|
||||
if message.text:
|
||||
message_lower = message.text.lower()
|
||||
for term in blocked_terms:
|
||||
if term in message_lower:
|
||||
print(f"[SecurityMiddleware] BLOCKED: Found '{term}' in message")
|
||||
|
||||
# Override the response instead of calling AI
|
||||
context.result = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
text="I cannot process requests containing sensitive information. "
|
||||
"Please rephrase your question without including passwords, secrets, or other "
|
||||
"sensitive data.",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Set terminate flag to stop execution
|
||||
context.terminate = True
|
||||
return
|
||||
|
||||
# Continue to next middleware or AI execution
|
||||
await next(context)
|
||||
|
||||
|
||||
async def class_based_chat_middleware() -> None:
|
||||
"""Demonstrate class-based middleware at agent level."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Class-based Chat Middleware (Agent Level)")
|
||||
print("=" * 60)
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
AzureAIAgentClient(async_credential=credential).create_agent(
|
||||
name="EnhancedChatAgent",
|
||||
instructions="You are a helpful AI assistant.",
|
||||
# Register class-based middleware at agent level (applies to all runs)
|
||||
middleware=InputObserverMiddleware(),
|
||||
tools=get_weather,
|
||||
) as agent,
|
||||
):
|
||||
query = "What's the weather in Seattle?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Final Response: {result.text if result.text else 'No response'}")
|
||||
|
||||
|
||||
async def function_based_chat_middleware() -> None:
|
||||
"""Demonstrate function-based middleware at agent level."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Function-based Chat Middleware (Agent Level)")
|
||||
print("=" * 60)
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
AzureAIAgentClient(async_credential=credential).create_agent(
|
||||
name="FunctionMiddlewareAgent",
|
||||
instructions="You are a helpful AI assistant.",
|
||||
# Register function-based middleware at agent level
|
||||
middleware=security_and_override_middleware,
|
||||
) as agent,
|
||||
):
|
||||
# Scenario with normal query
|
||||
print("\n--- Scenario 1: Normal Query ---")
|
||||
query = "Hello, how are you?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Final Response: {result.text if result.text else 'No response'}")
|
||||
|
||||
# Scenario with security violation
|
||||
print("\n--- Scenario 2: Security Violation ---")
|
||||
query = "What is my password for this account?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Final Response: {result.text if result.text else 'No response'}")
|
||||
|
||||
|
||||
async def run_level_middleware() -> None:
|
||||
"""Demonstrate middleware registration at run level."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Run-level Chat Middleware")
|
||||
print("=" * 60)
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
AzureAIAgentClient(async_credential=credential).create_agent(
|
||||
name="RunLevelAgent",
|
||||
instructions="You are a helpful AI assistant.",
|
||||
tools=get_weather,
|
||||
# No middleware at agent level
|
||||
) as agent,
|
||||
):
|
||||
# Scenario 1: Run without any middleware
|
||||
print("\n--- Scenario 1: No Middleware ---")
|
||||
query = "What's the weather in Tokyo?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Response: {result.text if result.text else 'No response'}")
|
||||
|
||||
# Scenario 2: Run with specific middleware for this call only (both enhancement and security)
|
||||
print("\n--- Scenario 2: With Run-level Middleware ---")
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=[
|
||||
InputObserverMiddleware(replacement="What's the weather in Madrid?"),
|
||||
security_and_override_middleware,
|
||||
],
|
||||
)
|
||||
print(f"Response: {result.text if result.text else 'No response'}")
|
||||
|
||||
# Scenario 3: Security test with run-level middleware
|
||||
print("\n--- Scenario 3: Security Test with Run-level Middleware ---")
|
||||
query = "Can you help me with my secret API key?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=security_and_override_middleware,
|
||||
)
|
||||
print(f"Response: {result.text if result.text else 'No response'}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run all chat middleware examples."""
|
||||
print("Chat Middleware Examples")
|
||||
print("========================")
|
||||
|
||||
await class_based_chat_middleware()
|
||||
await function_based_chat_middleware()
|
||||
await run_level_middleware()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,28 +1,37 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import FunctionInvocationContext
|
||||
from agent_framework import (
|
||||
AgentRunContext,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
ChatMessage,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
from agent_framework.azure import AzureAIAgentClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Result Override with Middleware
|
||||
Result Override with Middleware (Regular and Streaming)
|
||||
|
||||
This sample demonstrates how to use middleware to intercept and modify function results
|
||||
after execution. The example shows:
|
||||
after execution, supporting both regular and streaming agent responses. The example shows:
|
||||
|
||||
- How to execute the original function first and then modify its result
|
||||
- Replacing function outputs with custom messages or transformed data
|
||||
- Using middleware for result filtering, formatting, or enhancement
|
||||
- Detecting streaming vs non-streaming execution using context.is_streaming
|
||||
- Overriding streaming results with custom async generators
|
||||
|
||||
The weather override middleware lets the original weather function execute normally,
|
||||
then replaces its result with a custom "perfect weather" message, demonstrating
|
||||
how middleware can be used for content filtering, A/B testing, or result enhancement.
|
||||
then replaces its result with a custom "perfect weather" message. For streaming responses,
|
||||
it creates a custom async generator that yields the override message in chunks.
|
||||
"""
|
||||
|
||||
|
||||
@@ -35,32 +44,39 @@ def get_weather(
|
||||
|
||||
|
||||
async def weather_override_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
"""Middleware that overrides weather results for both streaming and non-streaming cases."""
|
||||
|
||||
# Let the original function execute first
|
||||
# Let the original agent execution complete first
|
||||
await next(context)
|
||||
|
||||
# Override the result if it's a weather function
|
||||
if function_name == "get_weather" and context.result is not None:
|
||||
original_result = str(context.result)
|
||||
print(f"[WeatherOverrideMiddleware] Original result: {original_result}")
|
||||
# Check if there's a result to override (agent called weather function)
|
||||
if context.result is not None:
|
||||
# Create custom weather message
|
||||
chunks = [
|
||||
"Weather Advisory - ",
|
||||
"due to special atmospheric conditions, ",
|
||||
"all locations are experiencing perfect weather today! ",
|
||||
"Temperature is a comfortable 22°C with gentle breezes. ",
|
||||
"Perfect day for outdoor activities!",
|
||||
]
|
||||
|
||||
# Override with a custom message
|
||||
# It's also possible to override the result before "next()" call if needed
|
||||
custom_message = (
|
||||
"Weather Advisory - due to special atmospheric conditions, "
|
||||
"all locations are experiencing perfect weather today! "
|
||||
"Temperature is a comfortable 22°C with gentle breezes. "
|
||||
"Perfect day for outdoor activities!"
|
||||
)
|
||||
context.result = custom_message
|
||||
print(f"[WeatherOverrideMiddleware] Overriding with custom message: {custom_message}")
|
||||
if context.is_streaming:
|
||||
# For streaming: create an async generator that yields chunks
|
||||
async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
for chunk in chunks:
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text=chunk)])
|
||||
|
||||
context.result = override_stream()
|
||||
else:
|
||||
# For non-streaming: just replace with the string message
|
||||
custom_message = "".join(chunks)
|
||||
context.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=custom_message)])
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating result override with middleware."""
|
||||
"""Example demonstrating result override with middleware for both streaming and non-streaming."""
|
||||
print("=== Result Override Middleware Example ===")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
@@ -74,11 +90,22 @@ async def main() -> None:
|
||||
middleware=weather_override_middleware,
|
||||
) as agent,
|
||||
):
|
||||
# Non-streaming example
|
||||
print("\n--- Non-streaming Example ---")
|
||||
query = "What's the weather like in Seattle?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}")
|
||||
|
||||
# Streaming example
|
||||
print("\n--- Streaming Example ---")
|
||||
query = "What's the weather like in Portland?"
|
||||
print(f"User: {query}")
|
||||
print("Agent: ", end="", flush=True)
|
||||
async for chunk in agent.run_stream(query):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
FunctionInvocationContext,
|
||||
)
|
||||
from agent_framework.azure import AzureAIAgentClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Shared State Function-based Middleware Example
|
||||
|
||||
This sample demonstrates how to implement function-based middleware within a class to share state.
|
||||
The example includes:
|
||||
|
||||
- A MiddlewareContainer class with two simple function middleware methods
|
||||
- First middleware: Counts function calls and stores the count in shared state
|
||||
- Second middleware: Uses the shared count to add call numbers to function results
|
||||
|
||||
This approach shows how middleware can work together by sharing state within the same class instance.
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
def get_time(
|
||||
timezone: Annotated[str, Field(description="The timezone to get the time for.")] = "UTC",
|
||||
) -> str:
|
||||
"""Get the current time for a given timezone."""
|
||||
import datetime
|
||||
|
||||
return f"The current time in {timezone} is {datetime.datetime.now().strftime('%H:%M:%S')}"
|
||||
|
||||
|
||||
class MiddlewareContainer:
|
||||
"""Container class that holds middleware functions with shared state."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Simple shared state: count function calls
|
||||
self.call_count: int = 0
|
||||
|
||||
async def call_counter_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""First middleware: increments call count in shared state."""
|
||||
# Increment the shared call count
|
||||
self.call_count += 1
|
||||
|
||||
print(f"[CallCounter] This is function call #{self.call_count}")
|
||||
|
||||
# Call the next middleware/function
|
||||
await next(context)
|
||||
|
||||
async def result_enhancer_middleware(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Second middleware: uses shared call count to enhance function results."""
|
||||
print(f"[ResultEnhancer] Current total calls so far: {self.call_count}")
|
||||
|
||||
# Call the next middleware/function
|
||||
await next(context)
|
||||
|
||||
# After function execution, enhance the result using shared state
|
||||
if context.result:
|
||||
enhanced_result = f"[Call #{self.call_count}] {context.result}"
|
||||
context.result = enhanced_result
|
||||
print("[ResultEnhancer] Enhanced result with call number")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating shared state function-based middleware."""
|
||||
print("=== Shared State Function-based Middleware Example ===")
|
||||
|
||||
# Create middleware container with shared state
|
||||
middleware_container = MiddlewareContainer()
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
AzureAIAgentClient(async_credential=credential).create_agent(
|
||||
name="UtilityAgent",
|
||||
instructions="You are a helpful assistant that can provide weather information and current time.",
|
||||
tools=[get_weather, get_time],
|
||||
# Pass both middleware functions from the same container instance
|
||||
# Order matters: counter runs first to increment count,
|
||||
# then result enhancer uses the updated count
|
||||
middleware=[
|
||||
middleware_container.call_counter_middleware,
|
||||
middleware_container.result_enhancer_middleware,
|
||||
],
|
||||
) as agent,
|
||||
):
|
||||
# Test multiple requests to see shared state in action
|
||||
queries = [
|
||||
"What's the weather like in New York?",
|
||||
"What time is it in London?",
|
||||
"What's the weather in Tokyo?",
|
||||
]
|
||||
|
||||
for i, query in enumerate(queries, 1):
|
||||
print(f"\n--- Query {i} ---")
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
|
||||
# Display final statistics
|
||||
print("\n=== Final Statistics ===")
|
||||
print(f"Total function calls made: {middleware_container.call_count}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user