# Copyright (c) Microsoft. All rights reserved. import asyncio from collections.abc import Awaitable, Callable from random import randint from typing import Annotated from agent_framework import ( Agent, ChatContext, ChatMiddleware, ChatResponse, Message, MiddlewareTermination, chat_middleware, tool, ) from agent_framework.foundry import FoundryChatClient from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv from pydantic import Field # Load environment variables from .env file load_dotenv() """ Chat MiddlewareTypes Example This sample demonstrates how to use chat middleware to observe and override inputs sent to AI models. Chat middleware intercepts chat requests before they reach the underlying AI service, allowing you to: 1. Observe and log input messages 2. Modify input messages before sending to AI 3. Override the entire response The example covers: - Class-based chat middleware inheriting from ChatMiddleware - Function-based chat middleware with @chat_middleware decorator - MiddlewareTypes registration at agent level (applies to all runs) - MiddlewareTypes registration at run level (applies to specific run only) """ # NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; # see samples/02-agents/tools/function_tool_with_approval.py # and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], ) -> str: """Get the weather for a given location.""" conditions = ["sunny", "cloudy", "rainy", "stormy"] return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." class InputObserverMiddleware(ChatMiddleware): """Class-based middleware that observes and modifies input messages.""" def __init__(self, replacement: str | None = None): """Initialize with a replacement for user messages.""" self.replacement = replacement async def process( self, context: ChatContext, call_next: Callable[[], Awaitable[None]], ) -> None: """Observe and modify input messages before they are sent to AI.""" print("[InputObserverMiddleware] Observing input messages:") for i, message in enumerate(context.messages): content = message.text if message.text else str(message.contents) print(f" Message {i + 1} ({message.role}): {content}") print(f"[InputObserverMiddleware] Total messages: {len(context.messages)}") # Modify user messages by creating new messages with enhanced text modified_messages: list[Message] = [] modified_count = 0 for message in context.messages: if message.role == "user" and message.text: original_text = message.text updated_text = original_text if self.replacement: updated_text = self.replacement print(f"[InputObserverMiddleware] Updated: '{original_text}' -> '{updated_text}'") modified_message = Message(message.role, [updated_text]) modified_messages.append(modified_message) modified_count += 1 else: modified_messages.append(message) # Replace messages in context context.messages[:] = modified_messages # Continue to next middleware or AI execution await call_next() # Observe that processing is complete print("[InputObserverMiddleware] Processing completed") @chat_middleware async def security_and_override_middleware( context: ChatContext, call_next: Callable[[], Awaitable[None]], ) -> None: """Function-based middleware that implements security filtering and response override.""" print("[SecurityMiddleware] Processing input...") # Security check - block sensitive information blocked_terms = ["password", "secret", "api_key", "token"] for message in context.messages: if message.text: message_lower = message.text.lower() for term in blocked_terms: if term in message_lower: print(f"[SecurityMiddleware] BLOCKED: Found '{term}' in message") # Override the response instead of calling AI context.result = ChatResponse( messages=[ Message( role="assistant", text="I cannot process requests containing sensitive information. " "Please rephrase your question without including passwords, secrets, or other " "sensitive data.", ) ] ) # Set terminate flag to stop execution raise MiddlewareTermination # Continue to next middleware or AI execution await call_next() async def class_based_chat_middleware() -> None: """Demonstrate class-based middleware at agent level.""" print("\n" + "=" * 60) print("Class-based Chat MiddlewareTypes (Agent Level)") print("=" * 60) # For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred # authentication option. async with ( AzureCliCredential() as credential, Agent( client=FoundryChatClient(credential=credential), name="EnhancedChatAgent", instructions="You are a helpful AI assistant.", # Register class-based middleware at agent level (applies to all runs) middleware=[InputObserverMiddleware()], tools=get_weather, ) as agent, ): query = "What's the weather in Seattle?" print(f"User: {query}") result = await agent.run(query) print(f"Final Response: {result.text if result.text else 'No response'}") async def function_based_chat_middleware() -> None: """Demonstrate function-based middleware at agent level.""" print("\n" + "=" * 60) print("Function-based Chat MiddlewareTypes (Agent Level)") print("=" * 60) async with ( AzureCliCredential() as credential, Agent( client=FoundryChatClient(credential=credential), name="FunctionMiddlewareAgent", instructions="You are a helpful AI assistant.", # Register function-based middleware at agent level middleware=[security_and_override_middleware], ) as agent, ): # Scenario with normal query print("\n--- Scenario 1: Normal Query ---") query = "Hello, how are you?" print(f"User: {query}") result = await agent.run(query) print(f"Final Response: {result.text if result.text else 'No response'}") # Scenario with security violation print("\n--- Scenario 2: Security Violation ---") query = "What is my password for this account?" print(f"User: {query}") result = await agent.run(query) print(f"Final Response: {result.text if result.text else 'No response'}") async def run_level_middleware() -> None: """Demonstrate middleware registration at run level.""" print("\n" + "=" * 60) print("Run-level Chat MiddlewareTypes") print("=" * 60) async with ( AzureCliCredential() as credential, Agent( client=FoundryChatClient(credential=credential), name="RunLevelAgent", instructions="You are a helpful AI assistant.", tools=get_weather, # No middleware at agent level ) as agent, ): # Scenario 1: Run without any middleware print("\n--- Scenario 1: No MiddlewareTypes ---") query = "What's the weather in Tokyo?" print(f"User: {query}") result = await agent.run(query) print(f"Response: {result.text if result.text else 'No response'}") # Scenario 2: Run with specific middleware for this call only (both enhancement and security) print("\n--- Scenario 2: With Run-level MiddlewareTypes ---") print(f"User: {query}") result = await agent.run( query, middleware=[ InputObserverMiddleware(replacement="What's the weather in Madrid?"), security_and_override_middleware, ], ) print(f"Response: {result.text if result.text else 'No response'}") # Scenario 3: Security test with run-level middleware print("\n--- Scenario 3: Security Test with Run-level MiddlewareTypes ---") query = "Can you help me with my secret API key?" print(f"User: {query}") result = await agent.run( query, middleware=[security_and_override_middleware], ) print(f"Response: {result.text if result.text else 'No response'}") async def main() -> None: """Run all chat middleware examples.""" print("Chat MiddlewareTypes Examples") print("========================") await class_based_chat_middleware() await function_based_chat_middleware() await run_level_middleware() if __name__ == "__main__": asyncio.run(main())