Python: Add Purview Middleware (#1142)

* [Py Purview] Purview Python Initial Commit

* [Py Purview] Purview Python Minor Fixes

* [Py Purview] Purview Python Comment Fixesish

* [Py Purview] Purview Python Agent Middleware Done

* [Py Purview] Purview Python Agent Middleware Done

* [Py Purview] Purview Python Lint Errors

* [Py Purview] Purview Python Final Hopefully

* [Py Purview] Purview Python Final Hopefully

* [Py Purview] Purview Python Fix ReadMe

* [Py Purview] Purview Python Fix MyPy

* [Py Purview] Purview Python Minor Updates on comments

* [Py Purview] Purview Python Fix Build Error

---------

Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
Rishabh Chawla
2025-10-16 14:46:04 -07:00
committed by GitHub
Unverified
parent 76ae0a62ac
commit 59da578902
24 changed files with 3707 additions and 5 deletions
@@ -243,4 +243,4 @@ class AzureOpenAIConfigMixin(OpenAIBase):
def_headers = None
self.default_headers = def_headers
super().__init__(model_id=deployment_name, client=client)
super().__init__(model_id=deployment_name, client=client, **kwargs)
@@ -3,12 +3,20 @@
import importlib
from typing import Any
PACKAGE_NAME = "agent_framework_copilotstudio"
PACKAGE_EXTRA = ["microsoft-copilotstudio", "copilotstudio"]
_IMPORTS: dict[str, tuple[str, list[str]]] = {
"CopilotStudioAgent": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]),
"__version__": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]),
"acquire_token": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]),
# Purview (Graph Data Security & Governance) integration exports
"PurviewPolicyMiddleware": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewChatPolicyMiddleware": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewSettings": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewAppLocation": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewLocationType": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewAuthenticationError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewRateLimitError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewRequestError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
"PurviewServiceError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
}
@@ -23,7 +31,7 @@ def __getattr__(name: str) -> Any:
f"please use `pip install agent-framework-{package_extra[0]}`, "
"or update your requirements.txt or pyproject.toml file."
) from exc
raise AttributeError(f"Module `azure` has no attribute {name}.")
raise AttributeError(f"Module `microsoft` has no attribute {name}.")
def __dir__() -> list[str]:
@@ -1,5 +1,29 @@
# Copyright (c) Microsoft. All rights reserved.
from agent_framework_copilotstudio import CopilotStudioAgent, __version__, acquire_token
from agent_framework_purview import (
PurviewAppLocation,
PurviewAuthenticationError,
PurviewChatPolicyMiddleware,
PurviewLocationType,
PurviewPolicyMiddleware,
PurviewRateLimitError,
PurviewRequestError,
PurviewServiceError,
PurviewSettings,
)
__all__ = ["CopilotStudioAgent", "__version__", "acquire_token"]
__all__ = [
"CopilotStudioAgent",
"PurviewAppLocation",
"PurviewAuthenticationError",
"PurviewChatPolicyMiddleware",
"PurviewLocationType",
"PurviewPolicyMiddleware",
"PurviewRateLimitError",
"PurviewRequestError",
"PurviewServiceError",
"PurviewSettings",
"__version__",
"acquire_token",
]
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
+224
View File
@@ -0,0 +1,224 @@
## Microsoft Agent Framework Purview Integration (Python)
`agent-framework-purview` adds Microsoft Purview (Microsoft Graph dataSecurityAndGovernance) policy evaluation to the Microsoft Agent Framework. It lets you enforce data security / governance policies on both the *prompt* (user input + conversation history) and the *model response* before they proceed further in your workflow.
> Status: **Preview**
### Key Features
- Middleware-based policy enforcement (agent-level and chat-client level)
- Blocks or allows content at both ingress (prompt) and egress (response)
- Works with any `ChatAgent` / agent orchestration using the standard Agent Framework middleware pipeline
- Supports both synchronous `TokenCredential` and `AsyncTokenCredential` from `azure-identity`
- Simple, typed configuration via `PurviewSettings` / `PurviewAppLocation`
- Two middleware types:
- `PurviewPolicyMiddleware` (Agent pipeline)
- `PurviewChatPolicyMiddleware` (Chat client middleware list)
### When to Use
Add Purview when you need to:
- Prevent sensitive or disallowed content from being sent to an LLM
- Prevent model output containing disallowed data from leaving the system
- Apply centrally managed policies without rewriting agent logic
---
## Quick Start
```python
import asyncio
from agent_framework import ChatAgent, ChatMessage, Role
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings
from azure.identity import InteractiveBrowserCredential
async def main():
chat_client = AzureOpenAIChatClient() # uses environment for endpoint + deployment
purview_middleware = PurviewPolicyMiddleware(
credential=InteractiveBrowserCredential(),
settings=PurviewSettings(app_name="My Sample App")
)
agent = ChatAgent(
chat_client=chat_client,
instructions="You are a helpful assistant.",
middleware=[purview_middleware]
)
response = await agent.run(ChatMessage(role=Role.USER, text="Summarize zero trust in one sentence."))
print(response)
asyncio.run(main())
```
If a policy violation is detected on the prompt, the middleware terminates the run and substitutes a system message: `"Prompt blocked by policy"`. If on the response, the result becomes `"Response blocked by policy"`.
---
## Authentication
`PurviewClient` uses the `azure-identity` library for token acquisition. You can use any `TokenCredential` or `AsyncTokenCredential` implementation.
The APIs require the following Graph Permissions:
- ProtectionScopes.Compute.All : (userProtectionScopeContainer)[https://learn.microsoft.com/en-us/graph/api/userprotectionscopecontainer-compute]
- Content.Process.All : (processContent)[https://learn.microsoft.com/en-us/graph/api/userdatasecurityandgovernance-processcontent]
- ContentActivity.Write : (contentActivity)[https://learn.microsoft.com/en-us/graph/api/activitiescontainer-post-contentactivities]
### Scopes
`PurviewSettings.get_scopes()` derives the Graph scope list (currently `https://graph.microsoft.com/.default` style).
### Tenant Enablement for Purview
- The tenant requires an e5 license and consumptive billing setup.
- There need to be (Data Loss Prevention)[https://learn.microsoft.com/en-us/purview/dlp-create-deploy-policy] or (Data Collection Policies)[https://learn.microsoft.com/en-us/purview/collection-policies-policy-reference] that apply to the user to call Process Content API else it calls Content Activities API for auditing the message.
---
## Configuration
### `PurviewSettings`
```python
PurviewSettings(
app_name="My App", # Display / logical name
tenant_id=None, # Optional used mainly for auth context
purview_app_location=None, # Optional PurviewAppLocation for scoping
graph_base_uri="https://graph.microsoft.com/v1.0/",
process_inline=False, # Reserved for future inline processing optimizations
blocked_prompt_message="Prompt blocked by policy", # Custom message for blocked prompts
blocked_response_message="Response blocked by policy" # Custom message for blocked responses
)
```
To scope evaluation by location (application, URL, or domain):
```python
from agent_framework.microsoft import (
PurviewAppLocation,
PurviewLocationType,
PurviewSettings,
)
settings = PurviewSettings(
app_name="Contoso Support",
purview_app_location=PurviewAppLocation(
location_type=PurviewLocationType.APPLICATION,
location_value="<app-client-id>"
)
)
```
### Customizing Blocked Messages
By default, when Purview blocks a prompt or response, the middleware returns a generic system message. You can customize these messages by providing your own text in the `PurviewSettings`:
```python
from agent_framework.microsoft import PurviewSettings
settings = PurviewSettings(
app_name="My App",
blocked_prompt_message="Your request contains content that violates our policies. Please rephrase and try again.",
blocked_response_message="The response was blocked due to policy restrictions. Please contact support if you need assistance."
)
```
This is useful for:
- Providing more user-friendly error messages
- Including support contact information
- Localizing messages for different languages
- Adding branding or specific guidance for your application
### Selecting Agent vs Chat Middleware
Use the agent middleware when you already have / want the full agent pipeline:
```python
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings
from azure.identity import DefaultAzureCredential
credential = DefaultAzureCredential()
client = AzureOpenAIChatClient()
agent = ChatAgent(
chat_client=client,
instructions="You are helpful.",
middleware=[PurviewPolicyMiddleware(credential, PurviewSettings(app_name="My App"))]
)
```
Use the chat middleware when you attach directly to a chat client (e.g. minimal agent shell or custom orchestration):
```python
import os
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework.microsoft import PurviewChatPolicyMiddleware, PurviewSettings
from azure.identity import DefaultAzureCredential
credential = DefaultAzureCredential()
chat_client = AzureOpenAIChatClient(
deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
credential=credential,
middleware=[
PurviewChatPolicyMiddleware(credential, PurviewSettings(app_name="My App (Chat)"))
],
)
agent = ChatAgent(chat_client=chat_client, instructions="You are helpful.")
```
The policy logic is identical; the difference is only the hook point in the pipeline.
---
## Middleware Lifecycle
1. Before agent execution (`prompt phase`): all `context.messages` are evaluated.
2. If blocked: `context.result` is replaced with a system message and `context.terminate = True`.
3. After successful agent execution (`response phase`): the produced messages are evaluated.
4. If blocked: result messages are replaced with a blocking notice.
When a user identifier is discovered (e.g. in `ChatMessage.additional_properties['user_id']`) during the prompt phase it is reused for the response phase so both evaluations map consistently to the same user.
You can customize the blocking messages using the `blocked_prompt_message` and `blocked_response_message` fields in `PurviewSettings`. For more advanced scenarios, you can wrap the middleware or post-process `context.result` in later middleware.
---
## Exceptions
| Exception | Scenario |
|-----------|----------|
| `PurviewAuthenticationError` | Token acquisition / validation issues |
| `PurviewRateLimitError` | 429 responses from service |
| `PurviewRequestError` | 4xx client errors (bad input, unauthorized, forbidden) |
| `PurviewServiceError` | 5xx or unexpected service errors |
Catch broadly if you want unified fallback:
```python
from agent_framework.microsoft import (
PurviewAuthenticationError, PurviewRateLimitError,
PurviewRequestError, PurviewServiceError
)
try:
...
except (PurviewAuthenticationError, PurviewRateLimitError, PurviewRequestError, PurviewServiceError) as ex:
# Log / degrade gracefully
print(f"Purview enforcement skipped: {ex}")
```
---
## Notes
- Provide a `user_id` per request (e.g. in `ChatMessage(..., additional_properties={"user_id": "<guid>"})`) when possible for per-user policy scoping; otherwise supply a default via settings or environment.
- Blocking messages can be customized via `blocked_prompt_message` and `blocked_response_message` in `PurviewSettings`. By default, they are "Prompt blocked by policy" and "Response blocked by policy" respectively.
- Streaming responses: post-response policy evaluation presently applies only to non-streaming chat responses.
- Errors during policy checks are logged and do not fail the run; they degrade gracefully.
@@ -0,0 +1,22 @@
# Copyright (c) Microsoft. All rights reserved.
from ._exceptions import (
PurviewAuthenticationError,
PurviewRateLimitError,
PurviewRequestError,
PurviewServiceError,
)
from ._middleware import PurviewChatPolicyMiddleware, PurviewPolicyMiddleware
from ._settings import PurviewAppLocation, PurviewLocationType, PurviewSettings
__all__ = [
"PurviewAppLocation",
"PurviewAuthenticationError",
"PurviewChatPolicyMiddleware",
"PurviewLocationType",
"PurviewPolicyMiddleware",
"PurviewRateLimitError",
"PurviewRequestError",
"PurviewServiceError",
"PurviewSettings",
]
@@ -0,0 +1,126 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import base64
import inspect
import json
from typing import Any, cast
import httpx
from agent_framework import AGENT_FRAMEWORK_USER_AGENT
from agent_framework.observability import get_tracer
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from ._exceptions import (
PurviewAuthenticationError,
PurviewRateLimitError,
PurviewRequestError,
PurviewServiceError,
)
from ._models import (
ContentActivitiesRequest,
ContentActivitiesResponse,
ProcessContentRequest,
ProcessContentResponse,
ProtectionScopesRequest,
ProtectionScopesResponse,
)
from ._settings import PurviewSettings
class PurviewClient:
"""Async client for calling Graph Purview endpoints.
Supports both synchronous TokenCredential and asynchronous AsyncTokenCredential implementations.
A sync credential will be invoked in a thread to avoid blocking the event loop.
"""
def __init__(
self,
credential: TokenCredential | AsyncTokenCredential,
settings: PurviewSettings,
*,
timeout: float | None = 10.0,
):
self._credential: TokenCredential | AsyncTokenCredential = credential
self._settings = settings
self._graph_uri = settings.graph_base_uri.rstrip("/")
self._timeout = timeout
self._client = httpx.AsyncClient(timeout=timeout)
async def close(self) -> None:
await self._client.aclose()
async def _get_token(self, *, tenant_id: str | None = None) -> str:
"""Acquire an access token using either async or sync credential."""
scopes = self._settings.get_scopes()
cred = self._credential
token = cred.get_token(*scopes, tenant_id=tenant_id)
token = await token if inspect.isawaitable(token) else token
return token.token
@staticmethod
def _extract_token_info(token: str) -> dict[str, Any]:
parts = token.split(".")
if len(parts) < 2:
raise ValueError("Invalid JWT token format")
payload = parts[1]
rem = len(payload) % 4
if rem:
payload += "=" * (4 - rem)
decoded = base64.urlsafe_b64decode(payload)
data = json.loads(decoded.decode("utf-8"))
return {
"user_id": data.get("oid") if data.get("idtyp") == "user" else None,
"tenant_id": data.get("tid"),
"client_id": data.get("appid"),
}
async def get_user_info_from_token(self, *, tenant_id: str | None = None) -> dict[str, Any]:
token = await self._get_token(tenant_id=tenant_id)
return self._extract_token_info(token)
async def process_content(self, request: ProcessContentRequest) -> ProcessContentResponse:
with get_tracer().start_as_current_span("purview.process_content"):
token = await self._get_token(tenant_id=request.tenant_id)
url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/processContent"
return cast(ProcessContentResponse, await self._post(url, request, ProcessContentResponse, token))
async def get_protection_scopes(self, request: ProtectionScopesRequest) -> ProtectionScopesResponse:
with get_tracer().start_as_current_span("purview.get_protection_scopes"):
token = await self._get_token()
url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/protectionScopes/compute"
return cast(ProtectionScopesResponse, await self._post(url, request, ProtectionScopesResponse, token))
async def send_content_activities(self, request: ContentActivitiesRequest) -> ContentActivitiesResponse:
with get_tracer().start_as_current_span("purview.send_content_activities"):
token = await self._get_token()
url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/activities/contentActivities"
return cast(ContentActivitiesResponse, await self._post(url, request, ContentActivitiesResponse, token))
async def _post(self, url: str, model: Any, response_type: type[Any], token: str) -> Any:
payload = model.model_dump(by_alias=True, exclude_none=True, mode="json")
headers = {
"Authorization": f"Bearer {token}",
"User-Agent": AGENT_FRAMEWORK_USER_AGENT,
"Content-Type": "application/json",
}
resp = await self._client.post(url, json=payload, headers=headers)
if resp.status_code in (401, 403):
raise PurviewAuthenticationError(f"Auth failure {resp.status_code}: {resp.text}")
if resp.status_code == 429:
raise PurviewRateLimitError(f"Rate limited {resp.status_code}: {resp.text}")
if resp.status_code not in (200, 201, 202):
raise PurviewRequestError(f"Purview request failed {resp.status_code}: {resp.text}")
try:
data = resp.json()
except ValueError:
data = {}
try:
# Prefer pydantic-style model_validate if present, else fall back to constructor.
if hasattr(response_type, "model_validate"):
return response_type.model_validate(data) # type: ignore[no-any-return]
return response_type(**data) # type: ignore[call-arg, no-any-return]
except Exception as ex: # pragma: no cover
raise PurviewServiceError(f"Failed to deserialize Purview response: {ex}") from ex
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft. All rights reserved.
"""Purview specific exceptions (minimal error shaping)."""
from __future__ import annotations
from agent_framework.exceptions import ServiceResponseException
__all__ = [
"PurviewAuthenticationError",
"PurviewRateLimitError",
"PurviewRequestError",
"PurviewServiceError",
]
class PurviewServiceError(ServiceResponseException):
"""Base exception for Purview errors."""
class PurviewAuthenticationError(PurviewServiceError):
"""Authentication / authorization failure (401/403)."""
class PurviewRateLimitError(PurviewServiceError):
"""Rate limiting or throttling (429)."""
class PurviewRequestError(PurviewServiceError):
"""Other non-success HTTP errors."""
@@ -0,0 +1,168 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
from collections.abc import Awaitable, Callable
from agent_framework import AgentMiddleware, AgentRunContext, ChatContext, ChatMiddleware
from agent_framework._logging import get_logger
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from ._client import PurviewClient
from ._models import Activity
from ._processor import ScopedContentProcessor
from ._settings import PurviewSettings
logger = get_logger("agent_framework.purview")
class PurviewPolicyMiddleware(AgentMiddleware):
"""Agent middleware that enforces Purview policies on prompt and response.
Accepts either a synchronous TokenCredential or an AsyncTokenCredential.
Usage:
.. code-block:: python
from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings
from agent_framework import ChatAgent
credential = ... # TokenCredential or AsyncTokenCredential
settings = PurviewSettings(app_name="My App")
agent = ChatAgent(
chat_client=client, instructions="...", middleware=[PurviewPolicyMiddleware(credential, settings)]
)
"""
def __init__(
self,
credential: TokenCredential | AsyncTokenCredential,
settings: PurviewSettings,
) -> None:
self._client = PurviewClient(credential, settings)
self._processor = ScopedContentProcessor(self._client, settings)
self._settings = settings
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None: # type: ignore[override]
resolved_user_id: str | None = None
try:
# Pre (prompt) check
should_block_prompt, resolved_user_id = await self._processor.process_messages(
context.messages, Activity.UPLOAD_TEXT
)
if should_block_prompt:
from agent_framework import AgentRunResponse, ChatMessage, Role
context.result = AgentRunResponse(
messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_prompt_message)]
)
context.terminate = True
return
except Exception as ex:
# Log and continue if there's an error in the pre-check
logger.error(f"Error in Purview policy pre-check: {ex}")
await next(context)
try:
# Post (response) check only if we have a normal AgentRunResponse
# Use the same user_id from the request for the response evaluation
if context.result and not context.is_streaming:
should_block_response, _ = await self._processor.process_messages(
context.result.messages, # type: ignore[union-attr]
Activity.UPLOAD_TEXT,
user_id=resolved_user_id,
)
if should_block_response:
from agent_framework import AgentRunResponse, ChatMessage, Role
context.result = AgentRunResponse(
messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_response_message)]
)
else:
# Streaming responses are not supported for post-checks
logger.debug("Streaming responses are not supported for Purview policy post-checks")
except Exception as ex:
# Log and continue if there's an error in the post-check
logger.error(f"Error in Purview policy post-check: {ex}")
class PurviewChatPolicyMiddleware(ChatMiddleware):
"""Chat middleware variant for Purview policy evaluation.
This allows users to attach Purview enforcement directly to a chat client
Behavior:
* Pre-chat: evaluates outgoing (user + context) messages as an upload activity
and can terminate execution if blocked.
* Post-chat: evaluates the received response messages (streaming is not presently supported)
and can replace them with a blocked message. Uses the same user_id from the request
to ensure consistent user identity throughout the evaluation.
Usage:
.. code-block:: python
from agent_framework.microsoft import PurviewChatPolicyMiddleware, PurviewSettings
from agent_framework import ChatClient
credential = ... # TokenCredential or AsyncTokenCredential
settings = PurviewSettings(app_name="My App")
client = ChatClient(..., middleware=[PurviewChatPolicyMiddleware(credential, settings)])
"""
def __init__(
self,
credential: TokenCredential | AsyncTokenCredential,
settings: PurviewSettings,
) -> None:
self._client = PurviewClient(credential, settings)
self._processor = ScopedContentProcessor(self._client, settings)
self._settings = settings
async def process(
self,
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
) -> None: # type: ignore[override]
resolved_user_id: str | None = None
try:
should_block_prompt, resolved_user_id = await self._processor.process_messages(
context.messages, Activity.UPLOAD_TEXT
)
if should_block_prompt:
from agent_framework import ChatMessage
context.result = [ # type: ignore[assignment]
ChatMessage(role="system", text=self._settings.blocked_prompt_message)
]
context.terminate = True
return
except Exception as ex:
logger.error(f"Error in Purview policy pre-check: {ex}")
await next(context)
try:
# Post (response) evaluation only if non-streaming and we have messages result shape
# Use the same user_id from the request for the response evaluation
if context.result and not context.is_streaming:
result_obj = context.result
messages = getattr(result_obj, "messages", None)
if messages:
should_block_response, _ = await self._processor.process_messages(
messages, Activity.UPLOAD_TEXT, user_id=resolved_user_id
)
if should_block_response:
from agent_framework import ChatMessage
context.result = [ # type: ignore[assignment]
ChatMessage(role="system", text=self._settings.blocked_response_message)
]
else:
logger.debug("Streaming responses are not supported for Purview policy post-checks")
except Exception as ex:
logger.error(f"Error in Purview policy post-check: {ex}")
@@ -0,0 +1,992 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unified Purview model definitions and public export surface."""
from __future__ import annotations
from collections.abc import Mapping, MutableMapping, Sequence
from datetime import datetime
from enum import Enum, Flag, auto
from typing import Any, ClassVar, TypeVar, cast
from uuid import uuid4
from agent_framework._logging import get_logger
from agent_framework._serialization import SerializationMixin
logger = get_logger("agent_framework.purview")
# --------------------------------------------------------------------------------------
# Enums & flag helpers
# --------------------------------------------------------------------------------------
class Activity(str, Enum):
"""High-level activity types representing user or agent operations."""
UNKNOWN = "unknown"
UPLOAD_TEXT = "uploadText"
UPLOAD_FILE = "uploadFile"
DOWNLOAD_TEXT = "downloadText"
DOWNLOAD_FILE = "downloadFile"
class ProtectionScopeActivities(Flag):
"""Flag enumeration of activities used in policy protection scopes."""
NONE = 0
UPLOAD_TEXT = auto()
UPLOAD_FILE = auto()
DOWNLOAD_TEXT = auto()
DOWNLOAD_FILE = auto()
UNKNOWN_FUTURE_VALUE = auto()
def __int__(self) -> int: # pragma: no cover
return self.value
FlagT = TypeVar("FlagT", bound=Flag)
_PROTECTION_SCOPE_ACTIVITIES_MAP: dict[str, ProtectionScopeActivities] = {
"none": ProtectionScopeActivities.NONE,
"uploadText": ProtectionScopeActivities.UPLOAD_TEXT,
"uploadFile": ProtectionScopeActivities.UPLOAD_FILE,
"downloadText": ProtectionScopeActivities.DOWNLOAD_TEXT,
"downloadFile": ProtectionScopeActivities.DOWNLOAD_FILE,
"unknownFutureValue": ProtectionScopeActivities.UNKNOWN_FUTURE_VALUE,
}
_PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER: list[tuple[str, ProtectionScopeActivities]] = [
("uploadText", ProtectionScopeActivities.UPLOAD_TEXT),
("uploadFile", ProtectionScopeActivities.UPLOAD_FILE),
("downloadText", ProtectionScopeActivities.DOWNLOAD_TEXT),
("downloadFile", ProtectionScopeActivities.DOWNLOAD_FILE),
]
def deserialize_flag(
value: object, mapping: Mapping[str, FlagT], enum_cls: type[FlagT]
) -> FlagT | None: # pragma: no cover
"""Deserialize arbitrary input into a flag enum instance."""
if value is None:
return None
if isinstance(value, enum_cls):
return value
if isinstance(value, int):
try:
return enum_cls(value)
except Exception:
return None
flag_value = enum_cls(0)
parts: list[str] = []
if isinstance(value, str):
raw = value.strip()
if not raw:
return enum_cls(0)
parts.extend([p.strip() for p in raw.split(",") if p.strip()])
elif isinstance(value, (list, tuple, set)):
for item in value:
if isinstance(item, str):
parts.extend([p.strip() for p in item.split(",") if p.strip()])
elif isinstance(item, enum_cls):
flag_value |= item
elif isinstance(item, int):
try:
flag_value |= enum_cls(item)
except Exception:
logger.warning(f"Failed to convert int {item} to {enum_cls.__name__}")
else:
return None
for part in parts:
member = mapping.get(part)
if member is not None:
flag_value |= member
if flag_value == enum_cls(0):
none_member = mapping.get("none")
if none_member is not None:
return none_member # type: ignore[return-value,index]
return flag_value
def serialize_flag(
flag_value: Flag | int | None, ordered_parts: Sequence[tuple[str, Flag]]
) -> str | None: # pragma: no cover
"""Serialize a flag enum (or int) into a stable, comma-separated string."""
if flag_value is None:
return None
if isinstance(flag_value, int):
if flag_value == 0:
return "none"
int_parts: list[str] = []
for name, member in ordered_parts:
if flag_value & member.value:
int_parts.append(name)
return ",".join(int_parts) if int_parts else "none"
if not isinstance(flag_value, Flag):
return None
if flag_value.value == 0:
return "none"
parts: list[str] = []
for name, member in ordered_parts:
if flag_value & member:
parts.append(name)
return ",".join(parts) if parts else "none"
class DlpAction(str, Enum):
BLOCK_ACCESS = "blockAccess"
OTHER = "other"
class RestrictionAction(str, Enum):
BLOCK = "block"
OTHER = "other"
class ProtectionScopeState(str, Enum):
NOT_MODIFIED = "notModified"
MODIFIED = "modified"
UNKNOWN_FUTURE_VALUE = "unknownFutureValue"
class ExecutionMode(str, Enum):
EVALUATE_INLINE = "evaluateInline"
EVALUATE_OFFLINE = "evaluateOffline"
UNKNOWN_FUTURE_VALUE = "unknownFutureValue"
class PolicyPivotProperty(str, Enum):
NONE = "none"
ACTIVITY = "activity"
LOCATION = "location"
UNKNOWN_FUTURE_VALUE = "unknownFutureValue"
def translate_activity(activity: Activity) -> ProtectionScopeActivities:
mapping = {
Activity.UNKNOWN: ProtectionScopeActivities.NONE,
Activity.UPLOAD_TEXT: ProtectionScopeActivities.UPLOAD_TEXT,
Activity.UPLOAD_FILE: ProtectionScopeActivities.UPLOAD_FILE,
Activity.DOWNLOAD_TEXT: ProtectionScopeActivities.DOWNLOAD_TEXT,
Activity.DOWNLOAD_FILE: ProtectionScopeActivities.DOWNLOAD_FILE,
}
return mapping.get(activity, ProtectionScopeActivities.UNKNOWN_FUTURE_VALUE)
# --------------------------------------------------------------------------------------
# Simple value models
# --------------------------------------------------------------------------------------
class _AliasSerializable(SerializationMixin):
"""Base class adding alias mapping + pydantic-compat helpers.
Each subclass can define ``_ALIASES`` mapping internal attribute name -> external serialized key.
``to_dict`` will emit external keys; ``from_dict`` (via ``__init__`` preprocessing) accepts either form.
Provides light-weight compatibility helpers ``model_dump`` / ``model_validate``
"""
_ALIASES: ClassVar[dict[str, str]] = {}
def __init__(self, **kwargs: Any) -> None:
# Normalize alias keys -> internal names across the entire class hierarchy
# Collect all aliases from parent classes too
all_aliases: dict[str, str] = {}
for cls in type(self).__mro__:
if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict):
for internal, external in cls._ALIASES.items():
if external not in all_aliases:
all_aliases[external] = internal
# Normalize all aliased keys in kwargs
for external, internal in all_aliases.items():
if external in kwargs and internal not in kwargs:
kwargs[internal] = kwargs.pop(external)
# Set normalized kwargs as attributes
# This will overwrite any None values that child __init__ may have set from default params
for k, v in kwargs.items():
setattr(self, k, v)
# ------------------------------------------------------------------
# Compatibility helpers
# ------------------------------------------------------------------
def model_dump(self, *, by_alias: bool = True, exclude_none: bool = True, **_: Any) -> dict[str, Any]:
# Use self.to_dict() to get alias translation
d = self.to_dict(exclude_none=exclude_none)
# If by_alias=False, translate external -> internal (rarely needed; default True)
if not by_alias and self._ALIASES:
reverse = {v: k for k, v in self._ALIASES.items()}
translated: dict[str, Any] = {}
for k, v in d.items():
translated[reverse.get(k, k)] = v
return translated
return d
def model_dump_json(self, *, by_alias: bool = True, exclude_none: bool = True, **kwargs: Any) -> str:
import json
return json.dumps(self.model_dump(by_alias=by_alias, exclude_none=exclude_none, **kwargs))
@classmethod
def model_validate(cls, value: MutableMapping[str, Any]) -> _AliasSerializable: # type: ignore[name-defined]
return cls(**value)
# ------------------------------------------------------------------
# Override to handle alias emission
# ------------------------------------------------------------------
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
base = SerializationMixin.to_dict(self, exclude=exclude, exclude_none=exclude_none)
# For Graph API models, remove the auto-generated 'type' field if it's in DEFAULT_EXCLUDE
if "type" in self.DEFAULT_EXCLUDE:
base.pop("type", None)
# Collect all aliases from class hierarchy
all_aliases: dict[str, str] = {}
for cls in type(self).__mro__:
if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict):
# Parent aliases first (will be overridden by child if same key)
for internal, external in cls._ALIASES.items():
if internal not in all_aliases:
all_aliases[internal] = external
if not all_aliases:
return base
# Translate internal -> external keys (except 'type' reserved)
translated: dict[str, Any] = {}
for k, v in base.items():
if k == "type":
translated[k] = v
continue
external = all_aliases.get(k, k)
translated[external] = v
return translated
class PolicyLocation(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"data_type": "@odata.type"}
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"type"} # Exclude auto-generated type field for Graph API
def __init__(self, data_type: str | None = None, value: str | None = None, **kwargs: Any) -> None:
# Extract aliased values from kwargs
if "@odata.type" in kwargs:
data_type = kwargs["@odata.type"]
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.data_type = data_type
self.value = value
class ActivityMetadata(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"activity": "activity"}
def __init__(self, activity: Activity, **kwargs: Any) -> None:
super().__init__(activity=activity, **kwargs)
self.activity = activity
class OperatingSystemSpecifications(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {
"operating_system_platform": "operatingSystemPlatform",
"operating_system_version": "operatingSystemVersion",
}
def __init__(
self,
operating_system_platform: str | None = None,
operating_system_version: str | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "operatingSystemPlatform" in kwargs:
operating_system_platform = kwargs["operatingSystemPlatform"]
if "operatingSystemVersion" in kwargs:
operating_system_version = kwargs["operatingSystemVersion"]
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.operating_system_platform = operating_system_platform
self.operating_system_version = operating_system_version
class DeviceMetadata(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {
"ip_address": "ipAddress",
"operating_system_specifications": "operatingSystemSpecifications",
}
def __init__(
self,
ip_address: str | None = None,
operating_system_specifications: OperatingSystemSpecifications | MutableMapping[str, Any] | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "ipAddress" in kwargs:
ip_address = kwargs["ipAddress"]
if "operatingSystemSpecifications" in kwargs:
operating_system_specifications = kwargs["operatingSystemSpecifications"]
# Convert nested objects
if isinstance(operating_system_specifications, MutableMapping):
operating_system_specifications = OperatingSystemSpecifications(**operating_system_specifications)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.ip_address = ip_address
self.operating_system_specifications = operating_system_specifications
class IntegratedAppMetadata(_AliasSerializable):
def __init__(self, name: str | None = None, version: str | None = None, **kwargs: Any) -> None:
super().__init__(name=name, version=version, **kwargs)
self.name = name
self.version = version
class ProtectedAppMetadata(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"application_location": "applicationLocation"}
def __init__(
self,
name: str | None = None,
version: str | None = None,
application_location: PolicyLocation | MutableMapping[str, Any] | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "applicationLocation" in kwargs:
application_location = kwargs["applicationLocation"]
# Convert nested objects
if isinstance(application_location, MutableMapping):
application_location = PolicyLocation(**application_location)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.name = name
self.version = version
self.application_location = application_location # type: ignore[assignment]
class DlpActionInfo(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"restriction_action": "restrictionAction"}
def __init__(
self,
action: DlpAction | None = None,
restriction_action: RestrictionAction | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "restrictionAction" in kwargs:
restriction_action = kwargs["restrictionAction"]
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.action = action
self.restriction_action = restriction_action
class AccessedResourceDetails(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {
"label_id": "labelId",
"access_type": "accessType",
"is_cross_prompt_injection_detected": "isCrossPromptInjectionDetected",
}
def __init__(
self,
identifier: str | None = None,
name: str | None = None,
url: str | None = None,
label_id: str | None = None,
access_type: str | None = None,
status: str | None = None,
is_cross_prompt_injection_detected: bool | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "labelId" in kwargs:
label_id = kwargs["labelId"]
if "accessType" in kwargs:
access_type = kwargs["accessType"]
if "isCrossPromptInjectionDetected" in kwargs:
is_cross_prompt_injection_detected = kwargs["isCrossPromptInjectionDetected"]
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.identifier = identifier
self.name = name
self.url = url
self.label_id = label_id
self.access_type = access_type
self.status = status
self.is_cross_prompt_injection_detected = is_cross_prompt_injection_detected
class AiInteractionPlugin(_AliasSerializable):
def __init__(
self,
identifier: str | None = None,
name: str | None = None,
version: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(identifier=identifier, name=name, version=version, **kwargs)
self.identifier = identifier
self.name = name
self.version = version
class AiAgentInfo(_AliasSerializable):
def __init__(
self,
identifier: str | None = None,
name: str | None = None,
version: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(identifier=identifier, name=name, version=version, **kwargs)
self.identifier = identifier
self.name = name
self.version = version
# --------------------------------------------------------------------------------------
# Content models
# --------------------------------------------------------------------------------------
class GraphDataTypeBase(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"data_type": "@odata.type"}
# Exclude the auto-generated 'type' field - Graph API uses @odata.type instead
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"type"}
def __init__(self, data_type: str, **kwargs: Any) -> None:
super().__init__(data_type=data_type, **kwargs)
self.data_type = data_type
class ContentBase(GraphDataTypeBase):
pass
class PurviewTextContent(ContentBase):
def __init__(self, data: str, data_type: str = "microsoft.graph.textContent", **kwargs: Any) -> None:
super().__init__(data_type=data_type, **kwargs)
self.data = data
class PurviewBinaryContent(ContentBase):
def __init__(self, data: bytes, data_type: str = "microsoft.graph.binaryContent", **kwargs: Any) -> None:
super().__init__(data_type=data_type, **kwargs)
self.data = data
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
import base64
base = super().to_dict(exclude=exclude, exclude_none=exclude_none)
# Ensure bytes encoded as base64 string like pydantic
data_bytes = getattr(self, "data", b"") or b""
base["data"] = base64.b64encode(data_bytes).decode("utf-8")
return base
class ProcessConversationMetadata(GraphDataTypeBase):
_ALIASES: ClassVar[dict[str, str]] = {
"correlation_id": "correlationId",
"sequence_number": "sequenceNumber",
"is_truncated": "isTruncated",
"created_date_time": "createdDateTime",
"modified_date_time": "modifiedDateTime",
"parent_message_id": "parentMessageId",
"accessed_resources": "accessedResources_v2",
}
def __init__(
self,
identifier: str | None = None,
content: PurviewTextContent | PurviewBinaryContent | ContentBase | MutableMapping[str, Any] | None = None,
name: str | None = None,
is_truncated: bool | None = None,
data_type: str = "microsoft.graph.processConversationMetadata", # emitted via base
correlation_id: str | None = None,
sequence_number: int | None = None,
length: int | None = None,
created_date_time: datetime | None = None,
modified_date_time: datetime | None = None,
parent_message_id: str | None = None,
accessed_resources: list[AccessedResourceDetails | MutableMapping[str, Any]] | None = None,
plugins: list[AiInteractionPlugin | MutableMapping[str, Any]] | None = None,
agents: list[AiAgentInfo | MutableMapping[str, Any]] | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "correlationId" in kwargs:
correlation_id = kwargs["correlationId"]
if "sequenceNumber" in kwargs:
sequence_number = kwargs["sequenceNumber"]
if "isTruncated" in kwargs:
is_truncated = kwargs["isTruncated"]
if "createdDateTime" in kwargs:
created_date_time = kwargs["createdDateTime"]
if "modifiedDateTime" in kwargs:
modified_date_time = kwargs["modifiedDateTime"]
if "parentMessageId" in kwargs:
parent_message_id = kwargs["parentMessageId"]
if "accessedResources_v2" in kwargs:
accessed_resources = kwargs["accessedResources_v2"]
# Convert nested objects
if isinstance(content, MutableMapping):
# determine by type? fall back to text content
c_type = content.get("@odata.type") or content.get("data_type")
if c_type and "binary" in str(c_type):
content = PurviewBinaryContent(**content) # type: ignore[arg-type]
else:
content = PurviewTextContent(**content) # type: ignore[arg-type]
accessed_list: list[AccessedResourceDetails] | None = None
if accessed_resources:
accessed_list = [
ar if isinstance(ar, AccessedResourceDetails) else AccessedResourceDetails(**ar)
for ar in accessed_resources
]
plugin_list: list[AiInteractionPlugin] | None = None
if plugins:
plugin_list = [p if isinstance(p, AiInteractionPlugin) else AiInteractionPlugin(**p) for p in plugins]
agent_list: list[AiAgentInfo] | None = None
if agents:
agent_list = [a if isinstance(a, AiAgentInfo) else AiAgentInfo(**a) for a in agents]
# Call parent without explicit params with aliases
super().__init__(data_type=data_type, **kwargs)
self.identifier = identifier
self.content = content # type: ignore[assignment]
self.name = name
self.correlation_id = correlation_id
self.sequence_number = sequence_number
self.length = length
self.is_truncated = is_truncated
self.created_date_time = created_date_time
self.modified_date_time = modified_date_time
self.parent_message_id = parent_message_id
self.accessed_resources = accessed_list
self.plugins = plugin_list
self.agents = agent_list
class ContentToProcess(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {
"content_entries": "contentEntries",
"activity_metadata": "activityMetadata",
"device_metadata": "deviceMetadata",
"integrated_app_metadata": "integratedAppMetadata",
"protected_app_metadata": "protectedAppMetadata",
}
def __init__(
self,
content_entries: list[ProcessConversationMetadata | MutableMapping[str, Any]],
activity_metadata: ActivityMetadata | MutableMapping[str, Any],
device_metadata: DeviceMetadata | MutableMapping[str, Any],
integrated_app_metadata: IntegratedAppMetadata | MutableMapping[str, Any],
protected_app_metadata: ProtectedAppMetadata | MutableMapping[str, Any],
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "contentEntries" in kwargs:
content_entries = kwargs["contentEntries"]
if "activityMetadata" in kwargs:
activity_metadata = kwargs["activityMetadata"]
if "deviceMetadata" in kwargs:
device_metadata = kwargs["deviceMetadata"]
if "integratedAppMetadata" in kwargs:
integrated_app_metadata = kwargs["integratedAppMetadata"]
if "protectedAppMetadata" in kwargs:
protected_app_metadata = kwargs["protectedAppMetadata"]
# Convert nested objects
entries = [
e if isinstance(e, ProcessConversationMetadata) else ProcessConversationMetadata(**e)
for e in content_entries
]
if isinstance(activity_metadata, MutableMapping):
activity_metadata = ActivityMetadata(**activity_metadata)
if isinstance(device_metadata, MutableMapping):
device_metadata = DeviceMetadata(**device_metadata)
if isinstance(integrated_app_metadata, MutableMapping):
integrated_app_metadata = IntegratedAppMetadata(**integrated_app_metadata)
if isinstance(protected_app_metadata, MutableMapping):
protected_app_metadata = ProtectedAppMetadata(**protected_app_metadata)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.content_entries = entries
self.activity_metadata = activity_metadata # type: ignore[assignment]
self.device_metadata = device_metadata # type: ignore[assignment]
self.integrated_app_metadata = integrated_app_metadata # type: ignore[assignment]
self.protected_app_metadata = protected_app_metadata # type: ignore[assignment]
# --------------------------------------------------------------------------------------
# Request models
# --------------------------------------------------------------------------------------
class ProcessContentRequest(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"content_to_process": "contentToProcess"}
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"user_id", "tenant_id", "correlation_id", "process_inline"}
def __init__(
self,
content_to_process: ContentToProcess | MutableMapping[str, Any],
user_id: str,
tenant_id: str,
correlation_id: str | None = None,
process_inline: bool | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "contentToProcess" in kwargs:
content_to_process = kwargs["contentToProcess"]
# Convert nested objects
if isinstance(content_to_process, MutableMapping):
content_to_process = ContentToProcess(**content_to_process)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.content_to_process = content_to_process # type: ignore[assignment]
self.user_id = user_id
self.tenant_id = tenant_id
self.correlation_id = correlation_id
self.process_inline = process_inline
class ProtectionScopesRequest(_AliasSerializable):
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"user_id", "tenant_id", "correlation_id", "scope_identifier"}
_ALIASES: ClassVar[dict[str, str]] = {
"pivot_on": "pivotOn",
"device_metadata": "deviceMetadata",
"integrated_app_metadata": "integratedAppMetadata",
}
def __init__(
self,
user_id: str,
tenant_id: str,
activities: ProtectionScopeActivities | str | int | Sequence[str] | None = None,
locations: list[PolicyLocation | MutableMapping[str, Any]] | None = None,
pivot_on: PolicyPivotProperty | None = None,
device_metadata: DeviceMetadata | MutableMapping[str, Any] | None = None,
integrated_app_metadata: IntegratedAppMetadata | MutableMapping[str, Any] | None = None,
correlation_id: str | None = None,
scope_identifier: str | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "pivotOn" in kwargs:
pivot_on = kwargs["pivotOn"]
if "deviceMetadata" in kwargs:
device_metadata = kwargs["deviceMetadata"]
if "integratedAppMetadata" in kwargs:
integrated_app_metadata = kwargs["integratedAppMetadata"]
# Deserialize activities flag
if not isinstance(activities, ProtectionScopeActivities) and activities is not None:
activities = deserialize_flag(activities, _PROTECTION_SCOPE_ACTIVITIES_MAP, ProtectionScopeActivities)
# Convert nested objects
if locations:
locations = [loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations]
if isinstance(device_metadata, MutableMapping):
device_metadata = DeviceMetadata(**device_metadata)
if isinstance(integrated_app_metadata, MutableMapping):
integrated_app_metadata = IntegratedAppMetadata(**integrated_app_metadata)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.user_id = user_id
self.tenant_id = tenant_id
self.activities = activities # type: ignore[assignment]
self.locations = locations
self.pivot_on = pivot_on
self.device_metadata = device_metadata
self.integrated_app_metadata = integrated_app_metadata
self.correlation_id = correlation_id
self.scope_identifier = scope_identifier
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
# Get base dict (activities will be missing because Flag isn't JSON-serializable)
base = super().to_dict(exclude=exclude, exclude_none=exclude_none)
# Manually serialize activities flag if present and not excluded
if self.activities is not None or not exclude_none:
if self.activities is not None:
base["activities"] = serialize_flag(self.activities, _PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER)
elif not exclude_none:
base["activities"] = None
return base
class ContentActivitiesRequest(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {
"user_id": "userId",
"scope_identifier": "scopeIdentifier",
"content_to_process": "contentMetadata",
}
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"tenant_id", "correlation_id"}
def __init__(
self,
user_id: str,
content_to_process: ContentToProcess | MutableMapping[str, Any],
tenant_id: str,
id: str | None = None,
scope_identifier: str | None = None,
correlation_id: str | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "userId" in kwargs:
user_id = kwargs["userId"]
if "scopeIdentifier" in kwargs:
scope_identifier = kwargs["scopeIdentifier"]
if "contentMetadata" in kwargs:
content_to_process = kwargs["contentMetadata"]
# Convert nested objects
if isinstance(content_to_process, MutableMapping):
content_to_process = ContentToProcess(**content_to_process)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.id = id or str(uuid4())
self.user_id = user_id
self.content_to_process = content_to_process # type: ignore[assignment]
self.tenant_id = tenant_id
self.scope_identifier = scope_identifier
self.correlation_id = correlation_id
# --------------------------------------------------------------------------------------
# Response models
# --------------------------------------------------------------------------------------
class ErrorDetails(_AliasSerializable):
def __init__(self, code: str | None = None, message: str | None = None, **kwargs: Any) -> None:
super().__init__(code=code, message=message, **kwargs)
self.code = code
self.message = message
class ProcessingError(_AliasSerializable):
def __init__(self, message: str | None = None, **kwargs: Any) -> None:
super().__init__(message=message, **kwargs)
self.message = message
class ProcessContentResponse(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {
"protection_scope_state": "protectionScopeState",
"policy_actions": "policyActions",
"processing_errors": "processingErrors",
}
id: str | None
protection_scope_state: ProtectionScopeState | None
policy_actions: list[DlpActionInfo] | None
processing_errors: list[ProcessingError] | None
def __init__(
self,
id: str | None = None,
protection_scope_state: ProtectionScopeState | None = None,
policy_actions: list[DlpActionInfo | MutableMapping[str, Any]] | None = None,
processing_errors: list[ProcessingError | MutableMapping[str, Any]] | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "protectionScopeState" in kwargs:
protection_scope_state = kwargs["protectionScopeState"]
if "policyActions" in kwargs:
policy_actions = kwargs["policyActions"]
if "processingErrors" in kwargs:
processing_errors = kwargs["processingErrors"]
# Convert to objects
converted_policy_actions: list[DlpActionInfo] | None = None
if policy_actions is not None:
converted_policy_actions = cast(
list[DlpActionInfo],
[p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions],
)
converted_processing_errors: list[ProcessingError] | None = None
if processing_errors is not None:
converted_processing_errors = cast(
list[ProcessingError],
[pe if isinstance(pe, ProcessingError) else ProcessingError(**pe) for pe in processing_errors],
)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.id = id
self.protection_scope_state = protection_scope_state
self.policy_actions = converted_policy_actions
self.processing_errors = converted_processing_errors
class PolicyScope(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"policy_actions": "policyActions", "execution_mode": "executionMode"}
activities: ProtectionScopeActivities | None
locations: list[PolicyLocation] | None
policy_actions: list[DlpActionInfo] | None
execution_mode: ExecutionMode | None
def __init__(
self,
activities: ProtectionScopeActivities | str | int | Sequence[str] | None = None,
locations: list[PolicyLocation | MutableMapping[str, Any]] | None = None,
policy_actions: list[DlpActionInfo | MutableMapping[str, Any]] | None = None,
execution_mode: ExecutionMode | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs
if "policyActions" in kwargs:
policy_actions = kwargs["policyActions"]
if "executionMode" in kwargs:
execution_mode = kwargs["executionMode"]
# Deserialize activities flag
if not isinstance(activities, ProtectionScopeActivities) and activities is not None:
activities = deserialize_flag(activities, _PROTECTION_SCOPE_ACTIVITIES_MAP, ProtectionScopeActivities)
# Convert nested objects
converted_locations: list[PolicyLocation] | None = None
if locations is not None:
converted_locations = cast(
list[PolicyLocation],
[loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations],
)
converted_policy_actions: list[DlpActionInfo] | None = None
if policy_actions is not None:
converted_policy_actions = cast(
list[DlpActionInfo],
[p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions],
)
# Call parent without explicit params with aliases
super().__init__(**kwargs)
self.activities = activities # type: ignore[assignment]
self.locations = converted_locations
self.policy_actions = converted_policy_actions
self.execution_mode = execution_mode
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
# Get base dict (activities will be missing because Flag isn't JSON-serializable)
base = super().to_dict(exclude=exclude, exclude_none=exclude_none)
# Manually serialize activities flag if present and not excluded
if self.activities is not None or not exclude_none:
if self.activities is not None:
base["activities"] = serialize_flag(self.activities, _PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER)
elif not exclude_none:
base["activities"] = None
return base
class ProtectionScopesResponse(_AliasSerializable):
_ALIASES: ClassVar[dict[str, str]] = {"scope_identifier": "scopeIdentifier", "scopes": "value"}
scope_identifier: str | None
scopes: list[PolicyScope] | None
def __init__(
self,
scope_identifier: str | None = None,
scopes: list[PolicyScope | MutableMapping[str, Any]] | None = None,
**kwargs: Any,
) -> None:
# Extract aliased values from kwargs before they're normalized by parent
if "scopeIdentifier" in kwargs:
scope_identifier = kwargs["scopeIdentifier"]
if "value" in kwargs:
scopes = kwargs["value"]
converted_scopes: list[PolicyScope] | None = None
if scopes is not None:
converted_scopes = cast(
list[PolicyScope], [s if isinstance(s, PolicyScope) else PolicyScope(**s) for s in scopes]
)
# Don't pass parameters that have aliases - let parent normalize them
super().__init__(**kwargs)
self.scope_identifier = scope_identifier
self.scopes = converted_scopes
class ContentActivitiesResponse(_AliasSerializable):
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"status_code"}
def __init__(
self,
status_code: int | None = None,
error: ErrorDetails | MutableMapping[str, Any] | None = None,
**kwargs: Any,
) -> None:
if isinstance(error, MutableMapping):
error = ErrorDetails(**error)
super().__init__(status_code=status_code, error=error, **kwargs)
self.status_code = status_code
self.error = error # type: ignore[assignment]
__all__ = [
"AccessedResourceDetails",
"Activity",
"ActivityMetadata",
"AiAgentInfo",
"AiInteractionPlugin",
"ContentActivitiesRequest",
"ContentActivitiesResponse",
"ContentBase",
"ContentToProcess",
"DeviceMetadata",
"DlpAction",
"DlpActionInfo",
"ExecutionMode",
"GraphDataTypeBase",
"IntegratedAppMetadata",
"OperatingSystemSpecifications",
"PolicyLocation",
"PolicyPivotProperty",
"PolicyScope",
"ProcessContentRequest",
"ProcessContentResponse",
"ProcessConversationMetadata",
"ProcessingError",
"ProtectedAppMetadata",
"ProtectionScopeActivities",
"ProtectionScopeState",
"ProtectionScopesRequest",
"ProtectionScopesResponse",
"PurviewBinaryContent",
"PurviewTextContent",
"RestrictionAction",
"deserialize_flag",
"serialize_flag",
"translate_activity",
]
@@ -0,0 +1,251 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import uuid
from collections.abc import Iterable, MutableMapping
from typing import Any
from agent_framework import ChatMessage
from ._client import PurviewClient
from ._models import (
Activity,
ActivityMetadata,
ContentActivitiesRequest,
ContentToProcess,
DeviceMetadata,
DlpAction,
DlpActionInfo,
IntegratedAppMetadata,
OperatingSystemSpecifications,
PolicyLocation,
ProcessContentRequest,
ProcessContentResponse,
ProcessConversationMetadata,
ProcessingError,
ProtectedAppMetadata,
ProtectionScopesRequest,
ProtectionScopesResponse,
PurviewTextContent,
RestrictionAction,
translate_activity,
)
from ._settings import PurviewSettings
def _is_valid_guid(value: str | None) -> bool:
"""Check if a string is a valid GUID/UUID format using uuid module."""
if not value:
return False
try:
uuid.UUID(value)
return True
except (ValueError, AttributeError):
return False
class ScopedContentProcessor:
"""Combine protection scopes, process content, and content activities logic."""
def __init__(self, client: PurviewClient, settings: PurviewSettings):
self._client = client
self._settings = settings
async def process_messages(
self, messages: Iterable[ChatMessage], activity: Activity, user_id: str | None = None
) -> tuple[bool, str | None]:
"""Process messages for policy evaluation.
Args:
messages: The messages to process
activity: The activity type (e.g., UPLOAD_TEXT)
user_id: Optional user_id to use for all messages. If provided, this is the fallback.
Returns:
A tuple of (should_block: bool, resolved_user_id: str | None).
The resolved_user_id can be stored and passed back when processing the response
to ensure the same user context is maintained throughout the request/response cycle.
"""
pc_requests, resolved_user_id = await self._map_messages(messages, activity, user_id)
should_block = False
for req in pc_requests:
resp = await self._process_with_scopes(req)
if resp.policy_actions:
for act in resp.policy_actions:
if act.action == DlpAction.BLOCK_ACCESS or act.restriction_action == RestrictionAction.BLOCK:
should_block = True
break
if should_block:
break
return should_block, resolved_user_id
async def _map_messages(
self, messages: Iterable[ChatMessage], activity: Activity, provided_user_id: str | None = None
) -> tuple[list[ProcessContentRequest], str | None]:
"""Map messages to ProcessContentRequests.
Args:
messages: The messages to map
activity: The activity type
provided_user_id: Optional user_id to use. If provided, this is the fallback.
Returns:
A tuple of (requests, resolved_user_id)
"""
results: list[ProcessContentRequest] = []
token_info = None
if not (self._settings.tenant_id and self._settings.purview_app_location):
token_info = await self._client.get_user_info_from_token(tenant_id=self._settings.tenant_id)
tenant_id = (token_info or {}).get("tenant_id") or self._settings.tenant_id
if not tenant_id or not _is_valid_guid(tenant_id):
raise ValueError("Tenant id required or must be inferable from credential")
resolved_user_id = (token_info or {}).get("user_id")
resolved_author_name = None
if not resolved_user_id:
for m in messages:
if m.additional_properties:
potential_user_id = m.additional_properties.get("user_id")
if _is_valid_guid(potential_user_id):
resolved_user_id = potential_user_id
break
if m.author_name and _is_valid_guid(m.author_name) and not resolved_author_name:
resolved_author_name = m.author_name
if not resolved_user_id and resolved_author_name:
resolved_user_id = resolved_author_name
if not resolved_user_id:
resolved_user_id = provided_user_id if provided_user_id and _is_valid_guid(provided_user_id) else None
# Return empty results if user_id is empty
if not resolved_user_id or not _is_valid_guid(resolved_user_id):
return results, None
for m in messages:
message_id = m.message_id or str(uuid.uuid4())
content = PurviewTextContent(data=m.text or "")
meta = ProcessConversationMetadata(
identifier=message_id,
content=content,
name=f"Agent Framework Message {message_id}",
is_truncated=False,
correlation_id=str(uuid.uuid4()),
)
activity_meta = ActivityMetadata(activity=activity)
if self._settings.purview_app_location:
policy_location = PolicyLocation(
data_type=self._settings.purview_app_location.get_policy_location()["@odata.type"],
value=self._settings.purview_app_location.location_value,
)
elif token_info and token_info.get("client_id"):
policy_location = PolicyLocation(
data_type="microsoft.graph.policyLocationApplication",
value=token_info["client_id"],
)
else:
raise ValueError("App location not provided or inferable")
protected_app = ProtectedAppMetadata(
name=self._settings.app_name,
version="1.0",
application_location=policy_location,
)
integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version="1.0")
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Unknown", operating_system_version="Unknown"
)
)
ctp = ContentToProcess(
content_entries=[meta],
activity_metadata=activity_meta,
device_metadata=device_meta,
integrated_app_metadata=integrated_app,
protected_app_metadata=protected_app,
)
req = ProcessContentRequest(
content_to_process=ctp,
user_id=resolved_user_id, # Use the resolved user_id for all messages
tenant_id=tenant_id,
correlation_id=meta.correlation_id,
process_inline=True if self._settings.process_inline else None,
)
results.append(req)
return results, resolved_user_id
async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> ProcessContentResponse:
app_location = pc_request.content_to_process.protected_app_metadata.application_location
locations: list[PolicyLocation | MutableMapping[str, Any]] = [app_location] if app_location is not None else []
ps_req = ProtectionScopesRequest(
user_id=pc_request.user_id,
tenant_id=pc_request.tenant_id,
activities=translate_activity(pc_request.content_to_process.activity_metadata.activity),
locations=locations,
device_metadata=pc_request.content_to_process.device_metadata,
integrated_app_metadata=pc_request.content_to_process.integrated_app_metadata,
correlation_id=pc_request.correlation_id,
)
ps_resp = await self._client.get_protection_scopes(ps_req)
should_process, dlp_actions = self._check_applicable_scopes(pc_request, ps_resp)
if should_process:
pc_resp = await self._client.process_content(pc_request)
pc_resp.policy_actions = self._combine_policy_actions(pc_resp.policy_actions, dlp_actions)
return pc_resp
ca_req = ContentActivitiesRequest(
user_id=pc_request.user_id,
tenant_id=pc_request.tenant_id,
content_to_process=pc_request.content_to_process,
correlation_id=pc_request.correlation_id,
)
ca_resp = await self._client.send_content_activities(ca_req)
if ca_resp.error:
return ProcessContentResponse(processing_errors=[ProcessingError(message=str(ca_resp.error))])
return ProcessContentResponse()
@staticmethod
def _combine_policy_actions(
existing: list[DlpActionInfo] | None, new_actions: list[DlpActionInfo]
) -> list[DlpActionInfo]:
by_key: dict[str, DlpActionInfo] = {}
for a in existing or []:
if a.action:
by_key[a.action] = a
for a in new_actions:
if a.action:
by_key[a.action] = a
return list(by_key.values())
@staticmethod
def _check_applicable_scopes(
pc_request: ProcessContentRequest, ps_response: ProtectionScopesResponse
) -> tuple[bool, list[DlpActionInfo]]:
req_activity = translate_activity(pc_request.content_to_process.activity_metadata.activity)
location = pc_request.content_to_process.protected_app_metadata.application_location
should_process: bool = False
dlp_actions: list[DlpActionInfo] = []
for scope in ps_response.scopes or []:
# Check if all activities in req_activity are present in scope.activities using bitwise flags.
activity_match = bool(scope.activities and (scope.activities & req_activity) == req_activity)
location_match = False
if location is not None:
for loc in scope.locations or []:
if (
loc.data_type
and location.data_type
and loc.data_type.lower().endswith(location.data_type.split(".")[-1].lower())
and loc.value == location.value
):
location_match = True
break
if activity_match and location_match:
should_process = True
if scope.policy_actions:
dlp_actions.extend(scope.policy_actions)
return should_process, dlp_actions
@@ -0,0 +1,71 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
from enum import Enum
from agent_framework._pydantic import AFBaseSettings
from pydantic import BaseModel, Field
from pydantic_settings import SettingsConfigDict
class PurviewLocationType(str, Enum):
"""The type of location for Purview policy evaluation."""
APPLICATION = "application"
URI = "uri"
DOMAIN = "domain"
class PurviewAppLocation(BaseModel):
"""Identifier representing the app's location for Purview policy evaluation."""
location_type: PurviewLocationType = Field(..., description="The location type.")
location_value: str = Field(..., description="The location value.")
def get_policy_location(self) -> dict[str, str]:
ns = "microsoft.graph"
if self.location_type == PurviewLocationType.APPLICATION:
dt = f"{ns}.policyLocationApplication"
elif self.location_type == PurviewLocationType.URI:
dt = f"{ns}.policyLocationUrl"
elif self.location_type == PurviewLocationType.DOMAIN:
dt = f"{ns}.policyLocationDomain"
else: # pragma: no cover - defensive
raise ValueError("Invalid Purview location type")
return {"@odata.type": dt, "value": self.location_value}
class PurviewSettings(AFBaseSettings):
"""Settings for Purview integration mirroring .NET PurviewSettings.
Attributes:
app_name: Public app name.
tenant_id: Optional tenant id (guid) of the user making the request.
purview_app_location: Optional app location for policy evaluation.
graph_base_uri: Base URI for Microsoft Graph.
blocked_prompt_message: Custom message to return when a prompt is blocked by policy.
blocked_response_message: Custom message to return when a response is blocked by policy.
"""
app_name: str = Field(...)
tenant_id: str | None = Field(default=None)
purview_app_location: PurviewAppLocation | None = Field(default=None)
graph_base_uri: str = Field(default="https://graph.microsoft.com/v1.0/")
process_inline: bool = Field(default=False, description="Process content inline if supported.")
blocked_prompt_message: str = Field(
default="Prompt blocked by policy",
description="Message to return when a prompt is blocked by policy.",
)
blocked_response_message: str = Field(
default="Response blocked by policy",
description="Message to return when a response is blocked by policy.",
)
model_config = SettingsConfigDict(populate_by_name=True, validate_assignment=True)
def get_scopes(self) -> list[str]:
from urllib.parse import urlparse
host = urlparse(self.graph_base_uri).hostname or "graph.microsoft.com"
return [f"https://{host}/.default"]
+87
View File
@@ -0,0 +1,87 @@
[project]
name = "agent-framework-purview"
description = "Microsoft Purview (Graph dataSecurityAndGovernance) integration for Microsoft Agent Framework."
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
readme = "README.md"
requires-python = ">=3.10"
version = "0.1.0b1"
license-files = ["LICENSE"]
urls.homepage = "https://github.com/microsoft/agent-framework"
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
urls.release_notes = "https://github.com/microsoft/agent-framework/releases"
urls.issues = "https://github.com/microsoft/agent-framework/issues"
classifiers = [
"License :: OSI Approved :: MIT License",
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Framework :: Pydantic :: 2",
"Typing :: Typed",
]
dependencies = [
"agent-framework-core",
"azure-core>=1.30.0",
"httpx>=0.27.0",
]
[tool.uv]
prerelease = "if-necessary-or-explicit"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
"sys_platform == 'win32'"
]
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
[tool.ruff]
extend = "../../pyproject.toml"
[tool.coverage.run]
omit = [
"**/__init__.py"
]
[tool.pyright]
extend = "../../pyproject.toml"
exclude = ['tests']
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
python_version = "3.10"
ignore_missing_imports = true
disallow_untyped_defs = true
no_implicit_optional = true
check_untyped_defs = true
warn_return_any = true
show_error_codes = true
warn_unused_ignores = false
disallow_incomplete_defs = true
disallow_untyped_decorators = true
[tool.bandit]
targets = ["agent_framework_purview"]
exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_purview"
test = "pytest --cov=agent_framework_purview --cov-report=term-missing:skip-covered tests"
[build-system]
requires = ["flit-core >= 3.9,<4.0"]
build-backend = "flit_core.buildapi"
+68
View File
@@ -0,0 +1,68 @@
# Copyright (c) Microsoft. All rights reserved.
"""Shared pytest fixtures for Purview tests."""
import pytest
from agent_framework_purview._models import (
Activity,
ActivityMetadata,
ContentToProcess,
DeviceMetadata,
IntegratedAppMetadata,
OperatingSystemSpecifications,
PolicyLocation,
ProcessContentRequest,
ProcessConversationMetadata,
ProtectedAppMetadata,
PurviewTextContent,
)
@pytest.fixture
def content_to_process_factory():
"""Factory fixture to create ContentToProcess objects with test data."""
def _create_content(text: str = "Test") -> ContentToProcess:
text_content = PurviewTextContent(data=text)
metadata = ProcessConversationMetadata(
identifier="msg-1",
content=text_content,
name="Test",
is_truncated=False,
)
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Windows", operating_system_version="10"
)
)
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
return ContentToProcess(
content_entries=[metadata],
activity_metadata=activity_meta,
device_metadata=device_meta,
integrated_app_metadata=integrated_app,
protected_app_metadata=protected_app,
)
return _create_content
@pytest.fixture
def process_content_request_factory(content_to_process_factory):
"""Factory fixture to create ProcessContentRequest objects with test data."""
def _create_request(
text: str = "Test", user_id: str = "user-123", tenant_id: str = "tenant-456"
) -> ProcessContentRequest:
content = content_to_process_factory(text)
return ProcessContentRequest(
content_to_process=content,
user_id=user_id,
tenant_id=tenant_id,
)
return _create_request
@@ -0,0 +1,149 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview chat middleware."""
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import ChatContext, ChatMessage, Role
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings
@dataclass
class DummyChatClient:
name: str = "dummy"
class TestPurviewChatPolicyMiddleware:
@pytest.fixture
def mock_credential(self) -> AsyncMock:
credential = AsyncMock()
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
return credential
@pytest.fixture
def settings(self) -> PurviewSettings:
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
@pytest.fixture
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewChatPolicyMiddleware:
return PurviewChatPolicyMiddleware(mock_credential, settings)
@pytest.fixture
def chat_context(self) -> ChatContext:
chat_client = DummyChatClient()
chat_options = MagicMock()
chat_options.model = "test-model"
return ChatContext(
chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options
)
async def test_initialization(self, middleware: PurviewChatPolicyMiddleware) -> None:
assert middleware._client is not None
assert middleware._processor is not None
async def test_allows_clean_prompt(
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
) -> None:
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
next_called = False
async def mock_next(ctx: ChatContext) -> None:
nonlocal next_called
next_called = True
class Result:
def __init__(self):
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Hi there")]
ctx.result = Result()
await middleware.process(chat_context, mock_next)
assert next_called
assert mock_proc.call_count == 2
assert chat_context.result.messages[0].role == Role.ASSISTANT
async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
async def mock_next(ctx: ChatContext) -> None: # should not run
raise AssertionError("next should not be called when prompt blocked")
await middleware.process(chat_context, mock_next)
assert chat_context.terminate
assert chat_context.result
msg = chat_context.result[0] # type: ignore[index]
assert msg.role in ("system", Role.SYSTEM)
assert "blocked" in msg.text.lower()
async def test_blocks_response(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
call_state = {"count": 0}
async def side_effect(messages, activity, user_id=None):
call_state["count"] += 1
should_block = call_state["count"] == 2
return (should_block, "user-123")
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
async def mock_next(ctx: ChatContext) -> None:
class Result:
def __init__(self):
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Sensitive output")] # pragma: no cover
ctx.result = Result()
await middleware.process(chat_context, mock_next)
assert call_state["count"] == 2
msgs = getattr(chat_context.result, "messages", None) or chat_context.result
first_msg = msgs[0]
assert first_msg.role in ("system", Role.SYSTEM)
assert "blocked" in first_msg.text.lower()
async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMiddleware) -> None:
chat_client = DummyChatClient()
chat_options = MagicMock()
chat_options.model = "test-model"
streaming_context = ChatContext(
chat_client=chat_client,
messages=[ChatMessage(role=Role.USER, text="Hello")],
chat_options=chat_options,
is_streaming=True,
)
with patch.object(middleware._processor, "process_messages", return_value=False) as mock_proc:
async def mock_next(ctx: ChatContext) -> None:
ctx.result = MagicMock()
await middleware.process(streaming_context, mock_next)
assert mock_proc.call_count == 1
async def test_chat_middleware_handles_post_check_exception(
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
) -> None:
"""Test that exceptions in post-check are logged but don't affect result."""
call_count = 0
async def mock_process_messages(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123") # Pre-check succeeds
raise Exception("Post-check error") # Post-check fails
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: ChatContext) -> None:
# Create a mock result with messages attribute
result = MagicMock()
result.messages = [ChatMessage(role=Role.ASSISTANT, text="Response")]
ctx.result = result
await middleware.process(chat_context, mock_next)
# Should have been called twice (pre and post)
assert call_count == 2
# Result should still be set
assert chat_context.result is not None
@@ -0,0 +1,238 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview client."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewSettings
from agent_framework_purview._client import PurviewClient
from agent_framework_purview._exceptions import (
PurviewAuthenticationError,
PurviewRateLimitError,
PurviewRequestError,
PurviewServiceError,
)
from agent_framework_purview._models import (
PolicyLocation,
ProcessContentRequest,
ProtectionScopesRequest,
)
class TestPurviewClient:
"""Test PurviewClient functionality."""
@pytest.fixture
def mock_credential(self) -> MagicMock:
"""Create a mock async credential."""
from azure.core.credentials_async import AsyncTokenCredential
credential = MagicMock(spec=AsyncTokenCredential)
mock_token = AccessToken("fake-token", 9999999999)
async def mock_get_token(*args, **kwargs):
return mock_token
credential.get_token = mock_get_token
return credential
@pytest.fixture
def settings(self) -> PurviewSettings:
"""Create test settings."""
return PurviewSettings(app_name="Test App", tenant_id="test-tenant", default_user_id="test-user")
@pytest.fixture
async def client(self, mock_credential: MagicMock, settings: PurviewSettings) -> PurviewClient:
"""Create a PurviewClient with mock credential."""
client = PurviewClient(mock_credential, settings, timeout=10.0)
yield client
await client.close()
async def test_client_initialization(self, mock_credential: MagicMock, settings: PurviewSettings) -> None:
"""Test PurviewClient initialization."""
client = PurviewClient(mock_credential, settings)
assert client._credential == mock_credential
assert client._settings == settings
assert client._graph_uri == "https://graph.microsoft.com/v1.0"
assert client._timeout == 10.0
await client.close()
async def test_get_token_async_credential(self, client: PurviewClient, mock_credential: MagicMock) -> None:
"""Test _get_token with async credential."""
token = await client._get_token(tenant_id="test-tenant")
assert token == "fake-token"
async def test_get_token_sync_credential(self, settings: PurviewSettings) -> None:
"""Test _get_token with sync credential."""
sync_credential = MagicMock()
sync_credential.get_token = MagicMock(return_value=AccessToken("sync-token", 9999999999))
client = PurviewClient(sync_credential, settings)
with patch("asyncio.get_running_loop") as mock_loop:
mock_executor = AsyncMock()
mock_executor.return_value = AccessToken("sync-token", 9999999999)
mock_loop.return_value.run_in_executor = mock_executor
token = await client._get_token(tenant_id="test-tenant")
assert token == "sync-token"
await client.close()
async def test_get_user_info_from_token(self, client: PurviewClient) -> None:
"""Test get_user_info_from_token extracts user info."""
import base64
import json
payload = {"tid": "test-tenant", "oid": "test-user", "idtyp": "user"}
payload_str = json.dumps(payload)
payload_bytes = payload_str.encode("utf-8")
payload_b64 = base64.urlsafe_b64encode(payload_bytes).decode("utf-8").rstrip("=")
fake_token = f"header.{payload_b64}.signature"
with patch.object(client, "_get_token", return_value=fake_token):
user_info = await client.get_user_info_from_token(tenant_id="test-tenant")
assert user_info["tenant_id"] == "test-tenant"
assert user_info["user_id"] == "test-user"
@pytest.mark.parametrize(
"status_code,exception_type",
[
(401, PurviewAuthenticationError),
(403, PurviewAuthenticationError),
(429, PurviewRateLimitError),
(400, PurviewRequestError),
(404, PurviewRequestError),
(500, PurviewServiceError),
(502, PurviewServiceError),
],
)
async def test_post_error_handling(
self, client: PurviewClient, content_to_process_factory, status_code: int, exception_type: type[Exception]
) -> None:
"""Test _post method handles different HTTP errors correctly."""
from agent_framework_purview._models import ProcessContentResponse
content = content_to_process_factory()
request = ProcessContentRequest(
content_to_process=content,
user_id="user-123",
tenant_id="tenant-456",
)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = status_code
mock_response.text = "Error message"
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Error", request=MagicMock(), response=mock_response
)
with patch.object(client._client, "post", return_value=mock_response), pytest.raises(exception_type):
await client._post(
"https://graph.microsoft.com/v1.0/test",
request,
ProcessContentResponse,
"fake-token",
)
async def test_process_content_success(
self, client: PurviewClient, content_to_process_factory, mock_credential: MagicMock
) -> None:
"""Test process_content method success path."""
content = content_to_process_factory("Test message")
request = ProcessContentRequest(
content_to_process=content,
user_id="user-123",
tenant_id="tenant-456",
)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {"id": "response-123", "protectionScopeState": "notModified"}
with patch.object(client._client, "post", return_value=mock_response):
response = await client.process_content(request)
assert response.id == "response-123"
assert response.protection_scope_state == "notModified"
async def test_get_protection_scopes_success(self, client: PurviewClient) -> None:
"""Test get_protection_scopes method success path."""
location = PolicyLocation(**{"@odata.type": "microsoft.graph.policyLocationApplication", "value": "app-id"})
request = ProtectionScopesRequest(
user_id="user-123", tenant_id="tenant-456", locations=[location], correlation_id="corr-789"
)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = {"scopeIdentifier": "scope-123", "value": []}
with patch.object(client._client, "post", return_value=mock_response):
response = await client.get_protection_scopes(request)
assert response.scope_identifier == "scope-123"
assert response.scopes == []
async def test_client_close(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None:
"""Test client properly closes HTTP client."""
client = PurviewClient(mock_credential, settings)
with patch.object(client._client, "aclose", new_callable=AsyncMock) as mock_close:
await client.close()
mock_close.assert_called_once()
async def test_invalid_jwt_token_format(self, client: PurviewClient) -> None:
"""Test that invalid JWT token format raises ValueError."""
with pytest.raises(ValueError, match="Invalid JWT token format"):
client._extract_token_info("invalid-token-without-dots")
async def test_rate_limit_error(self, client: PurviewClient) -> None:
"""Test that 429 status code raises PurviewRateLimitError."""
request = ProcessContentRequest(
user_id="test-user",
tenant_id="test-tenant",
content_to_process=[],
correlation_id="test-correlation-id",
)
with (
patch.object(client, "_get_token", return_value="fake-token"),
patch.object(
client._client,
"post",
return_value=httpx.Response(429, text="Rate limited", request=httpx.Request("POST", "http://test")),
),
pytest.raises(PurviewRateLimitError, match="Rate limited"),
):
await client.process_content(request)
async def test_generic_request_error(self, client: PurviewClient) -> None:
"""Test that non-200/201/202 status codes raise PurviewRequestError."""
request = ProcessContentRequest(
user_id="test-user",
tenant_id="test-tenant",
content_to_process=[],
correlation_id="test-correlation-id",
)
with (
patch.object(client, "_get_token", return_value="fake-token"),
patch.object(
client._client,
"post",
return_value=httpx.Response(
500, text="Internal server error", request=httpx.Request("POST", "http://test")
),
),
pytest.raises(PurviewRequestError, match="Purview request failed"),
):
await client.process_content(request)
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview exceptions."""
from agent_framework_purview import (
PurviewAuthenticationError,
PurviewRateLimitError,
PurviewRequestError,
PurviewServiceError,
)
class TestPurviewExceptions:
"""Test custom Purview exception classes."""
def test_purview_service_error(self) -> None:
"""Test PurviewServiceError base exception."""
error = PurviewServiceError("Service error occurred")
assert str(error) == "Service error occurred"
assert isinstance(error, Exception)
def test_purview_authentication_error(self) -> None:
"""Test PurviewAuthenticationError exception."""
error = PurviewAuthenticationError("Authentication failed")
assert str(error) == "Authentication failed"
assert isinstance(error, PurviewServiceError)
def test_purview_rate_limit_error(self) -> None:
"""Test PurviewRateLimitError exception."""
error = PurviewRateLimitError("Rate limit exceeded")
assert str(error) == "Rate limit exceeded"
assert isinstance(error, PurviewServiceError)
def test_purview_request_error(self) -> None:
"""Test PurviewRequestError exception."""
error = PurviewRequestError("Request failed")
assert str(error) == "Request failed"
assert isinstance(error, PurviewServiceError)
@@ -0,0 +1,201 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview middleware."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import AgentRunContext, AgentRunResponse, ChatMessage, Role
from azure.core.credentials import AccessToken
from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings
class TestPurviewPolicyMiddleware:
"""Test PurviewPolicyMiddleware functionality."""
@pytest.fixture
def mock_credential(self) -> AsyncMock:
"""Create a mock async credential."""
credential = AsyncMock()
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
return credential
@pytest.fixture
def settings(self) -> PurviewSettings:
"""Create test settings."""
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
@pytest.fixture
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewPolicyMiddleware:
"""Create PurviewPolicyMiddleware instance."""
return PurviewPolicyMiddleware(mock_credential, settings)
@pytest.fixture
def mock_agent(self) -> MagicMock:
"""Create a mock agent."""
agent = MagicMock()
agent.name = "test-agent"
return agent
def test_middleware_initialization(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None:
"""Test PurviewPolicyMiddleware initialization."""
middleware = PurviewPolicyMiddleware(mock_credential, settings)
assert middleware._client is not None
assert middleware._processor is not None
async def test_middleware_allows_clean_prompt(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware allows prompt that passes policy check."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello, how are you?")])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
next_called = False
async def mock_next(ctx: AgentRunContext) -> None:
nonlocal next_called
next_called = True
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="I'm good, thanks!")])
await middleware.process(context, mock_next)
assert next_called
assert context.result is not None
assert not context.terminate
async def test_middleware_blocks_prompt_on_policy_violation(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware blocks prompt that violates policy."""
context = AgentRunContext(
agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Sensitive information")]
)
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
next_called = False
async def mock_next(ctx: AgentRunContext) -> None:
nonlocal next_called
next_called = True
await middleware.process(context, mock_next)
assert not next_called
assert context.result is not None
assert context.terminate
assert len(context.result.messages) == 1
assert context.result.messages[0].role == Role.SYSTEM
assert "blocked by policy" in context.result.messages[0].text.lower()
async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None:
"""Test middleware checks agent response for policy violations."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")])
call_count = 0
async def mock_process_messages(messages, activity, user_id=None):
nonlocal call_count
call_count += 1
should_block = call_count != 1
return (should_block, "user-123")
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Here's some sensitive information")]
)
await middleware.process(context, mock_next)
assert call_count == 2
assert context.result is not None
assert len(context.result.messages) == 1
assert context.result.messages[0].role == Role.SYSTEM
assert "blocked by policy" in context.result.messages[0].text.lower()
async def test_middleware_handles_result_without_messages(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware handles result that doesn't have messages attribute."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = "Some non-standard result"
await middleware.process(context, mock_next)
assert context.result == "Some non-standard result"
async def test_middleware_processor_receives_correct_activity(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test middleware passes correct activity type to processor."""
from agent_framework_purview._models import Activity
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process:
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
await middleware.process(context, mock_next)
assert mock_process.call_count == 2
for call in mock_process.call_args_list:
assert call[0][1] == Activity.UPLOAD_TEXT
async def test_middleware_handles_pre_check_exception(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that exceptions in pre-check are logged but don't stop processing."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
with patch.object(
middleware._processor, "process_messages", side_effect=Exception("Pre-check error")
) as mock_process:
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
await middleware.process(context, mock_next)
# Should have been called twice (pre-check raises, then post-check also raises)
assert mock_process.call_count == 2
# Context should not be terminated
assert not context.terminate
# Result should be set by mock_next
assert context.result is not None
async def test_middleware_handles_post_check_exception(
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that exceptions in post-check are logged but don't affect result."""
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
call_count = 0
async def mock_process_messages(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123") # Pre-check succeeds
raise Exception("Post-check error") # Post-check fails
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
async def mock_next(ctx: AgentRunContext) -> None:
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
await middleware.process(context, mock_next)
# Should have been called twice (pre and post)
assert call_count == 2
# Result should still be set
assert context.result is not None
assert hasattr(context.result, "messages")
@@ -0,0 +1,246 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview models and serialization."""
from agent_framework_purview._models import (
Activity,
ActivityMetadata,
ContentToProcess,
DeviceMetadata,
IntegratedAppMetadata,
OperatingSystemSpecifications,
PolicyLocation,
ProcessContentRequest,
ProcessContentResponse,
ProcessConversationMetadata,
ProtectedAppMetadata,
ProtectionScopeActivities,
ProtectionScopesRequest,
ProtectionScopesResponse,
PurviewTextContent,
deserialize_flag,
serialize_flag,
)
class TestFlagOperations:
"""Test flag serialization and deserialization operations."""
def test_protection_scope_activities_flag_combination(self) -> None:
"""Test combining flags."""
combined = ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE
assert combined.value == 3
assert ProtectionScopeActivities.UPLOAD_TEXT in combined
assert ProtectionScopeActivities.UPLOAD_FILE in combined
def test_deserialize_flag_with_string(self) -> None:
"""Test deserializing flag from comma-separated string."""
mapping = {
"uploadText": ProtectionScopeActivities.UPLOAD_TEXT,
"uploadFile": ProtectionScopeActivities.UPLOAD_FILE,
}
result = deserialize_flag("uploadText,uploadFile", mapping, ProtectionScopeActivities)
assert result is not None
assert ProtectionScopeActivities.UPLOAD_TEXT in result
assert ProtectionScopeActivities.UPLOAD_FILE in result
def test_deserialize_flag_with_none(self) -> None:
"""Test deserializing None returns None."""
mapping = {"uploadText": ProtectionScopeActivities.UPLOAD_TEXT}
result = deserialize_flag(None, mapping, ProtectionScopeActivities)
assert result is None
def test_serialize_flag_with_none(self) -> None:
"""Test serializing None returns None."""
result = serialize_flag(None, [])
assert result is None
def test_serialize_flag_with_values(self) -> None:
"""Test serializing flag with values."""
flag = ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE
ordered = [
("uploadText", ProtectionScopeActivities.UPLOAD_TEXT),
("uploadFile", ProtectionScopeActivities.UPLOAD_FILE),
]
result = serialize_flag(flag, ordered)
assert result == "uploadText,uploadFile"
class TestComplexModels:
"""Test complex models with nested structures."""
def test_content_to_process_with_nested_structures(self) -> None:
"""Test ContentToProcess with all nested structures."""
text_content = PurviewTextContent(data="Test")
metadata = ProcessConversationMetadata(
identifier="msg-1",
content=text_content,
name="Test",
is_truncated=False,
)
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Windows", operating_system_version="10"
)
)
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
content = ContentToProcess(
content_entries=[metadata],
activity_metadata=activity_meta,
device_metadata=device_meta,
integrated_app_metadata=integrated_app,
protected_app_metadata=protected_app,
)
assert len(content.content_entries) == 1
assert content.activity_metadata.activity == Activity.UPLOAD_TEXT
assert content.device_metadata.operating_system_specifications.operating_system_platform == "Windows"
assert content.integrated_app_metadata.name == "App"
assert content.protected_app_metadata.name == "Protected"
class TestRequestResponseSerialization:
"""Test request/response serialization with aliases."""
def test_protection_scopes_request_serialization(self) -> None:
"""Test ProtectionScopesRequest serializes activities correctly."""
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
request = ProtectionScopesRequest(
user_id="user-123",
tenant_id="tenant-456",
activities=ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE,
locations=[location],
)
dumped = request.model_dump(by_alias=True, exclude_none=True, mode="json")
assert "activities" in dumped
assert isinstance(dumped["activities"], str)
assert "uploadText" in dumped["activities"]
class TestModelDeserialization:
"""Test model deserialization from API responses."""
def test_protection_scopes_response_deserialization(self) -> None:
"""Test ProtectionScopesResponse deserializes 'value' to 'scopes'."""
api_data = {
"scopeIdentifier": "scope-123",
"value": [
{
"activities": "uploadText,downloadText",
"locations": [{"@odata.type": "location.type", "value": "/path"}],
"policyActions": [{"action": "warn", "restrictionAction": "blockAccess"}],
"executionMode": "evaluateInline",
}
],
}
response = ProtectionScopesResponse.model_validate(api_data)
assert response.scope_identifier == "scope-123"
assert response.scopes is not None
assert len(response.scopes) == 1
assert response.scopes[0].execution_mode == "evaluateInline"
def test_process_content_response_deserialization(self) -> None:
"""Test ProcessContentResponse deserializes aliased fields correctly."""
api_data = {
"id": "response-123",
"protectionScopeState": "blocked",
"policyActions": [{"action": "block", "restrictionAction": "blockAccess"}],
}
response = ProcessContentResponse.model_validate(api_data)
assert response.id == "response-123"
assert response.protection_scope_state == "blocked"
assert len(response.policy_actions) == 1
def test_content_serialization_uses_aliases(self) -> None:
"""Test ContentToProcess serializes with camelCase aliases."""
text_content = PurviewTextContent(data="Test")
metadata = ProcessConversationMetadata(
identifier="msg-1",
content=text_content,
name="Test",
is_truncated=False,
)
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Windows", operating_system_version="10"
)
)
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
content = ContentToProcess(
content_entries=[metadata],
activity_metadata=activity_meta,
device_metadata=device_meta,
integrated_app_metadata=integrated_app,
protected_app_metadata=protected_app,
)
dumped = content.model_dump(by_alias=True, exclude_none=True, mode="json")
assert "contentEntries" in dumped
assert "activityMetadata" in dumped
assert "deviceMetadata" in dumped
assert "integratedAppMetadata" in dumped
assert "protectedAppMetadata" in dumped
def test_process_content_request_excludes_private_fields(self) -> None:
"""Test ProcessContentRequest excludes private fields when serializing."""
text_content = PurviewTextContent(data="Test")
metadata = ProcessConversationMetadata(
identifier="msg-1",
content=text_content,
name="Test",
is_truncated=False,
)
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Windows", operating_system_version="10"
)
)
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
content = ContentToProcess(
content_entries=[metadata],
activity_metadata=activity_meta,
device_metadata=device_meta,
integrated_app_metadata=integrated_app,
protected_app_metadata=protected_app,
)
request = ProcessContentRequest(
content_to_process=content,
user_id="user-123",
tenant_id="tenant-456",
correlation_id="corr-789",
)
dumped = request.model_dump(by_alias=True, exclude_none=True, mode="json")
# Check that excluded fields are not present
assert "user_id" not in dumped
assert "tenant_id" not in dumped
assert "correlation_id" not in dumped
# Check that content is present
assert "contentToProcess" in dumped
@@ -0,0 +1,369 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview processor."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import ChatMessage, Role
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings
from agent_framework_purview._models import (
Activity,
DlpAction,
DlpActionInfo,
ProcessContentResponse,
RestrictionAction,
)
from agent_framework_purview._processor import ScopedContentProcessor, _is_valid_guid
class TestGuidValidation:
"""Test GUID validation helper."""
def test_valid_guid(self) -> None:
"""Test _is_valid_guid with valid GUIDs."""
assert _is_valid_guid("12345678-1234-1234-1234-123456789012")
assert _is_valid_guid("a1b2c3d4-e5f6-4a5b-8c9d-0e1f2a3b4c5d")
def test_invalid_guid(self) -> None:
"""Test _is_valid_guid with invalid GUIDs."""
assert not _is_valid_guid("not-a-guid")
assert not _is_valid_guid("")
assert not _is_valid_guid(None)
class TestScopedContentProcessor:
"""Test ScopedContentProcessor functionality."""
@pytest.fixture
def mock_client(self) -> AsyncMock:
"""Create a mock Purview client."""
client = AsyncMock()
client.get_user_info_from_token = AsyncMock(
return_value={
"tenant_id": "12345678-1234-1234-1234-123456789012",
"user_id": "12345678-1234-1234-1234-123456789012",
"client_id": "12345678-1234-1234-1234-123456789012",
}
)
return client
@pytest.fixture
def settings_with_defaults(self) -> PurviewSettings:
"""Create settings with default values."""
app_location = PurviewAppLocation(
location_type=PurviewLocationType.APPLICATION, location_value="12345678-1234-1234-1234-123456789012"
)
return PurviewSettings(
app_name="Test App",
tenant_id="12345678-1234-1234-1234-123456789012",
purview_app_location=app_location,
)
@pytest.fixture
def settings_without_defaults(self) -> PurviewSettings:
"""Create settings without default values (requiring token info)."""
return PurviewSettings(app_name="Test App")
@pytest.fixture
def processor(self, mock_client: AsyncMock, settings_with_defaults: PurviewSettings) -> ScopedContentProcessor:
"""Create a ScopedContentProcessor with mock client."""
return ScopedContentProcessor(mock_client, settings_with_defaults)
async def test_processor_initialization(
self, mock_client: AsyncMock, settings_with_defaults: PurviewSettings
) -> None:
"""Test ScopedContentProcessor initialization."""
processor = ScopedContentProcessor(mock_client, settings_with_defaults)
assert processor._client == mock_client
assert processor._settings == settings_with_defaults
async def test_process_messages_with_defaults(self, processor: ScopedContentProcessor) -> None:
"""Test process_messages with settings that have defaults."""
messages = [
ChatMessage(role=Role.USER, text="Hello"),
ChatMessage(role=Role.ASSISTANT, text="Hi there"),
]
with patch.object(processor, "_map_messages", return_value=([], None)) as mock_map:
should_block, user_id = await processor.process_messages(messages, Activity.UPLOAD_TEXT)
assert should_block is False
assert user_id is None
mock_map.assert_called_once_with(messages, Activity.UPLOAD_TEXT, None)
async def test_process_messages_blocks_content(
self, processor: ScopedContentProcessor, process_content_request_factory
) -> None:
"""Test process_messages returns True when content should be blocked."""
messages = [ChatMessage(role=Role.USER, text="Sensitive content")]
mock_request = process_content_request_factory("Sensitive content")
mock_response = ProcessContentResponse(**{
"policyActions": [DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)]
})
with (
patch.object(processor, "_map_messages", return_value=([mock_request], "user-123")),
patch.object(processor, "_process_with_scopes", return_value=mock_response),
):
should_block, user_id = await processor.process_messages(messages, Activity.UPLOAD_TEXT)
assert should_block is True
assert user_id == "user-123"
async def test_map_messages_creates_requests(
self, processor: ScopedContentProcessor, mock_client: AsyncMock
) -> None:
"""Test _map_messages creates ProcessContentRequest objects."""
messages = [
ChatMessage(
role=Role.USER,
text="Test message",
message_id="msg-123",
author_name="12345678-1234-1234-1234-123456789012",
),
]
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
assert len(requests) == 1
assert requests[0].user_id == "12345678-1234-1234-1234-123456789012"
assert requests[0].tenant_id == "12345678-1234-1234-1234-123456789012"
assert user_id == "12345678-1234-1234-1234-123456789012"
async def test_map_messages_without_defaults_gets_token_info(self, mock_client: AsyncMock) -> None:
"""Test _map_messages gets token info when settings lack some defaults."""
settings = PurviewSettings(app_name="Test App", tenant_id="12345678-1234-1234-1234-123456789012")
processor = ScopedContentProcessor(mock_client, settings)
messages = [ChatMessage(role=Role.USER, text="Test", message_id="msg-123")]
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
mock_client.get_user_info_from_token.assert_called_once()
assert len(requests) == 1
assert user_id is not None
async def test_map_messages_raises_on_missing_tenant_id(self, mock_client: AsyncMock) -> None:
"""Test _map_messages raises ValueError when tenant_id cannot be determined."""
settings = PurviewSettings(app_name="Test App") # No tenant_id
processor = ScopedContentProcessor(mock_client, settings)
mock_client.get_user_info_from_token = AsyncMock(
return_value={"user_id": "test-user", "client_id": "test-client"}
)
messages = [ChatMessage(role=Role.USER, text="Test", message_id="msg-123")]
with pytest.raises(ValueError, match="Tenant id required"):
await processor._map_messages(messages, Activity.UPLOAD_TEXT)
async def test_check_applicable_scopes_no_scopes(
self, processor: ScopedContentProcessor, process_content_request_factory
) -> None:
"""Test _check_applicable_scopes when no scopes are returned."""
from agent_framework_purview._models import ProtectionScopesResponse
request = process_content_request_factory()
response = ProtectionScopesResponse(**{"value": None})
should_process, actions = processor._check_applicable_scopes(request, response)
assert should_process is False
assert actions == []
async def test_check_applicable_scopes_with_block_action(
self, processor: ScopedContentProcessor, process_content_request_factory
) -> None:
"""Test _check_applicable_scopes identifies block actions."""
from agent_framework_purview._models import (
PolicyLocation,
PolicyScope,
ProtectionScopeActivities,
ProtectionScopesResponse,
)
request = process_content_request_factory()
block_action = DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)
scope_location = PolicyLocation(**{
"@odata.type": "microsoft.graph.policyLocationApplication",
"value": "app-id",
})
scope = PolicyScope(**{
"policyActions": [block_action],
"activities": ProtectionScopeActivities.UPLOAD_TEXT,
"locations": [scope_location],
})
response = ProtectionScopesResponse(**{"value": [scope]})
should_process, actions = processor._check_applicable_scopes(request, response)
assert should_process is True
assert len(actions) == 1
assert actions[0].action == DlpAction.BLOCK_ACCESS
async def test_combine_policy_actions(self, processor: ScopedContentProcessor) -> None:
"""Test _combine_policy_actions merges action lists."""
action1 = DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)
action2 = DlpActionInfo(action=DlpAction.OTHER, restrictionAction=RestrictionAction.OTHER)
combined = processor._combine_policy_actions([action1], [action2])
assert len(combined) == 2
assert action1 in combined
assert action2 in combined
async def test_process_with_scopes_calls_client_methods(
self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory
) -> None:
"""Test _process_with_scopes calls get_protection_scopes and process_content."""
from agent_framework_purview._models import (
ContentActivitiesResponse,
ProtectionScopesResponse,
)
request = process_content_request_factory()
mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": []}))
mock_client.process_content = AsyncMock(
return_value=ProcessContentResponse(**{"id": "response-123", "protectionScopeState": "notModified"})
)
mock_client.send_content_activities = AsyncMock(return_value=ContentActivitiesResponse(**{"error": None}))
response = await processor._process_with_scopes(request)
mock_client.get_protection_scopes.assert_called_once()
mock_client.process_content.assert_not_called()
mock_client.send_content_activities.assert_called_once()
assert response.id is None
async def test_map_messages_with_user_id_in_additional_properties(self, mock_client: AsyncMock) -> None:
"""Test user_id extraction from message additional_properties."""
settings = PurviewSettings(
app_name="Test App",
tenant_id="12345678-1234-1234-1234-123456789012",
purview_app_location=PurviewAppLocation(
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
),
)
processor = ScopedContentProcessor(mock_client, settings)
messages = [
ChatMessage(
role=Role.USER,
text="Test message",
additional_properties={"user_id": "22345678-1234-1234-1234-123456789012"},
),
]
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
assert len(requests) == 1
assert user_id == "22345678-1234-1234-1234-123456789012"
assert requests[0].user_id == "22345678-1234-1234-1234-123456789012"
async def test_map_messages_with_provided_user_id_fallback(self, mock_client: AsyncMock) -> None:
"""Test using provided_user_id when no other source is available."""
settings = PurviewSettings(
app_name="Test App",
tenant_id="12345678-1234-1234-1234-123456789012",
purview_app_location=PurviewAppLocation(
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
),
)
processor = ScopedContentProcessor(mock_client, settings)
messages = [ChatMessage(role=Role.USER, text="Test message")]
requests, user_id = await processor._map_messages(
messages, Activity.UPLOAD_TEXT, provided_user_id="32345678-1234-1234-1234-123456789012"
)
assert len(requests) == 1
assert user_id == "32345678-1234-1234-1234-123456789012"
assert requests[0].user_id == "32345678-1234-1234-1234-123456789012"
async def test_map_messages_returns_empty_when_no_user_id(self, mock_client: AsyncMock) -> None:
"""Test that empty results are returned when user_id cannot be resolved."""
settings = PurviewSettings(
app_name="Test App",
tenant_id="12345678-1234-1234-1234-123456789012",
purview_app_location=PurviewAppLocation(
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
),
)
processor = ScopedContentProcessor(mock_client, settings)
messages = [ChatMessage(role=Role.USER, text="Test message")]
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
assert len(requests) == 0
assert user_id is None
async def test_process_content_sends_activities_when_not_applicable(
self, mock_client: AsyncMock, process_content_request_factory
) -> None:
"""Test that content activities are sent when scopes don't apply."""
settings = PurviewSettings(
app_name="Test App",
tenant_id="12345678-1234-1234-1234-123456789012",
purview_app_location=PurviewAppLocation(
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
),
)
processor = ScopedContentProcessor(mock_client, settings)
pc_request = process_content_request_factory()
# Mock get_protection_scopes to return no applicable scopes
mock_ps_response = MagicMock()
mock_ps_response.scopes = []
mock_client.get_protection_scopes.return_value = mock_ps_response
# Mock send_content_activities to return success
mock_ca_response = MagicMock()
mock_ca_response.error = None
mock_client.send_content_activities.return_value = mock_ca_response
response = await processor._process_with_scopes(pc_request)
mock_client.get_protection_scopes.assert_called_once()
mock_client.process_content.assert_not_called()
mock_client.send_content_activities.assert_called_once()
# When content activities succeed, response has no errors (processing_errors can be None or empty)
assert response.processing_errors is None or response.processing_errors == []
async def test_process_content_handles_activities_error(
self, mock_client: AsyncMock, process_content_request_factory
) -> None:
"""Test error handling when content activities fail."""
settings = PurviewSettings(
app_name="Test App",
tenant_id="12345678-1234-1234-1234-123456789012",
purview_app_location=PurviewAppLocation(
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
),
)
processor = ScopedContentProcessor(mock_client, settings)
pc_request = process_content_request_factory()
# Mock get_protection_scopes to return no applicable scopes
mock_ps_response = MagicMock()
mock_ps_response.scopes = []
mock_client.get_protection_scopes.return_value = mock_ps_response
# Mock send_content_activities to return error
mock_ca_response = MagicMock()
mock_ca_response.error = "Test error message"
mock_client.send_content_activities.return_value = mock_ca_response
response = await processor._process_with_scopes(pc_request)
assert len(response.processing_errors) == 1
assert response.processing_errors[0].message == "Test error message"
@@ -0,0 +1,85 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for Purview settings."""
import pytest
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings
class TestPurviewSettings:
"""Test PurviewSettings configuration."""
def test_settings_defaults(self) -> None:
"""Test PurviewSettings with default values."""
settings = PurviewSettings(app_name="Test App")
assert settings.app_name == "Test App"
assert settings.graph_base_uri == "https://graph.microsoft.com/v1.0/"
assert settings.tenant_id is None
assert settings.purview_app_location is None
assert settings.process_inline is False
def test_settings_with_custom_values(self) -> None:
"""Test PurviewSettings with custom values."""
app_location = PurviewAppLocation(location_type=PurviewLocationType.APPLICATION, location_value="app-123")
settings = PurviewSettings(
app_name="Test App",
graph_base_uri="https://graph.microsoft-ppe.com",
tenant_id="test-tenant-id",
process_inline=True,
purview_app_location=app_location,
)
assert settings.graph_base_uri == "https://graph.microsoft-ppe.com"
assert settings.tenant_id == "test-tenant-id"
assert settings.process_inline is True
assert settings.purview_app_location.location_value == "app-123"
@pytest.mark.parametrize(
"graph_uri,expected_scope",
[
("https://graph.microsoft.com/v1.0/", "https://graph.microsoft.com/.default"),
("https://graph.microsoft-ppe.com/v1.0/", "https://graph.microsoft-ppe.com/.default"),
],
)
def test_get_scopes(self, graph_uri: str, expected_scope: str) -> None:
"""Test get_scopes returns correct scope for different URIs."""
settings = PurviewSettings(app_name="Test App", graph_base_uri=graph_uri)
scopes = settings.get_scopes()
assert len(scopes) == 1
assert expected_scope in scopes
class TestPurviewAppLocation:
"""Test PurviewAppLocation configuration."""
@pytest.mark.parametrize(
"location_type,location_value,expected_odata_type",
[
(PurviewLocationType.APPLICATION, "app-123", "microsoft.graph.policyLocationApplication"),
(PurviewLocationType.URI, "https://example.com", "microsoft.graph.policyLocationUrl"),
(PurviewLocationType.DOMAIN, "example.com", "microsoft.graph.policyLocationDomain"),
],
)
def test_get_policy_location(
self, location_type: PurviewLocationType, location_value: str, expected_odata_type: str
) -> None:
"""Test get_policy_location returns correct structure for all location types."""
location = PurviewAppLocation(location_type=location_type, location_value=location_value)
policy_location = location.get_policy_location()
assert policy_location["@odata.type"] == expected_odata_type
assert policy_location["value"] == location_value
class TestPurviewLocationType:
"""Test PurviewLocationType enum."""
def test_location_type_values(self) -> None:
"""Test PurviewLocationType enum has expected values."""
assert PurviewLocationType.APPLICATION == "application"
assert PurviewLocationType.URI == "uri"
assert PurviewLocationType.DOMAIN == "domain"
@@ -0,0 +1,88 @@
## Purview Policy Enforcement Sample (Python)
This getting-started sample shows how to attach Microsoft Purview policy evaluation to an Agent Framework `ChatAgent` using the **middleware** approach.
1. Configure an Azure OpenAI chat client
2. Add Purview policy enforcement middleware (`PurviewPolicyMiddleware`)
3. Run a short conversation and observe prompt / response blocking behavior
---
## 1. Setup
### Required Environment Variables
| Variable | Required | Purpose |
|----------|----------|---------|
| `AZURE_OPENAI_ENDPOINT` | Yes | Azure OpenAI endpoint (https://<name>.openai.azure.com) |
| `AZURE_OPENAI_DEPLOYMENT_NAME` | Optional | Model deployment name (defaults inside SDK if omitted) |
| `PURVIEW_CLIENT_APP_ID` | Yes* | Client (application) ID used for Purview authentication |
| `PURVIEW_USE_CERT_AUTH` | Optional (`true`/`false`) | Switch between certificate and interactive auth |
| `PURVIEW_TENANT_ID` | Yes (when cert auth on) | Tenant ID for certificate authentication |
| `PURVIEW_CERT_PATH` | Yes (when cert auth on) | Path to your .pfx certificate |
| `PURVIEW_CERT_PASSWORD` | Optional | Password for encrypted certs |
*A demo default exists in code for illustration only—always set your own value.
### 2. Auth Modes Supported
#### A. Interactive Browser Authentication (default)
Opens a browser on first run to sign in.
```powershell
$env:AZURE_OPENAI_ENDPOINT = "https://your-openai-instance.openai.azure.com"
$env:PURVIEW_CLIENT_APP_ID = "00000000-0000-0000-0000-000000000000"
```
#### B. Certificate Authentication
For headless / CI scenarios.
```powershell
$env:PURVIEW_USE_CERT_AUTH = "true"
$env:PURVIEW_TENANT_ID = "<tenant-guid>"
$env:PURVIEW_CERT_PATH = "C:\path\to\cert.pfx"
$env:PURVIEW_CERT_PASSWORD = "optional-password"
```
Certificate steps (summary): create / register app, generate certificate, upload public key, export .pfx with private key, grant required Graph / Purview permissions.
---
## 3. Run the Sample
From repo root:
```powershell
cd python/samples/getting_started/purview_agent
python sample_purview_agent.py
```
If interactive auth is used, a browser window will appear the first time.
---
## 4. How It Works
1. Builds an Azure OpenAI chat client (using the environment endpoint / deployment)
2. Chooses credential mode (certificate vs interactive)
3. Creates `PurviewPolicyMiddleware` with `PurviewSettings`
4. Injects middleware into the agent at construction
5. Sends two user messages sequentially
6. Prints results (or policy block messages)
Prompt blocks set a system-level message: `Prompt blocked by policy` and terminate the run early. Response blocks rewrite the output to `Response blocked by policy`.
---
## 5. Code Snippet (Middleware Injection)
```python
agent = ChatAgent(
chat_client=chat_client,
instructions="You are good at telling jokes.",
name="Joker",
middleware=[
PurviewPolicyMiddleware(credential, PurviewSettings(app_name="Sample App", default_user_id="<guid>"))
],
)
```
---
@@ -0,0 +1,179 @@
# Copyright (c) Microsoft. All rights reserved.
"""Purview policy enforcement sample (Python).
Shows:
1. Creating a basic chat agent
2. Adding Purview policy evaluation via AGENT middleware (agent-level)
3. Adding Purview policy evaluation via CHAT middleware (chat-client level)
4. Running a threaded conversation and printing results
Environment variables:
- AZURE_OPENAI_ENDPOINT (required)
- AZURE_OPENAI_DEPLOYMENT_NAME (optional, defaults to gpt-4o-mini)
- PURVIEW_CLIENT_APP_ID (required)
- PURVIEW_USE_CERT_AUTH (optional, set to "true" for certificate auth)
- PURVIEW_TENANT_ID (required if certificate auth)
- PURVIEW_CERT_PATH (required if certificate auth)
- PURVIEW_CERT_PASSWORD (optional)
- PURVIEW_DEFAULT_USER_ID (optional, user ID for Purview evaluation)
"""
from __future__ import annotations
import asyncio
import os
from typing import Any
from agent_framework import AgentRunResponse, ChatAgent, ChatMessage, Role
from agent_framework.azure import AzureOpenAIChatClient
from azure.identity import (
AzureCliCredential,
CertificateCredential,
InteractiveBrowserCredential,
)
# Purview integration pieces
from agent_framework.microsoft import (
PurviewPolicyMiddleware,
PurviewChatPolicyMiddleware,
PurviewSettings,
)
JOKER_NAME = "Joker"
JOKER_INSTRUCTIONS = "You are good at telling jokes. Keep responses concise."
def _get_env(name: str, *, required: bool = True, default: str | None = None) -> str:
val = os.environ.get(name, default)
if required and not val:
raise RuntimeError(f"Environment variable {name} is required")
return val # type: ignore[return-value]
def build_credential() -> Any:
"""Select an Azure credential for Purview authentication.
Supported modes:
1. CertificateCredential (if PURVIEW_USE_CERT_AUTH=true)
2. InteractiveBrowserCredential (requires PURVIEW_CLIENT_APP_ID)
"""
client_id = _get_env("PURVIEW_CLIENT_APP_ID", required=True)
use_cert_auth = _get_env("PURVIEW_USE_CERT_AUTH", required=False, default="false").lower() == "true"
if not client_id:
raise RuntimeError(
"PURVIEW_CLIENT_APP_ID is required for interactive browser authentication; "
"set PURVIEW_USE_CERT_AUTH=true for certificate mode instead."
)
if use_cert_auth:
tenant_id = _get_env("PURVIEW_TENANT_ID")
cert_path = _get_env("PURVIEW_CERT_PATH")
cert_password = _get_env("PURVIEW_CERT_PASSWORD", required=False, default=None)
print(f"Using Certificate Authentication (tenant: {tenant_id}, cert: {cert_path})")
return CertificateCredential(
tenant_id=tenant_id,
client_id=client_id,
certificate_path=cert_path,
password=cert_password,
)
print(f"Using Interactive Browser Authentication (client_id: {client_id})")
return InteractiveBrowserCredential(client_id=client_id)
async def run_with_agent_middleware() -> None:
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if not endpoint:
print("Skipping run: AZURE_OPENAI_ENDPOINT not set")
return
deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4o-mini")
user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID")
chat_client = AzureOpenAIChatClient(deployment_name=deployment, endpoint=endpoint, credential=AzureCliCredential())
purview_agent_middleware = PurviewPolicyMiddleware(
build_credential(),
PurviewSettings(
app_name="Agent Framework Sample App",
),
)
agent = ChatAgent(
chat_client=chat_client,
instructions=JOKER_INSTRUCTIONS,
name=JOKER_NAME,
middleware=purview_agent_middleware,
)
print("-- Agent Middleware Path --")
first: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="Tell me a joke about a pirate.", additional_properties={"user_id": user_id}))
print("First response (agent middleware):\n", first)
second: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="That was funny. Tell me another one.", additional_properties={"user_id": user_id}))
print("Second response (agent middleware):\n", second)
async def run_with_chat_middleware() -> None:
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if not endpoint:
print("Skipping chat middleware run: AZURE_OPENAI_ENDPOINT not set")
return
deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", default="gpt-4o-mini")
user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID")
chat_client = AzureOpenAIChatClient(
deployment_name=deployment,
endpoint=endpoint,
credential=AzureCliCredential(),
middleware=[
PurviewChatPolicyMiddleware(
build_credential(),
PurviewSettings(
app_name="Agent Framework Sample App (Chat)",
),
)
],
)
agent = ChatAgent(
chat_client=chat_client,
instructions=JOKER_INSTRUCTIONS,
name=JOKER_NAME,
)
print("-- Chat Middleware Path --")
first: AgentRunResponse = await agent.run(
ChatMessage(
role=Role.USER,
text="Give me a short clean joke.",
additional_properties={"user_id": user_id},
)
)
print("First response (chat middleware):\n", first)
second: AgentRunResponse = await agent.run(
ChatMessage(
role=Role.USER,
text="One more please.",
additional_properties={"user_id": user_id},
)
)
print("Second response (chat middleware):\n", second)
async def main() -> None:
print("== Purview Agent Sample (Agent & Chat Middleware) ==")
try:
await run_with_agent_middleware()
except Exception as ex: # pragma: no cover - demo resilience
print(f"Agent middleware path failed: {ex}")
try:
await run_with_chat_middleware()
except Exception as ex: # pragma: no cover - demo resilience
print(f"Chat middleware path failed: {ex}")
if __name__ == "__main__":
asyncio.run(main())
+18
View File
@@ -31,6 +31,7 @@ members = [
"agent-framework-devui",
"agent-framework-lab",
"agent-framework-mem0",
"agent-framework-purview",
"agent-framework-redis",
]
@@ -383,6 +384,23 @@ requires-dist = [
{ name = "mem0ai", specifier = ">=0.1.117" },
]
[[package]]
name = "agent-framework-purview"
version = "0.1.0b1"
source = { editable = "packages/purview" }
dependencies = [
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
[package.metadata]
requires-dist = [
{ name = "agent-framework-core", editable = "packages/core" },
{ name = "azure-core", specifier = ">=1.30.0" },
{ name = "httpx", specifier = ">=0.27.0" },
]
[[package]]
name = "agent-framework-redis"
version = "1.0.0b251007"