Python: Extending middleware capabilities (#844)

* Implemented termination

* Added termination sample

* Allowed middleware pipeline modification

* Added run-level middleware

* Added more validation to function-based middleware

* Added example with function-based decorator approach

* Update python/samples/getting_started/middleware/decorator_middleware.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update python/samples/getting_started/middleware/decorator_middleware.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Small improvements

* Fixed tests

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2025-09-21 21:37:57 -07:00
committed by GitHub
Unverified
parent 08f792e511
commit f61d8abe58
6 changed files with 1680 additions and 119 deletions
@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable, Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
@@ -15,12 +16,22 @@ if TYPE_CHECKING:
TAgent = TypeVar("TAgent", bound="AgentProtocol")
class MiddlewareType(Enum):
"""Enum representing the type of middleware."""
AGENT = "agent"
FUNCTION = "function"
__all__ = [
"AgentMiddleware",
"AgentRunContext",
"FunctionInvocationContext",
"FunctionMiddleware",
"Middleware",
"agent_middleware",
"function_middleware",
"use_agent_middleware",
]
@@ -38,6 +49,8 @@ class AgentRunContext:
to see the actual execution result or can be set to override the execution result.
For non-streaming: should be AgentRunResponse
For streaming: should be AsyncIterable[AgentRunResponseUpdate]
terminate: A flag indicating whether to terminate execution after current middleware.
When set to True, execution will stop as soon as control returns to framework.
"""
agent: "AgentProtocol"
@@ -45,6 +58,7 @@ class AgentRunContext:
is_streaming: bool = False
metadata: dict[str, Any] = field(default_factory=lambda: {})
result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None
terminate: bool = False
@dataclass
@@ -57,12 +71,15 @@ class FunctionInvocationContext:
metadata: Metadata dictionary for sharing data between function middleware.
result: Function execution result. Can be observed after calling next()
to see the actual execution result or can be set to override the execution result.
terminate: A flag indicating whether to terminate execution after current middleware.
When set to True, execution will stop as soon as control returns to framework.
"""
function: "AIFunction[Any, Any]"
arguments: "BaseModel"
metadata: dict[str, Any] = field(default_factory=lambda: {})
result: Any = None
terminate: bool = False
class AgentMiddleware(ABC):
@@ -131,6 +148,53 @@ FunctionMiddlewareCallable = Callable[
Middleware: TypeAlias = AgentMiddleware | AgentMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable
# Middleware type markers for decorators
def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
"""Decorator to mark a function as agent middleware.
This decorator explicitly identifies a function as agent middleware,
which processes AgentRunContext objects.
Args:
func: The middleware function to mark as agent middleware.
Returns:
The same function with agent middleware marker.
Example:
@agent_middleware
async def my_middleware(context: AgentRunContext, next):
# Process agent invocation
await next(context)
"""
# Add marker attribute to identify this as agent middleware
func._middleware_type: MiddlewareType = MiddlewareType.AGENT # type: ignore
return func
def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareCallable:
"""Decorator to mark a function as function middleware.
This decorator explicitly identifies a function as function middleware,
which processes FunctionInvocationContext objects.
Args:
func: The middleware function to mark as function middleware.
Returns:
The same function with function middleware marker.
Example:
@function_middleware
async def my_middleware(context: FunctionInvocationContext, next):
# Process function invocation
await next(context)
"""
# Add marker attribute to identify this as function middleware
func._middleware_type: MiddlewareType = MiddlewareType.FUNCTION # type: ignore
return func
class AgentMiddlewareWrapper(AgentMiddleware):
"""Wrapper to convert pure functions into AgentMiddleware protocol objects."""
@@ -271,6 +335,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
if index >= len(self._middlewares):
async def final_wrapper(c: AgentRunContext) -> None:
# If terminate was set, skip execution
if c.terminate:
return
# Execute actual handler and populate context for observability
result = await final_handler(c)
result_container["result"] = result
@@ -282,6 +350,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
next_handler = create_next_handler(index + 1)
async def current_handler(c: AgentRunContext) -> None:
# If terminate is set, don't continue the pipeline
if c.terminate:
return
await middleware.process(c, next_handler)
# After middleware execution, check if response was overridden
if c.result is not None and isinstance(c.result, AgentRunResponse):
@@ -337,6 +409,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
if index >= len(self._middlewares):
async def final_wrapper(c: AgentRunContext) -> None: # noqa: RUF029
# If terminate was set, skip execution
if c.terminate:
return
# Execute actual handler and populate context for observability
result = final_handler(c)
result_container["result_stream"] = result
@@ -349,6 +425,9 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
async def current_handler(c: AgentRunContext) -> None:
await middleware.process(c, next_handler)
# If terminate is set, don't continue the pipeline
if c.terminate:
return
return current_handler
@@ -424,8 +503,8 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
# Custom final handler that handles pre-existing results
async def function_final_handler(c: FunctionInvocationContext) -> Any:
# If result was set before calling next(), skip execution
if c.result is not None:
# If terminate was set, skip execution and return the result (which might be None)
if c.terminate:
return c.result
# Execute actual handler and populate context for observability
return await final_handler(c)
@@ -446,6 +525,10 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
This decorator adds middleware functionality to any agent class.
It wraps the run() and run_stream() methods to provide middleware execution.
The middleware execution can be terminated at any point by setting the
context.terminate property to True. Once set, the pipeline will stop executing
further middleware as soon as control returns to the pipeline.
Args:
agent_class: The agent class to add middleware support to.
@@ -458,12 +541,101 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
original_run = agent_class.run # type: ignore[attr-defined]
original_run_stream = agent_class.run_stream # type: ignore[attr-defined]
def _initialize_middleware_pipelines(self: Any, middlewares: Middleware | list[Middleware] | None) -> None:
"""Initialize agent and function middleware pipelines from the provided middleware list."""
if not middlewares:
return
def _determine_middleware_type(middleware: Any) -> MiddlewareType:
"""Determine middleware type using decorator and/or parameter type annotation.
middleware_list: list[Middleware] = middlewares if isinstance(middlewares, list) else [middlewares] # type: ignore
Args:
middleware: The middleware function to analyze.
Returns:
MiddlewareType.AGENT or MiddlewareType.FUNCTION indicating the middleware type.
Raises:
ValueError: When middleware type cannot be determined or there's a mismatch.
"""
# Check for decorator marker
decorator_type: MiddlewareType | None = getattr(middleware, "_middleware_type", None)
# Check for parameter type annotation
param_type: MiddlewareType | None = None
try:
sig = inspect.signature(middleware)
params = list(sig.parameters.values())
# Must have at least 2 parameters (context and next)
if len(params) >= 2:
first_param = params[0]
if hasattr(first_param.annotation, "__name__"):
annotation_name = first_param.annotation.__name__
if annotation_name == "AgentRunContext":
param_type = MiddlewareType.AGENT
elif annotation_name == "FunctionInvocationContext":
param_type = MiddlewareType.FUNCTION
else:
# Not enough parameters - can't be valid middleware
raise ValueError(
f"Middleware function must have at least 2 parameters (context, next), "
f"but {middleware.__name__} has {len(params)}"
)
except Exception as e:
if isinstance(e, ValueError):
raise # Re-raise our custom errors
# Signature inspection failed - continue with other checks
pass
if decorator_type and param_type:
# Both decorator and parameter type specified - they must match
if decorator_type != param_type:
raise ValueError(
f"Middleware type mismatch: decorator indicates '{decorator_type.value}' "
f"but parameter type indicates '{param_type.value}' for function {middleware.__name__}"
)
return decorator_type
if decorator_type:
# Just decorator specified - rely on decorator
return decorator_type
if param_type:
# Just parameter type specified - rely on types
return param_type
# Neither decorator nor parameter type specified - throw exception
raise ValueError(
f"Cannot determine middleware type for function {middleware.__name__}. "
f"Please either use @agent_middleware/@function_middleware decorators "
f"or specify parameter types (AgentRunContext or FunctionInvocationContext)."
)
def _build_middleware_pipelines(
agent_level_middlewares: Middleware | list[Middleware] | None,
run_level_middlewares: Middleware | list[Middleware] | None = None,
) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline]:
"""Build fresh agent and function middleware pipelines from the provided middleware lists.
Args:
agent_level_middlewares: Agent-level middleware (executed first)
run_level_middlewares: Run-level middleware (executed after agent middleware)
"""
# Merge middleware lists: agent middleware first, then run middleware
combined_middlewares: list[Middleware] = []
if agent_level_middlewares:
if isinstance(agent_level_middlewares, list):
combined_middlewares.extend(agent_level_middlewares) # type: ignore[arg-type]
else:
combined_middlewares.append(agent_level_middlewares)
if run_level_middlewares:
if isinstance(run_level_middlewares, list):
combined_middlewares.extend(run_level_middlewares) # type: ignore[arg-type]
else:
combined_middlewares.append(run_level_middlewares)
if not combined_middlewares:
return AgentMiddlewarePipeline(), FunctionMiddlewarePipeline()
middleware_list = combined_middlewares
# Separate agent and function middleware using isinstance checks
agent_middlewares: list[AgentMiddleware | AgentMiddlewareCallable] = []
@@ -475,75 +647,42 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
elif isinstance(middleware, FunctionMiddleware):
function_middlewares.append(middleware)
elif callable(middleware): # type: ignore[arg-type]
# Check function signature to determine type
try:
sig = inspect.signature(middleware)
params = list(sig.parameters.values())
if len(params) >= 1:
first_param = params[0]
# Check if first parameter is AgentRunContext or FunctionInvocationContext
if (
hasattr(first_param.annotation, "__name__")
and first_param.annotation.__name__ == "AgentRunContext"
):
agent_middlewares.append(middleware) # type: ignore
elif (
hasattr(first_param.annotation, "__name__")
and first_param.annotation.__name__ == "FunctionInvocationContext"
):
function_middlewares.append(middleware) # type: ignore
else:
# Default to agent middleware if uncertain
agent_middlewares.append(middleware) # type: ignore
else:
agent_middlewares.append(middleware) # type: ignore
except Exception:
# If signature inspection fails, assume it's an agent middleware
# Determine middleware type using decorator and/or parameter type annotation
middleware_type = _determine_middleware_type(middleware)
if middleware_type == MiddlewareType.AGENT:
agent_middlewares.append(middleware) # type: ignore
elif middleware_type == MiddlewareType.FUNCTION:
function_middlewares.append(middleware) # type: ignore
else:
# This should not happen if _determine_middleware_type is implemented correctly
raise ValueError(f"Unknown middleware type: {middleware_type}")
else:
# Fallback
agent_middlewares.append(middleware) # type: ignore
self._agent_middleware_pipeline = AgentMiddlewarePipeline(agent_middlewares)
self._function_middleware_pipeline = FunctionMiddlewarePipeline(function_middlewares)
return AgentMiddlewarePipeline(agent_middlewares), FunctionMiddlewarePipeline(function_middlewares)
async def middleware_enabled_run(
self: Any,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: Any = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
) -> AgentRunResponse:
"""Middleware-enabled run method."""
# Initialize middleware pipelines if not already done
if (
hasattr(self, "middleware")
and self.middleware
and not (
hasattr(self, "_agent_middleware_pipeline")
and hasattr(self, "_function_middleware_pipeline")
and (
self._agent_middleware_pipeline.has_middlewares
or self._function_middleware_pipeline.has_middlewares
)
)
):
_initialize_middleware_pipelines(self, self.middleware)
# Ensure pipelines exist even if empty
if not hasattr(self, "_agent_middleware_pipeline"):
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
if not hasattr(self, "_function_middleware_pipeline"):
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
# Build fresh middleware pipelines from current middleware collection and run-level middleware
agent_middleware = getattr(self, "middleware", None)
agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware)
# Add function middleware pipeline to kwargs if available
if self._function_middleware_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
if function_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = function_pipeline
normalized_messages = self._normalize_messages(messages)
# Execute with middleware if available
if self._agent_middleware_pipeline.has_middlewares:
if agent_pipeline.has_middlewares:
context = AgentRunContext(
agent=self, # type: ignore[arg-type]
messages=normalized_messages,
@@ -553,7 +692,7 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse:
return await original_run(self, ctx.messages, thread=thread, **kwargs) # type: ignore
result = await self._agent_middleware_pipeline.execute(
result = await agent_pipeline.execute(
self, # type: ignore[arg-type]
normalized_messages,
context,
@@ -570,38 +709,22 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: Any = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Middleware-enabled run_stream method."""
# Initialize middleware pipelines if not already done
if (
hasattr(self, "middleware")
and self.middleware
and not (
hasattr(self, "_agent_middleware_pipeline")
and hasattr(self, "_function_middleware_pipeline")
and (
self._agent_middleware_pipeline.has_middlewares
or self._function_middleware_pipeline.has_middlewares
)
)
):
_initialize_middleware_pipelines(self, self.middleware)
# Ensure pipelines exist even if empty
if not hasattr(self, "_agent_middleware_pipeline"):
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
if not hasattr(self, "_function_middleware_pipeline"):
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
# Build fresh middleware pipelines from current middleware collection and run-level middleware
agent_middleware = getattr(self, "middleware", None)
agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware)
# Add function middleware pipeline to kwargs if available
if self._function_middleware_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
if function_pipeline.has_middlewares:
kwargs["_function_middleware_pipeline"] = function_pipeline
normalized_messages = self._normalize_messages(messages)
# Execute with middleware if available
if self._agent_middleware_pipeline.has_middlewares:
if agent_pipeline.has_middlewares:
context = AgentRunContext(
agent=self, # type: ignore[arg-type]
messages=normalized_messages,
@@ -613,7 +736,7 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
yield update
async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]:
async for update in self._agent_middleware_pipeline.execute_stream(
async for update in agent_pipeline.execute_stream(
self, # type: ignore[arg-type]
normalized_messages,
context,
@@ -77,6 +77,16 @@ class TestFunctionInvocationContext:
class TestAgentMiddlewarePipeline:
"""Test cases for AgentMiddlewarePipeline."""
class PreNextTerminateMiddleware(AgentMiddleware):
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
context.terminate = True
await next(context)
class PostNextTerminateMiddleware(AgentMiddleware):
async def process(self, context: AgentRunContext, next: Any) -> None:
await next(context)
context.terminate = True
def test_init_empty(self) -> None:
"""Test AgentMiddlewarePipeline initialization with no middlewares."""
pipeline = AgentMiddlewarePipeline()
@@ -194,10 +204,143 @@ class TestAgentMiddlewarePipeline:
assert updates[1].text == "chunk2"
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
async def test_execute_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline execution with termination before next()."""
middleware = self.PreNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
execution_order: list[str] = []
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
# Handler should not be executed when terminated before next()
execution_order.append("handler")
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
response = await pipeline.execute(mock_agent, messages, context, final_handler)
assert response is not None
assert context.terminate
# Handler should not be called when terminated before next()
assert execution_order == []
assert not response.messages
async def test_execute_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline execution with termination after next()."""
middleware = self.PostNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
execution_order: list[str] = []
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
execution_order.append("handler")
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
response = await pipeline.execute(mock_agent, messages, context, final_handler)
assert response is not None
assert len(response.messages) == 1
assert response.messages[0].text == "response"
assert context.terminate
assert execution_order == ["handler"]
async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline streaming execution with termination before next()."""
middleware = self.PreNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
execution_order: list[str] = []
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
# Handler should not be executed when terminated before next()
execution_order.append("handler_start")
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[AgentRunResponseUpdate] = []
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
updates.append(update)
assert context.terminate
# Handler should not be called when terminated before next()
assert execution_order == []
assert not updates
async def test_execute_stream_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
"""Test pipeline streaming execution with termination after next()."""
middleware = self.PostNextTerminateMiddleware()
pipeline = AgentMiddlewarePipeline([middleware])
messages = [ChatMessage(role=Role.USER, text="test")]
context = AgentRunContext(agent=mock_agent, messages=messages)
execution_order: list[str] = []
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
execution_order.append("handler_start")
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
execution_order.append("handler_end")
updates: list[AgentRunResponseUpdate] = []
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
updates.append(update)
assert len(updates) == 2
assert updates[0].text == "chunk1"
assert updates[1].text == "chunk2"
assert context.terminate
assert execution_order == ["handler_start", "handler_end"]
class TestFunctionMiddlewarePipeline:
"""Test cases for FunctionMiddlewarePipeline."""
class PreNextTerminateFunctionMiddleware(FunctionMiddleware):
async def process(self, context: FunctionInvocationContext, next: Any) -> None:
context.terminate = True
await next(context)
class PostNextTerminateFunctionMiddleware(FunctionMiddleware):
async def process(self, context: FunctionInvocationContext, next: Any) -> None:
await next(context)
context.terminate = True
async def test_execute_with_pre_next_termination(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test pipeline execution with termination before next()."""
middleware = self.PreNextTerminateFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
execution_order: list[str] = []
async def final_handler(ctx: FunctionInvocationContext) -> str:
# Handler should not be executed when terminated before next()
execution_order.append("handler")
return "test result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result is None
assert context.terminate
# Handler should not be called when terminated before next()
assert execution_order == []
async def test_execute_with_post_next_termination(self, mock_function: AIFunction[Any, Any]) -> None:
"""Test pipeline execution with termination after next()."""
middleware = self.PostNextTerminateFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
execution_order: list[str] = []
async def final_handler(ctx: FunctionInvocationContext) -> str:
execution_order.append("handler")
return "test result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
assert result == "test result"
assert context.terminate
assert execution_order == ["handler"]
def test_init_empty(self) -> None:
"""Test FunctionMiddlewarePipeline initialization with no middlewares."""
pipeline = FunctionMiddlewarePipeline()
@@ -884,44 +1027,6 @@ class TestMiddlewareExecutionControl:
assert result.messages == [] # Empty response
assert not handler_called
async def test_function_middleware_pre_execution_override_with_next(
self, mock_function: AIFunction[Any, Any]
) -> None:
"""Test that function middleware can override result before calling next() - this skips handler execution."""
class FunctionTestArgs(BaseModel):
name: str = Field(description="Test name parameter")
class PreOverrideFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
# Set override first
context.result = "pre-override result"
# Then call next() to continue middleware pipeline
await next(context)
middleware = PreOverrideFunctionMiddleware()
pipeline = FunctionMiddlewarePipeline([middleware])
arguments = FunctionTestArgs(name="test")
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
handler_called = False
async def final_handler(ctx: FunctionInvocationContext) -> str:
nonlocal handler_called
handler_called = True
# This should not be called when result is pre-set
return "original result"
result = await pipeline.execute(mock_function, arguments, context, final_handler)
# Verify pre-override worked and handler was NOT called (because result was already set)
assert result == "pre-override result"
assert not handler_called
@pytest.fixture
def mock_agent() -> AgentProtocol:
@@ -1,6 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Awaitable, Callable
from typing import Any
import pytest
from agent_framework import (
AgentRunResponseUpdate,
@@ -12,12 +15,15 @@ from agent_framework import (
FunctionResultContent,
Role,
TextContent,
agent_middleware,
function_middleware,
)
from agent_framework._middleware import (
AgentMiddleware,
AgentRunContext,
FunctionInvocationContext,
FunctionMiddleware,
MiddlewareType,
)
from .conftest import MockChatClient
@@ -98,6 +104,170 @@ class TestChatAgentClassBasedMiddleware:
class TestChatAgentFunctionBasedMiddleware:
"""Test cases for function-based middleware integration with ChatAgent."""
async def test_agent_middleware_with_pre_termination(self, chat_client: "MockChatClient") -> None:
"""Test that agent middleware can terminate execution before calling next()."""
execution_order: list[str] = []
class PreTerminationMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("middleware_before")
context.terminate = True
# We call next() but since terminate=True, subsequent middleware and handler should not execute
await next(context)
execution_order.append("middleware_after")
# Create ChatAgent with terminating middleware
middleware = PreTerminationMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
# Execute the agent with multiple messages
messages = [
ChatMessage(role=Role.USER, text="message1"),
ChatMessage(role=Role.USER, text="message2"), # This should not be processed due to termination
]
response = await agent.run(messages)
# Verify response
assert response is not None
assert not response.messages # No messages should be in response due to pre-termination
assert execution_order == ["middleware_before", "middleware_after"] # Middleware still completes
assert chat_client.call_count == 0 # No calls should be made due to termination
async def test_agent_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None:
"""Test that agent middleware can terminate execution after calling next()."""
execution_order: list[str] = []
class PostTerminationMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("middleware_before")
await next(context)
execution_order.append("middleware_after")
context.terminate = True
# Create ChatAgent with terminating middleware
middleware = PostTerminationMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
# Execute the agent with multiple messages
messages = [
ChatMessage(role=Role.USER, text="message1"),
ChatMessage(role=Role.USER, text="message2"),
]
response = await agent.run(messages)
# Verify response
assert response is not None
assert len(response.messages) == 1
assert response.messages[0].role == Role.ASSISTANT
assert "test response" in response.messages[0].text
# Verify middleware execution order
assert execution_order == ["middleware_before", "middleware_after"]
assert chat_client.call_count == 1
async def test_function_middleware_with_pre_termination(self, chat_client: "MockChatClient") -> None:
"""Test that function middleware can terminate execution before calling next()."""
execution_order: list[str] = []
class PreTerminationFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append("middleware_before")
context.terminate = True
# We call next() but since terminate=True, subsequent middleware and handler should not execute
await next(context)
execution_order.append("middleware_after")
# Create a message to start the conversation
messages = [ChatMessage(role=Role.USER, text="test message")]
# Set up chat client to return a function call
chat_client.responses = [
ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"})
],
)
]
)
]
# Create the test function with the expected signature
def test_function(text: str) -> str:
execution_order.append("function_called")
return "test_result"
# Create ChatAgent with function middleware and test function
middleware = PreTerminationFunctionMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware], tools=[test_function])
# Execute the agent
await agent.run(messages)
# Verify that function was not called and only middleware executed
assert execution_order == ["middleware_before", "middleware_after"]
assert "function_called" not in execution_order
assert execution_order == ["middleware_before", "middleware_after"]
async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None:
"""Test that function middleware can terminate execution after calling next()."""
execution_order: list[str] = []
class PostTerminationFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_order.append("middleware_before")
await next(context)
execution_order.append("middleware_after")
context.terminate = True
# Create a message to start the conversation
messages = [ChatMessage(role=Role.USER, text="test message")]
# Set up chat client to return a function call
chat_client.responses = [
ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"})
],
)
]
)
]
# Create the test function with the expected signature
def test_function(text: str) -> str:
execution_order.append("function_called")
return "test_result"
# Create ChatAgent with function middleware and test function
middleware = PostTerminationFunctionMiddleware()
agent = ChatAgent(chat_client=chat_client, middleware=[middleware], tools=[test_function])
# Execute the agent
response = await agent.run(messages)
# Verify that function was called and middleware executed
assert response is not None
assert "function_called" in execution_order
assert execution_order == ["middleware_before", "function_called", "middleware_after"]
async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
"""Test function-based agent middleware with ChatAgent."""
execution_order: list[str] = []
@@ -542,3 +712,631 @@ class TestChatAgentFunctionMiddlewareWithTools:
assert len(function_results) == 1
assert function_calls[0].name == "sample_tool_function"
assert function_results[0].call_id == function_calls[0].call_id
class TestMiddlewareDynamicRebuild:
"""Test cases for dynamic middleware pipeline rebuilding with ChatAgent."""
class TrackingAgentMiddleware(AgentMiddleware):
"""Test middleware that tracks execution."""
def __init__(self, name: str, execution_log: list[str]):
self.name = name
self.execution_log = execution_log
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
self.execution_log.append(f"{self.name}_start")
await next(context)
self.execution_log.append(f"{self.name}_end")
async def test_middleware_dynamic_rebuild_non_streaming(self, chat_client: "MockChatClient") -> None:
"""Test that middleware pipeline is rebuilt when agent.middleware collection is modified for non-streaming."""
execution_log: list[str] = []
# Create agent with initial middleware
middleware1 = self.TrackingAgentMiddleware("middleware1", execution_log)
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1])
# First execution - should use middleware1
await agent.run("Test message 1")
assert "middleware1_start" in execution_log
assert "middleware1_end" in execution_log
# Clear execution log
execution_log.clear()
# Modify the middleware collection by adding another middleware
middleware2 = self.TrackingAgentMiddleware("middleware2", execution_log)
agent.middleware = [middleware1, middleware2]
# Second execution - should use both middleware1 and middleware2
await agent.run("Test message 2")
assert "middleware1_start" in execution_log
assert "middleware1_end" in execution_log
assert "middleware2_start" in execution_log
assert "middleware2_end" in execution_log
# Clear execution log
execution_log.clear()
# Modify the middleware collection by replacing with just middleware2
agent.middleware = [middleware2]
# Third execution - should use only middleware2
await agent.run("Test message 3")
assert "middleware1_start" not in execution_log
assert "middleware1_end" not in execution_log
assert "middleware2_start" in execution_log
assert "middleware2_end" in execution_log
# Clear execution log
execution_log.clear()
# Remove all middleware
agent.middleware = []
# Fourth execution - should use no middleware
await agent.run("Test message 4")
assert len(execution_log) == 0
async def test_middleware_dynamic_rebuild_streaming(self, chat_client: "MockChatClient") -> None:
"""Test that middleware pipeline is rebuilt for streaming when agent.middleware collection is modified."""
execution_log: list[str] = []
# Create agent with initial middleware
middleware1 = self.TrackingAgentMiddleware("stream_middleware1", execution_log)
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1])
# First streaming execution
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream("Test stream message 1"):
updates.append(update)
assert "stream_middleware1_start" in execution_log
assert "stream_middleware1_end" in execution_log
# Clear execution log
execution_log.clear()
# Modify the middleware collection
middleware2 = self.TrackingAgentMiddleware("stream_middleware2", execution_log)
agent.middleware = [middleware2]
# Second streaming execution - should use only middleware2
updates = []
async for update in agent.run_stream("Test stream message 2"):
updates.append(update)
assert "stream_middleware1_start" not in execution_log
assert "stream_middleware1_end" not in execution_log
assert "stream_middleware2_start" in execution_log
assert "stream_middleware2_end" in execution_log
async def test_middleware_order_change_detection(self, chat_client: "MockChatClient") -> None:
"""Test that changing the order of middleware is detected and applied."""
execution_log: list[str] = []
middleware1 = self.TrackingAgentMiddleware("first", execution_log)
middleware2 = self.TrackingAgentMiddleware("second", execution_log)
# Create agent with middleware in order [first, second]
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2])
# First execution
await agent.run("Test message 1")
assert execution_log == ["first_start", "second_start", "second_end", "first_end"]
# Clear execution log
execution_log.clear()
# Change order to [second, first]
agent.middleware = [middleware2, middleware1]
# Second execution - should reflect new order
await agent.run("Test message 2")
assert execution_log == ["second_start", "first_start", "first_end", "second_end"]
class TestRunLevelMiddleware:
"""Test cases for run-level middleware functionality."""
class TrackingAgentMiddleware(AgentMiddleware):
"""Test middleware that tracks execution."""
def __init__(self, name: str, execution_log: list[str]):
self.name = name
self.execution_log = execution_log
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
self.execution_log.append(f"{self.name}_start")
await next(context)
self.execution_log.append(f"{self.name}_end")
async def test_run_level_middleware_isolation(self, chat_client: "MockChatClient") -> None:
"""Test that run-level middleware is isolated between multiple runs."""
execution_log: list[str] = []
# Create agent without any agent-level middleware
agent = ChatAgent(chat_client=chat_client)
# Create run-level middleware
run_middleware1 = self.TrackingAgentMiddleware("run1", execution_log)
run_middleware2 = self.TrackingAgentMiddleware("run2", execution_log)
# First run with run_middleware1
await agent.run("Test message 1", middleware=[run_middleware1])
assert execution_log == ["run1_start", "run1_end"]
# Clear execution log
execution_log.clear()
# Second run with run_middleware2 - should not see run_middleware1
await agent.run("Test message 2", middleware=[run_middleware2])
assert execution_log == ["run2_start", "run2_end"]
assert "run1_start" not in execution_log
assert "run1_end" not in execution_log
# Clear execution log
execution_log.clear()
# Third run with no middleware - should not see any middleware execution
await agent.run("Test message 3")
assert execution_log == []
# Clear execution log
execution_log.clear()
# Fourth run with both run middlewares - should see both
await agent.run("Test message 4", middleware=[run_middleware1, run_middleware2])
assert execution_log == ["run1_start", "run2_start", "run2_end", "run1_end"]
async def test_agent_plus_run_middleware_execution_order(self, chat_client: "MockChatClient") -> None:
"""Test that agent middleware executes first, followed by run middleware."""
execution_log: list[str] = []
metadata_log: list[str] = []
class MetadataAgentMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_log.append(f"{self.name}_start")
# Set metadata to pass information to run middleware
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
await next(context)
execution_log.append(f"{self.name}_end")
class MetadataRunMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_log.append(f"{self.name}_start")
# Read metadata set by agent middleware
for key, value in context.metadata.items():
metadata_log.append(f"{self.name}_reads_{key}:{value}")
# Set run-level metadata
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
await next(context)
execution_log.append(f"{self.name}_end")
# Create agent with agent-level middleware
agent_middleware = MetadataAgentMiddleware("agent")
agent = ChatAgent(chat_client=chat_client, middleware=[agent_middleware])
# Create run-level middleware
run_middleware = MetadataRunMiddleware("run")
# Execute with both agent and run middleware
await agent.run("Test message", middleware=[run_middleware])
# Verify execution order: agent middleware wraps run middleware
expected_order = ["agent_start", "run_start", "run_end", "agent_end"]
assert execution_log == expected_order
# Verify that run middleware can read agent middleware metadata
assert "run_reads_agent_key:agent_value" in metadata_log
async def test_run_level_middleware_non_streaming(self, chat_client: "MockChatClient") -> None:
"""Test run-level middleware with non-streaming execution."""
execution_log: list[str] = []
# Create agent without agent-level middleware
agent = ChatAgent(chat_client=chat_client)
# Create run-level middleware
run_middleware = self.TrackingAgentMiddleware("run_nonstream", execution_log)
# Execute non-streaming with run middleware
response = await agent.run("Test non-streaming", middleware=[run_middleware])
# Verify response is correct
assert response is not None
assert len(response.messages) > 0
assert response.messages[0].role == Role.ASSISTANT
assert "test response" in response.messages[0].text
# Verify middleware was executed
assert execution_log == ["run_nonstream_start", "run_nonstream_end"]
async def test_run_level_middleware_streaming(self, chat_client: "MockChatClient") -> None:
"""Test run-level middleware with streaming execution."""
execution_log: list[str] = []
streaming_flags: list[bool] = []
class StreamingTrackingMiddleware(AgentMiddleware):
def __init__(self, name: str):
self.name = name
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_log.append(f"{self.name}_start")
streaming_flags.append(context.is_streaming)
await next(context)
execution_log.append(f"{self.name}_end")
# Create agent without agent-level middleware
agent = ChatAgent(chat_client=chat_client)
# Set up mock streaming responses
chat_client.streaming_responses = [
[
ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT),
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
]
]
# Create run-level middleware
run_middleware = StreamingTrackingMiddleware("run_stream")
# Execute streaming with run middleware
updates: list[AgentRunResponseUpdate] = []
async for update in agent.run_stream("Test streaming", middleware=[run_middleware]):
updates.append(update)
# Verify streaming response
assert len(updates) == 2
assert updates[0].text == "Stream"
assert updates[1].text == " response"
# Verify middleware was executed with correct streaming flag
assert execution_log == ["run_stream_start", "run_stream_end"]
assert streaming_flags == [True] # Context should indicate streaming
async def test_agent_and_run_level_both_agent_and_function_middleware(self, chat_client: "MockChatClient") -> None:
"""Test complete scenario with agent and function middleware at both agent-level and run-level."""
execution_log: list[str] = []
# Agent-level middleware
class AgentLevelAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_log.append("agent_level_agent_start")
context.metadata["agent_level_agent"] = "processed"
await next(context)
execution_log.append("agent_level_agent_end")
class AgentLevelFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_log.append("agent_level_function_start")
context.metadata["agent_level_function"] = "processed"
await next(context)
execution_log.append("agent_level_function_end")
# Run-level middleware
class RunLevelAgentMiddleware(AgentMiddleware):
async def process(
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_log.append("run_level_agent_start")
# Verify agent-level middleware metadata is available
assert "agent_level_agent" in context.metadata
context.metadata["run_level_agent"] = "processed"
await next(context)
execution_log.append("run_level_agent_end")
class RunLevelFunctionMiddleware(FunctionMiddleware):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
execution_log.append("run_level_function_start")
# Verify agent-level function middleware metadata is available
assert "agent_level_function" in context.metadata
context.metadata["run_level_function"] = "processed"
await next(context)
execution_log.append("run_level_function_end")
# Create tool function for testing function middleware
def custom_tool(message: str) -> str:
execution_log.append("tool_executed")
return f"Tool response: {message}"
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="test_call",
name="custom_tool",
arguments='{"message": "test"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client.responses = [function_call_response, final_response]
# Create agent with agent-level middleware
agent = ChatAgent(
chat_client=chat_client,
middleware=[AgentLevelAgentMiddleware(), AgentLevelFunctionMiddleware()],
tools=[custom_tool],
)
# Execute with run-level middleware
response = await agent.run(
"Test message",
middleware=[RunLevelAgentMiddleware(), RunLevelFunctionMiddleware()],
)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert chat_client.call_count == 2 # Function call + final response
expected_order = [
"agent_level_agent_start",
"run_level_agent_start",
"agent_level_function_start",
"run_level_function_start",
"tool_executed",
"run_level_function_end",
"agent_level_function_end",
"run_level_agent_end",
"agent_level_agent_end",
]
assert execution_log == expected_order
# Verify function call and result are in the response
all_contents = [content for message in response.messages for content in message.contents]
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
assert len(function_calls) == 1
assert len(function_results) == 1
assert function_calls[0].name == "custom_tool"
assert function_results[0].call_id == function_calls[0].call_id
assert function_results[0].result is not None
assert "Tool response: test" in str(function_results[0].result)
class TestMiddlewareDecoratorLogic:
"""Test the middleware decorator and type annotation logic."""
async def test_decorator_and_type_match(self, chat_client: MockChatClient) -> None:
"""Both decorator and parameter type specified and match."""
execution_order: list[str] = []
@agent_middleware
async def matching_agent_middleware(
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
) -> None:
execution_order.append("decorator_type_match_agent")
await next(context)
@function_middleware
async def matching_function_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("decorator_type_match_function")
await next(context)
# Create tool function for testing function middleware
def custom_tool(message: str) -> str:
execution_order.append("tool_executed")
return f"Tool response: {message}"
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="test_call",
name="custom_tool",
arguments='{"message": "test"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client.responses = [function_call_response, final_response]
# Should work without errors
agent = ChatAgent(
chat_client=chat_client,
middleware=[matching_agent_middleware, matching_function_middleware],
tools=[custom_tool],
)
response = await agent.run([ChatMessage(role=Role.USER, text="test")])
assert response is not None
assert "decorator_type_match_agent" in execution_order
assert "decorator_type_match_function" in execution_order
async def test_decorator_and_type_mismatch(self, chat_client: MockChatClient) -> None:
"""Both decorator and parameter type specified but don't match."""
# This will cause a type error at decoration time, so we need to test differently
# Should raise ValueError due to mismatch during agent creation
with pytest.raises(ValueError, match="Middleware type mismatch"):
@agent_middleware # type: ignore[arg-type]
async def mismatched_middleware(
context: FunctionInvocationContext, # Wrong type for @agent_middleware
next: Any,
) -> None:
await next(context)
agent = ChatAgent(chat_client=chat_client, middleware=[mismatched_middleware])
await agent.run([ChatMessage(role=Role.USER, text="test")])
async def test_only_decorator_specified(self, chat_client: Any) -> None:
"""Only decorator specified - rely on decorator."""
execution_order: list[str] = []
@agent_middleware
async def decorator_only_agent(context: Any, next: Any) -> None: # No type annotation
execution_order.append("decorator_only_agent")
await next(context)
@function_middleware
async def decorator_only_function(context: Any, next: Any) -> None: # No type annotation
execution_order.append("decorator_only_function")
await next(context)
# Create tool function for testing function middleware
def custom_tool(message: str) -> str:
execution_order.append("tool_executed")
return f"Tool response: {message}"
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="test_call",
name="custom_tool",
arguments='{"message": "test"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client.responses = [function_call_response, final_response]
# Should work - relies on decorator
agent = ChatAgent(
chat_client=chat_client, middleware=[decorator_only_agent, decorator_only_function], tools=[custom_tool]
)
response = await agent.run([ChatMessage(role=Role.USER, text="test")])
assert response is not None
assert "decorator_only_agent" in execution_order
assert "decorator_only_function" in execution_order
async def test_only_type_specified(self, chat_client: Any) -> None:
"""Only parameter type specified - rely on types."""
execution_order: list[str] = []
# No decorator
async def type_only_agent(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
execution_order.append("type_only_agent")
await next(context)
# No decorator
async def type_only_function(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
execution_order.append("type_only_function")
await next(context)
# Create tool function for testing function middleware
def custom_tool(message: str) -> str:
execution_order.append("tool_executed")
return f"Tool response: {message}"
# Set up mock to return a function call first, then a regular response
function_call_response = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id="test_call",
name="custom_tool",
arguments='{"message": "test"}',
)
],
)
]
)
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
chat_client.responses = [function_call_response, final_response]
# Should work - relies on type annotations
agent = ChatAgent(
chat_client=chat_client, middleware=[type_only_agent, type_only_function], tools=[custom_tool]
)
response = await agent.run([ChatMessage(role=Role.USER, text="test")])
assert response is not None
assert "type_only_agent" in execution_order
assert "type_only_function" in execution_order
async def test_neither_decorator_nor_type(self, chat_client: Any) -> None:
"""Neither decorator nor parameter type specified - should throw exception."""
async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type
await next(context)
# Should raise ValueError
with pytest.raises(ValueError, match="Cannot determine middleware type"):
agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware])
await agent.run([ChatMessage(role=Role.USER, text="test")])
async def test_insufficient_parameters_error(self, chat_client: Any) -> None:
"""Test that middleware with insufficient parameters raises an error."""
from agent_framework import ChatAgent, agent_middleware
# Should raise ValueError about insufficient parameters
with pytest.raises(ValueError, match="must have at least 2 parameters"):
@agent_middleware # type: ignore[arg-type]
async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter
pass
agent = ChatAgent(chat_client=chat_client, middleware=[insufficient_params_middleware])
await agent.run([ChatMessage(role=Role.USER, text="test")])
async def test_decorator_markers_preserved(self) -> None:
"""Test that decorator markers are properly set on functions."""
@agent_middleware
async def test_agent_middleware(context: Any, next: Any) -> None:
pass
@function_middleware
async def test_function_middleware(context: Any, next: Any) -> None:
pass
# Check that decorator markers were set
assert hasattr(test_agent_middleware, "_middleware_type")
assert test_agent_middleware._middleware_type == MiddlewareType.AGENT # type: ignore[attr-defined]
assert hasattr(test_function_middleware, "_middleware_type")
assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined]
@@ -0,0 +1,271 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentMiddleware,
AgentRunContext,
AgentRunResponse,
FunctionInvocationContext,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Agent-Level and Run-Level Middleware Example
This sample demonstrates the difference between agent-level and run-level middleware:
- Agent-level middleware: Applied to ALL runs of the agent (persistent across runs)
- Run-level middleware: Applied to specific runs only (isolated per run)
The example shows:
1. Agent-level security middleware that validates all requests
2. Agent-level performance monitoring across all runs
3. Run-level context middleware for specific use cases (high priority, debugging)
4. Run-level caching middleware for expensive operations
Execution order: Agent middleware (outermost) -> Run middleware (innermost) -> Agent execution
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
# Agent-level middleware (applied to ALL runs)
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent-level security middleware that validates all requests."""
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
print("[SecurityMiddleware] Checking security for all requests...")
# Check for security violations in the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text.lower()
if any(word in query for word in ["password", "secret", "credentials"]):
print("[SecurityMiddleware] Security violation detected! Blocking request.")
return # Don't call next() to prevent execution
print("[SecurityMiddleware] Security check passed.")
context.metadata["security_validated"] = True
await next(context)
async def performance_monitor_middleware(
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Agent-level performance monitoring for all runs."""
print("[PerformanceMonitor] Starting performance monitoring...")
start_time = time.time()
await next(context)
end_time = time.time()
duration = end_time - start_time
print(f"[PerformanceMonitor] Total execution time: {duration:.3f}s")
context.metadata["execution_time"] = duration
# Run-level middleware (applied to specific runs only)
class HighPriorityMiddleware(AgentMiddleware):
"""Run-level middleware for high priority requests."""
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
print("[HighPriority] Processing high priority request with expedited handling...")
# Read metadata set by agent-level middleware
if context.metadata.get("security_validated"):
print("[HighPriority] Security validation confirmed from agent middleware")
# Set high priority flag
context.metadata["priority"] = "high"
context.metadata["expedited"] = True
await next(context)
print("[HighPriority] High priority processing completed")
async def debugging_middleware(
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Run-level debugging middleware for troubleshooting specific runs."""
print("[Debug] Debug mode enabled for this run")
print(f"[Debug] Messages count: {len(context.messages)}")
print(f"[Debug] Is streaming: {context.is_streaming}")
# Log existing metadata from agent middleware
if context.metadata:
print(f"[Debug] Existing metadata: {context.metadata}")
context.metadata["debug_enabled"] = True
await next(context)
print("[Debug] Debug information collected")
class CachingMiddleware(AgentMiddleware):
"""Run-level caching middleware for expensive operations."""
def __init__(self) -> None:
self.cache: dict[str, AgentRunResponse] = {}
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
# Create a simple cache key from the last message
last_message = context.messages[-1] if context.messages else None
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
if cache_key in self.cache:
print(f"[Cache] Cache HIT for: '{cache_key[:30]}...'")
context.result = self.cache[cache_key] # type: ignore
return # Don't call next(), return cached result
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
context.metadata["cache_key"] = cache_key
await next(context)
# Cache the result if we have one
if context.result:
self.cache[cache_key] = context.result # type: ignore
print("[Cache] Result cached for future use")
async def function_logging_middleware(
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Function middleware that logs all function calls."""
function_name = context.function.name
args = context.arguments
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
await next(context)
print(f"[FunctionLog] Function {function_name} completed")
async def main() -> None:
"""Example demonstrating agent-level and run-level middleware."""
print("=== Agent-Level and Run-Level Middleware Example ===\n")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
# Agent-level middleware: applied to ALL runs
middleware=[
SecurityAgentMiddleware(),
performance_monitor_middleware,
function_logging_middleware,
],
) as agent,
):
print("Agent created with agent-level middleware:")
print(" - SecurityMiddleware (blocks sensitive requests)")
print(" - PerformanceMonitor (tracks execution time)")
print(" - FunctionLogging (logs all function calls)")
print()
# Run 1: Normal query with no run-level middleware
print("=" * 60)
print("RUN 1: Normal query (agent-level middleware only)")
print("=" * 60)
query = "What's the weather like in Paris?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 2: High priority request with run-level middleware
print("=" * 60)
print("RUN 2: High priority request (agent + run-level middleware)")
print("=" * 60)
query = "What's the weather in Tokyo? This is urgent!"
print(f"User: {query}")
result = await agent.run(
query,
middleware=HighPriorityMiddleware(), # Run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 3: Debug mode with run-level debugging middleware
print("=" * 60)
print("RUN 3: Debug mode (agent + run-level debugging)")
print("=" * 60)
query = "What's the weather in London?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[debugging_middleware], # Run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 4: Multiple run-level middleware
print("=" * 60)
print("RUN 4: Multiple run-level middleware (caching + debug)")
print("=" * 60)
caching = CachingMiddleware()
query = "What's the weather in New York?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[caching, debugging_middleware], # Multiple run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 5: Test cache hit with same query
print("=" * 60)
print("RUN 5: Test cache hit (same query as Run 4)")
print("=" * 60)
print(f"User: {query}") # Same query as Run 4
result = await agent.run(
query,
middleware=[caching], # Same caching middleware instance
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 6: Security violation test
print("=" * 60)
print("RUN 6: Security test (should be blocked by agent middleware)")
print("=" * 60)
query = "What's the secret weather password for Berlin?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
print()
# Run 7: Normal query again (no run-level middleware interference)
print("=" * 60)
print("RUN 7: Normal query again (agent-level middleware only)")
print("=" * 60)
query = "What's the weather in Sydney?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,87 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import datetime
from agent_framework import (
agent_middleware,
function_middleware,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
"""
Decorator Middleware Example
This sample demonstrates how to use @agent_middleware and @function_middleware decorators
to explicitly mark middleware functions without requiring type annotations.
The framework supports the following middleware detection scenarios:
1. Both decorator and parameter type specified:
- Validates that they match (e.g., @agent_middleware with AgentRunContext)
- Throws exception if they don't match for safety
2. Only decorator specified:
- Relies on decorator to determine middleware type
- No type annotations needed - framework handles context types automatically
3. Only parameter type specified:
- Uses type annotations (AgentRunContext, FunctionInvocationContext) for detection
4. Neither decorator nor parameter type specified:
- Throws exception requiring either decorator or type annotation
- Prevents ambiguous middleware that can't be properly classified
Key benefits of decorator approach:
- No type annotations needed (simpler syntax)
- Explicit middleware type declaration
- Clear intent in code
- Prevents type mismatches
"""
def get_current_time() -> str:
"""Get the current time."""
return f"Current time is {datetime.datetime.now().strftime('%H:%M:%S')}"
@agent_middleware # Decorator marks this as agent middleware - no type annotations needed
async def simple_agent_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Agent middleware that runs before and after agent execution."""
print("[Agent Middleware] Before agent execution")
await next(context)
print("[Agent Middleware] After agent execution")
@function_middleware # Decorator marks this as function middleware - no type annotations needed
async def simple_function_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Function middleware that runs before and after function calls."""
print(f"[Function Middleware] Before calling: {context.function.name}") # type: ignore
await next(context)
print(f"[Function Middleware] After calling: {context.function.name}") # type: ignore
async def main() -> None:
"""Example demonstrating decorator-based middleware."""
print("=== Decorator Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="TimeAgent",
instructions="You are a helpful time assistant. Call get_current_time when asked about time.",
tools=get_current_time,
middleware=[simple_agent_middleware, simple_function_middleware],
) as agent,
):
query = "What time is it?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,177 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentMiddleware,
AgentRunContext,
AgentRunResponse,
ChatMessage,
Role,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Middleware Termination Example
This sample demonstrates how middleware can terminate execution using the `context.terminate` flag.
The example includes:
- PreTerminationMiddleware: Terminates execution before calling next() to prevent agent processing
- PostTerminationMiddleware: Allows processing to complete but terminates further execution
This is useful for implementing security checks, rate limiting, or early exit conditions.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class PreTerminationMiddleware(AgentMiddleware):
"""Middleware that terminates execution before calling the agent."""
def __init__(self, blocked_words: list[str]):
self.blocked_words = [word.lower() for word in blocked_words]
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
# Check if the user message contains any blocked words
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text.lower()
for blocked_word in self.blocked_words:
if blocked_word in query:
print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.")
# Set a custom response
context.result = AgentRunResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
text=(
f"Sorry, I cannot process requests containing '{blocked_word}'. "
"Please rephrase your question."
),
)
]
)
# Set terminate flag to prevent further processing
context.terminate = True
break
await next(context)
class PostTerminationMiddleware(AgentMiddleware):
"""Middleware that allows processing but terminates after reaching max responses across multiple runs."""
def __init__(self, max_responses: int = 1):
self.max_responses = max_responses
self.response_count = 0
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
# Check if we should terminate before processing
if self.response_count >= self.max_responses:
print(
f"[PostTerminationMiddleware] Maximum responses ({self.max_responses}) reached. "
"Terminating further processing."
)
context.terminate = True
# Allow the agent to process normally
await next(context)
# Increment response count after processing
self.response_count += 1
async def pre_termination_middleware() -> None:
"""Demonstrate pre-termination middleware that blocks requests with certain words."""
print("\n--- Example 1: Pre-termination Middleware ---")
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]),
) as agent,
):
# Test with normal query
print("\n1. Normal query:")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
# Test with blocked word
print("\n2. Query with blocked word:")
query = "What's the bad weather in New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
async def post_termination_middleware() -> None:
"""Demonstrate post-termination middleware that limits responses across multiple runs."""
print("\n--- Example 2: Post-termination Middleware ---")
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=PostTerminationMiddleware(max_responses=1),
) as agent,
):
# First run (should work)
print("\n1. First run:")
query = "What's the weather in Paris?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
# Second run (should be terminated by middleware)
print("\n2. Second run (should be terminated):")
query = "What about the weather in London?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
# Third run (should also be terminated)
print("\n3. Third run (should also be terminated):")
query = "And New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
async def main() -> None:
"""Example demonstrating middleware termination functionality."""
print("=== Middleware Termination Example ===")
await pre_termination_middleware()
await post_termination_middleware()
if __name__ == "__main__":
asyncio.run(main())