Python: Add Cosmos DB NoSQL Checkpoint Storage for Python Workflows (#4916)

* Add CosmosCheckpointStorage for Python workflow checkpointing

Add native Cosmos DB NoSQL support for workflow checkpoint storage in the
Python agent-framework-azure-cosmos package, achieving parity with the
existing .NET CosmosCheckpointStore.

New files:
- _checkpoint_storage.py: CosmosCheckpointStorage implementing the
  CheckpointStorage protocol with 6 methods (save, load, list_checkpoints,
  delete, get_latest, list_checkpoint_ids)
- test_cosmos_checkpoint_storage.py: Unit and integration tests
- workflow_checkpointing.py: Sample demonstrating Cosmos DB-backed
  workflow checkpoint/resume

Auth support:
- Managed identity / RBAC via Azure credential objects
  (DefaultAzureCredential, ManagedIdentityCredential, etc.)
- Key-based auth via account key string or AZURE_COSMOS_KEY env var
- Pre-created CosmosClient or ContainerProxy

Key design decisions:
- Partition key: /workflow_name for efficient per-workflow queries
- Serialization: Reuses encode/decode_checkpoint_value for full Python
  object fidelity (hybrid JSON + pickle approach)
- Container auto-creation via create_container_if_not_exists

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Adding cosmos checkpointer

* Resolving comments

* Fixing builds

* Adding sample for history provider and checkpoint storage

* Resolving comments

* fixing builds

* Resolving comments

---------

Co-authored-by: Aayush Kataria <aayushkataria@Aayushs-MacBook-Pro-2.local>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
Aayush Kataria
2026-04-08 22:01:41 -07:00
committed by GitHub
Unverified
parent a7a02c1abd
commit 30a2bc3dcb
11 changed files with 1989 additions and 8 deletions
+86 -7
View File
@@ -14,7 +14,7 @@ The Azure Cosmos DB integration provides `CosmosHistoryProvider` for persistent
```python
from azure.identity.aio import DefaultAzureCredential
from agent_framework.azure import CosmosHistoryProvider
from agent_framework_azure_cosmos import CosmosHistoryProvider
provider = CosmosHistoryProvider(
endpoint="https://<account>.documents.azure.com:443/",
@@ -35,13 +35,92 @@ Container naming behavior:
- Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`)
- `session_id` is used as the Cosmos partition key for reads/writes
See the [conversation samples](../../samples/02-agents/conversations/) for runnable examples, including
[`cosmos_history_provider.py`](../../samples/02-agents/conversations/cosmos_history_provider.py).
See `samples/02-agents/conversations/cosmos_history_provider.py` for a runnable example.
## Import Paths
## Cosmos DB Workflow Checkpoint Storage
`CosmosCheckpointStorage` implements the `CheckpointStorage` protocol, enabling
durable workflow checkpointing backed by Azure Cosmos DB NoSQL. Workflows can be
paused and resumed across process restarts by persisting checkpoint state in Cosmos DB.
### Basic Usage
#### Managed Identity / RBAC (recommended for production)
```python
from agent_framework.azure import CosmosHistoryProvider
# or directly:
from agent_framework_azure_cosmos import CosmosHistoryProvider
from azure.identity.aio import DefaultAzureCredential
from agent_framework import WorkflowBuilder
from agent_framework_azure_cosmos import CosmosCheckpointStorage
checkpoint_storage = CosmosCheckpointStorage(
endpoint="https://<account>.documents.azure.com:443/",
credential=DefaultAzureCredential(),
database_name="agent-framework",
container_name="workflow-checkpoints",
)
```
#### Account Key
```python
from agent_framework_azure_cosmos import CosmosCheckpointStorage
checkpoint_storage = CosmosCheckpointStorage(
endpoint="https://<account>.documents.azure.com:443/",
credential="<your-account-key>",
database_name="agent-framework",
container_name="workflow-checkpoints",
)
```
#### Then use with a workflow
```python
from agent_framework import WorkflowBuilder
# Build a workflow with checkpointing enabled
workflow = WorkflowBuilder(
start_executor=start,
checkpoint_storage=checkpoint_storage,
).build()
# Run the workflow — checkpoints are automatically saved after each superstep
result = await workflow.run(message="input data")
# Resume from a checkpoint
latest = await checkpoint_storage.get_latest(workflow_name=workflow.name)
if latest:
resumed = await workflow.run(checkpoint_id=latest.checkpoint_id)
```
### Authentication Options
`CosmosCheckpointStorage` supports the same authentication modes as `CosmosHistoryProvider`:
- **Managed identity / RBAC** (recommended): Pass `DefaultAzureCredential()`,
`ManagedIdentityCredential()`, or any Azure `TokenCredential`
- **Account key**: Pass a key string via `credential` parameter
- **Environment variables**: Set `AZURE_COSMOS_ENDPOINT`, `AZURE_COSMOS_DATABASE_NAME`,
`AZURE_COSMOS_CONTAINER_NAME`, and `AZURE_COSMOS_KEY` (key not required when using
Azure credentials)
- **Pre-created client**: Pass an existing `CosmosClient` or `ContainerProxy`
### Database and Container Setup
The database and container are created automatically on first use (via
`create_database_if_not_exists` and `create_container_if_not_exists`). The container
uses `/workflow_name` as the partition key. You can also pre-create them in the Azure
portal with this partition key configuration.
### Environment Variables
| Variable | Description |
|---|---|
| `AZURE_COSMOS_ENDPOINT` | Cosmos DB account endpoint |
| `AZURE_COSMOS_DATABASE_NAME` | Database name |
| `AZURE_COSMOS_CONTAINER_NAME` | Container name |
| `AZURE_COSMOS_KEY` | Account key (optional if using Azure credentials) |
See `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing.py` for a standalone example,
or `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py` for an end-to-end
example with Azure AI Foundry agents.
@@ -2,6 +2,7 @@
import importlib.metadata
from ._checkpoint_storage import CosmosCheckpointStorage
from ._history_provider import CosmosHistoryProvider
try:
@@ -10,6 +11,7 @@ except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = [
"CosmosCheckpointStorage",
"CosmosHistoryProvider",
"__version__",
]
@@ -0,0 +1,432 @@
# Copyright (c) Microsoft. All rights reserved.
"""Azure Cosmos DB checkpoint storage for workflow checkpointing."""
from __future__ import annotations
import logging
from typing import Any, TypedDict
from agent_framework import AGENT_FRAMEWORK_USER_AGENT
from agent_framework._settings import SecretString, load_settings
from agent_framework._workflows._checkpoint import CheckpointID, WorkflowCheckpoint
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from agent_framework.exceptions import WorkflowCheckpointException
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.cosmos import PartitionKey
from azure.cosmos.aio import ContainerProxy, CosmosClient
from azure.cosmos.exceptions import CosmosResourceNotFoundError
AzureCredentialTypes = TokenCredential | AsyncTokenCredential
logger = logging.getLogger(__name__)
class AzureCosmosCheckpointSettings(TypedDict, total=False):
"""Settings for CosmosCheckpointStorage resolved from args and environment."""
endpoint: str | None
database_name: str | None
container_name: str | None
key: SecretString | None
class CosmosCheckpointStorage:
"""Azure Cosmos DB-backed checkpoint storage for workflow checkpointing.
Implements the ``CheckpointStorage`` protocol using Azure Cosmos DB NoSQL
as the persistent backend. Checkpoints are stored as JSON documents with
``workflow_name`` as the partition key, enabling efficient per-workflow queries.
This storage uses the same hybrid JSON + pickle encoding as
``FileCheckpointStorage``, allowing full Python object fidelity for
complex workflow state while keeping the document structure human-readable.
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load
checkpoints from trusted sources. Loading a malicious checkpoint can execute
arbitrary code.
The database and container are created automatically on first use
if they do not already exist. The container uses partition key
``/workflow_name``.
Example using managed identity / RBAC:
.. code-block:: python
from azure.identity.aio import DefaultAzureCredential
from agent_framework_azure_cosmos import CosmosCheckpointStorage
storage = CosmosCheckpointStorage(
endpoint="https://my-account.documents.azure.com:443/",
credential=DefaultAzureCredential(),
database_name="agent-db",
container_name="checkpoints",
)
Example using account key:
.. code-block:: python
storage = CosmosCheckpointStorage(
endpoint="https://my-account.documents.azure.com:443/",
credential="my-account-key",
database_name="agent-db",
container_name="checkpoints",
)
Then use with a workflow builder:
.. code-block:: python
workflow = WorkflowBuilder(
start_executor=start,
checkpoint_storage=storage,
).build()
"""
def __init__(
self,
*,
endpoint: str | None = None,
database_name: str | None = None,
container_name: str | None = None,
credential: str | AzureCredentialTypes | None = None,
cosmos_client: CosmosClient | None = None,
container_client: ContainerProxy | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize the Azure Cosmos DB checkpoint storage.
Supports multiple authentication modes:
- **Container client** (``container_client``): Use a pre-created
Cosmos async container proxy. No client lifecycle is managed.
- **Cosmos client** (``cosmos_client``): Use a pre-created Cosmos
async client. The caller is responsible for closing it.
- **Endpoint + credential**: Create a new Cosmos client. The storage
owns the client and closes it on ``close()``.
- **Environment variables**: Falls back to ``AZURE_COSMOS_ENDPOINT``,
``AZURE_COSMOS_DATABASE_NAME``, ``AZURE_COSMOS_CONTAINER_NAME``,
and ``AZURE_COSMOS_KEY``.
Args:
endpoint: Cosmos DB account endpoint.
Can be set via ``AZURE_COSMOS_ENDPOINT``.
database_name: Cosmos DB database name.
Can be set via ``AZURE_COSMOS_DATABASE_NAME``.
container_name: Cosmos DB container name.
Can be set via ``AZURE_COSMOS_CONTAINER_NAME``.
credential: Credential to authenticate with Cosmos DB.
For **managed identity / RBAC**, pass an Azure credential object
such as ``DefaultAzureCredential()`` or
``ManagedIdentityCredential()``.
For **key-based auth**, pass the account key as a string,
or set ``AZURE_COSMOS_KEY`` in the environment.
cosmos_client: Pre-created Cosmos async client.
container_client: Pre-created Cosmos container client.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
"""
self._cosmos_client: CosmosClient | None = cosmos_client
self._container_proxy: ContainerProxy | None = container_client
self._owns_client = False
if self._container_proxy is not None:
self.database_name: str = database_name or ""
self.container_name: str = container_name or ""
return
required_fields: list[str] = ["database_name", "container_name"]
if cosmos_client is None:
required_fields.append("endpoint")
if credential is None:
required_fields.append("key")
settings = load_settings(
AzureCosmosCheckpointSettings,
env_prefix="AZURE_COSMOS_",
required_fields=required_fields,
endpoint=endpoint,
database_name=database_name,
container_name=container_name,
key=credential if isinstance(credential, str) else None,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self.database_name = settings["database_name"] # type: ignore[assignment]
self.container_name = settings["container_name"] # type: ignore[assignment]
if self._cosmos_client is None:
self._cosmos_client = CosmosClient(
url=settings["endpoint"], # type: ignore[arg-type]
credential=credential or settings["key"].get_secret_value(), # type: ignore[arg-type,union-attr]
user_agent_suffix=AGENT_FRAMEWORK_USER_AGENT,
)
self._owns_client = True
async def save(self, checkpoint: WorkflowCheckpoint) -> CheckpointID:
"""Save a checkpoint to Cosmos DB and return its ID.
The checkpoint is encoded to a JSON-compatible form (using pickle for
non-JSON-native values) and stored as a Cosmos DB document with the
``workflow_name`` as the partition key.
The document ``id`` is a composite of ``workflow_name`` and
``checkpoint_id`` to ensure global uniqueness across partitions.
Args:
checkpoint: The WorkflowCheckpoint object to save.
Returns:
The unique ID of the saved checkpoint.
"""
await self._ensure_container_proxy()
checkpoint_dict = checkpoint.to_dict()
encoded = encode_checkpoint_value(checkpoint_dict)
document: dict[str, Any] = {
"id": self._make_document_id(checkpoint.workflow_name, checkpoint.checkpoint_id),
"workflow_name": checkpoint.workflow_name,
**encoded,
}
await self._container_proxy.upsert_item(body=document) # type: ignore[union-attr]
logger.info("Saved checkpoint %s to Cosmos DB", checkpoint.checkpoint_id)
return checkpoint.checkpoint_id
async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint:
"""Load a checkpoint from Cosmos DB by ID.
Args:
checkpoint_id: The unique ID of the checkpoint to load.
Returns:
The WorkflowCheckpoint object corresponding to the given ID.
Raises:
WorkflowCheckpointException: If no checkpoint with the given ID exists,
or if multiple checkpoints share the same ID across workflows.
"""
await self._ensure_container_proxy()
query = "SELECT * FROM c WHERE c.checkpoint_id = @checkpoint_id"
parameters: list[dict[str, object]] = [
{"name": "@checkpoint_id", "value": checkpoint_id},
]
items = self._container_proxy.query_items( # type: ignore[union-attr]
query=query,
parameters=parameters,
)
results: list[dict[str, Any]] = []
async for item in items:
results.append(item)
if not results:
raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}")
if len(results) > 1:
workflow_names = [r.get("workflow_name", "unknown") for r in results]
raise WorkflowCheckpointException(
f"Multiple checkpoints found with ID {checkpoint_id} across workflows: "
f"{workflow_names}. Use list_checkpoints(workflow_name=...) to query "
f"by workflow instead."
)
return self._document_to_checkpoint(results[0])
async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]:
"""List checkpoint objects for a given workflow name.
Args:
workflow_name: The name of the workflow to list checkpoints for.
Returns:
A list of WorkflowCheckpoint objects for the specified workflow name.
"""
await self._ensure_container_proxy()
query = "SELECT * FROM c WHERE c.workflow_name = @workflow_name ORDER BY c.timestamp ASC"
parameters: list[dict[str, object]] = [
{"name": "@workflow_name", "value": workflow_name},
]
items = self._container_proxy.query_items( # type: ignore[union-attr]
query=query,
parameters=parameters,
partition_key=workflow_name,
)
checkpoints: list[WorkflowCheckpoint] = []
async for item in items:
try:
checkpoints.append(self._document_to_checkpoint(item))
except Exception as e:
logger.warning("Failed to decode checkpoint document: %s", e)
return checkpoints
async def delete(self, checkpoint_id: CheckpointID) -> bool:
"""Delete a checkpoint from Cosmos DB by ID.
Args:
checkpoint_id: The unique ID of the checkpoint to delete.
Returns:
True if the checkpoint was successfully deleted, False if not found.
"""
await self._ensure_container_proxy()
query = "SELECT c.id, c.workflow_name FROM c WHERE c.checkpoint_id = @checkpoint_id"
parameters: list[dict[str, object]] = [
{"name": "@checkpoint_id", "value": checkpoint_id},
]
items = self._container_proxy.query_items( # type: ignore[union-attr]
query=query,
parameters=parameters,
)
async for item in items:
try:
await self._container_proxy.delete_item( # type: ignore[union-attr]
item=item["id"],
partition_key=item["workflow_name"],
)
logger.info("Deleted checkpoint %s from Cosmos DB", checkpoint_id)
return True
except CosmosResourceNotFoundError:
return False
return False
async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
"""Get the latest checkpoint for a given workflow name.
Args:
workflow_name: The name of the workflow to get the latest checkpoint for.
Returns:
The latest WorkflowCheckpoint, or None if no checkpoints exist.
"""
await self._ensure_container_proxy()
query = (
"SELECT * FROM c WHERE c.workflow_name = @workflow_name "
"ORDER BY c.timestamp DESC OFFSET 0 LIMIT 1"
)
parameters: list[dict[str, object]] = [
{"name": "@workflow_name", "value": workflow_name},
]
items = self._container_proxy.query_items( # type: ignore[union-attr]
query=query,
parameters=parameters,
partition_key=workflow_name,
)
async for item in items:
checkpoint = self._document_to_checkpoint(item)
logger.debug(
"Latest checkpoint for workflow %s is %s",
workflow_name,
checkpoint.checkpoint_id,
)
return checkpoint
return None
async def list_checkpoint_ids(self, *, workflow_name: str) -> list[CheckpointID]:
"""List checkpoint IDs for a given workflow name.
Args:
workflow_name: The name of the workflow to list checkpoint IDs for.
Returns:
A list of checkpoint IDs for the specified workflow name.
"""
await self._ensure_container_proxy()
query = (
"SELECT c.checkpoint_id FROM c WHERE c.workflow_name = @workflow_name "
"ORDER BY c.timestamp ASC"
)
parameters: list[dict[str, object]] = [
{"name": "@workflow_name", "value": workflow_name},
]
items = self._container_proxy.query_items( # type: ignore[union-attr]
query=query,
parameters=parameters,
partition_key=workflow_name,
)
checkpoint_ids: list[CheckpointID] = []
async for item in items:
cid = item.get("checkpoint_id")
if isinstance(cid, str):
checkpoint_ids.append(cid)
return checkpoint_ids
async def close(self) -> None:
"""Close the underlying Cosmos client when this storage owns it."""
if self._owns_client and self._cosmos_client is not None:
await self._cosmos_client.close()
async def __aenter__(self) -> CosmosCheckpointStorage:
"""Async context manager entry."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Async context manager exit."""
try:
await self.close()
except Exception:
if exc_type is None:
raise
async def _ensure_container_proxy(self) -> None:
"""Get or create the Cosmos DB database and container for storing checkpoints."""
if self._container_proxy is not None:
return
if self._cosmos_client is None:
raise RuntimeError("Cosmos client is not initialized.")
database = await self._cosmos_client.create_database_if_not_exists(id=self.database_name)
self._container_proxy = await database.create_container_if_not_exists(
id=self.container_name,
partition_key=PartitionKey(path="/workflow_name"),
)
@staticmethod
def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint:
"""Convert a Cosmos DB document back to a WorkflowCheckpoint.
Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``,
``_attachments``, ``_ts``) before decoding.
"""
# Remove Cosmos DB system properties and the composite 'id' field
# (checkpoints use 'checkpoint_id', not 'id')
cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"}
cleaned = {k: v for k, v in document.items() if k not in cosmos_keys}
decoded = decode_checkpoint_value(cleaned)
return WorkflowCheckpoint.from_dict(decoded)
@staticmethod
def _make_document_id(workflow_name: str, checkpoint_id: str) -> str:
"""Create a composite Cosmos DB document ID.
Combines ``workflow_name`` and ``checkpoint_id`` to ensure global
uniqueness across partitions.
"""
return f"{workflow_name}_{checkpoint_id}"
@@ -0,0 +1,599 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import os
import uuid
from collections.abc import AsyncIterator
from contextlib import suppress
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value
from agent_framework.exceptions import SettingNotFoundError, WorkflowCheckpointException
from azure.cosmos.aio import CosmosClient
from azure.cosmos.exceptions import CosmosResourceNotFoundError
import agent_framework_azure_cosmos._checkpoint_storage as checkpoint_storage_module
from agent_framework_azure_cosmos._checkpoint_storage import CosmosCheckpointStorage
skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif(
any(
os.getenv(name, "") == ""
for name in (
"AZURE_COSMOS_ENDPOINT",
"AZURE_COSMOS_KEY",
"AZURE_COSMOS_DATABASE_NAME",
"AZURE_COSMOS_CONTAINER_NAME",
)
),
reason=(
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_KEY, AZURE_COSMOS_DATABASE_NAME, and "
"AZURE_COSMOS_CONTAINER_NAME are required for Cosmos integration tests."
),
)
def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]:
async def _iterator() -> AsyncIterator[Any]:
for item in items:
yield item
return _iterator()
def _make_checkpoint(
workflow_name: str = "test-workflow",
checkpoint_id: str | None = None,
previous_checkpoint_id: str | None = None,
timestamp: str | None = None,
) -> WorkflowCheckpoint:
"""Create a minimal WorkflowCheckpoint for testing."""
return WorkflowCheckpoint(
workflow_name=workflow_name,
graph_signature_hash="abc123",
checkpoint_id=checkpoint_id or str(uuid.uuid4()),
previous_checkpoint_id=previous_checkpoint_id,
timestamp=timestamp or "2025-01-01T00:00:00+00:00",
state={"counter": 42},
iteration_count=1,
)
def _checkpoint_to_cosmos_document(checkpoint: WorkflowCheckpoint) -> dict[str, Any]:
"""Simulate what a Cosmos DB document looks like after save."""
encoded = encode_checkpoint_value(checkpoint.to_dict())
doc: dict[str, Any] = {
"id": f"{checkpoint.workflow_name}_{checkpoint.checkpoint_id}",
"workflow_name": checkpoint.workflow_name,
**encoded,
# Cosmos system properties
"_rid": "abc",
"_self": "dbs/abc/colls/def/docs/ghi",
"_etag": '"00000000-0000-0000-0000-000000000000"',
"_attachments": "attachments/",
"_ts": 1700000000,
}
return doc
@pytest.fixture
def mock_container() -> MagicMock:
container = MagicMock()
container.query_items = MagicMock(return_value=_to_async_iter([]))
container.upsert_item = AsyncMock(return_value={})
container.delete_item = AsyncMock(return_value={})
return container
@pytest.fixture
def mock_cosmos_client(mock_container: MagicMock) -> MagicMock:
database_client = MagicMock()
database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container)
client = MagicMock()
client.create_database_if_not_exists = AsyncMock(return_value=database_client)
client.close = AsyncMock()
return client
# --- Tests for initialization ---
async def test_init_uses_provided_container_client(mock_container: MagicMock) -> None:
storage = CosmosCheckpointStorage(container_client=mock_container)
assert storage.database_name == ""
assert storage.container_name == ""
async def test_init_uses_provided_cosmos_client(mock_cosmos_client: MagicMock) -> None:
storage = CosmosCheckpointStorage(
cosmos_client=mock_cosmos_client,
database_name="db1",
container_name="checkpoints",
)
assert storage.database_name == "db1"
assert storage.container_name == "checkpoints"
async def test_init_missing_required_settings_raises(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False)
monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False)
monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False)
monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False)
with pytest.raises(SettingNotFoundError, match="database_name"):
CosmosCheckpointStorage()
async def test_init_constructs_client_with_credential(
monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
) -> None:
"""Uses key-based auth when a key string is provided, otherwise falls back to Azure credential (RBAC)."""
mock_factory = MagicMock(return_value=mock_cosmos_client)
monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory)
monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False)
# Simulate real-world pattern: use key if available, else RBAC credential
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
credential: Any = cosmos_key if cosmos_key else MagicMock() # MagicMock simulates DefaultAzureCredential()
CosmosCheckpointStorage(
endpoint="https://account.documents.azure.com:443/",
credential=credential,
database_name="db1",
container_name="checkpoints",
)
mock_factory.assert_called_once()
kwargs = mock_factory.call_args.kwargs
assert kwargs["url"] == "https://account.documents.azure.com:443/"
assert kwargs["credential"] is credential
async def test_init_creates_database_and_container(mock_cosmos_client: MagicMock) -> None:
storage = CosmosCheckpointStorage(
cosmos_client=mock_cosmos_client,
database_name="db1",
container_name="custom-checkpoints",
)
await storage.list_checkpoint_ids(workflow_name="wf")
mock_cosmos_client.create_database_if_not_exists.assert_awaited_once_with(id="db1")
database_client = mock_cosmos_client.create_database_if_not_exists.return_value
assert database_client.create_container_if_not_exists.await_count == 1
kwargs = database_client.create_container_if_not_exists.await_args.kwargs
assert kwargs["id"] == "custom-checkpoints"
# --- Tests for save ---
async def test_save_upserts_document(mock_container: MagicMock) -> None:
storage = CosmosCheckpointStorage(container_client=mock_container)
checkpoint = _make_checkpoint()
result = await storage.save(checkpoint)
assert result == checkpoint.checkpoint_id
mock_container.upsert_item.assert_awaited_once()
document = mock_container.upsert_item.await_args.kwargs["body"]
assert document["id"] == f"test-workflow_{checkpoint.checkpoint_id}"
assert document["workflow_name"] == "test-workflow"
assert document["graph_signature_hash"] == "abc123"
assert document["state"]["counter"] == 42
async def test_save_returns_checkpoint_id(mock_container: MagicMock) -> None:
storage = CosmosCheckpointStorage(container_client=mock_container)
checkpoint = _make_checkpoint(checkpoint_id="cp-123")
result = await storage.save(checkpoint)
assert result == "cp-123"
# --- Tests for load ---
async def test_load_returns_checkpoint(mock_container: MagicMock) -> None:
checkpoint = _make_checkpoint(checkpoint_id="cp-load")
doc = _checkpoint_to_cosmos_document(checkpoint)
mock_container.query_items.return_value = _to_async_iter([doc])
storage = CosmosCheckpointStorage(container_client=mock_container)
loaded = await storage.load("cp-load")
assert loaded.checkpoint_id == "cp-load"
assert loaded.workflow_name == "test-workflow"
assert loaded.graph_signature_hash == "abc123"
assert loaded.state["counter"] == 42
async def test_load_nonexistent_raises(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
with pytest.raises(WorkflowCheckpointException, match="No checkpoint found"):
await storage.load("nonexistent-id")
async def test_load_queries_without_partition_key(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
with suppress(WorkflowCheckpointException):
await storage.load("cp-id")
kwargs = mock_container.query_items.call_args.kwargs
assert "partition_key" not in kwargs
async def test_load_multiple_workflows_same_checkpoint_id_raises(mock_container: MagicMock) -> None:
cp1 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-a")
cp2 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-b")
mock_container.query_items.return_value = _to_async_iter([
_checkpoint_to_cosmos_document(cp1),
_checkpoint_to_cosmos_document(cp2),
])
storage = CosmosCheckpointStorage(container_client=mock_container)
with pytest.raises(WorkflowCheckpointException, match="Multiple checkpoints found"):
await storage.load("shared-id")
# --- Tests for list_checkpoints ---
async def test_list_checkpoints_returns_checkpoints_for_workflow(mock_container: MagicMock) -> None:
cp1 = _make_checkpoint(checkpoint_id="cp-1", timestamp="2025-01-01T00:00:00+00:00")
cp2 = _make_checkpoint(checkpoint_id="cp-2", timestamp="2025-01-02T00:00:00+00:00")
mock_container.query_items.return_value = _to_async_iter([
_checkpoint_to_cosmos_document(cp1),
_checkpoint_to_cosmos_document(cp2),
])
storage = CosmosCheckpointStorage(container_client=mock_container)
results = await storage.list_checkpoints(workflow_name="test-workflow")
assert len(results) == 2
assert results[0].checkpoint_id == "cp-1"
assert results[1].checkpoint_id == "cp-2"
async def test_list_checkpoints_uses_partition_key(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
await storage.list_checkpoints(workflow_name="my-workflow")
kwargs = mock_container.query_items.call_args.kwargs
assert kwargs["partition_key"] == "my-workflow"
async def test_list_checkpoints_empty_returns_empty(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
results = await storage.list_checkpoints(workflow_name="test-workflow")
assert results == []
async def test_list_checkpoints_skips_malformed_documents(mock_container: MagicMock) -> None:
valid_cp = _make_checkpoint(checkpoint_id="cp-valid")
mock_container.query_items.return_value = _to_async_iter([
{"id": "bad_doc", "workflow_name": "test-workflow", "not_a_checkpoint": True},
_checkpoint_to_cosmos_document(valid_cp),
])
storage = CosmosCheckpointStorage(container_client=mock_container)
results = await storage.list_checkpoints(workflow_name="test-workflow")
assert len(results) == 1
assert results[0].checkpoint_id == "cp-valid"
# --- Tests for delete ---
async def test_delete_existing_returns_true(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([
{"id": "test-workflow_cp-del", "workflow_name": "test-workflow"},
])
storage = CosmosCheckpointStorage(container_client=mock_container)
result = await storage.delete("cp-del")
assert result is True
mock_container.delete_item.assert_awaited_once_with(
item="test-workflow_cp-del",
partition_key="test-workflow",
)
async def test_delete_nonexistent_returns_false(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
result = await storage.delete("nonexistent")
assert result is False
mock_container.delete_item.assert_not_awaited()
async def test_delete_cosmos_not_found_returns_false(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([
{"id": "test-workflow_cp-del", "workflow_name": "test-workflow"},
])
mock_container.delete_item = AsyncMock(side_effect=CosmosResourceNotFoundError)
storage = CosmosCheckpointStorage(container_client=mock_container)
result = await storage.delete("cp-del")
assert result is False
# --- Tests for get_latest ---
async def test_get_latest_returns_latest_checkpoint(mock_container: MagicMock) -> None:
cp = _make_checkpoint(checkpoint_id="cp-latest", timestamp="2025-06-01T00:00:00+00:00")
mock_container.query_items.return_value = _to_async_iter([
_checkpoint_to_cosmos_document(cp),
])
storage = CosmosCheckpointStorage(container_client=mock_container)
result = await storage.get_latest(workflow_name="test-workflow")
assert result is not None
assert result.checkpoint_id == "cp-latest"
async def test_get_latest_returns_none_when_empty(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
result = await storage.get_latest(workflow_name="test-workflow")
assert result is None
async def test_get_latest_uses_order_by_desc_with_limit(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
await storage.get_latest(workflow_name="test-workflow")
kwargs = mock_container.query_items.call_args.kwargs
assert "ORDER BY c.timestamp DESC" in kwargs["query"]
assert "OFFSET 0 LIMIT 1" in kwargs["query"]
# --- Tests for list_checkpoint_ids ---
async def test_list_checkpoint_ids_returns_ids(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([
{"checkpoint_id": "cp-1"},
{"checkpoint_id": "cp-2"},
])
storage = CosmosCheckpointStorage(container_client=mock_container)
ids = await storage.list_checkpoint_ids(workflow_name="test-workflow")
assert ids == ["cp-1", "cp-2"]
async def test_list_checkpoint_ids_empty_returns_empty(mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
storage = CosmosCheckpointStorage(container_client=mock_container)
ids = await storage.list_checkpoint_ids(workflow_name="test-workflow")
assert ids == []
# --- Tests for close and context manager ---
async def test_close_closes_owned_client(
monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
) -> None:
mock_factory = MagicMock(return_value=mock_cosmos_client)
monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory)
storage = CosmosCheckpointStorage(
endpoint="https://account.documents.azure.com:443/",
credential="key-123",
database_name="db1",
container_name="checkpoints",
)
await storage.close()
mock_cosmos_client.close.assert_awaited_once()
async def test_close_does_not_close_external_client(mock_cosmos_client: MagicMock) -> None:
storage = CosmosCheckpointStorage(
cosmos_client=mock_cosmos_client,
database_name="db1",
container_name="checkpoints",
)
await storage.close()
mock_cosmos_client.close.assert_not_awaited()
async def test_context_manager_closes_owned_client(
monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
) -> None:
mock_factory = MagicMock(return_value=mock_cosmos_client)
monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory)
async with CosmosCheckpointStorage(
endpoint="https://account.documents.azure.com:443/",
credential="key-123",
database_name="db1",
container_name="checkpoints",
) as storage:
assert storage is not None
mock_cosmos_client.close.assert_awaited_once()
async def test_context_manager_preserves_original_exception(mock_container: MagicMock) -> None:
storage = CosmosCheckpointStorage(container_client=mock_container)
with (
patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))),
pytest.raises(ValueError, match="inner error"),
):
async with storage:
raise ValueError("inner error")
async def test_context_manager_reraises_close_error(mock_container: MagicMock) -> None:
storage = CosmosCheckpointStorage(container_client=mock_container)
with (
patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))),
pytest.raises(RuntimeError, match="close failed"),
):
async with storage:
pass # no inner exception — close error should propagate
# --- Tests for save/load round-trip ---
async def test_round_trip_preserves_data(mock_container: MagicMock) -> None:
checkpoint = _make_checkpoint(
checkpoint_id="cp-roundtrip",
previous_checkpoint_id="cp-parent",
)
checkpoint.state = {"key": "value", "nested": {"a": 1}}
checkpoint.metadata = {"superstep": 3}
checkpoint.iteration_count = 5
saved_doc: dict[str, Any] = {}
async def capture_upsert(body: dict[str, Any]) -> dict[str, Any]:
saved_doc.update(body)
return body
mock_container.upsert_item = AsyncMock(side_effect=capture_upsert)
storage = CosmosCheckpointStorage(container_client=mock_container)
await storage.save(checkpoint)
returned_doc = {
**saved_doc,
"_rid": "abc",
"_self": "dbs/abc/colls/def/docs/ghi",
"_etag": '"etag"',
"_attachments": "attachments/",
"_ts": 1700000000,
}
mock_container.query_items.return_value = _to_async_iter([returned_doc])
loaded = await storage.load("cp-roundtrip")
assert loaded.checkpoint_id == checkpoint.checkpoint_id
assert loaded.workflow_name == checkpoint.workflow_name
assert loaded.graph_signature_hash == checkpoint.graph_signature_hash
assert loaded.previous_checkpoint_id == "cp-parent"
assert loaded.state == {"key": "value", "nested": {"a": 1}}
assert loaded.metadata == {"superstep": 3}
assert loaded.iteration_count == 5
assert loaded.version == "1.0"
# --- Integration test ---
@pytest.mark.integration
@skip_if_cosmos_integration_tests_disabled
async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None:
endpoint = os.getenv("AZURE_COSMOS_ENDPOINT", "")
key = os.getenv("AZURE_COSMOS_KEY", "")
database_prefix = os.getenv("AZURE_COSMOS_DATABASE_NAME", "")
container_prefix = os.getenv("AZURE_COSMOS_CONTAINER_NAME", "")
unique = uuid.uuid4().hex[:8]
database_name = f"{database_prefix}-cp-{unique}"
container_name = f"{container_prefix}-cp-{unique}"
async with CosmosClient(url=endpoint, credential=key) as cosmos_client:
await cosmos_client.create_database_if_not_exists(id=database_name)
storage = CosmosCheckpointStorage(
cosmos_client=cosmos_client,
database_name=database_name,
container_name=container_name,
)
try:
# Save two checkpoints for the same workflow
cp1 = _make_checkpoint(
checkpoint_id="cp-int-1",
workflow_name="integration-wf",
timestamp="2025-01-01T00:00:00+00:00",
)
cp2 = _make_checkpoint(
checkpoint_id="cp-int-2",
workflow_name="integration-wf",
previous_checkpoint_id="cp-int-1",
timestamp="2025-01-02T00:00:00+00:00",
)
cp2.state = {"step": 2}
await storage.save(cp1)
await storage.save(cp2)
# Load by ID
loaded = await storage.load("cp-int-1")
assert loaded.checkpoint_id == "cp-int-1"
assert loaded.workflow_name == "integration-wf"
# List all checkpoints for workflow
all_cps = await storage.list_checkpoints(workflow_name="integration-wf")
assert len(all_cps) == 2
# List checkpoint IDs
ids = await storage.list_checkpoint_ids(workflow_name="integration-wf")
assert "cp-int-1" in ids
assert "cp-int-2" in ids
# Get latest
latest = await storage.get_latest(workflow_name="integration-wf")
assert latest is not None
assert latest.checkpoint_id == "cp-int-2"
assert latest.state == {"step": 2}
# Delete
assert await storage.delete("cp-int-1") is True
assert await storage.delete("cp-int-1") is False
remaining = await storage.list_checkpoint_ids(workflow_name="integration-wf")
assert remaining == ["cp-int-2"]
# Cross-workflow isolation
other_cp = _make_checkpoint(
checkpoint_id="cp-other",
workflow_name="other-wf",
)
await storage.save(other_cp)
wf_cps = await storage.list_checkpoints(workflow_name="integration-wf")
assert len(wf_cps) == 1
assert wf_cps[0].checkpoint_id == "cp-int-2"
finally:
with suppress(Exception):
await cosmos_client.delete_database(database_name)