mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Cosmos DB NoSQL Checkpoint Storage for Python Workflows (#4916)
* Add CosmosCheckpointStorage for Python workflow checkpointing Add native Cosmos DB NoSQL support for workflow checkpoint storage in the Python agent-framework-azure-cosmos package, achieving parity with the existing .NET CosmosCheckpointStore. New files: - _checkpoint_storage.py: CosmosCheckpointStorage implementing the CheckpointStorage protocol with 6 methods (save, load, list_checkpoints, delete, get_latest, list_checkpoint_ids) - test_cosmos_checkpoint_storage.py: Unit and integration tests - workflow_checkpointing.py: Sample demonstrating Cosmos DB-backed workflow checkpoint/resume Auth support: - Managed identity / RBAC via Azure credential objects (DefaultAzureCredential, ManagedIdentityCredential, etc.) - Key-based auth via account key string or AZURE_COSMOS_KEY env var - Pre-created CosmosClient or ContainerProxy Key design decisions: - Partition key: /workflow_name for efficient per-workflow queries - Serialization: Reuses encode/decode_checkpoint_value for full Python object fidelity (hybrid JSON + pickle approach) - Container auto-creation via create_container_if_not_exists Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Adding cosmos checkpointer * Resolving comments * Fixing builds * Adding sample for history provider and checkpoint storage * Resolving comments * fixing builds * Resolving comments --------- Co-authored-by: Aayush Kataria <aayushkataria@Aayushs-MacBook-Pro-2.local> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
a7a02c1abd
commit
30a2bc3dcb
@@ -14,7 +14,7 @@ The Azure Cosmos DB integration provides `CosmosHistoryProvider` for persistent
|
||||
|
||||
```python
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
from agent_framework.azure import CosmosHistoryProvider
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
|
||||
provider = CosmosHistoryProvider(
|
||||
endpoint="https://<account>.documents.azure.com:443/",
|
||||
@@ -35,13 +35,92 @@ Container naming behavior:
|
||||
- Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`)
|
||||
- `session_id` is used as the Cosmos partition key for reads/writes
|
||||
|
||||
See the [conversation samples](../../samples/02-agents/conversations/) for runnable examples, including
|
||||
[`cosmos_history_provider.py`](../../samples/02-agents/conversations/cosmos_history_provider.py).
|
||||
See `samples/02-agents/conversations/cosmos_history_provider.py` for a runnable example.
|
||||
|
||||
## Import Paths
|
||||
## Cosmos DB Workflow Checkpoint Storage
|
||||
|
||||
`CosmosCheckpointStorage` implements the `CheckpointStorage` protocol, enabling
|
||||
durable workflow checkpointing backed by Azure Cosmos DB NoSQL. Workflows can be
|
||||
paused and resumed across process restarts by persisting checkpoint state in Cosmos DB.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
#### Managed Identity / RBAC (recommended for production)
|
||||
|
||||
```python
|
||||
from agent_framework.azure import CosmosHistoryProvider
|
||||
# or directly:
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
from agent_framework import WorkflowBuilder
|
||||
from agent_framework_azure_cosmos import CosmosCheckpointStorage
|
||||
|
||||
checkpoint_storage = CosmosCheckpointStorage(
|
||||
endpoint="https://<account>.documents.azure.com:443/",
|
||||
credential=DefaultAzureCredential(),
|
||||
database_name="agent-framework",
|
||||
container_name="workflow-checkpoints",
|
||||
)
|
||||
```
|
||||
|
||||
#### Account Key
|
||||
|
||||
```python
|
||||
from agent_framework_azure_cosmos import CosmosCheckpointStorage
|
||||
|
||||
checkpoint_storage = CosmosCheckpointStorage(
|
||||
endpoint="https://<account>.documents.azure.com:443/",
|
||||
credential="<your-account-key>",
|
||||
database_name="agent-framework",
|
||||
container_name="workflow-checkpoints",
|
||||
)
|
||||
```
|
||||
|
||||
#### Then use with a workflow
|
||||
|
||||
```python
|
||||
from agent_framework import WorkflowBuilder
|
||||
|
||||
# Build a workflow with checkpointing enabled
|
||||
workflow = WorkflowBuilder(
|
||||
start_executor=start,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
).build()
|
||||
|
||||
# Run the workflow — checkpoints are automatically saved after each superstep
|
||||
result = await workflow.run(message="input data")
|
||||
|
||||
# Resume from a checkpoint
|
||||
latest = await checkpoint_storage.get_latest(workflow_name=workflow.name)
|
||||
if latest:
|
||||
resumed = await workflow.run(checkpoint_id=latest.checkpoint_id)
|
||||
```
|
||||
|
||||
### Authentication Options
|
||||
|
||||
`CosmosCheckpointStorage` supports the same authentication modes as `CosmosHistoryProvider`:
|
||||
|
||||
- **Managed identity / RBAC** (recommended): Pass `DefaultAzureCredential()`,
|
||||
`ManagedIdentityCredential()`, or any Azure `TokenCredential`
|
||||
- **Account key**: Pass a key string via `credential` parameter
|
||||
- **Environment variables**: Set `AZURE_COSMOS_ENDPOINT`, `AZURE_COSMOS_DATABASE_NAME`,
|
||||
`AZURE_COSMOS_CONTAINER_NAME`, and `AZURE_COSMOS_KEY` (key not required when using
|
||||
Azure credentials)
|
||||
- **Pre-created client**: Pass an existing `CosmosClient` or `ContainerProxy`
|
||||
|
||||
### Database and Container Setup
|
||||
|
||||
The database and container are created automatically on first use (via
|
||||
`create_database_if_not_exists` and `create_container_if_not_exists`). The container
|
||||
uses `/workflow_name` as the partition key. You can also pre-create them in the Azure
|
||||
portal with this partition key configuration.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description |
|
||||
|---|---|
|
||||
| `AZURE_COSMOS_ENDPOINT` | Cosmos DB account endpoint |
|
||||
| `AZURE_COSMOS_DATABASE_NAME` | Database name |
|
||||
| `AZURE_COSMOS_CONTAINER_NAME` | Container name |
|
||||
| `AZURE_COSMOS_KEY` | Account key (optional if using Azure credentials) |
|
||||
|
||||
See `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing.py` for a standalone example,
|
||||
or `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py` for an end-to-end
|
||||
example with Azure AI Foundry agents.
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._checkpoint_storage import CosmosCheckpointStorage
|
||||
from ._history_provider import CosmosHistoryProvider
|
||||
|
||||
try:
|
||||
@@ -10,6 +11,7 @@ except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
"CosmosCheckpointStorage",
|
||||
"CosmosHistoryProvider",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,432 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Azure Cosmos DB checkpoint storage for workflow checkpointing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework._workflows._checkpoint import CheckpointID, WorkflowCheckpoint
|
||||
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
|
||||
from agent_framework.exceptions import WorkflowCheckpointException
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from azure.cosmos import PartitionKey
|
||||
from azure.cosmos.aio import ContainerProxy, CosmosClient
|
||||
from azure.cosmos.exceptions import CosmosResourceNotFoundError
|
||||
|
||||
AzureCredentialTypes = TokenCredential | AsyncTokenCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureCosmosCheckpointSettings(TypedDict, total=False):
|
||||
"""Settings for CosmosCheckpointStorage resolved from args and environment."""
|
||||
|
||||
endpoint: str | None
|
||||
database_name: str | None
|
||||
container_name: str | None
|
||||
key: SecretString | None
|
||||
|
||||
|
||||
class CosmosCheckpointStorage:
|
||||
"""Azure Cosmos DB-backed checkpoint storage for workflow checkpointing.
|
||||
|
||||
Implements the ``CheckpointStorage`` protocol using Azure Cosmos DB NoSQL
|
||||
as the persistent backend. Checkpoints are stored as JSON documents with
|
||||
``workflow_name`` as the partition key, enabling efficient per-workflow queries.
|
||||
|
||||
This storage uses the same hybrid JSON + pickle encoding as
|
||||
``FileCheckpointStorage``, allowing full Python object fidelity for
|
||||
complex workflow state while keeping the document structure human-readable.
|
||||
|
||||
SECURITY WARNING: Checkpoints use pickle for data serialization. Only load
|
||||
checkpoints from trusted sources. Loading a malicious checkpoint can execute
|
||||
arbitrary code.
|
||||
|
||||
The database and container are created automatically on first use
|
||||
if they do not already exist. The container uses partition key
|
||||
``/workflow_name``.
|
||||
|
||||
Example using managed identity / RBAC:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
from agent_framework_azure_cosmos import CosmosCheckpointStorage
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
endpoint="https://my-account.documents.azure.com:443/",
|
||||
credential=DefaultAzureCredential(),
|
||||
database_name="agent-db",
|
||||
container_name="checkpoints",
|
||||
)
|
||||
|
||||
Example using account key:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
endpoint="https://my-account.documents.azure.com:443/",
|
||||
credential="my-account-key",
|
||||
database_name="agent-db",
|
||||
container_name="checkpoints",
|
||||
)
|
||||
|
||||
Then use with a workflow builder:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
workflow = WorkflowBuilder(
|
||||
start_executor=start,
|
||||
checkpoint_storage=storage,
|
||||
).build()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint: str | None = None,
|
||||
database_name: str | None = None,
|
||||
container_name: str | None = None,
|
||||
credential: str | AzureCredentialTypes | None = None,
|
||||
cosmos_client: CosmosClient | None = None,
|
||||
container_client: ContainerProxy | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Azure Cosmos DB checkpoint storage.
|
||||
|
||||
Supports multiple authentication modes:
|
||||
|
||||
- **Container client** (``container_client``): Use a pre-created
|
||||
Cosmos async container proxy. No client lifecycle is managed.
|
||||
- **Cosmos client** (``cosmos_client``): Use a pre-created Cosmos
|
||||
async client. The caller is responsible for closing it.
|
||||
- **Endpoint + credential**: Create a new Cosmos client. The storage
|
||||
owns the client and closes it on ``close()``.
|
||||
- **Environment variables**: Falls back to ``AZURE_COSMOS_ENDPOINT``,
|
||||
``AZURE_COSMOS_DATABASE_NAME``, ``AZURE_COSMOS_CONTAINER_NAME``,
|
||||
and ``AZURE_COSMOS_KEY``.
|
||||
|
||||
Args:
|
||||
endpoint: Cosmos DB account endpoint.
|
||||
Can be set via ``AZURE_COSMOS_ENDPOINT``.
|
||||
database_name: Cosmos DB database name.
|
||||
Can be set via ``AZURE_COSMOS_DATABASE_NAME``.
|
||||
container_name: Cosmos DB container name.
|
||||
Can be set via ``AZURE_COSMOS_CONTAINER_NAME``.
|
||||
credential: Credential to authenticate with Cosmos DB.
|
||||
For **managed identity / RBAC**, pass an Azure credential object
|
||||
such as ``DefaultAzureCredential()`` or
|
||||
``ManagedIdentityCredential()``.
|
||||
For **key-based auth**, pass the account key as a string,
|
||||
or set ``AZURE_COSMOS_KEY`` in the environment.
|
||||
cosmos_client: Pre-created Cosmos async client.
|
||||
container_client: Pre-created Cosmos container client.
|
||||
env_file_path: Path to environment file for loading settings.
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
"""
|
||||
self._cosmos_client: CosmosClient | None = cosmos_client
|
||||
self._container_proxy: ContainerProxy | None = container_client
|
||||
self._owns_client = False
|
||||
|
||||
if self._container_proxy is not None:
|
||||
self.database_name: str = database_name or ""
|
||||
self.container_name: str = container_name or ""
|
||||
return
|
||||
|
||||
required_fields: list[str] = ["database_name", "container_name"]
|
||||
if cosmos_client is None:
|
||||
required_fields.append("endpoint")
|
||||
if credential is None:
|
||||
required_fields.append("key")
|
||||
|
||||
settings = load_settings(
|
||||
AzureCosmosCheckpointSettings,
|
||||
env_prefix="AZURE_COSMOS_",
|
||||
required_fields=required_fields,
|
||||
endpoint=endpoint,
|
||||
database_name=database_name,
|
||||
container_name=container_name,
|
||||
key=credential if isinstance(credential, str) else None,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
self.database_name = settings["database_name"] # type: ignore[assignment]
|
||||
self.container_name = settings["container_name"] # type: ignore[assignment]
|
||||
|
||||
if self._cosmos_client is None:
|
||||
self._cosmos_client = CosmosClient(
|
||||
url=settings["endpoint"], # type: ignore[arg-type]
|
||||
credential=credential or settings["key"].get_secret_value(), # type: ignore[arg-type,union-attr]
|
||||
user_agent_suffix=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
self._owns_client = True
|
||||
|
||||
async def save(self, checkpoint: WorkflowCheckpoint) -> CheckpointID:
|
||||
"""Save a checkpoint to Cosmos DB and return its ID.
|
||||
|
||||
The checkpoint is encoded to a JSON-compatible form (using pickle for
|
||||
non-JSON-native values) and stored as a Cosmos DB document with the
|
||||
``workflow_name`` as the partition key.
|
||||
|
||||
The document ``id`` is a composite of ``workflow_name`` and
|
||||
``checkpoint_id`` to ensure global uniqueness across partitions.
|
||||
|
||||
Args:
|
||||
checkpoint: The WorkflowCheckpoint object to save.
|
||||
|
||||
Returns:
|
||||
The unique ID of the saved checkpoint.
|
||||
"""
|
||||
await self._ensure_container_proxy()
|
||||
|
||||
checkpoint_dict = checkpoint.to_dict()
|
||||
encoded = encode_checkpoint_value(checkpoint_dict)
|
||||
|
||||
document: dict[str, Any] = {
|
||||
"id": self._make_document_id(checkpoint.workflow_name, checkpoint.checkpoint_id),
|
||||
"workflow_name": checkpoint.workflow_name,
|
||||
**encoded,
|
||||
}
|
||||
|
||||
await self._container_proxy.upsert_item(body=document) # type: ignore[union-attr]
|
||||
logger.info("Saved checkpoint %s to Cosmos DB", checkpoint.checkpoint_id)
|
||||
return checkpoint.checkpoint_id
|
||||
|
||||
async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint:
|
||||
"""Load a checkpoint from Cosmos DB by ID.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The unique ID of the checkpoint to load.
|
||||
|
||||
Returns:
|
||||
The WorkflowCheckpoint object corresponding to the given ID.
|
||||
|
||||
Raises:
|
||||
WorkflowCheckpointException: If no checkpoint with the given ID exists,
|
||||
or if multiple checkpoints share the same ID across workflows.
|
||||
"""
|
||||
await self._ensure_container_proxy()
|
||||
|
||||
query = "SELECT * FROM c WHERE c.checkpoint_id = @checkpoint_id"
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@checkpoint_id", "value": checkpoint_id},
|
||||
]
|
||||
|
||||
items = self._container_proxy.query_items( # type: ignore[union-attr]
|
||||
query=query,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async for item in items:
|
||||
results.append(item)
|
||||
|
||||
if not results:
|
||||
raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}")
|
||||
|
||||
if len(results) > 1:
|
||||
workflow_names = [r.get("workflow_name", "unknown") for r in results]
|
||||
raise WorkflowCheckpointException(
|
||||
f"Multiple checkpoints found with ID {checkpoint_id} across workflows: "
|
||||
f"{workflow_names}. Use list_checkpoints(workflow_name=...) to query "
|
||||
f"by workflow instead."
|
||||
)
|
||||
|
||||
return self._document_to_checkpoint(results[0])
|
||||
|
||||
async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]:
|
||||
"""List checkpoint objects for a given workflow name.
|
||||
|
||||
Args:
|
||||
workflow_name: The name of the workflow to list checkpoints for.
|
||||
|
||||
Returns:
|
||||
A list of WorkflowCheckpoint objects for the specified workflow name.
|
||||
"""
|
||||
await self._ensure_container_proxy()
|
||||
|
||||
query = "SELECT * FROM c WHERE c.workflow_name = @workflow_name ORDER BY c.timestamp ASC"
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@workflow_name", "value": workflow_name},
|
||||
]
|
||||
|
||||
items = self._container_proxy.query_items( # type: ignore[union-attr]
|
||||
query=query,
|
||||
parameters=parameters,
|
||||
partition_key=workflow_name,
|
||||
)
|
||||
|
||||
checkpoints: list[WorkflowCheckpoint] = []
|
||||
async for item in items:
|
||||
try:
|
||||
checkpoints.append(self._document_to_checkpoint(item))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to decode checkpoint document: %s", e)
|
||||
return checkpoints
|
||||
|
||||
async def delete(self, checkpoint_id: CheckpointID) -> bool:
|
||||
"""Delete a checkpoint from Cosmos DB by ID.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The unique ID of the checkpoint to delete.
|
||||
|
||||
Returns:
|
||||
True if the checkpoint was successfully deleted, False if not found.
|
||||
"""
|
||||
await self._ensure_container_proxy()
|
||||
|
||||
query = "SELECT c.id, c.workflow_name FROM c WHERE c.checkpoint_id = @checkpoint_id"
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@checkpoint_id", "value": checkpoint_id},
|
||||
]
|
||||
|
||||
items = self._container_proxy.query_items( # type: ignore[union-attr]
|
||||
query=query,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
async for item in items:
|
||||
try:
|
||||
await self._container_proxy.delete_item( # type: ignore[union-attr]
|
||||
item=item["id"],
|
||||
partition_key=item["workflow_name"],
|
||||
)
|
||||
logger.info("Deleted checkpoint %s from Cosmos DB", checkpoint_id)
|
||||
return True
|
||||
except CosmosResourceNotFoundError:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
|
||||
"""Get the latest checkpoint for a given workflow name.
|
||||
|
||||
Args:
|
||||
workflow_name: The name of the workflow to get the latest checkpoint for.
|
||||
|
||||
Returns:
|
||||
The latest WorkflowCheckpoint, or None if no checkpoints exist.
|
||||
"""
|
||||
await self._ensure_container_proxy()
|
||||
|
||||
query = (
|
||||
"SELECT * FROM c WHERE c.workflow_name = @workflow_name "
|
||||
"ORDER BY c.timestamp DESC OFFSET 0 LIMIT 1"
|
||||
)
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@workflow_name", "value": workflow_name},
|
||||
]
|
||||
|
||||
items = self._container_proxy.query_items( # type: ignore[union-attr]
|
||||
query=query,
|
||||
parameters=parameters,
|
||||
partition_key=workflow_name,
|
||||
)
|
||||
|
||||
async for item in items:
|
||||
checkpoint = self._document_to_checkpoint(item)
|
||||
logger.debug(
|
||||
"Latest checkpoint for workflow %s is %s",
|
||||
workflow_name,
|
||||
checkpoint.checkpoint_id,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
return None
|
||||
|
||||
async def list_checkpoint_ids(self, *, workflow_name: str) -> list[CheckpointID]:
|
||||
"""List checkpoint IDs for a given workflow name.
|
||||
|
||||
Args:
|
||||
workflow_name: The name of the workflow to list checkpoint IDs for.
|
||||
|
||||
Returns:
|
||||
A list of checkpoint IDs for the specified workflow name.
|
||||
"""
|
||||
await self._ensure_container_proxy()
|
||||
|
||||
query = (
|
||||
"SELECT c.checkpoint_id FROM c WHERE c.workflow_name = @workflow_name "
|
||||
"ORDER BY c.timestamp ASC"
|
||||
)
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@workflow_name", "value": workflow_name},
|
||||
]
|
||||
|
||||
items = self._container_proxy.query_items( # type: ignore[union-attr]
|
||||
query=query,
|
||||
parameters=parameters,
|
||||
partition_key=workflow_name,
|
||||
)
|
||||
|
||||
checkpoint_ids: list[CheckpointID] = []
|
||||
async for item in items:
|
||||
cid = item.get("checkpoint_id")
|
||||
if isinstance(cid, str):
|
||||
checkpoint_ids.append(cid)
|
||||
return checkpoint_ids
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying Cosmos client when this storage owns it."""
|
||||
if self._owns_client and self._cosmos_client is not None:
|
||||
await self._cosmos_client.close()
|
||||
|
||||
async def __aenter__(self) -> CosmosCheckpointStorage:
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
try:
|
||||
await self.close()
|
||||
except Exception:
|
||||
if exc_type is None:
|
||||
raise
|
||||
|
||||
async def _ensure_container_proxy(self) -> None:
|
||||
"""Get or create the Cosmos DB database and container for storing checkpoints."""
|
||||
if self._container_proxy is not None:
|
||||
return
|
||||
if self._cosmos_client is None:
|
||||
raise RuntimeError("Cosmos client is not initialized.")
|
||||
|
||||
database = await self._cosmos_client.create_database_if_not_exists(id=self.database_name)
|
||||
self._container_proxy = await database.create_container_if_not_exists(
|
||||
id=self.container_name,
|
||||
partition_key=PartitionKey(path="/workflow_name"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint:
|
||||
"""Convert a Cosmos DB document back to a WorkflowCheckpoint.
|
||||
|
||||
Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``,
|
||||
``_attachments``, ``_ts``) before decoding.
|
||||
"""
|
||||
# Remove Cosmos DB system properties and the composite 'id' field
|
||||
# (checkpoints use 'checkpoint_id', not 'id')
|
||||
cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"}
|
||||
cleaned = {k: v for k, v in document.items() if k not in cosmos_keys}
|
||||
|
||||
decoded = decode_checkpoint_value(cleaned)
|
||||
return WorkflowCheckpoint.from_dict(decoded)
|
||||
|
||||
@staticmethod
|
||||
def _make_document_id(workflow_name: str, checkpoint_id: str) -> str:
|
||||
"""Create a composite Cosmos DB document ID.
|
||||
|
||||
Combines ``workflow_name`` and ``checkpoint_id`` to ensure global
|
||||
uniqueness across partitions.
|
||||
"""
|
||||
return f"{workflow_name}_{checkpoint_id}"
|
||||
@@ -0,0 +1,599 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework._workflows._checkpoint import WorkflowCheckpoint
|
||||
from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value
|
||||
from agent_framework.exceptions import SettingNotFoundError, WorkflowCheckpointException
|
||||
from azure.cosmos.aio import CosmosClient
|
||||
from azure.cosmos.exceptions import CosmosResourceNotFoundError
|
||||
|
||||
import agent_framework_azure_cosmos._checkpoint_storage as checkpoint_storage_module
|
||||
from agent_framework_azure_cosmos._checkpoint_storage import CosmosCheckpointStorage
|
||||
|
||||
skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif(
|
||||
any(
|
||||
os.getenv(name, "") == ""
|
||||
for name in (
|
||||
"AZURE_COSMOS_ENDPOINT",
|
||||
"AZURE_COSMOS_KEY",
|
||||
"AZURE_COSMOS_DATABASE_NAME",
|
||||
"AZURE_COSMOS_CONTAINER_NAME",
|
||||
)
|
||||
),
|
||||
reason=(
|
||||
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_KEY, AZURE_COSMOS_DATABASE_NAME, and "
|
||||
"AZURE_COSMOS_CONTAINER_NAME are required for Cosmos integration tests."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]:
|
||||
async def _iterator() -> AsyncIterator[Any]:
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
return _iterator()
|
||||
|
||||
|
||||
def _make_checkpoint(
|
||||
workflow_name: str = "test-workflow",
|
||||
checkpoint_id: str | None = None,
|
||||
previous_checkpoint_id: str | None = None,
|
||||
timestamp: str | None = None,
|
||||
) -> WorkflowCheckpoint:
|
||||
"""Create a minimal WorkflowCheckpoint for testing."""
|
||||
return WorkflowCheckpoint(
|
||||
workflow_name=workflow_name,
|
||||
graph_signature_hash="abc123",
|
||||
checkpoint_id=checkpoint_id or str(uuid.uuid4()),
|
||||
previous_checkpoint_id=previous_checkpoint_id,
|
||||
timestamp=timestamp or "2025-01-01T00:00:00+00:00",
|
||||
state={"counter": 42},
|
||||
iteration_count=1,
|
||||
)
|
||||
|
||||
|
||||
def _checkpoint_to_cosmos_document(checkpoint: WorkflowCheckpoint) -> dict[str, Any]:
|
||||
"""Simulate what a Cosmos DB document looks like after save."""
|
||||
encoded = encode_checkpoint_value(checkpoint.to_dict())
|
||||
doc: dict[str, Any] = {
|
||||
"id": f"{checkpoint.workflow_name}_{checkpoint.checkpoint_id}",
|
||||
"workflow_name": checkpoint.workflow_name,
|
||||
**encoded,
|
||||
# Cosmos system properties
|
||||
"_rid": "abc",
|
||||
"_self": "dbs/abc/colls/def/docs/ghi",
|
||||
"_etag": '"00000000-0000-0000-0000-000000000000"',
|
||||
"_attachments": "attachments/",
|
||||
"_ts": 1700000000,
|
||||
}
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container() -> MagicMock:
|
||||
container = MagicMock()
|
||||
container.query_items = MagicMock(return_value=_to_async_iter([]))
|
||||
container.upsert_item = AsyncMock(return_value={})
|
||||
container.delete_item = AsyncMock(return_value={})
|
||||
return container
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cosmos_client(mock_container: MagicMock) -> MagicMock:
|
||||
database_client = MagicMock()
|
||||
database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container)
|
||||
|
||||
client = MagicMock()
|
||||
client.create_database_if_not_exists = AsyncMock(return_value=database_client)
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
# --- Tests for initialization ---
|
||||
|
||||
|
||||
async def test_init_uses_provided_container_client(mock_container: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
assert storage.database_name == ""
|
||||
assert storage.container_name == ""
|
||||
|
||||
|
||||
async def test_init_uses_provided_cosmos_client(mock_cosmos_client: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(
|
||||
cosmos_client=mock_cosmos_client,
|
||||
database_name="db1",
|
||||
container_name="checkpoints",
|
||||
)
|
||||
assert storage.database_name == "db1"
|
||||
assert storage.container_name == "checkpoints"
|
||||
|
||||
|
||||
async def test_init_missing_required_settings_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False)
|
||||
monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False)
|
||||
monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False)
|
||||
monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False)
|
||||
|
||||
with pytest.raises(SettingNotFoundError, match="database_name"):
|
||||
CosmosCheckpointStorage()
|
||||
|
||||
|
||||
async def test_init_constructs_client_with_credential(
|
||||
monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
|
||||
) -> None:
|
||||
"""Uses key-based auth when a key string is provided, otherwise falls back to Azure credential (RBAC)."""
|
||||
mock_factory = MagicMock(return_value=mock_cosmos_client)
|
||||
monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory)
|
||||
monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False)
|
||||
|
||||
# Simulate real-world pattern: use key if available, else RBAC credential
|
||||
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
|
||||
credential: Any = cosmos_key if cosmos_key else MagicMock() # MagicMock simulates DefaultAzureCredential()
|
||||
|
||||
CosmosCheckpointStorage(
|
||||
endpoint="https://account.documents.azure.com:443/",
|
||||
credential=credential,
|
||||
database_name="db1",
|
||||
container_name="checkpoints",
|
||||
)
|
||||
|
||||
mock_factory.assert_called_once()
|
||||
kwargs = mock_factory.call_args.kwargs
|
||||
assert kwargs["url"] == "https://account.documents.azure.com:443/"
|
||||
assert kwargs["credential"] is credential
|
||||
|
||||
|
||||
async def test_init_creates_database_and_container(mock_cosmos_client: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(
|
||||
cosmos_client=mock_cosmos_client,
|
||||
database_name="db1",
|
||||
container_name="custom-checkpoints",
|
||||
)
|
||||
|
||||
await storage.list_checkpoint_ids(workflow_name="wf")
|
||||
|
||||
mock_cosmos_client.create_database_if_not_exists.assert_awaited_once_with(id="db1")
|
||||
database_client = mock_cosmos_client.create_database_if_not_exists.return_value
|
||||
assert database_client.create_container_if_not_exists.await_count == 1
|
||||
kwargs = database_client.create_container_if_not_exists.await_args.kwargs
|
||||
assert kwargs["id"] == "custom-checkpoints"
|
||||
|
||||
|
||||
# --- Tests for save ---
|
||||
|
||||
|
||||
async def test_save_upserts_document(mock_container: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
checkpoint = _make_checkpoint()
|
||||
|
||||
result = await storage.save(checkpoint)
|
||||
|
||||
assert result == checkpoint.checkpoint_id
|
||||
mock_container.upsert_item.assert_awaited_once()
|
||||
document = mock_container.upsert_item.await_args.kwargs["body"]
|
||||
assert document["id"] == f"test-workflow_{checkpoint.checkpoint_id}"
|
||||
assert document["workflow_name"] == "test-workflow"
|
||||
assert document["graph_signature_hash"] == "abc123"
|
||||
assert document["state"]["counter"] == 42
|
||||
|
||||
|
||||
async def test_save_returns_checkpoint_id(mock_container: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
checkpoint = _make_checkpoint(checkpoint_id="cp-123")
|
||||
|
||||
result = await storage.save(checkpoint)
|
||||
|
||||
assert result == "cp-123"
|
||||
|
||||
|
||||
# --- Tests for load ---
|
||||
|
||||
|
||||
async def test_load_returns_checkpoint(mock_container: MagicMock) -> None:
|
||||
checkpoint = _make_checkpoint(checkpoint_id="cp-load")
|
||||
doc = _checkpoint_to_cosmos_document(checkpoint)
|
||||
mock_container.query_items.return_value = _to_async_iter([doc])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
loaded = await storage.load("cp-load")
|
||||
|
||||
assert loaded.checkpoint_id == "cp-load"
|
||||
assert loaded.workflow_name == "test-workflow"
|
||||
assert loaded.graph_signature_hash == "abc123"
|
||||
assert loaded.state["counter"] == 42
|
||||
|
||||
|
||||
async def test_load_nonexistent_raises(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="No checkpoint found"):
|
||||
await storage.load("nonexistent-id")
|
||||
|
||||
|
||||
async def test_load_queries_without_partition_key(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
with suppress(WorkflowCheckpointException):
|
||||
await storage.load("cp-id")
|
||||
|
||||
kwargs = mock_container.query_items.call_args.kwargs
|
||||
assert "partition_key" not in kwargs
|
||||
|
||||
|
||||
async def test_load_multiple_workflows_same_checkpoint_id_raises(mock_container: MagicMock) -> None:
|
||||
cp1 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-a")
|
||||
cp2 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-b")
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
_checkpoint_to_cosmos_document(cp1),
|
||||
_checkpoint_to_cosmos_document(cp2),
|
||||
])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="Multiple checkpoints found"):
|
||||
await storage.load("shared-id")
|
||||
|
||||
|
||||
# --- Tests for list_checkpoints ---
|
||||
|
||||
|
||||
async def test_list_checkpoints_returns_checkpoints_for_workflow(mock_container: MagicMock) -> None:
|
||||
cp1 = _make_checkpoint(checkpoint_id="cp-1", timestamp="2025-01-01T00:00:00+00:00")
|
||||
cp2 = _make_checkpoint(checkpoint_id="cp-2", timestamp="2025-01-02T00:00:00+00:00")
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
_checkpoint_to_cosmos_document(cp1),
|
||||
_checkpoint_to_cosmos_document(cp2),
|
||||
])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
results = await storage.list_checkpoints(workflow_name="test-workflow")
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].checkpoint_id == "cp-1"
|
||||
assert results[1].checkpoint_id == "cp-2"
|
||||
|
||||
|
||||
async def test_list_checkpoints_uses_partition_key(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
await storage.list_checkpoints(workflow_name="my-workflow")
|
||||
|
||||
kwargs = mock_container.query_items.call_args.kwargs
|
||||
assert kwargs["partition_key"] == "my-workflow"
|
||||
|
||||
|
||||
async def test_list_checkpoints_empty_returns_empty(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
results = await storage.list_checkpoints(workflow_name="test-workflow")
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
async def test_list_checkpoints_skips_malformed_documents(mock_container: MagicMock) -> None:
|
||||
valid_cp = _make_checkpoint(checkpoint_id="cp-valid")
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
{"id": "bad_doc", "workflow_name": "test-workflow", "not_a_checkpoint": True},
|
||||
_checkpoint_to_cosmos_document(valid_cp),
|
||||
])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
results = await storage.list_checkpoints(workflow_name="test-workflow")
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].checkpoint_id == "cp-valid"
|
||||
|
||||
|
||||
# --- Tests for delete ---
|
||||
|
||||
|
||||
async def test_delete_existing_returns_true(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
{"id": "test-workflow_cp-del", "workflow_name": "test-workflow"},
|
||||
])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
result = await storage.delete("cp-del")
|
||||
|
||||
assert result is True
|
||||
mock_container.delete_item.assert_awaited_once_with(
|
||||
item="test-workflow_cp-del",
|
||||
partition_key="test-workflow",
|
||||
)
|
||||
|
||||
|
||||
async def test_delete_nonexistent_returns_false(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
result = await storage.delete("nonexistent")
|
||||
|
||||
assert result is False
|
||||
mock_container.delete_item.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_delete_cosmos_not_found_returns_false(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
{"id": "test-workflow_cp-del", "workflow_name": "test-workflow"},
|
||||
])
|
||||
mock_container.delete_item = AsyncMock(side_effect=CosmosResourceNotFoundError)
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
result = await storage.delete("cp-del")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# --- Tests for get_latest ---
|
||||
|
||||
|
||||
async def test_get_latest_returns_latest_checkpoint(mock_container: MagicMock) -> None:
|
||||
cp = _make_checkpoint(checkpoint_id="cp-latest", timestamp="2025-06-01T00:00:00+00:00")
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
_checkpoint_to_cosmos_document(cp),
|
||||
])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
result = await storage.get_latest(workflow_name="test-workflow")
|
||||
|
||||
assert result is not None
|
||||
assert result.checkpoint_id == "cp-latest"
|
||||
|
||||
|
||||
async def test_get_latest_returns_none_when_empty(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
result = await storage.get_latest(workflow_name="test-workflow")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_get_latest_uses_order_by_desc_with_limit(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
await storage.get_latest(workflow_name="test-workflow")
|
||||
|
||||
kwargs = mock_container.query_items.call_args.kwargs
|
||||
assert "ORDER BY c.timestamp DESC" in kwargs["query"]
|
||||
assert "OFFSET 0 LIMIT 1" in kwargs["query"]
|
||||
|
||||
|
||||
# --- Tests for list_checkpoint_ids ---
|
||||
|
||||
|
||||
async def test_list_checkpoint_ids_returns_ids(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
{"checkpoint_id": "cp-1"},
|
||||
{"checkpoint_id": "cp-2"},
|
||||
])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
ids = await storage.list_checkpoint_ids(workflow_name="test-workflow")
|
||||
|
||||
assert ids == ["cp-1", "cp-2"]
|
||||
|
||||
|
||||
async def test_list_checkpoint_ids_empty_returns_empty(mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
ids = await storage.list_checkpoint_ids(workflow_name="test-workflow")
|
||||
|
||||
assert ids == []
|
||||
|
||||
|
||||
# --- Tests for close and context manager ---
|
||||
|
||||
|
||||
async def test_close_closes_owned_client(
|
||||
monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
|
||||
) -> None:
|
||||
mock_factory = MagicMock(return_value=mock_cosmos_client)
|
||||
monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory)
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
endpoint="https://account.documents.azure.com:443/",
|
||||
credential="key-123",
|
||||
database_name="db1",
|
||||
container_name="checkpoints",
|
||||
)
|
||||
|
||||
await storage.close()
|
||||
|
||||
mock_cosmos_client.close.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_close_does_not_close_external_client(mock_cosmos_client: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(
|
||||
cosmos_client=mock_cosmos_client,
|
||||
database_name="db1",
|
||||
container_name="checkpoints",
|
||||
)
|
||||
|
||||
await storage.close()
|
||||
|
||||
mock_cosmos_client.close.assert_not_awaited()
|
||||
|
||||
|
||||
async def test_context_manager_closes_owned_client(
|
||||
monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
|
||||
) -> None:
|
||||
mock_factory = MagicMock(return_value=mock_cosmos_client)
|
||||
monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory)
|
||||
|
||||
async with CosmosCheckpointStorage(
|
||||
endpoint="https://account.documents.azure.com:443/",
|
||||
credential="key-123",
|
||||
database_name="db1",
|
||||
container_name="checkpoints",
|
||||
) as storage:
|
||||
assert storage is not None
|
||||
|
||||
mock_cosmos_client.close.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_context_manager_preserves_original_exception(mock_container: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
|
||||
with (
|
||||
patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))),
|
||||
pytest.raises(ValueError, match="inner error"),
|
||||
):
|
||||
async with storage:
|
||||
raise ValueError("inner error")
|
||||
|
||||
|
||||
async def test_context_manager_reraises_close_error(mock_container: MagicMock) -> None:
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
|
||||
with (
|
||||
patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))),
|
||||
pytest.raises(RuntimeError, match="close failed"),
|
||||
):
|
||||
async with storage:
|
||||
pass # no inner exception — close error should propagate
|
||||
|
||||
|
||||
# --- Tests for save/load round-trip ---
|
||||
|
||||
|
||||
async def test_round_trip_preserves_data(mock_container: MagicMock) -> None:
|
||||
checkpoint = _make_checkpoint(
|
||||
checkpoint_id="cp-roundtrip",
|
||||
previous_checkpoint_id="cp-parent",
|
||||
)
|
||||
checkpoint.state = {"key": "value", "nested": {"a": 1}}
|
||||
checkpoint.metadata = {"superstep": 3}
|
||||
checkpoint.iteration_count = 5
|
||||
|
||||
saved_doc: dict[str, Any] = {}
|
||||
|
||||
async def capture_upsert(body: dict[str, Any]) -> dict[str, Any]:
|
||||
saved_doc.update(body)
|
||||
return body
|
||||
|
||||
mock_container.upsert_item = AsyncMock(side_effect=capture_upsert)
|
||||
|
||||
storage = CosmosCheckpointStorage(container_client=mock_container)
|
||||
await storage.save(checkpoint)
|
||||
|
||||
returned_doc = {
|
||||
**saved_doc,
|
||||
"_rid": "abc",
|
||||
"_self": "dbs/abc/colls/def/docs/ghi",
|
||||
"_etag": '"etag"',
|
||||
"_attachments": "attachments/",
|
||||
"_ts": 1700000000,
|
||||
}
|
||||
mock_container.query_items.return_value = _to_async_iter([returned_doc])
|
||||
|
||||
loaded = await storage.load("cp-roundtrip")
|
||||
|
||||
assert loaded.checkpoint_id == checkpoint.checkpoint_id
|
||||
assert loaded.workflow_name == checkpoint.workflow_name
|
||||
assert loaded.graph_signature_hash == checkpoint.graph_signature_hash
|
||||
assert loaded.previous_checkpoint_id == "cp-parent"
|
||||
assert loaded.state == {"key": "value", "nested": {"a": 1}}
|
||||
assert loaded.metadata == {"superstep": 3}
|
||||
assert loaded.iteration_count == 5
|
||||
assert loaded.version == "1.0"
|
||||
|
||||
|
||||
# --- Integration test ---
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@skip_if_cosmos_integration_tests_disabled
|
||||
async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None:
|
||||
endpoint = os.getenv("AZURE_COSMOS_ENDPOINT", "")
|
||||
key = os.getenv("AZURE_COSMOS_KEY", "")
|
||||
database_prefix = os.getenv("AZURE_COSMOS_DATABASE_NAME", "")
|
||||
container_prefix = os.getenv("AZURE_COSMOS_CONTAINER_NAME", "")
|
||||
unique = uuid.uuid4().hex[:8]
|
||||
database_name = f"{database_prefix}-cp-{unique}"
|
||||
container_name = f"{container_prefix}-cp-{unique}"
|
||||
|
||||
async with CosmosClient(url=endpoint, credential=key) as cosmos_client:
|
||||
await cosmos_client.create_database_if_not_exists(id=database_name)
|
||||
|
||||
storage = CosmosCheckpointStorage(
|
||||
cosmos_client=cosmos_client,
|
||||
database_name=database_name,
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
try:
|
||||
# Save two checkpoints for the same workflow
|
||||
cp1 = _make_checkpoint(
|
||||
checkpoint_id="cp-int-1",
|
||||
workflow_name="integration-wf",
|
||||
timestamp="2025-01-01T00:00:00+00:00",
|
||||
)
|
||||
cp2 = _make_checkpoint(
|
||||
checkpoint_id="cp-int-2",
|
||||
workflow_name="integration-wf",
|
||||
previous_checkpoint_id="cp-int-1",
|
||||
timestamp="2025-01-02T00:00:00+00:00",
|
||||
)
|
||||
cp2.state = {"step": 2}
|
||||
|
||||
await storage.save(cp1)
|
||||
await storage.save(cp2)
|
||||
|
||||
# Load by ID
|
||||
loaded = await storage.load("cp-int-1")
|
||||
assert loaded.checkpoint_id == "cp-int-1"
|
||||
assert loaded.workflow_name == "integration-wf"
|
||||
|
||||
# List all checkpoints for workflow
|
||||
all_cps = await storage.list_checkpoints(workflow_name="integration-wf")
|
||||
assert len(all_cps) == 2
|
||||
|
||||
# List checkpoint IDs
|
||||
ids = await storage.list_checkpoint_ids(workflow_name="integration-wf")
|
||||
assert "cp-int-1" in ids
|
||||
assert "cp-int-2" in ids
|
||||
|
||||
# Get latest
|
||||
latest = await storage.get_latest(workflow_name="integration-wf")
|
||||
assert latest is not None
|
||||
assert latest.checkpoint_id == "cp-int-2"
|
||||
assert latest.state == {"step": 2}
|
||||
|
||||
# Delete
|
||||
assert await storage.delete("cp-int-1") is True
|
||||
assert await storage.delete("cp-int-1") is False
|
||||
|
||||
remaining = await storage.list_checkpoint_ids(workflow_name="integration-wf")
|
||||
assert remaining == ["cp-int-2"]
|
||||
|
||||
# Cross-workflow isolation
|
||||
other_cp = _make_checkpoint(
|
||||
checkpoint_id="cp-other",
|
||||
workflow_name="other-wf",
|
||||
)
|
||||
await storage.save(other_cp)
|
||||
wf_cps = await storage.list_checkpoints(workflow_name="integration-wf")
|
||||
assert len(wf_cps) == 1
|
||||
assert wf_cps[0].checkpoint_id == "cp-int-2"
|
||||
|
||||
finally:
|
||||
with suppress(Exception):
|
||||
await cosmos_client.delete_database(database_name)
|
||||
@@ -9,6 +9,9 @@ These samples demonstrate different approaches to managing conversation history
|
||||
| [`suspend_resume_session.py`](suspend_resume_session.py) | Suspend and resume conversation sessions, comparing service-managed sessions (Azure AI Foundry) with in-memory sessions (OpenAI). |
|
||||
| [`custom_history_provider.py`](custom_history_provider.py) | Implement a custom history provider by extending `HistoryProvider`, enabling conversation persistence in your preferred storage backend. |
|
||||
| [`cosmos_history_provider.py`](cosmos_history_provider.py) | Use Azure Cosmos DB as a history provider for durable conversation storage with `CosmosHistoryProvider`. |
|
||||
| [`cosmos_history_provider_conversation_persistence.py`](cosmos_history_provider_conversation_persistence.py) | Persist and resume conversations across application restarts using `CosmosHistoryProvider` — serialize session state, restore it, and continue with full Cosmos DB history. |
|
||||
| [`cosmos_history_provider_messages.py`](cosmos_history_provider_messages.py) | Direct message history operations — retrieve stored messages as a transcript, clear session history, and verify data deletion. |
|
||||
| [`cosmos_history_provider_sessions.py`](cosmos_history_provider_sessions.py) | Multi-session and multi-tenant management — per-tenant session isolation, `list_sessions()` to enumerate, switch between sessions, and resume specific conversations. |
|
||||
| [`redis_history_provider.py`](redis_history_provider.py) | Use Redis as a history provider for persistent conversation history storage across sessions. |
|
||||
|
||||
## Prerequisites
|
||||
@@ -22,7 +25,7 @@ These samples demonstrate different approaches to managing conversation history
|
||||
**For `custom_history_provider.py`:**
|
||||
- `OPENAI_API_KEY`: Your OpenAI API key
|
||||
|
||||
**For `cosmos_history_provider.py`:**
|
||||
**For Cosmos DB samples (`cosmos_history_provider*.py`):**
|
||||
- `FOUNDRY_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint
|
||||
- `FOUNDRY_MODEL`: The Foundry model deployment name
|
||||
- `AZURE_COSMOS_ENDPOINT`: Your Azure Cosmos DB account endpoint
|
||||
|
||||
+165
@@ -0,0 +1,165 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# ruff: noqa: T201
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import Agent, AgentSession
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file.
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
This sample demonstrates persisting and resuming conversations across application
|
||||
restarts using CosmosHistoryProvider as the persistent backend.
|
||||
|
||||
Key components:
|
||||
- Phase 1: Run a conversation and serialize the session with session.to_dict()
|
||||
- Phase 2: Simulate an app restart — create new provider and agent instances,
|
||||
restore the session with AgentSession.from_dict(), and continue the conversation
|
||||
- Cosmos DB reloads the full message history, so the agent remembers everything
|
||||
|
||||
Environment variables:
|
||||
FOUNDRY_PROJECT_ENDPOINT
|
||||
FOUNDRY_MODEL
|
||||
AZURE_COSMOS_ENDPOINT
|
||||
AZURE_COSMOS_DATABASE_NAME
|
||||
AZURE_COSMOS_CONTAINER_NAME
|
||||
Optional:
|
||||
AZURE_COSMOS_KEY
|
||||
"""
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the conversation persistence sample."""
|
||||
project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT")
|
||||
model = os.getenv("FOUNDRY_MODEL")
|
||||
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
|
||||
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
|
||||
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
|
||||
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
|
||||
|
||||
if (
|
||||
not project_endpoint
|
||||
or not model
|
||||
or not cosmos_endpoint
|
||||
or not cosmos_database_name
|
||||
or not cosmos_container_name
|
||||
):
|
||||
print(
|
||||
"Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, "
|
||||
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME."
|
||||
)
|
||||
return
|
||||
|
||||
# ── Phase 1: Initial conversation ──
|
||||
|
||||
print("=== Phase 1: Initial conversation ===\n")
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
CosmosHistoryProvider(
|
||||
endpoint=cosmos_endpoint,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
credential=cosmos_key or credential,
|
||||
) as history_provider,
|
||||
Agent(
|
||||
client=FoundryChatClient(
|
||||
project_endpoint=project_endpoint,
|
||||
model=model,
|
||||
credential=credential,
|
||||
),
|
||||
name="PersistentAgent",
|
||||
instructions="You are a helpful assistant that remembers prior turns.",
|
||||
context_providers=[history_provider],
|
||||
default_options={"store": False},
|
||||
) as agent,
|
||||
):
|
||||
session = agent.create_session()
|
||||
|
||||
response1 = await agent.run(
|
||||
"My name is Ada. I'm building a distributed database in Rust.", session=session
|
||||
)
|
||||
print("User: My name is Ada. I'm building a distributed database in Rust.")
|
||||
print(f"Assistant: {response1.text}\n")
|
||||
|
||||
response2 = await agent.run("The hardest part is the consensus algorithm.", session=session)
|
||||
print("User: The hardest part is the consensus algorithm.")
|
||||
print(f"Assistant: {response2.text}\n")
|
||||
|
||||
serialized_session = session.to_dict()
|
||||
print(f"Session serialized. Session ID: {session.session_id}")
|
||||
|
||||
# ── Phase 2: Simulate app restart ──
|
||||
|
||||
print("\n=== Phase 2: Resuming after 'restart' ===\n")
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
CosmosHistoryProvider(
|
||||
endpoint=cosmos_endpoint,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
credential=cosmos_key or credential,
|
||||
) as history_provider,
|
||||
Agent(
|
||||
client=FoundryChatClient(
|
||||
project_endpoint=project_endpoint,
|
||||
model=model,
|
||||
credential=credential,
|
||||
),
|
||||
name="PersistentAgent",
|
||||
instructions="You are a helpful assistant that remembers prior turns.",
|
||||
context_providers=[history_provider],
|
||||
default_options={"store": False},
|
||||
) as agent,
|
||||
):
|
||||
restored_session = AgentSession.from_dict(serialized_session)
|
||||
print(f"Session restored. Session ID: {restored_session.session_id}\n")
|
||||
|
||||
response3 = await agent.run("What was I working on and what was the challenge?", session=restored_session)
|
||||
print("User: What was I working on and what was the challenge?")
|
||||
print(f"Assistant: {response3.text}\n")
|
||||
|
||||
messages = await history_provider.get_messages(restored_session.session_id)
|
||||
print(f"Messages stored in Cosmos DB: {len(messages)}")
|
||||
for i, msg in enumerate(messages, 1):
|
||||
print(f" {i}. [{msg.role}] {msg.text[:80]}...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Phase 1: Initial conversation ===
|
||||
|
||||
User: My name is Ada. I'm building a distributed database in Rust.
|
||||
Assistant: That sounds like a great project, Ada! Rust is an excellent choice for ...
|
||||
|
||||
User: The hardest part is the consensus algorithm.
|
||||
Assistant: Consensus algorithms can be tricky! Are you looking at Raft, Paxos, or ...
|
||||
|
||||
Session serialized. Session ID: <session-uuid>
|
||||
|
||||
=== Phase 2: Resuming after 'restart' ===
|
||||
|
||||
Session restored. Session ID: <session-uuid>
|
||||
|
||||
User: What was I working on and what was the challenge?
|
||||
Assistant: You told me you're building a distributed database in Rust and that the hardest
|
||||
part is the consensus algorithm.
|
||||
|
||||
Messages stored in Cosmos DB: 6
|
||||
1. [user] My name is Ada. I'm building a distributed database in Rust....
|
||||
2. [assistant] That sounds like a great project, Ada! Rust is an excellent ch...
|
||||
3. [user] The hardest part is the consensus algorithm....
|
||||
4. [assistant] Consensus algorithms can be tricky! Are you looking at Raft, Pa...
|
||||
5. [user] What was I working on and what was the challenge?...
|
||||
6. [assistant] You told me you're building a distributed database in Rust and ...
|
||||
"""
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# ruff: noqa: T201
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file.
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
This sample demonstrates direct message history operations using
|
||||
CosmosHistoryProvider — retrieving, displaying, and clearing stored messages.
|
||||
|
||||
Key components:
|
||||
- get_messages(session_id): Retrieve all stored messages as a chat transcript
|
||||
- clear(session_id): Delete all messages for a session (e.g., GDPR compliance)
|
||||
- Verifying that history is empty after clearing
|
||||
- Running a new conversation in the same session after clearing
|
||||
|
||||
Environment variables:
|
||||
FOUNDRY_PROJECT_ENDPOINT
|
||||
FOUNDRY_MODEL
|
||||
AZURE_COSMOS_ENDPOINT
|
||||
AZURE_COSMOS_DATABASE_NAME
|
||||
AZURE_COSMOS_CONTAINER_NAME
|
||||
Optional:
|
||||
AZURE_COSMOS_KEY
|
||||
"""
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the messages history sample."""
|
||||
project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT")
|
||||
model = os.getenv("FOUNDRY_MODEL")
|
||||
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
|
||||
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
|
||||
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
|
||||
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
|
||||
|
||||
if (
|
||||
not project_endpoint
|
||||
or not model
|
||||
or not cosmos_endpoint
|
||||
or not cosmos_database_name
|
||||
or not cosmos_container_name
|
||||
):
|
||||
print(
|
||||
"Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, "
|
||||
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME."
|
||||
)
|
||||
return
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
CosmosHistoryProvider(
|
||||
endpoint=cosmos_endpoint,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
credential=cosmos_key or credential,
|
||||
) as history_provider,
|
||||
Agent(
|
||||
client=FoundryChatClient(
|
||||
project_endpoint=project_endpoint,
|
||||
model=model,
|
||||
credential=credential,
|
||||
),
|
||||
name="HistoryAgent",
|
||||
instructions="You are a helpful assistant that remembers prior turns.",
|
||||
context_providers=[history_provider],
|
||||
default_options={"store": False},
|
||||
) as agent,
|
||||
):
|
||||
session = agent.create_session()
|
||||
session_id = session.session_id
|
||||
|
||||
# 1. Have a multi-turn conversation.
|
||||
print("=== Building a conversation ===\n")
|
||||
|
||||
queries = [
|
||||
"Hi! My favorite programming language is Python.",
|
||||
"I also enjoy hiking in the mountains on weekends.",
|
||||
"What do you know about me so far?",
|
||||
]
|
||||
for query in queries:
|
||||
response = await agent.run(query, session=session)
|
||||
print(f"User: {query}")
|
||||
print(f"Assistant: {response.text}\n")
|
||||
|
||||
# 2. Retrieve and display the full message history as a transcript.
|
||||
print("=== Chat transcript from Cosmos DB ===\n")
|
||||
|
||||
messages = await history_provider.get_messages(session_id)
|
||||
print(f"Total messages stored: {len(messages)}\n")
|
||||
for i, msg in enumerate(messages, 1):
|
||||
print(f" {i}. [{msg.role}] {msg.text[:100]}")
|
||||
|
||||
# 3. Clear the session history.
|
||||
print("\n=== Clearing session history ===\n")
|
||||
|
||||
await history_provider.clear(session_id)
|
||||
print(f"Cleared all messages for session: {session_id}")
|
||||
|
||||
# 4. Verify history is empty.
|
||||
remaining = await history_provider.get_messages(session_id)
|
||||
print(f"Messages after clear: {len(remaining)}")
|
||||
|
||||
# 5. Start a fresh conversation in the same session — agent has no memory.
|
||||
print("\n=== Fresh conversation (same session, no memory) ===\n")
|
||||
|
||||
response = await agent.run("What do you know about me?", session=session)
|
||||
print("User: What do you know about me?")
|
||||
print(f"Assistant: {response.text}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Building a conversation ===
|
||||
|
||||
User: Hi! My favorite programming language is Python.
|
||||
Assistant: That's great! Python is a wonderful language. What do you like most about it?
|
||||
|
||||
User: I also enjoy hiking in the mountains on weekends.
|
||||
Assistant: Hiking sounds lovely! Do you have a favorite trail or mountain range?
|
||||
|
||||
User: What do you know about me so far?
|
||||
Assistant: You love Python as your favorite programming language and enjoy hiking in the mountains on weekends.
|
||||
|
||||
=== Chat transcript from Cosmos DB ===
|
||||
|
||||
Total messages stored: 6
|
||||
|
||||
1. [user] Hi! My favorite programming language is Python.
|
||||
2. [assistant] That's great! Python is a wonderful language. What do you like most about it?
|
||||
3. [user] I also enjoy hiking in the mountains on weekends.
|
||||
4. [assistant] Hiking sounds lovely! Do you have a favorite trail or mountain range?
|
||||
5. [user] What do you know about me so far?
|
||||
6. [assistant] You love Python as your favorite programming language and enjoy hiking ...
|
||||
|
||||
=== Clearing session history ===
|
||||
|
||||
Cleared all messages for session: <session-uuid>
|
||||
Messages after clear: 0
|
||||
|
||||
=== Fresh conversation (same session, no memory) ===
|
||||
|
||||
User: What do you know about me?
|
||||
Assistant: I don't have any information about you yet. Feel free to share anything you'd like!
|
||||
"""
|
||||
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# ruff: noqa: T201
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file.
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
This sample demonstrates multi-session and multi-tenant management using
|
||||
CosmosHistoryProvider. Each tenant (user) gets isolated conversation sessions
|
||||
stored in the same Cosmos DB container, partitioned by session_id.
|
||||
|
||||
Key components:
|
||||
- Per-tenant session isolation using prefixed session IDs
|
||||
- list_sessions(): Enumerate all stored sessions across tenants
|
||||
- Switching between sessions for different users
|
||||
- Resuming a specific user's session — verifying data isolation
|
||||
|
||||
Environment variables:
|
||||
FOUNDRY_PROJECT_ENDPOINT
|
||||
FOUNDRY_MODEL
|
||||
AZURE_COSMOS_ENDPOINT
|
||||
AZURE_COSMOS_DATABASE_NAME
|
||||
AZURE_COSMOS_CONTAINER_NAME
|
||||
Optional:
|
||||
AZURE_COSMOS_KEY
|
||||
"""
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the session management sample."""
|
||||
project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT")
|
||||
model = os.getenv("FOUNDRY_MODEL")
|
||||
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
|
||||
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
|
||||
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
|
||||
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
|
||||
|
||||
if (
|
||||
not project_endpoint
|
||||
or not model
|
||||
or not cosmos_endpoint
|
||||
or not cosmos_database_name
|
||||
or not cosmos_container_name
|
||||
):
|
||||
print(
|
||||
"Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, "
|
||||
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME."
|
||||
)
|
||||
return
|
||||
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
CosmosHistoryProvider(
|
||||
endpoint=cosmos_endpoint,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
credential=cosmos_key or credential,
|
||||
) as history_provider,
|
||||
Agent(
|
||||
client=FoundryChatClient(
|
||||
project_endpoint=project_endpoint,
|
||||
model=model,
|
||||
credential=credential,
|
||||
),
|
||||
name="MultiTenantAgent",
|
||||
instructions="You are a helpful assistant that remembers prior turns.",
|
||||
context_providers=[history_provider],
|
||||
default_options={"store": False},
|
||||
) as agent,
|
||||
):
|
||||
# 1. Tenant "alice" starts a conversation about travel.
|
||||
print("=== Tenant: Alice — Travel conversation ===\n")
|
||||
|
||||
alice_session = agent.create_session(session_id="tenant-alice-session-1")
|
||||
|
||||
response = await agent.run(
|
||||
"Hi! I'm planning a trip to Italy. I love Renaissance art.", session=alice_session
|
||||
)
|
||||
print("Alice: I'm planning a trip to Italy. I love Renaissance art.")
|
||||
print(f"Assistant: {response.text}\n")
|
||||
|
||||
response = await agent.run("Which museums should I visit in Florence?", session=alice_session)
|
||||
print("Alice: Which museums should I visit in Florence?")
|
||||
print(f"Assistant: {response.text}\n")
|
||||
|
||||
# 2. Tenant "bob" starts a separate conversation about cooking.
|
||||
print("=== Tenant: Bob — Cooking conversation ===\n")
|
||||
|
||||
bob_session = agent.create_session(session_id="tenant-bob-session-1")
|
||||
|
||||
response = await agent.run(
|
||||
"Hey! I'm learning to cook Thai food. I just made pad thai.", session=bob_session
|
||||
)
|
||||
print("Bob: I'm learning to cook Thai food. I just made pad thai.")
|
||||
print(f"Assistant: {response.text}\n")
|
||||
|
||||
response = await agent.run("What Thai dish should I try next?", session=bob_session)
|
||||
print("Bob: What Thai dish should I try next?")
|
||||
print(f"Assistant: {response.text}\n")
|
||||
|
||||
# 3. List all sessions stored in Cosmos DB.
|
||||
print("=== Listing all sessions ===\n")
|
||||
|
||||
sessions = await history_provider.list_sessions()
|
||||
print(f"Found {len(sessions)} session(s):")
|
||||
for sid in sessions:
|
||||
print(f" - {sid}")
|
||||
|
||||
# 4. Resume Alice's session — verify she gets her travel context back.
|
||||
print("\n=== Resuming Alice's session ===\n")
|
||||
|
||||
alice_resumed = agent.create_session(session_id="tenant-alice-session-1")
|
||||
|
||||
response = await agent.run("What were we discussing?", session=alice_resumed)
|
||||
print("Alice: What were we discussing?")
|
||||
print(f"Assistant: {response.text}\n")
|
||||
|
||||
# 5. Resume Bob's session — verify he gets his cooking context back.
|
||||
print("=== Resuming Bob's session ===\n")
|
||||
|
||||
bob_resumed = agent.create_session(session_id="tenant-bob-session-1")
|
||||
|
||||
response = await agent.run("What was the last dish I mentioned?", session=bob_resumed)
|
||||
print("Bob: What was the last dish I mentioned?")
|
||||
print(f"Assistant: {response.text}\n")
|
||||
|
||||
# 6. Show per-session message counts.
|
||||
print("=== Per-session message counts ===\n")
|
||||
|
||||
alice_messages = await history_provider.get_messages("tenant-alice-session-1")
|
||||
bob_messages = await history_provider.get_messages("tenant-bob-session-1")
|
||||
print(f"Alice's session: {len(alice_messages)} messages")
|
||||
print(f"Bob's session: {len(bob_messages)} messages")
|
||||
|
||||
# 7. Clean up: clear both sessions.
|
||||
print("\n=== Cleaning up ===\n")
|
||||
|
||||
await history_provider.clear("tenant-alice-session-1")
|
||||
await history_provider.clear("tenant-bob-session-1")
|
||||
print("Cleared Alice's and Bob's sessions.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Tenant: Alice — Travel conversation ===
|
||||
|
||||
Alice: I'm planning a trip to Italy. I love Renaissance art.
|
||||
Assistant: Italy is a dream for Renaissance art lovers! Florence, Rome, and Venice ...
|
||||
|
||||
Alice: Which museums should I visit in Florence?
|
||||
Assistant: In Florence, the Uffizi Gallery is a must — it has Botticelli's Birth of Venus ...
|
||||
|
||||
=== Tenant: Bob — Cooking conversation ===
|
||||
|
||||
Bob: I'm learning to cook Thai food. I just made pad thai.
|
||||
Assistant: Pad thai is a great start! How did it turn out?
|
||||
|
||||
Bob: What Thai dish should I try next?
|
||||
Assistant: I'd suggest trying green curry or tom yum soup — both are classic Thai dishes ...
|
||||
|
||||
=== Listing all sessions ===
|
||||
|
||||
Found 2 session(s):
|
||||
- tenant-alice-session-1
|
||||
- tenant-bob-session-1
|
||||
|
||||
=== Resuming Alice's session ===
|
||||
|
||||
Alice: What were we discussing?
|
||||
Assistant: We were discussing your trip to Italy and your love for Renaissance art ...
|
||||
|
||||
=== Resuming Bob's session ===
|
||||
|
||||
Bob: What was the last dish I mentioned?
|
||||
Assistant: You mentioned pad thai — it was the dish you just made!
|
||||
|
||||
=== Per-session message counts ===
|
||||
|
||||
Alice's session: 6 messages
|
||||
Bob's session: 6 messages
|
||||
|
||||
=== Cleaning up ===
|
||||
|
||||
Cleared Alice's and Bob's sessions.
|
||||
"""
|
||||
@@ -52,6 +52,8 @@ Once comfortable with these, explore the rest of the samples below.
|
||||
| Checkpointed Sub-Workflow | [checkpoint/sub_workflow_checkpoint.py](./checkpoint/sub_workflow_checkpoint.py) | Save and resume a sub-workflow that pauses for human approval |
|
||||
| Handoff + Tool Approval Resume | [orchestrations/handoff_with_tool_approval_checkpoint_resume.py](./orchestrations/handoff_with_tool_approval_checkpoint_resume.py) | Handoff workflow that captures tool-call approvals in checkpoints and resumes with human decisions |
|
||||
| Workflow as Agent Checkpoint | [checkpoint/workflow_as_agent_checkpoint.py](./checkpoint/workflow_as_agent_checkpoint.py) | Enable checkpointing when using workflow.as_agent() with checkpoint_storage parameter |
|
||||
| Cosmos DB Checkpoint Storage | [checkpoint/cosmos_workflow_checkpointing.py](./checkpoint/cosmos_workflow_checkpointing.py) | Use `CosmosCheckpointStorage` for durable workflow checkpointing backed by Azure Cosmos DB NoSQL |
|
||||
| Cosmos DB + Foundry Checkpoint | [checkpoint/cosmos_workflow_checkpointing_foundry.py](./checkpoint/cosmos_workflow_checkpointing_foundry.py) | Multi-agent workflow using `FoundryChatClient` with `CosmosCheckpointStorage` for durable pause/resume |
|
||||
|
||||
### composition
|
||||
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# ruff: noqa: T201
|
||||
|
||||
"""Sample: Workflow Checkpointing with Cosmos DB NoSQL.
|
||||
|
||||
Purpose:
|
||||
This sample shows how to use Azure Cosmos DB NoSQL as a persistent checkpoint
|
||||
storage backend for workflows, enabling durable pause-and-resume across
|
||||
process restarts.
|
||||
|
||||
What you learn:
|
||||
- How to configure CosmosCheckpointStorage for workflow checkpointing
|
||||
- How to run a workflow that automatically persists checkpoints to Cosmos DB
|
||||
- How to resume a workflow from a Cosmos DB checkpoint
|
||||
- How to list and inspect available checkpoints
|
||||
|
||||
Prerequisites:
|
||||
- An Azure Cosmos DB account (or local emulator)
|
||||
- Environment variables set (see below)
|
||||
|
||||
Environment variables:
|
||||
AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint
|
||||
AZURE_COSMOS_DATABASE_NAME - Database name
|
||||
AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints
|
||||
Optional:
|
||||
AZURE_COSMOS_KEY - Account key (if not using Azure credentials)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
WorkflowBuilder,
|
||||
WorkflowCheckpoint,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
|
||||
from agent_framework_azure_cosmos import CosmosCheckpointStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeTask:
|
||||
"""Task containing the list of numbers remaining to be processed."""
|
||||
|
||||
remaining_numbers: list[int]
|
||||
|
||||
|
||||
class StartExecutor(Executor):
|
||||
"""Initiates the workflow by providing the upper limit."""
|
||||
|
||||
@handler
|
||||
async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None:
|
||||
"""Start the workflow with numbers up to the given limit."""
|
||||
print(f"StartExecutor: Starting computation up to {upper_limit}")
|
||||
await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1))))
|
||||
|
||||
|
||||
class WorkerExecutor(Executor):
|
||||
"""Processes numbers and manages executor state for checkpointing."""
|
||||
|
||||
def __init__(self, id: str) -> None:
|
||||
"""Initialize the worker executor."""
|
||||
super().__init__(id=id)
|
||||
self._results: dict[int, list[tuple[int, int]]] = {}
|
||||
|
||||
@handler
|
||||
async def compute(
|
||||
self,
|
||||
task: ComputeTask,
|
||||
ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]],
|
||||
) -> None:
|
||||
"""Process the next number, computing its factor pairs."""
|
||||
next_number = task.remaining_numbers.pop(0)
|
||||
print(f"WorkerExecutor: Processing {next_number}")
|
||||
|
||||
pairs: list[tuple[int, int]] = []
|
||||
for i in range(1, next_number):
|
||||
if next_number % i == 0:
|
||||
pairs.append((i, next_number // i))
|
||||
self._results[next_number] = pairs
|
||||
|
||||
if not task.remaining_numbers:
|
||||
await ctx.yield_output(self._results)
|
||||
else:
|
||||
await ctx.send_message(task)
|
||||
|
||||
@override
|
||||
async def on_checkpoint_save(self) -> dict[str, Any]:
|
||||
return {"results": self._results}
|
||||
|
||||
@override
|
||||
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
|
||||
self._results = state.get("results", {})
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the workflow checkpointing sample with Cosmos DB."""
|
||||
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
|
||||
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
|
||||
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
|
||||
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
|
||||
|
||||
if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name:
|
||||
print(
|
||||
"Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, "
|
||||
"and AZURE_COSMOS_CONTAINER_NAME."
|
||||
)
|
||||
return
|
||||
|
||||
# Authentication: supports both managed identity/RBAC and key-based auth.
|
||||
# When AZURE_COSMOS_KEY is set, key-based auth is used.
|
||||
# Otherwise, falls back to DefaultAzureCredential (properly closed via async with).
|
||||
if cosmos_key:
|
||||
async with CosmosCheckpointStorage(
|
||||
endpoint=cosmos_endpoint,
|
||||
credential=cosmos_key,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
) as checkpoint_storage:
|
||||
await _run_workflow(checkpoint_storage)
|
||||
else:
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
|
||||
async with DefaultAzureCredential() as credential, CosmosCheckpointStorage(
|
||||
endpoint=cosmos_endpoint,
|
||||
credential=credential,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
) as checkpoint_storage:
|
||||
await _run_workflow(checkpoint_storage)
|
||||
|
||||
|
||||
async def _run_workflow(checkpoint_storage: CosmosCheckpointStorage) -> None:
|
||||
"""Build and run the workflow with Cosmos DB checkpointing."""
|
||||
start = StartExecutor(id="start")
|
||||
worker = WorkerExecutor(id="worker")
|
||||
workflow_builder = (
|
||||
WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage)
|
||||
.add_edge(start, worker)
|
||||
.add_edge(worker, worker)
|
||||
)
|
||||
|
||||
# --- First run: execute the workflow ---
|
||||
print("\n=== First Run ===\n")
|
||||
workflow = workflow_builder.build()
|
||||
|
||||
output = None
|
||||
async for event in workflow.run(message=8, stream=True):
|
||||
if event.type == "output":
|
||||
output = event.data
|
||||
|
||||
print(f"Factor pairs computed: {output}")
|
||||
|
||||
# List checkpoints saved in Cosmos DB
|
||||
checkpoint_ids = await checkpoint_storage.list_checkpoint_ids(
|
||||
workflow_name=workflow.name,
|
||||
)
|
||||
print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}")
|
||||
for cid in checkpoint_ids:
|
||||
print(f" - {cid}")
|
||||
|
||||
# Get the latest checkpoint
|
||||
latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest(
|
||||
workflow_name=workflow.name,
|
||||
)
|
||||
|
||||
if latest is None:
|
||||
print("No checkpoint found to resume from.")
|
||||
return
|
||||
|
||||
print(f"\nLatest checkpoint: {latest.checkpoint_id}")
|
||||
print(f" iteration_count: {latest.iteration_count}")
|
||||
print(f" timestamp: {latest.timestamp}")
|
||||
|
||||
# --- Second run: resume from the latest checkpoint ---
|
||||
print("\n=== Resuming from Checkpoint ===\n")
|
||||
workflow2 = workflow_builder.build()
|
||||
|
||||
output2 = None
|
||||
async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True):
|
||||
if event.type == "output":
|
||||
output2 = event.data
|
||||
|
||||
if output2:
|
||||
print(f"Resumed workflow produced: {output2}")
|
||||
else:
|
||||
print("Resumed workflow completed (no remaining work — already finished).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# ruff: noqa: T201
|
||||
|
||||
"""Sample: Workflow Checkpointing with Cosmos DB and Azure AI Foundry.
|
||||
|
||||
Purpose:
|
||||
This sample demonstrates how to use CosmosCheckpointStorage with agents built
|
||||
on Azure AI Foundry (via FoundryChatClient). It shows a multi-agent
|
||||
workflow where checkpoint state is persisted to Cosmos DB, enabling durable
|
||||
pause-and-resume across process restarts.
|
||||
|
||||
What you learn:
|
||||
- How to wire CosmosCheckpointStorage with FoundryChatClient agents
|
||||
- How to combine session history with workflow checkpointing
|
||||
- How to resume a workflow-as-agent from a Cosmos DB checkpoint
|
||||
|
||||
Key concepts:
|
||||
- AgentSession: Maintains conversation history across agent invocations
|
||||
- CosmosCheckpointStorage: Persists workflow execution state in Cosmos DB
|
||||
- These are complementary: sessions track conversation, checkpoints track workflow state
|
||||
|
||||
Environment variables:
|
||||
FOUNDRY_PROJECT_ENDPOINT - Azure AI Foundry project endpoint
|
||||
FOUNDRY_MODEL - Model deployment name
|
||||
AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint
|
||||
AZURE_COSMOS_DATABASE_NAME - Database name
|
||||
AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints
|
||||
Optional:
|
||||
AZURE_COSMOS_KEY - Account key (if not using Azure credentials)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from agent_framework_azure_cosmos import CosmosCheckpointStorage
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the Azure AI Foundry + Cosmos DB checkpointing sample."""
|
||||
project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT")
|
||||
model = os.getenv("FOUNDRY_MODEL")
|
||||
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
|
||||
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
|
||||
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
|
||||
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
|
||||
|
||||
if not project_endpoint or not model:
|
||||
print("Please set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL.")
|
||||
return
|
||||
|
||||
if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name:
|
||||
print(
|
||||
"Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, "
|
||||
"and AZURE_COSMOS_CONTAINER_NAME."
|
||||
)
|
||||
return
|
||||
|
||||
# Use a single AzureCliCredential for both Cosmos and Foundry,
|
||||
# properly closed via async context manager.
|
||||
async with AzureCliCredential() as azure_credential:
|
||||
cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential
|
||||
|
||||
async with CosmosCheckpointStorage(
|
||||
endpoint=cosmos_endpoint,
|
||||
credential=cosmos_credential,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
) as checkpoint_storage:
|
||||
# Create Azure AI Foundry agents
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=project_endpoint,
|
||||
model=model,
|
||||
credential=azure_credential,
|
||||
)
|
||||
|
||||
assistant = Agent(
|
||||
name="assistant",
|
||||
instructions="You are a helpful assistant. Keep responses brief.",
|
||||
client=client,
|
||||
)
|
||||
|
||||
reviewer = Agent(
|
||||
name="reviewer",
|
||||
instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.",
|
||||
client=client,
|
||||
)
|
||||
|
||||
# Build a sequential workflow and wrap it as an agent
|
||||
workflow = SequentialBuilder(participants=[assistant, reviewer]).build()
|
||||
agent = workflow.as_agent(name="FoundryCheckpointedAgent")
|
||||
|
||||
# --- First run: execute with Cosmos DB checkpointing ---
|
||||
print("=== First Run ===\n")
|
||||
|
||||
session = agent.create_session()
|
||||
query = "What are the benefits of renewable energy?"
|
||||
print(f"User: {query}")
|
||||
|
||||
response = await agent.run(query, session=session, checkpoint_storage=checkpoint_storage)
|
||||
|
||||
for msg in response.messages:
|
||||
speaker = msg.author_name or msg.role
|
||||
print(f"[{speaker}]: {msg.text}")
|
||||
|
||||
# Show checkpoints persisted in Cosmos DB
|
||||
checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name)
|
||||
print(f"\nCheckpoints in Cosmos DB: {len(checkpoints)}")
|
||||
for i, cp in enumerate(checkpoints[:5], 1):
|
||||
print(f" {i}. {cp.checkpoint_id} (iteration={cp.iteration_count})")
|
||||
|
||||
# --- Second run: continue conversation with checkpoint history ---
|
||||
print("\n=== Second Run (continuing conversation) ===\n")
|
||||
|
||||
query2 = "Can you elaborate on the economic benefits?"
|
||||
print(f"User: {query2}")
|
||||
|
||||
response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage)
|
||||
|
||||
for msg in response2.messages:
|
||||
speaker = msg.author_name or msg.role
|
||||
print(f"[{speaker}]: {msg.text}")
|
||||
|
||||
# Show total checkpoints
|
||||
all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name)
|
||||
print(f"\nTotal checkpoints after two runs: {len(all_checkpoints)}")
|
||||
|
||||
# Get latest checkpoint
|
||||
latest = await checkpoint_storage.get_latest(workflow_name=workflow.name)
|
||||
if latest:
|
||||
print(f"Latest checkpoint: {latest.checkpoint_id}")
|
||||
print(f" iteration_count: {latest.iteration_count}")
|
||||
print(f" timestamp: {latest.timestamp}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user