Fix chat middleware: add streaming support, terminate flag, and check only last message (#2120)

This commit fixes three issues in the security_filter_middleware:

1. Missing context.terminate flag - Without this, middleware continues processing after setting blocked response
2. No streaming support - When context.is_streaming is True, middleware now returns async generator with ChatResponseUpdate
3. Checks all messages - Changed to check only context.messages[-1] (most recent user message) instead of iterating through conversation history

Changes:
- Added AsyncIterable import
- Added ChatResponseUpdate and TextContent imports
- Modified security_filter_middleware to handle both streaming and non-streaming modes
- Added context.terminate = True to properly stop execution
- Changed message checking logic to only inspect the last user message

Co-authored-by: Victor Dibia <chuvidi2003@gmail.com>
Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
t-anjan
2025-11-13 09:50:44 +05:30
committed by GitHub
Unverified
parent a4e82f4e04
commit 665aacf1ad
@@ -3,7 +3,7 @@
import logging
import os
from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Annotated
from agent_framework import (
@@ -11,8 +11,10 @@ from agent_framework import (
ChatContext,
ChatMessage,
ChatResponse,
ChatResponseUpdate,
FunctionInvocationContext,
Role,
TextContent,
chat_middleware,
function_middleware,
ai_function
@@ -37,28 +39,42 @@ async def security_filter_middleware(
next: Callable[[ChatContext], Awaitable[None]],
) -> None:
"""Chat middleware that blocks requests containing sensitive information."""
# Block requests with sensitive information
blocked_terms = ["password", "secret", "api_key", "token"]
for message in context.messages:
if message.text:
message_lower = message.text.lower()
for term in blocked_terms:
if term in message_lower:
# Override the response without calling the LLM
# Check only the last message (most recent user input)
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.role == Role.USER and last_message.text:
message_lower = last_message.text.lower()
for term in blocked_terms:
if term in message_lower:
error_message = (
"I cannot process requests containing sensitive information. "
"Please rephrase your question without including passwords, secrets, "
"or other sensitive data."
)
if context.is_streaming:
# Streaming mode: return async generator
async def blocked_stream() -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(
contents=[TextContent(text=error_message)],
role=Role.ASSISTANT,
)
context.result = blocked_stream()
else:
# Non-streaming mode: return complete response
context.result = ChatResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
text=(
"I cannot process requests containing sensitive information. "
"Please rephrase your question without including passwords, secrets, "
"or other sensitive data."
),
text=error_message,
)
]
)
return
context.terminate = True
return
await next(context)