mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Fix chat middleware: add streaming support, terminate flag, and check only last message (#2120)
This commit fixes three issues in the security_filter_middleware: 1. Missing context.terminate flag - Without this, middleware continues processing after setting blocked response 2. No streaming support - When context.is_streaming is True, middleware now returns async generator with ChatResponseUpdate 3. Checks all messages - Changed to check only context.messages[-1] (most recent user message) instead of iterating through conversation history Changes: - Added AsyncIterable import - Added ChatResponseUpdate and TextContent imports - Modified security_filter_middleware to handle both streaming and non-streaming modes - Added context.terminate = True to properly stop execution - Changed message checking logic to only inspect the last user message Co-authored-by: Victor Dibia <chuvidi2003@gmail.com> Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
a4e82f4e04
commit
665aacf1ad
@@ -3,7 +3,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
@@ -11,8 +11,10 @@ from agent_framework import (
|
||||
ChatContext,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
FunctionInvocationContext,
|
||||
Role,
|
||||
TextContent,
|
||||
chat_middleware,
|
||||
function_middleware,
|
||||
ai_function
|
||||
@@ -37,28 +39,42 @@ async def security_filter_middleware(
|
||||
next: Callable[[ChatContext], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Chat middleware that blocks requests containing sensitive information."""
|
||||
# Block requests with sensitive information
|
||||
blocked_terms = ["password", "secret", "api_key", "token"]
|
||||
|
||||
for message in context.messages:
|
||||
if message.text:
|
||||
message_lower = message.text.lower()
|
||||
for term in blocked_terms:
|
||||
if term in message_lower:
|
||||
# Override the response without calling the LLM
|
||||
# Check only the last message (most recent user input)
|
||||
last_message = context.messages[-1] if context.messages else None
|
||||
if last_message and last_message.role == Role.USER and last_message.text:
|
||||
message_lower = last_message.text.lower()
|
||||
for term in blocked_terms:
|
||||
if term in message_lower:
|
||||
error_message = (
|
||||
"I cannot process requests containing sensitive information. "
|
||||
"Please rephrase your question without including passwords, secrets, "
|
||||
"or other sensitive data."
|
||||
)
|
||||
|
||||
if context.is_streaming:
|
||||
# Streaming mode: return async generator
|
||||
async def blocked_stream() -> AsyncIterable[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[TextContent(text=error_message)],
|
||||
role=Role.ASSISTANT,
|
||||
)
|
||||
|
||||
context.result = blocked_stream()
|
||||
else:
|
||||
# Non-streaming mode: return complete response
|
||||
context.result = ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT,
|
||||
text=(
|
||||
"I cannot process requests containing sensitive information. "
|
||||
"Please rephrase your question without including passwords, secrets, "
|
||||
"or other sensitive data."
|
||||
),
|
||||
text=error_message,
|
||||
)
|
||||
]
|
||||
)
|
||||
return
|
||||
|
||||
context.terminate = True
|
||||
return
|
||||
|
||||
await next(context)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user