# Copyright (c) Microsoft. All rights reserved. import asyncio import time from collections.abc import Awaitable, Callable from random import randint from typing import Annotated from agent_framework import ( Agent, AgentContext, AgentMiddleware, AgentResponse, FunctionInvocationContext, FunctionMiddleware, Message, tool, ) from agent_framework.foundry import FoundryChatClient from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv from pydantic import Field # Load environment variables from .env file load_dotenv() """ Class-based MiddlewareTypes Example This sample demonstrates how to implement middleware using class-based approach by inheriting from AgentMiddleware and FunctionMiddleware base classes. The example includes: - SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests containing sensitive information like passwords or secrets - LoggingFunctionMiddleware: Logs function execution details including timing and parameters This approach is useful when you need stateful middleware or complex logic that benefits from object-oriented design patterns. """ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py # and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." class SecurityAgentMiddleware(AgentMiddleware): """Agent middleware that checks for security violations.""" async def process( self, context: AgentContext, call_next: Callable[[], Awaitable[None]], ) -> None: # Check for potential security violations in the query # Look at the last user message last_message = context.messages[-1] if context.messages else None if last_message and last_message.text: query = last_message.text if "password" in query.lower() or "secret" in query.lower(): print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.") # Override the result with warning message context.result = AgentResponse( messages=[Message("assistant", ["Detected sensitive information, the request is blocked."])] ) # Simply don't call call_next() to prevent execution return print("[SecurityAgentMiddleware] Security check passed.") await call_next() class LoggingFunctionMiddleware(FunctionMiddleware): """Function middleware that logs function calls.""" async def process( self, context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]], ) -> None: function_name = context.function.name print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.") start_time = time.time() await call_next() end_time = time.time() duration = end_time - start_time print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.") async def main() -> None: """Example demonstrating class-based middleware.""" print("=== Class-based MiddlewareTypes Example ===") # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, Agent( client=FoundryChatClient(credential=credential), name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=get_weather, middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()], ) as agent, ): # Test with normal query print("\n--- Normal Query ---") query = "What's the weather like in Seattle?" print(f"User: {query}") result = await agent.run(query) print(f"Agent: {result.text}\n") # Test with security-related query print("--- Security Test ---") query = "What's the password for the weather service?" print(f"User: {query}") result = await agent.run(query) print(f"Agent: {result.text}\n") if __name__ == "__main__": asyncio.run(main())