Files
agent-framework/python/packages/purview/tests/test_chat_middleware.py
T
Rishabh Chawla 59da578902 Python: Add Purview Middleware (#1142)
* [Py Purview] Purview Python Initial Commit

* [Py Purview] Purview Python Minor Fixes

* [Py Purview] Purview Python Comment Fixesish

* [Py Purview] Purview Python Agent Middleware Done

* [Py Purview] Purview Python Agent Middleware Done

* [Py Purview] Purview Python Lint Errors

* [Py Purview] Purview Python Final Hopefully

* [Py Purview] Purview Python Final Hopefully

* [Py Purview] Purview Python Fix ReadMe

* [Py Purview] Purview Python Fix MyPy

* [Py Purview] Purview Python Minor Updates on comments

* [Py Purview] Purview Python Fix Build Error

---------

Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
2025-10-16 21:46:04 +00:00

150 lines
6.0 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview chat middleware."""
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import ChatContext, ChatMessage, Role
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings
@dataclass
class DummyChatClient:
name: str = "dummy"
class TestPurviewChatPolicyMiddleware:
@pytest.fixture
def mock_credential(self) -> AsyncMock:
credential = AsyncMock()
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
return credential
@pytest.fixture
def settings(self) -> PurviewSettings:
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
@pytest.fixture
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewChatPolicyMiddleware:
return PurviewChatPolicyMiddleware(mock_credential, settings)
@pytest.fixture
def chat_context(self) -> ChatContext:
chat_client = DummyChatClient()
chat_options = MagicMock()
chat_options.model = "test-model"
return ChatContext(
chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options
)
async def test_initialization(self, middleware: PurviewChatPolicyMiddleware) -> None:
assert middleware._client is not None
assert middleware._processor is not None
async def test_allows_clean_prompt(
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
) -> None:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
next_called = False
async def mock_next(ctx: ChatContext) -> None:
nonlocal next_called
next_called = True
class Result:
def __init__(self):
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Hi there")]
ctx.result = Result()
await middleware.process(chat_context, mock_next)
assert next_called
assert mock_proc.call_count == 2
assert chat_context.result.messages[0].role == Role.ASSISTANT
async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
async def mock_next(ctx: ChatContext) -> None: # should not run
raise AssertionError("next should not be called when prompt blocked")
await middleware.process(chat_context, mock_next)
assert chat_context.terminate
assert chat_context.result
msg = chat_context.result[0] # type: ignore[index]
assert msg.role in ("system", Role.SYSTEM)
assert "blocked" in msg.text.lower()
async def test_blocks_response(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
call_state = {"count": 0}
async def side_effect(messages, activity, user_id=None):
call_state["count"] += 1
should_block = call_state["count"] == 2
return (should_block, "user-123")
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: ChatContext) -> None:
class Result:
def __init__(self):
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Sensitive output")] # pragma: no cover
ctx.result = Result()
await middleware.process(chat_context, mock_next)
assert call_state["count"] == 2
msgs = getattr(chat_context.result, "messages", None) or chat_context.result
first_msg = msgs[0]
assert first_msg.role in ("system", Role.SYSTEM)
assert "blocked" in first_msg.text.lower()
async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMiddleware) -> None:
chat_client = DummyChatClient()
chat_options = MagicMock()
chat_options.model = "test-model"
streaming_context = ChatContext(
chat_client=chat_client,
messages=[ChatMessage(role=Role.USER, text="Hello")],
chat_options=chat_options,
is_streaming=True,
)
with patch.object(middleware._processor, "process_messages", return_value=False) as mock_proc:
async def mock_next(ctx: ChatContext) -> None:
ctx.result = MagicMock()
await middleware.process(streaming_context, mock_next)
assert mock_proc.call_count == 1
async def test_chat_middleware_handles_post_check_exception(
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
) -> None:
"""Test that exceptions in post-check are logged but don't affect result."""
call_count = 0
async def mock_process_messages(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123") # Pre-check succeeds
raise Exception("Post-check error") # Post-check fails
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: ChatContext) -> None:
# Create a mock result with messages attribute
result = MagicMock()
result.messages = [ChatMessage(role=Role.ASSISTANT, text="Response")]
ctx.result = result
await middleware.process(chat_context, mock_next)
# Should have been called twice (pre and post)
assert call_count == 2
# Result should still be set
assert chat_context.result is not None