# Copyright (c) Microsoft. All rights reserved. import asyncio import time from collections.abc import Awaitable, Callable from random import randint from typing import Annotated from agent_framework import ( AgentMiddleware, AgentResponse, AgentRunContext, ChatMessage, FunctionInvocationContext, FunctionMiddleware, Role, tool, ) from agent_framework.azure import AzureAIAgentClient from azure.identity.aio import AzureCliCredential from pydantic import Field """ Class-based Middleware Example This sample demonstrates how to implement middleware using class-based approach by inheriting from AgentMiddleware and FunctionMiddleware base classes. The example includes: - SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests containing sensitive information like passwords or secrets - LoggingFunctionMiddleware: Logs function execution details including timing and parameters This approach is useful when you need stateful middleware or complex logic that benefits from object-oriented design patterns. """ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/getting_started/tools/function_tool_with_approval.py and samples/getting_started/tools/function_tool_with_approval_and_threads.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." class SecurityAgentMiddleware(AgentMiddleware): """Agent middleware that checks for security violations.""" async def process( self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]], ) -> None: # Check for potential security violations in the query # Look at the last user message last_message = context.messages[-1] if context.messages else None if last_message and last_message.text: query = last_message.text if "password" in query.lower() or "secret" in query.lower(): print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.") # Override the result with warning message context.result = AgentResponse( messages=[ ChatMessage(role=Role.ASSISTANT, text="Detected sensitive information, the request is blocked.") ] ) # Simply don't call next() to prevent execution return print("[SecurityAgentMiddleware] Security check passed.") await next(context) class LoggingFunctionMiddleware(FunctionMiddleware): """Function middleware that logs function calls.""" async def process( self, context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]], ) -> None: function_name = context.function.name print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.") start_time = time.time() await next(context) end_time = time.time() duration = end_time - start_time print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.") async def main() -> None: """Example demonstrating class-based middleware.""" print("=== Class-based Middleware Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, AzureAIAgentClient(credential=credential).as_agent( name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=get_weather, middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()], ) as agent, ): # Test with normal query print("\n--- Normal Query ---") query = "What's the weather like in Seattle?" print(f"User: {query}") result = await agent.run(query) print(f"Agent: {result.text}\n") # Test with security-related query print("--- Security Test ---") query = "What's the password for the weather service?" print(f"User: {query}") result = await agent.run(query) print(f"Agent: {result.text}\n") if __name__ == "__main__": asyncio.run(main())