mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
59da578902
* [Py Purview] Purview Python Initial Commit * [Py Purview] Purview Python Minor Fixes * [Py Purview] Purview Python Comment Fixesish * [Py Purview] Purview Python Agent Middleware Done * [Py Purview] Purview Python Agent Middleware Done * [Py Purview] Purview Python Lint Errors * [Py Purview] Purview Python Final Hopefully * [Py Purview] Purview Python Final Hopefully * [Py Purview] Purview Python Fix ReadMe * [Py Purview] Purview Python Fix MyPy * [Py Purview] Purview Python Minor Updates on comments * [Py Purview] Purview Python Fix Build Error --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
150 lines
6.0 KiB
Python
150 lines
6.0 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
"""Tests for Purview chat middleware."""
|
|
|
|
from dataclasses import dataclass
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from agent_framework import ChatContext, ChatMessage, Role
|
|
from azure.core.credentials import AccessToken
|
|
|
|
from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings
|
|
|
|
|
|
@dataclass
|
|
class DummyChatClient:
|
|
name: str = "dummy"
|
|
|
|
|
|
class TestPurviewChatPolicyMiddleware:
|
|
@pytest.fixture
|
|
def mock_credential(self) -> AsyncMock:
|
|
credential = AsyncMock()
|
|
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
|
|
return credential
|
|
|
|
@pytest.fixture
|
|
def settings(self) -> PurviewSettings:
|
|
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
|
|
|
|
@pytest.fixture
|
|
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewChatPolicyMiddleware:
|
|
return PurviewChatPolicyMiddleware(mock_credential, settings)
|
|
|
|
@pytest.fixture
|
|
def chat_context(self) -> ChatContext:
|
|
chat_client = DummyChatClient()
|
|
chat_options = MagicMock()
|
|
chat_options.model = "test-model"
|
|
return ChatContext(
|
|
chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options
|
|
)
|
|
|
|
async def test_initialization(self, middleware: PurviewChatPolicyMiddleware) -> None:
|
|
assert middleware._client is not None
|
|
assert middleware._processor is not None
|
|
|
|
async def test_allows_clean_prompt(
|
|
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
|
|
) -> None:
|
|
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
|
next_called = False
|
|
|
|
async def mock_next(ctx: ChatContext) -> None:
|
|
nonlocal next_called
|
|
next_called = True
|
|
|
|
class Result:
|
|
def __init__(self):
|
|
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Hi there")]
|
|
|
|
ctx.result = Result()
|
|
|
|
await middleware.process(chat_context, mock_next)
|
|
assert next_called
|
|
assert mock_proc.call_count == 2
|
|
assert chat_context.result.messages[0].role == Role.ASSISTANT
|
|
|
|
async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
|
|
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
|
|
|
|
async def mock_next(ctx: ChatContext) -> None: # should not run
|
|
raise AssertionError("next should not be called when prompt blocked")
|
|
|
|
await middleware.process(chat_context, mock_next)
|
|
assert chat_context.terminate
|
|
assert chat_context.result
|
|
msg = chat_context.result[0] # type: ignore[index]
|
|
assert msg.role in ("system", Role.SYSTEM)
|
|
assert "blocked" in msg.text.lower()
|
|
|
|
async def test_blocks_response(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
|
|
call_state = {"count": 0}
|
|
|
|
async def side_effect(messages, activity, user_id=None):
|
|
call_state["count"] += 1
|
|
should_block = call_state["count"] == 2
|
|
return (should_block, "user-123")
|
|
|
|
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
|
|
|
|
async def mock_next(ctx: ChatContext) -> None:
|
|
class Result:
|
|
def __init__(self):
|
|
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Sensitive output")] # pragma: no cover
|
|
|
|
ctx.result = Result()
|
|
|
|
await middleware.process(chat_context, mock_next)
|
|
assert call_state["count"] == 2
|
|
msgs = getattr(chat_context.result, "messages", None) or chat_context.result
|
|
first_msg = msgs[0]
|
|
assert first_msg.role in ("system", Role.SYSTEM)
|
|
assert "blocked" in first_msg.text.lower()
|
|
|
|
async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMiddleware) -> None:
|
|
chat_client = DummyChatClient()
|
|
chat_options = MagicMock()
|
|
chat_options.model = "test-model"
|
|
streaming_context = ChatContext(
|
|
chat_client=chat_client,
|
|
messages=[ChatMessage(role=Role.USER, text="Hello")],
|
|
chat_options=chat_options,
|
|
is_streaming=True,
|
|
)
|
|
with patch.object(middleware._processor, "process_messages", return_value=False) as mock_proc:
|
|
|
|
async def mock_next(ctx: ChatContext) -> None:
|
|
ctx.result = MagicMock()
|
|
|
|
await middleware.process(streaming_context, mock_next)
|
|
assert mock_proc.call_count == 1
|
|
|
|
async def test_chat_middleware_handles_post_check_exception(
|
|
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
|
|
) -> None:
|
|
"""Test that exceptions in post-check are logged but don't affect result."""
|
|
call_count = 0
|
|
|
|
async def mock_process_messages(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return (False, "user-123") # Pre-check succeeds
|
|
raise Exception("Post-check error") # Post-check fails
|
|
|
|
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
|
|
|
async def mock_next(ctx: ChatContext) -> None:
|
|
# Create a mock result with messages attribute
|
|
result = MagicMock()
|
|
result.messages = [ChatMessage(role=Role.ASSISTANT, text="Response")]
|
|
ctx.result = result
|
|
|
|
await middleware.process(chat_context, mock_next)
|
|
|
|
# Should have been called twice (pre and post)
|
|
assert call_count == 2
|
|
# Result should still be set
|
|
assert chat_context.result is not None
|