mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Agent and Function middleware (#770)
* Initial middleware implementation
* Small fixes
* Small updates
* Small updates in samples
* Moved middleware functionality to decorator
* Removed obsolete file
* Renamed AgentInvocationContext to AzureRunContext
* Added unit tests
* Small settings update for test discovery in VS Code
* Added unit tests
* Reverted changes in environment settings
* Added context result override
* Renaming and updates to logic
* Added more samples
* Updated DEV_SETUP.md
* Addressed PR feedback
* Addressed PR feedback
* Removed unused parameter
* Small fix
* Small fix in telemetry logic
* Revert "Small fix in telemetry logic"
This reverts commit 6f82660d2d.
* Small fix
---------
Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
538be4c149
commit
99860a5d07
@@ -13,6 +13,7 @@ from ._clients import * # noqa: F403
|
||||
from ._logging import * # noqa: F403
|
||||
from ._mcp import * # noqa: F403
|
||||
from ._memory import * # noqa: F403
|
||||
from ._middleware import * # noqa: F403
|
||||
from ._threads import * # noqa: F403
|
||||
from ._tools import * # noqa: F403
|
||||
from ._types import * # noqa: F403
|
||||
|
||||
@@ -12,6 +12,7 @@ from ._clients import BaseChatClient, ChatClientProtocol
|
||||
from ._logging import get_logger
|
||||
from ._mcp import MCPTool
|
||||
from ._memory import AggregateContextProvider, Context, ContextProvider
|
||||
from ._middleware import Middleware, use_agent_middleware
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._threads import AgentThread, ChatMessageStore, deserialize_thread_state, thread_on_new_messages
|
||||
from ._tools import FUNCTION_INVOKING_CHAT_CLIENT_MARKER, ToolProtocol
|
||||
@@ -138,12 +139,14 @@ class BaseAgent(AFBaseModel):
|
||||
description: The description of the agent.
|
||||
display_name: The display name of the agent, which is either the name or id.
|
||||
context_providers: The collection of multiple context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
context_providers: AggregateContextProvider | None = None
|
||||
middleware: Middleware | list[Middleware] | None = None
|
||||
|
||||
async def _notify_thread_of_new_messages(
|
||||
self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]
|
||||
@@ -189,6 +192,7 @@ class BaseAgent(AFBaseModel):
|
||||
# region ChatAgent
|
||||
|
||||
|
||||
@use_agent_middleware
|
||||
@use_agent_telemetry
|
||||
class ChatAgent(BaseAgent):
|
||||
"""A Chat Client Agent."""
|
||||
@@ -231,6 +235,7 @@ class ChatAgent(BaseAgent):
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
|
||||
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
||||
middleware: Middleware | list[Middleware] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a ChatAgent.
|
||||
@@ -266,6 +271,7 @@ class ChatAgent(BaseAgent):
|
||||
chat_message_store_factory: factory function to create an instance of ChatMessageStore. If not provided,
|
||||
the default in-memory store will be used.
|
||||
context_providers: The collection of multiple context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
kwargs: any additional keyword arguments.
|
||||
Unused, can be used by subclasses of this Agent.
|
||||
"""
|
||||
@@ -287,6 +293,7 @@ class ChatAgent(BaseAgent):
|
||||
"chat_client": chat_client,
|
||||
"chat_message_store_factory": chat_message_store_factory,
|
||||
"context_providers": aggregate_context_providers,
|
||||
"middleware": middleware,
|
||||
"chat_options": ChatOptions(
|
||||
ai_model_id=model,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
|
||||
from ._logging import get_logger
|
||||
from ._mcp import MCPTool
|
||||
from ._memory import AggregateContextProvider, ContextProvider
|
||||
from ._middleware import Middleware
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._threads import ChatMessageStore
|
||||
from ._tools import ToolProtocol
|
||||
@@ -344,7 +345,14 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
)
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
@@ -424,8 +432,15 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
)
|
||||
prepped_messages = self.prepare_messages(messages)
|
||||
self._prepare_tool_choice(chat_options=chat_options)
|
||||
|
||||
# Remove middleware pipeline from kwargs as it's only used by function invocation wrappers
|
||||
if "_function_middleware_pipeline" in kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "_function_middleware_pipeline"}
|
||||
else:
|
||||
filtered_kwargs = kwargs
|
||||
|
||||
async for update in self._inner_get_streaming_response(
|
||||
messages=prepped_messages, chat_options=chat_options, **kwargs
|
||||
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
|
||||
):
|
||||
yield update
|
||||
|
||||
@@ -465,6 +480,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
| None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
|
||||
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
||||
middleware: Middleware | list[Middleware] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "ChatAgent":
|
||||
"""Create an agent with the given name and instructions.
|
||||
@@ -476,6 +492,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided,
|
||||
the default in-memory store will be used.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
**kwargs: Additional keyword arguments to pass to the agent.
|
||||
See ChatAgent for all the available options.
|
||||
|
||||
@@ -491,6 +508,7 @@ class BaseChatClient(AFBaseModel, ABC):
|
||||
tools=tools,
|
||||
chat_message_store_factory=chat_message_store_factory,
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,632 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
|
||||
|
||||
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._agents import AgentProtocol
|
||||
from ._tools import AIFunction
|
||||
|
||||
TAgent = TypeVar("TAgent", bound="AgentProtocol")
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentRunContext",
|
||||
"FunctionInvocationContext",
|
||||
"FunctionMiddleware",
|
||||
"Middleware",
|
||||
"use_agent_middleware",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRunContext:
|
||||
"""Context object for agent middleware invocations.
|
||||
|
||||
Attributes:
|
||||
agent: The agent being invoked.
|
||||
messages: The messages being sent to the agent.
|
||||
is_streaming: Whether this is a streaming invocation.
|
||||
metadata: Metadata dictionary for sharing data between agent middleware.
|
||||
result: Agent execution result. Can be observed after calling next()
|
||||
to see the actual execution result or can be set to override the execution result.
|
||||
For non-streaming: should be AgentRunResponse
|
||||
For streaming: should be AsyncIterable[AgentRunResponseUpdate]
|
||||
"""
|
||||
|
||||
agent: "AgentProtocol"
|
||||
messages: list[ChatMessage]
|
||||
is_streaming: bool = False
|
||||
metadata: dict[str, Any] = field(default_factory=lambda: {})
|
||||
result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionInvocationContext:
|
||||
"""Context object for function middleware invocations.
|
||||
|
||||
Attributes:
|
||||
function: The function being invoked.
|
||||
arguments: The validated arguments for the function.
|
||||
metadata: Metadata dictionary for sharing data between function middleware.
|
||||
result: Function execution result. Can be observed after calling next()
|
||||
to see the actual execution result or can be set to override the execution result.
|
||||
"""
|
||||
|
||||
function: "AIFunction[Any, Any]"
|
||||
arguments: "BaseModel"
|
||||
metadata: dict[str, Any] = field(default_factory=lambda: {})
|
||||
result: Any = None
|
||||
|
||||
|
||||
class AgentMiddleware(ABC):
|
||||
"""Abstract base class for agent middleware that can intercept agent invocations."""
|
||||
|
||||
@abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process an agent invocation.
|
||||
|
||||
Args:
|
||||
context: Agent invocation context containing agent, messages, and metadata.
|
||||
Use context.is_streaming to determine if this is a streaming call.
|
||||
Middleware can set context.result to override execution, or observe
|
||||
the actual execution result after calling next().
|
||||
For non-streaming: AgentRunResponse
|
||||
For streaming: AsyncIterable[AgentRunResponseUpdate]
|
||||
next: Function to call the next middleware or final agent execution.
|
||||
Does not return anything - all data flows through the context.
|
||||
|
||||
Note:
|
||||
Middleware should not return anything. All data manipulation should happen
|
||||
within the context object. Set context.result to override execution,
|
||||
or observe context.result after calling next() for actual results.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class FunctionMiddleware(ABC):
|
||||
"""Abstract base class for function middleware that can intercept function invocations."""
|
||||
|
||||
@abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a function invocation.
|
||||
|
||||
Args:
|
||||
context: Function invocation context containing function, arguments, and metadata.
|
||||
Middleware can set context.result to override execution, or observe
|
||||
the actual execution result after calling next().
|
||||
next: Function to call the next middleware or final function execution.
|
||||
Does not return anything - all data flows through the context.
|
||||
|
||||
Note:
|
||||
Middleware should not return anything. All data manipulation should happen
|
||||
within the context object. Set context.result to override execution,
|
||||
or observe context.result after calling next() for actual results.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Pure function type definitions for convenience
|
||||
AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]]
|
||||
|
||||
FunctionMiddlewareCallable = Callable[
|
||||
[FunctionInvocationContext, Callable[[FunctionInvocationContext], Awaitable[None]]], Awaitable[None]
|
||||
]
|
||||
|
||||
# Type alias for all middleware types
|
||||
Middleware: TypeAlias = AgentMiddleware | AgentMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable
|
||||
|
||||
|
||||
class AgentMiddlewareWrapper(AgentMiddleware):
|
||||
"""Wrapper to convert pure functions into AgentMiddleware protocol objects."""
|
||||
|
||||
def __init__(self, func: AgentMiddlewareCallable):
|
||||
self.func = func
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
await self.func(context, next)
|
||||
|
||||
|
||||
class FunctionMiddlewareWrapper(FunctionMiddleware):
|
||||
"""Wrapper to convert pure functions into FunctionMiddleware protocol objects."""
|
||||
|
||||
def __init__(self, func: FunctionMiddlewareCallable):
|
||||
self.func = func
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
await self.func(context, next)
|
||||
|
||||
|
||||
class BaseMiddlewarePipeline(ABC):
|
||||
"""Base class for middleware pipeline execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the base middleware pipeline."""
|
||||
self._middlewares: list[Any] = []
|
||||
|
||||
@abstractmethod
|
||||
def _register_middleware(self, middleware: Any) -> None:
|
||||
"""Register a middleware item. Must be implemented by subclasses."""
|
||||
...
|
||||
|
||||
@property
|
||||
def has_middlewares(self) -> bool:
|
||||
"""Check if there are any middlewares registered."""
|
||||
return bool(self._middlewares)
|
||||
|
||||
def _create_handler_chain(
|
||||
self,
|
||||
final_handler: Callable[[Any], Awaitable[Any]],
|
||||
result_container: dict[str, Any],
|
||||
result_key: str = "result",
|
||||
) -> Callable[[Any], Awaitable[None]]:
|
||||
"""Create a chain of middleware handlers.
|
||||
|
||||
Args:
|
||||
final_handler: The final handler to execute
|
||||
result_container: Container to store the result
|
||||
result_key: Key to use in the result container
|
||||
|
||||
Returns:
|
||||
The first handler in the chain
|
||||
"""
|
||||
|
||||
def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]:
|
||||
if index >= len(self._middlewares):
|
||||
|
||||
async def final_wrapper(c: Any) -> None:
|
||||
# Execute actual handler and populate context for observability
|
||||
result = await final_handler(c)
|
||||
result_container[result_key] = result
|
||||
c.result = result
|
||||
|
||||
return final_wrapper
|
||||
|
||||
middleware = self._middlewares[index]
|
||||
next_handler = create_next_handler(index + 1)
|
||||
|
||||
async def current_handler(c: Any) -> None:
|
||||
await middleware.process(c, next_handler)
|
||||
|
||||
return current_handler
|
||||
|
||||
return create_next_handler(0)
|
||||
|
||||
|
||||
class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
"""Executes agent middleware in a chain."""
|
||||
|
||||
def __init__(self, middlewares: list[AgentMiddleware | AgentMiddlewareCallable] | None = None):
|
||||
"""Initialize the agent middleware pipeline.
|
||||
|
||||
Args:
|
||||
middlewares: List of agent middleware to include in the pipeline.
|
||||
"""
|
||||
super().__init__()
|
||||
self._middlewares: list[AgentMiddleware] = []
|
||||
|
||||
if middlewares:
|
||||
for middleware in middlewares:
|
||||
self._register_middleware(middleware)
|
||||
|
||||
def _register_middleware(self, middleware: AgentMiddleware | AgentMiddlewareCallable) -> None:
|
||||
"""Register an agent middleware item."""
|
||||
if isinstance(middleware, AgentMiddleware):
|
||||
self._middlewares.append(middleware)
|
||||
elif callable(middleware):
|
||||
self._middlewares.append(AgentMiddlewareWrapper(middleware))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
agent: "AgentProtocol",
|
||||
messages: list[ChatMessage],
|
||||
context: AgentRunContext,
|
||||
final_handler: Callable[[AgentRunContext], Awaitable[AgentRunResponse]],
|
||||
) -> AgentRunResponse | None:
|
||||
"""Execute the agent middleware pipeline for non-streaming.
|
||||
|
||||
Args:
|
||||
agent: The agent being invoked.
|
||||
messages: The messages to send to the agent.
|
||||
context: The agent invocation context.
|
||||
final_handler: The final handler that performs the actual agent execution.
|
||||
|
||||
Returns:
|
||||
The agent response after processing through all middleware.
|
||||
"""
|
||||
# Update context with agent and messages
|
||||
context.agent = agent
|
||||
context.messages = messages
|
||||
context.is_streaming = False
|
||||
|
||||
if not self._middlewares:
|
||||
return await final_handler(context)
|
||||
|
||||
# Store the final result
|
||||
result_container: dict[str, AgentRunResponse | None] = {"response": None}
|
||||
|
||||
def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]:
|
||||
if index >= len(self._middlewares):
|
||||
|
||||
async def final_wrapper(c: AgentRunContext) -> None:
|
||||
# Execute actual handler and populate context for observability
|
||||
result = await final_handler(c)
|
||||
result_container["result"] = result
|
||||
c.result = result
|
||||
|
||||
return final_wrapper
|
||||
|
||||
middleware = self._middlewares[index]
|
||||
next_handler = create_next_handler(index + 1)
|
||||
|
||||
async def current_handler(c: AgentRunContext) -> None:
|
||||
await middleware.process(c, next_handler)
|
||||
# After middleware execution, check if response was overridden
|
||||
if c.result is not None and isinstance(c.result, AgentRunResponse):
|
||||
result_container["result"] = c.result
|
||||
|
||||
return current_handler
|
||||
|
||||
first_handler = create_next_handler(0)
|
||||
await first_handler(context)
|
||||
|
||||
# Return the result from result container or overridden result
|
||||
if context.result is not None and isinstance(context.result, AgentRunResponse):
|
||||
return context.result
|
||||
|
||||
# If no result was set (next() not called), return empty AgentRunResponse
|
||||
response = result_container.get("result")
|
||||
if response is None:
|
||||
return AgentRunResponse()
|
||||
return response
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
agent: "AgentProtocol",
|
||||
messages: list[ChatMessage],
|
||||
context: AgentRunContext,
|
||||
final_handler: Callable[[AgentRunContext], AsyncIterable[AgentRunResponseUpdate]],
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
"""Execute the agent middleware pipeline for streaming.
|
||||
|
||||
Args:
|
||||
agent: The agent being invoked.
|
||||
messages: The messages to send to the agent.
|
||||
context: The agent invocation context.
|
||||
final_handler: The final handler that performs the actual agent streaming execution.
|
||||
|
||||
Yields:
|
||||
Agent response updates after processing through all middleware.
|
||||
"""
|
||||
# Update context with agent and messages
|
||||
context.agent = agent
|
||||
context.messages = messages
|
||||
context.is_streaming = True
|
||||
|
||||
if not self._middlewares:
|
||||
async for update in final_handler(context):
|
||||
yield update
|
||||
return
|
||||
|
||||
# Store the final result
|
||||
result_container: dict[str, AsyncIterable[AgentRunResponseUpdate] | None] = {"result_stream": None}
|
||||
|
||||
def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]:
|
||||
if index >= len(self._middlewares):
|
||||
|
||||
async def final_wrapper(c: AgentRunContext) -> None: # noqa: RUF029
|
||||
# Execute actual handler and populate context for observability
|
||||
result = final_handler(c)
|
||||
result_container["result_stream"] = result
|
||||
c.result = result
|
||||
|
||||
return final_wrapper
|
||||
|
||||
middleware = self._middlewares[index]
|
||||
next_handler = create_next_handler(index + 1)
|
||||
|
||||
async def current_handler(c: AgentRunContext) -> None:
|
||||
await middleware.process(c, next_handler)
|
||||
|
||||
return current_handler
|
||||
|
||||
first_handler = create_next_handler(0)
|
||||
await first_handler(context)
|
||||
|
||||
# Yield from the result stream in result container or overridden result
|
||||
if context.result is not None and hasattr(context.result, "__aiter__"):
|
||||
async for update in context.result: # type: ignore
|
||||
yield update
|
||||
return
|
||||
|
||||
result_stream = result_container["result_stream"]
|
||||
if result_stream is None:
|
||||
# If no result stream was set (next() not called), yield nothing
|
||||
return
|
||||
|
||||
async for update in result_stream:
|
||||
yield update
|
||||
|
||||
|
||||
class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
"""Executes function middleware in a chain."""
|
||||
|
||||
def __init__(self, middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] | None = None):
|
||||
"""Initialize the function middleware pipeline.
|
||||
|
||||
Args:
|
||||
middlewares: List of function middleware to include in the pipeline.
|
||||
"""
|
||||
super().__init__()
|
||||
self._middlewares: list[FunctionMiddleware] = []
|
||||
|
||||
if middlewares:
|
||||
for middleware in middlewares:
|
||||
self._register_middleware(middleware)
|
||||
|
||||
def _register_middleware(self, middleware: FunctionMiddleware | FunctionMiddlewareCallable) -> None:
|
||||
"""Register a function middleware item."""
|
||||
# Check if it's a class instance inheriting from FunctionMiddleware
|
||||
if isinstance(middleware, FunctionMiddleware):
|
||||
self._middlewares.append(middleware)
|
||||
elif callable(middleware):
|
||||
self._middlewares.append(FunctionMiddlewareWrapper(middleware))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
function: Any,
|
||||
arguments: "BaseModel",
|
||||
context: FunctionInvocationContext,
|
||||
final_handler: Callable[[FunctionInvocationContext], Awaitable[Any]],
|
||||
) -> Any:
|
||||
"""Execute the function middleware pipeline.
|
||||
|
||||
Args:
|
||||
function: The function being invoked.
|
||||
arguments: The validated arguments for the function.
|
||||
context: The function invocation context.
|
||||
final_handler: The final handler that performs the actual function execution.
|
||||
|
||||
Returns:
|
||||
The function result after processing through all middleware.
|
||||
"""
|
||||
# Update context with function and arguments
|
||||
context.function = function
|
||||
context.arguments = arguments
|
||||
|
||||
if not self._middlewares:
|
||||
return await final_handler(context)
|
||||
|
||||
# Store the final result
|
||||
result_container: dict[str, Any] = {"result": None}
|
||||
|
||||
# Custom final handler that handles pre-existing results
|
||||
async def function_final_handler(c: FunctionInvocationContext) -> Any:
|
||||
# If result was set before calling next(), skip execution
|
||||
if c.result is not None:
|
||||
return c.result
|
||||
# Execute actual handler and populate context for observability
|
||||
return await final_handler(c)
|
||||
|
||||
first_handler = self._create_handler_chain(function_final_handler, result_container, "result")
|
||||
await first_handler(context)
|
||||
|
||||
# Return the result from result container or overridden result
|
||||
if context.result is not None:
|
||||
return context.result
|
||||
return result_container["result"]
|
||||
|
||||
|
||||
# Decorator for adding middleware support to agent classes
|
||||
def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
|
||||
"""Class decorator that adds middleware support to an agent class.
|
||||
|
||||
This decorator adds middleware functionality to any agent class.
|
||||
It wraps the run() and run_stream() methods to provide middleware execution.
|
||||
|
||||
Args:
|
||||
agent_class: The agent class to add middleware support to.
|
||||
|
||||
Returns:
|
||||
The modified agent class with middleware support.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
# Store original methods
|
||||
original_run = agent_class.run # type: ignore[attr-defined]
|
||||
original_run_stream = agent_class.run_stream # type: ignore[attr-defined]
|
||||
|
||||
def _initialize_middleware_pipelines(self: Any, middlewares: Middleware | list[Middleware] | None) -> None:
|
||||
"""Initialize agent and function middleware pipelines from the provided middleware list."""
|
||||
if not middlewares:
|
||||
return
|
||||
|
||||
middleware_list: list[Middleware] = middlewares if isinstance(middlewares, list) else [middlewares] # type: ignore
|
||||
|
||||
# Separate agent and function middleware using isinstance checks
|
||||
agent_middlewares: list[AgentMiddleware | AgentMiddlewareCallable] = []
|
||||
function_middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] = []
|
||||
|
||||
for middleware in middleware_list:
|
||||
if isinstance(middleware, AgentMiddleware):
|
||||
agent_middlewares.append(middleware)
|
||||
elif isinstance(middleware, FunctionMiddleware):
|
||||
function_middlewares.append(middleware)
|
||||
elif callable(middleware): # type: ignore[arg-type]
|
||||
# Check function signature to determine type
|
||||
try:
|
||||
sig = inspect.signature(middleware)
|
||||
params = list(sig.parameters.values())
|
||||
if len(params) >= 1:
|
||||
first_param = params[0]
|
||||
# Check if first parameter is AgentRunContext or FunctionInvocationContext
|
||||
if (
|
||||
hasattr(first_param.annotation, "__name__")
|
||||
and first_param.annotation.__name__ == "AgentRunContext"
|
||||
):
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
elif (
|
||||
hasattr(first_param.annotation, "__name__")
|
||||
and first_param.annotation.__name__ == "FunctionInvocationContext"
|
||||
):
|
||||
function_middlewares.append(middleware) # type: ignore
|
||||
else:
|
||||
# Default to agent middleware if uncertain
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
else:
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
except Exception:
|
||||
# If signature inspection fails, assume it's an agent middleware
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
else:
|
||||
# Fallback
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
|
||||
self._agent_middleware_pipeline = AgentMiddlewarePipeline(agent_middlewares)
|
||||
self._function_middleware_pipeline = FunctionMiddlewarePipeline(function_middlewares)
|
||||
|
||||
async def middleware_enabled_run(
|
||||
self: Any,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResponse:
|
||||
"""Middleware-enabled run method."""
|
||||
# Initialize middleware pipelines if not already done
|
||||
if (
|
||||
hasattr(self, "middleware")
|
||||
and self.middleware
|
||||
and not (
|
||||
hasattr(self, "_agent_middleware_pipeline")
|
||||
and hasattr(self, "_function_middleware_pipeline")
|
||||
and (
|
||||
self._agent_middleware_pipeline.has_middlewares
|
||||
or self._function_middleware_pipeline.has_middlewares
|
||||
)
|
||||
)
|
||||
):
|
||||
_initialize_middleware_pipelines(self, self.middleware)
|
||||
|
||||
# Ensure pipelines exist even if empty
|
||||
if not hasattr(self, "_agent_middleware_pipeline"):
|
||||
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
|
||||
if not hasattr(self, "_function_middleware_pipeline"):
|
||||
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
|
||||
|
||||
# Add function middleware pipeline to kwargs if available
|
||||
if self._function_middleware_pipeline.has_middlewares:
|
||||
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
|
||||
|
||||
normalized_messages = self._normalize_messages(messages)
|
||||
|
||||
# Execute with middleware if available
|
||||
if self._agent_middleware_pipeline.has_middlewares:
|
||||
context = AgentRunContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
messages=normalized_messages,
|
||||
is_streaming=False,
|
||||
)
|
||||
|
||||
async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
return await original_run(self, ctx.messages, thread=thread, **kwargs) # type: ignore
|
||||
|
||||
result = await self._agent_middleware_pipeline.execute(
|
||||
self, # type: ignore[arg-type]
|
||||
normalized_messages,
|
||||
context,
|
||||
_execute_handler,
|
||||
)
|
||||
|
||||
return result if result else AgentRunResponse()
|
||||
|
||||
# No middleware, execute directly
|
||||
return await original_run(self, normalized_messages, thread=thread, **kwargs) # type: ignore[return-value]
|
||||
|
||||
def middleware_enabled_run_stream(
|
||||
self: Any,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
"""Middleware-enabled run_stream method."""
|
||||
# Initialize middleware pipelines if not already done
|
||||
if (
|
||||
hasattr(self, "middleware")
|
||||
and self.middleware
|
||||
and not (
|
||||
hasattr(self, "_agent_middleware_pipeline")
|
||||
and hasattr(self, "_function_middleware_pipeline")
|
||||
and (
|
||||
self._agent_middleware_pipeline.has_middlewares
|
||||
or self._function_middleware_pipeline.has_middlewares
|
||||
)
|
||||
)
|
||||
):
|
||||
_initialize_middleware_pipelines(self, self.middleware)
|
||||
|
||||
# Ensure pipelines exist even if empty
|
||||
if not hasattr(self, "_agent_middleware_pipeline"):
|
||||
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
|
||||
if not hasattr(self, "_function_middleware_pipeline"):
|
||||
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
|
||||
|
||||
# Add function middleware pipeline to kwargs if available
|
||||
if self._function_middleware_pipeline.has_middlewares:
|
||||
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
|
||||
|
||||
normalized_messages = self._normalize_messages(messages)
|
||||
|
||||
# Execute with middleware if available
|
||||
if self._agent_middleware_pipeline.has_middlewares:
|
||||
context = AgentRunContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
messages=normalized_messages,
|
||||
is_streaming=True,
|
||||
)
|
||||
|
||||
async def _execute_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
async for update in original_run_stream(self, ctx.messages, thread=thread, **kwargs): # type: ignore[misc]
|
||||
yield update
|
||||
|
||||
async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
async for update in self._agent_middleware_pipeline.execute_stream(
|
||||
self, # type: ignore[arg-type]
|
||||
normalized_messages,
|
||||
context,
|
||||
_execute_stream_handler,
|
||||
):
|
||||
yield update
|
||||
|
||||
return _stream_generator()
|
||||
|
||||
# No middleware, execute directly
|
||||
return original_run_stream(self, normalized_messages, thread=thread, **kwargs) # type: ignore
|
||||
|
||||
agent_class.run = middleware_enabled_run # type: ignore
|
||||
agent_class.run_stream = middleware_enabled_run_stream # type: ignore
|
||||
|
||||
return agent_class
|
||||
@@ -576,6 +576,7 @@ async def _auto_invoke_function(
|
||||
tool_map: dict[str, AIFunction[BaseModel, Any]],
|
||||
sequence_index: int | None = None,
|
||||
request_index: int | None = None,
|
||||
middleware_pipeline: Any = None, # Optional MiddlewarePipeline
|
||||
) -> "Contents":
|
||||
"""Invoke a function call requested by the agent, applying filters that are defined in the agent."""
|
||||
from ._types import FunctionResultContent
|
||||
@@ -590,14 +591,43 @@ async def _auto_invoke_function(
|
||||
merged_args: dict[str, Any] = (custom_args or {}) | parsed_args
|
||||
args = tool.input_model.model_validate(merged_args)
|
||||
exception = None
|
||||
try:
|
||||
function_result = await tool.invoke(
|
||||
|
||||
# Execute through middleware pipeline if available
|
||||
if middleware_pipeline and hasattr(middleware_pipeline, "has_middlewares") and middleware_pipeline.has_middlewares:
|
||||
from ._middleware import FunctionInvocationContext
|
||||
|
||||
middleware_context = FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
tool_call_id=function_call_content.call_id,
|
||||
) # type: ignore[arg-type]
|
||||
except Exception as ex:
|
||||
exception = ex
|
||||
function_result = None
|
||||
)
|
||||
|
||||
async def final_function_handler(context_obj: Any) -> Any:
|
||||
return await tool.invoke(
|
||||
arguments=context_obj.arguments,
|
||||
tool_call_id=function_call_content.call_id,
|
||||
)
|
||||
|
||||
try:
|
||||
function_result = await middleware_pipeline.execute(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
context=middleware_context,
|
||||
final_handler=final_function_handler,
|
||||
)
|
||||
except Exception as ex:
|
||||
exception = ex
|
||||
function_result = None
|
||||
else:
|
||||
# No middleware - execute directly
|
||||
try:
|
||||
function_result = await tool.invoke(
|
||||
arguments=args,
|
||||
tool_call_id=function_call_content.call_id,
|
||||
) # type: ignore[arg-type]
|
||||
except Exception as ex:
|
||||
exception = ex
|
||||
function_result = None
|
||||
|
||||
return FunctionResultContent(
|
||||
call_id=function_call_content.call_id,
|
||||
exception=exception,
|
||||
@@ -631,6 +661,7 @@ async def execute_function_calls(
|
||||
| Callable[..., Any] \
|
||||
| MutableMapping[str, Any] \
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]",
|
||||
middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports
|
||||
) -> list["Contents"]:
|
||||
tool_map = _get_tool_map(tools)
|
||||
# Run all function calls concurrently
|
||||
@@ -641,6 +672,7 @@ async def execute_function_calls(
|
||||
tool_map=tool_map,
|
||||
sequence_index=seq_idx,
|
||||
request_index=attempt_idx,
|
||||
middleware_pipeline=middleware_pipeline,
|
||||
)
|
||||
for seq_idx, function_call in enumerate(function_calls)
|
||||
])
|
||||
@@ -706,11 +738,14 @@ def _handle_function_calls_response(
|
||||
if not tools and (chat_options := kwargs.get("chat_options")) and isinstance(chat_options, ChatOptions):
|
||||
tools = chat_options.tools
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=function_calls,
|
||||
tools=tools, # type: ignore
|
||||
middleware_pipeline=middleware_pipeline,
|
||||
)
|
||||
# add a single ChatMessage to the response with the results
|
||||
result_message = ChatMessage(role="tool", contents=function_results) # type: ignore[call-overload]
|
||||
@@ -815,11 +850,14 @@ def _handle_function_calls_streaming_response(
|
||||
tools = chat_options.tools
|
||||
|
||||
if function_calls and tools:
|
||||
# Extract function middleware pipeline from kwargs if available
|
||||
middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
function_results = await execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=function_calls,
|
||||
tools=tools, # type: ignore[reportArgumentType]
|
||||
middleware_pipeline=middleware_pipeline,
|
||||
)
|
||||
function_result_msg = ChatMessage(role="tool", contents=function_results)
|
||||
yield ChatResponseUpdate(contents=function_results, role="tool")
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
from typing import Generic, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
TInput = TypeVar("TInput")
|
||||
TResponse = TypeVar("TResponse")
|
||||
|
||||
__all__ = ["InputGuardrail", "OutputGuardrail"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class InputGuardrail(Protocol, Generic[TInput]):
|
||||
"""A protocol for input guardrails that can validate and transform input messages."""
|
||||
|
||||
def __call__(self, message: TInput) -> TInput:
|
||||
"""Validate and possibly transform the input message."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OutputGuardrail(Protocol, Generic[TResponse]):
|
||||
"""A protocol for output guardrails that can validate and transform output messages."""
|
||||
|
||||
def __call__(self, message: TResponse) -> TResponse:
|
||||
"""Validate and possibly transform the output message."""
|
||||
...
|
||||
@@ -1026,8 +1026,12 @@ def _capture_messages(
|
||||
|
||||
prepped = prepare_messages(messages)
|
||||
for index, message in enumerate(prepped):
|
||||
try:
|
||||
message_data = message.model_dump(exclude_none=True)
|
||||
except Exception:
|
||||
message_data = {"role": message.role.value, "contents": message.contents}
|
||||
logger.info(
|
||||
message.model_dump_json(exclude_none=True),
|
||||
message_data,
|
||||
extra={
|
||||
OtelAttr.EVENT_NAME: OtelAttr.CHOICE if output else ROLE_EVENT_MAP.get(message.role.value),
|
||||
OtelAttr.PROVIDER_NAME: provider_name,
|
||||
|
||||
@@ -72,7 +72,15 @@ environments = [
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
testpaths = [
|
||||
'tests',
|
||||
'packages/main/tests',
|
||||
'packages/azure/tests',
|
||||
'packages/foundry/tests',
|
||||
'packages/copilotstudio/tests',
|
||||
'packages/mem0/tests',
|
||||
'packages/runtime/tests'
|
||||
]
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
@@ -105,6 +105,9 @@ class MockChatClient:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.additional_properties: dict[str, Any] = {}
|
||||
self.call_count: int = 0
|
||||
self.responses: list[ChatResponse] = []
|
||||
self.streaming_responses: list[list[ChatResponseUpdate]] = []
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
@@ -112,6 +115,9 @@ class MockChatClient:
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
logger.debug(f"Running custom chat client, with: {messages=}, {kwargs=}")
|
||||
self.call_count += 1
|
||||
if self.responses:
|
||||
return self.responses.pop(0)
|
||||
return ChatResponse(messages=ChatMessage(role="assistant", text="test response"))
|
||||
|
||||
async def get_streaming_response(
|
||||
@@ -120,8 +126,13 @@ class MockChatClient:
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
logger.debug(f"Running custom chat client stream, with: {messages=}, {kwargs=}")
|
||||
yield ChatResponseUpdate(text=TextContent(text="test streaming response "), role="assistant")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
|
||||
self.call_count += 1
|
||||
if self.streaming_responses:
|
||||
for update in self.streaming_responses.pop(0):
|
||||
yield update
|
||||
else:
|
||||
yield ChatResponseUpdate(text=TextContent(text="test streaming response "), role="assistant")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
|
||||
|
||||
|
||||
class MockBaseChatClient(BaseChatClient):
|
||||
|
||||
@@ -0,0 +1,939 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
ChatMessage,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
AgentMiddlewarePipeline,
|
||||
AgentRunContext,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewarePipeline,
|
||||
)
|
||||
from agent_framework._tools import AIFunction
|
||||
|
||||
|
||||
class TestAgentRunContext:
|
||||
"""Test cases for AgentRunContext."""
|
||||
|
||||
def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test AgentRunContext initialization with default values."""
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
assert context.agent is mock_agent
|
||||
assert context.messages == messages
|
||||
assert context.is_streaming is False
|
||||
assert context.metadata == {}
|
||||
|
||||
def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test AgentRunContext initialization with custom values."""
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
metadata = {"key": "value"}
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages, is_streaming=True, metadata=metadata)
|
||||
|
||||
assert context.agent is mock_agent
|
||||
assert context.messages == messages
|
||||
assert context.is_streaming is True
|
||||
assert context.metadata == metadata
|
||||
|
||||
|
||||
class TestFunctionInvocationContext:
|
||||
"""Test cases for FunctionInvocationContext."""
|
||||
|
||||
def test_init_with_defaults(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test FunctionInvocationContext initialization with default values."""
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
assert context.function is mock_function
|
||||
assert context.arguments == arguments
|
||||
assert context.metadata == {}
|
||||
|
||||
def test_init_with_custom_metadata(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test FunctionInvocationContext initialization with custom metadata."""
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
metadata = {"key": "value"}
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments, metadata=metadata)
|
||||
|
||||
assert context.function is mock_function
|
||||
assert context.arguments == arguments
|
||||
assert context.metadata == metadata
|
||||
|
||||
|
||||
class TestAgentMiddlewarePipeline:
|
||||
"""Test cases for AgentMiddlewarePipeline."""
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
"""Test AgentMiddlewarePipeline initialization with no middlewares."""
|
||||
pipeline = AgentMiddlewarePipeline()
|
||||
assert not pipeline.has_middlewares
|
||||
|
||||
def test_init_with_class_middleware(self) -> None:
|
||||
"""Test AgentMiddlewarePipeline initialization with class-based middleware."""
|
||||
middleware = TestAgentMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
def test_init_with_function_middleware(self) -> None:
|
||||
"""Test AgentMiddlewarePipeline initialization with function-based middleware."""
|
||||
|
||||
async def test_middleware(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
|
||||
pipeline = AgentMiddlewarePipeline([test_middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
async def test_execute_no_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline execution with no middleware."""
|
||||
pipeline = AgentMiddlewarePipeline()
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
return expected_response
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
assert result == expected_response
|
||||
|
||||
async def test_execute_with_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class OrderTrackingMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = OrderTrackingMiddleware("test")
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
expected_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
execution_order.append("handler")
|
||||
return expected_response
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
assert result == expected_response
|
||||
assert execution_order == ["test_before", "handler", "test_after"]
|
||||
|
||||
async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline streaming execution with no middleware."""
|
||||
pipeline = AgentMiddlewarePipeline()
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
|
||||
async def test_execute_stream_with_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline streaming execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class StreamOrderTrackingMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = StreamOrderTrackingMiddleware("test")
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
execution_order.append("handler_start")
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
|
||||
|
||||
|
||||
class TestFunctionMiddlewarePipeline:
|
||||
"""Test cases for FunctionMiddlewarePipeline."""
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
"""Test FunctionMiddlewarePipeline initialization with no middlewares."""
|
||||
pipeline = FunctionMiddlewarePipeline()
|
||||
assert not pipeline.has_middlewares
|
||||
|
||||
def test_init_with_class_middleware(self) -> None:
|
||||
"""Test FunctionMiddlewarePipeline initialization with class-based middleware."""
|
||||
middleware = TestFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
def test_init_with_function_middleware(self) -> None:
|
||||
"""Test FunctionMiddlewarePipeline initialization with function-based middleware."""
|
||||
|
||||
async def test_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
await next(context)
|
||||
|
||||
pipeline = FunctionMiddlewarePipeline([test_middleware])
|
||||
assert pipeline.has_middlewares
|
||||
|
||||
async def test_execute_no_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test pipeline execution with no middleware."""
|
||||
pipeline = FunctionMiddlewarePipeline()
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
expected_result = "function_result"
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
return expected_result
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
assert result == expected_result
|
||||
|
||||
async def test_execute_with_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test pipeline execution with middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class OrderTrackingFunctionMiddleware(FunctionMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
middleware = OrderTrackingFunctionMiddleware("test")
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
expected_result = "function_result"
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
execution_order.append("handler")
|
||||
return expected_result
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
assert result == expected_result
|
||||
assert execution_order == ["test_before", "handler", "test_after"]
|
||||
|
||||
|
||||
class TestClassBasedMiddleware:
|
||||
"""Test cases for class-based middleware implementations."""
|
||||
|
||||
async def test_agent_middleware_execution(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test class-based agent middleware execution."""
|
||||
metadata_updates: list[str] = []
|
||||
|
||||
class MetadataAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
context.metadata["before"] = True
|
||||
metadata_updates.append("before")
|
||||
await next(context)
|
||||
context.metadata["after"] = True
|
||||
metadata_updates.append("after")
|
||||
|
||||
middleware = MetadataAgentMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
metadata_updates.append("handler")
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
assert context.metadata["before"] is True
|
||||
assert context.metadata["after"] is True
|
||||
assert metadata_updates == ["before", "handler", "after"]
|
||||
|
||||
async def test_function_middleware_execution(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test class-based function middleware execution."""
|
||||
metadata_updates: list[str] = []
|
||||
|
||||
class MetadataFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
context.metadata["before"] = True
|
||||
metadata_updates.append("before")
|
||||
await next(context)
|
||||
context.metadata["after"] = True
|
||||
metadata_updates.append("after")
|
||||
|
||||
middleware = MetadataFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
metadata_updates.append("handler")
|
||||
return "result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
assert result == "result"
|
||||
assert context.metadata["before"] is True
|
||||
assert context.metadata["after"] is True
|
||||
assert metadata_updates == ["before", "handler", "after"]
|
||||
|
||||
|
||||
class TestFunctionBasedMiddleware:
|
||||
"""Test cases for function-based middleware implementations."""
|
||||
|
||||
async def test_agent_function_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test function-based agent middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def test_agent_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
context.metadata["function_middleware"] = True
|
||||
await next(context)
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = AgentMiddlewarePipeline([test_agent_middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
execution_order.append("handler")
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
assert context.metadata["function_middleware"] is True
|
||||
assert execution_order == ["function_before", "handler", "function_after"]
|
||||
|
||||
async def test_function_function_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test function-based function middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def test_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
context.metadata["function_middleware"] = True
|
||||
await next(context)
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = FunctionMiddlewarePipeline([test_function_middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
execution_order.append("handler")
|
||||
return "result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
assert result == "result"
|
||||
assert context.metadata["function_middleware"] is True
|
||||
assert execution_order == ["function_before", "handler", "function_after"]
|
||||
|
||||
|
||||
class TestMixedMiddleware:
|
||||
"""Test cases for mixed class and function-based middleware."""
|
||||
|
||||
async def test_mixed_agent_middleware(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test mixed class and function-based agent middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("class_before")
|
||||
await next(context)
|
||||
execution_order.append("class_after")
|
||||
|
||||
async def function_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
await next(context)
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = AgentMiddlewarePipeline([ClassMiddleware(), function_middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
execution_order.append("handler")
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
|
||||
|
||||
async def test_mixed_function_middleware(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test mixed class and function-based function middleware."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("class_before")
|
||||
await next(context)
|
||||
execution_order.append("class_after")
|
||||
|
||||
async def function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_before")
|
||||
await next(context)
|
||||
execution_order.append("function_after")
|
||||
|
||||
pipeline = FunctionMiddlewarePipeline([ClassMiddleware(), function_middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
execution_order.append("handler")
|
||||
return "result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
assert result == "result"
|
||||
assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"]
|
||||
|
||||
|
||||
class TestMultipleMiddlewareOrdering:
|
||||
"""Test cases for multiple middleware execution order."""
|
||||
|
||||
async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that multiple agent middlewares execute in registration order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("third_before")
|
||||
await next(context)
|
||||
execution_order.append("third_after")
|
||||
|
||||
middlewares = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()]
|
||||
pipeline = AgentMiddlewarePipeline(middlewares) # type: ignore
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
execution_order.append("handler")
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
assert result is not None
|
||||
expected_order = [
|
||||
"first_before",
|
||||
"second_before",
|
||||
"third_before",
|
||||
"handler",
|
||||
"third_after",
|
||||
"second_after",
|
||||
"first_after",
|
||||
]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_function_middleware_execution_order(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test that multiple function middlewares execute in registration order."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("first_before")
|
||||
await next(context)
|
||||
execution_order.append("first_after")
|
||||
|
||||
class SecondMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("second_before")
|
||||
await next(context)
|
||||
execution_order.append("second_after")
|
||||
|
||||
middlewares = [FirstMiddleware(), SecondMiddleware()]
|
||||
pipeline = FunctionMiddlewarePipeline(middlewares) # type: ignore
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
execution_order.append("handler")
|
||||
return "result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
assert result == "result"
|
||||
expected_order = ["first_before", "second_before", "handler", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
|
||||
class TestContextContentValidation:
|
||||
"""Test cases for validating middleware context content."""
|
||||
|
||||
async def test_agent_context_validation(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that agent context contains expected data."""
|
||||
|
||||
class ContextValidationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Verify context has all expected attributes
|
||||
assert hasattr(context, "agent")
|
||||
assert hasattr(context, "messages")
|
||||
assert hasattr(context, "is_streaming")
|
||||
assert hasattr(context, "metadata")
|
||||
|
||||
# Verify context content
|
||||
assert context.agent is mock_agent
|
||||
assert len(context.messages) == 1
|
||||
assert context.messages[0].role == Role.USER
|
||||
assert context.messages[0].text == "test"
|
||||
assert context.is_streaming is False
|
||||
assert isinstance(context.metadata, dict)
|
||||
|
||||
# Add custom metadata
|
||||
context.metadata["validated"] = True
|
||||
|
||||
await next(context)
|
||||
|
||||
middleware = ContextValidationMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
# Verify metadata was set by middleware
|
||||
assert ctx.metadata.get("validated") is True
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
assert result is not None
|
||||
|
||||
async def test_function_context_validation(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test that function context contains expected data."""
|
||||
|
||||
class ContextValidationMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Verify context has all expected attributes
|
||||
assert hasattr(context, "function")
|
||||
assert hasattr(context, "arguments")
|
||||
assert hasattr(context, "metadata")
|
||||
|
||||
# Verify context content
|
||||
assert context.function is mock_function
|
||||
assert isinstance(context.arguments, FunctionTestArgs)
|
||||
assert context.arguments.name == "test"
|
||||
assert isinstance(context.metadata, dict)
|
||||
|
||||
# Add custom metadata
|
||||
context.metadata["validated"] = True
|
||||
|
||||
await next(context)
|
||||
|
||||
middleware = ContextValidationMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
# Verify metadata was set by middleware
|
||||
assert ctx.metadata.get("validated") is True
|
||||
return "result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
assert result == "result"
|
||||
|
||||
|
||||
class TestStreamingScenarios:
|
||||
"""Test cases for streaming and non-streaming scenarios."""
|
||||
|
||||
async def test_streaming_flag_validation(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that is_streaming flag is correctly set for streaming calls."""
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingFlagMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
|
||||
middleware = StreamingFlagMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
|
||||
# Test non-streaming
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
streaming_flags.append(ctx.is_streaming)
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
# Test streaming
|
||||
context_stream = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
streaming_flags.append(ctx.is_streaming)
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk")])
|
||||
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context_stream, final_stream_handler):
|
||||
updates.append(update)
|
||||
|
||||
# Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler]
|
||||
assert streaming_flags == [False, False, True, True]
|
||||
|
||||
async def test_streaming_middleware_behavior(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test middleware behavior with streaming responses."""
|
||||
chunks_processed: list[str] = []
|
||||
|
||||
class StreamProcessingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
chunks_processed.append("before_stream")
|
||||
await next(context)
|
||||
chunks_processed.append("after_stream")
|
||||
|
||||
middleware = StreamProcessingMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
chunks_processed.append("stream_start")
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
chunks_processed.append("chunk1_yielded")
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
chunks_processed.append("chunk2_yielded")
|
||||
chunks_processed.append("stream_end")
|
||||
|
||||
updates: list[str] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context, final_stream_handler):
|
||||
updates.append(update.text)
|
||||
|
||||
assert updates == ["chunk1", "chunk2"]
|
||||
assert chunks_processed == [
|
||||
"before_stream",
|
||||
"after_stream",
|
||||
"stream_start",
|
||||
"chunk1_yielded",
|
||||
"chunk2_yielded",
|
||||
"stream_end",
|
||||
]
|
||||
|
||||
|
||||
# region Helper classes and fixtures
|
||||
|
||||
|
||||
class FunctionTestArgs(BaseModel):
|
||||
"""Test arguments for function middleware tests."""
|
||||
|
||||
name: str = Field(description="Test name parameter")
|
||||
|
||||
|
||||
class TestAgentMiddleware(AgentMiddleware):
|
||||
"""Test implementation of AgentMiddleware."""
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
await next(context)
|
||||
|
||||
|
||||
class TestFunctionMiddleware(FunctionMiddleware):
|
||||
"""Test implementation of FunctionMiddleware."""
|
||||
|
||||
async def process(
|
||||
self, context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
await next(context)
|
||||
|
||||
|
||||
class MockFunctionArgs(BaseModel):
|
||||
"""Test arguments for function middleware tests."""
|
||||
|
||||
name: str = Field(description="Test name parameter")
|
||||
|
||||
|
||||
class TestMiddlewareExecutionControl:
|
||||
"""Test cases for middleware execution control (when next() is called vs not called)."""
|
||||
|
||||
async def test_agent_middleware_no_next_no_execution(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that when agent middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class NoNextMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
middleware = NoNextMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
# Verify no execution happened - should return empty AgentRunResponse
|
||||
assert result is not None
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.messages == [] # Empty response
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that when agent middleware doesn't call next(), no streaming execution happens."""
|
||||
|
||||
class NoNextStreamingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
middleware = NoNextStreamingMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="should not execute")])
|
||||
|
||||
# When middleware doesn't call next(), streaming should yield no updates
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
# Verify no execution happened and no updates were yielded
|
||||
assert len(updates) == 0
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_function_middleware_no_next_no_execution(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test that when function middleware doesn't call next(), no execution happens."""
|
||||
|
||||
class FunctionTestArgs(BaseModel):
|
||||
name: str = Field(description="Test name parameter")
|
||||
|
||||
class NoNextFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Don't call next() - this should prevent any execution
|
||||
pass
|
||||
|
||||
middleware = NoNextFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return "should not execute"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
# Verify no execution happened
|
||||
assert result is None
|
||||
assert not handler_called
|
||||
assert context.result is None
|
||||
|
||||
async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that when first middleware doesn't call next(), subsequent middlewares are not called."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("first")
|
||||
# Don't call next() - this should stop the pipeline
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("second")
|
||||
await next(context)
|
||||
|
||||
pipeline = AgentMiddlewarePipeline([FirstMiddleware(), SecondMiddleware()])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="should not execute")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
# Verify only first middleware was called and empty response returned
|
||||
assert execution_order == ["first"]
|
||||
assert result is not None
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.messages == [] # Empty response
|
||||
assert not handler_called
|
||||
|
||||
async def test_function_middleware_pre_execution_override_with_next(
|
||||
self, mock_function: AIFunction[Any, Any]
|
||||
) -> None:
|
||||
"""Test that function middleware can override result before calling next() - this skips handler execution."""
|
||||
|
||||
class FunctionTestArgs(BaseModel):
|
||||
name: str = Field(description="Test name parameter")
|
||||
|
||||
class PreOverrideFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Set override first
|
||||
context.result = "pre-override result"
|
||||
# Then call next() to continue middleware pipeline
|
||||
await next(context)
|
||||
|
||||
middleware = PreOverrideFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
# This should not be called when result is pre-set
|
||||
return "original result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
# Verify pre-override worked and handler was NOT called (because result was already set)
|
||||
assert result == "pre-override result"
|
||||
assert not handler_called
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> AgentProtocol:
|
||||
"""Mock agent for testing."""
|
||||
agent = MagicMock(spec=AgentProtocol)
|
||||
agent.name = "test_agent"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_function() -> AIFunction[Any, Any]:
|
||||
"""Mock function for testing."""
|
||||
function = MagicMock(spec=AIFunction[Any, Any])
|
||||
function.name = "test_function"
|
||||
return function
|
||||
@@ -0,0 +1,463 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
AgentMiddlewarePipeline,
|
||||
AgentRunContext,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewarePipeline,
|
||||
)
|
||||
from agent_framework._tools import AIFunction
|
||||
|
||||
from .conftest import MockChatClient
|
||||
|
||||
|
||||
class FunctionTestArgs(BaseModel):
|
||||
"""Test arguments for function middleware tests."""
|
||||
|
||||
name: str = Field(description="Test name parameter")
|
||||
|
||||
|
||||
class TestResultOverrideMiddleware:
|
||||
"""Test cases for middleware result override functionality."""
|
||||
|
||||
async def test_agent_middleware_response_override_non_streaming(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that agent middleware can override response for non-streaming execution."""
|
||||
override_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="overridden response")])
|
||||
|
||||
class ResponseOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Execute the pipeline first, then override the response
|
||||
await next(context)
|
||||
context.result = override_response
|
||||
|
||||
middleware = ResponseOverrideMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="original response")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
# Verify the overridden response is returned
|
||||
assert result is not None
|
||||
assert result == override_response
|
||||
assert result.messages[0].text == "overridden response"
|
||||
# Verify original handler was called since middleware called next()
|
||||
assert handler_called
|
||||
|
||||
async def test_agent_middleware_response_override_streaming(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that agent middleware can override response for streaming execution."""
|
||||
|
||||
async def override_stream() -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="overridden")])
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text=" stream")])
|
||||
|
||||
class StreamResponseOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Execute the pipeline first, then override the response stream
|
||||
await next(context)
|
||||
context.result = override_stream()
|
||||
|
||||
middleware = StreamResponseOverrideMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="original")])
|
||||
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
# Verify the overridden response stream is returned
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "overridden"
|
||||
assert updates[1].text == " stream"
|
||||
|
||||
async def test_function_middleware_result_override(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test that function middleware can override result."""
|
||||
override_result = "overridden function result"
|
||||
|
||||
class ResultOverrideMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Execute the pipeline first, then override the result
|
||||
await next(context)
|
||||
context.result = override_result
|
||||
|
||||
middleware = ResultOverrideMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return "original function result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
# Verify the overridden result is returned
|
||||
assert result == override_result
|
||||
# Verify original handler was called since middleware called next()
|
||||
assert handler_called
|
||||
|
||||
async def test_chat_agent_middleware_response_override(self) -> None:
|
||||
"""Test result override functionality with ChatAgent integration."""
|
||||
mock_chat_client = MockChatClient()
|
||||
|
||||
class ChatAgentResponseOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Always call next() first to allow execution
|
||||
await next(context)
|
||||
# Then conditionally override based on content
|
||||
if any("special" in msg.text for msg in context.messages if msg.text):
|
||||
context.result = AgentRunResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Special response from middleware!")]
|
||||
)
|
||||
|
||||
# Create ChatAgent with override middleware
|
||||
middleware = ChatAgentResponseOverrideMiddleware()
|
||||
agent = ChatAgent(chat_client=mock_chat_client, middleware=[middleware])
|
||||
|
||||
# Test override case
|
||||
override_messages = [ChatMessage(role=Role.USER, text="Give me a special response")]
|
||||
override_response = await agent.run(override_messages)
|
||||
assert override_response.messages[0].text == "Special response from middleware!"
|
||||
# Verify chat client was called since middleware called next()
|
||||
assert mock_chat_client.call_count == 1
|
||||
|
||||
# Test normal case
|
||||
normal_messages = [ChatMessage(role=Role.USER, text="Normal request")]
|
||||
normal_response = await agent.run(normal_messages)
|
||||
assert normal_response.messages[0].text == "test response"
|
||||
# Verify chat client was called for normal case
|
||||
assert mock_chat_client.call_count == 2
|
||||
|
||||
async def test_chat_agent_middleware_streaming_override(self) -> None:
|
||||
"""Test streaming result override functionality with ChatAgent integration."""
|
||||
mock_chat_client = MockChatClient()
|
||||
|
||||
async def custom_stream() -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="Custom")])
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text=" streaming")])
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text=" response!")])
|
||||
|
||||
class ChatAgentStreamOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Always call next() first to allow execution
|
||||
await next(context)
|
||||
# Then conditionally override based on content
|
||||
if any("custom stream" in msg.text for msg in context.messages if msg.text):
|
||||
context.result = custom_stream()
|
||||
|
||||
# Create ChatAgent with override middleware
|
||||
middleware = ChatAgentStreamOverrideMiddleware()
|
||||
agent = ChatAgent(chat_client=mock_chat_client, middleware=[middleware])
|
||||
|
||||
# Test streaming override case
|
||||
override_messages = [ChatMessage(role=Role.USER, text="Give me a custom stream")]
|
||||
override_updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream(override_messages):
|
||||
override_updates.append(update)
|
||||
|
||||
assert len(override_updates) == 3
|
||||
assert override_updates[0].text == "Custom"
|
||||
assert override_updates[1].text == " streaming"
|
||||
assert override_updates[2].text == " response!"
|
||||
|
||||
# Test normal streaming case
|
||||
normal_messages = [ChatMessage(role=Role.USER, text="Normal streaming request")]
|
||||
normal_updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream(normal_messages):
|
||||
normal_updates.append(update)
|
||||
|
||||
assert len(normal_updates) == 2
|
||||
assert normal_updates[0].text == "test streaming response "
|
||||
assert normal_updates[1].text == "another update"
|
||||
|
||||
async def test_agent_middleware_conditional_no_next(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that when agent middleware conditionally doesn't call next(), no execution happens."""
|
||||
|
||||
class ConditionalNoNextMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Only call next() if message contains "execute"
|
||||
if any("execute" in msg.text for msg in context.messages if msg.text):
|
||||
await next(context)
|
||||
# Otherwise, don't call next() - no execution should happen
|
||||
|
||||
middleware = ConditionalNoNextMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")])
|
||||
|
||||
# Test case where next() is NOT called
|
||||
no_execute_messages = [ChatMessage(role=Role.USER, text="Don't run this")]
|
||||
no_execute_context = AgentRunContext(agent=mock_agent, messages=no_execute_messages)
|
||||
no_execute_result = await pipeline.execute(mock_agent, no_execute_messages, no_execute_context, final_handler)
|
||||
|
||||
# When middleware doesn't call next(), result should be empty AgentRunResponse
|
||||
assert no_execute_result is not None
|
||||
assert isinstance(no_execute_result, AgentRunResponse)
|
||||
assert no_execute_result.messages == [] # Empty response
|
||||
assert not handler_called
|
||||
assert no_execute_context.result is None
|
||||
|
||||
# Reset for next test
|
||||
handler_called = False
|
||||
|
||||
# Test case where next() IS called
|
||||
execute_messages = [ChatMessage(role=Role.USER, text="Please execute this")]
|
||||
execute_context = AgentRunContext(agent=mock_agent, messages=execute_messages)
|
||||
execute_result = await pipeline.execute(mock_agent, execute_messages, execute_context, final_handler)
|
||||
|
||||
assert execute_result is not None
|
||||
assert execute_result.messages[0].text == "executed response"
|
||||
assert handler_called
|
||||
|
||||
async def test_function_middleware_conditional_no_next(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test that when function middleware conditionally doesn't call next(), no execution happens."""
|
||||
|
||||
class ConditionalNoNextFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Only call next() if argument name contains "execute"
|
||||
args = context.arguments
|
||||
assert isinstance(args, FunctionTestArgs)
|
||||
if "execute" in args.name:
|
||||
await next(context)
|
||||
# Otherwise, don't call next() - no execution should happen
|
||||
|
||||
middleware = ConditionalNoNextFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return "executed function result"
|
||||
|
||||
# Test case where next() is NOT called
|
||||
no_execute_args = FunctionTestArgs(name="test_no_action")
|
||||
no_execute_context = FunctionInvocationContext(function=mock_function, arguments=no_execute_args)
|
||||
no_execute_result = await pipeline.execute(mock_function, no_execute_args, no_execute_context, final_handler)
|
||||
|
||||
# When middleware doesn't call next(), function result should be None (functions can return None)
|
||||
assert no_execute_result is None
|
||||
assert not handler_called
|
||||
assert no_execute_context.result is None
|
||||
|
||||
# Reset for next test
|
||||
handler_called = False
|
||||
|
||||
# Test case where next() IS called
|
||||
execute_args = FunctionTestArgs(name="test_execute")
|
||||
execute_context = FunctionInvocationContext(function=mock_function, arguments=execute_args)
|
||||
execute_result = await pipeline.execute(mock_function, execute_args, execute_context, final_handler)
|
||||
|
||||
assert execute_result == "executed function result"
|
||||
assert handler_called
|
||||
|
||||
|
||||
class TestResultObservability:
|
||||
"""Test cases for middleware result observability functionality."""
|
||||
|
||||
async def test_agent_middleware_response_observability(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that middleware can observe response after execution."""
|
||||
observed_responses: list[AgentRunResponse] = []
|
||||
|
||||
class ObservabilityMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Context should be empty before next()
|
||||
assert context.result is None
|
||||
|
||||
# Call next to execute
|
||||
await next(context)
|
||||
|
||||
# Context should now contain the response for observability
|
||||
assert context.result is not None
|
||||
assert isinstance(context.result, AgentRunResponse)
|
||||
observed_responses.append(context.result)
|
||||
|
||||
middleware = ObservabilityMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="executed response")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
# Verify response was observed
|
||||
assert len(observed_responses) == 1
|
||||
assert observed_responses[0].messages[0].text == "executed response"
|
||||
assert result == observed_responses[0]
|
||||
|
||||
async def test_function_middleware_result_observability(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test that middleware can observe function result after execution."""
|
||||
observed_results: list[str] = []
|
||||
|
||||
class ObservabilityMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Context should be empty before next()
|
||||
assert context.result is None
|
||||
|
||||
# Call next to execute
|
||||
await next(context)
|
||||
|
||||
# Context should now contain the result for observability
|
||||
assert context.result is not None
|
||||
observed_results.append(context.result)
|
||||
|
||||
middleware = ObservabilityMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
return "executed function result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
# Verify result was observed
|
||||
assert len(observed_results) == 1
|
||||
assert observed_results[0] == "executed function result"
|
||||
assert result == observed_results[0]
|
||||
|
||||
async def test_agent_middleware_post_execution_override(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test that middleware can override response after observing execution."""
|
||||
|
||||
class PostExecutionOverrideMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Call next to execute first
|
||||
await next(context)
|
||||
|
||||
# Now observe and conditionally override
|
||||
assert context.result is not None
|
||||
assert isinstance(context.result, AgentRunResponse)
|
||||
|
||||
if "modify" in context.result.messages[0].text:
|
||||
# Override after observing
|
||||
context.result = AgentRunResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="modified after execution")]
|
||||
)
|
||||
|
||||
middleware = PostExecutionOverrideMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response to modify")])
|
||||
|
||||
result = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
|
||||
# Verify response was modified after execution
|
||||
assert result is not None
|
||||
assert result.messages[0].text == "modified after execution"
|
||||
|
||||
async def test_function_middleware_post_execution_override(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test that middleware can override function result after observing execution."""
|
||||
|
||||
class PostExecutionOverrideMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Call next to execute first
|
||||
await next(context)
|
||||
|
||||
# Now observe and conditionally override
|
||||
assert context.result is not None
|
||||
|
||||
if "modify" in context.result:
|
||||
# Override after observing
|
||||
context.result = "modified after execution"
|
||||
|
||||
middleware = PostExecutionOverrideMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
return "result to modify"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
# Verify result was modified after execution
|
||||
assert result == "modified after execution"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> AgentProtocol:
|
||||
"""Mock agent for testing."""
|
||||
agent = MagicMock(spec=AgentProtocol)
|
||||
agent.name = "test_agent"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_function() -> AIFunction[Any, Any]:
|
||||
"""Mock function for testing."""
|
||||
function = MagicMock(spec=AIFunction[Any, Any])
|
||||
function.name = "test_function"
|
||||
return function
|
||||
@@ -0,0 +1,544 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from agent_framework import (
|
||||
AgentRunResponseUpdate,
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
AgentRunContext,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
)
|
||||
|
||||
from .conftest import MockChatClient
|
||||
|
||||
# region ChatAgent Tests
|
||||
|
||||
|
||||
class TestChatAgentClassBasedMiddleware:
|
||||
"""Test cases for class-based middleware integration with ChatAgent."""
|
||||
|
||||
async def test_class_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test class-based agent middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class TrackingAgentMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
# Create ChatAgent with middleware
|
||||
middleware = TrackingAgentMiddleware("agent_middleware")
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
# Note: conftest "MockChatClient" returns different text format
|
||||
assert "test response" in response.messages[0].text
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["agent_middleware_before", "agent_middleware_after"]
|
||||
|
||||
async def test_class_based_function_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test class-based function middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class TrackingFunctionMiddleware(FunctionMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
# Create ChatAgent with function middleware (no tools, so function middleware won't be triggered)
|
||||
middleware = TrackingFunctionMiddleware("function_middleware")
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 1
|
||||
|
||||
# Note: Function middleware won't execute since no function calls are made
|
||||
assert execution_order == []
|
||||
|
||||
|
||||
class TestChatAgentFunctionBasedMiddleware:
|
||||
"""Test cases for function-based middleware integration with ChatAgent."""
|
||||
|
||||
async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test function-based agent middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_agent_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("agent_function_before")
|
||||
await next(context)
|
||||
execution_order.append("agent_function_after")
|
||||
|
||||
# Create ChatAgent with function middleware
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_agent_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert response.messages[0].text == "test response"
|
||||
assert chat_client.call_count == 1
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["agent_function_before", "agent_function_after"]
|
||||
|
||||
async def test_function_based_function_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test function-based function middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_function_before")
|
||||
await next(context)
|
||||
execution_order.append("function_function_after")
|
||||
|
||||
# Create ChatAgent with function middleware (no tools, so function middleware won't be triggered)
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[tracking_function_middleware])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 1
|
||||
|
||||
# Note: Function middleware won't execute since no function calls are made
|
||||
assert execution_order == []
|
||||
|
||||
|
||||
class TestChatAgentStreamingMiddleware:
|
||||
"""Test cases for streaming middleware integration with ChatAgent."""
|
||||
|
||||
async def test_agent_middleware_with_streaming(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test agent middleware with streaming ChatAgent responses."""
|
||||
execution_order: list[str] = []
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingTrackingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create ChatAgent with middleware
|
||||
middleware = StreamingTrackingMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
|
||||
# Set up mock streaming responses
|
||||
chat_client.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(contents=[TextContent(text="Streaming")], role=Role.ASSISTANT),
|
||||
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
|
||||
]
|
||||
]
|
||||
|
||||
# Execute streaming
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream(messages):
|
||||
updates.append(update)
|
||||
|
||||
# Verify streaming response
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "Streaming"
|
||||
assert updates[1].text == " response"
|
||||
assert chat_client.call_count == 1
|
||||
|
||||
# Verify middleware was called and streaming flag was set correctly
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
assert streaming_flags == [True] # Context should indicate streaming
|
||||
|
||||
async def test_non_streaming_vs_streaming_flag_validation(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that is_streaming flag is correctly set for different execution modes."""
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class FlagTrackingMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
|
||||
# Create ChatAgent with middleware
|
||||
middleware = FlagTrackingMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
|
||||
# Test non-streaming execution
|
||||
response = await agent.run(messages)
|
||||
assert response is not None
|
||||
|
||||
# Test streaming execution
|
||||
async for _ in agent.run_stream(messages):
|
||||
pass
|
||||
|
||||
# Verify flags: [non-streaming, streaming]
|
||||
assert streaming_flags == [False, True]
|
||||
|
||||
|
||||
class TestChatAgentMultipleMiddlewareOrdering:
|
||||
"""Test cases for multiple middleware execution order with ChatAgent."""
|
||||
|
||||
async def test_multiple_agent_middleware_execution_order(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that multiple agent middlewares execute in correct order with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class OrderedMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
# Create multiple middlewares
|
||||
middleware1 = OrderedMiddleware("first")
|
||||
middleware2 = OrderedMiddleware("second")
|
||||
middleware3 = OrderedMiddleware("third")
|
||||
|
||||
# Create ChatAgent with multiple middlewares
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2, middleware3])
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert chat_client.call_count == 1
|
||||
|
||||
# Verify execution order (should be nested: first wraps second wraps third)
|
||||
expected_order = ["first_before", "second_before", "third_before", "third_after", "second_after", "first_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
async def test_mixed_middleware_types_with_chat_agent(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test mixed class and function-based middlewares with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ClassAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("class_agent_before")
|
||||
await next(context)
|
||||
execution_order.append("class_agent_after")
|
||||
|
||||
async def function_agent_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_agent_before")
|
||||
await next(context)
|
||||
execution_order.append("function_agent_after")
|
||||
|
||||
class ClassFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("class_function_before")
|
||||
await next(context)
|
||||
execution_order.append("class_function_after")
|
||||
|
||||
async def function_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_function_before")
|
||||
await next(context)
|
||||
execution_order.append("function_function_after")
|
||||
|
||||
# Create ChatAgent with mixed middleware types (no tools, focusing on agent middleware)
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[
|
||||
ClassAgentMiddleware(),
|
||||
function_agent_middleware,
|
||||
ClassFunctionMiddleware(), # Won't execute without function calls
|
||||
function_function_middleware, # Won't execute without function calls
|
||||
],
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert chat_client.call_count == 1
|
||||
|
||||
# Verify that agent middlewares were executed in correct order
|
||||
# (Function middlewares won't execute since no functions are called)
|
||||
expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"]
|
||||
assert execution_order == expected_order
|
||||
|
||||
|
||||
# region Tool Functions for Testing
|
||||
|
||||
|
||||
def sample_tool_function(location: str) -> str:
|
||||
"""A simple tool function for middleware testing."""
|
||||
return f"Weather in {location}: sunny"
|
||||
|
||||
|
||||
# region ChatAgent Function Middleware Tests with Tools
|
||||
|
||||
|
||||
class TestChatAgentFunctionMiddlewareWithTools:
|
||||
"""Test cases for function middleware integration with ChatAgent when tools are used."""
|
||||
|
||||
async def test_class_based_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test class-based function middleware with ChatAgent when function calls are made."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class TrackingFunctionMiddleware(FunctionMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append(f"{self.name}_before")
|
||||
await next(context)
|
||||
execution_order.append(f"{self.name}_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_123",
|
||||
name="sample_tool_function",
|
||||
arguments='{"location": "Seattle"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
|
||||
chat_client.responses = [function_call_response, final_response]
|
||||
|
||||
# Create ChatAgent with function middleware and tools
|
||||
middleware = TrackingFunctionMiddleware("function_middleware")
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[middleware],
|
||||
tools=[sample_tool_function],
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="Get weather for Seattle")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
|
||||
|
||||
# Verify function middleware was executed
|
||||
assert execution_order == ["function_middleware_before", "function_middleware_after"]
|
||||
|
||||
# Verify function call and result are in the response
|
||||
all_contents = [content for message in response.messages for content in message.contents]
|
||||
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
|
||||
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
|
||||
|
||||
assert len(function_calls) == 1
|
||||
assert len(function_results) == 1
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
|
||||
async def test_function_based_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test function-based function middleware with ChatAgent when function calls are made."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def tracking_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_456",
|
||||
name="sample_tool_function",
|
||||
arguments='{"location": "San Francisco"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
|
||||
chat_client.responses = [function_call_response, final_response]
|
||||
|
||||
# Create ChatAgent with function middleware and tools
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[tracking_function_middleware],
|
||||
tools=[sample_tool_function],
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="Get weather for San Francisco")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
|
||||
|
||||
# Verify function middleware was executed
|
||||
assert execution_order == ["function_middleware_before", "function_middleware_after"]
|
||||
|
||||
# Verify function call and result are in the response
|
||||
all_contents = [content for message in response.messages for content in message.contents]
|
||||
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
|
||||
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
|
||||
|
||||
assert len(function_calls) == 1
|
||||
assert len(function_results) == 1
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
|
||||
async def test_mixed_agent_and_function_middleware_with_tool_calls(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test both agent and function middleware with ChatAgent when function calls are made."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class TrackingAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("agent_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("agent_middleware_after")
|
||||
|
||||
class TrackingFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("function_middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("function_middleware_after")
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="call_789",
|
||||
name="sample_tool_function",
|
||||
arguments='{"location": "New York"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
|
||||
chat_client.responses = [function_call_response, final_response]
|
||||
|
||||
# Create ChatAgent with both agent and function middleware and tools
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[TrackingAgentMiddleware(), TrackingFunctionMiddleware()],
|
||||
tools=[sample_tool_function],
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
messages = [ChatMessage(role=Role.USER, text="Get weather for New York")]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Two calls: one for function call, one for final response
|
||||
|
||||
# Verify middleware execution order: agent middleware wraps everything,
|
||||
# function middleware only for function calls
|
||||
expected_order = [
|
||||
"agent_middleware_before",
|
||||
"function_middleware_before",
|
||||
"function_middleware_after",
|
||||
"agent_middleware_after",
|
||||
]
|
||||
assert execution_order == expected_order
|
||||
|
||||
# Verify function call and result are in the response
|
||||
all_contents = [content for message in response.messages for content in message.contents]
|
||||
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
|
||||
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
|
||||
|
||||
assert len(function_calls) == 1
|
||||
assert len(function_results) == 1
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentMiddleware,
|
||||
AgentRunContext,
|
||||
AgentRunResponse,
|
||||
ChatMessage,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
Role,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Class-based Middleware Example
|
||||
|
||||
This sample demonstrates how to implement middleware using class-based approach by inheriting
|
||||
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
|
||||
|
||||
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
|
||||
containing sensitive information like passwords or secrets
|
||||
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
|
||||
|
||||
This approach is useful when you need stateful middleware or complex logic that benefits
|
||||
from object-oriented design patterns.
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
class SecurityAgentMiddleware(AgentMiddleware):
|
||||
"""Agent middleware that checks for security violations."""
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Check for potential security violations in the query
|
||||
# Look at the last user message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
if last_message and last_message.text:
|
||||
query = last_message.text
|
||||
if "password" in query.lower() or "secret" in query.lower():
|
||||
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
|
||||
# Override the result with warning message
|
||||
context.result = AgentRunResponse(
|
||||
messages=[
|
||||
ChatMessage(role=Role.ASSISTANT, text="Detected sensitive information, the request is blocked.")
|
||||
]
|
||||
)
|
||||
# Simply don't call next() to prevent execution
|
||||
return
|
||||
|
||||
print("[SecurityAgentMiddleware] Security check passed.")
|
||||
await next(context)
|
||||
|
||||
|
||||
class LoggingFunctionMiddleware(FunctionMiddleware):
|
||||
"""Function middleware that logs function calls."""
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
await next(context)
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
||||
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating class-based middleware."""
|
||||
print("=== Class-based Middleware Example ===")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
|
||||
) as agent,
|
||||
):
|
||||
# Test with normal query
|
||||
print("\n--- Normal Query ---")
|
||||
query = "What's the weather like in Seattle?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}\n")
|
||||
|
||||
# Test with security-related query
|
||||
print("--- Security Test ---")
|
||||
query = "What's the password for the weather service?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import FunctionInvocationContext
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Exception Handling with Middleware
|
||||
|
||||
This sample demonstrates how to use middleware for centralized exception handling in function calls.
|
||||
The example shows:
|
||||
|
||||
- How to catch exceptions thrown by functions and provide graceful error responses
|
||||
- Overriding function results when errors occur to provide user-friendly messages
|
||||
- Using middleware to implement retry logic, fallback mechanisms, or error reporting
|
||||
|
||||
The middleware catches TimeoutError from an unstable data service and replaces it with
|
||||
a helpful message for the user, preventing raw exceptions from reaching the end user.
|
||||
"""
|
||||
|
||||
|
||||
def unstable_data_service(
|
||||
query: Annotated[str, Field(description="The data query to execute.")],
|
||||
) -> str:
|
||||
"""A simulated data service that sometimes throws exceptions."""
|
||||
# Simulate failure
|
||||
raise TimeoutError("Data service request timed out")
|
||||
|
||||
|
||||
async def exception_handling_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
|
||||
try:
|
||||
print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
|
||||
await next(context)
|
||||
print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
|
||||
except TimeoutError as e:
|
||||
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
|
||||
# Override function result to provide custom message in response.
|
||||
context.result = (
|
||||
"Request Timeout: The data service is taking longer than expected to respond.",
|
||||
"Respond with message - 'Sorry for the inconvenience, please try again later.'",
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating exception handling with middleware."""
|
||||
print("=== Exception Handling Middleware Example ===")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="DataAgent",
|
||||
instructions="You are a helpful data assistant. Use the data service tool to fetch information for users.",
|
||||
tools=unstable_data_service,
|
||||
middleware=exception_handling_middleware,
|
||||
) as agent,
|
||||
):
|
||||
query = "Get user statistics"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentRunContext,
|
||||
FunctionInvocationContext,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Function-based Middleware Example
|
||||
|
||||
This sample demonstrates how to implement middleware using simple async functions instead of classes.
|
||||
The example includes:
|
||||
|
||||
- Security middleware that validates agent requests for sensitive information
|
||||
- Logging middleware that tracks function execution timing and parameters
|
||||
- Performance monitoring to measure execution duration
|
||||
|
||||
Function-based middleware is ideal for simple, stateless operations and provides a more
|
||||
lightweight approach compared to class-based middleware. Both agent and function middleware
|
||||
can be implemented as async functions that accept context and next parameters.
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def security_agent_middleware(
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent middleware that checks for security violations."""
|
||||
# Check for potential security violations in the query
|
||||
# For this example, we'll check the last user message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
if last_message and last_message.text:
|
||||
query = last_message.text
|
||||
if "password" in query.lower() or "secret" in query.lower():
|
||||
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
|
||||
# Simply don't call next() to prevent execution
|
||||
return
|
||||
|
||||
print("[SecurityAgentMiddleware] Security check passed.")
|
||||
await next(context)
|
||||
|
||||
|
||||
async def logging_function_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that logs function calls."""
|
||||
function_name = context.function.name
|
||||
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
await next(context)
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
||||
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating function-based middleware."""
|
||||
print("=== Function-based Middleware Example ===")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
middleware=[security_agent_middleware, logging_function_middleware],
|
||||
) as agent,
|
||||
):
|
||||
# Test with normal query
|
||||
print("\n--- Normal Query ---")
|
||||
query = "What's the weather like in Tokyo?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}\n")
|
||||
|
||||
# Test with security violation
|
||||
print("--- Security Test ---")
|
||||
query = "What's the secret weather password?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import FunctionInvocationContext
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Result Override with Middleware
|
||||
|
||||
This sample demonstrates how to use middleware to intercept and modify function results
|
||||
after execution. The example shows:
|
||||
|
||||
- How to execute the original function first and then modify its result
|
||||
- Replacing function outputs with custom messages or transformed data
|
||||
- Using middleware for result filtering, formatting, or enhancement
|
||||
|
||||
The weather override middleware lets the original weather function execute normally,
|
||||
then replaces its result with a custom "perfect weather" message, demonstrating
|
||||
how middleware can be used for content filtering, A/B testing, or result enhancement.
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def weather_override_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
function_name = context.function.name
|
||||
|
||||
# Let the original function execute first
|
||||
await next(context)
|
||||
|
||||
# Override the result if it's a weather function
|
||||
if function_name == "get_weather" and context.result is not None:
|
||||
original_result = str(context.result)
|
||||
print(f"[WeatherOverrideMiddleware] Original result: {original_result}")
|
||||
|
||||
# Override with a custom message
|
||||
# It's also possible to override the result before "next()" call if needed
|
||||
custom_message = (
|
||||
"Weather Advisory - due to special atmospheric conditions, "
|
||||
"all locations are experiencing perfect weather today! "
|
||||
"Temperature is a comfortable 22°C with gentle breezes. "
|
||||
"Perfect day for outdoor activities!"
|
||||
)
|
||||
context.result = custom_message
|
||||
print(f"[WeatherOverrideMiddleware] Overriding with custom message: {custom_message}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating result override with middleware."""
|
||||
print("=== Result Override Middleware Example ===")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant. Use the weather tool to get current conditions.",
|
||||
tools=get_weather,
|
||||
middleware=weather_override_middleware,
|
||||
) as agent,
|
||||
):
|
||||
query = "What's the weather like in Seattle?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user