Python: Add allowed_checkpoint_types support to CosmosCheckpointStorage for parity with FileCheckpointStorage (#5202)

* Python: Add allowed_checkpoint_types support to CosmosCheckpointStorage (#5200)

Add allowed_checkpoint_types parameter to CosmosCheckpointStorage for
parity with FileCheckpointStorage. This ensures both providers use the
same restricted pickle deserialization by default.

Changes:
- Accept allowed_checkpoint_types kwarg in __init__, stored as frozenset
- Convert _document_to_checkpoint from @staticmethod to instance method
- Forward allowed_types to decode_checkpoint_value on all load paths
- Update class docstring to describe the new parameter
- Add tests covering built-in safe types, app type opt-in/blocking,
  and all load paths (load, list_checkpoints, get_latest)
- Add changelog entry noting the breaking behavior change

BREAKING CHANGE: CosmosCheckpointStorage now uses restricted pickle
deserialization by default. Checkpoints containing application-defined
types will require passing those types via allowed_checkpoint_types.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Add `allowed_checkpoint_types` support to `CosmosCheckpointStorage` for parity with `FileCheckpointStorage`

Fixes #5200

* Address PR review: add pickle security warning and fix docstring examples

- Reintroduce explicit security warning about pickle deserialization risks
- Convert Example:: block to .. code-block:: python with imports for
  consistency with other docstring examples
- Note: PR title should be updated to include [BREAKING] prefix per
  changelog convention (comment #3, requires GitHub UI change)

Fixes #5200

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Evan Mattson
2026-04-14 11:20:55 +09:00
committed by GitHub
Unverified
parent b89adb280b
commit 1b95e8585d
3 changed files with 179 additions and 6 deletions
+3
View File
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Changed
- **agent-framework-azure-cosmos**: [BREAKING] `CosmosCheckpointStorage` now uses restricted pickle deserialization by default, matching `FileCheckpointStorage` behavior. If your checkpoints contain application-defined types, pass them via `allowed_checkpoint_types=["my_app.models:MyState"]`. ([#5200](https://github.com/microsoft/agent-framework/issues/5200))
## [1.0.1] - 2026-04-09
### Added
@@ -43,9 +43,34 @@ class CosmosCheckpointStorage:
``FileCheckpointStorage``, allowing full Python object fidelity for
complex workflow state while keeping the document structure human-readable.
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load
checkpoints from trusted sources. Loading a malicious checkpoint can execute
arbitrary code.
Security warning: checkpoints use pickle for non-JSON-native values. Loading
checkpoints from untrusted sources is unsafe and can execute arbitrary code
during deserialization. The built-in deserialization restrictions reduce risk,
but they do not make untrusted checkpoints safe to load. Extending
``allowed_checkpoint_types`` may further increase risk and should only be done
for trusted application types.
By default, checkpoint deserialization is restricted to a built-in set of safe
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. To allow additional application-specific types, pass them via
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
Example:
.. code-block:: python
from azure.identity.aio import DefaultAzureCredential
from agent_framework_azure_cosmos import CosmosCheckpointStorage
storage = CosmosCheckpointStorage(
endpoint="https://my-account.documents.azure.com:443/",
credential=DefaultAzureCredential(),
database_name="agent-db",
container_name="checkpoints",
allowed_checkpoint_types=[
"my_app.models:MyState",
],
)
The database and container are created automatically on first use
if they do not already exist. The container uses partition key
@@ -97,6 +122,7 @@ class CosmosCheckpointStorage:
container_client: ContainerProxy | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
allowed_checkpoint_types: list[str] | None = None,
) -> None:
"""Initialize the Azure Cosmos DB checkpoint storage.
@@ -129,10 +155,15 @@ class CosmosCheckpointStorage:
container_client: Pre-created Cosmos container client.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
allowed_checkpoint_types: Additional types (beyond the built-in safe set
and framework types) that are permitted during checkpoint
deserialization. Each entry should be a ``"module:qualname"``
string (e.g., ``"my_app.models:MyState"``).
"""
self._cosmos_client: CosmosClient | None = cosmos_client
self._container_proxy: ContainerProxy | None = container_client
self._owns_client = False
self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or [])
if self._container_proxy is not None:
self.database_name: str = database_name or ""
@@ -401,8 +432,7 @@ class CosmosCheckpointStorage:
partition_key=PartitionKey(path="/workflow_name"),
)
@staticmethod
def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint:
def _document_to_checkpoint(self, document: dict[str, Any]) -> WorkflowCheckpoint:
"""Convert a Cosmos DB document back to a WorkflowCheckpoint.
Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``,
@@ -413,7 +443,7 @@ class CosmosCheckpointStorage:
cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"}
cleaned = {k: v for k, v in document.items() if k not in cosmos_keys}
decoded = decode_checkpoint_value(cleaned)
decoded = decode_checkpoint_value(cleaned, allowed_types=self._allowed_types)
return WorkflowCheckpoint.from_dict(decoded)
@staticmethod
@@ -6,6 +6,7 @@ import os
import uuid
from collections.abc import AsyncIterator
from contextlib import suppress
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@@ -595,3 +596,142 @@ async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None:
finally:
with suppress(Exception):
await cosmos_client.delete_database(database_name)
# --- Tests for allowed_checkpoint_types ---
@dataclass
class _AppState:
"""Application-defined state type used to test allowed_checkpoint_types."""
label: str
count: int
_APP_STATE_TYPE_KEY = f"{_AppState.__module__}:{_AppState.__qualname__}"
def _make_checkpoint_with_state(state: dict[str, Any]) -> WorkflowCheckpoint:
"""Create a checkpoint with custom state for serialization tests."""
return WorkflowCheckpoint(
workflow_name="test-workflow",
graph_signature_hash="abc123",
timestamp="2025-01-01T00:00:00+00:00",
state=state,
iteration_count=1,
)
async def test_init_accepts_allowed_checkpoint_types(mock_container: MagicMock) -> None:
"""CosmosCheckpointStorage.__init__ accepts allowed_checkpoint_types."""
storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=["some.module:SomeType"],
)
assert storage is not None
async def test_load_allows_builtin_safe_types(mock_container: MagicMock) -> None:
"""Built-in safe types load without opt-in via allowed_checkpoint_types."""
from datetime import datetime, timezone
checkpoint = _make_checkpoint_with_state({
"ts": datetime(2025, 1, 1, tzinfo=timezone.utc),
"tags": {1, 2, 3},
})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(container_client=mock_container)
loaded = await storage.load(checkpoint.checkpoint_id)
assert loaded.state["ts"] == datetime(2025, 1, 1, tzinfo=timezone.utc)
assert loaded.state["tags"] == {1, 2, 3}
async def test_load_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
"""Application types are blocked when not listed in allowed_checkpoint_types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(container_client=mock_container)
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
await storage.load(checkpoint.checkpoint_id)
async def test_load_allows_listed_app_type(mock_container: MagicMock) -> None:
"""Application types are allowed when listed in allowed_checkpoint_types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=7)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
)
loaded = await storage.load(checkpoint.checkpoint_id)
assert isinstance(loaded.state["data"], _AppState)
assert loaded.state["data"].label == "ok"
assert loaded.state["data"].count == 7
async def test_list_checkpoints_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
"""list_checkpoints skips documents with unlisted application types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(container_client=mock_container)
results = await storage.list_checkpoints(workflow_name="test-workflow")
# The document is skipped (logged as warning) because the type is blocked
assert len(results) == 0
async def test_list_checkpoints_allows_listed_app_type(mock_container: MagicMock) -> None:
"""list_checkpoints decodes documents with listed application types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=3)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
)
results = await storage.list_checkpoints(workflow_name="test-workflow")
assert len(results) == 1
assert isinstance(results[0].state["data"], _AppState)
async def test_get_latest_blocks_unlisted_app_type(mock_container: MagicMock) -> None:
"""get_latest raises when the checkpoint contains an unlisted application type."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(container_client=mock_container)
with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"):
await storage.get_latest(workflow_name="test-workflow")
async def test_get_latest_allows_listed_app_type(mock_container: MagicMock) -> None:
"""get_latest decodes checkpoints with listed application types."""
checkpoint = _make_checkpoint_with_state({"data": _AppState(label="latest", count=9)})
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(
container_client=mock_container,
allowed_checkpoint_types=[_APP_STATE_TYPE_KEY],
)
result = await storage.get_latest(workflow_name="test-workflow")
assert result is not None
assert isinstance(result.state["data"], _AppState)
assert result.state["data"].label == "latest"