mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Purview Middleware (#1142)
* [Py Purview] Purview Python Initial Commit * [Py Purview] Purview Python Minor Fixes * [Py Purview] Purview Python Comment Fixesish * [Py Purview] Purview Python Agent Middleware Done * [Py Purview] Purview Python Agent Middleware Done * [Py Purview] Purview Python Lint Errors * [Py Purview] Purview Python Final Hopefully * [Py Purview] Purview Python Final Hopefully * [Py Purview] Purview Python Fix ReadMe * [Py Purview] Purview Python Fix MyPy * [Py Purview] Purview Python Minor Updates on comments * [Py Purview] Purview Python Fix Build Error --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
76ae0a62ac
commit
59da578902
@@ -243,4 +243,4 @@ class AzureOpenAIConfigMixin(OpenAIBase):
|
||||
def_headers = None
|
||||
self.default_headers = def_headers
|
||||
|
||||
super().__init__(model_id=deployment_name, client=client)
|
||||
super().__init__(model_id=deployment_name, client=client, **kwargs)
|
||||
|
||||
@@ -3,12 +3,20 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
PACKAGE_NAME = "agent_framework_copilotstudio"
|
||||
PACKAGE_EXTRA = ["microsoft-copilotstudio", "copilotstudio"]
|
||||
_IMPORTS: dict[str, tuple[str, list[str]]] = {
|
||||
"CopilotStudioAgent": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]),
|
||||
"__version__": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]),
|
||||
"acquire_token": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]),
|
||||
# Purview (Graph Data Security & Governance) integration exports
|
||||
"PurviewPolicyMiddleware": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewChatPolicyMiddleware": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewSettings": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewAppLocation": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewLocationType": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewAuthenticationError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewRateLimitError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewRequestError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
"PurviewServiceError": ("agent_framework_purview", ["microsoft-purview", "purview"]),
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +31,7 @@ def __getattr__(name: str) -> Any:
|
||||
f"please use `pip install agent-framework-{package_extra[0]}`, "
|
||||
"or update your requirements.txt or pyproject.toml file."
|
||||
) from exc
|
||||
raise AttributeError(f"Module `azure` has no attribute {name}.")
|
||||
raise AttributeError(f"Module `microsoft` has no attribute {name}.")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
|
||||
@@ -1,5 +1,29 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework_copilotstudio import CopilotStudioAgent, __version__, acquire_token
|
||||
from agent_framework_purview import (
|
||||
PurviewAppLocation,
|
||||
PurviewAuthenticationError,
|
||||
PurviewChatPolicyMiddleware,
|
||||
PurviewLocationType,
|
||||
PurviewPolicyMiddleware,
|
||||
PurviewRateLimitError,
|
||||
PurviewRequestError,
|
||||
PurviewServiceError,
|
||||
PurviewSettings,
|
||||
)
|
||||
|
||||
__all__ = ["CopilotStudioAgent", "__version__", "acquire_token"]
|
||||
__all__ = [
|
||||
"CopilotStudioAgent",
|
||||
"PurviewAppLocation",
|
||||
"PurviewAuthenticationError",
|
||||
"PurviewChatPolicyMiddleware",
|
||||
"PurviewLocationType",
|
||||
"PurviewPolicyMiddleware",
|
||||
"PurviewRateLimitError",
|
||||
"PurviewRequestError",
|
||||
"PurviewServiceError",
|
||||
"PurviewSettings",
|
||||
"__version__",
|
||||
"acquire_token",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -0,0 +1,224 @@
|
||||
## Microsoft Agent Framework – Purview Integration (Python)
|
||||
|
||||
`agent-framework-purview` adds Microsoft Purview (Microsoft Graph dataSecurityAndGovernance) policy evaluation to the Microsoft Agent Framework. It lets you enforce data security / governance policies on both the *prompt* (user input + conversation history) and the *model response* before they proceed further in your workflow.
|
||||
|
||||
> Status: **Preview**
|
||||
|
||||
### Key Features
|
||||
|
||||
- Middleware-based policy enforcement (agent-level and chat-client level)
|
||||
- Blocks or allows content at both ingress (prompt) and egress (response)
|
||||
- Works with any `ChatAgent` / agent orchestration using the standard Agent Framework middleware pipeline
|
||||
- Supports both synchronous `TokenCredential` and `AsyncTokenCredential` from `azure-identity`
|
||||
- Simple, typed configuration via `PurviewSettings` / `PurviewAppLocation`
|
||||
- Two middleware types:
|
||||
- `PurviewPolicyMiddleware` (Agent pipeline)
|
||||
- `PurviewChatPolicyMiddleware` (Chat client middleware list)
|
||||
|
||||
### When to Use
|
||||
Add Purview when you need to:
|
||||
|
||||
- Prevent sensitive or disallowed content from being sent to an LLM
|
||||
- Prevent model output containing disallowed data from leaving the system
|
||||
- Apply centrally managed policies without rewriting agent logic
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from agent_framework import ChatAgent, ChatMessage, Role
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings
|
||||
from azure.identity import InteractiveBrowserCredential
|
||||
|
||||
async def main():
|
||||
chat_client = AzureOpenAIChatClient() # uses environment for endpoint + deployment
|
||||
|
||||
purview_middleware = PurviewPolicyMiddleware(
|
||||
credential=InteractiveBrowserCredential(),
|
||||
settings=PurviewSettings(app_name="My Sample App")
|
||||
)
|
||||
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
instructions="You are a helpful assistant.",
|
||||
middleware=[purview_middleware]
|
||||
)
|
||||
|
||||
response = await agent.run(ChatMessage(role=Role.USER, text="Summarize zero trust in one sentence."))
|
||||
print(response)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
If a policy violation is detected on the prompt, the middleware terminates the run and substitutes a system message: `"Prompt blocked by policy"`. If on the response, the result becomes `"Response blocked by policy"`.
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
`PurviewClient` uses the `azure-identity` library for token acquisition. You can use any `TokenCredential` or `AsyncTokenCredential` implementation.
|
||||
|
||||
The APIs require the following Graph Permissions:
|
||||
- ProtectionScopes.Compute.All : (userProtectionScopeContainer)[https://learn.microsoft.com/en-us/graph/api/userprotectionscopecontainer-compute]
|
||||
- Content.Process.All : (processContent)[https://learn.microsoft.com/en-us/graph/api/userdatasecurityandgovernance-processcontent]
|
||||
- ContentActivity.Write : (contentActivity)[https://learn.microsoft.com/en-us/graph/api/activitiescontainer-post-contentactivities]
|
||||
|
||||
### Scopes
|
||||
`PurviewSettings.get_scopes()` derives the Graph scope list (currently `https://graph.microsoft.com/.default` style).
|
||||
|
||||
### Tenant Enablement for Purview
|
||||
- The tenant requires an e5 license and consumptive billing setup.
|
||||
- There need to be (Data Loss Prevention)[https://learn.microsoft.com/en-us/purview/dlp-create-deploy-policy] or (Data Collection Policies)[https://learn.microsoft.com/en-us/purview/collection-policies-policy-reference] that apply to the user to call Process Content API else it calls Content Activities API for auditing the message.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### `PurviewSettings`
|
||||
|
||||
```python
|
||||
PurviewSettings(
|
||||
app_name="My App", # Display / logical name
|
||||
tenant_id=None, # Optional – used mainly for auth context
|
||||
purview_app_location=None, # Optional PurviewAppLocation for scoping
|
||||
graph_base_uri="https://graph.microsoft.com/v1.0/",
|
||||
process_inline=False, # Reserved for future inline processing optimizations
|
||||
blocked_prompt_message="Prompt blocked by policy", # Custom message for blocked prompts
|
||||
blocked_response_message="Response blocked by policy" # Custom message for blocked responses
|
||||
)
|
||||
```
|
||||
|
||||
To scope evaluation by location (application, URL, or domain):
|
||||
|
||||
```python
|
||||
from agent_framework.microsoft import (
|
||||
PurviewAppLocation,
|
||||
PurviewLocationType,
|
||||
PurviewSettings,
|
||||
)
|
||||
|
||||
settings = PurviewSettings(
|
||||
app_name="Contoso Support",
|
||||
purview_app_location=PurviewAppLocation(
|
||||
location_type=PurviewLocationType.APPLICATION,
|
||||
location_value="<app-client-id>"
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Customizing Blocked Messages
|
||||
|
||||
By default, when Purview blocks a prompt or response, the middleware returns a generic system message. You can customize these messages by providing your own text in the `PurviewSettings`:
|
||||
|
||||
```python
|
||||
from agent_framework.microsoft import PurviewSettings
|
||||
|
||||
settings = PurviewSettings(
|
||||
app_name="My App",
|
||||
blocked_prompt_message="Your request contains content that violates our policies. Please rephrase and try again.",
|
||||
blocked_response_message="The response was blocked due to policy restrictions. Please contact support if you need assistance."
|
||||
)
|
||||
```
|
||||
|
||||
This is useful for:
|
||||
- Providing more user-friendly error messages
|
||||
- Including support contact information
|
||||
- Localizing messages for different languages
|
||||
- Adding branding or specific guidance for your application
|
||||
|
||||
### Selecting Agent vs Chat Middleware
|
||||
|
||||
Use the agent middleware when you already have / want the full agent pipeline:
|
||||
|
||||
```python
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
credential = DefaultAzureCredential()
|
||||
client = AzureOpenAIChatClient()
|
||||
|
||||
agent = ChatAgent(
|
||||
chat_client=client,
|
||||
instructions="You are helpful.",
|
||||
middleware=[PurviewPolicyMiddleware(credential, PurviewSettings(app_name="My App"))]
|
||||
)
|
||||
```
|
||||
|
||||
Use the chat middleware when you attach directly to a chat client (e.g. minimal agent shell or custom orchestration):
|
||||
|
||||
```python
|
||||
import os
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.microsoft import PurviewChatPolicyMiddleware, PurviewSettings
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
credential = DefaultAzureCredential()
|
||||
|
||||
chat_client = AzureOpenAIChatClient(
|
||||
deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
|
||||
endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
credential=credential,
|
||||
middleware=[
|
||||
PurviewChatPolicyMiddleware(credential, PurviewSettings(app_name="My App (Chat)"))
|
||||
],
|
||||
)
|
||||
|
||||
agent = ChatAgent(chat_client=chat_client, instructions="You are helpful.")
|
||||
```
|
||||
|
||||
The policy logic is identical; the difference is only the hook point in the pipeline.
|
||||
|
||||
---
|
||||
|
||||
## Middleware Lifecycle
|
||||
|
||||
1. Before agent execution (`prompt phase`): all `context.messages` are evaluated.
|
||||
2. If blocked: `context.result` is replaced with a system message and `context.terminate = True`.
|
||||
3. After successful agent execution (`response phase`): the produced messages are evaluated.
|
||||
4. If blocked: result messages are replaced with a blocking notice.
|
||||
|
||||
When a user identifier is discovered (e.g. in `ChatMessage.additional_properties['user_id']`) during the prompt phase it is reused for the response phase so both evaluations map consistently to the same user.
|
||||
|
||||
You can customize the blocking messages using the `blocked_prompt_message` and `blocked_response_message` fields in `PurviewSettings`. For more advanced scenarios, you can wrap the middleware or post-process `context.result` in later middleware.
|
||||
|
||||
---
|
||||
|
||||
## Exceptions
|
||||
|
||||
| Exception | Scenario |
|
||||
|-----------|----------|
|
||||
| `PurviewAuthenticationError` | Token acquisition / validation issues |
|
||||
| `PurviewRateLimitError` | 429 responses from service |
|
||||
| `PurviewRequestError` | 4xx client errors (bad input, unauthorized, forbidden) |
|
||||
| `PurviewServiceError` | 5xx or unexpected service errors |
|
||||
|
||||
Catch broadly if you want unified fallback:
|
||||
|
||||
```python
|
||||
from agent_framework.microsoft import (
|
||||
PurviewAuthenticationError, PurviewRateLimitError,
|
||||
PurviewRequestError, PurviewServiceError
|
||||
)
|
||||
|
||||
try:
|
||||
...
|
||||
except (PurviewAuthenticationError, PurviewRateLimitError, PurviewRequestError, PurviewServiceError) as ex:
|
||||
# Log / degrade gracefully
|
||||
print(f"Purview enforcement skipped: {ex}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
- Provide a `user_id` per request (e.g. in `ChatMessage(..., additional_properties={"user_id": "<guid>"})`) when possible for per-user policy scoping; otherwise supply a default via settings or environment.
|
||||
- Blocking messages can be customized via `blocked_prompt_message` and `blocked_response_message` in `PurviewSettings`. By default, they are "Prompt blocked by policy" and "Response blocked by policy" respectively.
|
||||
- Streaming responses: post-response policy evaluation presently applies only to non-streaming chat responses.
|
||||
- Errors during policy checks are logged and do not fail the run; they degrade gracefully.
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from ._exceptions import (
|
||||
PurviewAuthenticationError,
|
||||
PurviewRateLimitError,
|
||||
PurviewRequestError,
|
||||
PurviewServiceError,
|
||||
)
|
||||
from ._middleware import PurviewChatPolicyMiddleware, PurviewPolicyMiddleware
|
||||
from ._settings import PurviewAppLocation, PurviewLocationType, PurviewSettings
|
||||
|
||||
__all__ = [
|
||||
"PurviewAppLocation",
|
||||
"PurviewAuthenticationError",
|
||||
"PurviewChatPolicyMiddleware",
|
||||
"PurviewLocationType",
|
||||
"PurviewPolicyMiddleware",
|
||||
"PurviewRateLimitError",
|
||||
"PurviewRequestError",
|
||||
"PurviewServiceError",
|
||||
"PurviewSettings",
|
||||
]
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT
|
||||
from agent_framework.observability import get_tracer
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
from ._exceptions import (
|
||||
PurviewAuthenticationError,
|
||||
PurviewRateLimitError,
|
||||
PurviewRequestError,
|
||||
PurviewServiceError,
|
||||
)
|
||||
from ._models import (
|
||||
ContentActivitiesRequest,
|
||||
ContentActivitiesResponse,
|
||||
ProcessContentRequest,
|
||||
ProcessContentResponse,
|
||||
ProtectionScopesRequest,
|
||||
ProtectionScopesResponse,
|
||||
)
|
||||
from ._settings import PurviewSettings
|
||||
|
||||
|
||||
class PurviewClient:
|
||||
"""Async client for calling Graph Purview endpoints.
|
||||
|
||||
Supports both synchronous TokenCredential and asynchronous AsyncTokenCredential implementations.
|
||||
A sync credential will be invoked in a thread to avoid blocking the event loop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credential: TokenCredential | AsyncTokenCredential,
|
||||
settings: PurviewSettings,
|
||||
*,
|
||||
timeout: float | None = 10.0,
|
||||
):
|
||||
self._credential: TokenCredential | AsyncTokenCredential = credential
|
||||
self._settings = settings
|
||||
self._graph_uri = settings.graph_base_uri.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._client = httpx.AsyncClient(timeout=timeout)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
|
||||
async def _get_token(self, *, tenant_id: str | None = None) -> str:
|
||||
"""Acquire an access token using either async or sync credential."""
|
||||
scopes = self._settings.get_scopes()
|
||||
cred = self._credential
|
||||
token = cred.get_token(*scopes, tenant_id=tenant_id)
|
||||
token = await token if inspect.isawaitable(token) else token
|
||||
return token.token
|
||||
|
||||
@staticmethod
|
||||
def _extract_token_info(token: str) -> dict[str, Any]:
|
||||
parts = token.split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError("Invalid JWT token format")
|
||||
payload = parts[1]
|
||||
rem = len(payload) % 4
|
||||
if rem:
|
||||
payload += "=" * (4 - rem)
|
||||
decoded = base64.urlsafe_b64decode(payload)
|
||||
data = json.loads(decoded.decode("utf-8"))
|
||||
return {
|
||||
"user_id": data.get("oid") if data.get("idtyp") == "user" else None,
|
||||
"tenant_id": data.get("tid"),
|
||||
"client_id": data.get("appid"),
|
||||
}
|
||||
|
||||
async def get_user_info_from_token(self, *, tenant_id: str | None = None) -> dict[str, Any]:
|
||||
token = await self._get_token(tenant_id=tenant_id)
|
||||
return self._extract_token_info(token)
|
||||
|
||||
async def process_content(self, request: ProcessContentRequest) -> ProcessContentResponse:
|
||||
with get_tracer().start_as_current_span("purview.process_content"):
|
||||
token = await self._get_token(tenant_id=request.tenant_id)
|
||||
url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/processContent"
|
||||
return cast(ProcessContentResponse, await self._post(url, request, ProcessContentResponse, token))
|
||||
|
||||
async def get_protection_scopes(self, request: ProtectionScopesRequest) -> ProtectionScopesResponse:
|
||||
with get_tracer().start_as_current_span("purview.get_protection_scopes"):
|
||||
token = await self._get_token()
|
||||
url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/protectionScopes/compute"
|
||||
return cast(ProtectionScopesResponse, await self._post(url, request, ProtectionScopesResponse, token))
|
||||
|
||||
async def send_content_activities(self, request: ContentActivitiesRequest) -> ContentActivitiesResponse:
|
||||
with get_tracer().start_as_current_span("purview.send_content_activities"):
|
||||
token = await self._get_token()
|
||||
url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/activities/contentActivities"
|
||||
return cast(ContentActivitiesResponse, await self._post(url, request, ContentActivitiesResponse, token))
|
||||
|
||||
async def _post(self, url: str, model: Any, response_type: type[Any], token: str) -> Any:
|
||||
payload = model.model_dump(by_alias=True, exclude_none=True, mode="json")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"User-Agent": AGENT_FRAMEWORK_USER_AGENT,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
resp = await self._client.post(url, json=payload, headers=headers)
|
||||
if resp.status_code in (401, 403):
|
||||
raise PurviewAuthenticationError(f"Auth failure {resp.status_code}: {resp.text}")
|
||||
if resp.status_code == 429:
|
||||
raise PurviewRateLimitError(f"Rate limited {resp.status_code}: {resp.text}")
|
||||
if resp.status_code not in (200, 201, 202):
|
||||
raise PurviewRequestError(f"Purview request failed {resp.status_code}: {resp.text}")
|
||||
try:
|
||||
data = resp.json()
|
||||
except ValueError:
|
||||
data = {}
|
||||
try:
|
||||
# Prefer pydantic-style model_validate if present, else fall back to constructor.
|
||||
if hasattr(response_type, "model_validate"):
|
||||
return response_type.model_validate(data) # type: ignore[no-any-return]
|
||||
return response_type(**data) # type: ignore[call-arg, no-any-return]
|
||||
except Exception as ex: # pragma: no cover
|
||||
raise PurviewServiceError(f"Failed to deserialize Purview response: {ex}") from ex
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
"""Purview specific exceptions (minimal error shaping)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agent_framework.exceptions import ServiceResponseException
|
||||
|
||||
__all__ = [
|
||||
"PurviewAuthenticationError",
|
||||
"PurviewRateLimitError",
|
||||
"PurviewRequestError",
|
||||
"PurviewServiceError",
|
||||
]
|
||||
|
||||
|
||||
class PurviewServiceError(ServiceResponseException):
|
||||
"""Base exception for Purview errors."""
|
||||
|
||||
|
||||
class PurviewAuthenticationError(PurviewServiceError):
|
||||
"""Authentication / authorization failure (401/403)."""
|
||||
|
||||
|
||||
class PurviewRateLimitError(PurviewServiceError):
|
||||
"""Rate limiting or throttling (429)."""
|
||||
|
||||
|
||||
class PurviewRequestError(PurviewServiceError):
|
||||
"""Other non-success HTTP errors."""
|
||||
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from agent_framework import AgentMiddleware, AgentRunContext, ChatContext, ChatMiddleware
|
||||
from agent_framework._logging import get_logger
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
from ._client import PurviewClient
|
||||
from ._models import Activity
|
||||
from ._processor import ScopedContentProcessor
|
||||
from ._settings import PurviewSettings
|
||||
|
||||
logger = get_logger("agent_framework.purview")
|
||||
|
||||
|
||||
class PurviewPolicyMiddleware(AgentMiddleware):
|
||||
"""Agent middleware that enforces Purview policies on prompt and response.
|
||||
|
||||
Accepts either a synchronous TokenCredential or an AsyncTokenCredential.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings
|
||||
from agent_framework import ChatAgent
|
||||
|
||||
credential = ... # TokenCredential or AsyncTokenCredential
|
||||
settings = PurviewSettings(app_name="My App")
|
||||
agent = ChatAgent(
|
||||
chat_client=client, instructions="...", middleware=[PurviewPolicyMiddleware(credential, settings)]
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credential: TokenCredential | AsyncTokenCredential,
|
||||
settings: PurviewSettings,
|
||||
) -> None:
|
||||
self._client = PurviewClient(credential, settings)
|
||||
self._processor = ScopedContentProcessor(self._client, settings)
|
||||
self._settings = settings
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: AgentRunContext,
|
||||
next: Callable[[AgentRunContext], Awaitable[None]],
|
||||
) -> None: # type: ignore[override]
|
||||
resolved_user_id: str | None = None
|
||||
try:
|
||||
# Pre (prompt) check
|
||||
should_block_prompt, resolved_user_id = await self._processor.process_messages(
|
||||
context.messages, Activity.UPLOAD_TEXT
|
||||
)
|
||||
if should_block_prompt:
|
||||
from agent_framework import AgentRunResponse, ChatMessage, Role
|
||||
|
||||
context.result = AgentRunResponse(
|
||||
messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_prompt_message)]
|
||||
)
|
||||
context.terminate = True
|
||||
return
|
||||
except Exception as ex:
|
||||
# Log and continue if there's an error in the pre-check
|
||||
logger.error(f"Error in Purview policy pre-check: {ex}")
|
||||
|
||||
await next(context)
|
||||
|
||||
try:
|
||||
# Post (response) check only if we have a normal AgentRunResponse
|
||||
# Use the same user_id from the request for the response evaluation
|
||||
if context.result and not context.is_streaming:
|
||||
should_block_response, _ = await self._processor.process_messages(
|
||||
context.result.messages, # type: ignore[union-attr]
|
||||
Activity.UPLOAD_TEXT,
|
||||
user_id=resolved_user_id,
|
||||
)
|
||||
if should_block_response:
|
||||
from agent_framework import AgentRunResponse, ChatMessage, Role
|
||||
|
||||
context.result = AgentRunResponse(
|
||||
messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_response_message)]
|
||||
)
|
||||
else:
|
||||
# Streaming responses are not supported for post-checks
|
||||
logger.debug("Streaming responses are not supported for Purview policy post-checks")
|
||||
except Exception as ex:
|
||||
# Log and continue if there's an error in the post-check
|
||||
logger.error(f"Error in Purview policy post-check: {ex}")
|
||||
|
||||
|
||||
class PurviewChatPolicyMiddleware(ChatMiddleware):
|
||||
"""Chat middleware variant for Purview policy evaluation.
|
||||
|
||||
This allows users to attach Purview enforcement directly to a chat client
|
||||
|
||||
Behavior:
|
||||
* Pre-chat: evaluates outgoing (user + context) messages as an upload activity
|
||||
and can terminate execution if blocked.
|
||||
* Post-chat: evaluates the received response messages (streaming is not presently supported)
|
||||
and can replace them with a blocked message. Uses the same user_id from the request
|
||||
to ensure consistent user identity throughout the evaluation.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
from agent_framework.microsoft import PurviewChatPolicyMiddleware, PurviewSettings
|
||||
from agent_framework import ChatClient
|
||||
|
||||
credential = ... # TokenCredential or AsyncTokenCredential
|
||||
settings = PurviewSettings(app_name="My App")
|
||||
client = ChatClient(..., middleware=[PurviewChatPolicyMiddleware(credential, settings)])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credential: TokenCredential | AsyncTokenCredential,
|
||||
settings: PurviewSettings,
|
||||
) -> None:
|
||||
self._client = PurviewClient(credential, settings)
|
||||
self._processor = ScopedContentProcessor(self._client, settings)
|
||||
self._settings = settings
|
||||
|
||||
async def process(
|
||||
self,
|
||||
context: ChatContext,
|
||||
next: Callable[[ChatContext], Awaitable[None]],
|
||||
) -> None: # type: ignore[override]
|
||||
resolved_user_id: str | None = None
|
||||
try:
|
||||
should_block_prompt, resolved_user_id = await self._processor.process_messages(
|
||||
context.messages, Activity.UPLOAD_TEXT
|
||||
)
|
||||
if should_block_prompt:
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
context.result = [ # type: ignore[assignment]
|
||||
ChatMessage(role="system", text=self._settings.blocked_prompt_message)
|
||||
]
|
||||
context.terminate = True
|
||||
return
|
||||
except Exception as ex:
|
||||
logger.error(f"Error in Purview policy pre-check: {ex}")
|
||||
|
||||
await next(context)
|
||||
|
||||
try:
|
||||
# Post (response) evaluation only if non-streaming and we have messages result shape
|
||||
# Use the same user_id from the request for the response evaluation
|
||||
if context.result and not context.is_streaming:
|
||||
result_obj = context.result
|
||||
messages = getattr(result_obj, "messages", None)
|
||||
if messages:
|
||||
should_block_response, _ = await self._processor.process_messages(
|
||||
messages, Activity.UPLOAD_TEXT, user_id=resolved_user_id
|
||||
)
|
||||
if should_block_response:
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
context.result = [ # type: ignore[assignment]
|
||||
ChatMessage(role="system", text=self._settings.blocked_response_message)
|
||||
]
|
||||
else:
|
||||
logger.debug("Streaming responses are not supported for Purview policy post-checks")
|
||||
except Exception as ex:
|
||||
logger.error(f"Error in Purview policy post-check: {ex}")
|
||||
@@ -0,0 +1,992 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unified Purview model definitions and public export surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, Flag, auto
|
||||
from typing import Any, ClassVar, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework._logging import get_logger
|
||||
from agent_framework._serialization import SerializationMixin
|
||||
|
||||
logger = get_logger("agent_framework.purview")
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Enums & flag helpers
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Activity(str, Enum):
|
||||
"""High-level activity types representing user or agent operations."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
UPLOAD_TEXT = "uploadText"
|
||||
UPLOAD_FILE = "uploadFile"
|
||||
DOWNLOAD_TEXT = "downloadText"
|
||||
DOWNLOAD_FILE = "downloadFile"
|
||||
|
||||
|
||||
class ProtectionScopeActivities(Flag):
|
||||
"""Flag enumeration of activities used in policy protection scopes."""
|
||||
|
||||
NONE = 0
|
||||
UPLOAD_TEXT = auto()
|
||||
UPLOAD_FILE = auto()
|
||||
DOWNLOAD_TEXT = auto()
|
||||
DOWNLOAD_FILE = auto()
|
||||
UNKNOWN_FUTURE_VALUE = auto()
|
||||
|
||||
def __int__(self) -> int: # pragma: no cover
|
||||
return self.value
|
||||
|
||||
|
||||
FlagT = TypeVar("FlagT", bound=Flag)
|
||||
|
||||
_PROTECTION_SCOPE_ACTIVITIES_MAP: dict[str, ProtectionScopeActivities] = {
|
||||
"none": ProtectionScopeActivities.NONE,
|
||||
"uploadText": ProtectionScopeActivities.UPLOAD_TEXT,
|
||||
"uploadFile": ProtectionScopeActivities.UPLOAD_FILE,
|
||||
"downloadText": ProtectionScopeActivities.DOWNLOAD_TEXT,
|
||||
"downloadFile": ProtectionScopeActivities.DOWNLOAD_FILE,
|
||||
"unknownFutureValue": ProtectionScopeActivities.UNKNOWN_FUTURE_VALUE,
|
||||
}
|
||||
_PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER: list[tuple[str, ProtectionScopeActivities]] = [
|
||||
("uploadText", ProtectionScopeActivities.UPLOAD_TEXT),
|
||||
("uploadFile", ProtectionScopeActivities.UPLOAD_FILE),
|
||||
("downloadText", ProtectionScopeActivities.DOWNLOAD_TEXT),
|
||||
("downloadFile", ProtectionScopeActivities.DOWNLOAD_FILE),
|
||||
]
|
||||
|
||||
|
||||
def deserialize_flag(
|
||||
value: object, mapping: Mapping[str, FlagT], enum_cls: type[FlagT]
|
||||
) -> FlagT | None: # pragma: no cover
|
||||
"""Deserialize arbitrary input into a flag enum instance."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, enum_cls):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
try:
|
||||
return enum_cls(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
flag_value = enum_cls(0)
|
||||
parts: list[str] = []
|
||||
|
||||
if isinstance(value, str):
|
||||
raw = value.strip()
|
||||
if not raw:
|
||||
return enum_cls(0)
|
||||
parts.extend([p.strip() for p in raw.split(",") if p.strip()])
|
||||
elif isinstance(value, (list, tuple, set)):
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
parts.extend([p.strip() for p in item.split(",") if p.strip()])
|
||||
elif isinstance(item, enum_cls):
|
||||
flag_value |= item
|
||||
elif isinstance(item, int):
|
||||
try:
|
||||
flag_value |= enum_cls(item)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to convert int {item} to {enum_cls.__name__}")
|
||||
else:
|
||||
return None
|
||||
|
||||
for part in parts:
|
||||
member = mapping.get(part)
|
||||
if member is not None:
|
||||
flag_value |= member
|
||||
|
||||
if flag_value == enum_cls(0):
|
||||
none_member = mapping.get("none")
|
||||
if none_member is not None:
|
||||
return none_member # type: ignore[return-value,index]
|
||||
return flag_value
|
||||
|
||||
|
||||
def serialize_flag(
|
||||
flag_value: Flag | int | None, ordered_parts: Sequence[tuple[str, Flag]]
|
||||
) -> str | None: # pragma: no cover
|
||||
"""Serialize a flag enum (or int) into a stable, comma-separated string."""
|
||||
if flag_value is None:
|
||||
return None
|
||||
if isinstance(flag_value, int):
|
||||
if flag_value == 0:
|
||||
return "none"
|
||||
int_parts: list[str] = []
|
||||
for name, member in ordered_parts:
|
||||
if flag_value & member.value:
|
||||
int_parts.append(name)
|
||||
return ",".join(int_parts) if int_parts else "none"
|
||||
if not isinstance(flag_value, Flag):
|
||||
return None
|
||||
if flag_value.value == 0:
|
||||
return "none"
|
||||
parts: list[str] = []
|
||||
for name, member in ordered_parts:
|
||||
if flag_value & member:
|
||||
parts.append(name)
|
||||
return ",".join(parts) if parts else "none"
|
||||
|
||||
|
||||
class DlpAction(str, Enum):
|
||||
BLOCK_ACCESS = "blockAccess"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class RestrictionAction(str, Enum):
|
||||
BLOCK = "block"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class ProtectionScopeState(str, Enum):
|
||||
NOT_MODIFIED = "notModified"
|
||||
MODIFIED = "modified"
|
||||
UNKNOWN_FUTURE_VALUE = "unknownFutureValue"
|
||||
|
||||
|
||||
class ExecutionMode(str, Enum):
|
||||
EVALUATE_INLINE = "evaluateInline"
|
||||
EVALUATE_OFFLINE = "evaluateOffline"
|
||||
UNKNOWN_FUTURE_VALUE = "unknownFutureValue"
|
||||
|
||||
|
||||
class PolicyPivotProperty(str, Enum):
|
||||
NONE = "none"
|
||||
ACTIVITY = "activity"
|
||||
LOCATION = "location"
|
||||
UNKNOWN_FUTURE_VALUE = "unknownFutureValue"
|
||||
|
||||
|
||||
def translate_activity(activity: Activity) -> ProtectionScopeActivities:
|
||||
mapping = {
|
||||
Activity.UNKNOWN: ProtectionScopeActivities.NONE,
|
||||
Activity.UPLOAD_TEXT: ProtectionScopeActivities.UPLOAD_TEXT,
|
||||
Activity.UPLOAD_FILE: ProtectionScopeActivities.UPLOAD_FILE,
|
||||
Activity.DOWNLOAD_TEXT: ProtectionScopeActivities.DOWNLOAD_TEXT,
|
||||
Activity.DOWNLOAD_FILE: ProtectionScopeActivities.DOWNLOAD_FILE,
|
||||
}
|
||||
return mapping.get(activity, ProtectionScopeActivities.UNKNOWN_FUTURE_VALUE)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Simple value models
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AliasSerializable(SerializationMixin):
|
||||
"""Base class adding alias mapping + pydantic-compat helpers.
|
||||
|
||||
Each subclass can define ``_ALIASES`` mapping internal attribute name -> external serialized key.
|
||||
``to_dict`` will emit external keys; ``from_dict`` (via ``__init__`` preprocessing) accepts either form.
|
||||
|
||||
Provides light-weight compatibility helpers ``model_dump`` / ``model_validate``
|
||||
"""
|
||||
|
||||
_ALIASES: ClassVar[dict[str, str]] = {}
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
# Normalize alias keys -> internal names across the entire class hierarchy
|
||||
# Collect all aliases from parent classes too
|
||||
all_aliases: dict[str, str] = {}
|
||||
for cls in type(self).__mro__:
|
||||
if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict):
|
||||
for internal, external in cls._ALIASES.items():
|
||||
if external not in all_aliases:
|
||||
all_aliases[external] = internal
|
||||
|
||||
# Normalize all aliased keys in kwargs
|
||||
for external, internal in all_aliases.items():
|
||||
if external in kwargs and internal not in kwargs:
|
||||
kwargs[internal] = kwargs.pop(external)
|
||||
|
||||
# Set normalized kwargs as attributes
|
||||
# This will overwrite any None values that child __init__ may have set from default params
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compatibility helpers
|
||||
# ------------------------------------------------------------------
|
||||
def model_dump(self, *, by_alias: bool = True, exclude_none: bool = True, **_: Any) -> dict[str, Any]:
|
||||
# Use self.to_dict() to get alias translation
|
||||
d = self.to_dict(exclude_none=exclude_none)
|
||||
# If by_alias=False, translate external -> internal (rarely needed; default True)
|
||||
if not by_alias and self._ALIASES:
|
||||
reverse = {v: k for k, v in self._ALIASES.items()}
|
||||
translated: dict[str, Any] = {}
|
||||
for k, v in d.items():
|
||||
translated[reverse.get(k, k)] = v
|
||||
return translated
|
||||
return d
|
||||
|
||||
def model_dump_json(self, *, by_alias: bool = True, exclude_none: bool = True, **kwargs: Any) -> str:
|
||||
import json
|
||||
|
||||
return json.dumps(self.model_dump(by_alias=by_alias, exclude_none=exclude_none, **kwargs))
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, value: MutableMapping[str, Any]) -> _AliasSerializable: # type: ignore[name-defined]
|
||||
return cls(**value)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override to handle alias emission
|
||||
# ------------------------------------------------------------------
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
|
||||
base = SerializationMixin.to_dict(self, exclude=exclude, exclude_none=exclude_none)
|
||||
|
||||
# For Graph API models, remove the auto-generated 'type' field if it's in DEFAULT_EXCLUDE
|
||||
if "type" in self.DEFAULT_EXCLUDE:
|
||||
base.pop("type", None)
|
||||
|
||||
# Collect all aliases from class hierarchy
|
||||
all_aliases: dict[str, str] = {}
|
||||
for cls in type(self).__mro__:
|
||||
if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict):
|
||||
# Parent aliases first (will be overridden by child if same key)
|
||||
for internal, external in cls._ALIASES.items():
|
||||
if internal not in all_aliases:
|
||||
all_aliases[internal] = external
|
||||
|
||||
if not all_aliases:
|
||||
return base
|
||||
|
||||
# Translate internal -> external keys (except 'type' reserved)
|
||||
translated: dict[str, Any] = {}
|
||||
for k, v in base.items():
|
||||
if k == "type":
|
||||
translated[k] = v
|
||||
continue
|
||||
external = all_aliases.get(k, k)
|
||||
translated[external] = v
|
||||
return translated
|
||||
|
||||
|
||||
class PolicyLocation(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"data_type": "@odata.type"}
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"type"} # Exclude auto-generated type field for Graph API
|
||||
|
||||
def __init__(self, data_type: str | None = None, value: str | None = None, **kwargs: Any) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "@odata.type" in kwargs:
|
||||
data_type = kwargs["@odata.type"]
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.data_type = data_type
|
||||
self.value = value
|
||||
|
||||
|
||||
class ActivityMetadata(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"activity": "activity"}
|
||||
|
||||
def __init__(self, activity: Activity, **kwargs: Any) -> None:
|
||||
super().__init__(activity=activity, **kwargs)
|
||||
self.activity = activity
|
||||
|
||||
|
||||
class OperatingSystemSpecifications(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"operating_system_platform": "operatingSystemPlatform",
|
||||
"operating_system_version": "operatingSystemVersion",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operating_system_platform: str | None = None,
|
||||
operating_system_version: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "operatingSystemPlatform" in kwargs:
|
||||
operating_system_platform = kwargs["operatingSystemPlatform"]
|
||||
if "operatingSystemVersion" in kwargs:
|
||||
operating_system_version = kwargs["operatingSystemVersion"]
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.operating_system_platform = operating_system_platform
|
||||
self.operating_system_version = operating_system_version
|
||||
|
||||
|
||||
class DeviceMetadata(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"ip_address": "ipAddress",
|
||||
"operating_system_specifications": "operatingSystemSpecifications",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_address: str | None = None,
|
||||
operating_system_specifications: OperatingSystemSpecifications | MutableMapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "ipAddress" in kwargs:
|
||||
ip_address = kwargs["ipAddress"]
|
||||
if "operatingSystemSpecifications" in kwargs:
|
||||
operating_system_specifications = kwargs["operatingSystemSpecifications"]
|
||||
|
||||
# Convert nested objects
|
||||
if isinstance(operating_system_specifications, MutableMapping):
|
||||
operating_system_specifications = OperatingSystemSpecifications(**operating_system_specifications)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.ip_address = ip_address
|
||||
self.operating_system_specifications = operating_system_specifications
|
||||
|
||||
|
||||
class IntegratedAppMetadata(_AliasSerializable):
|
||||
def __init__(self, name: str | None = None, version: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(name=name, version=version, **kwargs)
|
||||
self.name = name
|
||||
self.version = version
|
||||
|
||||
|
||||
class ProtectedAppMetadata(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"application_location": "applicationLocation"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = None,
|
||||
version: str | None = None,
|
||||
application_location: PolicyLocation | MutableMapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "applicationLocation" in kwargs:
|
||||
application_location = kwargs["applicationLocation"]
|
||||
|
||||
# Convert nested objects
|
||||
if isinstance(application_location, MutableMapping):
|
||||
application_location = PolicyLocation(**application_location)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.name = name
|
||||
self.version = version
|
||||
self.application_location = application_location # type: ignore[assignment]
|
||||
|
||||
|
||||
class DlpActionInfo(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"restriction_action": "restrictionAction"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action: DlpAction | None = None,
|
||||
restriction_action: RestrictionAction | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "restrictionAction" in kwargs:
|
||||
restriction_action = kwargs["restrictionAction"]
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.action = action
|
||||
self.restriction_action = restriction_action
|
||||
|
||||
|
||||
class AccessedResourceDetails(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"label_id": "labelId",
|
||||
"access_type": "accessType",
|
||||
"is_cross_prompt_injection_detected": "isCrossPromptInjectionDetected",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identifier: str | None = None,
|
||||
name: str | None = None,
|
||||
url: str | None = None,
|
||||
label_id: str | None = None,
|
||||
access_type: str | None = None,
|
||||
status: str | None = None,
|
||||
is_cross_prompt_injection_detected: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "labelId" in kwargs:
|
||||
label_id = kwargs["labelId"]
|
||||
if "accessType" in kwargs:
|
||||
access_type = kwargs["accessType"]
|
||||
if "isCrossPromptInjectionDetected" in kwargs:
|
||||
is_cross_prompt_injection_detected = kwargs["isCrossPromptInjectionDetected"]
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.identifier = identifier
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.label_id = label_id
|
||||
self.access_type = access_type
|
||||
self.status = status
|
||||
self.is_cross_prompt_injection_detected = is_cross_prompt_injection_detected
|
||||
|
||||
|
||||
class AiInteractionPlugin(_AliasSerializable):
|
||||
def __init__(
|
||||
self,
|
||||
identifier: str | None = None,
|
||||
name: str | None = None,
|
||||
version: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(identifier=identifier, name=name, version=version, **kwargs)
|
||||
self.identifier = identifier
|
||||
self.name = name
|
||||
self.version = version
|
||||
|
||||
|
||||
class AiAgentInfo(_AliasSerializable):
|
||||
def __init__(
|
||||
self,
|
||||
identifier: str | None = None,
|
||||
name: str | None = None,
|
||||
version: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(identifier=identifier, name=name, version=version, **kwargs)
|
||||
self.identifier = identifier
|
||||
self.name = name
|
||||
self.version = version
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Content models
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GraphDataTypeBase(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"data_type": "@odata.type"}
|
||||
# Exclude the auto-generated 'type' field - Graph API uses @odata.type instead
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"type"}
|
||||
|
||||
def __init__(self, data_type: str, **kwargs: Any) -> None:
|
||||
super().__init__(data_type=data_type, **kwargs)
|
||||
self.data_type = data_type
|
||||
|
||||
|
||||
class ContentBase(GraphDataTypeBase):
|
||||
pass
|
||||
|
||||
|
||||
class PurviewTextContent(ContentBase):
|
||||
def __init__(self, data: str, data_type: str = "microsoft.graph.textContent", **kwargs: Any) -> None:
|
||||
super().__init__(data_type=data_type, **kwargs)
|
||||
self.data = data
|
||||
|
||||
|
||||
class PurviewBinaryContent(ContentBase):
|
||||
def __init__(self, data: bytes, data_type: str = "microsoft.graph.binaryContent", **kwargs: Any) -> None:
|
||||
super().__init__(data_type=data_type, **kwargs)
|
||||
self.data = data
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
|
||||
import base64
|
||||
|
||||
base = super().to_dict(exclude=exclude, exclude_none=exclude_none)
|
||||
# Ensure bytes encoded as base64 string like pydantic
|
||||
data_bytes = getattr(self, "data", b"") or b""
|
||||
base["data"] = base64.b64encode(data_bytes).decode("utf-8")
|
||||
return base
|
||||
|
||||
|
||||
class ProcessConversationMetadata(GraphDataTypeBase):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"correlation_id": "correlationId",
|
||||
"sequence_number": "sequenceNumber",
|
||||
"is_truncated": "isTruncated",
|
||||
"created_date_time": "createdDateTime",
|
||||
"modified_date_time": "modifiedDateTime",
|
||||
"parent_message_id": "parentMessageId",
|
||||
"accessed_resources": "accessedResources_v2",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identifier: str | None = None,
|
||||
content: PurviewTextContent | PurviewBinaryContent | ContentBase | MutableMapping[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
is_truncated: bool | None = None,
|
||||
data_type: str = "microsoft.graph.processConversationMetadata", # emitted via base
|
||||
correlation_id: str | None = None,
|
||||
sequence_number: int | None = None,
|
||||
length: int | None = None,
|
||||
created_date_time: datetime | None = None,
|
||||
modified_date_time: datetime | None = None,
|
||||
parent_message_id: str | None = None,
|
||||
accessed_resources: list[AccessedResourceDetails | MutableMapping[str, Any]] | None = None,
|
||||
plugins: list[AiInteractionPlugin | MutableMapping[str, Any]] | None = None,
|
||||
agents: list[AiAgentInfo | MutableMapping[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "correlationId" in kwargs:
|
||||
correlation_id = kwargs["correlationId"]
|
||||
if "sequenceNumber" in kwargs:
|
||||
sequence_number = kwargs["sequenceNumber"]
|
||||
if "isTruncated" in kwargs:
|
||||
is_truncated = kwargs["isTruncated"]
|
||||
if "createdDateTime" in kwargs:
|
||||
created_date_time = kwargs["createdDateTime"]
|
||||
if "modifiedDateTime" in kwargs:
|
||||
modified_date_time = kwargs["modifiedDateTime"]
|
||||
if "parentMessageId" in kwargs:
|
||||
parent_message_id = kwargs["parentMessageId"]
|
||||
if "accessedResources_v2" in kwargs:
|
||||
accessed_resources = kwargs["accessedResources_v2"]
|
||||
|
||||
# Convert nested objects
|
||||
if isinstance(content, MutableMapping):
|
||||
# determine by type? fall back to text content
|
||||
c_type = content.get("@odata.type") or content.get("data_type")
|
||||
if c_type and "binary" in str(c_type):
|
||||
content = PurviewBinaryContent(**content) # type: ignore[arg-type]
|
||||
else:
|
||||
content = PurviewTextContent(**content) # type: ignore[arg-type]
|
||||
accessed_list: list[AccessedResourceDetails] | None = None
|
||||
if accessed_resources:
|
||||
accessed_list = [
|
||||
ar if isinstance(ar, AccessedResourceDetails) else AccessedResourceDetails(**ar)
|
||||
for ar in accessed_resources
|
||||
]
|
||||
plugin_list: list[AiInteractionPlugin] | None = None
|
||||
if plugins:
|
||||
plugin_list = [p if isinstance(p, AiInteractionPlugin) else AiInteractionPlugin(**p) for p in plugins]
|
||||
agent_list: list[AiAgentInfo] | None = None
|
||||
if agents:
|
||||
agent_list = [a if isinstance(a, AiAgentInfo) else AiAgentInfo(**a) for a in agents]
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(data_type=data_type, **kwargs)
|
||||
self.identifier = identifier
|
||||
self.content = content # type: ignore[assignment]
|
||||
self.name = name
|
||||
self.correlation_id = correlation_id
|
||||
self.sequence_number = sequence_number
|
||||
self.length = length
|
||||
self.is_truncated = is_truncated
|
||||
self.created_date_time = created_date_time
|
||||
self.modified_date_time = modified_date_time
|
||||
self.parent_message_id = parent_message_id
|
||||
self.accessed_resources = accessed_list
|
||||
self.plugins = plugin_list
|
||||
self.agents = agent_list
|
||||
|
||||
|
||||
class ContentToProcess(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"content_entries": "contentEntries",
|
||||
"activity_metadata": "activityMetadata",
|
||||
"device_metadata": "deviceMetadata",
|
||||
"integrated_app_metadata": "integratedAppMetadata",
|
||||
"protected_app_metadata": "protectedAppMetadata",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_entries: list[ProcessConversationMetadata | MutableMapping[str, Any]],
|
||||
activity_metadata: ActivityMetadata | MutableMapping[str, Any],
|
||||
device_metadata: DeviceMetadata | MutableMapping[str, Any],
|
||||
integrated_app_metadata: IntegratedAppMetadata | MutableMapping[str, Any],
|
||||
protected_app_metadata: ProtectedAppMetadata | MutableMapping[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "contentEntries" in kwargs:
|
||||
content_entries = kwargs["contentEntries"]
|
||||
if "activityMetadata" in kwargs:
|
||||
activity_metadata = kwargs["activityMetadata"]
|
||||
if "deviceMetadata" in kwargs:
|
||||
device_metadata = kwargs["deviceMetadata"]
|
||||
if "integratedAppMetadata" in kwargs:
|
||||
integrated_app_metadata = kwargs["integratedAppMetadata"]
|
||||
if "protectedAppMetadata" in kwargs:
|
||||
protected_app_metadata = kwargs["protectedAppMetadata"]
|
||||
|
||||
# Convert nested objects
|
||||
entries = [
|
||||
e if isinstance(e, ProcessConversationMetadata) else ProcessConversationMetadata(**e)
|
||||
for e in content_entries
|
||||
]
|
||||
if isinstance(activity_metadata, MutableMapping):
|
||||
activity_metadata = ActivityMetadata(**activity_metadata)
|
||||
if isinstance(device_metadata, MutableMapping):
|
||||
device_metadata = DeviceMetadata(**device_metadata)
|
||||
if isinstance(integrated_app_metadata, MutableMapping):
|
||||
integrated_app_metadata = IntegratedAppMetadata(**integrated_app_metadata)
|
||||
if isinstance(protected_app_metadata, MutableMapping):
|
||||
protected_app_metadata = ProtectedAppMetadata(**protected_app_metadata)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.content_entries = entries
|
||||
self.activity_metadata = activity_metadata # type: ignore[assignment]
|
||||
self.device_metadata = device_metadata # type: ignore[assignment]
|
||||
self.integrated_app_metadata = integrated_app_metadata # type: ignore[assignment]
|
||||
self.protected_app_metadata = protected_app_metadata # type: ignore[assignment]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Request models
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProcessContentRequest(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"content_to_process": "contentToProcess"}
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"user_id", "tenant_id", "correlation_id", "process_inline"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_to_process: ContentToProcess | MutableMapping[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
correlation_id: str | None = None,
|
||||
process_inline: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "contentToProcess" in kwargs:
|
||||
content_to_process = kwargs["contentToProcess"]
|
||||
|
||||
# Convert nested objects
|
||||
if isinstance(content_to_process, MutableMapping):
|
||||
content_to_process = ContentToProcess(**content_to_process)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.content_to_process = content_to_process # type: ignore[assignment]
|
||||
self.user_id = user_id
|
||||
self.tenant_id = tenant_id
|
||||
self.correlation_id = correlation_id
|
||||
self.process_inline = process_inline
|
||||
|
||||
|
||||
class ProtectionScopesRequest(_AliasSerializable):
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"user_id", "tenant_id", "correlation_id", "scope_identifier"}
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"pivot_on": "pivotOn",
|
||||
"device_metadata": "deviceMetadata",
|
||||
"integrated_app_metadata": "integratedAppMetadata",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
activities: ProtectionScopeActivities | str | int | Sequence[str] | None = None,
|
||||
locations: list[PolicyLocation | MutableMapping[str, Any]] | None = None,
|
||||
pivot_on: PolicyPivotProperty | None = None,
|
||||
device_metadata: DeviceMetadata | MutableMapping[str, Any] | None = None,
|
||||
integrated_app_metadata: IntegratedAppMetadata | MutableMapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
scope_identifier: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "pivotOn" in kwargs:
|
||||
pivot_on = kwargs["pivotOn"]
|
||||
if "deviceMetadata" in kwargs:
|
||||
device_metadata = kwargs["deviceMetadata"]
|
||||
if "integratedAppMetadata" in kwargs:
|
||||
integrated_app_metadata = kwargs["integratedAppMetadata"]
|
||||
|
||||
# Deserialize activities flag
|
||||
if not isinstance(activities, ProtectionScopeActivities) and activities is not None:
|
||||
activities = deserialize_flag(activities, _PROTECTION_SCOPE_ACTIVITIES_MAP, ProtectionScopeActivities)
|
||||
|
||||
# Convert nested objects
|
||||
if locations:
|
||||
locations = [loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations]
|
||||
if isinstance(device_metadata, MutableMapping):
|
||||
device_metadata = DeviceMetadata(**device_metadata)
|
||||
if isinstance(integrated_app_metadata, MutableMapping):
|
||||
integrated_app_metadata = IntegratedAppMetadata(**integrated_app_metadata)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.user_id = user_id
|
||||
self.tenant_id = tenant_id
|
||||
self.activities = activities # type: ignore[assignment]
|
||||
self.locations = locations
|
||||
self.pivot_on = pivot_on
|
||||
self.device_metadata = device_metadata
|
||||
self.integrated_app_metadata = integrated_app_metadata
|
||||
self.correlation_id = correlation_id
|
||||
self.scope_identifier = scope_identifier
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
|
||||
# Get base dict (activities will be missing because Flag isn't JSON-serializable)
|
||||
base = super().to_dict(exclude=exclude, exclude_none=exclude_none)
|
||||
|
||||
# Manually serialize activities flag if present and not excluded
|
||||
if self.activities is not None or not exclude_none:
|
||||
if self.activities is not None:
|
||||
base["activities"] = serialize_flag(self.activities, _PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER)
|
||||
elif not exclude_none:
|
||||
base["activities"] = None
|
||||
|
||||
return base
|
||||
|
||||
|
||||
class ContentActivitiesRequest(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"user_id": "userId",
|
||||
"scope_identifier": "scopeIdentifier",
|
||||
"content_to_process": "contentMetadata",
|
||||
}
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"tenant_id", "correlation_id"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
content_to_process: ContentToProcess | MutableMapping[str, Any],
|
||||
tenant_id: str,
|
||||
id: str | None = None,
|
||||
scope_identifier: str | None = None,
|
||||
correlation_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "userId" in kwargs:
|
||||
user_id = kwargs["userId"]
|
||||
if "scopeIdentifier" in kwargs:
|
||||
scope_identifier = kwargs["scopeIdentifier"]
|
||||
if "contentMetadata" in kwargs:
|
||||
content_to_process = kwargs["contentMetadata"]
|
||||
|
||||
# Convert nested objects
|
||||
if isinstance(content_to_process, MutableMapping):
|
||||
content_to_process = ContentToProcess(**content_to_process)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.id = id or str(uuid4())
|
||||
self.user_id = user_id
|
||||
self.content_to_process = content_to_process # type: ignore[assignment]
|
||||
self.tenant_id = tenant_id
|
||||
self.scope_identifier = scope_identifier
|
||||
self.correlation_id = correlation_id
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Response models
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ErrorDetails(_AliasSerializable):
|
||||
def __init__(self, code: str | None = None, message: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(code=code, message=message, **kwargs)
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class ProcessingError(_AliasSerializable):
|
||||
def __init__(self, message: str | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(message=message, **kwargs)
|
||||
self.message = message
|
||||
|
||||
|
||||
class ProcessContentResponse(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {
|
||||
"protection_scope_state": "protectionScopeState",
|
||||
"policy_actions": "policyActions",
|
||||
"processing_errors": "processingErrors",
|
||||
}
|
||||
|
||||
id: str | None
|
||||
protection_scope_state: ProtectionScopeState | None
|
||||
policy_actions: list[DlpActionInfo] | None
|
||||
processing_errors: list[ProcessingError] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str | None = None,
|
||||
protection_scope_state: ProtectionScopeState | None = None,
|
||||
policy_actions: list[DlpActionInfo | MutableMapping[str, Any]] | None = None,
|
||||
processing_errors: list[ProcessingError | MutableMapping[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "protectionScopeState" in kwargs:
|
||||
protection_scope_state = kwargs["protectionScopeState"]
|
||||
if "policyActions" in kwargs:
|
||||
policy_actions = kwargs["policyActions"]
|
||||
if "processingErrors" in kwargs:
|
||||
processing_errors = kwargs["processingErrors"]
|
||||
|
||||
# Convert to objects
|
||||
converted_policy_actions: list[DlpActionInfo] | None = None
|
||||
if policy_actions is not None:
|
||||
converted_policy_actions = cast(
|
||||
list[DlpActionInfo],
|
||||
[p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions],
|
||||
)
|
||||
|
||||
converted_processing_errors: list[ProcessingError] | None = None
|
||||
if processing_errors is not None:
|
||||
converted_processing_errors = cast(
|
||||
list[ProcessingError],
|
||||
[pe if isinstance(pe, ProcessingError) else ProcessingError(**pe) for pe in processing_errors],
|
||||
)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.id = id
|
||||
self.protection_scope_state = protection_scope_state
|
||||
self.policy_actions = converted_policy_actions
|
||||
self.processing_errors = converted_processing_errors
|
||||
|
||||
|
||||
class PolicyScope(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"policy_actions": "policyActions", "execution_mode": "executionMode"}
|
||||
|
||||
activities: ProtectionScopeActivities | None
|
||||
locations: list[PolicyLocation] | None
|
||||
policy_actions: list[DlpActionInfo] | None
|
||||
execution_mode: ExecutionMode | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activities: ProtectionScopeActivities | str | int | Sequence[str] | None = None,
|
||||
locations: list[PolicyLocation | MutableMapping[str, Any]] | None = None,
|
||||
policy_actions: list[DlpActionInfo | MutableMapping[str, Any]] | None = None,
|
||||
execution_mode: ExecutionMode | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs
|
||||
if "policyActions" in kwargs:
|
||||
policy_actions = kwargs["policyActions"]
|
||||
if "executionMode" in kwargs:
|
||||
execution_mode = kwargs["executionMode"]
|
||||
|
||||
# Deserialize activities flag
|
||||
if not isinstance(activities, ProtectionScopeActivities) and activities is not None:
|
||||
activities = deserialize_flag(activities, _PROTECTION_SCOPE_ACTIVITIES_MAP, ProtectionScopeActivities)
|
||||
|
||||
# Convert nested objects
|
||||
converted_locations: list[PolicyLocation] | None = None
|
||||
if locations is not None:
|
||||
converted_locations = cast(
|
||||
list[PolicyLocation],
|
||||
[loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations],
|
||||
)
|
||||
|
||||
converted_policy_actions: list[DlpActionInfo] | None = None
|
||||
if policy_actions is not None:
|
||||
converted_policy_actions = cast(
|
||||
list[DlpActionInfo],
|
||||
[p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions],
|
||||
)
|
||||
|
||||
# Call parent without explicit params with aliases
|
||||
super().__init__(**kwargs)
|
||||
self.activities = activities # type: ignore[assignment]
|
||||
self.locations = converted_locations
|
||||
self.policy_actions = converted_policy_actions
|
||||
self.execution_mode = execution_mode
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override]
|
||||
# Get base dict (activities will be missing because Flag isn't JSON-serializable)
|
||||
base = super().to_dict(exclude=exclude, exclude_none=exclude_none)
|
||||
|
||||
# Manually serialize activities flag if present and not excluded
|
||||
if self.activities is not None or not exclude_none:
|
||||
if self.activities is not None:
|
||||
base["activities"] = serialize_flag(self.activities, _PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER)
|
||||
elif not exclude_none:
|
||||
base["activities"] = None
|
||||
|
||||
return base
|
||||
|
||||
|
||||
class ProtectionScopesResponse(_AliasSerializable):
|
||||
_ALIASES: ClassVar[dict[str, str]] = {"scope_identifier": "scopeIdentifier", "scopes": "value"}
|
||||
|
||||
scope_identifier: str | None
|
||||
scopes: list[PolicyScope] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope_identifier: str | None = None,
|
||||
scopes: list[PolicyScope | MutableMapping[str, Any]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Extract aliased values from kwargs before they're normalized by parent
|
||||
if "scopeIdentifier" in kwargs:
|
||||
scope_identifier = kwargs["scopeIdentifier"]
|
||||
if "value" in kwargs:
|
||||
scopes = kwargs["value"]
|
||||
|
||||
converted_scopes: list[PolicyScope] | None = None
|
||||
if scopes is not None:
|
||||
converted_scopes = cast(
|
||||
list[PolicyScope], [s if isinstance(s, PolicyScope) else PolicyScope(**s) for s in scopes]
|
||||
)
|
||||
|
||||
# Don't pass parameters that have aliases - let parent normalize them
|
||||
super().__init__(**kwargs)
|
||||
self.scope_identifier = scope_identifier
|
||||
self.scopes = converted_scopes
|
||||
|
||||
|
||||
class ContentActivitiesResponse(_AliasSerializable):
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"status_code"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int | None = None,
|
||||
error: ErrorDetails | MutableMapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if isinstance(error, MutableMapping):
|
||||
error = ErrorDetails(**error)
|
||||
super().__init__(status_code=status_code, error=error, **kwargs)
|
||||
self.status_code = status_code
|
||||
self.error = error # type: ignore[assignment]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AccessedResourceDetails",
|
||||
"Activity",
|
||||
"ActivityMetadata",
|
||||
"AiAgentInfo",
|
||||
"AiInteractionPlugin",
|
||||
"ContentActivitiesRequest",
|
||||
"ContentActivitiesResponse",
|
||||
"ContentBase",
|
||||
"ContentToProcess",
|
||||
"DeviceMetadata",
|
||||
"DlpAction",
|
||||
"DlpActionInfo",
|
||||
"ExecutionMode",
|
||||
"GraphDataTypeBase",
|
||||
"IntegratedAppMetadata",
|
||||
"OperatingSystemSpecifications",
|
||||
"PolicyLocation",
|
||||
"PolicyPivotProperty",
|
||||
"PolicyScope",
|
||||
"ProcessContentRequest",
|
||||
"ProcessContentResponse",
|
||||
"ProcessConversationMetadata",
|
||||
"ProcessingError",
|
||||
"ProtectedAppMetadata",
|
||||
"ProtectionScopeActivities",
|
||||
"ProtectionScopeState",
|
||||
"ProtectionScopesRequest",
|
||||
"ProtectionScopesResponse",
|
||||
"PurviewBinaryContent",
|
||||
"PurviewTextContent",
|
||||
"RestrictionAction",
|
||||
"deserialize_flag",
|
||||
"serialize_flag",
|
||||
"translate_activity",
|
||||
]
|
||||
@@ -0,0 +1,251 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Iterable, MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
from ._client import PurviewClient
|
||||
from ._models import (
|
||||
Activity,
|
||||
ActivityMetadata,
|
||||
ContentActivitiesRequest,
|
||||
ContentToProcess,
|
||||
DeviceMetadata,
|
||||
DlpAction,
|
||||
DlpActionInfo,
|
||||
IntegratedAppMetadata,
|
||||
OperatingSystemSpecifications,
|
||||
PolicyLocation,
|
||||
ProcessContentRequest,
|
||||
ProcessContentResponse,
|
||||
ProcessConversationMetadata,
|
||||
ProcessingError,
|
||||
ProtectedAppMetadata,
|
||||
ProtectionScopesRequest,
|
||||
ProtectionScopesResponse,
|
||||
PurviewTextContent,
|
||||
RestrictionAction,
|
||||
translate_activity,
|
||||
)
|
||||
from ._settings import PurviewSettings
|
||||
|
||||
|
||||
def _is_valid_guid(value: str | None) -> bool:
|
||||
"""Check if a string is a valid GUID/UUID format using uuid module."""
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
uuid.UUID(value)
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
class ScopedContentProcessor:
|
||||
"""Combine protection scopes, process content, and content activities logic."""
|
||||
|
||||
def __init__(self, client: PurviewClient, settings: PurviewSettings):
|
||||
self._client = client
|
||||
self._settings = settings
|
||||
|
||||
async def process_messages(
|
||||
self, messages: Iterable[ChatMessage], activity: Activity, user_id: str | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Process messages for policy evaluation.
|
||||
|
||||
Args:
|
||||
messages: The messages to process
|
||||
activity: The activity type (e.g., UPLOAD_TEXT)
|
||||
user_id: Optional user_id to use for all messages. If provided, this is the fallback.
|
||||
|
||||
Returns:
|
||||
A tuple of (should_block: bool, resolved_user_id: str | None).
|
||||
The resolved_user_id can be stored and passed back when processing the response
|
||||
to ensure the same user context is maintained throughout the request/response cycle.
|
||||
"""
|
||||
pc_requests, resolved_user_id = await self._map_messages(messages, activity, user_id)
|
||||
should_block = False
|
||||
for req in pc_requests:
|
||||
resp = await self._process_with_scopes(req)
|
||||
if resp.policy_actions:
|
||||
for act in resp.policy_actions:
|
||||
if act.action == DlpAction.BLOCK_ACCESS or act.restriction_action == RestrictionAction.BLOCK:
|
||||
should_block = True
|
||||
break
|
||||
if should_block:
|
||||
break
|
||||
return should_block, resolved_user_id
|
||||
|
||||
async def _map_messages(
|
||||
self, messages: Iterable[ChatMessage], activity: Activity, provided_user_id: str | None = None
|
||||
) -> tuple[list[ProcessContentRequest], str | None]:
|
||||
"""Map messages to ProcessContentRequests.
|
||||
|
||||
Args:
|
||||
messages: The messages to map
|
||||
activity: The activity type
|
||||
provided_user_id: Optional user_id to use. If provided, this is the fallback.
|
||||
|
||||
Returns:
|
||||
A tuple of (requests, resolved_user_id)
|
||||
"""
|
||||
results: list[ProcessContentRequest] = []
|
||||
token_info = None
|
||||
|
||||
if not (self._settings.tenant_id and self._settings.purview_app_location):
|
||||
token_info = await self._client.get_user_info_from_token(tenant_id=self._settings.tenant_id)
|
||||
|
||||
tenant_id = (token_info or {}).get("tenant_id") or self._settings.tenant_id
|
||||
if not tenant_id or not _is_valid_guid(tenant_id):
|
||||
raise ValueError("Tenant id required or must be inferable from credential")
|
||||
|
||||
resolved_user_id = (token_info or {}).get("user_id")
|
||||
resolved_author_name = None
|
||||
if not resolved_user_id:
|
||||
for m in messages:
|
||||
if m.additional_properties:
|
||||
potential_user_id = m.additional_properties.get("user_id")
|
||||
if _is_valid_guid(potential_user_id):
|
||||
resolved_user_id = potential_user_id
|
||||
break
|
||||
if m.author_name and _is_valid_guid(m.author_name) and not resolved_author_name:
|
||||
resolved_author_name = m.author_name
|
||||
|
||||
if not resolved_user_id and resolved_author_name:
|
||||
resolved_user_id = resolved_author_name
|
||||
|
||||
if not resolved_user_id:
|
||||
resolved_user_id = provided_user_id if provided_user_id and _is_valid_guid(provided_user_id) else None
|
||||
|
||||
# Return empty results if user_id is empty
|
||||
if not resolved_user_id or not _is_valid_guid(resolved_user_id):
|
||||
return results, None
|
||||
|
||||
for m in messages:
|
||||
message_id = m.message_id or str(uuid.uuid4())
|
||||
content = PurviewTextContent(data=m.text or "")
|
||||
meta = ProcessConversationMetadata(
|
||||
identifier=message_id,
|
||||
content=content,
|
||||
name=f"Agent Framework Message {message_id}",
|
||||
is_truncated=False,
|
||||
correlation_id=str(uuid.uuid4()),
|
||||
)
|
||||
activity_meta = ActivityMetadata(activity=activity)
|
||||
|
||||
if self._settings.purview_app_location:
|
||||
policy_location = PolicyLocation(
|
||||
data_type=self._settings.purview_app_location.get_policy_location()["@odata.type"],
|
||||
value=self._settings.purview_app_location.location_value,
|
||||
)
|
||||
elif token_info and token_info.get("client_id"):
|
||||
policy_location = PolicyLocation(
|
||||
data_type="microsoft.graph.policyLocationApplication",
|
||||
value=token_info["client_id"],
|
||||
)
|
||||
else:
|
||||
raise ValueError("App location not provided or inferable")
|
||||
|
||||
protected_app = ProtectedAppMetadata(
|
||||
name=self._settings.app_name,
|
||||
version="1.0",
|
||||
application_location=policy_location,
|
||||
)
|
||||
integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version="1.0")
|
||||
device_meta = DeviceMetadata(
|
||||
operating_system_specifications=OperatingSystemSpecifications(
|
||||
operating_system_platform="Unknown", operating_system_version="Unknown"
|
||||
)
|
||||
)
|
||||
|
||||
ctp = ContentToProcess(
|
||||
content_entries=[meta],
|
||||
activity_metadata=activity_meta,
|
||||
device_metadata=device_meta,
|
||||
integrated_app_metadata=integrated_app,
|
||||
protected_app_metadata=protected_app,
|
||||
)
|
||||
req = ProcessContentRequest(
|
||||
content_to_process=ctp,
|
||||
user_id=resolved_user_id, # Use the resolved user_id for all messages
|
||||
tenant_id=tenant_id,
|
||||
correlation_id=meta.correlation_id,
|
||||
process_inline=True if self._settings.process_inline else None,
|
||||
)
|
||||
results.append(req)
|
||||
return results, resolved_user_id
|
||||
|
||||
async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> ProcessContentResponse:
|
||||
app_location = pc_request.content_to_process.protected_app_metadata.application_location
|
||||
locations: list[PolicyLocation | MutableMapping[str, Any]] = [app_location] if app_location is not None else []
|
||||
|
||||
ps_req = ProtectionScopesRequest(
|
||||
user_id=pc_request.user_id,
|
||||
tenant_id=pc_request.tenant_id,
|
||||
activities=translate_activity(pc_request.content_to_process.activity_metadata.activity),
|
||||
locations=locations,
|
||||
device_metadata=pc_request.content_to_process.device_metadata,
|
||||
integrated_app_metadata=pc_request.content_to_process.integrated_app_metadata,
|
||||
correlation_id=pc_request.correlation_id,
|
||||
)
|
||||
ps_resp = await self._client.get_protection_scopes(ps_req)
|
||||
should_process, dlp_actions = self._check_applicable_scopes(pc_request, ps_resp)
|
||||
|
||||
if should_process:
|
||||
pc_resp = await self._client.process_content(pc_request)
|
||||
pc_resp.policy_actions = self._combine_policy_actions(pc_resp.policy_actions, dlp_actions)
|
||||
return pc_resp
|
||||
ca_req = ContentActivitiesRequest(
|
||||
user_id=pc_request.user_id,
|
||||
tenant_id=pc_request.tenant_id,
|
||||
content_to_process=pc_request.content_to_process,
|
||||
correlation_id=pc_request.correlation_id,
|
||||
)
|
||||
ca_resp = await self._client.send_content_activities(ca_req)
|
||||
if ca_resp.error:
|
||||
return ProcessContentResponse(processing_errors=[ProcessingError(message=str(ca_resp.error))])
|
||||
return ProcessContentResponse()
|
||||
|
||||
@staticmethod
|
||||
def _combine_policy_actions(
|
||||
existing: list[DlpActionInfo] | None, new_actions: list[DlpActionInfo]
|
||||
) -> list[DlpActionInfo]:
|
||||
by_key: dict[str, DlpActionInfo] = {}
|
||||
for a in existing or []:
|
||||
if a.action:
|
||||
by_key[a.action] = a
|
||||
for a in new_actions:
|
||||
if a.action:
|
||||
by_key[a.action] = a
|
||||
return list(by_key.values())
|
||||
|
||||
@staticmethod
|
||||
def _check_applicable_scopes(
|
||||
pc_request: ProcessContentRequest, ps_response: ProtectionScopesResponse
|
||||
) -> tuple[bool, list[DlpActionInfo]]:
|
||||
req_activity = translate_activity(pc_request.content_to_process.activity_metadata.activity)
|
||||
location = pc_request.content_to_process.protected_app_metadata.application_location
|
||||
should_process: bool = False
|
||||
dlp_actions: list[DlpActionInfo] = []
|
||||
for scope in ps_response.scopes or []:
|
||||
# Check if all activities in req_activity are present in scope.activities using bitwise flags.
|
||||
activity_match = bool(scope.activities and (scope.activities & req_activity) == req_activity)
|
||||
location_match = False
|
||||
if location is not None:
|
||||
for loc in scope.locations or []:
|
||||
if (
|
||||
loc.data_type
|
||||
and location.data_type
|
||||
and loc.data_type.lower().endswith(location.data_type.split(".")[-1].lower())
|
||||
and loc.value == location.value
|
||||
):
|
||||
location_match = True
|
||||
break
|
||||
if activity_match and location_match:
|
||||
should_process = True
|
||||
if scope.policy_actions:
|
||||
dlp_actions.extend(scope.policy_actions)
|
||||
return should_process, dlp_actions
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
|
||||
class PurviewLocationType(str, Enum):
|
||||
"""The type of location for Purview policy evaluation."""
|
||||
|
||||
APPLICATION = "application"
|
||||
URI = "uri"
|
||||
DOMAIN = "domain"
|
||||
|
||||
|
||||
class PurviewAppLocation(BaseModel):
|
||||
"""Identifier representing the app's location for Purview policy evaluation."""
|
||||
|
||||
location_type: PurviewLocationType = Field(..., description="The location type.")
|
||||
location_value: str = Field(..., description="The location value.")
|
||||
|
||||
def get_policy_location(self) -> dict[str, str]:
|
||||
ns = "microsoft.graph"
|
||||
if self.location_type == PurviewLocationType.APPLICATION:
|
||||
dt = f"{ns}.policyLocationApplication"
|
||||
elif self.location_type == PurviewLocationType.URI:
|
||||
dt = f"{ns}.policyLocationUrl"
|
||||
elif self.location_type == PurviewLocationType.DOMAIN:
|
||||
dt = f"{ns}.policyLocationDomain"
|
||||
else: # pragma: no cover - defensive
|
||||
raise ValueError("Invalid Purview location type")
|
||||
return {"@odata.type": dt, "value": self.location_value}
|
||||
|
||||
|
||||
class PurviewSettings(AFBaseSettings):
|
||||
"""Settings for Purview integration mirroring .NET PurviewSettings.
|
||||
|
||||
Attributes:
|
||||
app_name: Public app name.
|
||||
tenant_id: Optional tenant id (guid) of the user making the request.
|
||||
purview_app_location: Optional app location for policy evaluation.
|
||||
graph_base_uri: Base URI for Microsoft Graph.
|
||||
blocked_prompt_message: Custom message to return when a prompt is blocked by policy.
|
||||
blocked_response_message: Custom message to return when a response is blocked by policy.
|
||||
"""
|
||||
|
||||
app_name: str = Field(...)
|
||||
tenant_id: str | None = Field(default=None)
|
||||
purview_app_location: PurviewAppLocation | None = Field(default=None)
|
||||
graph_base_uri: str = Field(default="https://graph.microsoft.com/v1.0/")
|
||||
process_inline: bool = Field(default=False, description="Process content inline if supported.")
|
||||
blocked_prompt_message: str = Field(
|
||||
default="Prompt blocked by policy",
|
||||
description="Message to return when a prompt is blocked by policy.",
|
||||
)
|
||||
blocked_response_message: str = Field(
|
||||
default="Response blocked by policy",
|
||||
description="Message to return when a response is blocked by policy.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(populate_by_name=True, validate_assignment=True)
|
||||
|
||||
def get_scopes(self) -> list[str]:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
host = urlparse(self.graph_base_uri).hostname or "graph.microsoft.com"
|
||||
return [f"https://{host}/.default"]
|
||||
@@ -0,0 +1,87 @@
|
||||
[project]
|
||||
name = "agent-framework-purview"
|
||||
description = "Microsoft Purview (Graph dataSecurityAndGovernance) integration for Microsoft Agent Framework."
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "0.1.0b1"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://github.com/microsoft/agent-framework"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Framework :: Pydantic :: 2",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core",
|
||||
"azure-core>=1.30.0",
|
||||
"httpx>=0.27.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"**/__init__.py"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
extend = "../../pyproject.toml"
|
||||
exclude = ['tests']
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
|
||||
[tool.bandit]
|
||||
targets = ["agent_framework_purview"]
|
||||
exclude_dirs = ["tests"]
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_purview"
|
||||
test = "pytest --cov=agent_framework_purview --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
[build-system]
|
||||
requires = ["flit-core >= 3.9,<4.0"]
|
||||
build-backend = "flit_core.buildapi"
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
"""Shared pytest fixtures for Purview tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_purview._models import (
|
||||
Activity,
|
||||
ActivityMetadata,
|
||||
ContentToProcess,
|
||||
DeviceMetadata,
|
||||
IntegratedAppMetadata,
|
||||
OperatingSystemSpecifications,
|
||||
PolicyLocation,
|
||||
ProcessContentRequest,
|
||||
ProcessConversationMetadata,
|
||||
ProtectedAppMetadata,
|
||||
PurviewTextContent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def content_to_process_factory():
|
||||
"""Factory fixture to create ContentToProcess objects with test data."""
|
||||
|
||||
def _create_content(text: str = "Test") -> ContentToProcess:
|
||||
text_content = PurviewTextContent(data=text)
|
||||
metadata = ProcessConversationMetadata(
|
||||
identifier="msg-1",
|
||||
content=text_content,
|
||||
name="Test",
|
||||
is_truncated=False,
|
||||
)
|
||||
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
|
||||
device_meta = DeviceMetadata(
|
||||
operating_system_specifications=OperatingSystemSpecifications(
|
||||
operating_system_platform="Windows", operating_system_version="10"
|
||||
)
|
||||
)
|
||||
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
|
||||
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
|
||||
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
|
||||
|
||||
return ContentToProcess(
|
||||
content_entries=[metadata],
|
||||
activity_metadata=activity_meta,
|
||||
device_metadata=device_meta,
|
||||
integrated_app_metadata=integrated_app,
|
||||
protected_app_metadata=protected_app,
|
||||
)
|
||||
|
||||
return _create_content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def process_content_request_factory(content_to_process_factory):
|
||||
"""Factory fixture to create ProcessContentRequest objects with test data."""
|
||||
|
||||
def _create_request(
|
||||
text: str = "Test", user_id: str = "user-123", tenant_id: str = "tenant-456"
|
||||
) -> ProcessContentRequest:
|
||||
content = content_to_process_factory(text)
|
||||
return ProcessContentRequest(
|
||||
content_to_process=content,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return _create_request
|
||||
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
"""Tests for Purview chat middleware."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import ChatContext, ChatMessage, Role
|
||||
from azure.core.credentials import AccessToken
|
||||
|
||||
from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyChatClient:
|
||||
name: str = "dummy"
|
||||
|
||||
|
||||
class TestPurviewChatPolicyMiddleware:
|
||||
@pytest.fixture
|
||||
def mock_credential(self) -> AsyncMock:
|
||||
credential = AsyncMock()
|
||||
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
|
||||
return credential
|
||||
|
||||
@pytest.fixture
|
||||
def settings(self) -> PurviewSettings:
|
||||
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
|
||||
|
||||
@pytest.fixture
|
||||
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewChatPolicyMiddleware:
|
||||
return PurviewChatPolicyMiddleware(mock_credential, settings)
|
||||
|
||||
@pytest.fixture
|
||||
def chat_context(self) -> ChatContext:
|
||||
chat_client = DummyChatClient()
|
||||
chat_options = MagicMock()
|
||||
chat_options.model = "test-model"
|
||||
return ChatContext(
|
||||
chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options
|
||||
)
|
||||
|
||||
async def test_initialization(self, middleware: PurviewChatPolicyMiddleware) -> None:
|
||||
assert middleware._client is not None
|
||||
assert middleware._processor is not None
|
||||
|
||||
async def test_allows_clean_prompt(
|
||||
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
|
||||
) -> None:
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc:
|
||||
next_called = False
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
nonlocal next_called
|
||||
next_called = True
|
||||
|
||||
class Result:
|
||||
def __init__(self):
|
||||
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Hi there")]
|
||||
|
||||
ctx.result = Result()
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
assert next_called
|
||||
assert mock_proc.call_count == 2
|
||||
assert chat_context.result.messages[0].role == Role.ASSISTANT
|
||||
|
||||
async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None: # should not run
|
||||
raise AssertionError("next should not be called when prompt blocked")
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
assert chat_context.terminate
|
||||
assert chat_context.result
|
||||
msg = chat_context.result[0] # type: ignore[index]
|
||||
assert msg.role in ("system", Role.SYSTEM)
|
||||
assert "blocked" in msg.text.lower()
|
||||
|
||||
async def test_blocks_response(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None:
|
||||
call_state = {"count": 0}
|
||||
|
||||
async def side_effect(messages, activity, user_id=None):
|
||||
call_state["count"] += 1
|
||||
should_block = call_state["count"] == 2
|
||||
return (should_block, "user-123")
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=side_effect):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
class Result:
|
||||
def __init__(self):
|
||||
self.messages = [ChatMessage(role=Role.ASSISTANT, text="Sensitive output")] # pragma: no cover
|
||||
|
||||
ctx.result = Result()
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
assert call_state["count"] == 2
|
||||
msgs = getattr(chat_context.result, "messages", None) or chat_context.result
|
||||
first_msg = msgs[0]
|
||||
assert first_msg.role in ("system", Role.SYSTEM)
|
||||
assert "blocked" in first_msg.text.lower()
|
||||
|
||||
async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMiddleware) -> None:
|
||||
chat_client = DummyChatClient()
|
||||
chat_options = MagicMock()
|
||||
chat_options.model = "test-model"
|
||||
streaming_context = ChatContext(
|
||||
chat_client=chat_client,
|
||||
messages=[ChatMessage(role=Role.USER, text="Hello")],
|
||||
chat_options=chat_options,
|
||||
is_streaming=True,
|
||||
)
|
||||
with patch.object(middleware._processor, "process_messages", return_value=False) as mock_proc:
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
ctx.result = MagicMock()
|
||||
|
||||
await middleware.process(streaming_context, mock_next)
|
||||
assert mock_proc.call_count == 1
|
||||
|
||||
async def test_chat_middleware_handles_post_check_exception(
|
||||
self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext
|
||||
) -> None:
|
||||
"""Test that exceptions in post-check are logged but don't affect result."""
|
||||
call_count = 0
|
||||
|
||||
async def mock_process_messages(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return (False, "user-123") # Pre-check succeeds
|
||||
raise Exception("Post-check error") # Post-check fails
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: ChatContext) -> None:
|
||||
# Create a mock result with messages attribute
|
||||
result = MagicMock()
|
||||
result.messages = [ChatMessage(role=Role.ASSISTANT, text="Response")]
|
||||
ctx.result = result
|
||||
|
||||
await middleware.process(chat_context, mock_next)
|
||||
|
||||
# Should have been called twice (pre and post)
|
||||
assert call_count == 2
|
||||
# Result should still be set
|
||||
assert chat_context.result is not None
|
||||
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for Purview client."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from azure.core.credentials import AccessToken
|
||||
|
||||
from agent_framework_purview import PurviewSettings
|
||||
from agent_framework_purview._client import PurviewClient
|
||||
from agent_framework_purview._exceptions import (
|
||||
PurviewAuthenticationError,
|
||||
PurviewRateLimitError,
|
||||
PurviewRequestError,
|
||||
PurviewServiceError,
|
||||
)
|
||||
from agent_framework_purview._models import (
|
||||
PolicyLocation,
|
||||
ProcessContentRequest,
|
||||
ProtectionScopesRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestPurviewClient:
|
||||
"""Test PurviewClient functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credential(self) -> MagicMock:
|
||||
"""Create a mock async credential."""
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
credential = MagicMock(spec=AsyncTokenCredential)
|
||||
mock_token = AccessToken("fake-token", 9999999999)
|
||||
|
||||
async def mock_get_token(*args, **kwargs):
|
||||
return mock_token
|
||||
|
||||
credential.get_token = mock_get_token
|
||||
return credential
|
||||
|
||||
@pytest.fixture
|
||||
def settings(self) -> PurviewSettings:
|
||||
"""Create test settings."""
|
||||
return PurviewSettings(app_name="Test App", tenant_id="test-tenant", default_user_id="test-user")
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, mock_credential: MagicMock, settings: PurviewSettings) -> PurviewClient:
|
||||
"""Create a PurviewClient with mock credential."""
|
||||
client = PurviewClient(mock_credential, settings, timeout=10.0)
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
async def test_client_initialization(self, mock_credential: MagicMock, settings: PurviewSettings) -> None:
|
||||
"""Test PurviewClient initialization."""
|
||||
client = PurviewClient(mock_credential, settings)
|
||||
|
||||
assert client._credential == mock_credential
|
||||
assert client._settings == settings
|
||||
assert client._graph_uri == "https://graph.microsoft.com/v1.0"
|
||||
assert client._timeout == 10.0
|
||||
|
||||
await client.close()
|
||||
|
||||
async def test_get_token_async_credential(self, client: PurviewClient, mock_credential: MagicMock) -> None:
|
||||
"""Test _get_token with async credential."""
|
||||
token = await client._get_token(tenant_id="test-tenant")
|
||||
|
||||
assert token == "fake-token"
|
||||
|
||||
async def test_get_token_sync_credential(self, settings: PurviewSettings) -> None:
|
||||
"""Test _get_token with sync credential."""
|
||||
sync_credential = MagicMock()
|
||||
sync_credential.get_token = MagicMock(return_value=AccessToken("sync-token", 9999999999))
|
||||
|
||||
client = PurviewClient(sync_credential, settings)
|
||||
|
||||
with patch("asyncio.get_running_loop") as mock_loop:
|
||||
mock_executor = AsyncMock()
|
||||
mock_executor.return_value = AccessToken("sync-token", 9999999999)
|
||||
mock_loop.return_value.run_in_executor = mock_executor
|
||||
|
||||
token = await client._get_token(tenant_id="test-tenant")
|
||||
|
||||
assert token == "sync-token"
|
||||
|
||||
await client.close()
|
||||
|
||||
async def test_get_user_info_from_token(self, client: PurviewClient) -> None:
|
||||
"""Test get_user_info_from_token extracts user info."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
payload = {"tid": "test-tenant", "oid": "test-user", "idtyp": "user"}
|
||||
payload_str = json.dumps(payload)
|
||||
payload_bytes = payload_str.encode("utf-8")
|
||||
payload_b64 = base64.urlsafe_b64encode(payload_bytes).decode("utf-8").rstrip("=")
|
||||
fake_token = f"header.{payload_b64}.signature"
|
||||
|
||||
with patch.object(client, "_get_token", return_value=fake_token):
|
||||
user_info = await client.get_user_info_from_token(tenant_id="test-tenant")
|
||||
|
||||
assert user_info["tenant_id"] == "test-tenant"
|
||||
assert user_info["user_id"] == "test-user"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,exception_type",
|
||||
[
|
||||
(401, PurviewAuthenticationError),
|
||||
(403, PurviewAuthenticationError),
|
||||
(429, PurviewRateLimitError),
|
||||
(400, PurviewRequestError),
|
||||
(404, PurviewRequestError),
|
||||
(500, PurviewServiceError),
|
||||
(502, PurviewServiceError),
|
||||
],
|
||||
)
|
||||
async def test_post_error_handling(
|
||||
self, client: PurviewClient, content_to_process_factory, status_code: int, exception_type: type[Exception]
|
||||
) -> None:
|
||||
"""Test _post method handles different HTTP errors correctly."""
|
||||
from agent_framework_purview._models import ProcessContentResponse
|
||||
|
||||
content = content_to_process_factory()
|
||||
request = ProcessContentRequest(
|
||||
content_to_process=content,
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
)
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = "Error message"
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Error", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch.object(client._client, "post", return_value=mock_response), pytest.raises(exception_type):
|
||||
await client._post(
|
||||
"https://graph.microsoft.com/v1.0/test",
|
||||
request,
|
||||
ProcessContentResponse,
|
||||
"fake-token",
|
||||
)
|
||||
|
||||
async def test_process_content_success(
|
||||
self, client: PurviewClient, content_to_process_factory, mock_credential: MagicMock
|
||||
) -> None:
|
||||
"""Test process_content method success path."""
|
||||
content = content_to_process_factory("Test message")
|
||||
request = ProcessContentRequest(
|
||||
content_to_process=content,
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
)
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"id": "response-123", "protectionScopeState": "notModified"}
|
||||
|
||||
with patch.object(client._client, "post", return_value=mock_response):
|
||||
response = await client.process_content(request)
|
||||
|
||||
assert response.id == "response-123"
|
||||
assert response.protection_scope_state == "notModified"
|
||||
|
||||
async def test_get_protection_scopes_success(self, client: PurviewClient) -> None:
|
||||
"""Test get_protection_scopes method success path."""
|
||||
location = PolicyLocation(**{"@odata.type": "microsoft.graph.policyLocationApplication", "value": "app-id"})
|
||||
request = ProtectionScopesRequest(
|
||||
user_id="user-123", tenant_id="tenant-456", locations=[location], correlation_id="corr-789"
|
||||
)
|
||||
|
||||
mock_response = MagicMock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"scopeIdentifier": "scope-123", "value": []}
|
||||
|
||||
with patch.object(client._client, "post", return_value=mock_response):
|
||||
response = await client.get_protection_scopes(request)
|
||||
|
||||
assert response.scope_identifier == "scope-123"
|
||||
assert response.scopes == []
|
||||
|
||||
async def test_client_close(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None:
|
||||
"""Test client properly closes HTTP client."""
|
||||
client = PurviewClient(mock_credential, settings)
|
||||
|
||||
with patch.object(client._client, "aclose", new_callable=AsyncMock) as mock_close:
|
||||
await client.close()
|
||||
mock_close.assert_called_once()
|
||||
|
||||
async def test_invalid_jwt_token_format(self, client: PurviewClient) -> None:
|
||||
"""Test that invalid JWT token format raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid JWT token format"):
|
||||
client._extract_token_info("invalid-token-without-dots")
|
||||
|
||||
async def test_rate_limit_error(self, client: PurviewClient) -> None:
|
||||
"""Test that 429 status code raises PurviewRateLimitError."""
|
||||
request = ProcessContentRequest(
|
||||
user_id="test-user",
|
||||
tenant_id="test-tenant",
|
||||
content_to_process=[],
|
||||
correlation_id="test-correlation-id",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_get_token", return_value="fake-token"),
|
||||
patch.object(
|
||||
client._client,
|
||||
"post",
|
||||
return_value=httpx.Response(429, text="Rate limited", request=httpx.Request("POST", "http://test")),
|
||||
),
|
||||
pytest.raises(PurviewRateLimitError, match="Rate limited"),
|
||||
):
|
||||
await client.process_content(request)
|
||||
|
||||
async def test_generic_request_error(self, client: PurviewClient) -> None:
|
||||
"""Test that non-200/201/202 status codes raise PurviewRequestError."""
|
||||
request = ProcessContentRequest(
|
||||
user_id="test-user",
|
||||
tenant_id="test-tenant",
|
||||
content_to_process=[],
|
||||
correlation_id="test-correlation-id",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_get_token", return_value="fake-token"),
|
||||
patch.object(
|
||||
client._client,
|
||||
"post",
|
||||
return_value=httpx.Response(
|
||||
500, text="Internal server error", request=httpx.Request("POST", "http://test")
|
||||
),
|
||||
),
|
||||
pytest.raises(PurviewRequestError, match="Purview request failed"),
|
||||
):
|
||||
await client.process_content(request)
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for Purview exceptions."""
|
||||
|
||||
from agent_framework_purview import (
|
||||
PurviewAuthenticationError,
|
||||
PurviewRateLimitError,
|
||||
PurviewRequestError,
|
||||
PurviewServiceError,
|
||||
)
|
||||
|
||||
|
||||
class TestPurviewExceptions:
|
||||
"""Test custom Purview exception classes."""
|
||||
|
||||
def test_purview_service_error(self) -> None:
|
||||
"""Test PurviewServiceError base exception."""
|
||||
error = PurviewServiceError("Service error occurred")
|
||||
assert str(error) == "Service error occurred"
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_purview_authentication_error(self) -> None:
|
||||
"""Test PurviewAuthenticationError exception."""
|
||||
error = PurviewAuthenticationError("Authentication failed")
|
||||
assert str(error) == "Authentication failed"
|
||||
assert isinstance(error, PurviewServiceError)
|
||||
|
||||
def test_purview_rate_limit_error(self) -> None:
|
||||
"""Test PurviewRateLimitError exception."""
|
||||
error = PurviewRateLimitError("Rate limit exceeded")
|
||||
assert str(error) == "Rate limit exceeded"
|
||||
assert isinstance(error, PurviewServiceError)
|
||||
|
||||
def test_purview_request_error(self) -> None:
|
||||
"""Test PurviewRequestError exception."""
|
||||
error = PurviewRequestError("Request failed")
|
||||
assert str(error) == "Request failed"
|
||||
assert isinstance(error, PurviewServiceError)
|
||||
@@ -0,0 +1,201 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for Purview middleware."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentRunContext, AgentRunResponse, ChatMessage, Role
|
||||
from azure.core.credentials import AccessToken
|
||||
|
||||
from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings
|
||||
|
||||
|
||||
class TestPurviewPolicyMiddleware:
|
||||
"""Test PurviewPolicyMiddleware functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credential(self) -> AsyncMock:
|
||||
"""Create a mock async credential."""
|
||||
credential = AsyncMock()
|
||||
credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999))
|
||||
return credential
|
||||
|
||||
@pytest.fixture
|
||||
def settings(self) -> PurviewSettings:
|
||||
"""Create test settings."""
|
||||
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
|
||||
|
||||
@pytest.fixture
|
||||
def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewPolicyMiddleware:
|
||||
"""Create PurviewPolicyMiddleware instance."""
|
||||
return PurviewPolicyMiddleware(mock_credential, settings)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(self) -> MagicMock:
|
||||
"""Create a mock agent."""
|
||||
agent = MagicMock()
|
||||
agent.name = "test-agent"
|
||||
return agent
|
||||
|
||||
def test_middleware_initialization(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None:
|
||||
"""Test PurviewPolicyMiddleware initialization."""
|
||||
middleware = PurviewPolicyMiddleware(mock_credential, settings)
|
||||
|
||||
assert middleware._client is not None
|
||||
assert middleware._processor is not None
|
||||
|
||||
async def test_middleware_allows_clean_prompt(
|
||||
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
|
||||
) -> None:
|
||||
"""Test middleware allows prompt that passes policy check."""
|
||||
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello, how are you?")])
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
|
||||
next_called = False
|
||||
|
||||
async def mock_next(ctx: AgentRunContext) -> None:
|
||||
nonlocal next_called
|
||||
next_called = True
|
||||
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="I'm good, thanks!")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
assert next_called
|
||||
assert context.result is not None
|
||||
assert not context.terminate
|
||||
|
||||
async def test_middleware_blocks_prompt_on_policy_violation(
|
||||
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
|
||||
) -> None:
|
||||
"""Test middleware blocks prompt that violates policy."""
|
||||
context = AgentRunContext(
|
||||
agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Sensitive information")]
|
||||
)
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")):
|
||||
next_called = False
|
||||
|
||||
async def mock_next(ctx: AgentRunContext) -> None:
|
||||
nonlocal next_called
|
||||
next_called = True
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
assert not next_called
|
||||
assert context.result is not None
|
||||
assert context.terminate
|
||||
assert len(context.result.messages) == 1
|
||||
assert context.result.messages[0].role == Role.SYSTEM
|
||||
assert "blocked by policy" in context.result.messages[0].text.lower()
|
||||
|
||||
async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None:
|
||||
"""Test middleware checks agent response for policy violations."""
|
||||
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_process_messages(messages, activity, user_id=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
should_block = call_count != 1
|
||||
return (should_block, "user-123")
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: AgentRunContext) -> None:
|
||||
ctx.result = AgentRunResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Here's some sensitive information")]
|
||||
)
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
assert call_count == 2
|
||||
assert context.result is not None
|
||||
assert len(context.result.messages) == 1
|
||||
assert context.result.messages[0].role == Role.SYSTEM
|
||||
assert "blocked by policy" in context.result.messages[0].text.lower()
|
||||
|
||||
async def test_middleware_handles_result_without_messages(
|
||||
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
|
||||
) -> None:
|
||||
"""Test middleware handles result that doesn't have messages attribute."""
|
||||
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")])
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")):
|
||||
|
||||
async def mock_next(ctx: AgentRunContext) -> None:
|
||||
ctx.result = "Some non-standard result"
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
assert context.result == "Some non-standard result"
|
||||
|
||||
async def test_middleware_processor_receives_correct_activity(
|
||||
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
|
||||
) -> None:
|
||||
"""Test middleware passes correct activity type to processor."""
|
||||
from agent_framework_purview._models import Activity
|
||||
|
||||
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process:
|
||||
|
||||
async def mock_next(ctx: AgentRunContext) -> None:
|
||||
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
assert mock_process.call_count == 2
|
||||
for call in mock_process.call_args_list:
|
||||
assert call[0][1] == Activity.UPLOAD_TEXT
|
||||
|
||||
async def test_middleware_handles_pre_check_exception(
|
||||
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
|
||||
) -> None:
|
||||
"""Test that exceptions in pre-check are logged but don't stop processing."""
|
||||
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
|
||||
|
||||
with patch.object(
|
||||
middleware._processor, "process_messages", side_effect=Exception("Pre-check error")
|
||||
) as mock_process:
|
||||
|
||||
async def mock_next(ctx: AgentRunContext) -> None:
|
||||
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
# Should have been called twice (pre-check raises, then post-check also raises)
|
||||
assert mock_process.call_count == 2
|
||||
# Context should not be terminated
|
||||
assert not context.terminate
|
||||
# Result should be set by mock_next
|
||||
assert context.result is not None
|
||||
|
||||
async def test_middleware_handles_post_check_exception(
|
||||
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
|
||||
) -> None:
|
||||
"""Test that exceptions in post-check are logged but don't affect result."""
|
||||
context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")])
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_process_messages(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return (False, "user-123") # Pre-check succeeds
|
||||
raise Exception("Post-check error") # Post-check fails
|
||||
|
||||
with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages):
|
||||
|
||||
async def mock_next(ctx: AgentRunContext) -> None:
|
||||
ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")])
|
||||
|
||||
await middleware.process(context, mock_next)
|
||||
|
||||
# Should have been called twice (pre and post)
|
||||
assert call_count == 2
|
||||
# Result should still be set
|
||||
assert context.result is not None
|
||||
assert hasattr(context.result, "messages")
|
||||
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for Purview models and serialization."""
|
||||
|
||||
from agent_framework_purview._models import (
|
||||
Activity,
|
||||
ActivityMetadata,
|
||||
ContentToProcess,
|
||||
DeviceMetadata,
|
||||
IntegratedAppMetadata,
|
||||
OperatingSystemSpecifications,
|
||||
PolicyLocation,
|
||||
ProcessContentRequest,
|
||||
ProcessContentResponse,
|
||||
ProcessConversationMetadata,
|
||||
ProtectedAppMetadata,
|
||||
ProtectionScopeActivities,
|
||||
ProtectionScopesRequest,
|
||||
ProtectionScopesResponse,
|
||||
PurviewTextContent,
|
||||
deserialize_flag,
|
||||
serialize_flag,
|
||||
)
|
||||
|
||||
|
||||
class TestFlagOperations:
|
||||
"""Test flag serialization and deserialization operations."""
|
||||
|
||||
def test_protection_scope_activities_flag_combination(self) -> None:
|
||||
"""Test combining flags."""
|
||||
combined = ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE
|
||||
assert combined.value == 3
|
||||
assert ProtectionScopeActivities.UPLOAD_TEXT in combined
|
||||
assert ProtectionScopeActivities.UPLOAD_FILE in combined
|
||||
|
||||
def test_deserialize_flag_with_string(self) -> None:
|
||||
"""Test deserializing flag from comma-separated string."""
|
||||
mapping = {
|
||||
"uploadText": ProtectionScopeActivities.UPLOAD_TEXT,
|
||||
"uploadFile": ProtectionScopeActivities.UPLOAD_FILE,
|
||||
}
|
||||
|
||||
result = deserialize_flag("uploadText,uploadFile", mapping, ProtectionScopeActivities)
|
||||
assert result is not None
|
||||
assert ProtectionScopeActivities.UPLOAD_TEXT in result
|
||||
assert ProtectionScopeActivities.UPLOAD_FILE in result
|
||||
|
||||
def test_deserialize_flag_with_none(self) -> None:
|
||||
"""Test deserializing None returns None."""
|
||||
mapping = {"uploadText": ProtectionScopeActivities.UPLOAD_TEXT}
|
||||
result = deserialize_flag(None, mapping, ProtectionScopeActivities)
|
||||
assert result is None
|
||||
|
||||
def test_serialize_flag_with_none(self) -> None:
|
||||
"""Test serializing None returns None."""
|
||||
result = serialize_flag(None, [])
|
||||
assert result is None
|
||||
|
||||
def test_serialize_flag_with_values(self) -> None:
|
||||
"""Test serializing flag with values."""
|
||||
flag = ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE
|
||||
ordered = [
|
||||
("uploadText", ProtectionScopeActivities.UPLOAD_TEXT),
|
||||
("uploadFile", ProtectionScopeActivities.UPLOAD_FILE),
|
||||
]
|
||||
result = serialize_flag(flag, ordered)
|
||||
assert result == "uploadText,uploadFile"
|
||||
|
||||
|
||||
class TestComplexModels:
|
||||
"""Test complex models with nested structures."""
|
||||
|
||||
def test_content_to_process_with_nested_structures(self) -> None:
|
||||
"""Test ContentToProcess with all nested structures."""
|
||||
text_content = PurviewTextContent(data="Test")
|
||||
metadata = ProcessConversationMetadata(
|
||||
identifier="msg-1",
|
||||
content=text_content,
|
||||
name="Test",
|
||||
is_truncated=False,
|
||||
)
|
||||
|
||||
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
|
||||
device_meta = DeviceMetadata(
|
||||
operating_system_specifications=OperatingSystemSpecifications(
|
||||
operating_system_platform="Windows", operating_system_version="10"
|
||||
)
|
||||
)
|
||||
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
|
||||
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
|
||||
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
|
||||
|
||||
content = ContentToProcess(
|
||||
content_entries=[metadata],
|
||||
activity_metadata=activity_meta,
|
||||
device_metadata=device_meta,
|
||||
integrated_app_metadata=integrated_app,
|
||||
protected_app_metadata=protected_app,
|
||||
)
|
||||
|
||||
assert len(content.content_entries) == 1
|
||||
assert content.activity_metadata.activity == Activity.UPLOAD_TEXT
|
||||
assert content.device_metadata.operating_system_specifications.operating_system_platform == "Windows"
|
||||
assert content.integrated_app_metadata.name == "App"
|
||||
assert content.protected_app_metadata.name == "Protected"
|
||||
|
||||
|
||||
class TestRequestResponseSerialization:
|
||||
"""Test request/response serialization with aliases."""
|
||||
|
||||
def test_protection_scopes_request_serialization(self) -> None:
|
||||
"""Test ProtectionScopesRequest serializes activities correctly."""
|
||||
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
|
||||
|
||||
request = ProtectionScopesRequest(
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
activities=ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE,
|
||||
locations=[location],
|
||||
)
|
||||
|
||||
dumped = request.model_dump(by_alias=True, exclude_none=True, mode="json")
|
||||
|
||||
assert "activities" in dumped
|
||||
assert isinstance(dumped["activities"], str)
|
||||
assert "uploadText" in dumped["activities"]
|
||||
|
||||
|
||||
class TestModelDeserialization:
|
||||
"""Test model deserialization from API responses."""
|
||||
|
||||
def test_protection_scopes_response_deserialization(self) -> None:
|
||||
"""Test ProtectionScopesResponse deserializes 'value' to 'scopes'."""
|
||||
api_data = {
|
||||
"scopeIdentifier": "scope-123",
|
||||
"value": [
|
||||
{
|
||||
"activities": "uploadText,downloadText",
|
||||
"locations": [{"@odata.type": "location.type", "value": "/path"}],
|
||||
"policyActions": [{"action": "warn", "restrictionAction": "blockAccess"}],
|
||||
"executionMode": "evaluateInline",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = ProtectionScopesResponse.model_validate(api_data)
|
||||
|
||||
assert response.scope_identifier == "scope-123"
|
||||
assert response.scopes is not None
|
||||
assert len(response.scopes) == 1
|
||||
assert response.scopes[0].execution_mode == "evaluateInline"
|
||||
|
||||
def test_process_content_response_deserialization(self) -> None:
|
||||
"""Test ProcessContentResponse deserializes aliased fields correctly."""
|
||||
api_data = {
|
||||
"id": "response-123",
|
||||
"protectionScopeState": "blocked",
|
||||
"policyActions": [{"action": "block", "restrictionAction": "blockAccess"}],
|
||||
}
|
||||
|
||||
response = ProcessContentResponse.model_validate(api_data)
|
||||
|
||||
assert response.id == "response-123"
|
||||
assert response.protection_scope_state == "blocked"
|
||||
assert len(response.policy_actions) == 1
|
||||
|
||||
def test_content_serialization_uses_aliases(self) -> None:
|
||||
"""Test ContentToProcess serializes with camelCase aliases."""
|
||||
text_content = PurviewTextContent(data="Test")
|
||||
metadata = ProcessConversationMetadata(
|
||||
identifier="msg-1",
|
||||
content=text_content,
|
||||
name="Test",
|
||||
is_truncated=False,
|
||||
)
|
||||
|
||||
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
|
||||
device_meta = DeviceMetadata(
|
||||
operating_system_specifications=OperatingSystemSpecifications(
|
||||
operating_system_platform="Windows", operating_system_version="10"
|
||||
)
|
||||
)
|
||||
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
|
||||
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
|
||||
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
|
||||
|
||||
content = ContentToProcess(
|
||||
content_entries=[metadata],
|
||||
activity_metadata=activity_meta,
|
||||
device_metadata=device_meta,
|
||||
integrated_app_metadata=integrated_app,
|
||||
protected_app_metadata=protected_app,
|
||||
)
|
||||
|
||||
dumped = content.model_dump(by_alias=True, exclude_none=True, mode="json")
|
||||
|
||||
assert "contentEntries" in dumped
|
||||
assert "activityMetadata" in dumped
|
||||
assert "deviceMetadata" in dumped
|
||||
assert "integratedAppMetadata" in dumped
|
||||
assert "protectedAppMetadata" in dumped
|
||||
|
||||
def test_process_content_request_excludes_private_fields(self) -> None:
|
||||
"""Test ProcessContentRequest excludes private fields when serializing."""
|
||||
text_content = PurviewTextContent(data="Test")
|
||||
metadata = ProcessConversationMetadata(
|
||||
identifier="msg-1",
|
||||
content=text_content,
|
||||
name="Test",
|
||||
is_truncated=False,
|
||||
)
|
||||
|
||||
activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT)
|
||||
device_meta = DeviceMetadata(
|
||||
operating_system_specifications=OperatingSystemSpecifications(
|
||||
operating_system_platform="Windows", operating_system_version="10"
|
||||
)
|
||||
)
|
||||
integrated_app = IntegratedAppMetadata(name="App", version="1.0")
|
||||
location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id")
|
||||
protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location)
|
||||
|
||||
content = ContentToProcess(
|
||||
content_entries=[metadata],
|
||||
activity_metadata=activity_meta,
|
||||
device_metadata=device_meta,
|
||||
integrated_app_metadata=integrated_app,
|
||||
protected_app_metadata=protected_app,
|
||||
)
|
||||
|
||||
request = ProcessContentRequest(
|
||||
content_to_process=content,
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
correlation_id="corr-789",
|
||||
)
|
||||
|
||||
dumped = request.model_dump(by_alias=True, exclude_none=True, mode="json")
|
||||
|
||||
# Check that excluded fields are not present
|
||||
assert "user_id" not in dumped
|
||||
assert "tenant_id" not in dumped
|
||||
assert "correlation_id" not in dumped
|
||||
|
||||
# Check that content is present
|
||||
assert "contentToProcess" in dumped
|
||||
@@ -0,0 +1,369 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for Purview processor."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import ChatMessage, Role
|
||||
|
||||
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings
|
||||
from agent_framework_purview._models import (
|
||||
Activity,
|
||||
DlpAction,
|
||||
DlpActionInfo,
|
||||
ProcessContentResponse,
|
||||
RestrictionAction,
|
||||
)
|
||||
from agent_framework_purview._processor import ScopedContentProcessor, _is_valid_guid
|
||||
|
||||
|
||||
class TestGuidValidation:
|
||||
"""Test GUID validation helper."""
|
||||
|
||||
def test_valid_guid(self) -> None:
|
||||
"""Test _is_valid_guid with valid GUIDs."""
|
||||
assert _is_valid_guid("12345678-1234-1234-1234-123456789012")
|
||||
assert _is_valid_guid("a1b2c3d4-e5f6-4a5b-8c9d-0e1f2a3b4c5d")
|
||||
|
||||
def test_invalid_guid(self) -> None:
|
||||
"""Test _is_valid_guid with invalid GUIDs."""
|
||||
assert not _is_valid_guid("not-a-guid")
|
||||
assert not _is_valid_guid("")
|
||||
assert not _is_valid_guid(None)
|
||||
|
||||
|
||||
class TestScopedContentProcessor:
|
||||
"""Test ScopedContentProcessor functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(self) -> AsyncMock:
|
||||
"""Create a mock Purview client."""
|
||||
client = AsyncMock()
|
||||
client.get_user_info_from_token = AsyncMock(
|
||||
return_value={
|
||||
"tenant_id": "12345678-1234-1234-1234-123456789012",
|
||||
"user_id": "12345678-1234-1234-1234-123456789012",
|
||||
"client_id": "12345678-1234-1234-1234-123456789012",
|
||||
}
|
||||
)
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def settings_with_defaults(self) -> PurviewSettings:
|
||||
"""Create settings with default values."""
|
||||
app_location = PurviewAppLocation(
|
||||
location_type=PurviewLocationType.APPLICATION, location_value="12345678-1234-1234-1234-123456789012"
|
||||
)
|
||||
return PurviewSettings(
|
||||
app_name="Test App",
|
||||
tenant_id="12345678-1234-1234-1234-123456789012",
|
||||
purview_app_location=app_location,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def settings_without_defaults(self) -> PurviewSettings:
|
||||
"""Create settings without default values (requiring token info)."""
|
||||
return PurviewSettings(app_name="Test App")
|
||||
|
||||
@pytest.fixture
|
||||
def processor(self, mock_client: AsyncMock, settings_with_defaults: PurviewSettings) -> ScopedContentProcessor:
|
||||
"""Create a ScopedContentProcessor with mock client."""
|
||||
return ScopedContentProcessor(mock_client, settings_with_defaults)
|
||||
|
||||
async def test_processor_initialization(
|
||||
self, mock_client: AsyncMock, settings_with_defaults: PurviewSettings
|
||||
) -> None:
|
||||
"""Test ScopedContentProcessor initialization."""
|
||||
processor = ScopedContentProcessor(mock_client, settings_with_defaults)
|
||||
|
||||
assert processor._client == mock_client
|
||||
assert processor._settings == settings_with_defaults
|
||||
|
||||
async def test_process_messages_with_defaults(self, processor: ScopedContentProcessor) -> None:
|
||||
"""Test process_messages with settings that have defaults."""
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text="Hello"),
|
||||
ChatMessage(role=Role.ASSISTANT, text="Hi there"),
|
||||
]
|
||||
|
||||
with patch.object(processor, "_map_messages", return_value=([], None)) as mock_map:
|
||||
should_block, user_id = await processor.process_messages(messages, Activity.UPLOAD_TEXT)
|
||||
|
||||
assert should_block is False
|
||||
assert user_id is None
|
||||
mock_map.assert_called_once_with(messages, Activity.UPLOAD_TEXT, None)
|
||||
|
||||
async def test_process_messages_blocks_content(
|
||||
self, processor: ScopedContentProcessor, process_content_request_factory
|
||||
) -> None:
|
||||
"""Test process_messages returns True when content should be blocked."""
|
||||
messages = [ChatMessage(role=Role.USER, text="Sensitive content")]
|
||||
|
||||
mock_request = process_content_request_factory("Sensitive content")
|
||||
|
||||
mock_response = ProcessContentResponse(**{
|
||||
"policyActions": [DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)]
|
||||
})
|
||||
|
||||
with (
|
||||
patch.object(processor, "_map_messages", return_value=([mock_request], "user-123")),
|
||||
patch.object(processor, "_process_with_scopes", return_value=mock_response),
|
||||
):
|
||||
should_block, user_id = await processor.process_messages(messages, Activity.UPLOAD_TEXT)
|
||||
|
||||
assert should_block is True
|
||||
assert user_id == "user-123"
|
||||
|
||||
async def test_map_messages_creates_requests(
|
||||
self, processor: ScopedContentProcessor, mock_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test _map_messages creates ProcessContentRequest objects."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.USER,
|
||||
text="Test message",
|
||||
message_id="msg-123",
|
||||
author_name="12345678-1234-1234-1234-123456789012",
|
||||
),
|
||||
]
|
||||
|
||||
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
|
||||
|
||||
assert len(requests) == 1
|
||||
assert requests[0].user_id == "12345678-1234-1234-1234-123456789012"
|
||||
assert requests[0].tenant_id == "12345678-1234-1234-1234-123456789012"
|
||||
assert user_id == "12345678-1234-1234-1234-123456789012"
|
||||
|
||||
async def test_map_messages_without_defaults_gets_token_info(self, mock_client: AsyncMock) -> None:
|
||||
"""Test _map_messages gets token info when settings lack some defaults."""
|
||||
settings = PurviewSettings(app_name="Test App", tenant_id="12345678-1234-1234-1234-123456789012")
|
||||
processor = ScopedContentProcessor(mock_client, settings)
|
||||
messages = [ChatMessage(role=Role.USER, text="Test", message_id="msg-123")]
|
||||
|
||||
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
|
||||
|
||||
mock_client.get_user_info_from_token.assert_called_once()
|
||||
assert len(requests) == 1
|
||||
assert user_id is not None
|
||||
|
||||
async def test_map_messages_raises_on_missing_tenant_id(self, mock_client: AsyncMock) -> None:
|
||||
"""Test _map_messages raises ValueError when tenant_id cannot be determined."""
|
||||
settings = PurviewSettings(app_name="Test App") # No tenant_id
|
||||
processor = ScopedContentProcessor(mock_client, settings)
|
||||
|
||||
mock_client.get_user_info_from_token = AsyncMock(
|
||||
return_value={"user_id": "test-user", "client_id": "test-client"}
|
||||
)
|
||||
|
||||
messages = [ChatMessage(role=Role.USER, text="Test", message_id="msg-123")]
|
||||
|
||||
with pytest.raises(ValueError, match="Tenant id required"):
|
||||
await processor._map_messages(messages, Activity.UPLOAD_TEXT)
|
||||
|
||||
async def test_check_applicable_scopes_no_scopes(
|
||||
self, processor: ScopedContentProcessor, process_content_request_factory
|
||||
) -> None:
|
||||
"""Test _check_applicable_scopes when no scopes are returned."""
|
||||
from agent_framework_purview._models import ProtectionScopesResponse
|
||||
|
||||
request = process_content_request_factory()
|
||||
response = ProtectionScopesResponse(**{"value": None})
|
||||
|
||||
should_process, actions = processor._check_applicable_scopes(request, response)
|
||||
|
||||
assert should_process is False
|
||||
assert actions == []
|
||||
|
||||
async def test_check_applicable_scopes_with_block_action(
|
||||
self, processor: ScopedContentProcessor, process_content_request_factory
|
||||
) -> None:
|
||||
"""Test _check_applicable_scopes identifies block actions."""
|
||||
from agent_framework_purview._models import (
|
||||
PolicyLocation,
|
||||
PolicyScope,
|
||||
ProtectionScopeActivities,
|
||||
ProtectionScopesResponse,
|
||||
)
|
||||
|
||||
request = process_content_request_factory()
|
||||
|
||||
block_action = DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)
|
||||
scope_location = PolicyLocation(**{
|
||||
"@odata.type": "microsoft.graph.policyLocationApplication",
|
||||
"value": "app-id",
|
||||
})
|
||||
scope = PolicyScope(**{
|
||||
"policyActions": [block_action],
|
||||
"activities": ProtectionScopeActivities.UPLOAD_TEXT,
|
||||
"locations": [scope_location],
|
||||
})
|
||||
response = ProtectionScopesResponse(**{"value": [scope]})
|
||||
|
||||
should_process, actions = processor._check_applicable_scopes(request, response)
|
||||
|
||||
assert should_process is True
|
||||
assert len(actions) == 1
|
||||
assert actions[0].action == DlpAction.BLOCK_ACCESS
|
||||
|
||||
async def test_combine_policy_actions(self, processor: ScopedContentProcessor) -> None:
|
||||
"""Test _combine_policy_actions merges action lists."""
|
||||
action1 = DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)
|
||||
action2 = DlpActionInfo(action=DlpAction.OTHER, restrictionAction=RestrictionAction.OTHER)
|
||||
|
||||
combined = processor._combine_policy_actions([action1], [action2])
|
||||
|
||||
assert len(combined) == 2
|
||||
assert action1 in combined
|
||||
assert action2 in combined
|
||||
|
||||
async def test_process_with_scopes_calls_client_methods(
|
||||
self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory
|
||||
) -> None:
|
||||
"""Test _process_with_scopes calls get_protection_scopes and process_content."""
|
||||
from agent_framework_purview._models import (
|
||||
ContentActivitiesResponse,
|
||||
ProtectionScopesResponse,
|
||||
)
|
||||
|
||||
request = process_content_request_factory()
|
||||
|
||||
mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": []}))
|
||||
mock_client.process_content = AsyncMock(
|
||||
return_value=ProcessContentResponse(**{"id": "response-123", "protectionScopeState": "notModified"})
|
||||
)
|
||||
mock_client.send_content_activities = AsyncMock(return_value=ContentActivitiesResponse(**{"error": None}))
|
||||
|
||||
response = await processor._process_with_scopes(request)
|
||||
|
||||
mock_client.get_protection_scopes.assert_called_once()
|
||||
mock_client.process_content.assert_not_called()
|
||||
mock_client.send_content_activities.assert_called_once()
|
||||
assert response.id is None
|
||||
|
||||
async def test_map_messages_with_user_id_in_additional_properties(self, mock_client: AsyncMock) -> None:
|
||||
"""Test user_id extraction from message additional_properties."""
|
||||
settings = PurviewSettings(
|
||||
app_name="Test App",
|
||||
tenant_id="12345678-1234-1234-1234-123456789012",
|
||||
purview_app_location=PurviewAppLocation(
|
||||
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
|
||||
),
|
||||
)
|
||||
processor = ScopedContentProcessor(mock_client, settings)
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role=Role.USER,
|
||||
text="Test message",
|
||||
additional_properties={"user_id": "22345678-1234-1234-1234-123456789012"},
|
||||
),
|
||||
]
|
||||
|
||||
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
|
||||
|
||||
assert len(requests) == 1
|
||||
assert user_id == "22345678-1234-1234-1234-123456789012"
|
||||
assert requests[0].user_id == "22345678-1234-1234-1234-123456789012"
|
||||
|
||||
async def test_map_messages_with_provided_user_id_fallback(self, mock_client: AsyncMock) -> None:
|
||||
"""Test using provided_user_id when no other source is available."""
|
||||
settings = PurviewSettings(
|
||||
app_name="Test App",
|
||||
tenant_id="12345678-1234-1234-1234-123456789012",
|
||||
purview_app_location=PurviewAppLocation(
|
||||
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
|
||||
),
|
||||
)
|
||||
processor = ScopedContentProcessor(mock_client, settings)
|
||||
|
||||
messages = [ChatMessage(role=Role.USER, text="Test message")]
|
||||
|
||||
requests, user_id = await processor._map_messages(
|
||||
messages, Activity.UPLOAD_TEXT, provided_user_id="32345678-1234-1234-1234-123456789012"
|
||||
)
|
||||
|
||||
assert len(requests) == 1
|
||||
assert user_id == "32345678-1234-1234-1234-123456789012"
|
||||
assert requests[0].user_id == "32345678-1234-1234-1234-123456789012"
|
||||
|
||||
async def test_map_messages_returns_empty_when_no_user_id(self, mock_client: AsyncMock) -> None:
|
||||
"""Test that empty results are returned when user_id cannot be resolved."""
|
||||
settings = PurviewSettings(
|
||||
app_name="Test App",
|
||||
tenant_id="12345678-1234-1234-1234-123456789012",
|
||||
purview_app_location=PurviewAppLocation(
|
||||
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
|
||||
),
|
||||
)
|
||||
processor = ScopedContentProcessor(mock_client, settings)
|
||||
|
||||
messages = [ChatMessage(role=Role.USER, text="Test message")]
|
||||
|
||||
requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT)
|
||||
|
||||
assert len(requests) == 0
|
||||
assert user_id is None
|
||||
|
||||
async def test_process_content_sends_activities_when_not_applicable(
|
||||
self, mock_client: AsyncMock, process_content_request_factory
|
||||
) -> None:
|
||||
"""Test that content activities are sent when scopes don't apply."""
|
||||
settings = PurviewSettings(
|
||||
app_name="Test App",
|
||||
tenant_id="12345678-1234-1234-1234-123456789012",
|
||||
purview_app_location=PurviewAppLocation(
|
||||
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
|
||||
),
|
||||
)
|
||||
processor = ScopedContentProcessor(mock_client, settings)
|
||||
|
||||
pc_request = process_content_request_factory()
|
||||
|
||||
# Mock get_protection_scopes to return no applicable scopes
|
||||
mock_ps_response = MagicMock()
|
||||
mock_ps_response.scopes = []
|
||||
mock_client.get_protection_scopes.return_value = mock_ps_response
|
||||
|
||||
# Mock send_content_activities to return success
|
||||
mock_ca_response = MagicMock()
|
||||
mock_ca_response.error = None
|
||||
mock_client.send_content_activities.return_value = mock_ca_response
|
||||
|
||||
response = await processor._process_with_scopes(pc_request)
|
||||
|
||||
mock_client.get_protection_scopes.assert_called_once()
|
||||
mock_client.process_content.assert_not_called()
|
||||
mock_client.send_content_activities.assert_called_once()
|
||||
# When content activities succeed, response has no errors (processing_errors can be None or empty)
|
||||
assert response.processing_errors is None or response.processing_errors == []
|
||||
|
||||
async def test_process_content_handles_activities_error(
|
||||
self, mock_client: AsyncMock, process_content_request_factory
|
||||
) -> None:
|
||||
"""Test error handling when content activities fail."""
|
||||
settings = PurviewSettings(
|
||||
app_name="Test App",
|
||||
tenant_id="12345678-1234-1234-1234-123456789012",
|
||||
purview_app_location=PurviewAppLocation(
|
||||
location_type=PurviewLocationType.APPLICATION, location_value="app-id"
|
||||
),
|
||||
)
|
||||
processor = ScopedContentProcessor(mock_client, settings)
|
||||
|
||||
pc_request = process_content_request_factory()
|
||||
|
||||
# Mock get_protection_scopes to return no applicable scopes
|
||||
mock_ps_response = MagicMock()
|
||||
mock_ps_response.scopes = []
|
||||
mock_client.get_protection_scopes.return_value = mock_ps_response
|
||||
|
||||
# Mock send_content_activities to return error
|
||||
mock_ca_response = MagicMock()
|
||||
mock_ca_response.error = "Test error message"
|
||||
mock_client.send_content_activities.return_value = mock_ca_response
|
||||
|
||||
response = await processor._process_with_scopes(pc_request)
|
||||
|
||||
assert len(response.processing_errors) == 1
|
||||
assert response.processing_errors[0].message == "Test error message"
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for Purview settings."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings
|
||||
|
||||
|
||||
class TestPurviewSettings:
|
||||
"""Test PurviewSettings configuration."""
|
||||
|
||||
def test_settings_defaults(self) -> None:
|
||||
"""Test PurviewSettings with default values."""
|
||||
settings = PurviewSettings(app_name="Test App")
|
||||
|
||||
assert settings.app_name == "Test App"
|
||||
assert settings.graph_base_uri == "https://graph.microsoft.com/v1.0/"
|
||||
assert settings.tenant_id is None
|
||||
assert settings.purview_app_location is None
|
||||
assert settings.process_inline is False
|
||||
|
||||
def test_settings_with_custom_values(self) -> None:
|
||||
"""Test PurviewSettings with custom values."""
|
||||
app_location = PurviewAppLocation(location_type=PurviewLocationType.APPLICATION, location_value="app-123")
|
||||
|
||||
settings = PurviewSettings(
|
||||
app_name="Test App",
|
||||
graph_base_uri="https://graph.microsoft-ppe.com",
|
||||
tenant_id="test-tenant-id",
|
||||
process_inline=True,
|
||||
purview_app_location=app_location,
|
||||
)
|
||||
|
||||
assert settings.graph_base_uri == "https://graph.microsoft-ppe.com"
|
||||
assert settings.tenant_id == "test-tenant-id"
|
||||
assert settings.process_inline is True
|
||||
assert settings.purview_app_location.location_value == "app-123"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"graph_uri,expected_scope",
|
||||
[
|
||||
("https://graph.microsoft.com/v1.0/", "https://graph.microsoft.com/.default"),
|
||||
("https://graph.microsoft-ppe.com/v1.0/", "https://graph.microsoft-ppe.com/.default"),
|
||||
],
|
||||
)
|
||||
def test_get_scopes(self, graph_uri: str, expected_scope: str) -> None:
|
||||
"""Test get_scopes returns correct scope for different URIs."""
|
||||
settings = PurviewSettings(app_name="Test App", graph_base_uri=graph_uri)
|
||||
scopes = settings.get_scopes()
|
||||
|
||||
assert len(scopes) == 1
|
||||
assert expected_scope in scopes
|
||||
|
||||
|
||||
class TestPurviewAppLocation:
|
||||
"""Test PurviewAppLocation configuration."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"location_type,location_value,expected_odata_type",
|
||||
[
|
||||
(PurviewLocationType.APPLICATION, "app-123", "microsoft.graph.policyLocationApplication"),
|
||||
(PurviewLocationType.URI, "https://example.com", "microsoft.graph.policyLocationUrl"),
|
||||
(PurviewLocationType.DOMAIN, "example.com", "microsoft.graph.policyLocationDomain"),
|
||||
],
|
||||
)
|
||||
def test_get_policy_location(
|
||||
self, location_type: PurviewLocationType, location_value: str, expected_odata_type: str
|
||||
) -> None:
|
||||
"""Test get_policy_location returns correct structure for all location types."""
|
||||
location = PurviewAppLocation(location_type=location_type, location_value=location_value)
|
||||
policy_location = location.get_policy_location()
|
||||
|
||||
assert policy_location["@odata.type"] == expected_odata_type
|
||||
assert policy_location["value"] == location_value
|
||||
|
||||
|
||||
class TestPurviewLocationType:
|
||||
"""Test PurviewLocationType enum."""
|
||||
|
||||
def test_location_type_values(self) -> None:
|
||||
"""Test PurviewLocationType enum has expected values."""
|
||||
assert PurviewLocationType.APPLICATION == "application"
|
||||
assert PurviewLocationType.URI == "uri"
|
||||
assert PurviewLocationType.DOMAIN == "domain"
|
||||
@@ -0,0 +1,88 @@
|
||||
## Purview Policy Enforcement Sample (Python)
|
||||
|
||||
This getting-started sample shows how to attach Microsoft Purview policy evaluation to an Agent Framework `ChatAgent` using the **middleware** approach.
|
||||
|
||||
1. Configure an Azure OpenAI chat client
|
||||
2. Add Purview policy enforcement middleware (`PurviewPolicyMiddleware`)
|
||||
3. Run a short conversation and observe prompt / response blocking behavior
|
||||
|
||||
---
|
||||
## 1. Setup
|
||||
### Required Environment Variables
|
||||
|
||||
| Variable | Required | Purpose |
|
||||
|----------|----------|---------|
|
||||
| `AZURE_OPENAI_ENDPOINT` | Yes | Azure OpenAI endpoint (https://<name>.openai.azure.com) |
|
||||
| `AZURE_OPENAI_DEPLOYMENT_NAME` | Optional | Model deployment name (defaults inside SDK if omitted) |
|
||||
| `PURVIEW_CLIENT_APP_ID` | Yes* | Client (application) ID used for Purview authentication |
|
||||
| `PURVIEW_USE_CERT_AUTH` | Optional (`true`/`false`) | Switch between certificate and interactive auth |
|
||||
| `PURVIEW_TENANT_ID` | Yes (when cert auth on) | Tenant ID for certificate authentication |
|
||||
| `PURVIEW_CERT_PATH` | Yes (when cert auth on) | Path to your .pfx certificate |
|
||||
| `PURVIEW_CERT_PASSWORD` | Optional | Password for encrypted certs |
|
||||
|
||||
*A demo default exists in code for illustration only—always set your own value.
|
||||
|
||||
### 2. Auth Modes Supported
|
||||
|
||||
#### A. Interactive Browser Authentication (default)
|
||||
Opens a browser on first run to sign in.
|
||||
|
||||
```powershell
|
||||
$env:AZURE_OPENAI_ENDPOINT = "https://your-openai-instance.openai.azure.com"
|
||||
$env:PURVIEW_CLIENT_APP_ID = "00000000-0000-0000-0000-000000000000"
|
||||
```
|
||||
|
||||
#### B. Certificate Authentication
|
||||
For headless / CI scenarios.
|
||||
|
||||
```powershell
|
||||
$env:PURVIEW_USE_CERT_AUTH = "true"
|
||||
$env:PURVIEW_TENANT_ID = "<tenant-guid>"
|
||||
$env:PURVIEW_CERT_PATH = "C:\path\to\cert.pfx"
|
||||
$env:PURVIEW_CERT_PASSWORD = "optional-password"
|
||||
```
|
||||
|
||||
Certificate steps (summary): create / register app, generate certificate, upload public key, export .pfx with private key, grant required Graph / Purview permissions.
|
||||
|
||||
---
|
||||
|
||||
## 3. Run the Sample
|
||||
|
||||
From repo root:
|
||||
|
||||
```powershell
|
||||
cd python/samples/getting_started/purview_agent
|
||||
python sample_purview_agent.py
|
||||
```
|
||||
|
||||
If interactive auth is used, a browser window will appear the first time.
|
||||
|
||||
---
|
||||
|
||||
## 4. How It Works
|
||||
|
||||
1. Builds an Azure OpenAI chat client (using the environment endpoint / deployment)
|
||||
2. Chooses credential mode (certificate vs interactive)
|
||||
3. Creates `PurviewPolicyMiddleware` with `PurviewSettings`
|
||||
4. Injects middleware into the agent at construction
|
||||
5. Sends two user messages sequentially
|
||||
6. Prints results (or policy block messages)
|
||||
|
||||
Prompt blocks set a system-level message: `Prompt blocked by policy` and terminate the run early. Response blocks rewrite the output to `Response blocked by policy`.
|
||||
|
||||
---
|
||||
|
||||
## 5. Code Snippet (Middleware Injection)
|
||||
|
||||
```python
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
instructions="You are good at telling jokes.",
|
||||
name="Joker",
|
||||
middleware=[
|
||||
PurviewPolicyMiddleware(credential, PurviewSettings(app_name="Sample App", default_user_id="<guid>"))
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
"""Purview policy enforcement sample (Python).
|
||||
|
||||
Shows:
|
||||
1. Creating a basic chat agent
|
||||
2. Adding Purview policy evaluation via AGENT middleware (agent-level)
|
||||
3. Adding Purview policy evaluation via CHAT middleware (chat-client level)
|
||||
4. Running a threaded conversation and printing results
|
||||
|
||||
Environment variables:
|
||||
- AZURE_OPENAI_ENDPOINT (required)
|
||||
- AZURE_OPENAI_DEPLOYMENT_NAME (optional, defaults to gpt-4o-mini)
|
||||
- PURVIEW_CLIENT_APP_ID (required)
|
||||
- PURVIEW_USE_CERT_AUTH (optional, set to "true" for certificate auth)
|
||||
- PURVIEW_TENANT_ID (required if certificate auth)
|
||||
- PURVIEW_CERT_PATH (required if certificate auth)
|
||||
- PURVIEW_CERT_PASSWORD (optional)
|
||||
- PURVIEW_DEFAULT_USER_ID (optional, user ID for Purview evaluation)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentRunResponse, ChatAgent, ChatMessage, Role
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import (
|
||||
AzureCliCredential,
|
||||
CertificateCredential,
|
||||
InteractiveBrowserCredential,
|
||||
)
|
||||
|
||||
# Purview integration pieces
|
||||
from agent_framework.microsoft import (
|
||||
PurviewPolicyMiddleware,
|
||||
PurviewChatPolicyMiddleware,
|
||||
PurviewSettings,
|
||||
)
|
||||
|
||||
JOKER_NAME = "Joker"
|
||||
JOKER_INSTRUCTIONS = "You are good at telling jokes. Keep responses concise."
|
||||
|
||||
|
||||
def _get_env(name: str, *, required: bool = True, default: str | None = None) -> str:
|
||||
val = os.environ.get(name, default)
|
||||
if required and not val:
|
||||
raise RuntimeError(f"Environment variable {name} is required")
|
||||
return val # type: ignore[return-value]
|
||||
|
||||
|
||||
def build_credential() -> Any:
|
||||
"""Select an Azure credential for Purview authentication.
|
||||
|
||||
Supported modes:
|
||||
1. CertificateCredential (if PURVIEW_USE_CERT_AUTH=true)
|
||||
2. InteractiveBrowserCredential (requires PURVIEW_CLIENT_APP_ID)
|
||||
"""
|
||||
client_id = _get_env("PURVIEW_CLIENT_APP_ID", required=True)
|
||||
use_cert_auth = _get_env("PURVIEW_USE_CERT_AUTH", required=False, default="false").lower() == "true"
|
||||
|
||||
if not client_id:
|
||||
raise RuntimeError(
|
||||
"PURVIEW_CLIENT_APP_ID is required for interactive browser authentication; "
|
||||
"set PURVIEW_USE_CERT_AUTH=true for certificate mode instead."
|
||||
)
|
||||
|
||||
if use_cert_auth:
|
||||
tenant_id = _get_env("PURVIEW_TENANT_ID")
|
||||
cert_path = _get_env("PURVIEW_CERT_PATH")
|
||||
cert_password = _get_env("PURVIEW_CERT_PASSWORD", required=False, default=None)
|
||||
print(f"Using Certificate Authentication (tenant: {tenant_id}, cert: {cert_path})")
|
||||
return CertificateCredential(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
certificate_path=cert_path,
|
||||
password=cert_password,
|
||||
)
|
||||
|
||||
print(f"Using Interactive Browser Authentication (client_id: {client_id})")
|
||||
return InteractiveBrowserCredential(client_id=client_id)
|
||||
|
||||
|
||||
async def run_with_agent_middleware() -> None:
|
||||
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
if not endpoint:
|
||||
print("Skipping run: AZURE_OPENAI_ENDPOINT not set")
|
||||
return
|
||||
|
||||
deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4o-mini")
|
||||
user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID")
|
||||
chat_client = AzureOpenAIChatClient(deployment_name=deployment, endpoint=endpoint, credential=AzureCliCredential())
|
||||
|
||||
purview_agent_middleware = PurviewPolicyMiddleware(
|
||||
build_credential(),
|
||||
PurviewSettings(
|
||||
app_name="Agent Framework Sample App",
|
||||
),
|
||||
)
|
||||
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
instructions=JOKER_INSTRUCTIONS,
|
||||
name=JOKER_NAME,
|
||||
middleware=purview_agent_middleware,
|
||||
)
|
||||
|
||||
print("-- Agent Middleware Path --")
|
||||
first: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="Tell me a joke about a pirate.", additional_properties={"user_id": user_id}))
|
||||
print("First response (agent middleware):\n", first)
|
||||
|
||||
second: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="That was funny. Tell me another one.", additional_properties={"user_id": user_id}))
|
||||
print("Second response (agent middleware):\n", second)
|
||||
|
||||
|
||||
async def run_with_chat_middleware() -> None:
|
||||
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
if not endpoint:
|
||||
print("Skipping chat middleware run: AZURE_OPENAI_ENDPOINT not set")
|
||||
return
|
||||
|
||||
deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", default="gpt-4o-mini")
|
||||
user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID")
|
||||
|
||||
chat_client = AzureOpenAIChatClient(
|
||||
deployment_name=deployment,
|
||||
endpoint=endpoint,
|
||||
credential=AzureCliCredential(),
|
||||
middleware=[
|
||||
PurviewChatPolicyMiddleware(
|
||||
build_credential(),
|
||||
PurviewSettings(
|
||||
app_name="Agent Framework Sample App (Chat)",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client,
|
||||
instructions=JOKER_INSTRUCTIONS,
|
||||
name=JOKER_NAME,
|
||||
)
|
||||
|
||||
print("-- Chat Middleware Path --")
|
||||
first: AgentRunResponse = await agent.run(
|
||||
ChatMessage(
|
||||
role=Role.USER,
|
||||
text="Give me a short clean joke.",
|
||||
additional_properties={"user_id": user_id},
|
||||
)
|
||||
)
|
||||
print("First response (chat middleware):\n", first)
|
||||
|
||||
second: AgentRunResponse = await agent.run(
|
||||
ChatMessage(
|
||||
role=Role.USER,
|
||||
text="One more please.",
|
||||
additional_properties={"user_id": user_id},
|
||||
)
|
||||
)
|
||||
print("Second response (chat middleware):\n", second)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("== Purview Agent Sample (Agent & Chat Middleware) ==")
|
||||
try:
|
||||
await run_with_agent_middleware()
|
||||
except Exception as ex: # pragma: no cover - demo resilience
|
||||
print(f"Agent middleware path failed: {ex}")
|
||||
|
||||
try:
|
||||
await run_with_chat_middleware()
|
||||
except Exception as ex: # pragma: no cover - demo resilience
|
||||
print(f"Chat middleware path failed: {ex}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Generated
+18
@@ -31,6 +31,7 @@ members = [
|
||||
"agent-framework-devui",
|
||||
"agent-framework-lab",
|
||||
"agent-framework-mem0",
|
||||
"agent-framework-purview",
|
||||
"agent-framework-redis",
|
||||
]
|
||||
|
||||
@@ -383,6 +384,23 @@ requires-dist = [
|
||||
{ name = "mem0ai", specifier = ">=0.1.117" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-purview"
|
||||
version = "0.1.0b1"
|
||||
source = { editable = "packages/purview" }
|
||||
dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "azure-core", specifier = ">=1.30.0" },
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent-framework-redis"
|
||||
version = "1.0.0b251007"
|
||||
|
||||
Reference in New Issue
Block a user