Python: Agent and Function middleware (#770)

* Initial middleware implementation

* Small fixes

* Small updates

* Small updates in samples

* Moved middleware functionality to decorator

* Removed obsolete file

* Renamed AgentInvocationContext to AzureRunContext

* Added unit tests

* Small settings update for test discovery in VS Code

* Added unit tests

* Reverted changes in environment settings

* Added context result override

* Renaming and updates to logic

* Added more samples

* Updated DEV_SETUP.md

* Addressed PR feedback

* Addressed PR feedback

* Removed unused parameter

* Small fix

* Small fix in telemetry logic

* Revert "Small fix in telemetry logic"

This reverts commit 6f82660d2d.

* Small fix

---------

Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2025-09-18 16:30:05 -07:00
committed by GitHub
Unverified
parent 538be4c149
commit 99860a5d07
16 changed files with 3071 additions and 40 deletions
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentMiddleware,
AgentRunContext,
AgentRunResponse,
ChatMessage,
FunctionInvocationContext,
FunctionMiddleware,
Role,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Class-based Middleware Example
This sample demonstrates how to implement middleware using class-based approach by inheriting
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
containing sensitive information like passwords or secrets
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
This approach is useful when you need stateful middleware or complex logic that benefits
from object-oriented design patterns.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent middleware that checks for security violations."""
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Override the result with warning message
context.result = AgentRunResponse(
messages=[
ChatMessage(role=Role.ASSISTANT, text="Detected sensitive information, the request is blocked.")
]
)
# Simply don't call next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await next(context)
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function calls."""
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await next(context)
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating class-based middleware."""
print("=== Class-based Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
# Test with security-related query
print("--- Security Test ---")
query = "What's the password for the weather service?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from typing import Annotated
from agent_framework import FunctionInvocationContext
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Exception Handling with Middleware
This sample demonstrates how to use middleware for centralized exception handling in function calls.
The example shows:
- How to catch exceptions thrown by functions and provide graceful error responses
- Overriding function results when errors occur to provide user-friendly messages
- Using middleware to implement retry logic, fallback mechanisms, or error reporting
The middleware catches TimeoutError from an unstable data service and replaces it with
a helpful message for the user, preventing raw exceptions from reaching the end user.
"""
def unstable_data_service(
query: Annotated[str, Field(description="The data query to execute.")],
) -> str:
"""A simulated data service that sometimes throws exceptions."""
# Simulate failure
raise TimeoutError("Data service request timed out")
async def exception_handling_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
function_name = context.function.name
try:
print(f"[ExceptionHandlingMiddleware] Executing function: {function_name}")
await next(context)
print(f"[ExceptionHandlingMiddleware] Function {function_name} completed successfully.")
except TimeoutError as e:
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
# Override function result to provide custom message in response.
context.result = (
"Request Timeout: The data service is taking longer than expected to respond.",
"Respond with message - 'Sorry for the inconvenience, please try again later.'",
)
async def main() -> None:
"""Example demonstrating exception handling with middleware."""
print("=== Exception Handling Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="DataAgent",
instructions="You are a helpful data assistant. Use the data service tool to fetch information for users.",
tools=unstable_data_service,
middleware=exception_handling_middleware,
) as agent,
):
query = "Get user statistics"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,109 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentRunContext,
FunctionInvocationContext,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Function-based Middleware Example
This sample demonstrates how to implement middleware using simple async functions instead of classes.
The example includes:
- Security middleware that validates agent requests for sensitive information
- Logging middleware that tracks function execution timing and parameters
- Performance monitoring to measure execution duration
Function-based middleware is ideal for simple, stateless operations and provides a more
lightweight approach compared to class-based middleware. Both agent and function middleware
can be implemented as async functions that accept context and next parameters.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def security_agent_middleware(
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Agent middleware that checks for security violations."""
# Check for potential security violations in the query
# For this example, we'll check the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Simply don't call next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await next(context)
async def logging_function_middleware(
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Function middleware that logs function calls."""
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await next(context)
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating function-based middleware."""
print("=== Function-based Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[security_agent_middleware, logging_function_middleware],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Tokyo?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}\n")
# Test with security violation
print("--- Security Test ---")
query = "What's the secret weather password?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}\n")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import FunctionInvocationContext
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Result Override with Middleware
This sample demonstrates how to use middleware to intercept and modify function results
after execution. The example shows:
- How to execute the original function first and then modify its result
- Replacing function outputs with custom messages or transformed data
- Using middleware for result filtering, formatting, or enhancement
The weather override middleware lets the original weather function execute normally,
then replaces its result with a custom "perfect weather" message, demonstrating
how middleware can be used for content filtering, A/B testing, or result enhancement.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def weather_override_middleware(
context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
function_name = context.function.name
# Let the original function execute first
await next(context)
# Override the result if it's a weather function
if function_name == "get_weather" and context.result is not None:
original_result = str(context.result)
print(f"[WeatherOverrideMiddleware] Original result: {original_result}")
# Override with a custom message
# It's also possible to override the result before "next()" call if needed
custom_message = (
"Weather Advisory - due to special atmospheric conditions, "
"all locations are experiencing perfect weather today! "
"Temperature is a comfortable 22°C with gentle breezes. "
"Perfect day for outdoor activities!"
)
context.result = custom_message
print(f"[WeatherOverrideMiddleware] Overriding with custom message: {custom_message}")
async def main() -> None:
"""Example demonstrating result override with middleware."""
print("=== Result Override Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant. Use the weather tool to get current conditions.",
tools=get_weather,
middleware=weather_override_middleware,
) as agent,
):
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result}")
if __name__ == "__main__":
asyncio.run(main())