Files
agent-framework/python/packages/purview/tests/test_middleware.py
T
Eduard van Valkenburg 838a7fd61d Python: [BREAKING] Types API Review improvements (#3647)
* Replace Role and FinishReason classes with NewType + Literal

- Remove EnumLike metaclass from _types.py
- Replace Role class with NewType('Role', str) + RoleLiteral
- Replace FinishReason class with NewType('FinishReason', str) + FinishReasonLiteral
- Update all usages across codebase to use string literals
- Remove .value access patterns (direct string comparison now works)
- Add backward compatibility for legacy dict serialization format
- Update tests to reflect new string-based types

Addresses #3591, #3615

* Simplify ChatResponse and AgentResponse type hints (#3592)

- Remove overloads from ChatResponse.__init__
- Remove text parameter from ChatResponse.__init__
- Remove | dict[str, Any] from finish_reason and usage_details params
- Remove **kwargs from AgentResponse.__init__
- Both now accept ChatMessage | Sequence[ChatMessage] | None for messages
- Update docstrings and examples to reflect changes
- Fix tests that were using removed kwargs
- Fix Role type hint usage in ag-ui utils

* Remove text parameter from ChatResponseUpdate and AgentResponseUpdate (#3597)

- Remove text parameter from ChatResponseUpdate.__init__
- Remove text parameter from AgentResponseUpdate.__init__
- Remove **kwargs from both update classes
- Simplify contents parameter type to Sequence[Content] | None
- Update all usages to use contents=[Content.from_text(...)] pattern
- Fix imports in test files
- Update docstrings and examples

* Rename from_chat_response_updates to from_updates (#3593)

- ChatResponse.from_chat_response_updates → ChatResponse.from_updates
- ChatResponse.from_chat_response_generator → ChatResponse.from_update_generator
- AgentResponse.from_agent_run_response_updates → AgentResponse.from_updates

* Remove try_parse_value method from ChatResponse and AgentResponse (#3595)

- Remove try_parse_value method from ChatResponse
- Remove try_parse_value method from AgentResponse
- Remove try_parse_value calls from from_updates and from_update_generator methods
- Update samples to use try/except with response.value instead
- Update tests to use response.value pattern
- Users should now use response.value with try/except for safe parsing

* Add agent_id to AgentResponse and clarify author_name documentation (#3596)

- Add agent_id parameter to AgentResponse class
- Document that author_name is on ChatMessage objects, not responses
- Update ChatResponse docstring with author_name note
- Update AgentResponse docstring with author_name note

* Simplify ChatMessage.__init__ signature (#3618)

- Make contents a positional argument accepting Sequence[Content | str]
- Auto-convert strings in contents to TextContent
- Remove overloads, keep text kwarg for backward compatibility with serialization
- Update _parse_content_list to handle string items
- Update all usages across codebase to use new format: ChatMessage("role", ["text"])

* Allow Content as input on run and get_response

- Update prepare_messages and normalize_messages to accept Content
- Update type signatures in _agents.py and _clients.py
- Add tests for Content input handling

* Fix ChatMessage usage across packages and samples

Update all remaining ChatMessage(role=..., text=...) to use new
ChatMessage('role', ['text']) signature.

* Fix Role string usage and response format parsing

- Fix redis provider: remove .value access on string literals
- Fix durabletask ensure_response_format: set _response_format before accessing .value

* Fix ollama .value and ai_model_id issues, handle None in content list

- Fix ollama _chat_client: remove .value on string literals
- Fix ollama _chat_client: rename ai_model_id to model_id
- Fix _parse_content_list: skip None values gracefully

* Fix A2AAgent type signature to include Content

* Fix Role/FinishReason NewType dict annotations and improve test coverage to 95%

* Fix mypy errors for Role/FinishReason NewType usage

* Fix Role.TOOL and Role.ASSISTANT usage in _orchestrator_helpers.py

* Fix Role NewType usage in durabletask _models.py
2026-02-04 10:13:23 +00:00

339 lines
14 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview middleware."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import AgentResponse, AgentRunContext, ChatMessage
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings
class TestPurviewPolicyMiddleware:
"""Test PurviewPolicyMiddleware functionality."""
@pytest.fixture
def mock_credential(self) -> AsyncMock:
"""Create a mock async credential."""
credential = AsyncMock()
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
return credential
@pytest.fixture
def settings(self) -> PurviewSettings:
"""Create test settings."""
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
@pytest.fixture
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewPolicyMiddleware:
"""Create PurviewPolicyMiddleware instance."""
return PurviewPolicyMiddleware(mock_credential, settings)
@pytest.fixture
def mock_agent(self) -> MagicMock:
"""Create a mock agent."""
agent = MagicMock()
agent.name = "test-agent"
return agent
def test_middleware_initialization(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None:
"""Test PurviewPolicyMiddleware initialization."""
middleware = PurviewPolicyMiddleware(mock_credential, settings)
assert middleware._client is not None
assert middleware._processor is not None
async def test_middleware_allows_clean_prompt(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware allows prompt that passes policy check."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Hello, how are you?"])])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
next_called = False
async def mock_next(ctx: AgentRunContext) -> None:
nonlocal next_called
next_called = True
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["I'm good, thanks!"])])
await middleware.process(context, mock_next)
assert next_called
assert context.result is not None
assert not context.terminate
async def test_middleware_blocks_prompt_on_policy_violation(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware blocks prompt that violates policy."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Sensitive information"])])
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
next_called = False
async def mock_next(ctx: AgentRunContext) -> None:
nonlocal next_called
next_called = True
await middleware.process(context, mock_next)
assert not next_called
assert context.result is not None
assert context.terminate
assert len(context.result.messages) == 1
assert context.result.messages[0].role == "system"
assert "blocked by policy" in context.result.messages[0].text.lower()
async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None:
"""Test middleware checks agent response for policy violations."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Hello"])])
call_count = 0
async def mock_process_messages(messages, activity, user_id=None):
nonlocal call_count
call_count += 1
should_block = call_count != 1
return (should_block, "user-123")
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["Here's some sensitive information"])])
await middleware.process(context, mock_next)
assert call_count == 2
assert context.result is not None
assert len(context.result.messages) == 1
assert context.result.messages[0].role == "system"
assert "blocked by policy" in context.result.messages[0].text.lower()
async def test_middleware_handles_result_without_messages(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware handles result that doesn't have messages attribute."""
# Set ignore_exceptions to True so AttributeError is caught and logged
middleware._settings.ignore_exceptions = True
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Hello"])])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = "Some non-standard result"
await middleware.process(context, mock_next)
assert context.result == "Some non-standard result"
async def test_middleware_processor_receives_correct_activity(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware passes correct activity type to processor."""
from agent_framework_purview._models import Activity
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Test"])])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process:
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["Response"])])
await middleware.process(context, mock_next)
assert mock_process.call_count == 2
for call in mock_process.call_args_list:
assert call[0][1] == Activity.UPLOAD_TEXT
async def test_middleware_streaming_skips_post_check(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that streaming results skip post-check evaluation."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Hello"])])
context.is_streaming = True
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["streaming"])])
await middleware.process(context, mock_next)
assert mock_proc.call_count == 1
async def test_middleware_payment_required_in_pre_check_raises_by_default(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that 402 in pre-check is raised when ignore_payment_required=False."""
from agent_framework_purview._exceptions import PurviewPaymentRequiredError
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Hello"])])
with patch.object(
middleware._processor,
"process_messages",
side_effect=PurviewPaymentRequiredError("Payment required"),
):
async def mock_next(_: AgentRunContext) -> None:
raise AssertionError("next should not be called")
with pytest.raises(PurviewPaymentRequiredError):
await middleware.process(context, mock_next)
async def test_middleware_payment_required_in_post_check_raises_by_default(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that 402 in post-check is raised when ignore_payment_required=False."""
from agent_framework_purview._exceptions import PurviewPaymentRequiredError
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Hello"])])
call_count = 0
async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123")
raise PurviewPaymentRequiredError("Payment required")
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["OK"])])
with pytest.raises(PurviewPaymentRequiredError):
await middleware.process(context, mock_next)
async def test_middleware_post_check_exception_raises_when_ignore_exceptions_false(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that post-check exceptions are propagated when ignore_exceptions=False."""
middleware._settings.ignore_exceptions = False
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Hello"])])
call_count = 0
async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123")
raise ValueError("Post-check blew up")
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["OK"])])
with pytest.raises(ValueError, match="Post-check blew up"):
await middleware.process(context, mock_next)
async def test_middleware_handles_pre_check_exception(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that exceptions in pre-check are logged but don't stop processing when ignore_exceptions=True."""
# Set ignore_exceptions to True
middleware._settings.ignore_exceptions = True
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Test"])])
with patch.object(
middleware._processor, "process_messages", side_effect=Exception("Pre-check error")
) as mock_process:
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["Response"])])
await middleware.process(context, mock_next)
# Should have been called twice (pre-check raises, then post-check also raises)
assert mock_process.call_count == 2
# Context should not be terminated
assert not context.terminate
# Result should be set by mock_next
assert context.result is not None
async def test_middleware_handles_post_check_exception(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that exceptions in post-check are logged but don't affect result when ignore_exceptions=True."""
# Set ignore_exceptions to True
middleware._settings.ignore_exceptions = True
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Test"])])
call_count = 0
async def mock_process_messages(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123") # Pre-check succeeds
raise Exception("Post-check error") # Post-check fails
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["Response"])])
await middleware.process(context, mock_next)
# Should have been called twice (pre and post)
assert call_count == 2
# Result should still be set
assert context.result is not None
assert hasattr(context.result, "messages")
async def test_middleware_with_ignore_exceptions_true(self, mock_credential: AsyncMock) -> None:
"""Test that middleware logs but doesn't throw when ignore_exceptions is True."""
settings = PurviewSettings(app_name="Test App", ignore_exceptions=True)
middleware = PurviewPolicyMiddleware(mock_credential, settings)
mock_agent = MagicMock()
mock_agent.name = "test-agent"
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Test"])])
# Mock processor to raise an exception
async def mock_process_messages(*args, **kwargs):
raise ValueError("Test error")
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx):
ctx.result = AgentResponse(messages=[ChatMessage("assistant", ["Response"])])
# Should not raise, just log
await middleware.process(context, mock_next)
# Result should be set because next was called despite the error
assert context.result is not None
async def test_middleware_with_ignore_exceptions_false(self, mock_credential: AsyncMock) -> None:
"""Test that middleware throws exceptions when ignore_exceptions is False."""
settings = PurviewSettings(app_name="Test App", ignore_exceptions=False)
middleware = PurviewPolicyMiddleware(mock_credential, settings)
mock_agent = MagicMock()
mock_agent.name = "test-agent"
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage("user", ["Test"])])
# Mock processor to raise an exception
async def mock_process_messages(*args, **kwargs):
raise ValueError("Test error")
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx):
pass
# Should raise the exception
with pytest.raises(ValueError, match="Test error"):
await middleware.process(context, mock_next)