mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Extending middleware capabilities (#844)
* Implemented termination * Added termination sample * Allowed middleware pipeline modification * Added run-level middleware * Added more validation to function-based middleware * Added example with function-based decorator approach * Update python/samples/getting_started/middleware/decorator_middleware.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update python/samples/getting_started/middleware/decorator_middleware.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Small improvements * Fixed tests --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
08f792e511
commit
f61d8abe58
@@ -0,0 +1,271 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentMiddleware,
|
||||
AgentRunContext,
|
||||
AgentRunResponse,
|
||||
FunctionInvocationContext,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Agent-Level and Run-Level Middleware Example
|
||||
|
||||
This sample demonstrates the difference between agent-level and run-level middleware:
|
||||
|
||||
- Agent-level middleware: Applied to ALL runs of the agent (persistent across runs)
|
||||
- Run-level middleware: Applied to specific runs only (isolated per run)
|
||||
|
||||
The example shows:
|
||||
1. Agent-level security middleware that validates all requests
|
||||
2. Agent-level performance monitoring across all runs
|
||||
3. Run-level context middleware for specific use cases (high priority, debugging)
|
||||
4. Run-level caching middleware for expensive operations
|
||||
|
||||
Execution order: Agent middleware (outermost) -> Run middleware (innermost) -> Agent execution
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
# Agent-level middleware (applied to ALL runs)
|
||||
class SecurityAgentMiddleware(AgentMiddleware):
|
||||
"""Agent-level security middleware that validates all requests."""
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
print("[SecurityMiddleware] Checking security for all requests...")
|
||||
|
||||
# Check for security violations in the last user message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
if last_message and last_message.text:
|
||||
query = last_message.text.lower()
|
||||
if any(word in query for word in ["password", "secret", "credentials"]):
|
||||
print("[SecurityMiddleware] Security violation detected! Blocking request.")
|
||||
return # Don't call next() to prevent execution
|
||||
|
||||
print("[SecurityMiddleware] Security check passed.")
|
||||
context.metadata["security_validated"] = True
|
||||
await next(context)
|
||||
|
||||
|
||||
async def performance_monitor_middleware(
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent-level performance monitoring for all runs."""
|
||||
print("[PerformanceMonitor] Starting performance monitoring...")
|
||||
start_time = time.time()
|
||||
|
||||
await next(context)
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print(f"[PerformanceMonitor] Total execution time: {duration:.3f}s")
|
||||
context.metadata["execution_time"] = duration
|
||||
|
||||
|
||||
# Run-level middleware (applied to specific runs only)
|
||||
class HighPriorityMiddleware(AgentMiddleware):
|
||||
"""Run-level middleware for high priority requests."""
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
print("[HighPriority] Processing high priority request with expedited handling...")
|
||||
|
||||
# Read metadata set by agent-level middleware
|
||||
if context.metadata.get("security_validated"):
|
||||
print("[HighPriority] Security validation confirmed from agent middleware")
|
||||
|
||||
# Set high priority flag
|
||||
context.metadata["priority"] = "high"
|
||||
context.metadata["expedited"] = True
|
||||
|
||||
await next(context)
|
||||
print("[HighPriority] High priority processing completed")
|
||||
|
||||
|
||||
async def debugging_middleware(
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Run-level debugging middleware for troubleshooting specific runs."""
|
||||
print("[Debug] Debug mode enabled for this run")
|
||||
print(f"[Debug] Messages count: {len(context.messages)}")
|
||||
print(f"[Debug] Is streaming: {context.is_streaming}")
|
||||
|
||||
# Log existing metadata from agent middleware
|
||||
if context.metadata:
|
||||
print(f"[Debug] Existing metadata: {context.metadata}")
|
||||
|
||||
context.metadata["debug_enabled"] = True
|
||||
|
||||
await next(context)
|
||||
|
||||
print("[Debug] Debug information collected")
|
||||
|
||||
|
||||
class CachingMiddleware(AgentMiddleware):
|
||||
"""Run-level caching middleware for expensive operations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, AgentRunResponse] = {}
|
||||
|
||||
async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None:
|
||||
# Create a simple cache key from the last message
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
cache_key: str = last_message.text if last_message and last_message.text else "no_message"
|
||||
|
||||
if cache_key in self.cache:
|
||||
print(f"[Cache] Cache HIT for: '{cache_key[:30]}...'")
|
||||
context.result = self.cache[cache_key] # type: ignore
|
||||
return # Don't call next(), return cached result
|
||||
|
||||
print(f"[Cache] Cache MISS for: '{cache_key[:30]}...'")
|
||||
context.metadata["cache_key"] = cache_key
|
||||
|
||||
await next(context)
|
||||
|
||||
# Cache the result if we have one
|
||||
if context.result:
|
||||
self.cache[cache_key] = context.result # type: ignore
|
||||
print("[Cache] Result cached for future use")
|
||||
|
||||
|
||||
async def function_logging_middleware(
|
||||
context: FunctionInvocationContext,
|
||||
next: Callable[[FunctionInvocationContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Function middleware that logs all function calls."""
|
||||
function_name = context.function.name
|
||||
args = context.arguments
|
||||
print(f"[FunctionLog] Calling function: {function_name} with args: {args}")
|
||||
|
||||
await next(context)
|
||||
|
||||
print(f"[FunctionLog] Function {function_name} completed")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating agent-level and run-level middleware."""
|
||||
print("=== Agent-Level and Run-Level Middleware Example ===\n")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
# Agent-level middleware: applied to ALL runs
|
||||
middleware=[
|
||||
SecurityAgentMiddleware(),
|
||||
performance_monitor_middleware,
|
||||
function_logging_middleware,
|
||||
],
|
||||
) as agent,
|
||||
):
|
||||
print("Agent created with agent-level middleware:")
|
||||
print(" - SecurityMiddleware (blocks sensitive requests)")
|
||||
print(" - PerformanceMonitor (tracks execution time)")
|
||||
print(" - FunctionLogging (logs all function calls)")
|
||||
print()
|
||||
|
||||
# Run 1: Normal query with no run-level middleware
|
||||
print("=" * 60)
|
||||
print("RUN 1: Normal query (agent-level middleware only)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather like in Paris?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 2: High priority request with run-level middleware
|
||||
print("=" * 60)
|
||||
print("RUN 2: High priority request (agent + run-level middleware)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather in Tokyo? This is urgent!"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=HighPriorityMiddleware(), # Run-level middleware
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 3: Debug mode with run-level debugging middleware
|
||||
print("=" * 60)
|
||||
print("RUN 3: Debug mode (agent + run-level debugging)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather in London?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=[debugging_middleware], # Run-level middleware
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 4: Multiple run-level middleware
|
||||
print("=" * 60)
|
||||
print("RUN 4: Multiple run-level middleware (caching + debug)")
|
||||
print("=" * 60)
|
||||
caching = CachingMiddleware()
|
||||
query = "What's the weather in New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=[caching, debugging_middleware], # Multiple run-level middleware
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 5: Test cache hit with same query
|
||||
print("=" * 60)
|
||||
print("RUN 5: Test cache hit (same query as Run 4)")
|
||||
print("=" * 60)
|
||||
print(f"User: {query}") # Same query as Run 4
|
||||
result = await agent.run(
|
||||
query,
|
||||
middleware=[caching], # Same caching middleware instance
|
||||
)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
# Run 6: Security violation test
|
||||
print("=" * 60)
|
||||
print("RUN 6: Security test (should be blocked by agent middleware)")
|
||||
print("=" * 60)
|
||||
query = "What's the secret weather password for Berlin?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'Request was blocked by security middleware'}")
|
||||
print()
|
||||
|
||||
# Run 7: Normal query again (no run-level middleware interference)
|
||||
print("=" * 60)
|
||||
print("RUN 7: Normal query again (agent-level middleware only)")
|
||||
print("=" * 60)
|
||||
query = "What's the weather in Sydney?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
from agent_framework import (
|
||||
agent_middleware,
|
||||
function_middleware,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
"""
|
||||
Decorator Middleware Example
|
||||
|
||||
This sample demonstrates how to use @agent_middleware and @function_middleware decorators
|
||||
to explicitly mark middleware functions without requiring type annotations.
|
||||
|
||||
The framework supports the following middleware detection scenarios:
|
||||
|
||||
1. Both decorator and parameter type specified:
|
||||
- Validates that they match (e.g., @agent_middleware with AgentRunContext)
|
||||
- Throws exception if they don't match for safety
|
||||
|
||||
2. Only decorator specified:
|
||||
- Relies on decorator to determine middleware type
|
||||
- No type annotations needed - framework handles context types automatically
|
||||
|
||||
3. Only parameter type specified:
|
||||
- Uses type annotations (AgentRunContext, FunctionInvocationContext) for detection
|
||||
|
||||
4. Neither decorator nor parameter type specified:
|
||||
- Throws exception requiring either decorator or type annotation
|
||||
- Prevents ambiguous middleware that can't be properly classified
|
||||
|
||||
Key benefits of decorator approach:
|
||||
- No type annotations needed (simpler syntax)
|
||||
- Explicit middleware type declaration
|
||||
- Clear intent in code
|
||||
- Prevents type mismatches
|
||||
"""
|
||||
|
||||
|
||||
def get_current_time() -> str:
|
||||
"""Get the current time."""
|
||||
return f"Current time is {datetime.datetime.now().strftime('%H:%M:%S')}"
|
||||
|
||||
|
||||
@agent_middleware # Decorator marks this as agent middleware - no type annotations needed
|
||||
async def simple_agent_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Agent middleware that runs before and after agent execution."""
|
||||
print("[Agent Middleware] Before agent execution")
|
||||
await next(context)
|
||||
print("[Agent Middleware] After agent execution")
|
||||
|
||||
|
||||
@function_middleware # Decorator marks this as function middleware - no type annotations needed
|
||||
async def simple_function_middleware(context, next): # type: ignore - parameters intentionally untyped to demonstrate decorator functionality
|
||||
"""Function middleware that runs before and after function calls."""
|
||||
print(f"[Function Middleware] Before calling: {context.function.name}") # type: ignore
|
||||
await next(context)
|
||||
print(f"[Function Middleware] After calling: {context.function.name}") # type: ignore
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating decorator-based middleware."""
|
||||
print("=== Decorator Middleware Example ===")
|
||||
|
||||
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
|
||||
# authentication option.
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="TimeAgent",
|
||||
instructions="You are a helpful time assistant. Call get_current_time when asked about time.",
|
||||
tools=get_current_time,
|
||||
middleware=[simple_agent_middleware, simple_function_middleware],
|
||||
) as agent,
|
||||
):
|
||||
query = "What time is it?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentMiddleware,
|
||||
AgentRunContext,
|
||||
AgentRunResponse,
|
||||
ChatMessage,
|
||||
Role,
|
||||
)
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Middleware Termination Example
|
||||
|
||||
This sample demonstrates how middleware can terminate execution using the `context.terminate` flag.
|
||||
The example includes:
|
||||
|
||||
- PreTerminationMiddleware: Terminates execution before calling next() to prevent agent processing
|
||||
- PostTerminationMiddleware: Allows processing to complete but terminates further execution
|
||||
|
||||
This is useful for implementing security checks, rate limiting, or early exit conditions.
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
class PreTerminationMiddleware(AgentMiddleware):
|
||||
"""Middleware that terminates execution before calling the agent."""
|
||||
|
||||
def __init__(self, blocked_words: list[str]):
|
||||
self.blocked_words = [word.lower() for word in blocked_words]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
# Check if the user message contains any blocked words
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
if last_message and last_message.text:
|
||||
query = last_message.text.lower()
|
||||
for blocked_word in self.blocked_words:
|
||||
if blocked_word in query:
|
||||
print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.")
|
||||
|
||||
# Set a custom response
|
||||
context.result = AgentRunResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
text=(
|
||||
f"Sorry, I cannot process requests containing '{blocked_word}'. "
|
||||
"Please rephrase your question."
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Set terminate flag to prevent further processing
|
||||
context.terminate = True
|
||||
break
|
||||
|
||||
await next(context)
|
||||
|
||||
|
||||
class PostTerminationMiddleware(AgentMiddleware):
|
||||
"""Middleware that allows processing but terminates after reaching max responses across multiple runs."""
|
||||
|
||||
def __init__(self, max_responses: int = 1):
|
||||
self.max_responses = max_responses
|
||||
self.response_count = 0
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None:
|
||||
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
|
||||
|
||||
# Check if we should terminate before processing
|
||||
if self.response_count >= self.max_responses:
|
||||
print(
|
||||
f"[PostTerminationMiddleware] Maximum responses ({self.max_responses}) reached. "
|
||||
"Terminating further processing."
|
||||
)
|
||||
context.terminate = True
|
||||
|
||||
# Allow the agent to process normally
|
||||
await next(context)
|
||||
|
||||
# Increment response count after processing
|
||||
self.response_count += 1
|
||||
|
||||
|
||||
async def pre_termination_middleware() -> None:
|
||||
"""Demonstrate pre-termination middleware that blocks requests with certain words."""
|
||||
print("\n--- Example 1: Pre-termination Middleware ---")
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
middleware=PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]),
|
||||
) as agent,
|
||||
):
|
||||
# Test with normal query
|
||||
print("\n1. Normal query:")
|
||||
query = "What's the weather like in Seattle?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
# Test with blocked word
|
||||
print("\n2. Query with blocked word:")
|
||||
query = "What's the bad weather in New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
|
||||
async def post_termination_middleware() -> None:
|
||||
"""Demonstrate post-termination middleware that limits responses across multiple runs."""
|
||||
print("\n--- Example 2: Post-termination Middleware ---")
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
FoundryChatClient(async_credential=credential).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=get_weather,
|
||||
middleware=PostTerminationMiddleware(max_responses=1),
|
||||
) as agent,
|
||||
):
|
||||
# First run (should work)
|
||||
print("\n1. First run:")
|
||||
query = "What's the weather in Paris?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text}")
|
||||
|
||||
# Second run (should be terminated by middleware)
|
||||
print("\n2. Second run (should be terminated):")
|
||||
query = "What about the weather in London?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
|
||||
# Third run (should also be terminated)
|
||||
print("\n3. Third run (should also be terminated):")
|
||||
query = "And New York?"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query)
|
||||
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Example demonstrating middleware termination functionality."""
|
||||
print("=== Middleware Termination Example ===")
|
||||
await pre_termination_middleware()
|
||||
await post_termination_middleware()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user