mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Extending middleware capabilities (#844)
* Implemented termination * Added termination sample * Allowed middleware pipeline modification * Added run-level middleware * Added more validation to function-based middleware * Added example with function-based decorator approach * Update python/samples/getting_started/middleware/decorator_middleware.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update python/samples/getting_started/middleware/decorator_middleware.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Small improvements * Fixed tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
08f792e511
commit
f61d8abe58
@@ -3,6 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
|
||||
|
||||
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
@@ -15,12 +16,22 @@ if TYPE_CHECKING:
|
||||
|
||||
TAgent = TypeVar("TAgent", bound="AgentProtocol")
|
||||
|
||||
|
||||
class MiddlewareType(Enum):
|
||||
"""Enum representing the type of middleware."""
|
||||
|
||||
AGENT = "agent"
|
||||
FUNCTION = "function"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentRunContext",
|
||||
"FunctionInvocationContext",
|
||||
"FunctionMiddleware",
|
||||
"Middleware",
|
||||
"agent_middleware",
|
||||
"function_middleware",
|
||||
"use_agent_middleware",
|
||||
]
|
||||
|
||||
@@ -38,6 +49,8 @@ class AgentRunContext:
|
||||
to see the actual execution result or can be set to override the execution result.
|
||||
For non-streaming: should be AgentRunResponse
|
||||
For streaming: should be AsyncIterable[AgentRunResponseUpdate]
|
||||
terminate: A flag indicating whether to terminate execution after current middleware.
|
||||
When set to True, execution will stop as soon as control returns to framework.
|
||||
"""
|
||||
|
||||
agent: "AgentProtocol"
|
||||
@@ -45,6 +58,7 @@ class AgentRunContext:
|
||||
is_streaming: bool = False
|
||||
metadata: dict[str, Any] = field(default_factory=lambda: {})
|
||||
result: AgentRunResponse | AsyncIterable[AgentRunResponseUpdate] | None = None
|
||||
terminate: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -57,12 +71,15 @@ class FunctionInvocationContext:
|
||||
metadata: Metadata dictionary for sharing data between function middleware.
|
||||
result: Function execution result. Can be observed after calling next()
|
||||
to see the actual execution result or can be set to override the execution result.
|
||||
terminate: A flag indicating whether to terminate execution after current middleware.
|
||||
When set to True, execution will stop as soon as control returns to framework.
|
||||
"""
|
||||
|
||||
function: "AIFunction[Any, Any]"
|
||||
arguments: "BaseModel"
|
||||
metadata: dict[str, Any] = field(default_factory=lambda: {})
|
||||
result: Any = None
|
||||
terminate: bool = False
|
||||
|
||||
|
||||
class AgentMiddleware(ABC):
|
||||
@@ -131,6 +148,53 @@ FunctionMiddlewareCallable = Callable[
|
||||
Middleware: TypeAlias = AgentMiddleware | AgentMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable
|
||||
|
||||
|
||||
# Middleware type markers for decorators
|
||||
def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:
|
||||
"""Decorator to mark a function as agent middleware.
|
||||
|
||||
This decorator explicitly identifies a function as agent middleware,
|
||||
which processes AgentRunContext objects.
|
||||
|
||||
Args:
|
||||
func: The middleware function to mark as agent middleware.
|
||||
|
||||
Returns:
|
||||
The same function with agent middleware marker.
|
||||
|
||||
Example:
|
||||
@agent_middleware
|
||||
async def my_middleware(context: AgentRunContext, next):
|
||||
# Process agent invocation
|
||||
await next(context)
|
||||
"""
|
||||
# Add marker attribute to identify this as agent middleware
|
||||
func._middleware_type: MiddlewareType = MiddlewareType.AGENT # type: ignore
|
||||
return func
|
||||
|
||||
|
||||
def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareCallable:
|
||||
"""Decorator to mark a function as function middleware.
|
||||
|
||||
This decorator explicitly identifies a function as function middleware,
|
||||
which processes FunctionInvocationContext objects.
|
||||
|
||||
Args:
|
||||
func: The middleware function to mark as function middleware.
|
||||
|
||||
Returns:
|
||||
The same function with function middleware marker.
|
||||
|
||||
Example:
|
||||
@function_middleware
|
||||
async def my_middleware(context: FunctionInvocationContext, next):
|
||||
# Process function invocation
|
||||
await next(context)
|
||||
"""
|
||||
# Add marker attribute to identify this as function middleware
|
||||
func._middleware_type: MiddlewareType = MiddlewareType.FUNCTION # type: ignore
|
||||
return func
|
||||
|
||||
|
||||
class AgentMiddlewareWrapper(AgentMiddleware):
|
||||
"""Wrapper to convert pure functions into AgentMiddleware protocol objects."""
|
||||
|
||||
@@ -271,6 +335,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
if index >= len(self._middlewares):
|
||||
|
||||
async def final_wrapper(c: AgentRunContext) -> None:
|
||||
# If terminate was set, skip execution
|
||||
if c.terminate:
|
||||
return
|
||||
|
||||
# Execute actual handler and populate context for observability
|
||||
result = await final_handler(c)
|
||||
result_container["result"] = result
|
||||
@@ -282,6 +350,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
next_handler = create_next_handler(index + 1)
|
||||
|
||||
async def current_handler(c: AgentRunContext) -> None:
|
||||
# If terminate is set, don't continue the pipeline
|
||||
if c.terminate:
|
||||
return
|
||||
|
||||
await middleware.process(c, next_handler)
|
||||
# After middleware execution, check if response was overridden
|
||||
if c.result is not None and isinstance(c.result, AgentRunResponse):
|
||||
@@ -337,6 +409,10 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
if index >= len(self._middlewares):
|
||||
|
||||
async def final_wrapper(c: AgentRunContext) -> None: # noqa: RUF029
|
||||
# If terminate was set, skip execution
|
||||
if c.terminate:
|
||||
return
|
||||
|
||||
# Execute actual handler and populate context for observability
|
||||
result = final_handler(c)
|
||||
result_container["result_stream"] = result
|
||||
@@ -349,6 +425,9 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
|
||||
async def current_handler(c: AgentRunContext) -> None:
|
||||
await middleware.process(c, next_handler)
|
||||
# If terminate is set, don't continue the pipeline
|
||||
if c.terminate:
|
||||
return
|
||||
|
||||
return current_handler
|
||||
|
||||
@@ -424,8 +503,8 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline):
|
||||
|
||||
# Custom final handler that handles pre-existing results
|
||||
async def function_final_handler(c: FunctionInvocationContext) -> Any:
|
||||
# If result was set before calling next(), skip execution
|
||||
if c.result is not None:
|
||||
# If terminate was set, skip execution and return the result (which might be None)
|
||||
if c.terminate:
|
||||
return c.result
|
||||
# Execute actual handler and populate context for observability
|
||||
return await final_handler(c)
|
||||
@@ -446,6 +525,10 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
|
||||
This decorator adds middleware functionality to any agent class.
|
||||
It wraps the run() and run_stream() methods to provide middleware execution.
|
||||
|
||||
The middleware execution can be terminated at any point by setting the
|
||||
context.terminate property to True. Once set, the pipeline will stop executing
|
||||
further middleware as soon as control returns to the pipeline.
|
||||
|
||||
Args:
|
||||
agent_class: The agent class to add middleware support to.
|
||||
|
||||
@@ -458,12 +541,101 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
|
||||
original_run = agent_class.run # type: ignore[attr-defined]
|
||||
original_run_stream = agent_class.run_stream # type: ignore[attr-defined]
|
||||
|
||||
def _initialize_middleware_pipelines(self: Any, middlewares: Middleware | list[Middleware] | None) -> None:
|
||||
"""Initialize agent and function middleware pipelines from the provided middleware list."""
|
||||
if not middlewares:
|
||||
return
|
||||
def _determine_middleware_type(middleware: Any) -> MiddlewareType:
|
||||
"""Determine middleware type using decorator and/or parameter type annotation.
|
||||
|
||||
middleware_list: list[Middleware] = middlewares if isinstance(middlewares, list) else [middlewares] # type: ignore
|
||||
Args:
|
||||
middleware: The middleware function to analyze.
|
||||
|
||||
Returns:
|
||||
MiddlewareType.AGENT or MiddlewareType.FUNCTION indicating the middleware type.
|
||||
|
||||
Raises:
|
||||
ValueError: When middleware type cannot be determined or there's a mismatch.
|
||||
"""
|
||||
# Check for decorator marker
|
||||
decorator_type: MiddlewareType | None = getattr(middleware, "_middleware_type", None)
|
||||
|
||||
# Check for parameter type annotation
|
||||
param_type: MiddlewareType | None = None
|
||||
try:
|
||||
sig = inspect.signature(middleware)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
# Must have at least 2 parameters (context and next)
|
||||
if len(params) >= 2:
|
||||
first_param = params[0]
|
||||
if hasattr(first_param.annotation, "__name__"):
|
||||
annotation_name = first_param.annotation.__name__
|
||||
if annotation_name == "AgentRunContext":
|
||||
param_type = MiddlewareType.AGENT
|
||||
elif annotation_name == "FunctionInvocationContext":
|
||||
param_type = MiddlewareType.FUNCTION
|
||||
else:
|
||||
# Not enough parameters - can't be valid middleware
|
||||
raise ValueError(
|
||||
f"Middleware function must have at least 2 parameters (context, next), "
|
||||
f"but {middleware.__name__} has {len(params)}"
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise # Re-raise our custom errors
|
||||
# Signature inspection failed - continue with other checks
|
||||
pass
|
||||
|
||||
if decorator_type and param_type:
|
||||
# Both decorator and parameter type specified - they must match
|
||||
if decorator_type != param_type:
|
||||
raise ValueError(
|
||||
f"Middleware type mismatch: decorator indicates '{decorator_type.value}' "
|
||||
f"but parameter type indicates '{param_type.value}' for function {middleware.__name__}"
|
||||
)
|
||||
return decorator_type
|
||||
|
||||
if decorator_type:
|
||||
# Just decorator specified - rely on decorator
|
||||
return decorator_type
|
||||
|
||||
if param_type:
|
||||
# Just parameter type specified - rely on types
|
||||
return param_type
|
||||
|
||||
# Neither decorator nor parameter type specified - throw exception
|
||||
raise ValueError(
|
||||
f"Cannot determine middleware type for function {middleware.__name__}. "
|
||||
f"Please either use @agent_middleware/@function_middleware decorators "
|
||||
f"or specify parameter types (AgentRunContext or FunctionInvocationContext)."
|
||||
)
|
||||
|
||||
def _build_middleware_pipelines(
|
||||
agent_level_middlewares: Middleware | list[Middleware] | None,
|
||||
run_level_middlewares: Middleware | list[Middleware] | None = None,
|
||||
) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline]:
|
||||
"""Build fresh agent and function middleware pipelines from the provided middleware lists.
|
||||
|
||||
Args:
|
||||
agent_level_middlewares: Agent-level middleware (executed first)
|
||||
run_level_middlewares: Run-level middleware (executed after agent middleware)
|
||||
"""
|
||||
# Merge middleware lists: agent middleware first, then run middleware
|
||||
combined_middlewares: list[Middleware] = []
|
||||
|
||||
if agent_level_middlewares:
|
||||
if isinstance(agent_level_middlewares, list):
|
||||
combined_middlewares.extend(agent_level_middlewares) # type: ignore[arg-type]
|
||||
else:
|
||||
combined_middlewares.append(agent_level_middlewares)
|
||||
|
||||
if run_level_middlewares:
|
||||
if isinstance(run_level_middlewares, list):
|
||||
combined_middlewares.extend(run_level_middlewares) # type: ignore[arg-type]
|
||||
else:
|
||||
combined_middlewares.append(run_level_middlewares)
|
||||
|
||||
if not combined_middlewares:
|
||||
return AgentMiddlewarePipeline(), FunctionMiddlewarePipeline()
|
||||
|
||||
middleware_list = combined_middlewares
|
||||
|
||||
# Separate agent and function middleware using isinstance checks
|
||||
agent_middlewares: list[AgentMiddleware | AgentMiddlewareCallable] = []
|
||||
@@ -475,75 +647,42 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
|
||||
elif isinstance(middleware, FunctionMiddleware):
|
||||
function_middlewares.append(middleware)
|
||||
elif callable(middleware): # type: ignore[arg-type]
|
||||
# Check function signature to determine type
|
||||
try:
|
||||
sig = inspect.signature(middleware)
|
||||
params = list(sig.parameters.values())
|
||||
if len(params) >= 1:
|
||||
first_param = params[0]
|
||||
# Check if first parameter is AgentRunContext or FunctionInvocationContext
|
||||
if (
|
||||
hasattr(first_param.annotation, "__name__")
|
||||
and first_param.annotation.__name__ == "AgentRunContext"
|
||||
):
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
elif (
|
||||
hasattr(first_param.annotation, "__name__")
|
||||
and first_param.annotation.__name__ == "FunctionInvocationContext"
|
||||
):
|
||||
function_middlewares.append(middleware) # type: ignore
|
||||
else:
|
||||
# Default to agent middleware if uncertain
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
else:
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
except Exception:
|
||||
# If signature inspection fails, assume it's an agent middleware
|
||||
# Determine middleware type using decorator and/or parameter type annotation
|
||||
middleware_type = _determine_middleware_type(middleware)
|
||||
if middleware_type == MiddlewareType.AGENT:
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
elif middleware_type == MiddlewareType.FUNCTION:
|
||||
function_middlewares.append(middleware) # type: ignore
|
||||
else:
|
||||
# This should not happen if _determine_middleware_type is implemented correctly
|
||||
raise ValueError(f"Unknown middleware type: {middleware_type}")
|
||||
else:
|
||||
# Fallback
|
||||
agent_middlewares.append(middleware) # type: ignore
|
||||
|
||||
self._agent_middleware_pipeline = AgentMiddlewarePipeline(agent_middlewares)
|
||||
self._function_middleware_pipeline = FunctionMiddlewarePipeline(function_middlewares)
|
||||
return AgentMiddlewarePipeline(agent_middlewares), FunctionMiddlewarePipeline(function_middlewares)
|
||||
|
||||
async def middleware_enabled_run(
|
||||
self: Any,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: Any = None,
|
||||
middleware: Middleware | list[Middleware] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResponse:
|
||||
"""Middleware-enabled run method."""
|
||||
# Initialize middleware pipelines if not already done
|
||||
if (
|
||||
hasattr(self, "middleware")
|
||||
and self.middleware
|
||||
and not (
|
||||
hasattr(self, "_agent_middleware_pipeline")
|
||||
and hasattr(self, "_function_middleware_pipeline")
|
||||
and (
|
||||
self._agent_middleware_pipeline.has_middlewares
|
||||
or self._function_middleware_pipeline.has_middlewares
|
||||
)
|
||||
)
|
||||
):
|
||||
_initialize_middleware_pipelines(self, self.middleware)
|
||||
|
||||
# Ensure pipelines exist even if empty
|
||||
if not hasattr(self, "_agent_middleware_pipeline"):
|
||||
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
|
||||
if not hasattr(self, "_function_middleware_pipeline"):
|
||||
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
|
||||
# Build fresh middleware pipelines from current middleware collection and run-level middleware
|
||||
agent_middleware = getattr(self, "middleware", None)
|
||||
agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware)
|
||||
|
||||
# Add function middleware pipeline to kwargs if available
|
||||
if self._function_middleware_pipeline.has_middlewares:
|
||||
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
|
||||
if function_pipeline.has_middlewares:
|
||||
kwargs["_function_middleware_pipeline"] = function_pipeline
|
||||
|
||||
normalized_messages = self._normalize_messages(messages)
|
||||
|
||||
# Execute with middleware if available
|
||||
if self._agent_middleware_pipeline.has_middlewares:
|
||||
if agent_pipeline.has_middlewares:
|
||||
context = AgentRunContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
messages=normalized_messages,
|
||||
@@ -553,7 +692,7 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
|
||||
async def _execute_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
return await original_run(self, ctx.messages, thread=thread, **kwargs) # type: ignore
|
||||
|
||||
result = await self._agent_middleware_pipeline.execute(
|
||||
result = await agent_pipeline.execute(
|
||||
self, # type: ignore[arg-type]
|
||||
normalized_messages,
|
||||
context,
|
||||
@@ -570,38 +709,22 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: Any = None,
|
||||
middleware: Middleware | list[Middleware] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
"""Middleware-enabled run_stream method."""
|
||||
# Initialize middleware pipelines if not already done
|
||||
if (
|
||||
hasattr(self, "middleware")
|
||||
and self.middleware
|
||||
and not (
|
||||
hasattr(self, "_agent_middleware_pipeline")
|
||||
and hasattr(self, "_function_middleware_pipeline")
|
||||
and (
|
||||
self._agent_middleware_pipeline.has_middlewares
|
||||
or self._function_middleware_pipeline.has_middlewares
|
||||
)
|
||||
)
|
||||
):
|
||||
_initialize_middleware_pipelines(self, self.middleware)
|
||||
|
||||
# Ensure pipelines exist even if empty
|
||||
if not hasattr(self, "_agent_middleware_pipeline"):
|
||||
self._agent_middleware_pipeline = AgentMiddlewarePipeline()
|
||||
if not hasattr(self, "_function_middleware_pipeline"):
|
||||
self._function_middleware_pipeline = FunctionMiddlewarePipeline()
|
||||
# Build fresh middleware pipelines from current middleware collection and run-level middleware
|
||||
agent_middleware = getattr(self, "middleware", None)
|
||||
agent_pipeline, function_pipeline = _build_middleware_pipelines(agent_middleware, middleware)
|
||||
|
||||
# Add function middleware pipeline to kwargs if available
|
||||
if self._function_middleware_pipeline.has_middlewares:
|
||||
kwargs["_function_middleware_pipeline"] = self._function_middleware_pipeline
|
||||
if function_pipeline.has_middlewares:
|
||||
kwargs["_function_middleware_pipeline"] = function_pipeline
|
||||
|
||||
normalized_messages = self._normalize_messages(messages)
|
||||
|
||||
# Execute with middleware if available
|
||||
if self._agent_middleware_pipeline.has_middlewares:
|
||||
if agent_pipeline.has_middlewares:
|
||||
context = AgentRunContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
messages=normalized_messages,
|
||||
@@ -613,7 +736,7 @@ def use_agent_middleware(agent_class: type[TAgent]) -> type[TAgent]:
|
||||
yield update
|
||||
|
||||
async def _stream_generator() -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
async for update in self._agent_middleware_pipeline.execute_stream(
|
||||
async for update in agent_pipeline.execute_stream(
|
||||
self, # type: ignore[arg-type]
|
||||
normalized_messages,
|
||||
context,
|
||||
|
||||
@@ -77,6 +77,16 @@ class TestFunctionInvocationContext:
|
||||
class TestAgentMiddlewarePipeline:
|
||||
"""Test cases for AgentMiddlewarePipeline."""
|
||||
|
||||
class PreNextTerminateMiddleware(AgentMiddleware):
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
context.terminate = True
|
||||
await next(context)
|
||||
|
||||
class PostNextTerminateMiddleware(AgentMiddleware):
|
||||
async def process(self, context: AgentRunContext, next: Any) -> None:
|
||||
await next(context)
|
||||
context.terminate = True
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
"""Test AgentMiddlewarePipeline initialization with no middlewares."""
|
||||
pipeline = AgentMiddlewarePipeline()
|
||||
@@ -194,10 +204,143 @@ class TestAgentMiddlewarePipeline:
|
||||
assert updates[1].text == "chunk2"
|
||||
assert execution_order == ["test_before", "test_after", "handler_start", "handler_end"]
|
||||
|
||||
async def test_execute_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
# Handler should not be executed when terminated before next()
|
||||
execution_order.append("handler")
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
response = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
assert response is not None
|
||||
assert context.terminate
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
assert not response.messages
|
||||
|
||||
async def test_execute_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AgentRunResponse:
|
||||
execution_order.append("handler")
|
||||
return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="response")])
|
||||
|
||||
response = await pipeline.execute(mock_agent, messages, context, final_handler)
|
||||
assert response is not None
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "response"
|
||||
assert context.terminate
|
||||
assert execution_order == ["handler"]
|
||||
|
||||
async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline streaming execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
# Handler should not be executed when terminated before next()
|
||||
execution_order.append("handler_start")
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert context.terminate
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
assert not updates
|
||||
|
||||
async def test_execute_stream_with_post_next_termination(self, mock_agent: AgentProtocol) -> None:
|
||||
"""Test pipeline streaming execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateMiddleware()
|
||||
pipeline = AgentMiddlewarePipeline([middleware])
|
||||
messages = [ChatMessage(role=Role.USER, text="test")]
|
||||
context = AgentRunContext(agent=mock_agent, messages=messages)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
execution_order.append("handler_start")
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk1")])
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="chunk2")])
|
||||
execution_order.append("handler_end")
|
||||
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "chunk1"
|
||||
assert updates[1].text == "chunk2"
|
||||
assert context.terminate
|
||||
assert execution_order == ["handler_start", "handler_end"]
|
||||
|
||||
|
||||
class TestFunctionMiddlewarePipeline:
|
||||
"""Test cases for FunctionMiddlewarePipeline."""
|
||||
|
||||
class PreNextTerminateFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(self, context: FunctionInvocationContext, next: Any) -> None:
|
||||
context.terminate = True
|
||||
await next(context)
|
||||
|
||||
class PostNextTerminateFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(self, context: FunctionInvocationContext, next: Any) -> None:
|
||||
await next(context)
|
||||
context.terminate = True
|
||||
|
||||
async def test_execute_with_pre_next_termination(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test pipeline execution with termination before next()."""
|
||||
middleware = self.PreNextTerminateFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
# Handler should not be executed when terminated before next()
|
||||
execution_order.append("handler")
|
||||
return "test result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
assert result is None
|
||||
assert context.terminate
|
||||
# Handler should not be called when terminated before next()
|
||||
assert execution_order == []
|
||||
|
||||
async def test_execute_with_post_next_termination(self, mock_function: AIFunction[Any, Any]) -> None:
|
||||
"""Test pipeline execution with termination after next()."""
|
||||
middleware = self.PostNextTerminateFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
execution_order: list[str] = []
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
execution_order.append("handler")
|
||||
return "test result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
assert result == "test result"
|
||||
assert context.terminate
|
||||
assert execution_order == ["handler"]
|
||||
|
||||
def test_init_empty(self) -> None:
|
||||
"""Test FunctionMiddlewarePipeline initialization with no middlewares."""
|
||||
pipeline = FunctionMiddlewarePipeline()
|
||||
@@ -884,44 +1027,6 @@ class TestMiddlewareExecutionControl:
|
||||
assert result.messages == [] # Empty response
|
||||
assert not handler_called
|
||||
|
||||
async def test_function_middleware_pre_execution_override_with_next(
|
||||
self, mock_function: AIFunction[Any, Any]
|
||||
) -> None:
|
||||
"""Test that function middleware can override result before calling next() - this skips handler execution."""
|
||||
|
||||
class FunctionTestArgs(BaseModel):
|
||||
name: str = Field(description="Test name parameter")
|
||||
|
||||
class PreOverrideFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Set override first
|
||||
context.result = "pre-override result"
|
||||
# Then call next() to continue middleware pipeline
|
||||
await next(context)
|
||||
|
||||
middleware = PreOverrideFunctionMiddleware()
|
||||
pipeline = FunctionMiddlewarePipeline([middleware])
|
||||
arguments = FunctionTestArgs(name="test")
|
||||
context = FunctionInvocationContext(function=mock_function, arguments=arguments)
|
||||
|
||||
handler_called = False
|
||||
|
||||
async def final_handler(ctx: FunctionInvocationContext) -> str:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
# This should not be called when result is pre-set
|
||||
return "original result"
|
||||
|
||||
result = await pipeline.execute(mock_function, arguments, context, final_handler)
|
||||
|
||||
# Verify pre-override worked and handler was NOT called (because result was already set)
|
||||
assert result == "pre-override result"
|
||||
assert not handler_called
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent() -> AgentProtocol:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentRunResponseUpdate,
|
||||
@@ -12,12 +15,15 @@ from agent_framework import (
|
||||
FunctionResultContent,
|
||||
Role,
|
||||
TextContent,
|
||||
agent_middleware,
|
||||
function_middleware,
|
||||
)
|
||||
from agent_framework._middleware import (
|
||||
AgentMiddleware,
|
||||
AgentRunContext,
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
MiddlewareType,
|
||||
)
|
||||
|
||||
from .conftest import MockChatClient
|
||||
@@ -98,6 +104,170 @@ class TestChatAgentClassBasedMiddleware:
|
||||
class TestChatAgentFunctionBasedMiddleware:
|
||||
"""Test cases for function-based middleware integration with ChatAgent."""
|
||||
|
||||
async def test_agent_middleware_with_pre_termination(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that agent middleware can terminate execution before calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PreTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
context.terminate = True
|
||||
# We call next() but since terminate=True, subsequent middleware and handler should not execute
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create ChatAgent with terminating middleware
|
||||
middleware = PreTerminationMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
|
||||
# Execute the agent with multiple messages
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text="message1"),
|
||||
ChatMessage(role=Role.USER, text="message2"), # This should not be processed due to termination
|
||||
]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert not response.messages # No messages should be in response due to pre-termination
|
||||
assert execution_order == ["middleware_before", "middleware_after"] # Middleware still completes
|
||||
assert chat_client.call_count == 0 # No calls should be made due to termination
|
||||
|
||||
async def test_agent_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that agent middleware can terminate execution after calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PostTerminationMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
context.terminate = True
|
||||
|
||||
# Create ChatAgent with terminating middleware
|
||||
middleware = PostTerminationMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware])
|
||||
|
||||
# Execute the agent with multiple messages
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text="message1"),
|
||||
ChatMessage(role=Role.USER, text="message2"),
|
||||
]
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert "test response" in response.messages[0].text
|
||||
|
||||
# Verify middleware execution order
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
assert chat_client.call_count == 1
|
||||
|
||||
async def test_function_middleware_with_pre_termination(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that function middleware can terminate execution before calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PreTerminationFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
context.terminate = True
|
||||
# We call next() but since terminate=True, subsequent middleware and handler should not execute
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
|
||||
# Create a message to start the conversation
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
|
||||
# Set up chat client to return a function call
|
||||
chat_client.responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"})
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
# Create the test function with the expected signature
|
||||
def test_function(text: str) -> str:
|
||||
execution_order.append("function_called")
|
||||
return "test_result"
|
||||
|
||||
# Create ChatAgent with function middleware and test function
|
||||
middleware = PreTerminationFunctionMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware], tools=[test_function])
|
||||
|
||||
# Execute the agent
|
||||
await agent.run(messages)
|
||||
|
||||
# Verify that function was not called and only middleware executed
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
assert "function_called" not in execution_order
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
|
||||
async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that function middleware can terminate execution after calling next()."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class PostTerminationFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_order.append("middleware_before")
|
||||
await next(context)
|
||||
execution_order.append("middleware_after")
|
||||
context.terminate = True
|
||||
|
||||
# Create a message to start the conversation
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
|
||||
# Set up chat client to return a function call
|
||||
chat_client.responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"})
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
# Create the test function with the expected signature
|
||||
def test_function(text: str) -> str:
|
||||
execution_order.append("function_called")
|
||||
return "test_result"
|
||||
|
||||
# Create ChatAgent with function middleware and test function
|
||||
middleware = PostTerminationFunctionMiddleware()
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware], tools=[test_function])
|
||||
|
||||
# Execute the agent
|
||||
response = await agent.run(messages)
|
||||
|
||||
# Verify that function was called and middleware executed
|
||||
assert response is not None
|
||||
assert "function_called" in execution_order
|
||||
assert execution_order == ["middleware_before", "function_called", "middleware_after"]
|
||||
|
||||
async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test function-based agent middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
@@ -542,3 +712,631 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
assert len(function_results) == 1
|
||||
assert function_calls[0].name == "sample_tool_function"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
|
||||
|
||||
class TestMiddlewareDynamicRebuild:
|
||||
"""Test cases for dynamic middleware pipeline rebuilding with ChatAgent."""
|
||||
|
||||
class TrackingAgentMiddleware(AgentMiddleware):
|
||||
"""Test middleware that tracks execution."""
|
||||
|
||||
def __init__(self, name: str, execution_log: list[str]):
|
||||
self.name = name
|
||||
self.execution_log = execution_log
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
self.execution_log.append(f"{self.name}_start")
|
||||
await next(context)
|
||||
self.execution_log.append(f"{self.name}_end")
|
||||
|
||||
async def test_middleware_dynamic_rebuild_non_streaming(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that middleware pipeline is rebuilt when agent.middleware collection is modified for non-streaming."""
|
||||
execution_log: list[str] = []
|
||||
|
||||
# Create agent with initial middleware
|
||||
middleware1 = self.TrackingAgentMiddleware("middleware1", execution_log)
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1])
|
||||
|
||||
# First execution - should use middleware1
|
||||
await agent.run("Test message 1")
|
||||
assert "middleware1_start" in execution_log
|
||||
assert "middleware1_end" in execution_log
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Modify the middleware collection by adding another middleware
|
||||
middleware2 = self.TrackingAgentMiddleware("middleware2", execution_log)
|
||||
agent.middleware = [middleware1, middleware2]
|
||||
|
||||
# Second execution - should use both middleware1 and middleware2
|
||||
await agent.run("Test message 2")
|
||||
assert "middleware1_start" in execution_log
|
||||
assert "middleware1_end" in execution_log
|
||||
assert "middleware2_start" in execution_log
|
||||
assert "middleware2_end" in execution_log
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Modify the middleware collection by replacing with just middleware2
|
||||
agent.middleware = [middleware2]
|
||||
|
||||
# Third execution - should use only middleware2
|
||||
await agent.run("Test message 3")
|
||||
assert "middleware1_start" not in execution_log
|
||||
assert "middleware1_end" not in execution_log
|
||||
assert "middleware2_start" in execution_log
|
||||
assert "middleware2_end" in execution_log
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Remove all middleware
|
||||
agent.middleware = []
|
||||
|
||||
# Fourth execution - should use no middleware
|
||||
await agent.run("Test message 4")
|
||||
assert len(execution_log) == 0
|
||||
|
||||
async def test_middleware_dynamic_rebuild_streaming(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that middleware pipeline is rebuilt for streaming when agent.middleware collection is modified."""
|
||||
execution_log: list[str] = []
|
||||
|
||||
# Create agent with initial middleware
|
||||
middleware1 = self.TrackingAgentMiddleware("stream_middleware1", execution_log)
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1])
|
||||
|
||||
# First streaming execution
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream("Test stream message 1"):
|
||||
updates.append(update)
|
||||
|
||||
assert "stream_middleware1_start" in execution_log
|
||||
assert "stream_middleware1_end" in execution_log
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Modify the middleware collection
|
||||
middleware2 = self.TrackingAgentMiddleware("stream_middleware2", execution_log)
|
||||
agent.middleware = [middleware2]
|
||||
|
||||
# Second streaming execution - should use only middleware2
|
||||
updates = []
|
||||
async for update in agent.run_stream("Test stream message 2"):
|
||||
updates.append(update)
|
||||
|
||||
assert "stream_middleware1_start" not in execution_log
|
||||
assert "stream_middleware1_end" not in execution_log
|
||||
assert "stream_middleware2_start" in execution_log
|
||||
assert "stream_middleware2_end" in execution_log
|
||||
|
||||
async def test_middleware_order_change_detection(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that changing the order of middleware is detected and applied."""
|
||||
execution_log: list[str] = []
|
||||
|
||||
middleware1 = self.TrackingAgentMiddleware("first", execution_log)
|
||||
middleware2 = self.TrackingAgentMiddleware("second", execution_log)
|
||||
|
||||
# Create agent with middleware in order [first, second]
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2])
|
||||
|
||||
# First execution
|
||||
await agent.run("Test message 1")
|
||||
assert execution_log == ["first_start", "second_start", "second_end", "first_end"]
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Change order to [second, first]
|
||||
agent.middleware = [middleware2, middleware1]
|
||||
|
||||
# Second execution - should reflect new order
|
||||
await agent.run("Test message 2")
|
||||
assert execution_log == ["second_start", "first_start", "first_end", "second_end"]
|
||||
|
||||
|
||||
class TestRunLevelMiddleware:
|
||||
"""Test cases for run-level middleware functionality."""
|
||||
|
||||
class TrackingAgentMiddleware(AgentMiddleware):
|
||||
"""Test middleware that tracks execution."""
|
||||
|
||||
def __init__(self, name: str, execution_log: list[str]):
|
||||
self.name = name
|
||||
self.execution_log = execution_log
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
self.execution_log.append(f"{self.name}_start")
|
||||
await next(context)
|
||||
self.execution_log.append(f"{self.name}_end")
|
||||
|
||||
async def test_run_level_middleware_isolation(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that run-level middleware is isolated between multiple runs."""
|
||||
execution_log: list[str] = []
|
||||
|
||||
# Create agent without any agent-level middleware
|
||||
agent = ChatAgent(chat_client=chat_client)
|
||||
|
||||
# Create run-level middleware
|
||||
run_middleware1 = self.TrackingAgentMiddleware("run1", execution_log)
|
||||
run_middleware2 = self.TrackingAgentMiddleware("run2", execution_log)
|
||||
|
||||
# First run with run_middleware1
|
||||
await agent.run("Test message 1", middleware=[run_middleware1])
|
||||
assert execution_log == ["run1_start", "run1_end"]
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Second run with run_middleware2 - should not see run_middleware1
|
||||
await agent.run("Test message 2", middleware=[run_middleware2])
|
||||
assert execution_log == ["run2_start", "run2_end"]
|
||||
assert "run1_start" not in execution_log
|
||||
assert "run1_end" not in execution_log
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Third run with no middleware - should not see any middleware execution
|
||||
await agent.run("Test message 3")
|
||||
assert execution_log == []
|
||||
|
||||
# Clear execution log
|
||||
execution_log.clear()
|
||||
|
||||
# Fourth run with both run middlewares - should see both
|
||||
await agent.run("Test message 4", middleware=[run_middleware1, run_middleware2])
|
||||
assert execution_log == ["run1_start", "run2_start", "run2_end", "run1_end"]
|
||||
|
||||
async def test_agent_plus_run_middleware_execution_order(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that agent middleware executes first, followed by run middleware."""
|
||||
execution_log: list[str] = []
|
||||
metadata_log: list[str] = []
|
||||
|
||||
class MetadataAgentMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_log.append(f"{self.name}_start")
|
||||
# Set metadata to pass information to run middleware
|
||||
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
|
||||
await next(context)
|
||||
execution_log.append(f"{self.name}_end")
|
||||
|
||||
class MetadataRunMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_log.append(f"{self.name}_start")
|
||||
# Read metadata set by agent middleware
|
||||
for key, value in context.metadata.items():
|
||||
metadata_log.append(f"{self.name}_reads_{key}:{value}")
|
||||
# Set run-level metadata
|
||||
context.metadata[f"{self.name}_key"] = f"{self.name}_value"
|
||||
await next(context)
|
||||
execution_log.append(f"{self.name}_end")
|
||||
|
||||
# Create agent with agent-level middleware
|
||||
agent_middleware = MetadataAgentMiddleware("agent")
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[agent_middleware])
|
||||
|
||||
# Create run-level middleware
|
||||
run_middleware = MetadataRunMiddleware("run")
|
||||
|
||||
# Execute with both agent and run middleware
|
||||
await agent.run("Test message", middleware=[run_middleware])
|
||||
|
||||
# Verify execution order: agent middleware wraps run middleware
|
||||
expected_order = ["agent_start", "run_start", "run_end", "agent_end"]
|
||||
assert execution_log == expected_order
|
||||
|
||||
# Verify that run middleware can read agent middleware metadata
|
||||
assert "run_reads_agent_key:agent_value" in metadata_log
|
||||
|
||||
async def test_run_level_middleware_non_streaming(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test run-level middleware with non-streaming execution."""
|
||||
execution_log: list[str] = []
|
||||
|
||||
# Create agent without agent-level middleware
|
||||
agent = ChatAgent(chat_client=chat_client)
|
||||
|
||||
# Create run-level middleware
|
||||
run_middleware = self.TrackingAgentMiddleware("run_nonstream", execution_log)
|
||||
|
||||
# Execute non-streaming with run middleware
|
||||
response = await agent.run("Test non-streaming", middleware=[run_middleware])
|
||||
|
||||
# Verify response is correct
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert "test response" in response.messages[0].text
|
||||
|
||||
# Verify middleware was executed
|
||||
assert execution_log == ["run_nonstream_start", "run_nonstream_end"]
|
||||
|
||||
async def test_run_level_middleware_streaming(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test run-level middleware with streaming execution."""
|
||||
execution_log: list[str] = []
|
||||
streaming_flags: list[bool] = []
|
||||
|
||||
class StreamingTrackingMiddleware(AgentMiddleware):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_log.append(f"{self.name}_start")
|
||||
streaming_flags.append(context.is_streaming)
|
||||
await next(context)
|
||||
execution_log.append(f"{self.name}_end")
|
||||
|
||||
# Create agent without agent-level middleware
|
||||
agent = ChatAgent(chat_client=chat_client)
|
||||
|
||||
# Set up mock streaming responses
|
||||
chat_client.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT),
|
||||
ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT),
|
||||
]
|
||||
]
|
||||
|
||||
# Create run-level middleware
|
||||
run_middleware = StreamingTrackingMiddleware("run_stream")
|
||||
|
||||
# Execute streaming with run middleware
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in agent.run_stream("Test streaming", middleware=[run_middleware]):
|
||||
updates.append(update)
|
||||
|
||||
# Verify streaming response
|
||||
assert len(updates) == 2
|
||||
assert updates[0].text == "Stream"
|
||||
assert updates[1].text == " response"
|
||||
|
||||
# Verify middleware was executed with correct streaming flag
|
||||
assert execution_log == ["run_stream_start", "run_stream_end"]
|
||||
assert streaming_flags == [True] # Context should indicate streaming
|
||||
|
||||
async def test_agent_and_run_level_both_agent_and_function_middleware(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test complete scenario with agent and function middleware at both agent-level and run-level."""
|
||||
execution_log: list[str] = []
|
||||
|
||||
# Agent-level middleware
|
||||
class AgentLevelAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_log.append("agent_level_agent_start")
|
||||
context.metadata["agent_level_agent"] = "processed"
|
||||
await next(context)
|
||||
execution_log.append("agent_level_agent_end")
|
||||
|
||||
class AgentLevelFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_log.append("agent_level_function_start")
|
||||
context.metadata["agent_level_function"] = "processed"
|
||||
await next(context)
|
||||
execution_log.append("agent_level_function_end")
|
||||
|
||||
# Run-level middleware
|
||||
class RunLevelAgentMiddleware(AgentMiddleware):
|
||||
async def process(
|
||||
self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_log.append("run_level_agent_start")
|
||||
# Verify agent-level middleware metadata is available
|
||||
assert "agent_level_agent" in context.metadata
|
||||
context.metadata["run_level_agent"] = "processed"
|
||||
await next(context)
|
||||
execution_log.append("run_level_agent_end")
|
||||
|
||||
class RunLevelFunctionMiddleware(FunctionMiddleware):
|
||||
async def process(
|
||||
self,
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
execution_log.append("run_level_function_start")
|
||||
# Verify agent-level function middleware metadata is available
|
||||
assert "agent_level_function" in context.metadata
|
||||
context.metadata["run_level_function"] = "processed"
|
||||
await next(context)
|
||||
execution_log.append("run_level_function_end")
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
def custom_tool(message: str) -> str:
|
||||
execution_log.append("tool_executed")
|
||||
return f"Tool response: {message}"
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="test_call",
|
||||
name="custom_tool",
|
||||
arguments='{"message": "test"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
chat_client.responses = [function_call_response, final_response]
|
||||
|
||||
# Create agent with agent-level middleware
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[AgentLevelAgentMiddleware(), AgentLevelFunctionMiddleware()],
|
||||
tools=[custom_tool],
|
||||
)
|
||||
|
||||
# Execute with run-level middleware
|
||||
response = await agent.run(
|
||||
"Test message",
|
||||
middleware=[RunLevelAgentMiddleware(), RunLevelFunctionMiddleware()],
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert chat_client.call_count == 2 # Function call + final response
|
||||
|
||||
expected_order = [
|
||||
"agent_level_agent_start",
|
||||
"run_level_agent_start",
|
||||
"agent_level_function_start",
|
||||
"run_level_function_start",
|
||||
"tool_executed",
|
||||
"run_level_function_end",
|
||||
"agent_level_function_end",
|
||||
"run_level_agent_end",
|
||||
"agent_level_agent_end",
|
||||
]
|
||||
assert execution_log == expected_order
|
||||
|
||||
# Verify function call and result are in the response
|
||||
all_contents = [content for message in response.messages for content in message.contents]
|
||||
function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)]
|
||||
function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)]
|
||||
|
||||
assert len(function_calls) == 1
|
||||
assert len(function_results) == 1
|
||||
assert function_calls[0].name == "custom_tool"
|
||||
assert function_results[0].call_id == function_calls[0].call_id
|
||||
assert function_results[0].result is not None
|
||||
assert "Tool response: test" in str(function_results[0].result)
|
||||
|
||||
|
||||
class TestMiddlewareDecoratorLogic:
|
||||
"""Test the middleware decorator and type annotation logic."""
|
||||
|
||||
async def test_decorator_and_type_match(self, chat_client: MockChatClient) -> None:
|
||||
"""Both decorator and parameter type specified and match."""
|
||||
|
||||
execution_order: list[str] = []
|
||||
|
||||
@agent_middleware
|
||||
async def matching_agent_middleware(
|
||||
context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("decorator_type_match_agent")
|
||||
await next(context)
|
||||
|
||||
@function_middleware
|
||||
async def matching_function_middleware(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("decorator_type_match_function")
|
||||
await next(context)
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
def custom_tool(message: str) -> str:
|
||||
execution_order.append("tool_executed")
|
||||
return f"Tool response: {message}"
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="test_call",
|
||||
name="custom_tool",
|
||||
arguments='{"message": "test"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
chat_client.responses = [function_call_response, final_response]
|
||||
|
||||
# Should work without errors
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
middleware=[matching_agent_middleware, matching_function_middleware],
|
||||
tools=[custom_tool],
|
||||
)
|
||||
|
||||
response = await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
assert response is not None
|
||||
assert "decorator_type_match_agent" in execution_order
|
||||
assert "decorator_type_match_function" in execution_order
|
||||
|
||||
async def test_decorator_and_type_mismatch(self, chat_client: MockChatClient) -> None:
|
||||
"""Both decorator and parameter type specified but don't match."""
|
||||
|
||||
# This will cause a type error at decoration time, so we need to test differently
|
||||
# Should raise ValueError due to mismatch during agent creation
|
||||
with pytest.raises(ValueError, match="Middleware type mismatch"):
|
||||
|
||||
@agent_middleware # type: ignore[arg-type]
|
||||
async def mismatched_middleware(
|
||||
context: FunctionInvocationContext, # Wrong type for @agent_middleware
|
||||
next: Any,
|
||||
) -> None:
|
||||
await next(context)
|
||||
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[mismatched_middleware])
|
||||
await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
async def test_only_decorator_specified(self, chat_client: Any) -> None:
|
||||
"""Only decorator specified - rely on decorator."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
@agent_middleware
|
||||
async def decorator_only_agent(context: Any, next: Any) -> None: # No type annotation
|
||||
execution_order.append("decorator_only_agent")
|
||||
await next(context)
|
||||
|
||||
@function_middleware
|
||||
async def decorator_only_function(context: Any, next: Any) -> None: # No type annotation
|
||||
execution_order.append("decorator_only_function")
|
||||
await next(context)
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
def custom_tool(message: str) -> str:
|
||||
execution_order.append("tool_executed")
|
||||
return f"Tool response: {message}"
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="test_call",
|
||||
name="custom_tool",
|
||||
arguments='{"message": "test"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
chat_client.responses = [function_call_response, final_response]
|
||||
|
||||
# Should work - relies on decorator
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client, middleware=[decorator_only_agent, decorator_only_function], tools=[custom_tool]
|
||||
)
|
||||
|
||||
response = await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
assert response is not None
|
||||
assert "decorator_only_agent" in execution_order
|
||||
assert "decorator_only_function" in execution_order
|
||||
|
||||
async def test_only_type_specified(self, chat_client: Any) -> None:
|
||||
"""Only parameter type specified - rely on types."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
# No decorator
|
||||
async def type_only_agent(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
execution_order.append("type_only_agent")
|
||||
await next(context)
|
||||
|
||||
# No decorator
|
||||
async def type_only_function(
|
||||
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
execution_order.append("type_only_function")
|
||||
await next(context)
|
||||
|
||||
# Create tool function for testing function middleware
|
||||
def custom_tool(message: str) -> str:
|
||||
execution_order.append("tool_executed")
|
||||
return f"Tool response: {message}"
|
||||
|
||||
# Set up mock to return a function call first, then a regular response
|
||||
function_call_response = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="test_call",
|
||||
name="custom_tool",
|
||||
arguments='{"message": "test"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
final_response = ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Final response")])
|
||||
chat_client.responses = [function_call_response, final_response]
|
||||
|
||||
# Should work - relies on type annotations
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client, middleware=[type_only_agent, type_only_function], tools=[custom_tool]
|
||||
)
|
||||
|
||||
response = await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
assert response is not None
|
||||
assert "type_only_agent" in execution_order
|
||||
assert "type_only_function" in execution_order
|
||||
|
||||
async def test_neither_decorator_nor_type(self, chat_client: Any) -> None:
|
||||
"""Neither decorator nor parameter type specified - should throw exception."""
|
||||
|
||||
async def no_info_middleware(context: Any, next: Any) -> None: # No decorator, no type
|
||||
await next(context)
|
||||
|
||||
# Should raise ValueError
|
||||
with pytest.raises(ValueError, match="Cannot determine middleware type"):
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[no_info_middleware])
|
||||
await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
async def test_insufficient_parameters_error(self, chat_client: Any) -> None:
|
||||
"""Test that middleware with insufficient parameters raises an error."""
|
||||
from agent_framework import ChatAgent, agent_middleware
|
||||
|
||||
# Should raise ValueError about insufficient parameters
|
||||
with pytest.raises(ValueError, match="must have at least 2 parameters"):
|
||||
|
||||
@agent_middleware # type: ignore[arg-type]
|
||||
async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter
|
||||
pass
|
||||
|
||||
agent = ChatAgent(chat_client=chat_client, middleware=[insufficient_params_middleware])
|
||||
await agent.run([ChatMessage(role=Role.USER, text="test")])
|
||||
|
||||
async def test_decorator_markers_preserved(self) -> None:
|
||||
"""Test that decorator markers are properly set on functions."""
|
||||
|
||||
@agent_middleware
|
||||
async def test_agent_middleware(context: Any, next: Any) -> None:
|
||||
pass
|
||||
|
||||
@function_middleware
|
||||
async def test_function_middleware(context: Any, next: Any) -> None:
|
||||
pass
|
||||
|
||||
# Check that decorator markers were set
|
||||
assert hasattr(test_agent_middleware, "_middleware_type")
|
||||
assert test_agent_middleware._middleware_type == MiddlewareType.AGENT # type: ignore[attr-defined]
|
||||
|
||||
assert hasattr(test_function_middleware, "_middleware_type")
|
||||
assert test_function_middleware._middleware_type == MiddlewareType.FUNCTION # type: ignore[attr-defined]
|
||||
|
||||
@@ -0,0 +1,271 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentMiddleware,
|
||||
AgentRunContext,
|
||||
AgentRunResponse,
|
||||
FunctionInvocationContext,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Agent-Level and Run-Level Middleware Example
|
||||
|
||||
This sample demonstrates the difference between agent-level and run-level middleware:
|
||||
|
||||
- Agent-level middleware: Applied to ALL runs of the agent (persistent across runs)
|
||||
- Run-level middleware: Applied to specific runs only (isolated per run)
|
||||
|
||||
The example shows:
|
||||
1. Agent-level security middleware that validates all requests
|
||||
2. Agent-level performance monitoring across all runs
|
||||
3. Run-level context middleware for specific use cases (high priority, debugging)
|
||||
4. Run-level caching middleware for expensive operations
|
||||
|
||||
Execution order: Agent middleware (outermost) -> Run middleware (innermost) -> Agent execution
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
# Agent-level middleware (applied to ALL runs)
|
||||
class SecurityAgentMiddleware(AgentMiddleware):
|
||||
"""Agent-level security middleware that validates all requests."""
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
print("[SecurityMiddleware] Checking security for all requests...")
|
||||
|
||||
# Check for security violations in the last user message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
if last_message and last_message.text:
|
||||
query = last_message.text.lower()
|
||||
if any(word in query for word in ["password", "secret", "credentials"]):
|
||||
print("[SecurityMiddleware] Security violation detected! Blocking request.")
|
||||
return # Don't call next() to prevent execution
|
||||
|
||||
print("[SecurityMiddleware] Security check passed.")
|
||||
context.metadata["security_validated"] = True
|
||||
await next(context)
|
||||
|
||||
|
||||
async def performance_monitor_middleware(
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent-level performance monitoring for all runs."""
|
||||
print("[PerformanceMonitor] Starting performance monitoring...")
|
||||
start_time = time.time()
|
||||
|
||||
await next(context)
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print(f"[PerformanceMonitor] Total execution time: {duration:.3f}s")
|
||||
context.metadata["execution_time"] = duration
|
||||
|
||||
|
||||
# Run-level middleware (applied to specific runs only)
|
||||
class HighPriorityMiddleware(AgentMiddleware):
|
||||
"""Run-level middleware for high priority requests."""
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
print("[HighPriority] Processing high priority request with expedited handling...")
|
||||
|
||||
# Read metadata set by agent-level middleware
|
||||
if context.metadata.get("security_validated"):
|
||||
print("[HighPriority] Security validation confirmed from agent middleware")
|
||||
|
||||
# Set high priority flag
|
||||
context.metadata["priority"] = "high"
|
||||
context.metadata["expedited"] = True
|
||||
|
||||
await next(context)
|
||||
print("[HighPriority] High priority processing completed")
|
||||
|
||||
|
||||
async def debugging_middleware(
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Run-level debugging middleware for troubleshooting specific runs."""
|
||||
print("[Debug] Debug mode enabled for this run")
|
||||
print(f"[Debug] Messages count: {len(context.messages)}")
|
||||
print(f"[Debug] Is streaming: {context.is_streaming}")
|
||||
|
||||
# Log existing metadata from agent middleware
|
||||
if context.metadata:
|
||||
print(f"[Debug] Existing metadata: {context.metadata}")
|
||||
|
||||
context.metadata["debug_enabled"] = True
|
||||
|
||||
await next(context)
|
||||
|
||||
print("[Debug] Debug information collected")
|
||||
|
||||
|
||||
class CachingMiddleware(AgentMiddleware):
|
||||
"""Run-level caching middleware for expensive operations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, AgentRunResponse] = {}
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
# Create a simple cache key from the last message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
|
||||
|
||||
if cache_key in self.cache:
|
||||
print(f"[Cache] Cache HIT for: '{cache_key[:30]}...'")
|
||||
context.result = self.cache[cache_key] # type: ignore
|
||||
return # Don't call next(), return cached result
|
||||
|
||||
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
|
||||
context.metadata["cache_key"] = cache_key
|
||||
|
||||
await next(context)
|
||||
|
||||
# Cache the result if we have one
|
||||
if context.result:
|
||||
self.cache[cache_key] = context.result # type: ignore
|
||||
print("[Cache] Result cached for future use")
|
||||
|
||||
|
||||
async def function_logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that logs all function calls."""
|
||||
function_name = context.function.name
|
||||
args = context.arguments
|
||||
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
|
||||
|
||||
await next(context)
|
||||
|
||||
print(f"[FunctionLog] Function {function_name} completed")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating agent-level and run-level middleware."""
|
||||
print("=== Agent-Level and Run-Level Middleware Example ===\n")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
# Agent-level middleware: applied to ALL runs
|
||||
middleware=[
|
||||
SecurityAgentMiddleware(),
|
||||
performance_monitor_middleware,
|
||||
function_logging_middleware,
|
||||
],
|
||||
) as agent,
|
||||
):
|
||||
print("Agent created with agent-level middleware:")
|
||||
print(" - SecurityMiddleware (blocks sensitive requests)")
|
||||
print(" - PerformanceMonitor (tracks execution time)")
|
||||
print(" - FunctionLogging (logs all function calls)")
|
||||
print()
|
||||
|
||||
# Run 1: Normal query with no run-level middleware
|
||||
print("=" * 60)
|
||||
print("RUN 1: Normal query (agent-level middleware only)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather like in Paris?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 2: High priority request with run-level middleware
|
||||
print("=" * 60)
|
||||
print("RUN 2: High priority request (agent + run-level middleware)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather in Tokyo? This is urgent!"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=HighPriorityMiddleware(), # Run-level middleware
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 3: Debug mode with run-level debugging middleware
|
||||
print("=" * 60)
|
||||
print("RUN 3: Debug mode (agent + run-level debugging)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather in London?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=[debugging_middleware], # Run-level middleware
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 4: Multiple run-level middleware
|
||||
print("=" * 60)
|
||||
print("RUN 4: Multiple run-level middleware (caching + debug)")
|
||||
print("=" * 60)
|
||||
caching = CachingMiddleware()
|
||||
query = "What's the weather in New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=[caching, debugging_middleware], # Multiple run-level middleware
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 5: Test cache hit with same query
|
||||
print("=" * 60)
|
||||
print("RUN 5: Test cache hit (same query as Run 4)")
|
||||
print("=" * 60)
|
||||
print(f"User: {query}") # Same query as Run 4
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=[caching], # Same caching middleware instance
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 6: Security violation test
|
||||
print("=" * 60)
|
||||
print("RUN 6: Security test (should be blocked by agent middleware)")
|
||||
print("=" * 60)
|
||||
query = "What's the secret weather password for Berlin?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
|
||||
print()
|
||||
|
||||
# Run 7: Normal query again (no run-level middleware interference)
|
||||
print("=" * 60)
|
||||
print("RUN 7: Normal query again (agent-level middleware only)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather in Sydney?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
from agent_framework import (
|
||||
agent_middleware,
|
||||
function_middleware,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
"""
|
||||
Decorator Middleware Example
|
||||
|
||||
This sample demonstrates how to use @agent_middleware and @function_middleware decorators
|
||||
to explicitly mark middleware functions without requiring type annotations.
|
||||
|
||||
The framework supports the following middleware detection scenarios:
|
||||
|
||||
1. Both decorator and parameter type specified:
|
||||
- Validates that they match (e.g., @agent_middleware with AgentRunContext)
|
||||
- Throws exception if they don't match for safety
|
||||
|
||||
2. Only decorator specified:
|
||||
- Relies on decorator to determine middleware type
|
||||
- No type annotations needed - framework handles context types automatically
|
||||
|
||||
3. Only parameter type specified:
|
||||
- Uses type annotations (AgentRunContext, FunctionInvocationContext) for detection
|
||||
|
||||
4. Neither decorator nor parameter type specified:
|
||||
- Throws exception requiring either decorator or type annotation
|
||||
- Prevents ambiguous middleware that can't be properly classified
|
||||
|
||||
Key benefits of decorator approach:
|
||||
- No type annotations needed (simpler syntax)
|
||||
- Explicit middleware type declaration
|
||||
- Clear intent in code
|
||||
- Prevents type mismatches
|
||||
"""
|
||||
|
||||
|
||||
def get_current_time() -> str:
|
||||
"""Get the current time."""
|
||||
return f"Current time is {datetime.datetime.now().strftime('%H:%M:%S')}"
|
||||
|
||||
|
||||
@agent_middleware # Decorator marks this as agent middleware - no type annotations needed
|
||||
async def simple_agent_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Agent middleware that runs before and after agent execution."""
|
||||
print("[Agent Middleware] Before agent execution")
|
||||
await next(context)
|
||||
print("[Agent Middleware] After agent execution")
|
||||
|
||||
|
||||
@function_middleware # Decorator marks this as function middleware - no type annotations needed
|
||||
async def simple_function_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Function middleware that runs before and after function calls."""
|
||||
print(f"[Function Middleware] Before calling: {context.function.name}") # type: ignore
|
||||
await next(context)
|
||||
print(f"[Function Middleware] After calling: {context.function.name}") # type: ignore
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating decorator-based middleware."""
|
||||
print("=== Decorator Middleware Example ===")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="TimeAgent",
|
||||
instructions="You are a helpful time assistant. Call get_current_time when asked about time.",
|
||||
tools=get_current_time,
|
||||
middleware=[simple_agent_middleware, simple_function_middleware],
|
||||
) as agent,
|
||||
):
|
||||
query = "What time is it?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentMiddleware,
|
||||
AgentRunContext,
|
||||
AgentRunResponse,
|
||||
ChatMessage,
|
||||
Role,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Middleware Termination Example
|
||||
|
||||
This sample demonstrates how middleware can terminate execution using the `context.terminate` flag.
|
||||
The example includes:
|
||||
|
||||
- PreTerminationMiddleware: Terminates execution before calling next() to prevent agent processing
|
||||
- PostTerminationMiddleware: Allows processing to complete but terminates further execution
|
||||
|
||||
This is useful for implementing security checks, rate limiting, or early exit conditions.
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
class PreTerminationMiddleware(AgentMiddleware):
|
||||
"""Middleware that terminates execution before calling the agent."""
|
||||
|
||||
def __init__(self, blocked_words: list[str]):
|
||||
self.blocked_words = [word.lower() for word in blocked_words]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Check if the user message contains any blocked words
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
if last_message and last_message.text:
|
||||
query = last_message.text.lower()
|
||||
for blocked_word in self.blocked_words:
|
||||
if blocked_word in query:
|
||||
print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.")
|
||||
|
||||
# Set a custom response
|
||||
context.result = AgentRunResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
text=(
|
||||
f"Sorry, I cannot process requests containing '{blocked_word}'. "
|
||||
"Please rephrase your question."
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Set terminate flag to prevent further processing
|
||||
context.terminate = True
|
||||
break
|
||||
|
||||
await next(context)
|
||||
|
||||
|
||||
class PostTerminationMiddleware(AgentMiddleware):
|
||||
"""Middleware that allows processing but terminates after reaching max responses across multiple runs."""
|
||||
|
||||
def __init__(self, max_responses: int = 1):
|
||||
self.max_responses = max_responses
|
||||
self.response_count = 0
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
|
||||
|
||||
# Check if we should terminate before processing
|
||||
if self.response_count >= self.max_responses:
|
||||
print(
|
||||
f"[PostTerminationMiddleware] Maximum responses ({self.max_responses}) reached. "
|
||||
"Terminating further processing."
|
||||
)
|
||||
context.terminate = True
|
||||
|
||||
# Allow the agent to process normally
|
||||
await next(context)
|
||||
|
||||
# Increment response count after processing
|
||||
self.response_count += 1
|
||||
|
||||
|
||||
async def pre_termination_middleware() -> None:
|
||||
"""Demonstrate pre-termination middleware that blocks requests with certain words."""
|
||||
print("\n--- Example 1: Pre-termination Middleware ---")
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
middleware=PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]),
|
||||
) as agent,
|
||||
):
|
||||
# Test with normal query
|
||||
print("\n1. Normal query:")
|
||||
query = "What's the weather like in Seattle?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
# Test with blocked word
|
||||
print("\n2. Query with blocked word:")
|
||||
query = "What's the bad weather in New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
|
||||
async def post_termination_middleware() -> None:
|
||||
"""Demonstrate post-termination middleware that limits responses across multiple runs."""
|
||||
print("\n--- Example 2: Post-termination Middleware ---")
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
middleware=PostTerminationMiddleware(max_responses=1),
|
||||
) as agent,
|
||||
):
|
||||
# First run (should work)
|
||||
print("\n1. First run:")
|
||||
query = "What's the weather in Paris?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
# Second run (should be terminated by middleware)
|
||||
print("\n2. Second run (should be terminated):")
|
||||
query = "What about the weather in London?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
|
||||
# Third run (should also be terminated)
|
||||
print("\n3. Third run (should also be terminated):")
|
||||
query = "And New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating middleware termination functionality."""
|
||||
print("=== Middleware Termination Example ===")
|
||||
await pre_termination_middleware()
|
||||
await post_termination_middleware()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user