mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add allowed_checkpoint_types support to CosmosCheckpointStorage for parity with FileCheckpointStorage (#5202)
* Python: Add allowed_checkpoint_types support to CosmosCheckpointStorage (#5200) Add allowed_checkpoint_types parameter to CosmosCheckpointStorage for parity with FileCheckpointStorage. This ensures both providers use the same restricted pickle deserialization by default. Changes: - Accept allowed_checkpoint_types kwarg in __init__, stored as frozenset - Convert _document_to_checkpoint from @staticmethod to instance method - Forward allowed_types to decode_checkpoint_value on all load paths - Update class docstring to describe the new parameter - Add tests covering built-in safe types, app type opt-in/blocking, and all load paths (load, list_checkpoints, get_latest) - Add changelog entry noting the breaking behavior change BREAKING CHANGE: CosmosCheckpointStorage now uses restricted pickle deserialization by default. Checkpoints containing application-defined types will require passing those types via allowed_checkpoint_types. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Add `allowed_checkpoint_types` support to `CosmosCheckpointStorage` for parity with `FileCheckpointStorage` Fixes #5200 * Address PR review: add pickle security warning and fix docstring examples - Reintroduce explicit security warning about pickle deserialization risks - Convert Example:: block to .. code-block:: python with imports for consistency with other docstring examples - Note: PR title should be updated to include [BREAKING] prefix per changelog convention (comment #3, requires GitHub UI change) Fixes #5200 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
b89adb280b
commit
1b95e8585d
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
- **agent-framework-azure-cosmos**: [BREAKING] `CosmosCheckpointStorage` now uses restricted pickle deserialization by default, matching `FileCheckpointStorage` behavior. If your checkpoints contain application-defined types, pass them via `allowed_checkpoint_types=["my_app.models:MyState"]`. ([#5200](https://github.com/microsoft/agent-framework/issues/5200))
|
||||
|
||||
## [1.0.1] - 2026-04-09
|
||||
|
||||
### Added
|
||||
|
||||
@@ -43,9 +43,34 @@ class CosmosCheckpointStorage:
|
||||
``FileCheckpointStorage``, allowing full Python object fidelity for
|
||||
complex workflow state while keeping the document structure human-readable.
|
||||
|
||||
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load
|
||||
checkpoints from trusted sources. Loading a malicious checkpoint can execute
|
||||
arbitrary code.
|
||||
Security warning: checkpoints use pickle for non-JSON-native values. Loading
|
||||
checkpoints from untrusted sources is unsafe and can execute arbitrary code
|
||||
during deserialization. The built-in deserialization restrictions reduce risk,
|
||||
but they do not make untrusted checkpoints safe to load. Extending
|
||||
``allowed_checkpoint_types`` may further increase risk and should only be done
|
||||
for trusted application types.
|
||||
|
||||
By default, checkpoint deserialization is restricted to a built-in set of safe
|
||||
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
|
||||
internal types. To allow additional application-specific types, pass them via
|
||||
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
from agent_framework_azure_cosmos import CosmosCheckpointStorage
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
endpoint="https://my-account.documents.azure.com:443/",
|
||||
credential=DefaultAzureCredential(),
|
||||
database_name="agent-db",
|
||||
container_name="checkpoints",
|
||||
allowed_checkpoint_types=[
|
||||
"my_app.models:MyState",
|
||||
],
|
||||
)
|
||||
|
||||
The database and container are created automatically on first use
|
||||
if they do not already exist. The container uses partition key
|
||||
@@ -97,6 +122,7 @@ class CosmosCheckpointStorage:
|
||||
container_client: ContainerProxy | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
allowed_checkpoint_types: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Azure Cosmos DB checkpoint storage.
|
||||
|
||||
@@ -129,10 +155,15 @@ class CosmosCheckpointStorage:
|
||||
container_client: Pre-created Cosmos container client.
|
||||
env_file_path: Path to environment file for loading settings.
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
allowed_checkpoint_types: Additional types (beyond the built-in safe set
|
||||
and framework types) that are permitted during checkpoint
|
||||
deserialization. Each entry should be a ``"module:qualname"``
|
||||
string (e.g., ``"my_app.models:MyState"``).
|
||||
"""
|
||||
self._cosmos_client: CosmosClient | None = cosmos_client
|
||||
self._container_proxy: ContainerProxy | None = container_client
|
||||
self._owns_client = False
|
||||
self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or [])
|
||||
|
||||
if self._container_proxy is not None:
|
||||
self.database_name: str = database_name or ""
|
||||
@@ -401,8 +432,7 @@ class CosmosCheckpointStorage:
|
||||
partition_key=PartitionKey(path="/workflow_name"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint:
|
||||
def _document_to_checkpoint(self, document: dict[str, Any]) -> WorkflowCheckpoint:
|
||||
"""Convert a Cosmos DB document back to a WorkflowCheckpoint.
|
||||
|
||||
Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``,
|
||||
@@ -413,7 +443,7 @@ class CosmosCheckpointStorage:
|
||||
cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"}
|
||||
cleaned = {k: v for k, v in document.items() if k not in cosmos_keys}
|
||||
|
||||
decoded = decode_checkpoint_value(cleaned)
|
||||
decoded = decode_checkpoint_value(cleaned, allowed_types=self._allowed_types)
|
||||
return WorkflowCheckpoint.from_dict(decoded)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -595,3 +596,142 @@ async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None:
|
||||
finally:
|
||||
with suppress(Exception):
|
||||
await cosmos_client.delete_database(database_name)
|
||||
|
||||
|
||||
# --- Tests for allowed_checkpoint_types ---
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AppState:
|
||||
"""Application-defined state type used to test allowed_checkpoint_types."""
|
||||
|
||||
label: str
|
||||
count: int
|
||||
|
||||
|
||||
_APP_STATE_TYPE_KEY = f"{_AppState.__module__}:{_AppState.__qualname__}"
|
||||
|
||||
|
||||
def _make_checkpoint_with_state(state: dict[str, Any]) -> WorkflowCheckpoint:
|
||||
"""Create a checkpoint with custom state for serialization tests."""
|
||||
return WorkflowCheckpoint(
|
||||
workflow_name="test-workflow",
|
||||
graph_signature_hash="abc123",
|
||||
timestamp="2025-01-01T00:00:00+00:00",
|
||||
state=state,
|
||||
iteration_count=1,
|
||||
)
|
||||
|
||||
|
||||
async def test_init_accepts_allowed_checkpoint_types(mock_container: MagicMock) -> None:
|
||||
"""CosmosCheckpointStorage.__init__ accepts allowed_checkpoint_types."""
|
||||
storage = CosmosCheckpointStorage(
|
||||
container_client=mock_container,
|
||||
allowed_checkpoint_types=["some.module:SomeType"],
|
||||
)
|
||||
assert storage is not None
|
||||
|
||||
|
||||
async def test_load_allows_builtin_safe_types(mock_container: MagicMock) -> None:
|
||||
"""Built-in safe types load without opt-in via allowed_checkpoint_types."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
checkpoint = _make_checkpoint_with_state({
|
||||
"ts": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
"tags": {1, 2, 3},
|
||||
})
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
loaded = await storage.load(checkpoint.checkpoint_id)
|
||||
|
||||
assert loaded.state["ts"] == datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
assert loaded.state["tags"] == {1, 2, 3}
|
||||
|
||||
|
||||
async def test_load_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
|
||||
"""Application types are blocked when not listed in allowed_checkpoint_types."""
|
||||
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
|
||||
await storage.load(checkpoint.checkpoint_id)
|
||||
|
||||
|
||||
async def test_load_allows_listed_app_type(mock_container: MagicMock) -> None:
|
||||
"""Application types are allowed when listed in allowed_checkpoint_types."""
|
||||
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=7)})
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
container_client=mock_container,
|
||||
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
|
||||
)
|
||||
loaded = await storage.load(checkpoint.checkpoint_id)
|
||||
|
||||
assert isinstance(loaded.state["data"], _AppState)
|
||||
assert loaded.state["data"].label == "ok"
|
||||
assert loaded.state["data"].count == 7
|
||||
|
||||
|
||||
async def test_list_checkpoints_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
|
||||
"""list_checkpoints skips documents with unlisted application types."""
|
||||
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
results = await storage.list_checkpoints(workflow_name="test-workflow")
|
||||
|
||||
# The document is skipped (logged as warning) because the type is blocked
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
async def test_list_checkpoints_allows_listed_app_type(mock_container: MagicMock) -> None:
|
||||
"""list_checkpoints decodes documents with listed application types."""
|
||||
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=3)})
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
container_client=mock_container,
|
||||
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
|
||||
)
|
||||
results = await storage.list_checkpoints(workflow_name="test-workflow")
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0].state["data"], _AppState)
|
||||
|
||||
|
||||
async def test_get_latest_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
|
||||
"""get_latest raises when the checkpoint contains an unlisted application type."""
|
||||
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
|
||||
await storage.get_latest(workflow_name="test-workflow")
|
||||
|
||||
|
||||
async def test_get_latest_allows_listed_app_type(mock_container: MagicMock) -> None:
|
||||
"""get_latest decodes checkpoints with listed application types."""
|
||||
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="latest", count=9)})
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
container_client=mock_container,
|
||||
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
|
||||
)
|
||||
result = await storage.get_latest(workflow_name="test-workflow")
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result.state["data"], _AppState)
|
||||
assert result.state["data"].label == "latest"
|
||||
|
||||
Reference in New Issue
Block a user