Python: Extending middleware capabilities (#844)

* Implemented termination

* Added termination sample

* Allowed middleware pipeline modification

* Added run-level middleware

* Added more validation to function-based middleware

* Added example with function-based decorator approach

* Update python/samples/getting_started/middleware/decorator_middleware.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update python/samples/getting_started/middleware/decorator_middleware.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Small improvements

* Fixed tests

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2025-09-21 21:37:57 -07:00
committed by GitHub
Unverified
parent 08f792e511
commit f61d8abe58
6 changed files with 1680 additions and 119 deletions
@@ -0,0 +1,271 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentMiddleware,
AgentRunContext,
AgentRunResponse,
FunctionInvocationContext,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Agent-Level and Run-Level Middleware Example
This sample demonstrates the difference between agent-level and run-level middleware:
- Agent-level middleware: Applied to ALL runs of the agent (persistent across runs)
- Run-level middleware: Applied to specific runs only (isolated per run)
The example shows:
1. Agent-level security middleware that validates all requests
2. Agent-level performance monitoring across all runs
3. Run-level context middleware for specific use cases (high priority, debugging)
4. Run-level caching middleware for expensive operations
Execution order: Agent middleware (outermost) -> Run middleware (innermost) -> Agent execution
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
# Agent-level middleware (applied to ALL runs)
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent-level security middleware that validates all requests."""
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
print("[SecurityMiddleware] Checking security for all requests...")
# Check for security violations in the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text.lower()
if any(word in query for word in ["password", "secret", "credentials"]):
print("[SecurityMiddleware] Security violation detected! Blocking request.")
return # Don't call next() to prevent execution
print("[SecurityMiddleware] Security check passed.")
context.metadata["security_validated"] = True
await next(context)
async def performance_monitor_middleware(
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Agent-level performance monitoring for all runs."""
print("[PerformanceMonitor] Starting performance monitoring...")
start_time = time.time()
await next(context)
end_time = time.time()
duration = end_time - start_time
print(f"[PerformanceMonitor] Total execution time: {duration:.3f}s")
context.metadata["execution_time"] = duration
# Run-level middleware (applied to specific runs only)
class HighPriorityMiddleware(AgentMiddleware):
"""Run-level middleware for high priority requests."""
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
print("[HighPriority] Processing high priority request with expedited handling...")
# Read metadata set by agent-level middleware
if context.metadata.get("security_validated"):
print("[HighPriority] Security validation confirmed from agent middleware")
# Set high priority flag
context.metadata["priority"] = "high"
context.metadata["expedited"] = True
await next(context)
print("[HighPriority] High priority processing completed")
async def debugging_middleware(
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
"""Run-level debugging middleware for troubleshooting specific runs."""
print("[Debug] Debug mode enabled for this run")
print(f"[Debug] Messages count: {len(context.messages)}")
print(f"[Debug] Is streaming: {context.is_streaming}")
# Log existing metadata from agent middleware
if context.metadata:
print(f"[Debug] Existing metadata: {context.metadata}")
context.metadata["debug_enabled"] = True
await next(context)
print("[Debug] Debug information collected")
class CachingMiddleware(AgentMiddleware):
"""Run-level caching middleware for expensive operations."""
def __init__(self) -> None:
self.cache: dict[str, AgentRunResponse] = {}
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
# Create a simple cache key from the last message
last_message = context.messages[-1] if context.messages else None
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
if cache_key in self.cache:
print(f"[Cache] Cache HIT for: '{cache_key[:30]}...'")
context.result = self.cache[cache_key] # type: ignore
return # Don't call next(), return cached result
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
context.metadata["cache_key"] = cache_key
await next(context)
# Cache the result if we have one
if context.result:
self.cache[cache_key] = context.result # type: ignore
print("[Cache] Result cached for future use")
async def function_logging_middleware(
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Function middleware that logs all function calls."""
function_name = context.function.name
args = context.arguments
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
await next(context)
print(f"[FunctionLog] Function {function_name} completed")
async def main() -> None:
"""Example demonstrating agent-level and run-level middleware."""
print("=== Agent-Level and Run-Level Middleware Example ===\n")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
# Agent-level middleware: applied to ALL runs
middleware=[
SecurityAgentMiddleware(),
performance_monitor_middleware,
function_logging_middleware,
],
) as agent,
):
print("Agent created with agent-level middleware:")
print(" - SecurityMiddleware (blocks sensitive requests)")
print(" - PerformanceMonitor (tracks execution time)")
print(" - FunctionLogging (logs all function calls)")
print()
# Run 1: Normal query with no run-level middleware
print("=" * 60)
print("RUN 1: Normal query (agent-level middleware only)")
print("=" * 60)
query = "What's the weather like in Paris?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 2: High priority request with run-level middleware
print("=" * 60)
print("RUN 2: High priority request (agent + run-level middleware)")
print("=" * 60)
query = "What's the weather in Tokyo? This is urgent!"
print(f"User: {query}")
result = await agent.run(
query,
middleware=HighPriorityMiddleware(), # Run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 3: Debug mode with run-level debugging middleware
print("=" * 60)
print("RUN 3: Debug mode (agent + run-level debugging)")
print("=" * 60)
query = "What's the weather in London?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[debugging_middleware], # Run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 4: Multiple run-level middleware
print("=" * 60)
print("RUN 4: Multiple run-level middleware (caching + debug)")
print("=" * 60)
caching = CachingMiddleware()
query = "What's the weather in New York?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[caching, debugging_middleware], # Multiple run-level middleware
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 5: Test cache hit with same query
print("=" * 60)
print("RUN 5: Test cache hit (same query as Run 4)")
print("=" * 60)
print(f"User: {query}") # Same query as Run 4
result = await agent.run(
query,
middleware=[caching], # Same caching middleware instance
)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
# Run 6: Security violation test
print("=" * 60)
print("RUN 6: Security test (should be blocked by agent middleware)")
print("=" * 60)
query = "What's the secret weather password for Berlin?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
print()
# Run 7: Normal query again (no run-level middleware interference)
print("=" * 60)
print("RUN 7: Normal query again (agent-level middleware only)")
print("=" * 60)
query = "What's the weather in Sydney?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
print()
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,87 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import datetime
from agent_framework import (
agent_middleware,
function_middleware,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
"""
Decorator Middleware Example
This sample demonstrates how to use @agent_middleware and @function_middleware decorators
to explicitly mark middleware functions without requiring type annotations.
The framework supports the following middleware detection scenarios:
1. Both decorator and parameter type specified:
- Validates that they match (e.g., @agent_middleware with AgentRunContext)
- Throws exception if they don't match for safety
2. Only decorator specified:
- Relies on decorator to determine middleware type
- No type annotations needed - framework handles context types automatically
3. Only parameter type specified:
- Uses type annotations (AgentRunContext, FunctionInvocationContext) for detection
4. Neither decorator nor parameter type specified:
- Throws exception requiring either decorator or type annotation
- Prevents ambiguous middleware that can't be properly classified
Key benefits of decorator approach:
- No type annotations needed (simpler syntax)
- Explicit middleware type declaration
- Clear intent in code
- Prevents type mismatches
"""
def get_current_time() -> str:
"""Get the current time."""
return f"Current time is {datetime.datetime.now().strftime('%H:%M:%S')}"
@agent_middleware # Decorator marks this as agent middleware - no type annotations needed
async def simple_agent_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Agent middleware that runs before and after agent execution."""
print("[Agent Middleware] Before agent execution")
await next(context)
print("[Agent Middleware] After agent execution")
@function_middleware # Decorator marks this as function middleware - no type annotations needed
async def simple_function_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
"""Function middleware that runs before and after function calls."""
print(f"[Function Middleware] Before calling: {context.function.name}") # type: ignore
await next(context)
print(f"[Function Middleware] After calling: {context.function.name}") # type: ignore
async def main() -> None:
"""Example demonstrating decorator-based middleware."""
print("=== Decorator Middleware Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="TimeAgent",
instructions="You are a helpful time assistant. Call get_current_time when asked about time.",
tools=get_current_time,
middleware=[simple_agent_middleware, simple_function_middleware],
) as agent,
):
query = "What time is it?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response'}")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,177 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentMiddleware,
AgentRunContext,
AgentRunResponse,
ChatMessage,
Role,
)
from agent_framework.foundry import FoundryChatClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Middleware Termination Example
This sample demonstrates how middleware can terminate execution using the `context.terminate` flag.
The example includes:
- PreTerminationMiddleware: Terminates execution before calling next() to prevent agent processing
- PostTerminationMiddleware: Allows processing to complete but terminates further execution
This is useful for implementing security checks, rate limiting, or early exit conditions.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class PreTerminationMiddleware(AgentMiddleware):
"""Middleware that terminates execution before calling the agent."""
def __init__(self, blocked_words: list[str]):
self.blocked_words = [word.lower() for word in blocked_words]
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
# Check if the user message contains any blocked words
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text.lower()
for blocked_word in self.blocked_words:
if blocked_word in query:
print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.")
# Set a custom response
context.result = AgentRunResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
text=(
f"Sorry, I cannot process requests containing '{blocked_word}'. "
"Please rephrase your question."
),
)
]
)
# Set terminate flag to prevent further processing
context.terminate = True
break
await next(context)
class PostTerminationMiddleware(AgentMiddleware):
"""Middleware that allows processing but terminates after reaching max responses across multiple runs."""
def __init__(self, max_responses: int = 1):
self.max_responses = max_responses
self.response_count = 0
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
# Check if we should terminate before processing
if self.response_count >= self.max_responses:
print(
f"[PostTerminationMiddleware] Maximum responses ({self.max_responses}) reached. "
"Terminating further processing."
)
context.terminate = True
# Allow the agent to process normally
await next(context)
# Increment response count after processing
self.response_count += 1
async def pre_termination_middleware() -> None:
"""Demonstrate pre-termination middleware that blocks requests with certain words."""
print("\n--- Example 1: Pre-termination Middleware ---")
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]),
) as agent,
):
# Test with normal query
print("\n1. Normal query:")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
# Test with blocked word
print("\n2. Query with blocked word:")
query = "What's the bad weather in New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
async def post_termination_middleware() -> None:
"""Demonstrate post-termination middleware that limits responses across multiple runs."""
print("\n--- Example 2: Post-termination Middleware ---")
async with (
AzureCliCredential() as credential,
FoundryChatClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=PostTerminationMiddleware(max_responses=1),
) as agent,
):
# First run (should work)
print("\n1. First run:")
query = "What's the weather in Paris?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
# Second run (should be terminated by middleware)
print("\n2. Second run (should be terminated):")
query = "What about the weather in London?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
# Third run (should also be terminated)
print("\n3. Third run (should also be terminated):")
query = "And New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
async def main() -> None:
"""Example demonstrating middleware termination functionality."""
print("=== Middleware Termination Example ===")
await pre_termination_middleware()
await post_termination_middleware()
if __name__ == "__main__":
asyncio.run(main())