Python: fixed middleware samples (#5026)

* fixed samples

* small update to explanation

* add snippet fix on root readme
This commit is contained in:
Eduard van Valkenburg
2026-04-01 15:40:27 +02:00
committed by GitHub
Unverified
parent 4b9856e66f
commit cee0a458fe
4 changed files with 85 additions and 45 deletions
@@ -13,8 +13,8 @@ This folder contains focused middleware samples for `Agent`, chat clients, tools
| [`exception_handling_with_middleware.py`](./exception_handling_with_middleware.py) | Shows how middleware can handle failures and recover cleanly. |
| [`function_based_middleware.py`](./function_based_middleware.py) | Shows function-based agent and function middleware. |
| [`middleware_termination.py`](./middleware_termination.py) | Demonstrates stopping a middleware pipeline early. |
| [`override_result_with_middleware.py`](./override_result_with_middleware.py) | Shows how middleware can replace the normal result. |
| [`runtime_context_delegation.py`](./runtime_context_delegation.py) | Demonstrates delegating work with runtime context data. |
| [`override_result_with_middleware.py`](./override_result_with_middleware.py) | Shows how middleware can replace regular and streaming results, then post-process the final response. |
| [`runtime_context_delegation.py`](./runtime_context_delegation.py) | Demonstrates delegating arguments with runtime context data. |
| [`session_behavior_middleware.py`](./session_behavior_middleware.py) | Shows how middleware interacts with session-backed runs. |
| [`shared_state_middleware.py`](./shared_state_middleware.py) | Demonstrates sharing mutable state across middleware invocations. |
| [`usage_tracking_middleware.py`](./usage_tracking_middleware.py) | Demonstrates one chat middleware function that tracks per-call usage in non-streaming and streaming tool-loop runs. |
@@ -81,7 +81,7 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[
role="assistant",
)
context.result = ResponseStream(_override_stream())
context.result = ResponseStream(_override_stream(), finalizer=ChatResponse.from_updates)
else:
# For non-streaming: just replace with a new message
current_text = context.result.text if isinstance(context.result, ChatResponse) else ""
@@ -99,12 +99,17 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[
return
if context.stream and isinstance(context.result, ResponseStream):
result_stream = context.result
def _append_validation_note(response: ChatResponse) -> ChatResponse:
response.messages.append(Message(role="assistant", text=validation_note))
return response
async def _validated_stream() -> AsyncIterable[ChatResponseUpdate]:
async for update in result_stream:
yield update
yield ChatResponseUpdate(
contents=[Content.from_text(text=validation_note)],
role="assistant",
)
context.result.with_finalizer(_append_validation_note)
context.result = ResponseStream(_validated_stream(), finalizer=ChatResponse.from_updates)
elif isinstance(context.result, ChatResponse):
context.result.messages.append(Message(role="assistant", text=validation_note))
@@ -118,11 +123,11 @@ async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[]
validation_note = "Validation: weather data verified."
state = {"found_prefix": False}
state = {"found_prefix": False, "found_validation": False}
def _sanitize(response: AgentResponse) -> AgentResponse:
found_prefix = state["found_prefix"]
found_validation = False
found_validation = state["found_validation"]
cleaned_messages: list[Message] = []
for message in response.messages:
@@ -141,12 +146,14 @@ async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[]
found_prefix = True
text = text.replace("Weather Advisory:", "")
text = re.sub(r"\[\d+\]\s*", "", text)
text = re.sub(r"\[\d+\]\s*", "", text).strip()
if not text:
continue
cleaned_messages.append(
Message(
role=message.role,
text=text.strip(),
text=text,
author_name=message.author_name,
message_id=message.message_id,
additional_properties=message.additional_properties,
@@ -166,19 +173,30 @@ async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[]
if context.stream and isinstance(context.result, ResponseStream):
def _clean_update(update: AgentResponseUpdate) -> AgentResponseUpdate:
cleaned_contents: list[Content] = []
for content in update.contents or []:
if not content.text:
cleaned_contents.append(content)
continue
text = content.text
if "Weather Advisory:" in text:
state["found_prefix"] = True
text = text.replace("Weather Advisory:", "")
if validation_note in text:
state["found_validation"] = True
text = text.replace(validation_note, "").strip()
if not text:
continue
text = re.sub(r"\[\d+\]\s*", "", text)
content.text = text
cleaned_contents.append(content)
update.contents = cleaned_contents
return update
context.result.with_transform_hook(_clean_update)
context.result.with_finalizer(_sanitize)
context.result.with_result_hook(_sanitize)
elif isinstance(context.result, AgentResponse):
context.result = _sanitize(context.result)
@@ -6,6 +6,7 @@ from typing import Annotated
from agent_framework import Agent, FunctionInvocationContext, function_middleware, tool
from agent_framework.foundry import FoundryChatClient
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
@@ -43,6 +44,13 @@ Key Concepts:
- MiddlewareTypes: Intercepts function calls to access/modify kwargs
- Closure: Functions capturing variables from outer scope
- kwargs Propagation: Automatic forwarding of runtime context through delegation chains
Environment Setup:
- Configure Azure credentials (e.g., via Azure CLI)
- Run `az login` to authenticate
- Set FOUNDRY_PROJECT_ENDPOINT to your Azure AI Foundry project endpoint
- Set FOUNDRY_MODEL to the model deployment name (for example: gpt-4o)
"""
@@ -85,7 +93,7 @@ class SessionContextContainer:
runtime_context = SessionContextContainer()
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production.
@tool(approval_mode="never_require")
async def send_email(
to: Annotated[str, Field(description="Recipient email address")],
@@ -149,7 +157,7 @@ async def pattern_1_single_agent_with_closure() -> None:
print("Use case: Single agent with multiple tools sharing runtime context")
print()
client = FoundryChatClient(model="gpt-4o-mini")
client = FoundryChatClient(credential=AzureCliCredential())
# Create agent with both tools and shared context via middleware
communication_agent = Agent(
@@ -177,9 +185,11 @@ async def pattern_1_single_agent_with_closure() -> None:
result1 = await communication_agent.run(
user_query,
# Runtime context passed as kwargs
api_token="sk-test-token-xyz-789",
user_id="user-12345",
session_metadata={"tenant": "acme-corp", "region": "us-west"},
function_invocation_kwargs={
"api_token": "sk-test-token-xyz-789",
"user_id": "user-12345",
"session_metadata": {"tenant": "acme-corp", "region": "us-west"},
},
)
print(f"\nAgent: {result1.text}")
@@ -195,9 +205,11 @@ async def pattern_1_single_agent_with_closure() -> None:
result2 = await communication_agent.run(
user_query2,
# Different runtime context for this request
api_token="sk-prod-token-abc-456",
user_id="user-67890",
session_metadata={"tenant": "store-inc", "region": "eu-central"},
function_invocation_kwargs={
"api_token": "sk-prod-token-abc-456",
"user_id": "user-67890",
"session_metadata": {"tenant": "store-inc", "region": "eu-central"},
},
)
print(f"\nAgent: {result2.text}")
@@ -215,9 +227,11 @@ async def pattern_1_single_agent_with_closure() -> None:
result3 = await communication_agent.run(
user_query3,
api_token="sk-dev-token-def-123",
user_id="user-11111",
session_metadata={"tenant": "dev-team", "region": "us-east"},
function_invocation_kwargs={
"api_token": "sk-dev-token-def-123",
"user_id": "user-11111",
"session_metadata": {"tenant": "dev-team", "region": "us-east"},
},
)
print(f"\nAgent: {result3.text}")
@@ -234,7 +248,9 @@ async def pattern_1_single_agent_with_closure() -> None:
result4 = await communication_agent.run(
user_query4,
# Missing api_token - tools should handle gracefully
user_id="user-22222",
function_invocation_kwargs={
"user_id": "user-22222",
},
)
print(f"\nAgent: {result4.text}")
@@ -295,7 +311,7 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None:
print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}")
await call_next()
client = FoundryChatClient(model="gpt-4o-mini")
client = FoundryChatClient(credential=AzureCliCredential())
# Create specialized sub-agents
email_agent = Agent(
@@ -341,9 +357,11 @@ async def pattern_2_hierarchical_with_kwargs_propagation() -> None:
print("Test: Send email with runtime context\n")
await coordinator.run(
"Send an email to john@example.com with subject 'Meeting' and body 'See you at 2pm'",
api_token="secret-token-abc",
user_id="user-999",
tenant_id="tenant-acme",
function_invocation_kwargs={
"api_token": "secret-token-abc",
"user_id": "user-999",
"tenant_id": "tenant-acme",
},
)
print(f"\n[Verification] EmailAgent received kwargs keys: {list(email_agent_kwargs.keys())}")
@@ -400,7 +418,7 @@ async def pattern_3_hierarchical_with_middleware() -> None:
auth_middleware = AuthContextMiddleware()
client = FoundryChatClient(model="gpt-4o-mini")
client = FoundryChatClient(credential=AzureCliCredential())
# Sub-agent with validation middleware
protected_agent = Agent(
@@ -428,16 +446,20 @@ async def pattern_3_hierarchical_with_middleware() -> None:
print("Test 1: Valid token\n")
await coordinator.run(
"Execute operation: backup_database",
api_token="valid-token-xyz-789",
user_id="admin-123",
function_invocation_kwargs={
"api_token": "valid-token-xyz-789",
"user_id": "admin-123",
},
)
# Test with invalid token
print("\nTest 2: Invalid token\n")
await coordinator.run(
"Execute operation: delete_records",
api_token="invalid-token-bad",
user_id="user-456",
function_invocation_kwargs={
"api_token": "invalid-token-bad",
"user_id": "user-456",
},
)
print(f"\n[Validation Summary] Validated tokens: {len(auth_middleware.validated_tokens)}")