Files
agent-framework/python/packages/purview/tests/test_middleware.py
T
Rishabh Chawla 64826b8f56 Python: [Purview] Add Caching and background processing in Python Purview Middleware (#1844)
* [PythonPurview] Add Caching and background processing

* [PythonPurview] Updates based on comments
2025-11-07 07:43:22 +00:00

257 lines
11 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview middleware."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import AgentRunContext, AgentRunResponse, ChatMessage, Role
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings
class TestPurviewPolicyMiddleware:
"""Test PurviewPolicyMiddleware functionality."""
@pytest.fixture
def mock_credential(self) -> AsyncMock:
"""Create a mock async credential."""
credential = AsyncMock()
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
return credential
@pytest.fixture
def settings(self) -> PurviewSettings:
"""Create test settings."""
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
@pytest.fixture
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewPolicyMiddleware:
"""Create PurviewPolicyMiddleware instance."""
return PurviewPolicyMiddleware(mock_credential, settings)
@pytest.fixture
def mock_agent(self) -> MagicMock:
"""Create a mock agent."""
agent = MagicMock()
agent.name = "test-agent"
return agent
def test_middleware_initialization(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None:
"""Test PurviewPolicyMiddleware initialization."""
middleware = PurviewPolicyMiddleware(mock_credential, settings)
assert middleware._client is not None
assert middleware._processor is not None
async def test_middleware_allows_clean_prompt(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware allows prompt that passes policy check."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello, how are you?")])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
next_called = False
async def mock_next(ctx: AgentRunContext) -> None:
nonlocal next_called
next_called = True
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="I'm good, thanks!")])
await middleware.process(context, mock_next)
assert next_called
assert context.result is not None
assert not context.terminate
async def test_middleware_blocks_prompt_on_policy_violation(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware blocks prompt that violates policy."""
context = AgentRunContext(
agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Sensitive information")]
)
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
next_called = False
async def mock_next(ctx: AgentRunContext) -> None:
nonlocal next_called
next_called = True
await middleware.process(context, mock_next)
assert not next_called
assert context.result is not None
assert context.terminate
assert len(context.result.messages) == 1
assert context.result.messages[0].role == Role.SYSTEM
assert "blocked by policy" in context.result.messages[0].text.lower()
async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None:
"""Test middleware checks agent response for policy violations."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")])
call_count = 0
async def mock_process_messages(messages, activity, user_id=None):
nonlocal call_count
call_count += 1
should_block = call_count != 1
return (should_block, "user-123")
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Here's some sensitive information")]
)
await middleware.process(context, mock_next)
assert call_count == 2
assert context.result is not None
assert len(context.result.messages) == 1
assert context.result.messages[0].role == Role.SYSTEM
assert "blocked by policy" in context.result.messages[0].text.lower()
async def test_middleware_handles_result_without_messages(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware handles result that doesn't have messages attribute."""
# Set ignore_exceptions to True so AttributeError is caught and logged
middleware._settings.ignore_exceptions = True
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = "Some non-standard result"
await middleware.process(context, mock_next)
assert context.result == "Some non-standard result"
async def test_middleware_processor_receives_correct_activity(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware passes correct activity type to processor."""
from agent_framework_purview._models import Activity
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process:
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
await middleware.process(context, mock_next)
assert mock_process.call_count == 2
for call in mock_process.call_args_list:
assert call[0][1] == Activity.UPLOAD_TEXT
async def test_middleware_handles_pre_check_exception(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that exceptions in pre-check are logged but don't stop processing when ignore_exceptions=True."""
# Set ignore_exceptions to True
middleware._settings.ignore_exceptions = True
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
with patch.object(
middleware._processor, "process_messages", side_effect=Exception("Pre-check error")
) as mock_process:
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
await middleware.process(context, mock_next)
# Should have been called twice (pre-check raises, then post-check also raises)
assert mock_process.call_count == 2
# Context should not be terminated
assert not context.terminate
# Result should be set by mock_next
assert context.result is not None
async def test_middleware_handles_post_check_exception(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that exceptions in post-check are logged but don't affect result when ignore_exceptions=True."""
# Set ignore_exceptions to True
middleware._settings.ignore_exceptions = True
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
call_count = 0
async def mock_process_messages(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123") # Pre-check succeeds
raise Exception("Post-check error") # Post-check fails
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
await middleware.process(context, mock_next)
# Should have been called twice (pre and post)
assert call_count == 2
# Result should still be set
assert context.result is not None
assert hasattr(context.result, "messages")
async def test_middleware_with_ignore_exceptions_true(self, mock_credential: AsyncMock) -> None:
"""Test that middleware logs but doesn't throw when ignore_exceptions is True."""
settings = PurviewSettings(app_name="Test App", ignore_exceptions=True)
middleware = PurviewPolicyMiddleware(mock_credential, settings)
mock_agent = MagicMock()
mock_agent.name = "test-agent"
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
# Mock processor to raise an exception
async def mock_process_messages(*args, **kwargs):
raise ValueError("Test error")
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx):
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
# Should not raise, just log
await middleware.process(context, mock_next)
# Result should be set because next was called despite the error
assert context.result is not None
async def test_middleware_with_ignore_exceptions_false(self, mock_credential: AsyncMock) -> None:
"""Test that middleware throws exceptions when ignore_exceptions is False."""
settings = PurviewSettings(app_name="Test App", ignore_exceptions=False)
middleware = PurviewPolicyMiddleware(mock_credential, settings)
mock_agent = MagicMock()
mock_agent.name = "test-agent"
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
# Mock processor to raise an exception
async def mock_process_messages(*args, **kwargs):
raise ValueError("Test error")
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx):
pass
# Should raise the exception
with pytest.raises(ValueError, match="Test error"):
await middleware.process(context, mock_next)