Python: Add Azure Cosmos history provider package (#4271)

* Created cosmos history provider

* add marker

* Python: address Cosmos PR feedback

- address provider/test/sample review feedback and cleanup typing
- add cosmos integration test coverage and skip gating
- add dedicated cosmos emulator jobs to python merge/integration workflows
- switch cosmos workflow execution to package poe integration-tests task

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: handle empty Cosmos session id

- replace default partition fallback for empty session_id
- log warning and generate GUID when session_id is empty
- update unit tests to validate GUID fallback behavior

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix sample

* fix cross partition query

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-03-03 13:29:32 +01:00
committed by GitHub
Unverified
parent 945933c351
commit c37f74f898
14 changed files with 1616 additions and 445 deletions
+47 -1
View File
@@ -247,6 +247,51 @@ jobs:
timeout-minutes: 15
run: uv run --directory packages/azure-ai poe integration-tests -n logical --dist worksteal --timeout=120 --session-timeout=900 --timeout_method thread --retries 2 --retry-delay 5
# Azure Cosmos integration tests
python-tests-cosmos:
name: Python Integration Tests - Cosmos
runs-on: ubuntu-latest
environment: integration
timeout-minutes: 60
services:
cosmosdb:
image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview
ports:
- 8081:8081
env:
AZURE_COSMOS_ENDPOINT: "http://localhost:8081/"
# Static Azure Cosmos DB emulator key (documented): https://learn.microsoft.com/en-us/azure/cosmos-db/emulator
AZURE_COSMOS_KEY: "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
AZURE_COSMOS_DATABASE_NAME: "agent-framework-cosmos-it-db"
AZURE_COSMOS_CONTAINER_NAME: "agent-framework-cosmos-it-container"
defaults:
run:
working-directory: python
steps:
- uses: actions/checkout@v6
with:
ref: ${{ inputs.checkout-ref }}
persist-credentials: false
- name: Set up python and install the project
id: python-setup
uses: ./.github/actions/python-setup
with:
python-version: ${{ env.UV_PYTHON }}
os: ${{ runner.os }}
- name: Wait for Cosmos DB emulator
run: |
for i in {1..60}; do
if curl --silent --show-error http://localhost:8081/ > /dev/null; then
echo "Cosmos DB emulator is ready."
exit 0
fi
sleep 2
done
echo "Cosmos DB emulator did not become ready in time." >&2
exit 1
- name: Test with pytest (Cosmos integration)
run: uv run --directory packages/azure-cosmos poe integration-tests -n logical --dist worksteal --timeout=120 --session-timeout=900 --timeout_method thread --retries 2 --retry-delay 5
python-integration-tests-check:
if: always()
runs-on: ubuntu-latest
@@ -257,7 +302,8 @@ jobs:
python-tests-azure-openai,
python-tests-misc-integration,
python-tests-functions,
python-tests-azure-ai
python-tests-azure-ai,
python-tests-cosmos
]
steps:
- name: Fail workflow if tests failed
+62
View File
@@ -38,6 +38,7 @@ jobs:
miscChanged: ${{ steps.filter.outputs.misc }}
functionsChanged: ${{ steps.filter.outputs.functions }}
azureAiChanged: ${{ steps.filter.outputs.azure-ai }}
cosmosChanged: ${{ steps.filter.outputs.cosmos }}
steps:
- uses: actions/checkout@v6
- uses: dorny/paths-filter@v3
@@ -67,6 +68,8 @@ jobs:
- 'python/packages/durabletask/**'
azure-ai:
- 'python/packages/azure-ai/**'
cosmos:
- 'python/packages/azure-cosmos/**'
# run only if 'python' files were changed
- name: python tests
if: steps.filter.outputs.python == 'true'
@@ -390,6 +393,64 @@ jobs:
# TODO: Add python-tests-lab
# Azure Cosmos integration tests
python-tests-cosmos:
name: Python Tests - Cosmos Integration
needs: paths-filter
if: >
github.event_name != 'pull_request' &&
needs.paths-filter.outputs.pythonChanges == 'true' &&
(github.event_name != 'merge_group' ||
needs.paths-filter.outputs.cosmosChanged == 'true' ||
needs.paths-filter.outputs.coreChanged == 'true')
runs-on: ubuntu-latest
environment: integration
services:
cosmosdb:
image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview
ports:
- 8081:8081
env:
AZURE_COSMOS_ENDPOINT: "http://localhost:8081/"
# Static Azure Cosmos DB emulator key (documented): https://learn.microsoft.com/en-us/azure/cosmos-db/emulator
AZURE_COSMOS_KEY: "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
AZURE_COSMOS_DATABASE_NAME: "agent-framework-cosmos-it-db"
AZURE_COSMOS_CONTAINER_NAME: "agent-framework-cosmos-it-container"
defaults:
run:
working-directory: python
steps:
- uses: actions/checkout@v6
- name: Set up python and install the project
id: python-setup
uses: ./.github/actions/python-setup
with:
python-version: ${{ env.UV_PYTHON }}
os: ${{ runner.os }}
- name: Wait for Cosmos DB emulator
run: |
for i in {1..60}; do
if curl --silent --show-error http://localhost:8081/ > /dev/null; then
echo "Cosmos DB emulator is ready."
exit 0
fi
sleep 2
done
echo "Cosmos DB emulator did not become ready in time." >&2
exit 1
- name: Test with pytest (Cosmos integration)
run: uv run --directory packages/azure-cosmos poe integration-tests -n logical --dist worksteal --timeout=120 --session-timeout=900 --timeout_method thread --retries 2 --retry-delay 5
working-directory: ./python
- name: Surface failing tests
if: always()
uses: pmeier/pytest-results-action@v0.7.2
with:
path: ./python/**.xml
summary: true
display-options: fEX
fail-on-empty: false
title: Cosmos integration test results
python-integration-tests-check:
if: always()
runs-on: ubuntu-latest
@@ -401,6 +462,7 @@ jobs:
python-tests-misc-integration,
python-tests-functions,
python-tests-azure-ai,
python-tests-cosmos,
]
steps:
- name: Fail workflow if tests failed
+28
View File
@@ -0,0 +1,28 @@
# Azure Cosmos DB Package (agent-framework-azure-cosmos)
Azure Cosmos DB history provider integration for Agent Framework.
## Main Classes
- **`CosmosHistoryProvider`** - Persistent conversation history storage backed by Azure Cosmos DB
## Usage
```python
from agent_framework_azure_cosmos import CosmosHistoryProvider
provider = CosmosHistoryProvider(
endpoint="https://<account>.documents.azure.com:443/",
credential="<key-or-token-credential>",
database_name="agent-framework",
container_name="chat-history",
)
```
Container name is configured on the provider. `session_id` is used as the partition key.
## Import Path
```python
from agent_framework_azure_cosmos import CosmosHistoryProvider
```
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
+38
View File
@@ -0,0 +1,38 @@
# Get Started with Microsoft Agent Framework Azure Cosmos DB
Please install this package via pip:
```bash
pip install agent-framework-azure-cosmos --pre
```
## Azure Cosmos DB History Provider
The Azure Cosmos DB integration provides `CosmosHistoryProvider` for persistent conversation history storage.
### Basic Usage Example
```python
from azure.identity.aio import DefaultAzureCredential
from agent_framework_azure_cosmos import CosmosHistoryProvider
provider = CosmosHistoryProvider(
endpoint="https://<account>.documents.azure.com:443/",
credential=DefaultAzureCredential(),
database_name="agent-framework",
container_name="chat-history",
)
```
Credentials follow the same pattern used by other Azure connectors in the repository:
- Pass a credential object (for example `DefaultAzureCredential`)
- Or pass a key string directly
- Or set `AZURE_COSMOS_KEY` in the environment
Container naming behavior:
- Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`)
- `session_id` is used as the Cosmos partition key for reads/writes
See `samples/cosmos_history_provider.py` for a runnable package-local example.
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft. All rights reserved.
import importlib.metadata
from ._history_provider import CosmosHistoryProvider
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = [
"CosmosHistoryProvider",
"__version__",
]
@@ -0,0 +1,269 @@
# Copyright (c) Microsoft. All rights reserved.
"""Azure Cosmos DB history provider."""
from __future__ import annotations
import logging
import time
import uuid
from collections.abc import Sequence
from typing import Any, ClassVar, TypedDict
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
from agent_framework._sessions import BaseHistoryProvider
from agent_framework._settings import SecretString, load_settings
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
from azure.cosmos import PartitionKey
from azure.cosmos.aio import ContainerProxy, CosmosClient, DatabaseProxy
logger = logging.getLogger(__name__)
class AzureCosmosHistorySettings(TypedDict, total=False):
"""Settings for CosmosHistoryProvider resolved from args and environment."""
endpoint: str | None
database_name: str | None
container_name: str | None
key: SecretString | None
class CosmosHistoryProvider(BaseHistoryProvider):
"""Azure Cosmos DB-backed history provider using BaseHistoryProvider hooks."""
DEFAULT_SOURCE_ID: ClassVar[str] = "azure_cosmos_history"
_BATCH_OPERATION_LIMIT: ClassVar[int] = 100
def __init__(
self,
source_id: str = DEFAULT_SOURCE_ID,
*,
load_messages: bool = True,
store_outputs: bool = True,
store_inputs: bool = True,
store_context_messages: bool = False,
store_context_from: set[str] | None = None,
endpoint: str | None = None,
database_name: str | None = None,
container_name: str | None = None,
credential: str | AzureCredentialTypes | None = None,
cosmos_client: CosmosClient | None = None,
container_client: ContainerProxy | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize the Azure Cosmos DB history provider.
Args:
source_id: Unique identifier for this provider instance.
load_messages: Whether to load messages before invocation.
store_outputs: Whether to store response messages.
store_inputs: Whether to store input messages.
store_context_messages: Whether to store context from other providers.
store_context_from: If set, only store context from these source_ids.
endpoint: Cosmos DB account endpoint.
Can be set via ``AZURE_COSMOS_ENDPOINT``.
database_name: Cosmos DB database name.
Can be set via ``AZURE_COSMOS_DATABASE_NAME``.
container_name: Cosmos DB container name.
Can be set via ``AZURE_COSMOS_CONTAINER_NAME``.
credential: Credential to authenticate with Cosmos DB.
Supports key string and Azure credential objects.
Can be set via ``AZURE_COSMOS_KEY`` when omitted.
cosmos_client: Pre-created Cosmos async client.
container_client: Pre-created Cosmos container client for fixed-container usage.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
"""
super().__init__(
source_id,
load_messages=load_messages,
store_outputs=store_outputs,
store_inputs=store_inputs,
store_context_messages=store_context_messages,
store_context_from=store_context_from,
)
self._cosmos_client: CosmosClient | None = cosmos_client
self._container_proxy: ContainerProxy | None = container_client
self._owns_client = False
self._database_client: DatabaseProxy | None = None
if self._container_proxy is not None:
self.database_name: str = database_name or ""
self.container_name: str = container_name or ""
return
required_fields: list[str] = ["database_name", "container_name"]
if cosmos_client is None:
required_fields.append("endpoint")
if credential is None:
required_fields.append("key")
settings = load_settings(
AzureCosmosHistorySettings,
env_prefix="AZURE_COSMOS_",
required_fields=required_fields,
endpoint=endpoint,
database_name=database_name,
container_name=container_name,
key=credential if isinstance(credential, str) else None,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self.database_name = settings["database_name"] # type: ignore[assignment]
self.container_name = settings["container_name"] # type: ignore[assignment]
if self._cosmos_client is None:
self._cosmos_client = CosmosClient(
url=settings["endpoint"], # type: ignore[arg-type]
credential=credential or settings["key"].get_secret_value(), # type: ignore[arg-type,union-attr]
user_agent_suffix=AGENT_FRAMEWORK_USER_AGENT,
)
self._owns_client = True
self._database_client = self._cosmos_client.get_database_client(self.database_name)
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
"""Retrieve stored messages for this session from Azure Cosmos DB."""
await self._ensure_container_proxy()
session_key = self._session_partition_key(session_id)
query = (
"SELECT c.message FROM c "
"WHERE c.session_id = @session_id AND c.source_id = @source_id "
"ORDER BY c.sort_key ASC"
)
parameters: list[dict[str, object]] = [
{"name": "@session_id", "value": session_key},
{"name": "@source_id", "value": self.source_id},
]
items = self._container_proxy.query_items( # type: ignore[union-attr]
query=query, parameters=parameters, partition_key=session_key
)
messages: list[Message] = []
async for item in items:
message_payload = item.get("message")
if isinstance(message_payload, dict):
messages.append(Message.from_dict(message_payload))
return messages
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
"""Persist messages for this session to Azure Cosmos DB."""
if not messages:
return
await self._ensure_container_proxy()
session_key = self._session_partition_key(session_id)
base_sort_key = time.time_ns()
operations: list[tuple[str, tuple[dict[str, Any]]]] = []
for index, message in enumerate(messages):
document = {
"id": str(uuid.uuid4()),
"session_id": session_key,
"sort_key": base_sort_key + index,
"source_id": self.source_id,
"message": message.to_dict(),
}
operations.append(("upsert", (document,)))
for start in range(0, len(operations), self._BATCH_OPERATION_LIMIT):
batch = operations[start : start + self._BATCH_OPERATION_LIMIT]
await self._container_proxy.execute_item_batch( # type: ignore[union-attr]
batch_operations=batch, partition_key=session_key
)
async def clear(self, session_id: str | None) -> None:
"""Clear all messages for a session from Azure Cosmos DB."""
await self._ensure_container_proxy()
session_key = self._session_partition_key(session_id)
query = "SELECT c.id FROM c WHERE c.session_id = @session_id AND c.source_id = @source_id"
parameters: list[dict[str, object]] = [
{"name": "@session_id", "value": session_key},
{"name": "@source_id", "value": self.source_id},
]
items = self._container_proxy.query_items( # type: ignore[union-attr]
query=query, parameters=parameters, partition_key=session_key
)
delete_operations: list[tuple[str, tuple[str]]] = []
async for item in items:
item_id = item.get("id")
if isinstance(item_id, str):
delete_operations.append(("delete", (item_id,)))
for start in range(0, len(delete_operations), self._BATCH_OPERATION_LIMIT):
batch = delete_operations[start : start + self._BATCH_OPERATION_LIMIT]
await self._container_proxy.execute_item_batch( # type: ignore[union-attr]
batch_operations=batch, partition_key=session_key
)
async def list_sessions(self) -> list[str]:
"""List all session IDs stored in this provider's Cosmos container."""
await self._ensure_container_proxy()
query = (
"SELECT DISTINCT VALUE c.session_id FROM c WHERE c.source_id = @source_id"
)
parameters: list[dict[str, object]] = [
{"name": "@source_id", "value": self.source_id}
]
# without a partition key, it is automatically a cross-partition query
items = self._container_proxy.query_items(query=query, parameters=parameters) # type: ignore[union-attr]
session_ids: set[str] = set()
async for item in items:
if isinstance(item, str):
session_ids.add(item)
return sorted(session_ids)
async def close(self) -> None:
"""Close the underlying Cosmos client when this provider owns it."""
if self._owns_client and self._cosmos_client is not None:
await self._cosmos_client.close()
async def __aenter__(self) -> CosmosHistoryProvider:
"""Async context manager entry."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Async context manager exit."""
try:
await self.close()
except Exception:
if exc_type is None:
raise
async def _ensure_container_proxy(self) -> None:
"""Get or create the Cosmos DB container for storing messages."""
if self._container_proxy is not None:
return
if self._database_client is None:
raise RuntimeError("Cosmos database client is not initialized.")
self._container_proxy = (
await self._database_client.create_container_if_not_exists(
id=self.container_name,
partition_key=PartitionKey(path="/session_id"),
)
)
@staticmethod
def _session_partition_key(session_id: str | None) -> str:
if session_id:
return session_id
generated_session_id = str(uuid.uuid4())
logger.warning(
"Received empty session_id; generated temporary session id '%s' for Cosmos partition key.",
generated_session_id,
)
return generated_session_id
@@ -0,0 +1,93 @@
[project]
name = "agent-framework-azure-cosmos"
description = "Azure Cosmos DB history provider integration for Microsoft Agent Framework."
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
readme = "README.md"
requires-python = ">=3.10"
version = "1.0.0b260219"
license-files = ["LICENSE"]
urls.homepage = "https://aka.ms/agent-framework"
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
urls.issues = "https://github.com/microsoft/agent-framework/issues"
classifiers = [
"License :: OSI Approved :: MIT License",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Typing :: Typed",
]
dependencies = [
"agent-framework-core>=1.0.0rc1",
"azure-cosmos>=4.9.0",
]
[tool.uv]
prerelease = "if-necessary-or-explicit"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
"sys_platform == 'win32'"
]
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
[tool.coverage.run]
omit = [
"**/__init__.py"
]
[tool.pyright]
extends = "../../pyproject.toml"
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
python_version = "3.10"
ignore_missing_imports = true
disallow_untyped_defs = true
no_implicit_optional = true
check_untyped_defs = true
warn_return_any = true
show_error_codes = true
warn_unused_ignores = false
disallow_incomplete_defs = true
disallow_untyped_decorators = true
[tool.bandit]
targets = ["agent_framework_azure_cosmos"]
exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_cosmos"
test = "pytest --cov=agent_framework_azure_cosmos --cov-report=term-missing:skip-covered tests"
integration-tests = "pytest tests/test_cosmos_history_provider.py -m integration"
[build-system]
requires = ["flit-core >= 3.11,<4.0"]
build-backend = "flit_core.buildapi"
@@ -0,0 +1,20 @@
# Azure Cosmos DB Package Samples
This folder contains samples for `agent-framework-azure-cosmos`.
| File | Description |
| --- | --- |
| [`cosmos_history_provider.py`](cosmos_history_provider.py) | Demonstrates an Agent using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient` (project endpoint), provider-configured container name, and `session_id` partitioning. |
## Prerequisites
- `AZURE_COSMOS_ENDPOINT`
- `AZURE_COSMOS_DATABASE_NAME`
- `AZURE_COSMOS_CONTAINER_NAME`
- `AZURE_COSMOS_KEY` (or equivalent credential flow)
## Run
```bash
uv run --directory packages/azure-cosmos python samples/cosmos_history_provider.py
```
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft. All rights reserved.
"""Samples for the Azure Cosmos history provider package."""
@@ -0,0 +1,100 @@
# Copyright (c) Microsoft. All rights reserved.
# ruff: noqa: T201
import asyncio
import os
from agent_framework.azure import AzureOpenAIResponsesClient
from agent_framework_azure_cosmos import CosmosHistoryProvider
from azure.identity.aio import AzureCliCredential
from dotenv import load_dotenv
# Load environment variables from .env file.
load_dotenv()
"""
This sample demonstrates CosmosHistoryProvider as an agent context provider.
Key components:
- AzureOpenAIResponsesClient configured with an Azure AI project endpoint
- CosmosHistoryProvider configured for Cosmos DB-backed message history
- Provider-configured container name with session_id as partition key
Environment variables:
AZURE_AI_PROJECT_ENDPOINT
AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME
AZURE_COSMOS_ENDPOINT
AZURE_COSMOS_DATABASE_NAME
AZURE_COSMOS_CONTAINER_NAME
Optional:
AZURE_COSMOS_KEY
"""
async def main() -> None:
"""Run the Cosmos history provider sample with an Agent."""
project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT")
deployment_name = os.getenv("AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME")
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
if (
not project_endpoint
or not deployment_name
or not cosmos_endpoint
or not cosmos_database_name
or not cosmos_container_name
):
print(
"Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME, "
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME."
)
return
# 1. Create an Azure credential and Responses client using project endpoint auth.
async with AzureCliCredential() as credential:
client = AzureOpenAIResponsesClient(
project_endpoint=project_endpoint,
deployment_name=deployment_name,
credential=credential,
)
# 2. Create an agent that uses the history provider as a context provider.
async with (
CosmosHistoryProvider(
endpoint=cosmos_endpoint,
database_name=cosmos_database_name,
container_name=cosmos_container_name,
credential=cosmos_key or credential,
) as history_provider,
client.as_agent(
name="CosmosHistoryAgent",
instructions="You are a helpful assistant that remembers prior turns.",
context_providers=[history_provider],
default_options={"store": False},
) as agent,
):
# 3. Create a session (session_id is used as the partition key).
session = agent.create_session()
# 4. Run a multi-turn conversation; history is persisted by CosmosHistoryProvider.
response1 = await agent.run("My name is Ada and I enjoy distributed systems.", session=session)
print(f"Assistant: {response1.text}")
response2 = await agent.run("What do you remember about me?", session=session)
print(f"Assistant: {response2.text}")
print(f"Container: {history_provider.container_name}")
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
Assistant: Nice to meet you, Ada! Distributed systems are a fascinating area.
Assistant: You told me your name is Ada and that you enjoy distributed systems.
Container: <AZURE_COSMOS_CONTAINER_NAME>
"""
@@ -0,0 +1,409 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import os
import uuid
from collections.abc import AsyncIterator
from contextlib import suppress
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import agent_framework_azure_cosmos._history_provider as history_provider_module
import pytest
from agent_framework import AgentResponse, Message
from agent_framework._sessions import AgentSession, SessionContext
from agent_framework.exceptions import SettingNotFoundError
from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider
from azure.cosmos.aio import CosmosClient
from azure.cosmos.exceptions import CosmosResourceNotFoundError
skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif(
any(
os.getenv(name, "") == ""
for name in (
"AZURE_COSMOS_ENDPOINT",
"AZURE_COSMOS_KEY",
"AZURE_COSMOS_DATABASE_NAME",
"AZURE_COSMOS_CONTAINER_NAME",
)
),
reason=(
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_KEY, AZURE_COSMOS_DATABASE_NAME, and "
"AZURE_COSMOS_CONTAINER_NAME are required for Cosmos integration tests."
),
)
def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]:
async def _iterator() -> AsyncIterator[Any]:
for item in items:
yield item
return _iterator()
@pytest.fixture
def mock_container() -> MagicMock:
container = MagicMock()
container.query_items = MagicMock(return_value=_to_async_iter([]))
container.execute_item_batch = AsyncMock(return_value=[])
return container
@pytest.fixture
def mock_cosmos_client(mock_container: MagicMock) -> MagicMock:
database_client = MagicMock()
database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container)
client = MagicMock()
client.get_database_client.return_value = database_client
client.close = AsyncMock()
return client
class TestCosmosHistoryProviderInit:
def test_uses_provided_container_client(self, mock_container: MagicMock) -> None:
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
assert provider.source_id == "mem"
assert provider.load_messages is True
assert provider.store_outputs is True
assert provider.store_inputs is True
assert provider.database_name == ""
assert provider.container_name == ""
def test_uses_provided_cosmos_client(self, mock_cosmos_client: MagicMock) -> None:
provider = CosmosHistoryProvider(
source_id="mem",
cosmos_client=mock_cosmos_client,
database_name="db1",
container_name="history",
)
mock_cosmos_client.get_database_client.assert_called_once_with("db1")
assert provider.database_name == "db1"
assert provider.container_name == "history"
def test_missing_required_settings_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False)
monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False)
monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False)
monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False)
with pytest.raises(SettingNotFoundError, match="database_name"):
CosmosHistoryProvider()
def test_constructs_client_with_string_credential(
self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
) -> None:
mock_factory = MagicMock(return_value=mock_cosmos_client)
monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory)
CosmosHistoryProvider(
endpoint="https://account.documents.azure.com:443/",
credential="key-123",
database_name="db1",
container_name="history",
)
mock_factory.assert_called_once()
kwargs = mock_factory.call_args.kwargs
assert kwargs["url"] == "https://account.documents.azure.com:443/"
assert kwargs["credential"] == "key-123"
class TestCosmosHistoryProviderContainerConfig:
async def test_provider_container_name_is_used(self, mock_cosmos_client: MagicMock) -> None:
provider = CosmosHistoryProvider(
source_id="mem",
cosmos_client=mock_cosmos_client,
database_name="db1",
container_name="custom-history",
)
await provider.get_messages("session-123")
database_client = mock_cosmos_client.get_database_client.return_value
assert database_client.create_container_if_not_exists.await_count == 1
kwargs = database_client.create_container_if_not_exists.await_args.kwargs
assert kwargs["id"] == "custom-history"
class TestCosmosHistoryProviderGetMessages:
async def test_returns_deserialized_messages(self, mock_container: MagicMock) -> None:
msg1 = Message(role="user", contents=["Hello"])
msg2 = Message(role="assistant", contents=["Hi"])
mock_container.query_items.return_value = _to_async_iter([
{"message": msg1.to_dict()},
{"message": msg2.to_dict()},
])
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
messages = await provider.get_messages("s1")
assert len(messages) == 2
assert messages[0].role == "user"
assert messages[0].text == "Hello"
assert messages[1].role == "assistant"
assert messages[1].text == "Hi"
query_kwargs = mock_container.query_items.call_args.kwargs
assert query_kwargs["partition_key"] == "s1"
assert query_kwargs["query"] == (
"SELECT c.message FROM c "
"WHERE c.session_id = @session_id AND c.source_id = @source_id "
"ORDER BY c.sort_key ASC"
)
assert query_kwargs["parameters"] == [
{"name": "@session_id", "value": "s1"},
{"name": "@source_id", "value": "mem"},
]
async def test_empty_returns_empty(self, mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([])
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
messages = await provider.get_messages("s1")
assert messages == []
async def test_none_session_id_generates_guid_partition_key(
self, mock_container: MagicMock, caplog: pytest.LogCaptureFixture
) -> None:
mock_container.query_items.return_value = _to_async_iter([])
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
with caplog.at_level("WARNING"):
await provider.get_messages(None)
query_kwargs = mock_container.query_items.call_args.kwargs
session_key = query_kwargs["partition_key"]
assert isinstance(session_key, str)
assert session_key != ""
assert session_key != "default"
uuid.UUID(session_key)
assert query_kwargs["parameters"] == [
{"name": "@session_id", "value": session_key},
{"name": "@source_id", "value": "mem"},
]
assert "Received empty session_id" in caplog.text
async def test_skips_non_dict_message_payload(self, mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([{"message": "bad"}, {"message": None}])
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
messages = await provider.get_messages("s1")
assert messages == []
class TestCosmosHistoryProviderListSessions:
async def test_list_sessions_returns_unique_sorted_ids(self, mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter(["s2", "s1", "s1", "s3"])
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
sessions = await provider.list_sessions()
assert sessions == ["s1", "s2", "s3"]
kwargs = mock_container.query_items.call_args.kwargs
assert kwargs["query"] == "SELECT DISTINCT VALUE c.session_id FROM c WHERE c.source_id = @source_id"
assert kwargs["parameters"] == [{"name": "@source_id", "value": "mem"}]
class TestCosmosHistoryProviderSaveMessages:
async def test_saves_messages(self, mock_container: MagicMock) -> None:
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
messages = [Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Hi"])]
await provider.save_messages("s1", messages)
mock_container.execute_item_batch.assert_awaited_once()
batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"]
assert len(batch_operations) == 2
first_operation, first_args = batch_operations[0]
assert first_operation == "upsert"
first_document = first_args[0]
assert first_document["session_id"] == "s1"
assert first_document["message"]["role"] == "user"
assert mock_container.execute_item_batch.await_args.kwargs["partition_key"] == "s1"
async def test_empty_messages_noop(self, mock_container: MagicMock) -> None:
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
await provider.save_messages("s1", [])
mock_container.execute_item_batch.assert_not_awaited()
async def test_batches_when_message_count_exceeds_limit(self, mock_container: MagicMock) -> None:
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
messages = [Message(role="user", contents=[f"msg-{index}"]) for index in range(101)]
await provider.save_messages("s1", messages)
assert mock_container.execute_item_batch.await_count == 2
first_call = mock_container.execute_item_batch.await_args_list[0].kwargs
second_call = mock_container.execute_item_batch.await_args_list[1].kwargs
assert len(first_call["batch_operations"]) == 100
assert len(second_call["batch_operations"]) == 1
assert first_call["partition_key"] == "s1"
assert second_call["partition_key"] == "s1"
class TestCosmosHistoryProviderClear:
async def test_clear_deletes_all_session_items(self, mock_container: MagicMock) -> None:
mock_container.query_items.return_value = _to_async_iter([{"id": "1"}, {"id": "2"}])
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
await provider.clear("s1")
mock_container.execute_item_batch.assert_awaited_once()
batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"]
assert len(batch_operations) == 2
assert batch_operations[0] == ("delete", ("1",))
assert batch_operations[1] == ("delete", ("2",))
assert mock_container.execute_item_batch.await_args.kwargs["partition_key"] == "s1"
query_kwargs = mock_container.query_items.call_args.kwargs
assert query_kwargs["query"] == (
"SELECT c.id FROM c WHERE c.session_id = @session_id AND c.source_id = @source_id"
)
assert query_kwargs["parameters"] == [
{"name": "@session_id", "value": "s1"},
{"name": "@source_id", "value": "mem"},
]
class TestCosmosHistoryProviderBeforeAfterRun:
async def test_before_run_loads_history(self, mock_container: MagicMock) -> None:
msg = Message(role="user", contents=["old msg"])
mock_container.query_items.return_value = _to_async_iter([{"message": msg.to_dict()}])
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
session = AgentSession(session_id="test")
context = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1")
await provider.before_run(
agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
assert "mem" in context.context_messages
assert context.context_messages["mem"][0].text == "old msg"
async def test_after_run_stores_input_and_response(self, mock_container: MagicMock) -> None:
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
session = AgentSession(session_id="test")
context = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
context._response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])])
await provider.after_run(
agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
mock_container.execute_item_batch.assert_awaited_once()
batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"]
assert len(batch_operations) == 2
input_doc = batch_operations[0][1][0]
response_doc = batch_operations[1][1][0]
assert input_doc["message"]["role"] == "user"
assert input_doc["message"]["contents"][0]["text"] == "hi"
assert response_doc["message"]["role"] == "assistant"
assert response_doc["message"]["contents"][0]["text"] == "hello"
class TestCosmosHistoryProviderClose:
async def test_close_closes_owned_client(
self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
) -> None:
mock_factory = MagicMock(return_value=mock_cosmos_client)
monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory)
provider = CosmosHistoryProvider(
endpoint="https://account.documents.azure.com:443/",
credential="key-123",
database_name="db1",
container_name="history",
)
await provider.close()
mock_cosmos_client.close.assert_awaited_once()
async def test_close_does_not_close_external_client(self, mock_cosmos_client: MagicMock) -> None:
provider = CosmosHistoryProvider(
source_id="mem",
cosmos_client=mock_cosmos_client,
database_name="db1",
container_name="history",
)
await provider.close()
mock_cosmos_client.close.assert_not_awaited()
async def test_async_context_manager_closes_owned_client(
self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
) -> None:
mock_factory = MagicMock(return_value=mock_cosmos_client)
monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory)
async with CosmosHistoryProvider(
endpoint="https://account.documents.azure.com:443/",
credential="key-123",
database_name="db1",
container_name="history",
) as provider:
assert provider is not None
mock_cosmos_client.close.assert_awaited_once()
async def test_async_context_manager_preserves_original_exception(self, mock_container: MagicMock) -> None:
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
with patch.object(
provider, "close", AsyncMock(side_effect=RuntimeError("close failed"))
), pytest.raises(ValueError, match="inner error"):
async with provider:
raise ValueError("inner error")
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_cosmos_integration_tests_disabled
async def test_cosmos_history_provider_roundtrip_with_emulator() -> None:
endpoint = os.getenv("AZURE_COSMOS_ENDPOINT", "")
key = os.getenv("AZURE_COSMOS_KEY", "")
database_prefix = os.getenv("AZURE_COSMOS_DATABASE_NAME", "")
container_prefix = os.getenv("AZURE_COSMOS_CONTAINER_NAME", "")
unique = uuid.uuid4().hex[:8]
database_name = f"{database_prefix}-{unique}"
container_name = f"{container_prefix}-{unique}"
session_id = f"session-{unique}"
async with CosmosClient(url=endpoint, credential=key) as cosmos_client:
await cosmos_client.create_database_if_not_exists(id=database_name)
provider = CosmosHistoryProvider(
source_id="cosmos_integration",
cosmos_client=cosmos_client,
database_name=database_name,
container_name=container_name,
)
try:
await provider.save_messages(
session_id,
[
Message(role="user", contents=["Hello Cosmos"]),
Message(role="assistant", contents=["Hi from Cosmos"]),
],
)
stored_messages = await provider.get_messages(session_id)
assert [message.role for message in stored_messages] == ["user", "assistant"]
assert [message.text for message in stored_messages] == ["Hello Cosmos", "Hi from Cosmos"]
sessions = await provider.list_sessions()
assert session_id in sessions
await provider.clear(session_id)
assert await provider.get_messages(session_id) == []
finally:
with suppress(CosmosResourceNotFoundError):
await cosmos_client.delete_database(database_name)
+3
View File
@@ -76,6 +76,7 @@ agent-framework-core = { workspace = true }
agent-framework-a2a = { workspace = true }
agent-framework-ag-ui = { workspace = true }
agent-framework-azure-ai-search = { workspace = true }
agent-framework-azure-cosmos = { workspace = true }
agent-framework-anthropic = { workspace = true }
agent-framework-azure-ai = { workspace = true }
agent-framework-azurefunctions = { workspace = true }
@@ -238,6 +239,7 @@ check = ["check-packages", "samples-lint", "samples-syntax", "test", "markdown-c
[tool.poe.tasks.all-tests-cov]
cmd = """
pytest --import-mode=importlib
-m "not integration"
--cov=agent_framework
--cov=agent_framework_core
--cov=agent_framework_a2a
@@ -265,6 +267,7 @@ pytest --import-mode=importlib
[tool.poe.tasks.all-tests]
cmd = """
pytest --import-mode=importlib
-m "not integration"
--ignore-glob=packages/lab/**
--ignore-glob=packages/devui/**
-rs
+508 -444
View File
File diff suppressed because it is too large Load Diff