From 59da578902eb86932d110f2ad9c6a799a5f856a8 Mon Sep 17 00:00:00 2001 From: Rishabh Chawla Date: Thu, 16 Oct 2025 14:46:04 -0700 Subject: [PATCH] Python: Add Purview Middleware (#1142) * [Py Purview] Purview Python Initial Commit * [Py Purview] Purview Python Minor Fixes * [Py Purview] Purview Python Comment Fixesish * [Py Purview] Purview Python Agent Middleware Done * [Py Purview] Purview Python Agent Middleware Done * [Py Purview] Purview Python Lint Errors * [Py Purview] Purview Python Final Hopefully * [Py Purview] Purview Python Final Hopefully * [Py Purview] Purview Python Fix ReadMe * [Py Purview] Purview Python Fix MyPy * [Py Purview] Purview Python Minor Updates on comments * [Py Purview] Purview Python Fix Build Error --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- .../core/agent_framework/azure/_shared.py | 2 +- .../agent_framework/microsoft/__init__.py | 14 +- .../agent_framework/microsoft/__init__.pyi | 26 +- python/packages/purview/LICENSE | 21 + python/packages/purview/README.md | 224 ++++ .../agent_framework_purview/__init__.py | 22 + .../agent_framework_purview/_client.py | 126 +++ .../agent_framework_purview/_exceptions.py | 29 + .../agent_framework_purview/_middleware.py | 168 +++ .../agent_framework_purview/_models.py | 992 ++++++++++++++++++ .../agent_framework_purview/_processor.py | 251 +++++ .../agent_framework_purview/_settings.py | 71 ++ python/packages/purview/pyproject.toml | 87 ++ python/packages/purview/tests/conftest.py | 68 ++ .../purview/tests/test_chat_middleware.py | 149 +++ python/packages/purview/tests/test_client.py | 238 +++++ .../packages/purview/tests/test_exceptions.py | 38 + .../packages/purview/tests/test_middleware.py | 201 ++++ python/packages/purview/tests/test_models.py | 246 +++++ .../packages/purview/tests/test_processor.py | 369 +++++++ .../packages/purview/tests/test_settings.py | 85 ++ .../getting_started/purview_agent/README.md | 88 ++ .../purview_agent/sample_purview_agent.py | 179 ++++ python/uv.lock | 18 + 24 files changed, 3707 insertions(+), 5 deletions(-) create mode 100644 python/packages/purview/LICENSE create mode 100644 python/packages/purview/README.md create mode 100644 python/packages/purview/agent_framework_purview/__init__.py create mode 100644 python/packages/purview/agent_framework_purview/_client.py create mode 100644 python/packages/purview/agent_framework_purview/_exceptions.py create mode 100644 python/packages/purview/agent_framework_purview/_middleware.py create mode 100644 python/packages/purview/agent_framework_purview/_models.py create mode 100644 python/packages/purview/agent_framework_purview/_processor.py create mode 100644 python/packages/purview/agent_framework_purview/_settings.py create mode 100644 python/packages/purview/pyproject.toml create mode 100644 python/packages/purview/tests/conftest.py create mode 100644 python/packages/purview/tests/test_chat_middleware.py create mode 100644 python/packages/purview/tests/test_client.py create mode 100644 python/packages/purview/tests/test_exceptions.py create mode 100644 python/packages/purview/tests/test_middleware.py create mode 100644 python/packages/purview/tests/test_models.py create mode 100644 python/packages/purview/tests/test_processor.py create mode 100644 python/packages/purview/tests/test_settings.py create mode 100644 python/samples/getting_started/purview_agent/README.md create mode 100644 python/samples/getting_started/purview_agent/sample_purview_agent.py diff --git a/python/packages/core/agent_framework/azure/_shared.py b/python/packages/core/agent_framework/azure/_shared.py index 093a4086f1..e3eb37b26e 100644 --- a/python/packages/core/agent_framework/azure/_shared.py +++ b/python/packages/core/agent_framework/azure/_shared.py @@ -243,4 +243,4 @@ class AzureOpenAIConfigMixin(OpenAIBase): def_headers = None self.default_headers = def_headers - super().__init__(model_id=deployment_name, client=client) + super().__init__(model_id=deployment_name, client=client, **kwargs) diff --git a/python/packages/core/agent_framework/microsoft/__init__.py b/python/packages/core/agent_framework/microsoft/__init__.py index 55022357bf..2874488829 100644 --- a/python/packages/core/agent_framework/microsoft/__init__.py +++ b/python/packages/core/agent_framework/microsoft/__init__.py @@ -3,12 +3,20 @@ import importlib from typing import Any -PACKAGE_NAME = "agent_framework_copilotstudio" -PACKAGE_EXTRA = ["microsoft-copilotstudio", "copilotstudio"] _IMPORTS: dict[str, tuple[str, list[str]]] = { "CopilotStudioAgent": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]), "__version__": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]), "acquire_token": ("agent_framework_copilotstudio", ["microsoft-copilotstudio", "copilotstudio"]), + # Purview (Graph Data Security & Governance) integration exports + "PurviewPolicyMiddleware": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewChatPolicyMiddleware": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewSettings": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewAppLocation": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewLocationType": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewAuthenticationError": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewRateLimitError": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewRequestError": ("agent_framework_purview", ["microsoft-purview", "purview"]), + "PurviewServiceError": ("agent_framework_purview", ["microsoft-purview", "purview"]), } @@ -23,7 +31,7 @@ def __getattr__(name: str) -> Any: f"please use `pip install agent-framework-{package_extra[0]}`, " "or update your requirements.txt or pyproject.toml file." ) from exc - raise AttributeError(f"Module `azure` has no attribute {name}.") + raise AttributeError(f"Module `microsoft` has no attribute {name}.") def __dir__() -> list[str]: diff --git a/python/packages/core/agent_framework/microsoft/__init__.pyi b/python/packages/core/agent_framework/microsoft/__init__.pyi index c30a7b3031..99ba2af489 100644 --- a/python/packages/core/agent_framework/microsoft/__init__.pyi +++ b/python/packages/core/agent_framework/microsoft/__init__.pyi @@ -1,5 +1,29 @@ # Copyright (c) Microsoft. All rights reserved. from agent_framework_copilotstudio import CopilotStudioAgent, __version__, acquire_token +from agent_framework_purview import ( + PurviewAppLocation, + PurviewAuthenticationError, + PurviewChatPolicyMiddleware, + PurviewLocationType, + PurviewPolicyMiddleware, + PurviewRateLimitError, + PurviewRequestError, + PurviewServiceError, + PurviewSettings, +) -__all__ = ["CopilotStudioAgent", "__version__", "acquire_token"] +__all__ = [ + "CopilotStudioAgent", + "PurviewAppLocation", + "PurviewAuthenticationError", + "PurviewChatPolicyMiddleware", + "PurviewLocationType", + "PurviewPolicyMiddleware", + "PurviewRateLimitError", + "PurviewRequestError", + "PurviewServiceError", + "PurviewSettings", + "__version__", + "acquire_token", +] diff --git a/python/packages/purview/LICENSE b/python/packages/purview/LICENSE new file mode 100644 index 0000000000..9e841e7a26 --- /dev/null +++ b/python/packages/purview/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/python/packages/purview/README.md b/python/packages/purview/README.md new file mode 100644 index 0000000000..b42213b4e5 --- /dev/null +++ b/python/packages/purview/README.md @@ -0,0 +1,224 @@ +## Microsoft Agent Framework – Purview Integration (Python) + +`agent-framework-purview` adds Microsoft Purview (Microsoft Graph dataSecurityAndGovernance) policy evaluation to the Microsoft Agent Framework. It lets you enforce data security / governance policies on both the *prompt* (user input + conversation history) and the *model response* before they proceed further in your workflow. + +> Status: **Preview** + +### Key Features + +- Middleware-based policy enforcement (agent-level and chat-client level) +- Blocks or allows content at both ingress (prompt) and egress (response) +- Works with any `ChatAgent` / agent orchestration using the standard Agent Framework middleware pipeline +- Supports both synchronous `TokenCredential` and `AsyncTokenCredential` from `azure-identity` +- Simple, typed configuration via `PurviewSettings` / `PurviewAppLocation` +- Two middleware types: + - `PurviewPolicyMiddleware` (Agent pipeline) + - `PurviewChatPolicyMiddleware` (Chat client middleware list) + +### When to Use +Add Purview when you need to: + +- Prevent sensitive or disallowed content from being sent to an LLM +- Prevent model output containing disallowed data from leaving the system +- Apply centrally managed policies without rewriting agent logic + +--- + +## Quick Start + +```python +import asyncio +from agent_framework import ChatAgent, ChatMessage, Role +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings +from azure.identity import InteractiveBrowserCredential + +async def main(): + chat_client = AzureOpenAIChatClient() # uses environment for endpoint + deployment + + purview_middleware = PurviewPolicyMiddleware( + credential=InteractiveBrowserCredential(), + settings=PurviewSettings(app_name="My Sample App") + ) + + agent = ChatAgent( + chat_client=chat_client, + instructions="You are a helpful assistant.", + middleware=[purview_middleware] + ) + + response = await agent.run(ChatMessage(role=Role.USER, text="Summarize zero trust in one sentence.")) + print(response) + +asyncio.run(main()) +``` + +If a policy violation is detected on the prompt, the middleware terminates the run and substitutes a system message: `"Prompt blocked by policy"`. If on the response, the result becomes `"Response blocked by policy"`. + +--- + +## Authentication + +`PurviewClient` uses the `azure-identity` library for token acquisition. You can use any `TokenCredential` or `AsyncTokenCredential` implementation. + +The APIs require the following Graph Permissions: +- ProtectionScopes.Compute.All : (userProtectionScopeContainer)[https://learn.microsoft.com/en-us/graph/api/userprotectionscopecontainer-compute] +- Content.Process.All : (processContent)[https://learn.microsoft.com/en-us/graph/api/userdatasecurityandgovernance-processcontent] +- ContentActivity.Write : (contentActivity)[https://learn.microsoft.com/en-us/graph/api/activitiescontainer-post-contentactivities] + +### Scopes +`PurviewSettings.get_scopes()` derives the Graph scope list (currently `https://graph.microsoft.com/.default` style). + +### Tenant Enablement for Purview +- The tenant requires an e5 license and consumptive billing setup. +- There need to be (Data Loss Prevention)[https://learn.microsoft.com/en-us/purview/dlp-create-deploy-policy] or (Data Collection Policies)[https://learn.microsoft.com/en-us/purview/collection-policies-policy-reference] that apply to the user to call Process Content API else it calls Content Activities API for auditing the message. + +--- + +## Configuration + +### `PurviewSettings` + +```python +PurviewSettings( + app_name="My App", # Display / logical name + tenant_id=None, # Optional – used mainly for auth context + purview_app_location=None, # Optional PurviewAppLocation for scoping + graph_base_uri="https://graph.microsoft.com/v1.0/", + process_inline=False, # Reserved for future inline processing optimizations + blocked_prompt_message="Prompt blocked by policy", # Custom message for blocked prompts + blocked_response_message="Response blocked by policy" # Custom message for blocked responses +) +``` + +To scope evaluation by location (application, URL, or domain): + +```python +from agent_framework.microsoft import ( + PurviewAppLocation, + PurviewLocationType, + PurviewSettings, +) + +settings = PurviewSettings( + app_name="Contoso Support", + purview_app_location=PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, + location_value="" + ) +) +``` + +### Customizing Blocked Messages + +By default, when Purview blocks a prompt or response, the middleware returns a generic system message. You can customize these messages by providing your own text in the `PurviewSettings`: + +```python +from agent_framework.microsoft import PurviewSettings + +settings = PurviewSettings( + app_name="My App", + blocked_prompt_message="Your request contains content that violates our policies. Please rephrase and try again.", + blocked_response_message="The response was blocked due to policy restrictions. Please contact support if you need assistance." +) +``` + +This is useful for: +- Providing more user-friendly error messages +- Including support contact information +- Localizing messages for different languages +- Adding branding or specific guidance for your application + +### Selecting Agent vs Chat Middleware + +Use the agent middleware when you already have / want the full agent pipeline: + +```python +from agent_framework import ChatAgent +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings +from azure.identity import DefaultAzureCredential + +credential = DefaultAzureCredential() +client = AzureOpenAIChatClient() + +agent = ChatAgent( + chat_client=client, + instructions="You are helpful.", + middleware=[PurviewPolicyMiddleware(credential, PurviewSettings(app_name="My App"))] +) +``` + +Use the chat middleware when you attach directly to a chat client (e.g. minimal agent shell or custom orchestration): + +```python +import os +from agent_framework import ChatAgent +from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.microsoft import PurviewChatPolicyMiddleware, PurviewSettings +from azure.identity import DefaultAzureCredential + +credential = DefaultAzureCredential() + +chat_client = AzureOpenAIChatClient( + deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"], + endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], + credential=credential, + middleware=[ + PurviewChatPolicyMiddleware(credential, PurviewSettings(app_name="My App (Chat)")) + ], +) + +agent = ChatAgent(chat_client=chat_client, instructions="You are helpful.") +``` + +The policy logic is identical; the difference is only the hook point in the pipeline. + +--- + +## Middleware Lifecycle + +1. Before agent execution (`prompt phase`): all `context.messages` are evaluated. +2. If blocked: `context.result` is replaced with a system message and `context.terminate = True`. +3. After successful agent execution (`response phase`): the produced messages are evaluated. +4. If blocked: result messages are replaced with a blocking notice. + +When a user identifier is discovered (e.g. in `ChatMessage.additional_properties['user_id']`) during the prompt phase it is reused for the response phase so both evaluations map consistently to the same user. + +You can customize the blocking messages using the `blocked_prompt_message` and `blocked_response_message` fields in `PurviewSettings`. For more advanced scenarios, you can wrap the middleware or post-process `context.result` in later middleware. + +--- + +## Exceptions + +| Exception | Scenario | +|-----------|----------| +| `PurviewAuthenticationError` | Token acquisition / validation issues | +| `PurviewRateLimitError` | 429 responses from service | +| `PurviewRequestError` | 4xx client errors (bad input, unauthorized, forbidden) | +| `PurviewServiceError` | 5xx or unexpected service errors | + +Catch broadly if you want unified fallback: + +```python +from agent_framework.microsoft import ( + PurviewAuthenticationError, PurviewRateLimitError, + PurviewRequestError, PurviewServiceError +) + +try: + ... +except (PurviewAuthenticationError, PurviewRateLimitError, PurviewRequestError, PurviewServiceError) as ex: + # Log / degrade gracefully + print(f"Purview enforcement skipped: {ex}") +``` + +--- + +## Notes +- Provide a `user_id` per request (e.g. in `ChatMessage(..., additional_properties={"user_id": ""})`) when possible for per-user policy scoping; otherwise supply a default via settings or environment. +- Blocking messages can be customized via `blocked_prompt_message` and `blocked_response_message` in `PurviewSettings`. By default, they are "Prompt blocked by policy" and "Response blocked by policy" respectively. +- Streaming responses: post-response policy evaluation presently applies only to non-streaming chat responses. +- Errors during policy checks are logged and do not fail the run; they degrade gracefully. + + diff --git a/python/packages/purview/agent_framework_purview/__init__.py b/python/packages/purview/agent_framework_purview/__init__.py new file mode 100644 index 0000000000..74e09164f7 --- /dev/null +++ b/python/packages/purview/agent_framework_purview/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft. All rights reserved. + +from ._exceptions import ( + PurviewAuthenticationError, + PurviewRateLimitError, + PurviewRequestError, + PurviewServiceError, +) +from ._middleware import PurviewChatPolicyMiddleware, PurviewPolicyMiddleware +from ._settings import PurviewAppLocation, PurviewLocationType, PurviewSettings + +__all__ = [ + "PurviewAppLocation", + "PurviewAuthenticationError", + "PurviewChatPolicyMiddleware", + "PurviewLocationType", + "PurviewPolicyMiddleware", + "PurviewRateLimitError", + "PurviewRequestError", + "PurviewServiceError", + "PurviewSettings", +] diff --git a/python/packages/purview/agent_framework_purview/_client.py b/python/packages/purview/agent_framework_purview/_client.py new file mode 100644 index 0000000000..8679a3bf84 --- /dev/null +++ b/python/packages/purview/agent_framework_purview/_client.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft. All rights reserved. +from __future__ import annotations + +import base64 +import inspect +import json +from typing import Any, cast + +import httpx +from agent_framework import AGENT_FRAMEWORK_USER_AGENT +from agent_framework.observability import get_tracer +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from ._exceptions import ( + PurviewAuthenticationError, + PurviewRateLimitError, + PurviewRequestError, + PurviewServiceError, +) +from ._models import ( + ContentActivitiesRequest, + ContentActivitiesResponse, + ProcessContentRequest, + ProcessContentResponse, + ProtectionScopesRequest, + ProtectionScopesResponse, +) +from ._settings import PurviewSettings + + +class PurviewClient: + """Async client for calling Graph Purview endpoints. + + Supports both synchronous TokenCredential and asynchronous AsyncTokenCredential implementations. + A sync credential will be invoked in a thread to avoid blocking the event loop. + """ + + def __init__( + self, + credential: TokenCredential | AsyncTokenCredential, + settings: PurviewSettings, + *, + timeout: float | None = 10.0, + ): + self._credential: TokenCredential | AsyncTokenCredential = credential + self._settings = settings + self._graph_uri = settings.graph_base_uri.rstrip("/") + self._timeout = timeout + self._client = httpx.AsyncClient(timeout=timeout) + + async def close(self) -> None: + await self._client.aclose() + + async def _get_token(self, *, tenant_id: str | None = None) -> str: + """Acquire an access token using either async or sync credential.""" + scopes = self._settings.get_scopes() + cred = self._credential + token = cred.get_token(*scopes, tenant_id=tenant_id) + token = await token if inspect.isawaitable(token) else token + return token.token + + @staticmethod + def _extract_token_info(token: str) -> dict[str, Any]: + parts = token.split(".") + if len(parts) < 2: + raise ValueError("Invalid JWT token format") + payload = parts[1] + rem = len(payload) % 4 + if rem: + payload += "=" * (4 - rem) + decoded = base64.urlsafe_b64decode(payload) + data = json.loads(decoded.decode("utf-8")) + return { + "user_id": data.get("oid") if data.get("idtyp") == "user" else None, + "tenant_id": data.get("tid"), + "client_id": data.get("appid"), + } + + async def get_user_info_from_token(self, *, tenant_id: str | None = None) -> dict[str, Any]: + token = await self._get_token(tenant_id=tenant_id) + return self._extract_token_info(token) + + async def process_content(self, request: ProcessContentRequest) -> ProcessContentResponse: + with get_tracer().start_as_current_span("purview.process_content"): + token = await self._get_token(tenant_id=request.tenant_id) + url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/processContent" + return cast(ProcessContentResponse, await self._post(url, request, ProcessContentResponse, token)) + + async def get_protection_scopes(self, request: ProtectionScopesRequest) -> ProtectionScopesResponse: + with get_tracer().start_as_current_span("purview.get_protection_scopes"): + token = await self._get_token() + url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/protectionScopes/compute" + return cast(ProtectionScopesResponse, await self._post(url, request, ProtectionScopesResponse, token)) + + async def send_content_activities(self, request: ContentActivitiesRequest) -> ContentActivitiesResponse: + with get_tracer().start_as_current_span("purview.send_content_activities"): + token = await self._get_token() + url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/activities/contentActivities" + return cast(ContentActivitiesResponse, await self._post(url, request, ContentActivitiesResponse, token)) + + async def _post(self, url: str, model: Any, response_type: type[Any], token: str) -> Any: + payload = model.model_dump(by_alias=True, exclude_none=True, mode="json") + headers = { + "Authorization": f"Bearer {token}", + "User-Agent": AGENT_FRAMEWORK_USER_AGENT, + "Content-Type": "application/json", + } + resp = await self._client.post(url, json=payload, headers=headers) + if resp.status_code in (401, 403): + raise PurviewAuthenticationError(f"Auth failure {resp.status_code}: {resp.text}") + if resp.status_code == 429: + raise PurviewRateLimitError(f"Rate limited {resp.status_code}: {resp.text}") + if resp.status_code not in (200, 201, 202): + raise PurviewRequestError(f"Purview request failed {resp.status_code}: {resp.text}") + try: + data = resp.json() + except ValueError: + data = {} + try: + # Prefer pydantic-style model_validate if present, else fall back to constructor. + if hasattr(response_type, "model_validate"): + return response_type.model_validate(data) # type: ignore[no-any-return] + return response_type(**data) # type: ignore[call-arg, no-any-return] + except Exception as ex: # pragma: no cover + raise PurviewServiceError(f"Failed to deserialize Purview response: {ex}") from ex diff --git a/python/packages/purview/agent_framework_purview/_exceptions.py b/python/packages/purview/agent_framework_purview/_exceptions.py new file mode 100644 index 0000000000..0ce8463504 --- /dev/null +++ b/python/packages/purview/agent_framework_purview/_exceptions.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Purview specific exceptions (minimal error shaping).""" + +from __future__ import annotations + +from agent_framework.exceptions import ServiceResponseException + +__all__ = [ + "PurviewAuthenticationError", + "PurviewRateLimitError", + "PurviewRequestError", + "PurviewServiceError", +] + + +class PurviewServiceError(ServiceResponseException): + """Base exception for Purview errors.""" + + +class PurviewAuthenticationError(PurviewServiceError): + """Authentication / authorization failure (401/403).""" + + +class PurviewRateLimitError(PurviewServiceError): + """Rate limiting or throttling (429).""" + + +class PurviewRequestError(PurviewServiceError): + """Other non-success HTTP errors.""" diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py new file mode 100644 index 0000000000..4ed21aa1a3 --- /dev/null +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft. All rights reserved. +from __future__ import annotations + +from collections.abc import Awaitable, Callable + +from agent_framework import AgentMiddleware, AgentRunContext, ChatContext, ChatMiddleware +from agent_framework._logging import get_logger +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from ._client import PurviewClient +from ._models import Activity +from ._processor import ScopedContentProcessor +from ._settings import PurviewSettings + +logger = get_logger("agent_framework.purview") + + +class PurviewPolicyMiddleware(AgentMiddleware): + """Agent middleware that enforces Purview policies on prompt and response. + + Accepts either a synchronous TokenCredential or an AsyncTokenCredential. + + Usage: + + .. code-block:: python + from agent_framework.microsoft import PurviewPolicyMiddleware, PurviewSettings + from agent_framework import ChatAgent + + credential = ... # TokenCredential or AsyncTokenCredential + settings = PurviewSettings(app_name="My App") + agent = ChatAgent( + chat_client=client, instructions="...", middleware=[PurviewPolicyMiddleware(credential, settings)] + ) + """ + + def __init__( + self, + credential: TokenCredential | AsyncTokenCredential, + settings: PurviewSettings, + ) -> None: + self._client = PurviewClient(credential, settings) + self._processor = ScopedContentProcessor(self._client, settings) + self._settings = settings + + async def process( + self, + context: AgentRunContext, + next: Callable[[AgentRunContext], Awaitable[None]], + ) -> None: # type: ignore[override] + resolved_user_id: str | None = None + try: + # Pre (prompt) check + should_block_prompt, resolved_user_id = await self._processor.process_messages( + context.messages, Activity.UPLOAD_TEXT + ) + if should_block_prompt: + from agent_framework import AgentRunResponse, ChatMessage, Role + + context.result = AgentRunResponse( + messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_prompt_message)] + ) + context.terminate = True + return + except Exception as ex: + # Log and continue if there's an error in the pre-check + logger.error(f"Error in Purview policy pre-check: {ex}") + + await next(context) + + try: + # Post (response) check only if we have a normal AgentRunResponse + # Use the same user_id from the request for the response evaluation + if context.result and not context.is_streaming: + should_block_response, _ = await self._processor.process_messages( + context.result.messages, # type: ignore[union-attr] + Activity.UPLOAD_TEXT, + user_id=resolved_user_id, + ) + if should_block_response: + from agent_framework import AgentRunResponse, ChatMessage, Role + + context.result = AgentRunResponse( + messages=[ChatMessage(role=Role.SYSTEM, text=self._settings.blocked_response_message)] + ) + else: + # Streaming responses are not supported for post-checks + logger.debug("Streaming responses are not supported for Purview policy post-checks") + except Exception as ex: + # Log and continue if there's an error in the post-check + logger.error(f"Error in Purview policy post-check: {ex}") + + +class PurviewChatPolicyMiddleware(ChatMiddleware): + """Chat middleware variant for Purview policy evaluation. + + This allows users to attach Purview enforcement directly to a chat client + + Behavior: + * Pre-chat: evaluates outgoing (user + context) messages as an upload activity + and can terminate execution if blocked. + * Post-chat: evaluates the received response messages (streaming is not presently supported) + and can replace them with a blocked message. Uses the same user_id from the request + to ensure consistent user identity throughout the evaluation. + + Usage: + + .. code-block:: python + from agent_framework.microsoft import PurviewChatPolicyMiddleware, PurviewSettings + from agent_framework import ChatClient + + credential = ... # TokenCredential or AsyncTokenCredential + settings = PurviewSettings(app_name="My App") + client = ChatClient(..., middleware=[PurviewChatPolicyMiddleware(credential, settings)]) + """ + + def __init__( + self, + credential: TokenCredential | AsyncTokenCredential, + settings: PurviewSettings, + ) -> None: + self._client = PurviewClient(credential, settings) + self._processor = ScopedContentProcessor(self._client, settings) + self._settings = settings + + async def process( + self, + context: ChatContext, + next: Callable[[ChatContext], Awaitable[None]], + ) -> None: # type: ignore[override] + resolved_user_id: str | None = None + try: + should_block_prompt, resolved_user_id = await self._processor.process_messages( + context.messages, Activity.UPLOAD_TEXT + ) + if should_block_prompt: + from agent_framework import ChatMessage + + context.result = [ # type: ignore[assignment] + ChatMessage(role="system", text=self._settings.blocked_prompt_message) + ] + context.terminate = True + return + except Exception as ex: + logger.error(f"Error in Purview policy pre-check: {ex}") + + await next(context) + + try: + # Post (response) evaluation only if non-streaming and we have messages result shape + # Use the same user_id from the request for the response evaluation + if context.result and not context.is_streaming: + result_obj = context.result + messages = getattr(result_obj, "messages", None) + if messages: + should_block_response, _ = await self._processor.process_messages( + messages, Activity.UPLOAD_TEXT, user_id=resolved_user_id + ) + if should_block_response: + from agent_framework import ChatMessage + + context.result = [ # type: ignore[assignment] + ChatMessage(role="system", text=self._settings.blocked_response_message) + ] + else: + logger.debug("Streaming responses are not supported for Purview policy post-checks") + except Exception as ex: + logger.error(f"Error in Purview policy post-check: {ex}") diff --git a/python/packages/purview/agent_framework_purview/_models.py b/python/packages/purview/agent_framework_purview/_models.py new file mode 100644 index 0000000000..a629536901 --- /dev/null +++ b/python/packages/purview/agent_framework_purview/_models.py @@ -0,0 +1,992 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unified Purview model definitions and public export surface.""" + +from __future__ import annotations + +from collections.abc import Mapping, MutableMapping, Sequence +from datetime import datetime +from enum import Enum, Flag, auto +from typing import Any, ClassVar, TypeVar, cast +from uuid import uuid4 + +from agent_framework._logging import get_logger +from agent_framework._serialization import SerializationMixin + +logger = get_logger("agent_framework.purview") + +# -------------------------------------------------------------------------------------- +# Enums & flag helpers +# -------------------------------------------------------------------------------------- + + +class Activity(str, Enum): + """High-level activity types representing user or agent operations.""" + + UNKNOWN = "unknown" + UPLOAD_TEXT = "uploadText" + UPLOAD_FILE = "uploadFile" + DOWNLOAD_TEXT = "downloadText" + DOWNLOAD_FILE = "downloadFile" + + +class ProtectionScopeActivities(Flag): + """Flag enumeration of activities used in policy protection scopes.""" + + NONE = 0 + UPLOAD_TEXT = auto() + UPLOAD_FILE = auto() + DOWNLOAD_TEXT = auto() + DOWNLOAD_FILE = auto() + UNKNOWN_FUTURE_VALUE = auto() + + def __int__(self) -> int: # pragma: no cover + return self.value + + +FlagT = TypeVar("FlagT", bound=Flag) + +_PROTECTION_SCOPE_ACTIVITIES_MAP: dict[str, ProtectionScopeActivities] = { + "none": ProtectionScopeActivities.NONE, + "uploadText": ProtectionScopeActivities.UPLOAD_TEXT, + "uploadFile": ProtectionScopeActivities.UPLOAD_FILE, + "downloadText": ProtectionScopeActivities.DOWNLOAD_TEXT, + "downloadFile": ProtectionScopeActivities.DOWNLOAD_FILE, + "unknownFutureValue": ProtectionScopeActivities.UNKNOWN_FUTURE_VALUE, +} +_PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER: list[tuple[str, ProtectionScopeActivities]] = [ + ("uploadText", ProtectionScopeActivities.UPLOAD_TEXT), + ("uploadFile", ProtectionScopeActivities.UPLOAD_FILE), + ("downloadText", ProtectionScopeActivities.DOWNLOAD_TEXT), + ("downloadFile", ProtectionScopeActivities.DOWNLOAD_FILE), +] + + +def deserialize_flag( + value: object, mapping: Mapping[str, FlagT], enum_cls: type[FlagT] +) -> FlagT | None: # pragma: no cover + """Deserialize arbitrary input into a flag enum instance.""" + if value is None: + return None + if isinstance(value, enum_cls): + return value + if isinstance(value, int): + try: + return enum_cls(value) + except Exception: + return None + + flag_value = enum_cls(0) + parts: list[str] = [] + + if isinstance(value, str): + raw = value.strip() + if not raw: + return enum_cls(0) + parts.extend([p.strip() for p in raw.split(",") if p.strip()]) + elif isinstance(value, (list, tuple, set)): + for item in value: + if isinstance(item, str): + parts.extend([p.strip() for p in item.split(",") if p.strip()]) + elif isinstance(item, enum_cls): + flag_value |= item + elif isinstance(item, int): + try: + flag_value |= enum_cls(item) + except Exception: + logger.warning(f"Failed to convert int {item} to {enum_cls.__name__}") + else: + return None + + for part in parts: + member = mapping.get(part) + if member is not None: + flag_value |= member + + if flag_value == enum_cls(0): + none_member = mapping.get("none") + if none_member is not None: + return none_member # type: ignore[return-value,index] + return flag_value + + +def serialize_flag( + flag_value: Flag | int | None, ordered_parts: Sequence[tuple[str, Flag]] +) -> str | None: # pragma: no cover + """Serialize a flag enum (or int) into a stable, comma-separated string.""" + if flag_value is None: + return None + if isinstance(flag_value, int): + if flag_value == 0: + return "none" + int_parts: list[str] = [] + for name, member in ordered_parts: + if flag_value & member.value: + int_parts.append(name) + return ",".join(int_parts) if int_parts else "none" + if not isinstance(flag_value, Flag): + return None + if flag_value.value == 0: + return "none" + parts: list[str] = [] + for name, member in ordered_parts: + if flag_value & member: + parts.append(name) + return ",".join(parts) if parts else "none" + + +class DlpAction(str, Enum): + BLOCK_ACCESS = "blockAccess" + OTHER = "other" + + +class RestrictionAction(str, Enum): + BLOCK = "block" + OTHER = "other" + + +class ProtectionScopeState(str, Enum): + NOT_MODIFIED = "notModified" + MODIFIED = "modified" + UNKNOWN_FUTURE_VALUE = "unknownFutureValue" + + +class ExecutionMode(str, Enum): + EVALUATE_INLINE = "evaluateInline" + EVALUATE_OFFLINE = "evaluateOffline" + UNKNOWN_FUTURE_VALUE = "unknownFutureValue" + + +class PolicyPivotProperty(str, Enum): + NONE = "none" + ACTIVITY = "activity" + LOCATION = "location" + UNKNOWN_FUTURE_VALUE = "unknownFutureValue" + + +def translate_activity(activity: Activity) -> ProtectionScopeActivities: + mapping = { + Activity.UNKNOWN: ProtectionScopeActivities.NONE, + Activity.UPLOAD_TEXT: ProtectionScopeActivities.UPLOAD_TEXT, + Activity.UPLOAD_FILE: ProtectionScopeActivities.UPLOAD_FILE, + Activity.DOWNLOAD_TEXT: ProtectionScopeActivities.DOWNLOAD_TEXT, + Activity.DOWNLOAD_FILE: ProtectionScopeActivities.DOWNLOAD_FILE, + } + return mapping.get(activity, ProtectionScopeActivities.UNKNOWN_FUTURE_VALUE) + + +# -------------------------------------------------------------------------------------- +# Simple value models +# -------------------------------------------------------------------------------------- + + +class _AliasSerializable(SerializationMixin): + """Base class adding alias mapping + pydantic-compat helpers. + + Each subclass can define ``_ALIASES`` mapping internal attribute name -> external serialized key. + ``to_dict`` will emit external keys; ``from_dict`` (via ``__init__`` preprocessing) accepts either form. + + Provides light-weight compatibility helpers ``model_dump`` / ``model_validate`` + """ + + _ALIASES: ClassVar[dict[str, str]] = {} + + def __init__(self, **kwargs: Any) -> None: + # Normalize alias keys -> internal names across the entire class hierarchy + # Collect all aliases from parent classes too + all_aliases: dict[str, str] = {} + for cls in type(self).__mro__: + if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict): + for internal, external in cls._ALIASES.items(): + if external not in all_aliases: + all_aliases[external] = internal + + # Normalize all aliased keys in kwargs + for external, internal in all_aliases.items(): + if external in kwargs and internal not in kwargs: + kwargs[internal] = kwargs.pop(external) + + # Set normalized kwargs as attributes + # This will overwrite any None values that child __init__ may have set from default params + for k, v in kwargs.items(): + setattr(self, k, v) + + # ------------------------------------------------------------------ + # Compatibility helpers + # ------------------------------------------------------------------ + def model_dump(self, *, by_alias: bool = True, exclude_none: bool = True, **_: Any) -> dict[str, Any]: + # Use self.to_dict() to get alias translation + d = self.to_dict(exclude_none=exclude_none) + # If by_alias=False, translate external -> internal (rarely needed; default True) + if not by_alias and self._ALIASES: + reverse = {v: k for k, v in self._ALIASES.items()} + translated: dict[str, Any] = {} + for k, v in d.items(): + translated[reverse.get(k, k)] = v + return translated + return d + + def model_dump_json(self, *, by_alias: bool = True, exclude_none: bool = True, **kwargs: Any) -> str: + import json + + return json.dumps(self.model_dump(by_alias=by_alias, exclude_none=exclude_none, **kwargs)) + + @classmethod + def model_validate(cls, value: MutableMapping[str, Any]) -> _AliasSerializable: # type: ignore[name-defined] + return cls(**value) + + # ------------------------------------------------------------------ + # Override to handle alias emission + # ------------------------------------------------------------------ + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + base = SerializationMixin.to_dict(self, exclude=exclude, exclude_none=exclude_none) + + # For Graph API models, remove the auto-generated 'type' field if it's in DEFAULT_EXCLUDE + if "type" in self.DEFAULT_EXCLUDE: + base.pop("type", None) + + # Collect all aliases from class hierarchy + all_aliases: dict[str, str] = {} + for cls in type(self).__mro__: + if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict): + # Parent aliases first (will be overridden by child if same key) + for internal, external in cls._ALIASES.items(): + if internal not in all_aliases: + all_aliases[internal] = external + + if not all_aliases: + return base + + # Translate internal -> external keys (except 'type' reserved) + translated: dict[str, Any] = {} + for k, v in base.items(): + if k == "type": + translated[k] = v + continue + external = all_aliases.get(k, k) + translated[external] = v + return translated + + +class PolicyLocation(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"data_type": "@odata.type"} + DEFAULT_EXCLUDE: ClassVar[set[str]] = {"type"} # Exclude auto-generated type field for Graph API + + def __init__(self, data_type: str | None = None, value: str | None = None, **kwargs: Any) -> None: + # Extract aliased values from kwargs + if "@odata.type" in kwargs: + data_type = kwargs["@odata.type"] + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.data_type = data_type + self.value = value + + +class ActivityMetadata(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"activity": "activity"} + + def __init__(self, activity: Activity, **kwargs: Any) -> None: + super().__init__(activity=activity, **kwargs) + self.activity = activity + + +class OperatingSystemSpecifications(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = { + "operating_system_platform": "operatingSystemPlatform", + "operating_system_version": "operatingSystemVersion", + } + + def __init__( + self, + operating_system_platform: str | None = None, + operating_system_version: str | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "operatingSystemPlatform" in kwargs: + operating_system_platform = kwargs["operatingSystemPlatform"] + if "operatingSystemVersion" in kwargs: + operating_system_version = kwargs["operatingSystemVersion"] + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.operating_system_platform = operating_system_platform + self.operating_system_version = operating_system_version + + +class DeviceMetadata(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = { + "ip_address": "ipAddress", + "operating_system_specifications": "operatingSystemSpecifications", + } + + def __init__( + self, + ip_address: str | None = None, + operating_system_specifications: OperatingSystemSpecifications | MutableMapping[str, Any] | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "ipAddress" in kwargs: + ip_address = kwargs["ipAddress"] + if "operatingSystemSpecifications" in kwargs: + operating_system_specifications = kwargs["operatingSystemSpecifications"] + + # Convert nested objects + if isinstance(operating_system_specifications, MutableMapping): + operating_system_specifications = OperatingSystemSpecifications(**operating_system_specifications) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.ip_address = ip_address + self.operating_system_specifications = operating_system_specifications + + +class IntegratedAppMetadata(_AliasSerializable): + def __init__(self, name: str | None = None, version: str | None = None, **kwargs: Any) -> None: + super().__init__(name=name, version=version, **kwargs) + self.name = name + self.version = version + + +class ProtectedAppMetadata(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"application_location": "applicationLocation"} + + def __init__( + self, + name: str | None = None, + version: str | None = None, + application_location: PolicyLocation | MutableMapping[str, Any] | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "applicationLocation" in kwargs: + application_location = kwargs["applicationLocation"] + + # Convert nested objects + if isinstance(application_location, MutableMapping): + application_location = PolicyLocation(**application_location) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.name = name + self.version = version + self.application_location = application_location # type: ignore[assignment] + + +class DlpActionInfo(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"restriction_action": "restrictionAction"} + + def __init__( + self, + action: DlpAction | None = None, + restriction_action: RestrictionAction | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "restrictionAction" in kwargs: + restriction_action = kwargs["restrictionAction"] + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.action = action + self.restriction_action = restriction_action + + +class AccessedResourceDetails(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = { + "label_id": "labelId", + "access_type": "accessType", + "is_cross_prompt_injection_detected": "isCrossPromptInjectionDetected", + } + + def __init__( + self, + identifier: str | None = None, + name: str | None = None, + url: str | None = None, + label_id: str | None = None, + access_type: str | None = None, + status: str | None = None, + is_cross_prompt_injection_detected: bool | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "labelId" in kwargs: + label_id = kwargs["labelId"] + if "accessType" in kwargs: + access_type = kwargs["accessType"] + if "isCrossPromptInjectionDetected" in kwargs: + is_cross_prompt_injection_detected = kwargs["isCrossPromptInjectionDetected"] + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.identifier = identifier + self.name = name + self.url = url + self.label_id = label_id + self.access_type = access_type + self.status = status + self.is_cross_prompt_injection_detected = is_cross_prompt_injection_detected + + +class AiInteractionPlugin(_AliasSerializable): + def __init__( + self, + identifier: str | None = None, + name: str | None = None, + version: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(identifier=identifier, name=name, version=version, **kwargs) + self.identifier = identifier + self.name = name + self.version = version + + +class AiAgentInfo(_AliasSerializable): + def __init__( + self, + identifier: str | None = None, + name: str | None = None, + version: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(identifier=identifier, name=name, version=version, **kwargs) + self.identifier = identifier + self.name = name + self.version = version + + +# -------------------------------------------------------------------------------------- +# Content models +# -------------------------------------------------------------------------------------- + + +class GraphDataTypeBase(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"data_type": "@odata.type"} + # Exclude the auto-generated 'type' field - Graph API uses @odata.type instead + DEFAULT_EXCLUDE: ClassVar[set[str]] = {"type"} + + def __init__(self, data_type: str, **kwargs: Any) -> None: + super().__init__(data_type=data_type, **kwargs) + self.data_type = data_type + + +class ContentBase(GraphDataTypeBase): + pass + + +class PurviewTextContent(ContentBase): + def __init__(self, data: str, data_type: str = "microsoft.graph.textContent", **kwargs: Any) -> None: + super().__init__(data_type=data_type, **kwargs) + self.data = data + + +class PurviewBinaryContent(ContentBase): + def __init__(self, data: bytes, data_type: str = "microsoft.graph.binaryContent", **kwargs: Any) -> None: + super().__init__(data_type=data_type, **kwargs) + self.data = data + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + import base64 + + base = super().to_dict(exclude=exclude, exclude_none=exclude_none) + # Ensure bytes encoded as base64 string like pydantic + data_bytes = getattr(self, "data", b"") or b"" + base["data"] = base64.b64encode(data_bytes).decode("utf-8") + return base + + +class ProcessConversationMetadata(GraphDataTypeBase): + _ALIASES: ClassVar[dict[str, str]] = { + "correlation_id": "correlationId", + "sequence_number": "sequenceNumber", + "is_truncated": "isTruncated", + "created_date_time": "createdDateTime", + "modified_date_time": "modifiedDateTime", + "parent_message_id": "parentMessageId", + "accessed_resources": "accessedResources_v2", + } + + def __init__( + self, + identifier: str | None = None, + content: PurviewTextContent | PurviewBinaryContent | ContentBase | MutableMapping[str, Any] | None = None, + name: str | None = None, + is_truncated: bool | None = None, + data_type: str = "microsoft.graph.processConversationMetadata", # emitted via base + correlation_id: str | None = None, + sequence_number: int | None = None, + length: int | None = None, + created_date_time: datetime | None = None, + modified_date_time: datetime | None = None, + parent_message_id: str | None = None, + accessed_resources: list[AccessedResourceDetails | MutableMapping[str, Any]] | None = None, + plugins: list[AiInteractionPlugin | MutableMapping[str, Any]] | None = None, + agents: list[AiAgentInfo | MutableMapping[str, Any]] | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "correlationId" in kwargs: + correlation_id = kwargs["correlationId"] + if "sequenceNumber" in kwargs: + sequence_number = kwargs["sequenceNumber"] + if "isTruncated" in kwargs: + is_truncated = kwargs["isTruncated"] + if "createdDateTime" in kwargs: + created_date_time = kwargs["createdDateTime"] + if "modifiedDateTime" in kwargs: + modified_date_time = kwargs["modifiedDateTime"] + if "parentMessageId" in kwargs: + parent_message_id = kwargs["parentMessageId"] + if "accessedResources_v2" in kwargs: + accessed_resources = kwargs["accessedResources_v2"] + + # Convert nested objects + if isinstance(content, MutableMapping): + # determine by type? fall back to text content + c_type = content.get("@odata.type") or content.get("data_type") + if c_type and "binary" in str(c_type): + content = PurviewBinaryContent(**content) # type: ignore[arg-type] + else: + content = PurviewTextContent(**content) # type: ignore[arg-type] + accessed_list: list[AccessedResourceDetails] | None = None + if accessed_resources: + accessed_list = [ + ar if isinstance(ar, AccessedResourceDetails) else AccessedResourceDetails(**ar) + for ar in accessed_resources + ] + plugin_list: list[AiInteractionPlugin] | None = None + if plugins: + plugin_list = [p if isinstance(p, AiInteractionPlugin) else AiInteractionPlugin(**p) for p in plugins] + agent_list: list[AiAgentInfo] | None = None + if agents: + agent_list = [a if isinstance(a, AiAgentInfo) else AiAgentInfo(**a) for a in agents] + + # Call parent without explicit params with aliases + super().__init__(data_type=data_type, **kwargs) + self.identifier = identifier + self.content = content # type: ignore[assignment] + self.name = name + self.correlation_id = correlation_id + self.sequence_number = sequence_number + self.length = length + self.is_truncated = is_truncated + self.created_date_time = created_date_time + self.modified_date_time = modified_date_time + self.parent_message_id = parent_message_id + self.accessed_resources = accessed_list + self.plugins = plugin_list + self.agents = agent_list + + +class ContentToProcess(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = { + "content_entries": "contentEntries", + "activity_metadata": "activityMetadata", + "device_metadata": "deviceMetadata", + "integrated_app_metadata": "integratedAppMetadata", + "protected_app_metadata": "protectedAppMetadata", + } + + def __init__( + self, + content_entries: list[ProcessConversationMetadata | MutableMapping[str, Any]], + activity_metadata: ActivityMetadata | MutableMapping[str, Any], + device_metadata: DeviceMetadata | MutableMapping[str, Any], + integrated_app_metadata: IntegratedAppMetadata | MutableMapping[str, Any], + protected_app_metadata: ProtectedAppMetadata | MutableMapping[str, Any], + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "contentEntries" in kwargs: + content_entries = kwargs["contentEntries"] + if "activityMetadata" in kwargs: + activity_metadata = kwargs["activityMetadata"] + if "deviceMetadata" in kwargs: + device_metadata = kwargs["deviceMetadata"] + if "integratedAppMetadata" in kwargs: + integrated_app_metadata = kwargs["integratedAppMetadata"] + if "protectedAppMetadata" in kwargs: + protected_app_metadata = kwargs["protectedAppMetadata"] + + # Convert nested objects + entries = [ + e if isinstance(e, ProcessConversationMetadata) else ProcessConversationMetadata(**e) + for e in content_entries + ] + if isinstance(activity_metadata, MutableMapping): + activity_metadata = ActivityMetadata(**activity_metadata) + if isinstance(device_metadata, MutableMapping): + device_metadata = DeviceMetadata(**device_metadata) + if isinstance(integrated_app_metadata, MutableMapping): + integrated_app_metadata = IntegratedAppMetadata(**integrated_app_metadata) + if isinstance(protected_app_metadata, MutableMapping): + protected_app_metadata = ProtectedAppMetadata(**protected_app_metadata) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.content_entries = entries + self.activity_metadata = activity_metadata # type: ignore[assignment] + self.device_metadata = device_metadata # type: ignore[assignment] + self.integrated_app_metadata = integrated_app_metadata # type: ignore[assignment] + self.protected_app_metadata = protected_app_metadata # type: ignore[assignment] + + +# -------------------------------------------------------------------------------------- +# Request models +# -------------------------------------------------------------------------------------- + + +class ProcessContentRequest(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"content_to_process": "contentToProcess"} + DEFAULT_EXCLUDE: ClassVar[set[str]] = {"user_id", "tenant_id", "correlation_id", "process_inline"} + + def __init__( + self, + content_to_process: ContentToProcess | MutableMapping[str, Any], + user_id: str, + tenant_id: str, + correlation_id: str | None = None, + process_inline: bool | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "contentToProcess" in kwargs: + content_to_process = kwargs["contentToProcess"] + + # Convert nested objects + if isinstance(content_to_process, MutableMapping): + content_to_process = ContentToProcess(**content_to_process) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.content_to_process = content_to_process # type: ignore[assignment] + self.user_id = user_id + self.tenant_id = tenant_id + self.correlation_id = correlation_id + self.process_inline = process_inline + + +class ProtectionScopesRequest(_AliasSerializable): + DEFAULT_EXCLUDE: ClassVar[set[str]] = {"user_id", "tenant_id", "correlation_id", "scope_identifier"} + _ALIASES: ClassVar[dict[str, str]] = { + "pivot_on": "pivotOn", + "device_metadata": "deviceMetadata", + "integrated_app_metadata": "integratedAppMetadata", + } + + def __init__( + self, + user_id: str, + tenant_id: str, + activities: ProtectionScopeActivities | str | int | Sequence[str] | None = None, + locations: list[PolicyLocation | MutableMapping[str, Any]] | None = None, + pivot_on: PolicyPivotProperty | None = None, + device_metadata: DeviceMetadata | MutableMapping[str, Any] | None = None, + integrated_app_metadata: IntegratedAppMetadata | MutableMapping[str, Any] | None = None, + correlation_id: str | None = None, + scope_identifier: str | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "pivotOn" in kwargs: + pivot_on = kwargs["pivotOn"] + if "deviceMetadata" in kwargs: + device_metadata = kwargs["deviceMetadata"] + if "integratedAppMetadata" in kwargs: + integrated_app_metadata = kwargs["integratedAppMetadata"] + + # Deserialize activities flag + if not isinstance(activities, ProtectionScopeActivities) and activities is not None: + activities = deserialize_flag(activities, _PROTECTION_SCOPE_ACTIVITIES_MAP, ProtectionScopeActivities) + + # Convert nested objects + if locations: + locations = [loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations] + if isinstance(device_metadata, MutableMapping): + device_metadata = DeviceMetadata(**device_metadata) + if isinstance(integrated_app_metadata, MutableMapping): + integrated_app_metadata = IntegratedAppMetadata(**integrated_app_metadata) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.user_id = user_id + self.tenant_id = tenant_id + self.activities = activities # type: ignore[assignment] + self.locations = locations + self.pivot_on = pivot_on + self.device_metadata = device_metadata + self.integrated_app_metadata = integrated_app_metadata + self.correlation_id = correlation_id + self.scope_identifier = scope_identifier + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + # Get base dict (activities will be missing because Flag isn't JSON-serializable) + base = super().to_dict(exclude=exclude, exclude_none=exclude_none) + + # Manually serialize activities flag if present and not excluded + if self.activities is not None or not exclude_none: + if self.activities is not None: + base["activities"] = serialize_flag(self.activities, _PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER) + elif not exclude_none: + base["activities"] = None + + return base + + +class ContentActivitiesRequest(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = { + "user_id": "userId", + "scope_identifier": "scopeIdentifier", + "content_to_process": "contentMetadata", + } + DEFAULT_EXCLUDE: ClassVar[set[str]] = {"tenant_id", "correlation_id"} + + def __init__( + self, + user_id: str, + content_to_process: ContentToProcess | MutableMapping[str, Any], + tenant_id: str, + id: str | None = None, + scope_identifier: str | None = None, + correlation_id: str | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "userId" in kwargs: + user_id = kwargs["userId"] + if "scopeIdentifier" in kwargs: + scope_identifier = kwargs["scopeIdentifier"] + if "contentMetadata" in kwargs: + content_to_process = kwargs["contentMetadata"] + + # Convert nested objects + if isinstance(content_to_process, MutableMapping): + content_to_process = ContentToProcess(**content_to_process) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.id = id or str(uuid4()) + self.user_id = user_id + self.content_to_process = content_to_process # type: ignore[assignment] + self.tenant_id = tenant_id + self.scope_identifier = scope_identifier + self.correlation_id = correlation_id + + +# -------------------------------------------------------------------------------------- +# Response models +# -------------------------------------------------------------------------------------- + + +class ErrorDetails(_AliasSerializable): + def __init__(self, code: str | None = None, message: str | None = None, **kwargs: Any) -> None: + super().__init__(code=code, message=message, **kwargs) + self.code = code + self.message = message + + +class ProcessingError(_AliasSerializable): + def __init__(self, message: str | None = None, **kwargs: Any) -> None: + super().__init__(message=message, **kwargs) + self.message = message + + +class ProcessContentResponse(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = { + "protection_scope_state": "protectionScopeState", + "policy_actions": "policyActions", + "processing_errors": "processingErrors", + } + + id: str | None + protection_scope_state: ProtectionScopeState | None + policy_actions: list[DlpActionInfo] | None + processing_errors: list[ProcessingError] | None + + def __init__( + self, + id: str | None = None, + protection_scope_state: ProtectionScopeState | None = None, + policy_actions: list[DlpActionInfo | MutableMapping[str, Any]] | None = None, + processing_errors: list[ProcessingError | MutableMapping[str, Any]] | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "protectionScopeState" in kwargs: + protection_scope_state = kwargs["protectionScopeState"] + if "policyActions" in kwargs: + policy_actions = kwargs["policyActions"] + if "processingErrors" in kwargs: + processing_errors = kwargs["processingErrors"] + + # Convert to objects + converted_policy_actions: list[DlpActionInfo] | None = None + if policy_actions is not None: + converted_policy_actions = cast( + list[DlpActionInfo], + [p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions], + ) + + converted_processing_errors: list[ProcessingError] | None = None + if processing_errors is not None: + converted_processing_errors = cast( + list[ProcessingError], + [pe if isinstance(pe, ProcessingError) else ProcessingError(**pe) for pe in processing_errors], + ) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.id = id + self.protection_scope_state = protection_scope_state + self.policy_actions = converted_policy_actions + self.processing_errors = converted_processing_errors + + +class PolicyScope(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"policy_actions": "policyActions", "execution_mode": "executionMode"} + + activities: ProtectionScopeActivities | None + locations: list[PolicyLocation] | None + policy_actions: list[DlpActionInfo] | None + execution_mode: ExecutionMode | None + + def __init__( + self, + activities: ProtectionScopeActivities | str | int | Sequence[str] | None = None, + locations: list[PolicyLocation | MutableMapping[str, Any]] | None = None, + policy_actions: list[DlpActionInfo | MutableMapping[str, Any]] | None = None, + execution_mode: ExecutionMode | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs + if "policyActions" in kwargs: + policy_actions = kwargs["policyActions"] + if "executionMode" in kwargs: + execution_mode = kwargs["executionMode"] + + # Deserialize activities flag + if not isinstance(activities, ProtectionScopeActivities) and activities is not None: + activities = deserialize_flag(activities, _PROTECTION_SCOPE_ACTIVITIES_MAP, ProtectionScopeActivities) + + # Convert nested objects + converted_locations: list[PolicyLocation] | None = None + if locations is not None: + converted_locations = cast( + list[PolicyLocation], + [loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations], + ) + + converted_policy_actions: list[DlpActionInfo] | None = None + if policy_actions is not None: + converted_policy_actions = cast( + list[DlpActionInfo], + [p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions], + ) + + # Call parent without explicit params with aliases + super().__init__(**kwargs) + self.activities = activities # type: ignore[assignment] + self.locations = converted_locations + self.policy_actions = converted_policy_actions + self.execution_mode = execution_mode + + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + # Get base dict (activities will be missing because Flag isn't JSON-serializable) + base = super().to_dict(exclude=exclude, exclude_none=exclude_none) + + # Manually serialize activities flag if present and not excluded + if self.activities is not None or not exclude_none: + if self.activities is not None: + base["activities"] = serialize_flag(self.activities, _PROTECTION_SCOPE_ACTIVITIES_SERIALIZE_ORDER) + elif not exclude_none: + base["activities"] = None + + return base + + +class ProtectionScopesResponse(_AliasSerializable): + _ALIASES: ClassVar[dict[str, str]] = {"scope_identifier": "scopeIdentifier", "scopes": "value"} + + scope_identifier: str | None + scopes: list[PolicyScope] | None + + def __init__( + self, + scope_identifier: str | None = None, + scopes: list[PolicyScope | MutableMapping[str, Any]] | None = None, + **kwargs: Any, + ) -> None: + # Extract aliased values from kwargs before they're normalized by parent + if "scopeIdentifier" in kwargs: + scope_identifier = kwargs["scopeIdentifier"] + if "value" in kwargs: + scopes = kwargs["value"] + + converted_scopes: list[PolicyScope] | None = None + if scopes is not None: + converted_scopes = cast( + list[PolicyScope], [s if isinstance(s, PolicyScope) else PolicyScope(**s) for s in scopes] + ) + + # Don't pass parameters that have aliases - let parent normalize them + super().__init__(**kwargs) + self.scope_identifier = scope_identifier + self.scopes = converted_scopes + + +class ContentActivitiesResponse(_AliasSerializable): + DEFAULT_EXCLUDE: ClassVar[set[str]] = {"status_code"} + + def __init__( + self, + status_code: int | None = None, + error: ErrorDetails | MutableMapping[str, Any] | None = None, + **kwargs: Any, + ) -> None: + if isinstance(error, MutableMapping): + error = ErrorDetails(**error) + super().__init__(status_code=status_code, error=error, **kwargs) + self.status_code = status_code + self.error = error # type: ignore[assignment] + + +__all__ = [ + "AccessedResourceDetails", + "Activity", + "ActivityMetadata", + "AiAgentInfo", + "AiInteractionPlugin", + "ContentActivitiesRequest", + "ContentActivitiesResponse", + "ContentBase", + "ContentToProcess", + "DeviceMetadata", + "DlpAction", + "DlpActionInfo", + "ExecutionMode", + "GraphDataTypeBase", + "IntegratedAppMetadata", + "OperatingSystemSpecifications", + "PolicyLocation", + "PolicyPivotProperty", + "PolicyScope", + "ProcessContentRequest", + "ProcessContentResponse", + "ProcessConversationMetadata", + "ProcessingError", + "ProtectedAppMetadata", + "ProtectionScopeActivities", + "ProtectionScopeState", + "ProtectionScopesRequest", + "ProtectionScopesResponse", + "PurviewBinaryContent", + "PurviewTextContent", + "RestrictionAction", + "deserialize_flag", + "serialize_flag", + "translate_activity", +] diff --git a/python/packages/purview/agent_framework_purview/_processor.py b/python/packages/purview/agent_framework_purview/_processor.py new file mode 100644 index 0000000000..81367dd58b --- /dev/null +++ b/python/packages/purview/agent_framework_purview/_processor.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft. All rights reserved. +from __future__ import annotations + +import uuid +from collections.abc import Iterable, MutableMapping +from typing import Any + +from agent_framework import ChatMessage + +from ._client import PurviewClient +from ._models import ( + Activity, + ActivityMetadata, + ContentActivitiesRequest, + ContentToProcess, + DeviceMetadata, + DlpAction, + DlpActionInfo, + IntegratedAppMetadata, + OperatingSystemSpecifications, + PolicyLocation, + ProcessContentRequest, + ProcessContentResponse, + ProcessConversationMetadata, + ProcessingError, + ProtectedAppMetadata, + ProtectionScopesRequest, + ProtectionScopesResponse, + PurviewTextContent, + RestrictionAction, + translate_activity, +) +from ._settings import PurviewSettings + + +def _is_valid_guid(value: str | None) -> bool: + """Check if a string is a valid GUID/UUID format using uuid module.""" + if not value: + return False + try: + uuid.UUID(value) + return True + except (ValueError, AttributeError): + return False + + +class ScopedContentProcessor: + """Combine protection scopes, process content, and content activities logic.""" + + def __init__(self, client: PurviewClient, settings: PurviewSettings): + self._client = client + self._settings = settings + + async def process_messages( + self, messages: Iterable[ChatMessage], activity: Activity, user_id: str | None = None + ) -> tuple[bool, str | None]: + """Process messages for policy evaluation. + + Args: + messages: The messages to process + activity: The activity type (e.g., UPLOAD_TEXT) + user_id: Optional user_id to use for all messages. If provided, this is the fallback. + + Returns: + A tuple of (should_block: bool, resolved_user_id: str | None). + The resolved_user_id can be stored and passed back when processing the response + to ensure the same user context is maintained throughout the request/response cycle. + """ + pc_requests, resolved_user_id = await self._map_messages(messages, activity, user_id) + should_block = False + for req in pc_requests: + resp = await self._process_with_scopes(req) + if resp.policy_actions: + for act in resp.policy_actions: + if act.action == DlpAction.BLOCK_ACCESS or act.restriction_action == RestrictionAction.BLOCK: + should_block = True + break + if should_block: + break + return should_block, resolved_user_id + + async def _map_messages( + self, messages: Iterable[ChatMessage], activity: Activity, provided_user_id: str | None = None + ) -> tuple[list[ProcessContentRequest], str | None]: + """Map messages to ProcessContentRequests. + + Args: + messages: The messages to map + activity: The activity type + provided_user_id: Optional user_id to use. If provided, this is the fallback. + + Returns: + A tuple of (requests, resolved_user_id) + """ + results: list[ProcessContentRequest] = [] + token_info = None + + if not (self._settings.tenant_id and self._settings.purview_app_location): + token_info = await self._client.get_user_info_from_token(tenant_id=self._settings.tenant_id) + + tenant_id = (token_info or {}).get("tenant_id") or self._settings.tenant_id + if not tenant_id or not _is_valid_guid(tenant_id): + raise ValueError("Tenant id required or must be inferable from credential") + + resolved_user_id = (token_info or {}).get("user_id") + resolved_author_name = None + if not resolved_user_id: + for m in messages: + if m.additional_properties: + potential_user_id = m.additional_properties.get("user_id") + if _is_valid_guid(potential_user_id): + resolved_user_id = potential_user_id + break + if m.author_name and _is_valid_guid(m.author_name) and not resolved_author_name: + resolved_author_name = m.author_name + + if not resolved_user_id and resolved_author_name: + resolved_user_id = resolved_author_name + + if not resolved_user_id: + resolved_user_id = provided_user_id if provided_user_id and _is_valid_guid(provided_user_id) else None + + # Return empty results if user_id is empty + if not resolved_user_id or not _is_valid_guid(resolved_user_id): + return results, None + + for m in messages: + message_id = m.message_id or str(uuid.uuid4()) + content = PurviewTextContent(data=m.text or "") + meta = ProcessConversationMetadata( + identifier=message_id, + content=content, + name=f"Agent Framework Message {message_id}", + is_truncated=False, + correlation_id=str(uuid.uuid4()), + ) + activity_meta = ActivityMetadata(activity=activity) + + if self._settings.purview_app_location: + policy_location = PolicyLocation( + data_type=self._settings.purview_app_location.get_policy_location()["@odata.type"], + value=self._settings.purview_app_location.location_value, + ) + elif token_info and token_info.get("client_id"): + policy_location = PolicyLocation( + data_type="microsoft.graph.policyLocationApplication", + value=token_info["client_id"], + ) + else: + raise ValueError("App location not provided or inferable") + + protected_app = ProtectedAppMetadata( + name=self._settings.app_name, + version="1.0", + application_location=policy_location, + ) + integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version="1.0") + device_meta = DeviceMetadata( + operating_system_specifications=OperatingSystemSpecifications( + operating_system_platform="Unknown", operating_system_version="Unknown" + ) + ) + + ctp = ContentToProcess( + content_entries=[meta], + activity_metadata=activity_meta, + device_metadata=device_meta, + integrated_app_metadata=integrated_app, + protected_app_metadata=protected_app, + ) + req = ProcessContentRequest( + content_to_process=ctp, + user_id=resolved_user_id, # Use the resolved user_id for all messages + tenant_id=tenant_id, + correlation_id=meta.correlation_id, + process_inline=True if self._settings.process_inline else None, + ) + results.append(req) + return results, resolved_user_id + + async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> ProcessContentResponse: + app_location = pc_request.content_to_process.protected_app_metadata.application_location + locations: list[PolicyLocation | MutableMapping[str, Any]] = [app_location] if app_location is not None else [] + + ps_req = ProtectionScopesRequest( + user_id=pc_request.user_id, + tenant_id=pc_request.tenant_id, + activities=translate_activity(pc_request.content_to_process.activity_metadata.activity), + locations=locations, + device_metadata=pc_request.content_to_process.device_metadata, + integrated_app_metadata=pc_request.content_to_process.integrated_app_metadata, + correlation_id=pc_request.correlation_id, + ) + ps_resp = await self._client.get_protection_scopes(ps_req) + should_process, dlp_actions = self._check_applicable_scopes(pc_request, ps_resp) + + if should_process: + pc_resp = await self._client.process_content(pc_request) + pc_resp.policy_actions = self._combine_policy_actions(pc_resp.policy_actions, dlp_actions) + return pc_resp + ca_req = ContentActivitiesRequest( + user_id=pc_request.user_id, + tenant_id=pc_request.tenant_id, + content_to_process=pc_request.content_to_process, + correlation_id=pc_request.correlation_id, + ) + ca_resp = await self._client.send_content_activities(ca_req) + if ca_resp.error: + return ProcessContentResponse(processing_errors=[ProcessingError(message=str(ca_resp.error))]) + return ProcessContentResponse() + + @staticmethod + def _combine_policy_actions( + existing: list[DlpActionInfo] | None, new_actions: list[DlpActionInfo] + ) -> list[DlpActionInfo]: + by_key: dict[str, DlpActionInfo] = {} + for a in existing or []: + if a.action: + by_key[a.action] = a + for a in new_actions: + if a.action: + by_key[a.action] = a + return list(by_key.values()) + + @staticmethod + def _check_applicable_scopes( + pc_request: ProcessContentRequest, ps_response: ProtectionScopesResponse + ) -> tuple[bool, list[DlpActionInfo]]: + req_activity = translate_activity(pc_request.content_to_process.activity_metadata.activity) + location = pc_request.content_to_process.protected_app_metadata.application_location + should_process: bool = False + dlp_actions: list[DlpActionInfo] = [] + for scope in ps_response.scopes or []: + # Check if all activities in req_activity are present in scope.activities using bitwise flags. + activity_match = bool(scope.activities and (scope.activities & req_activity) == req_activity) + location_match = False + if location is not None: + for loc in scope.locations or []: + if ( + loc.data_type + and location.data_type + and loc.data_type.lower().endswith(location.data_type.split(".")[-1].lower()) + and loc.value == location.value + ): + location_match = True + break + if activity_match and location_match: + should_process = True + if scope.policy_actions: + dlp_actions.extend(scope.policy_actions) + return should_process, dlp_actions diff --git a/python/packages/purview/agent_framework_purview/_settings.py b/python/packages/purview/agent_framework_purview/_settings.py new file mode 100644 index 0000000000..6b7e8d12f4 --- /dev/null +++ b/python/packages/purview/agent_framework_purview/_settings.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +from enum import Enum + +from agent_framework._pydantic import AFBaseSettings +from pydantic import BaseModel, Field +from pydantic_settings import SettingsConfigDict + + +class PurviewLocationType(str, Enum): + """The type of location for Purview policy evaluation.""" + + APPLICATION = "application" + URI = "uri" + DOMAIN = "domain" + + +class PurviewAppLocation(BaseModel): + """Identifier representing the app's location for Purview policy evaluation.""" + + location_type: PurviewLocationType = Field(..., description="The location type.") + location_value: str = Field(..., description="The location value.") + + def get_policy_location(self) -> dict[str, str]: + ns = "microsoft.graph" + if self.location_type == PurviewLocationType.APPLICATION: + dt = f"{ns}.policyLocationApplication" + elif self.location_type == PurviewLocationType.URI: + dt = f"{ns}.policyLocationUrl" + elif self.location_type == PurviewLocationType.DOMAIN: + dt = f"{ns}.policyLocationDomain" + else: # pragma: no cover - defensive + raise ValueError("Invalid Purview location type") + return {"@odata.type": dt, "value": self.location_value} + + +class PurviewSettings(AFBaseSettings): + """Settings for Purview integration mirroring .NET PurviewSettings. + + Attributes: + app_name: Public app name. + tenant_id: Optional tenant id (guid) of the user making the request. + purview_app_location: Optional app location for policy evaluation. + graph_base_uri: Base URI for Microsoft Graph. + blocked_prompt_message: Custom message to return when a prompt is blocked by policy. + blocked_response_message: Custom message to return when a response is blocked by policy. + """ + + app_name: str = Field(...) + tenant_id: str | None = Field(default=None) + purview_app_location: PurviewAppLocation | None = Field(default=None) + graph_base_uri: str = Field(default="https://graph.microsoft.com/v1.0/") + process_inline: bool = Field(default=False, description="Process content inline if supported.") + blocked_prompt_message: str = Field( + default="Prompt blocked by policy", + description="Message to return when a prompt is blocked by policy.", + ) + blocked_response_message: str = Field( + default="Response blocked by policy", + description="Message to return when a response is blocked by policy.", + ) + + model_config = SettingsConfigDict(populate_by_name=True, validate_assignment=True) + + def get_scopes(self) -> list[str]: + from urllib.parse import urlparse + + host = urlparse(self.graph_base_uri).hostname or "graph.microsoft.com" + return [f"https://{host}/.default"] diff --git a/python/packages/purview/pyproject.toml b/python/packages/purview/pyproject.toml new file mode 100644 index 0000000000..7336ae0c66 --- /dev/null +++ b/python/packages/purview/pyproject.toml @@ -0,0 +1,87 @@ +[project] +name = "agent-framework-purview" +description = "Microsoft Purview (Graph dataSecurityAndGovernance) integration for Microsoft Agent Framework." +authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] +readme = "README.md" +requires-python = ">=3.10" +version = "0.1.0b1" +license-files = ["LICENSE"] +urls.homepage = "https://github.com/microsoft/agent-framework" +urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" +urls.release_notes = "https://github.com/microsoft/agent-framework/releases" +urls.issues = "https://github.com/microsoft/agent-framework/issues" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Framework :: Pydantic :: 2", + "Typing :: Typed", +] +dependencies = [ + "agent-framework-core", + "azure-core>=1.30.0", + "httpx>=0.27.0", +] + +[tool.uv] +prerelease = "if-necessary-or-explicit" +environments = [ + "sys_platform == 'darwin'", + "sys_platform == 'linux'", + "sys_platform == 'win32'" +] + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.0" +[tool.pytest.ini_options] +testpaths = 'tests' +addopts = "-ra -q -r fEX" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [] + +[tool.ruff] +extend = "../../pyproject.toml" + +[tool.coverage.run] +omit = [ + "**/__init__.py" +] + +[tool.pyright] +extend = "../../pyproject.toml" +exclude = ['tests'] + +[tool.mypy] +plugins = ['pydantic.mypy'] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +no_implicit_optional = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false +disallow_incomplete_defs = true +disallow_untyped_decorators = true + +[tool.bandit] +targets = ["agent_framework_purview"] +exclude_dirs = ["tests"] + +[tool.poe] +executor.type = "uv" +include = "../../shared_tasks.toml" +[tool.poe.tasks] +mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_purview" +test = "pytest --cov=agent_framework_purview --cov-report=term-missing:skip-covered tests" + +[build-system] +requires = ["flit-core >= 3.9,<4.0"] +build-backend = "flit_core.buildapi" diff --git a/python/packages/purview/tests/conftest.py b/python/packages/purview/tests/conftest.py new file mode 100644 index 0000000000..dc9a7024e8 --- /dev/null +++ b/python/packages/purview/tests/conftest.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Shared pytest fixtures for Purview tests.""" + +import pytest + +from agent_framework_purview._models import ( + Activity, + ActivityMetadata, + ContentToProcess, + DeviceMetadata, + IntegratedAppMetadata, + OperatingSystemSpecifications, + PolicyLocation, + ProcessContentRequest, + ProcessConversationMetadata, + ProtectedAppMetadata, + PurviewTextContent, +) + + +@pytest.fixture +def content_to_process_factory(): + """Factory fixture to create ContentToProcess objects with test data.""" + + def _create_content(text: str = "Test") -> ContentToProcess: + text_content = PurviewTextContent(data=text) + metadata = ProcessConversationMetadata( + identifier="msg-1", + content=text_content, + name="Test", + is_truncated=False, + ) + activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT) + device_meta = DeviceMetadata( + operating_system_specifications=OperatingSystemSpecifications( + operating_system_platform="Windows", operating_system_version="10" + ) + ) + integrated_app = IntegratedAppMetadata(name="App", version="1.0") + location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id") + protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location) + + return ContentToProcess( + content_entries=[metadata], + activity_metadata=activity_meta, + device_metadata=device_meta, + integrated_app_metadata=integrated_app, + protected_app_metadata=protected_app, + ) + + return _create_content + + +@pytest.fixture +def process_content_request_factory(content_to_process_factory): + """Factory fixture to create ProcessContentRequest objects with test data.""" + + def _create_request( + text: str = "Test", user_id: str = "user-123", tenant_id: str = "tenant-456" + ) -> ProcessContentRequest: + content = content_to_process_factory(text) + return ProcessContentRequest( + content_to_process=content, + user_id=user_id, + tenant_id=tenant_id, + ) + + return _create_request diff --git a/python/packages/purview/tests/test_chat_middleware.py b/python/packages/purview/tests/test_chat_middleware.py new file mode 100644 index 0000000000..7e4faa0236 --- /dev/null +++ b/python/packages/purview/tests/test_chat_middleware.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Tests for Purview chat middleware.""" + +from dataclasses import dataclass +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import ChatContext, ChatMessage, Role +from azure.core.credentials import AccessToken + +from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings + + +@dataclass +class DummyChatClient: + name: str = "dummy" + + +class TestPurviewChatPolicyMiddleware: + @pytest.fixture + def mock_credential(self) -> AsyncMock: + credential = AsyncMock() + credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999)) + return credential + + @pytest.fixture + def settings(self) -> PurviewSettings: + return PurviewSettings(app_name="Test App", tenant_id="test-tenant") + + @pytest.fixture + def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewChatPolicyMiddleware: + return PurviewChatPolicyMiddleware(mock_credential, settings) + + @pytest.fixture + def chat_context(self) -> ChatContext: + chat_client = DummyChatClient() + chat_options = MagicMock() + chat_options.model = "test-model" + return ChatContext( + chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], chat_options=chat_options + ) + + async def test_initialization(self, middleware: PurviewChatPolicyMiddleware) -> None: + assert middleware._client is not None + assert middleware._processor is not None + + async def test_allows_clean_prompt( + self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext + ) -> None: + with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: + next_called = False + + async def mock_next(ctx: ChatContext) -> None: + nonlocal next_called + next_called = True + + class Result: + def __init__(self): + self.messages = [ChatMessage(role=Role.ASSISTANT, text="Hi there")] + + ctx.result = Result() + + await middleware.process(chat_context, mock_next) + assert next_called + assert mock_proc.call_count == 2 + assert chat_context.result.messages[0].role == Role.ASSISTANT + + async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None: + with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): + + async def mock_next(ctx: ChatContext) -> None: # should not run + raise AssertionError("next should not be called when prompt blocked") + + await middleware.process(chat_context, mock_next) + assert chat_context.terminate + assert chat_context.result + msg = chat_context.result[0] # type: ignore[index] + assert msg.role in ("system", Role.SYSTEM) + assert "blocked" in msg.text.lower() + + async def test_blocks_response(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None: + call_state = {"count": 0} + + async def side_effect(messages, activity, user_id=None): + call_state["count"] += 1 + should_block = call_state["count"] == 2 + return (should_block, "user-123") + + with patch.object(middleware._processor, "process_messages", side_effect=side_effect): + + async def mock_next(ctx: ChatContext) -> None: + class Result: + def __init__(self): + self.messages = [ChatMessage(role=Role.ASSISTANT, text="Sensitive output")] # pragma: no cover + + ctx.result = Result() + + await middleware.process(chat_context, mock_next) + assert call_state["count"] == 2 + msgs = getattr(chat_context.result, "messages", None) or chat_context.result + first_msg = msgs[0] + assert first_msg.role in ("system", Role.SYSTEM) + assert "blocked" in first_msg.text.lower() + + async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMiddleware) -> None: + chat_client = DummyChatClient() + chat_options = MagicMock() + chat_options.model = "test-model" + streaming_context = ChatContext( + chat_client=chat_client, + messages=[ChatMessage(role=Role.USER, text="Hello")], + chat_options=chat_options, + is_streaming=True, + ) + with patch.object(middleware._processor, "process_messages", return_value=False) as mock_proc: + + async def mock_next(ctx: ChatContext) -> None: + ctx.result = MagicMock() + + await middleware.process(streaming_context, mock_next) + assert mock_proc.call_count == 1 + + async def test_chat_middleware_handles_post_check_exception( + self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext + ) -> None: + """Test that exceptions in post-check are logged but don't affect result.""" + call_count = 0 + + async def mock_process_messages(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (False, "user-123") # Pre-check succeeds + raise Exception("Post-check error") # Post-check fails + + with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): + + async def mock_next(ctx: ChatContext) -> None: + # Create a mock result with messages attribute + result = MagicMock() + result.messages = [ChatMessage(role=Role.ASSISTANT, text="Response")] + ctx.result = result + + await middleware.process(chat_context, mock_next) + + # Should have been called twice (pre and post) + assert call_count == 2 + # Result should still be set + assert chat_context.result is not None diff --git a/python/packages/purview/tests/test_client.py b/python/packages/purview/tests/test_client.py new file mode 100644 index 0000000000..3becc9b49f --- /dev/null +++ b/python/packages/purview/tests/test_client.py @@ -0,0 +1,238 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Purview client.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from azure.core.credentials import AccessToken + +from agent_framework_purview import PurviewSettings +from agent_framework_purview._client import PurviewClient +from agent_framework_purview._exceptions import ( + PurviewAuthenticationError, + PurviewRateLimitError, + PurviewRequestError, + PurviewServiceError, +) +from agent_framework_purview._models import ( + PolicyLocation, + ProcessContentRequest, + ProtectionScopesRequest, +) + + +class TestPurviewClient: + """Test PurviewClient functionality.""" + + @pytest.fixture + def mock_credential(self) -> MagicMock: + """Create a mock async credential.""" + from azure.core.credentials_async import AsyncTokenCredential + + credential = MagicMock(spec=AsyncTokenCredential) + mock_token = AccessToken("fake-token", 9999999999) + + async def mock_get_token(*args, **kwargs): + return mock_token + + credential.get_token = mock_get_token + return credential + + @pytest.fixture + def settings(self) -> PurviewSettings: + """Create test settings.""" + return PurviewSettings(app_name="Test App", tenant_id="test-tenant", default_user_id="test-user") + + @pytest.fixture + async def client(self, mock_credential: MagicMock, settings: PurviewSettings) -> PurviewClient: + """Create a PurviewClient with mock credential.""" + client = PurviewClient(mock_credential, settings, timeout=10.0) + yield client + await client.close() + + async def test_client_initialization(self, mock_credential: MagicMock, settings: PurviewSettings) -> None: + """Test PurviewClient initialization.""" + client = PurviewClient(mock_credential, settings) + + assert client._credential == mock_credential + assert client._settings == settings + assert client._graph_uri == "https://graph.microsoft.com/v1.0" + assert client._timeout == 10.0 + + await client.close() + + async def test_get_token_async_credential(self, client: PurviewClient, mock_credential: MagicMock) -> None: + """Test _get_token with async credential.""" + token = await client._get_token(tenant_id="test-tenant") + + assert token == "fake-token" + + async def test_get_token_sync_credential(self, settings: PurviewSettings) -> None: + """Test _get_token with sync credential.""" + sync_credential = MagicMock() + sync_credential.get_token = MagicMock(return_value=AccessToken("sync-token", 9999999999)) + + client = PurviewClient(sync_credential, settings) + + with patch("asyncio.get_running_loop") as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = AccessToken("sync-token", 9999999999) + mock_loop.return_value.run_in_executor = mock_executor + + token = await client._get_token(tenant_id="test-tenant") + + assert token == "sync-token" + + await client.close() + + async def test_get_user_info_from_token(self, client: PurviewClient) -> None: + """Test get_user_info_from_token extracts user info.""" + import base64 + import json + + payload = {"tid": "test-tenant", "oid": "test-user", "idtyp": "user"} + payload_str = json.dumps(payload) + payload_bytes = payload_str.encode("utf-8") + payload_b64 = base64.urlsafe_b64encode(payload_bytes).decode("utf-8").rstrip("=") + fake_token = f"header.{payload_b64}.signature" + + with patch.object(client, "_get_token", return_value=fake_token): + user_info = await client.get_user_info_from_token(tenant_id="test-tenant") + + assert user_info["tenant_id"] == "test-tenant" + assert user_info["user_id"] == "test-user" + + @pytest.mark.parametrize( + "status_code,exception_type", + [ + (401, PurviewAuthenticationError), + (403, PurviewAuthenticationError), + (429, PurviewRateLimitError), + (400, PurviewRequestError), + (404, PurviewRequestError), + (500, PurviewServiceError), + (502, PurviewServiceError), + ], + ) + async def test_post_error_handling( + self, client: PurviewClient, content_to_process_factory, status_code: int, exception_type: type[Exception] + ) -> None: + """Test _post method handles different HTTP errors correctly.""" + from agent_framework_purview._models import ProcessContentResponse + + content = content_to_process_factory() + request = ProcessContentRequest( + content_to_process=content, + user_id="user-123", + tenant_id="tenant-456", + ) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = status_code + mock_response.text = "Error message" + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Error", request=MagicMock(), response=mock_response + ) + + with patch.object(client._client, "post", return_value=mock_response), pytest.raises(exception_type): + await client._post( + "https://graph.microsoft.com/v1.0/test", + request, + ProcessContentResponse, + "fake-token", + ) + + async def test_process_content_success( + self, client: PurviewClient, content_to_process_factory, mock_credential: MagicMock + ) -> None: + """Test process_content method success path.""" + content = content_to_process_factory("Test message") + request = ProcessContentRequest( + content_to_process=content, + user_id="user-123", + tenant_id="tenant-456", + ) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = {"id": "response-123", "protectionScopeState": "notModified"} + + with patch.object(client._client, "post", return_value=mock_response): + response = await client.process_content(request) + + assert response.id == "response-123" + assert response.protection_scope_state == "notModified" + + async def test_get_protection_scopes_success(self, client: PurviewClient) -> None: + """Test get_protection_scopes method success path.""" + location = PolicyLocation(**{"@odata.type": "microsoft.graph.policyLocationApplication", "value": "app-id"}) + request = ProtectionScopesRequest( + user_id="user-123", tenant_id="tenant-456", locations=[location], correlation_id="corr-789" + ) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = {"scopeIdentifier": "scope-123", "value": []} + + with patch.object(client._client, "post", return_value=mock_response): + response = await client.get_protection_scopes(request) + + assert response.scope_identifier == "scope-123" + assert response.scopes == [] + + async def test_client_close(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None: + """Test client properly closes HTTP client.""" + client = PurviewClient(mock_credential, settings) + + with patch.object(client._client, "aclose", new_callable=AsyncMock) as mock_close: + await client.close() + mock_close.assert_called_once() + + async def test_invalid_jwt_token_format(self, client: PurviewClient) -> None: + """Test that invalid JWT token format raises ValueError.""" + with pytest.raises(ValueError, match="Invalid JWT token format"): + client._extract_token_info("invalid-token-without-dots") + + async def test_rate_limit_error(self, client: PurviewClient) -> None: + """Test that 429 status code raises PurviewRateLimitError.""" + request = ProcessContentRequest( + user_id="test-user", + tenant_id="test-tenant", + content_to_process=[], + correlation_id="test-correlation-id", + ) + + with ( + patch.object(client, "_get_token", return_value="fake-token"), + patch.object( + client._client, + "post", + return_value=httpx.Response(429, text="Rate limited", request=httpx.Request("POST", "http://test")), + ), + pytest.raises(PurviewRateLimitError, match="Rate limited"), + ): + await client.process_content(request) + + async def test_generic_request_error(self, client: PurviewClient) -> None: + """Test that non-200/201/202 status codes raise PurviewRequestError.""" + request = ProcessContentRequest( + user_id="test-user", + tenant_id="test-tenant", + content_to_process=[], + correlation_id="test-correlation-id", + ) + + with ( + patch.object(client, "_get_token", return_value="fake-token"), + patch.object( + client._client, + "post", + return_value=httpx.Response( + 500, text="Internal server error", request=httpx.Request("POST", "http://test") + ), + ), + pytest.raises(PurviewRequestError, match="Purview request failed"), + ): + await client.process_content(request) diff --git a/python/packages/purview/tests/test_exceptions.py b/python/packages/purview/tests/test_exceptions.py new file mode 100644 index 0000000000..cb417d028d --- /dev/null +++ b/python/packages/purview/tests/test_exceptions.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Purview exceptions.""" + +from agent_framework_purview import ( + PurviewAuthenticationError, + PurviewRateLimitError, + PurviewRequestError, + PurviewServiceError, +) + + +class TestPurviewExceptions: + """Test custom Purview exception classes.""" + + def test_purview_service_error(self) -> None: + """Test PurviewServiceError base exception.""" + error = PurviewServiceError("Service error occurred") + assert str(error) == "Service error occurred" + assert isinstance(error, Exception) + + def test_purview_authentication_error(self) -> None: + """Test PurviewAuthenticationError exception.""" + error = PurviewAuthenticationError("Authentication failed") + assert str(error) == "Authentication failed" + assert isinstance(error, PurviewServiceError) + + def test_purview_rate_limit_error(self) -> None: + """Test PurviewRateLimitError exception.""" + error = PurviewRateLimitError("Rate limit exceeded") + assert str(error) == "Rate limit exceeded" + assert isinstance(error, PurviewServiceError) + + def test_purview_request_error(self) -> None: + """Test PurviewRequestError exception.""" + error = PurviewRequestError("Request failed") + assert str(error) == "Request failed" + assert isinstance(error, PurviewServiceError) diff --git a/python/packages/purview/tests/test_middleware.py b/python/packages/purview/tests/test_middleware.py new file mode 100644 index 0000000000..8d31d25950 --- /dev/null +++ b/python/packages/purview/tests/test_middleware.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Purview middleware.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import AgentRunContext, AgentRunResponse, ChatMessage, Role +from azure.core.credentials import AccessToken + +from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings + + +class TestPurviewPolicyMiddleware: + """Test PurviewPolicyMiddleware functionality.""" + + @pytest.fixture + def mock_credential(self) -> AsyncMock: + """Create a mock async credential.""" + credential = AsyncMock() + credential.get_token = AsyncMock(return_value=AccessToken("fake-token", 9999999999)) + return credential + + @pytest.fixture + def settings(self) -> PurviewSettings: + """Create test settings.""" + return PurviewSettings(app_name="Test App", tenant_id="test-tenant") + + @pytest.fixture + def middleware(self, mock_credential: AsyncMock, settings: PurviewSettings) -> PurviewPolicyMiddleware: + """Create PurviewPolicyMiddleware instance.""" + return PurviewPolicyMiddleware(mock_credential, settings) + + @pytest.fixture + def mock_agent(self) -> MagicMock: + """Create a mock agent.""" + agent = MagicMock() + agent.name = "test-agent" + return agent + + def test_middleware_initialization(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None: + """Test PurviewPolicyMiddleware initialization.""" + middleware = PurviewPolicyMiddleware(mock_credential, settings) + + assert middleware._client is not None + assert middleware._processor is not None + + async def test_middleware_allows_clean_prompt( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test middleware allows prompt that passes policy check.""" + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello, how are you?")]) + + with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): + next_called = False + + async def mock_next(ctx: AgentRunContext) -> None: + nonlocal next_called + next_called = True + ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="I'm good, thanks!")]) + + await middleware.process(context, mock_next) + + assert next_called + assert context.result is not None + assert not context.terminate + + async def test_middleware_blocks_prompt_on_policy_violation( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test middleware blocks prompt that violates policy.""" + context = AgentRunContext( + agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Sensitive information")] + ) + + with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): + next_called = False + + async def mock_next(ctx: AgentRunContext) -> None: + nonlocal next_called + next_called = True + + await middleware.process(context, mock_next) + + assert not next_called + assert context.result is not None + assert context.terminate + assert len(context.result.messages) == 1 + assert context.result.messages[0].role == Role.SYSTEM + assert "blocked by policy" in context.result.messages[0].text.lower() + + async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None: + """Test middleware checks agent response for policy violations.""" + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")]) + + call_count = 0 + + async def mock_process_messages(messages, activity, user_id=None): + nonlocal call_count + call_count += 1 + should_block = call_count != 1 + return (should_block, "user-123") + + with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = AgentRunResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="Here's some sensitive information")] + ) + + await middleware.process(context, mock_next) + + assert call_count == 2 + assert context.result is not None + assert len(context.result.messages) == 1 + assert context.result.messages[0].role == Role.SYSTEM + assert "blocked by policy" in context.result.messages[0].text.lower() + + async def test_middleware_handles_result_without_messages( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test middleware handles result that doesn't have messages attribute.""" + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Hello")]) + + with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = "Some non-standard result" + + await middleware.process(context, mock_next) + + assert context.result == "Some non-standard result" + + async def test_middleware_processor_receives_correct_activity( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test middleware passes correct activity type to processor.""" + from agent_framework_purview._models import Activity + + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")]) + + with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process: + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) + + await middleware.process(context, mock_next) + + assert mock_process.call_count == 2 + for call in mock_process.call_args_list: + assert call[0][1] == Activity.UPLOAD_TEXT + + async def test_middleware_handles_pre_check_exception( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test that exceptions in pre-check are logged but don't stop processing.""" + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")]) + + with patch.object( + middleware._processor, "process_messages", side_effect=Exception("Pre-check error") + ) as mock_process: + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) + + await middleware.process(context, mock_next) + + # Should have been called twice (pre-check raises, then post-check also raises) + assert mock_process.call_count == 2 + # Context should not be terminated + assert not context.terminate + # Result should be set by mock_next + assert context.result is not None + + async def test_middleware_handles_post_check_exception( + self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock + ) -> None: + """Test that exceptions in post-check are logged but don't affect result.""" + context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role=Role.USER, text="Test")]) + + call_count = 0 + + async def mock_process_messages(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (False, "user-123") # Pre-check succeeds + raise Exception("Post-check error") # Post-check fails + + with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): + + async def mock_next(ctx: AgentRunContext) -> None: + ctx.result = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Response")]) + + await middleware.process(context, mock_next) + + # Should have been called twice (pre and post) + assert call_count == 2 + # Result should still be set + assert context.result is not None + assert hasattr(context.result, "messages") diff --git a/python/packages/purview/tests/test_models.py b/python/packages/purview/tests/test_models.py new file mode 100644 index 0000000000..aa4f62d79d --- /dev/null +++ b/python/packages/purview/tests/test_models.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Purview models and serialization.""" + +from agent_framework_purview._models import ( + Activity, + ActivityMetadata, + ContentToProcess, + DeviceMetadata, + IntegratedAppMetadata, + OperatingSystemSpecifications, + PolicyLocation, + ProcessContentRequest, + ProcessContentResponse, + ProcessConversationMetadata, + ProtectedAppMetadata, + ProtectionScopeActivities, + ProtectionScopesRequest, + ProtectionScopesResponse, + PurviewTextContent, + deserialize_flag, + serialize_flag, +) + + +class TestFlagOperations: + """Test flag serialization and deserialization operations.""" + + def test_protection_scope_activities_flag_combination(self) -> None: + """Test combining flags.""" + combined = ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE + assert combined.value == 3 + assert ProtectionScopeActivities.UPLOAD_TEXT in combined + assert ProtectionScopeActivities.UPLOAD_FILE in combined + + def test_deserialize_flag_with_string(self) -> None: + """Test deserializing flag from comma-separated string.""" + mapping = { + "uploadText": ProtectionScopeActivities.UPLOAD_TEXT, + "uploadFile": ProtectionScopeActivities.UPLOAD_FILE, + } + + result = deserialize_flag("uploadText,uploadFile", mapping, ProtectionScopeActivities) + assert result is not None + assert ProtectionScopeActivities.UPLOAD_TEXT in result + assert ProtectionScopeActivities.UPLOAD_FILE in result + + def test_deserialize_flag_with_none(self) -> None: + """Test deserializing None returns None.""" + mapping = {"uploadText": ProtectionScopeActivities.UPLOAD_TEXT} + result = deserialize_flag(None, mapping, ProtectionScopeActivities) + assert result is None + + def test_serialize_flag_with_none(self) -> None: + """Test serializing None returns None.""" + result = serialize_flag(None, []) + assert result is None + + def test_serialize_flag_with_values(self) -> None: + """Test serializing flag with values.""" + flag = ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE + ordered = [ + ("uploadText", ProtectionScopeActivities.UPLOAD_TEXT), + ("uploadFile", ProtectionScopeActivities.UPLOAD_FILE), + ] + result = serialize_flag(flag, ordered) + assert result == "uploadText,uploadFile" + + +class TestComplexModels: + """Test complex models with nested structures.""" + + def test_content_to_process_with_nested_structures(self) -> None: + """Test ContentToProcess with all nested structures.""" + text_content = PurviewTextContent(data="Test") + metadata = ProcessConversationMetadata( + identifier="msg-1", + content=text_content, + name="Test", + is_truncated=False, + ) + + activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT) + device_meta = DeviceMetadata( + operating_system_specifications=OperatingSystemSpecifications( + operating_system_platform="Windows", operating_system_version="10" + ) + ) + integrated_app = IntegratedAppMetadata(name="App", version="1.0") + location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id") + protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location) + + content = ContentToProcess( + content_entries=[metadata], + activity_metadata=activity_meta, + device_metadata=device_meta, + integrated_app_metadata=integrated_app, + protected_app_metadata=protected_app, + ) + + assert len(content.content_entries) == 1 + assert content.activity_metadata.activity == Activity.UPLOAD_TEXT + assert content.device_metadata.operating_system_specifications.operating_system_platform == "Windows" + assert content.integrated_app_metadata.name == "App" + assert content.protected_app_metadata.name == "Protected" + + +class TestRequestResponseSerialization: + """Test request/response serialization with aliases.""" + + def test_protection_scopes_request_serialization(self) -> None: + """Test ProtectionScopesRequest serializes activities correctly.""" + location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id") + + request = ProtectionScopesRequest( + user_id="user-123", + tenant_id="tenant-456", + activities=ProtectionScopeActivities.UPLOAD_TEXT | ProtectionScopeActivities.UPLOAD_FILE, + locations=[location], + ) + + dumped = request.model_dump(by_alias=True, exclude_none=True, mode="json") + + assert "activities" in dumped + assert isinstance(dumped["activities"], str) + assert "uploadText" in dumped["activities"] + + +class TestModelDeserialization: + """Test model deserialization from API responses.""" + + def test_protection_scopes_response_deserialization(self) -> None: + """Test ProtectionScopesResponse deserializes 'value' to 'scopes'.""" + api_data = { + "scopeIdentifier": "scope-123", + "value": [ + { + "activities": "uploadText,downloadText", + "locations": [{"@odata.type": "location.type", "value": "/path"}], + "policyActions": [{"action": "warn", "restrictionAction": "blockAccess"}], + "executionMode": "evaluateInline", + } + ], + } + + response = ProtectionScopesResponse.model_validate(api_data) + + assert response.scope_identifier == "scope-123" + assert response.scopes is not None + assert len(response.scopes) == 1 + assert response.scopes[0].execution_mode == "evaluateInline" + + def test_process_content_response_deserialization(self) -> None: + """Test ProcessContentResponse deserializes aliased fields correctly.""" + api_data = { + "id": "response-123", + "protectionScopeState": "blocked", + "policyActions": [{"action": "block", "restrictionAction": "blockAccess"}], + } + + response = ProcessContentResponse.model_validate(api_data) + + assert response.id == "response-123" + assert response.protection_scope_state == "blocked" + assert len(response.policy_actions) == 1 + + def test_content_serialization_uses_aliases(self) -> None: + """Test ContentToProcess serializes with camelCase aliases.""" + text_content = PurviewTextContent(data="Test") + metadata = ProcessConversationMetadata( + identifier="msg-1", + content=text_content, + name="Test", + is_truncated=False, + ) + + activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT) + device_meta = DeviceMetadata( + operating_system_specifications=OperatingSystemSpecifications( + operating_system_platform="Windows", operating_system_version="10" + ) + ) + integrated_app = IntegratedAppMetadata(name="App", version="1.0") + location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id") + protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location) + + content = ContentToProcess( + content_entries=[metadata], + activity_metadata=activity_meta, + device_metadata=device_meta, + integrated_app_metadata=integrated_app, + protected_app_metadata=protected_app, + ) + + dumped = content.model_dump(by_alias=True, exclude_none=True, mode="json") + + assert "contentEntries" in dumped + assert "activityMetadata" in dumped + assert "deviceMetadata" in dumped + assert "integratedAppMetadata" in dumped + assert "protectedAppMetadata" in dumped + + def test_process_content_request_excludes_private_fields(self) -> None: + """Test ProcessContentRequest excludes private fields when serializing.""" + text_content = PurviewTextContent(data="Test") + metadata = ProcessConversationMetadata( + identifier="msg-1", + content=text_content, + name="Test", + is_truncated=False, + ) + + activity_meta = ActivityMetadata(activity=Activity.UPLOAD_TEXT) + device_meta = DeviceMetadata( + operating_system_specifications=OperatingSystemSpecifications( + operating_system_platform="Windows", operating_system_version="10" + ) + ) + integrated_app = IntegratedAppMetadata(name="App", version="1.0") + location = PolicyLocation(data_type="microsoft.graph.policyLocationApplication", value="app-id") + protected_app = ProtectedAppMetadata(name="Protected", version="1.0", application_location=location) + + content = ContentToProcess( + content_entries=[metadata], + activity_metadata=activity_meta, + device_metadata=device_meta, + integrated_app_metadata=integrated_app, + protected_app_metadata=protected_app, + ) + + request = ProcessContentRequest( + content_to_process=content, + user_id="user-123", + tenant_id="tenant-456", + correlation_id="corr-789", + ) + + dumped = request.model_dump(by_alias=True, exclude_none=True, mode="json") + + # Check that excluded fields are not present + assert "user_id" not in dumped + assert "tenant_id" not in dumped + assert "correlation_id" not in dumped + + # Check that content is present + assert "contentToProcess" in dumped diff --git a/python/packages/purview/tests/test_processor.py b/python/packages/purview/tests/test_processor.py new file mode 100644 index 0000000000..96554af126 --- /dev/null +++ b/python/packages/purview/tests/test_processor.py @@ -0,0 +1,369 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Purview processor.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework import ChatMessage, Role + +from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings +from agent_framework_purview._models import ( + Activity, + DlpAction, + DlpActionInfo, + ProcessContentResponse, + RestrictionAction, +) +from agent_framework_purview._processor import ScopedContentProcessor, _is_valid_guid + + +class TestGuidValidation: + """Test GUID validation helper.""" + + def test_valid_guid(self) -> None: + """Test _is_valid_guid with valid GUIDs.""" + assert _is_valid_guid("12345678-1234-1234-1234-123456789012") + assert _is_valid_guid("a1b2c3d4-e5f6-4a5b-8c9d-0e1f2a3b4c5d") + + def test_invalid_guid(self) -> None: + """Test _is_valid_guid with invalid GUIDs.""" + assert not _is_valid_guid("not-a-guid") + assert not _is_valid_guid("") + assert not _is_valid_guid(None) + + +class TestScopedContentProcessor: + """Test ScopedContentProcessor functionality.""" + + @pytest.fixture + def mock_client(self) -> AsyncMock: + """Create a mock Purview client.""" + client = AsyncMock() + client.get_user_info_from_token = AsyncMock( + return_value={ + "tenant_id": "12345678-1234-1234-1234-123456789012", + "user_id": "12345678-1234-1234-1234-123456789012", + "client_id": "12345678-1234-1234-1234-123456789012", + } + ) + return client + + @pytest.fixture + def settings_with_defaults(self) -> PurviewSettings: + """Create settings with default values.""" + app_location = PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, location_value="12345678-1234-1234-1234-123456789012" + ) + return PurviewSettings( + app_name="Test App", + tenant_id="12345678-1234-1234-1234-123456789012", + purview_app_location=app_location, + ) + + @pytest.fixture + def settings_without_defaults(self) -> PurviewSettings: + """Create settings without default values (requiring token info).""" + return PurviewSettings(app_name="Test App") + + @pytest.fixture + def processor(self, mock_client: AsyncMock, settings_with_defaults: PurviewSettings) -> ScopedContentProcessor: + """Create a ScopedContentProcessor with mock client.""" + return ScopedContentProcessor(mock_client, settings_with_defaults) + + async def test_processor_initialization( + self, mock_client: AsyncMock, settings_with_defaults: PurviewSettings + ) -> None: + """Test ScopedContentProcessor initialization.""" + processor = ScopedContentProcessor(mock_client, settings_with_defaults) + + assert processor._client == mock_client + assert processor._settings == settings_with_defaults + + async def test_process_messages_with_defaults(self, processor: ScopedContentProcessor) -> None: + """Test process_messages with settings that have defaults.""" + messages = [ + ChatMessage(role=Role.USER, text="Hello"), + ChatMessage(role=Role.ASSISTANT, text="Hi there"), + ] + + with patch.object(processor, "_map_messages", return_value=([], None)) as mock_map: + should_block, user_id = await processor.process_messages(messages, Activity.UPLOAD_TEXT) + + assert should_block is False + assert user_id is None + mock_map.assert_called_once_with(messages, Activity.UPLOAD_TEXT, None) + + async def test_process_messages_blocks_content( + self, processor: ScopedContentProcessor, process_content_request_factory + ) -> None: + """Test process_messages returns True when content should be blocked.""" + messages = [ChatMessage(role=Role.USER, text="Sensitive content")] + + mock_request = process_content_request_factory("Sensitive content") + + mock_response = ProcessContentResponse(**{ + "policyActions": [DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)] + }) + + with ( + patch.object(processor, "_map_messages", return_value=([mock_request], "user-123")), + patch.object(processor, "_process_with_scopes", return_value=mock_response), + ): + should_block, user_id = await processor.process_messages(messages, Activity.UPLOAD_TEXT) + + assert should_block is True + assert user_id == "user-123" + + async def test_map_messages_creates_requests( + self, processor: ScopedContentProcessor, mock_client: AsyncMock + ) -> None: + """Test _map_messages creates ProcessContentRequest objects.""" + messages = [ + ChatMessage( + role=Role.USER, + text="Test message", + message_id="msg-123", + author_name="12345678-1234-1234-1234-123456789012", + ), + ] + + requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT) + + assert len(requests) == 1 + assert requests[0].user_id == "12345678-1234-1234-1234-123456789012" + assert requests[0].tenant_id == "12345678-1234-1234-1234-123456789012" + assert user_id == "12345678-1234-1234-1234-123456789012" + + async def test_map_messages_without_defaults_gets_token_info(self, mock_client: AsyncMock) -> None: + """Test _map_messages gets token info when settings lack some defaults.""" + settings = PurviewSettings(app_name="Test App", tenant_id="12345678-1234-1234-1234-123456789012") + processor = ScopedContentProcessor(mock_client, settings) + messages = [ChatMessage(role=Role.USER, text="Test", message_id="msg-123")] + + requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT) + + mock_client.get_user_info_from_token.assert_called_once() + assert len(requests) == 1 + assert user_id is not None + + async def test_map_messages_raises_on_missing_tenant_id(self, mock_client: AsyncMock) -> None: + """Test _map_messages raises ValueError when tenant_id cannot be determined.""" + settings = PurviewSettings(app_name="Test App") # No tenant_id + processor = ScopedContentProcessor(mock_client, settings) + + mock_client.get_user_info_from_token = AsyncMock( + return_value={"user_id": "test-user", "client_id": "test-client"} + ) + + messages = [ChatMessage(role=Role.USER, text="Test", message_id="msg-123")] + + with pytest.raises(ValueError, match="Tenant id required"): + await processor._map_messages(messages, Activity.UPLOAD_TEXT) + + async def test_check_applicable_scopes_no_scopes( + self, processor: ScopedContentProcessor, process_content_request_factory + ) -> None: + """Test _check_applicable_scopes when no scopes are returned.""" + from agent_framework_purview._models import ProtectionScopesResponse + + request = process_content_request_factory() + response = ProtectionScopesResponse(**{"value": None}) + + should_process, actions = processor._check_applicable_scopes(request, response) + + assert should_process is False + assert actions == [] + + async def test_check_applicable_scopes_with_block_action( + self, processor: ScopedContentProcessor, process_content_request_factory + ) -> None: + """Test _check_applicable_scopes identifies block actions.""" + from agent_framework_purview._models import ( + PolicyLocation, + PolicyScope, + ProtectionScopeActivities, + ProtectionScopesResponse, + ) + + request = process_content_request_factory() + + block_action = DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK) + scope_location = PolicyLocation(**{ + "@odata.type": "microsoft.graph.policyLocationApplication", + "value": "app-id", + }) + scope = PolicyScope(**{ + "policyActions": [block_action], + "activities": ProtectionScopeActivities.UPLOAD_TEXT, + "locations": [scope_location], + }) + response = ProtectionScopesResponse(**{"value": [scope]}) + + should_process, actions = processor._check_applicable_scopes(request, response) + + assert should_process is True + assert len(actions) == 1 + assert actions[0].action == DlpAction.BLOCK_ACCESS + + async def test_combine_policy_actions(self, processor: ScopedContentProcessor) -> None: + """Test _combine_policy_actions merges action lists.""" + action1 = DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK) + action2 = DlpActionInfo(action=DlpAction.OTHER, restrictionAction=RestrictionAction.OTHER) + + combined = processor._combine_policy_actions([action1], [action2]) + + assert len(combined) == 2 + assert action1 in combined + assert action2 in combined + + async def test_process_with_scopes_calls_client_methods( + self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test _process_with_scopes calls get_protection_scopes and process_content.""" + from agent_framework_purview._models import ( + ContentActivitiesResponse, + ProtectionScopesResponse, + ) + + request = process_content_request_factory() + + mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": []})) + mock_client.process_content = AsyncMock( + return_value=ProcessContentResponse(**{"id": "response-123", "protectionScopeState": "notModified"}) + ) + mock_client.send_content_activities = AsyncMock(return_value=ContentActivitiesResponse(**{"error": None})) + + response = await processor._process_with_scopes(request) + + mock_client.get_protection_scopes.assert_called_once() + mock_client.process_content.assert_not_called() + mock_client.send_content_activities.assert_called_once() + assert response.id is None + + async def test_map_messages_with_user_id_in_additional_properties(self, mock_client: AsyncMock) -> None: + """Test user_id extraction from message additional_properties.""" + settings = PurviewSettings( + app_name="Test App", + tenant_id="12345678-1234-1234-1234-123456789012", + purview_app_location=PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, location_value="app-id" + ), + ) + processor = ScopedContentProcessor(mock_client, settings) + + messages = [ + ChatMessage( + role=Role.USER, + text="Test message", + additional_properties={"user_id": "22345678-1234-1234-1234-123456789012"}, + ), + ] + + requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT) + + assert len(requests) == 1 + assert user_id == "22345678-1234-1234-1234-123456789012" + assert requests[0].user_id == "22345678-1234-1234-1234-123456789012" + + async def test_map_messages_with_provided_user_id_fallback(self, mock_client: AsyncMock) -> None: + """Test using provided_user_id when no other source is available.""" + settings = PurviewSettings( + app_name="Test App", + tenant_id="12345678-1234-1234-1234-123456789012", + purview_app_location=PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, location_value="app-id" + ), + ) + processor = ScopedContentProcessor(mock_client, settings) + + messages = [ChatMessage(role=Role.USER, text="Test message")] + + requests, user_id = await processor._map_messages( + messages, Activity.UPLOAD_TEXT, provided_user_id="32345678-1234-1234-1234-123456789012" + ) + + assert len(requests) == 1 + assert user_id == "32345678-1234-1234-1234-123456789012" + assert requests[0].user_id == "32345678-1234-1234-1234-123456789012" + + async def test_map_messages_returns_empty_when_no_user_id(self, mock_client: AsyncMock) -> None: + """Test that empty results are returned when user_id cannot be resolved.""" + settings = PurviewSettings( + app_name="Test App", + tenant_id="12345678-1234-1234-1234-123456789012", + purview_app_location=PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, location_value="app-id" + ), + ) + processor = ScopedContentProcessor(mock_client, settings) + + messages = [ChatMessage(role=Role.USER, text="Test message")] + + requests, user_id = await processor._map_messages(messages, Activity.UPLOAD_TEXT) + + assert len(requests) == 0 + assert user_id is None + + async def test_process_content_sends_activities_when_not_applicable( + self, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test that content activities are sent when scopes don't apply.""" + settings = PurviewSettings( + app_name="Test App", + tenant_id="12345678-1234-1234-1234-123456789012", + purview_app_location=PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, location_value="app-id" + ), + ) + processor = ScopedContentProcessor(mock_client, settings) + + pc_request = process_content_request_factory() + + # Mock get_protection_scopes to return no applicable scopes + mock_ps_response = MagicMock() + mock_ps_response.scopes = [] + mock_client.get_protection_scopes.return_value = mock_ps_response + + # Mock send_content_activities to return success + mock_ca_response = MagicMock() + mock_ca_response.error = None + mock_client.send_content_activities.return_value = mock_ca_response + + response = await processor._process_with_scopes(pc_request) + + mock_client.get_protection_scopes.assert_called_once() + mock_client.process_content.assert_not_called() + mock_client.send_content_activities.assert_called_once() + # When content activities succeed, response has no errors (processing_errors can be None or empty) + assert response.processing_errors is None or response.processing_errors == [] + + async def test_process_content_handles_activities_error( + self, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test error handling when content activities fail.""" + settings = PurviewSettings( + app_name="Test App", + tenant_id="12345678-1234-1234-1234-123456789012", + purview_app_location=PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, location_value="app-id" + ), + ) + processor = ScopedContentProcessor(mock_client, settings) + + pc_request = process_content_request_factory() + + # Mock get_protection_scopes to return no applicable scopes + mock_ps_response = MagicMock() + mock_ps_response.scopes = [] + mock_client.get_protection_scopes.return_value = mock_ps_response + + # Mock send_content_activities to return error + mock_ca_response = MagicMock() + mock_ca_response.error = "Test error message" + mock_client.send_content_activities.return_value = mock_ca_response + + response = await processor._process_with_scopes(pc_request) + + assert len(response.processing_errors) == 1 + assert response.processing_errors[0].message == "Test error message" diff --git a/python/packages/purview/tests/test_settings.py b/python/packages/purview/tests/test_settings.py new file mode 100644 index 0000000000..dac8d68d63 --- /dev/null +++ b/python/packages/purview/tests/test_settings.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for Purview settings.""" + +import pytest + +from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings + + +class TestPurviewSettings: + """Test PurviewSettings configuration.""" + + def test_settings_defaults(self) -> None: + """Test PurviewSettings with default values.""" + settings = PurviewSettings(app_name="Test App") + + assert settings.app_name == "Test App" + assert settings.graph_base_uri == "https://graph.microsoft.com/v1.0/" + assert settings.tenant_id is None + assert settings.purview_app_location is None + assert settings.process_inline is False + + def test_settings_with_custom_values(self) -> None: + """Test PurviewSettings with custom values.""" + app_location = PurviewAppLocation(location_type=PurviewLocationType.APPLICATION, location_value="app-123") + + settings = PurviewSettings( + app_name="Test App", + graph_base_uri="https://graph.microsoft-ppe.com", + tenant_id="test-tenant-id", + process_inline=True, + purview_app_location=app_location, + ) + + assert settings.graph_base_uri == "https://graph.microsoft-ppe.com" + assert settings.tenant_id == "test-tenant-id" + assert settings.process_inline is True + assert settings.purview_app_location.location_value == "app-123" + + @pytest.mark.parametrize( + "graph_uri,expected_scope", + [ + ("https://graph.microsoft.com/v1.0/", "https://graph.microsoft.com/.default"), + ("https://graph.microsoft-ppe.com/v1.0/", "https://graph.microsoft-ppe.com/.default"), + ], + ) + def test_get_scopes(self, graph_uri: str, expected_scope: str) -> None: + """Test get_scopes returns correct scope for different URIs.""" + settings = PurviewSettings(app_name="Test App", graph_base_uri=graph_uri) + scopes = settings.get_scopes() + + assert len(scopes) == 1 + assert expected_scope in scopes + + +class TestPurviewAppLocation: + """Test PurviewAppLocation configuration.""" + + @pytest.mark.parametrize( + "location_type,location_value,expected_odata_type", + [ + (PurviewLocationType.APPLICATION, "app-123", "microsoft.graph.policyLocationApplication"), + (PurviewLocationType.URI, "https://example.com", "microsoft.graph.policyLocationUrl"), + (PurviewLocationType.DOMAIN, "example.com", "microsoft.graph.policyLocationDomain"), + ], + ) + def test_get_policy_location( + self, location_type: PurviewLocationType, location_value: str, expected_odata_type: str + ) -> None: + """Test get_policy_location returns correct structure for all location types.""" + location = PurviewAppLocation(location_type=location_type, location_value=location_value) + policy_location = location.get_policy_location() + + assert policy_location["@odata.type"] == expected_odata_type + assert policy_location["value"] == location_value + + +class TestPurviewLocationType: + """Test PurviewLocationType enum.""" + + def test_location_type_values(self) -> None: + """Test PurviewLocationType enum has expected values.""" + assert PurviewLocationType.APPLICATION == "application" + assert PurviewLocationType.URI == "uri" + assert PurviewLocationType.DOMAIN == "domain" diff --git a/python/samples/getting_started/purview_agent/README.md b/python/samples/getting_started/purview_agent/README.md new file mode 100644 index 0000000000..59b2ee918b --- /dev/null +++ b/python/samples/getting_started/purview_agent/README.md @@ -0,0 +1,88 @@ +## Purview Policy Enforcement Sample (Python) + +This getting-started sample shows how to attach Microsoft Purview policy evaluation to an Agent Framework `ChatAgent` using the **middleware** approach. + +1. Configure an Azure OpenAI chat client +2. Add Purview policy enforcement middleware (`PurviewPolicyMiddleware`) +3. Run a short conversation and observe prompt / response blocking behavior + +--- +## 1. Setup +### Required Environment Variables + +| Variable | Required | Purpose | +|----------|----------|---------| +| `AZURE_OPENAI_ENDPOINT` | Yes | Azure OpenAI endpoint (https://.openai.azure.com) | +| `AZURE_OPENAI_DEPLOYMENT_NAME` | Optional | Model deployment name (defaults inside SDK if omitted) | +| `PURVIEW_CLIENT_APP_ID` | Yes* | Client (application) ID used for Purview authentication | +| `PURVIEW_USE_CERT_AUTH` | Optional (`true`/`false`) | Switch between certificate and interactive auth | +| `PURVIEW_TENANT_ID` | Yes (when cert auth on) | Tenant ID for certificate authentication | +| `PURVIEW_CERT_PATH` | Yes (when cert auth on) | Path to your .pfx certificate | +| `PURVIEW_CERT_PASSWORD` | Optional | Password for encrypted certs | + +*A demo default exists in code for illustration only—always set your own value. + +### 2. Auth Modes Supported + +#### A. Interactive Browser Authentication (default) +Opens a browser on first run to sign in. + +```powershell +$env:AZURE_OPENAI_ENDPOINT = "https://your-openai-instance.openai.azure.com" +$env:PURVIEW_CLIENT_APP_ID = "00000000-0000-0000-0000-000000000000" +``` + +#### B. Certificate Authentication +For headless / CI scenarios. + +```powershell +$env:PURVIEW_USE_CERT_AUTH = "true" +$env:PURVIEW_TENANT_ID = "" +$env:PURVIEW_CERT_PATH = "C:\path\to\cert.pfx" +$env:PURVIEW_CERT_PASSWORD = "optional-password" +``` + +Certificate steps (summary): create / register app, generate certificate, upload public key, export .pfx with private key, grant required Graph / Purview permissions. + +--- + +## 3. Run the Sample + +From repo root: + +```powershell +cd python/samples/getting_started/purview_agent +python sample_purview_agent.py +``` + +If interactive auth is used, a browser window will appear the first time. + +--- + +## 4. How It Works + +1. Builds an Azure OpenAI chat client (using the environment endpoint / deployment) +2. Chooses credential mode (certificate vs interactive) +3. Creates `PurviewPolicyMiddleware` with `PurviewSettings` +4. Injects middleware into the agent at construction +5. Sends two user messages sequentially +6. Prints results (or policy block messages) + +Prompt blocks set a system-level message: `Prompt blocked by policy` and terminate the run early. Response blocks rewrite the output to `Response blocked by policy`. + +--- + +## 5. Code Snippet (Middleware Injection) + +```python +agent = ChatAgent( + chat_client=chat_client, + instructions="You are good at telling jokes.", + name="Joker", + middleware=[ + PurviewPolicyMiddleware(credential, PurviewSettings(app_name="Sample App", default_user_id="")) + ], +) +``` + +--- diff --git a/python/samples/getting_started/purview_agent/sample_purview_agent.py b/python/samples/getting_started/purview_agent/sample_purview_agent.py new file mode 100644 index 0000000000..00e596c32f --- /dev/null +++ b/python/samples/getting_started/purview_agent/sample_purview_agent.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Purview policy enforcement sample (Python). + +Shows: +1. Creating a basic chat agent +2. Adding Purview policy evaluation via AGENT middleware (agent-level) +3. Adding Purview policy evaluation via CHAT middleware (chat-client level) +4. Running a threaded conversation and printing results + +Environment variables: +- AZURE_OPENAI_ENDPOINT (required) +- AZURE_OPENAI_DEPLOYMENT_NAME (optional, defaults to gpt-4o-mini) +- PURVIEW_CLIENT_APP_ID (required) +- PURVIEW_USE_CERT_AUTH (optional, set to "true" for certificate auth) +- PURVIEW_TENANT_ID (required if certificate auth) +- PURVIEW_CERT_PATH (required if certificate auth) +- PURVIEW_CERT_PASSWORD (optional) +- PURVIEW_DEFAULT_USER_ID (optional, user ID for Purview evaluation) +""" +from __future__ import annotations + +import asyncio +import os +from typing import Any + +from agent_framework import AgentRunResponse, ChatAgent, ChatMessage, Role +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import ( + AzureCliCredential, + CertificateCredential, + InteractiveBrowserCredential, +) + +# Purview integration pieces +from agent_framework.microsoft import ( + PurviewPolicyMiddleware, + PurviewChatPolicyMiddleware, + PurviewSettings, +) + +JOKER_NAME = "Joker" +JOKER_INSTRUCTIONS = "You are good at telling jokes. Keep responses concise." + + +def _get_env(name: str, *, required: bool = True, default: str | None = None) -> str: + val = os.environ.get(name, default) + if required and not val: + raise RuntimeError(f"Environment variable {name} is required") + return val # type: ignore[return-value] + + +def build_credential() -> Any: + """Select an Azure credential for Purview authentication. + + Supported modes: + 1. CertificateCredential (if PURVIEW_USE_CERT_AUTH=true) + 2. InteractiveBrowserCredential (requires PURVIEW_CLIENT_APP_ID) + """ + client_id = _get_env("PURVIEW_CLIENT_APP_ID", required=True) + use_cert_auth = _get_env("PURVIEW_USE_CERT_AUTH", required=False, default="false").lower() == "true" + + if not client_id: + raise RuntimeError( + "PURVIEW_CLIENT_APP_ID is required for interactive browser authentication; " + "set PURVIEW_USE_CERT_AUTH=true for certificate mode instead." + ) + + if use_cert_auth: + tenant_id = _get_env("PURVIEW_TENANT_ID") + cert_path = _get_env("PURVIEW_CERT_PATH") + cert_password = _get_env("PURVIEW_CERT_PASSWORD", required=False, default=None) + print(f"Using Certificate Authentication (tenant: {tenant_id}, cert: {cert_path})") + return CertificateCredential( + tenant_id=tenant_id, + client_id=client_id, + certificate_path=cert_path, + password=cert_password, + ) + + print(f"Using Interactive Browser Authentication (client_id: {client_id})") + return InteractiveBrowserCredential(client_id=client_id) + + +async def run_with_agent_middleware() -> None: + endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + if not endpoint: + print("Skipping run: AZURE_OPENAI_ENDPOINT not set") + return + + deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4o-mini") + user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID") + chat_client = AzureOpenAIChatClient(deployment_name=deployment, endpoint=endpoint, credential=AzureCliCredential()) + + purview_agent_middleware = PurviewPolicyMiddleware( + build_credential(), + PurviewSettings( + app_name="Agent Framework Sample App", + ), + ) + + agent = ChatAgent( + chat_client=chat_client, + instructions=JOKER_INSTRUCTIONS, + name=JOKER_NAME, + middleware=purview_agent_middleware, + ) + + print("-- Agent Middleware Path --") + first: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="Tell me a joke about a pirate.", additional_properties={"user_id": user_id})) + print("First response (agent middleware):\n", first) + + second: AgentRunResponse = await agent.run(ChatMessage(role=Role.USER, text="That was funny. Tell me another one.", additional_properties={"user_id": user_id})) + print("Second response (agent middleware):\n", second) + + +async def run_with_chat_middleware() -> None: + endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + if not endpoint: + print("Skipping chat middleware run: AZURE_OPENAI_ENDPOINT not set") + return + + deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", default="gpt-4o-mini") + user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID") + + chat_client = AzureOpenAIChatClient( + deployment_name=deployment, + endpoint=endpoint, + credential=AzureCliCredential(), + middleware=[ + PurviewChatPolicyMiddleware( + build_credential(), + PurviewSettings( + app_name="Agent Framework Sample App (Chat)", + ), + ) + ], + ) + + agent = ChatAgent( + chat_client=chat_client, + instructions=JOKER_INSTRUCTIONS, + name=JOKER_NAME, + ) + + print("-- Chat Middleware Path --") + first: AgentRunResponse = await agent.run( + ChatMessage( + role=Role.USER, + text="Give me a short clean joke.", + additional_properties={"user_id": user_id}, + ) + ) + print("First response (chat middleware):\n", first) + + second: AgentRunResponse = await agent.run( + ChatMessage( + role=Role.USER, + text="One more please.", + additional_properties={"user_id": user_id}, + ) + ) + print("Second response (chat middleware):\n", second) + + +async def main() -> None: + print("== Purview Agent Sample (Agent & Chat Middleware) ==") + try: + await run_with_agent_middleware() + except Exception as ex: # pragma: no cover - demo resilience + print(f"Agent middleware path failed: {ex}") + + try: + await run_with_chat_middleware() + except Exception as ex: # pragma: no cover - demo resilience + print(f"Chat middleware path failed: {ex}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/uv.lock b/python/uv.lock index bfb8c600da..5ef7f01768 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -31,6 +31,7 @@ members = [ "agent-framework-devui", "agent-framework-lab", "agent-framework-mem0", + "agent-framework-purview", "agent-framework-redis", ] @@ -383,6 +384,23 @@ requires-dist = [ { name = "mem0ai", specifier = ">=0.1.117" }, ] +[[package]] +name = "agent-framework-purview" +version = "0.1.0b1" +source = { editable = "packages/purview" } +dependencies = [ + { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "azure-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[package.metadata] +requires-dist = [ + { name = "agent-framework-core", editable = "packages/core" }, + { name = "azure-core", specifier = ">=1.30.0" }, + { name = "httpx", specifier = ">=0.27.0" }, +] + [[package]] name = "agent-framework-redis" version = "1.0.0b251007"