mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Azure Cosmos history provider package (#4271)
* Created cosmos history provider * add marker * Python: address Cosmos PR feedback - address provider/test/sample review feedback and cleanup typing - add cosmos integration test coverage and skip gating - add dedicated cosmos emulator jobs to python merge/integration workflows - switch cosmos workflow execution to package poe integration-tests task Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: handle empty Cosmos session id - replace default partition fallback for empty session_id - log warning and generate GUID when session_id is empty - update unit tests to validate GUID fallback behavior Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix sample * fix cross partition query --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
945933c351
commit
c37f74f898
@@ -247,6 +247,51 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
run: uv run --directory packages/azure-ai poe integration-tests -n logical --dist worksteal --timeout=120 --session-timeout=900 --timeout_method thread --retries 2 --retry-delay 5
|
||||
|
||||
# Azure Cosmos integration tests
|
||||
python-tests-cosmos:
|
||||
name: Python Integration Tests - Cosmos
|
||||
runs-on: ubuntu-latest
|
||||
environment: integration
|
||||
timeout-minutes: 60
|
||||
services:
|
||||
cosmosdb:
|
||||
image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview
|
||||
ports:
|
||||
- 8081:8081
|
||||
env:
|
||||
AZURE_COSMOS_ENDPOINT: "http://localhost:8081/"
|
||||
# Static Azure Cosmos DB emulator key (documented): https://learn.microsoft.com/en-us/azure/cosmos-db/emulator
|
||||
AZURE_COSMOS_KEY: "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
AZURE_COSMOS_DATABASE_NAME: "agent-framework-cosmos-it-db"
|
||||
AZURE_COSMOS_CONTAINER_NAME: "agent-framework-cosmos-it-container"
|
||||
defaults:
|
||||
run:
|
||||
working-directory: python
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ inputs.checkout-ref }}
|
||||
persist-credentials: false
|
||||
- name: Set up python and install the project
|
||||
id: python-setup
|
||||
uses: ./.github/actions/python-setup
|
||||
with:
|
||||
python-version: ${{ env.UV_PYTHON }}
|
||||
os: ${{ runner.os }}
|
||||
- name: Wait for Cosmos DB emulator
|
||||
run: |
|
||||
for i in {1..60}; do
|
||||
if curl --silent --show-error http://localhost:8081/ > /dev/null; then
|
||||
echo "Cosmos DB emulator is ready."
|
||||
exit 0
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
echo "Cosmos DB emulator did not become ready in time." >&2
|
||||
exit 1
|
||||
- name: Test with pytest (Cosmos integration)
|
||||
run: uv run --directory packages/azure-cosmos poe integration-tests -n logical --dist worksteal --timeout=120 --session-timeout=900 --timeout_method thread --retries 2 --retry-delay 5
|
||||
|
||||
python-integration-tests-check:
|
||||
if: always()
|
||||
runs-on: ubuntu-latest
|
||||
@@ -257,7 +302,8 @@ jobs:
|
||||
python-tests-azure-openai,
|
||||
python-tests-misc-integration,
|
||||
python-tests-functions,
|
||||
python-tests-azure-ai
|
||||
python-tests-azure-ai,
|
||||
python-tests-cosmos
|
||||
]
|
||||
steps:
|
||||
- name: Fail workflow if tests failed
|
||||
|
||||
@@ -38,6 +38,7 @@ jobs:
|
||||
miscChanged: ${{ steps.filter.outputs.misc }}
|
||||
functionsChanged: ${{ steps.filter.outputs.functions }}
|
||||
azureAiChanged: ${{ steps.filter.outputs.azure-ai }}
|
||||
cosmosChanged: ${{ steps.filter.outputs.cosmos }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dorny/paths-filter@v3
|
||||
@@ -67,6 +68,8 @@ jobs:
|
||||
- 'python/packages/durabletask/**'
|
||||
azure-ai:
|
||||
- 'python/packages/azure-ai/**'
|
||||
cosmos:
|
||||
- 'python/packages/azure-cosmos/**'
|
||||
# run only if 'python' files were changed
|
||||
- name: python tests
|
||||
if: steps.filter.outputs.python == 'true'
|
||||
@@ -390,6 +393,64 @@ jobs:
|
||||
|
||||
# TODO: Add python-tests-lab
|
||||
|
||||
# Azure Cosmos integration tests
|
||||
python-tests-cosmos:
|
||||
name: Python Tests - Cosmos Integration
|
||||
needs: paths-filter
|
||||
if: >
|
||||
github.event_name != 'pull_request' &&
|
||||
needs.paths-filter.outputs.pythonChanges == 'true' &&
|
||||
(github.event_name != 'merge_group' ||
|
||||
needs.paths-filter.outputs.cosmosChanged == 'true' ||
|
||||
needs.paths-filter.outputs.coreChanged == 'true')
|
||||
runs-on: ubuntu-latest
|
||||
environment: integration
|
||||
services:
|
||||
cosmosdb:
|
||||
image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview
|
||||
ports:
|
||||
- 8081:8081
|
||||
env:
|
||||
AZURE_COSMOS_ENDPOINT: "http://localhost:8081/"
|
||||
# Static Azure Cosmos DB emulator key (documented): https://learn.microsoft.com/en-us/azure/cosmos-db/emulator
|
||||
AZURE_COSMOS_KEY: "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
AZURE_COSMOS_DATABASE_NAME: "agent-framework-cosmos-it-db"
|
||||
AZURE_COSMOS_CONTAINER_NAME: "agent-framework-cosmos-it-container"
|
||||
defaults:
|
||||
run:
|
||||
working-directory: python
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up python and install the project
|
||||
id: python-setup
|
||||
uses: ./.github/actions/python-setup
|
||||
with:
|
||||
python-version: ${{ env.UV_PYTHON }}
|
||||
os: ${{ runner.os }}
|
||||
- name: Wait for Cosmos DB emulator
|
||||
run: |
|
||||
for i in {1..60}; do
|
||||
if curl --silent --show-error http://localhost:8081/ > /dev/null; then
|
||||
echo "Cosmos DB emulator is ready."
|
||||
exit 0
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
echo "Cosmos DB emulator did not become ready in time." >&2
|
||||
exit 1
|
||||
- name: Test with pytest (Cosmos integration)
|
||||
run: uv run --directory packages/azure-cosmos poe integration-tests -n logical --dist worksteal --timeout=120 --session-timeout=900 --timeout_method thread --retries 2 --retry-delay 5
|
||||
working-directory: ./python
|
||||
- name: Surface failing tests
|
||||
if: always()
|
||||
uses: pmeier/pytest-results-action@v0.7.2
|
||||
with:
|
||||
path: ./python/**.xml
|
||||
summary: true
|
||||
display-options: fEX
|
||||
fail-on-empty: false
|
||||
title: Cosmos integration test results
|
||||
|
||||
python-integration-tests-check:
|
||||
if: always()
|
||||
runs-on: ubuntu-latest
|
||||
@@ -401,6 +462,7 @@ jobs:
|
||||
python-tests-misc-integration,
|
||||
python-tests-functions,
|
||||
python-tests-azure-ai,
|
||||
python-tests-cosmos,
|
||||
]
|
||||
steps:
|
||||
- name: Fail workflow if tests failed
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# Azure Cosmos DB Package (agent-framework-azure-cosmos)
|
||||
|
||||
Azure Cosmos DB history provider integration for Agent Framework.
|
||||
|
||||
## Main Classes
|
||||
|
||||
- **`CosmosHistoryProvider`** - Persistent conversation history storage backed by Azure Cosmos DB
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
|
||||
provider = CosmosHistoryProvider(
|
||||
endpoint="https://<account>.documents.azure.com:443/",
|
||||
credential="<key-or-token-credential>",
|
||||
database_name="agent-framework",
|
||||
container_name="chat-history",
|
||||
)
|
||||
```
|
||||
|
||||
Container name is configured on the provider. `session_id` is used as the partition key.
|
||||
|
||||
## Import Path
|
||||
|
||||
```python
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
```
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -0,0 +1,38 @@
|
||||
# Get Started with Microsoft Agent Framework Azure Cosmos DB
|
||||
|
||||
Please install this package via pip:
|
||||
|
||||
```bash
|
||||
pip install agent-framework-azure-cosmos --pre
|
||||
```
|
||||
|
||||
## Azure Cosmos DB History Provider
|
||||
|
||||
The Azure Cosmos DB integration provides `CosmosHistoryProvider` for persistent conversation history storage.
|
||||
|
||||
### Basic Usage Example
|
||||
|
||||
```python
|
||||
from azure.identity.aio import DefaultAzureCredential
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
|
||||
provider = CosmosHistoryProvider(
|
||||
endpoint="https://<account>.documents.azure.com:443/",
|
||||
credential=DefaultAzureCredential(),
|
||||
database_name="agent-framework",
|
||||
container_name="chat-history",
|
||||
)
|
||||
```
|
||||
|
||||
Credentials follow the same pattern used by other Azure connectors in the repository:
|
||||
|
||||
- Pass a credential object (for example `DefaultAzureCredential`)
|
||||
- Or pass a key string directly
|
||||
- Or set `AZURE_COSMOS_KEY` in the environment
|
||||
|
||||
Container naming behavior:
|
||||
|
||||
- Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`)
|
||||
- `session_id` is used as the Cosmos partition key for reads/writes
|
||||
|
||||
See `samples/cosmos_history_provider.py` for a runnable package-local example.
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._history_provider import CosmosHistoryProvider
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
"CosmosHistoryProvider",
|
||||
"__version__",
|
||||
]
|
||||
@@ -0,0 +1,269 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Azure Cosmos DB history provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, TypedDict
|
||||
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
|
||||
from agent_framework._sessions import BaseHistoryProvider
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
|
||||
from azure.cosmos import PartitionKey
|
||||
from azure.cosmos.aio import ContainerProxy, CosmosClient, DatabaseProxy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureCosmosHistorySettings(TypedDict, total=False):
|
||||
"""Settings for CosmosHistoryProvider resolved from args and environment."""
|
||||
|
||||
endpoint: str | None
|
||||
database_name: str | None
|
||||
container_name: str | None
|
||||
key: SecretString | None
|
||||
|
||||
|
||||
class CosmosHistoryProvider(BaseHistoryProvider):
|
||||
"""Azure Cosmos DB-backed history provider using BaseHistoryProvider hooks."""
|
||||
|
||||
DEFAULT_SOURCE_ID: ClassVar[str] = "azure_cosmos_history"
|
||||
_BATCH_OPERATION_LIMIT: ClassVar[int] = 100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_id: str = DEFAULT_SOURCE_ID,
|
||||
*,
|
||||
load_messages: bool = True,
|
||||
store_outputs: bool = True,
|
||||
store_inputs: bool = True,
|
||||
store_context_messages: bool = False,
|
||||
store_context_from: set[str] | None = None,
|
||||
endpoint: str | None = None,
|
||||
database_name: str | None = None,
|
||||
container_name: str | None = None,
|
||||
credential: str | AzureCredentialTypes | None = None,
|
||||
cosmos_client: CosmosClient | None = None,
|
||||
container_client: ContainerProxy | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Azure Cosmos DB history provider.
|
||||
|
||||
Args:
|
||||
source_id: Unique identifier for this provider instance.
|
||||
load_messages: Whether to load messages before invocation.
|
||||
store_outputs: Whether to store response messages.
|
||||
store_inputs: Whether to store input messages.
|
||||
store_context_messages: Whether to store context from other providers.
|
||||
store_context_from: If set, only store context from these source_ids.
|
||||
endpoint: Cosmos DB account endpoint.
|
||||
Can be set via ``AZURE_COSMOS_ENDPOINT``.
|
||||
database_name: Cosmos DB database name.
|
||||
Can be set via ``AZURE_COSMOS_DATABASE_NAME``.
|
||||
container_name: Cosmos DB container name.
|
||||
Can be set via ``AZURE_COSMOS_CONTAINER_NAME``.
|
||||
credential: Credential to authenticate with Cosmos DB.
|
||||
Supports key string and Azure credential objects.
|
||||
Can be set via ``AZURE_COSMOS_KEY`` when omitted.
|
||||
cosmos_client: Pre-created Cosmos async client.
|
||||
container_client: Pre-created Cosmos container client for fixed-container usage.
|
||||
env_file_path: Path to environment file for loading settings.
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
"""
|
||||
super().__init__(
|
||||
source_id,
|
||||
load_messages=load_messages,
|
||||
store_outputs=store_outputs,
|
||||
store_inputs=store_inputs,
|
||||
store_context_messages=store_context_messages,
|
||||
store_context_from=store_context_from,
|
||||
)
|
||||
|
||||
self._cosmos_client: CosmosClient | None = cosmos_client
|
||||
self._container_proxy: ContainerProxy | None = container_client
|
||||
self._owns_client = False
|
||||
self._database_client: DatabaseProxy | None = None
|
||||
|
||||
if self._container_proxy is not None:
|
||||
self.database_name: str = database_name or ""
|
||||
self.container_name: str = container_name or ""
|
||||
return
|
||||
|
||||
required_fields: list[str] = ["database_name", "container_name"]
|
||||
if cosmos_client is None:
|
||||
required_fields.append("endpoint")
|
||||
if credential is None:
|
||||
required_fields.append("key")
|
||||
|
||||
settings = load_settings(
|
||||
AzureCosmosHistorySettings,
|
||||
env_prefix="AZURE_COSMOS_",
|
||||
required_fields=required_fields,
|
||||
endpoint=endpoint,
|
||||
database_name=database_name,
|
||||
container_name=container_name,
|
||||
key=credential if isinstance(credential, str) else None,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
self.database_name = settings["database_name"] # type: ignore[assignment]
|
||||
self.container_name = settings["container_name"] # type: ignore[assignment]
|
||||
if self._cosmos_client is None:
|
||||
self._cosmos_client = CosmosClient(
|
||||
url=settings["endpoint"], # type: ignore[arg-type]
|
||||
credential=credential or settings["key"].get_secret_value(), # type: ignore[arg-type,union-attr]
|
||||
user_agent_suffix=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
self._owns_client = True
|
||||
|
||||
self._database_client = self._cosmos_client.get_database_client(self.database_name)
|
||||
|
||||
|
||||
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
|
||||
"""Retrieve stored messages for this session from Azure Cosmos DB."""
|
||||
await self._ensure_container_proxy()
|
||||
session_key = self._session_partition_key(session_id)
|
||||
|
||||
query = (
|
||||
"SELECT c.message FROM c "
|
||||
"WHERE c.session_id = @session_id AND c.source_id = @source_id "
|
||||
"ORDER BY c.sort_key ASC"
|
||||
)
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@session_id", "value": session_key},
|
||||
{"name": "@source_id", "value": self.source_id},
|
||||
]
|
||||
items = self._container_proxy.query_items( # type: ignore[union-attr]
|
||||
query=query, parameters=parameters, partition_key=session_key
|
||||
)
|
||||
|
||||
messages: list[Message] = []
|
||||
async for item in items:
|
||||
message_payload = item.get("message")
|
||||
if isinstance(message_payload, dict):
|
||||
messages.append(Message.from_dict(message_payload))
|
||||
|
||||
return messages
|
||||
|
||||
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
|
||||
"""Persist messages for this session to Azure Cosmos DB."""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
await self._ensure_container_proxy()
|
||||
session_key = self._session_partition_key(session_id)
|
||||
|
||||
base_sort_key = time.time_ns()
|
||||
operations: list[tuple[str, tuple[dict[str, Any]]]] = []
|
||||
for index, message in enumerate(messages):
|
||||
document = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"session_id": session_key,
|
||||
"sort_key": base_sort_key + index,
|
||||
"source_id": self.source_id,
|
||||
"message": message.to_dict(),
|
||||
}
|
||||
operations.append(("upsert", (document,)))
|
||||
|
||||
for start in range(0, len(operations), self._BATCH_OPERATION_LIMIT):
|
||||
batch = operations[start : start + self._BATCH_OPERATION_LIMIT]
|
||||
await self._container_proxy.execute_item_batch( # type: ignore[union-attr]
|
||||
batch_operations=batch, partition_key=session_key
|
||||
)
|
||||
|
||||
async def clear(self, session_id: str | None) -> None:
|
||||
"""Clear all messages for a session from Azure Cosmos DB."""
|
||||
await self._ensure_container_proxy()
|
||||
session_key = self._session_partition_key(session_id)
|
||||
query = "SELECT c.id FROM c WHERE c.session_id = @session_id AND c.source_id = @source_id"
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@session_id", "value": session_key},
|
||||
{"name": "@source_id", "value": self.source_id},
|
||||
]
|
||||
items = self._container_proxy.query_items( # type: ignore[union-attr]
|
||||
query=query, parameters=parameters, partition_key=session_key
|
||||
)
|
||||
|
||||
delete_operations: list[tuple[str, tuple[str]]] = []
|
||||
async for item in items:
|
||||
item_id = item.get("id")
|
||||
if isinstance(item_id, str):
|
||||
delete_operations.append(("delete", (item_id,)))
|
||||
|
||||
for start in range(0, len(delete_operations), self._BATCH_OPERATION_LIMIT):
|
||||
batch = delete_operations[start : start + self._BATCH_OPERATION_LIMIT]
|
||||
await self._container_proxy.execute_item_batch( # type: ignore[union-attr]
|
||||
batch_operations=batch, partition_key=session_key
|
||||
)
|
||||
|
||||
async def list_sessions(self) -> list[str]:
|
||||
"""List all session IDs stored in this provider's Cosmos container."""
|
||||
await self._ensure_container_proxy()
|
||||
query = (
|
||||
"SELECT DISTINCT VALUE c.session_id FROM c WHERE c.source_id = @source_id"
|
||||
)
|
||||
parameters: list[dict[str, object]] = [
|
||||
{"name": "@source_id", "value": self.source_id}
|
||||
]
|
||||
# without a partition key, it is automatically a cross-partition query
|
||||
items = self._container_proxy.query_items(query=query, parameters=parameters) # type: ignore[union-attr]
|
||||
|
||||
session_ids: set[str] = set()
|
||||
async for item in items:
|
||||
if isinstance(item, str):
|
||||
session_ids.add(item)
|
||||
return sorted(session_ids)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying Cosmos client when this provider owns it."""
|
||||
if self._owns_client and self._cosmos_client is not None:
|
||||
await self._cosmos_client.close()
|
||||
|
||||
async def __aenter__(self) -> CosmosHistoryProvider:
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit."""
|
||||
try:
|
||||
await self.close()
|
||||
except Exception:
|
||||
if exc_type is None:
|
||||
raise
|
||||
|
||||
async def _ensure_container_proxy(self) -> None:
|
||||
"""Get or create the Cosmos DB container for storing messages."""
|
||||
if self._container_proxy is not None:
|
||||
return
|
||||
if self._database_client is None:
|
||||
raise RuntimeError("Cosmos database client is not initialized.")
|
||||
|
||||
self._container_proxy = (
|
||||
await self._database_client.create_container_if_not_exists(
|
||||
id=self.container_name,
|
||||
partition_key=PartitionKey(path="/session_id"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _session_partition_key(session_id: str | None) -> str:
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
generated_session_id = str(uuid.uuid4())
|
||||
logger.warning(
|
||||
"Received empty session_id; generated temporary session id '%s' for Cosmos partition key.",
|
||||
generated_session_id,
|
||||
)
|
||||
return generated_session_id
|
||||
@@ -0,0 +1,93 @@
|
||||
[project]
|
||||
name = "agent-framework-azure-cosmos"
|
||||
description = "Azure Cosmos DB history provider integration for Microsoft Agent Framework."
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "1.0.0b260219"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://aka.ms/agent-framework"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.0.0rc1",
|
||||
"azure-cosmos>=4.9.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"**/__init__.py"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
|
||||
[tool.bandit]
|
||||
targets = ["agent_framework_azure_cosmos"]
|
||||
exclude_dirs = ["tests"]
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_cosmos"
|
||||
test = "pytest --cov=agent_framework_azure_cosmos --cov-report=term-missing:skip-covered tests"
|
||||
integration-tests = "pytest tests/test_cosmos_history_provider.py -m integration"
|
||||
|
||||
[build-system]
|
||||
requires = ["flit-core >= 3.11,<4.0"]
|
||||
build-backend = "flit_core.buildapi"
|
||||
@@ -0,0 +1,20 @@
|
||||
# Azure Cosmos DB Package Samples
|
||||
|
||||
This folder contains samples for `agent-framework-azure-cosmos`.
|
||||
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| [`cosmos_history_provider.py`](cosmos_history_provider.py) | Demonstrates an Agent using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient` (project endpoint), provider-configured container name, and `session_id` partitioning. |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- `AZURE_COSMOS_ENDPOINT`
|
||||
- `AZURE_COSMOS_DATABASE_NAME`
|
||||
- `AZURE_COSMOS_CONTAINER_NAME`
|
||||
- `AZURE_COSMOS_KEY` (or equivalent credential flow)
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
uv run --directory packages/azure-cosmos python samples/cosmos_history_provider.py
|
||||
```
|
||||
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Samples for the Azure Cosmos history provider package."""
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# ruff: noqa: T201
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||
from agent_framework_azure_cosmos import CosmosHistoryProvider
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file.
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
This sample demonstrates CosmosHistoryProvider as an agent context provider.
|
||||
|
||||
Key components:
|
||||
- AzureOpenAIResponsesClient configured with an Azure AI project endpoint
|
||||
- CosmosHistoryProvider configured for Cosmos DB-backed message history
|
||||
- Provider-configured container name with session_id as partition key
|
||||
|
||||
Environment variables:
|
||||
AZURE_AI_PROJECT_ENDPOINT
|
||||
AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME
|
||||
AZURE_COSMOS_ENDPOINT
|
||||
AZURE_COSMOS_DATABASE_NAME
|
||||
AZURE_COSMOS_CONTAINER_NAME
|
||||
Optional:
|
||||
AZURE_COSMOS_KEY
|
||||
"""
|
||||
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run the Cosmos history provider sample with an Agent."""
|
||||
project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT")
|
||||
deployment_name = os.getenv("AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME")
|
||||
cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT")
|
||||
cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME")
|
||||
cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME")
|
||||
cosmos_key = os.getenv("AZURE_COSMOS_KEY")
|
||||
|
||||
if (
|
||||
not project_endpoint
|
||||
or not deployment_name
|
||||
or not cosmos_endpoint
|
||||
or not cosmos_database_name
|
||||
or not cosmos_container_name
|
||||
):
|
||||
print(
|
||||
"Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME, "
|
||||
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME."
|
||||
)
|
||||
return
|
||||
|
||||
# 1. Create an Azure credential and Responses client using project endpoint auth.
|
||||
async with AzureCliCredential() as credential:
|
||||
client = AzureOpenAIResponsesClient(
|
||||
project_endpoint=project_endpoint,
|
||||
deployment_name=deployment_name,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
# 2. Create an agent that uses the history provider as a context provider.
|
||||
async with (
|
||||
CosmosHistoryProvider(
|
||||
endpoint=cosmos_endpoint,
|
||||
database_name=cosmos_database_name,
|
||||
container_name=cosmos_container_name,
|
||||
credential=cosmos_key or credential,
|
||||
) as history_provider,
|
||||
client.as_agent(
|
||||
name="CosmosHistoryAgent",
|
||||
instructions="You are a helpful assistant that remembers prior turns.",
|
||||
context_providers=[history_provider],
|
||||
default_options={"store": False},
|
||||
) as agent,
|
||||
):
|
||||
# 3. Create a session (session_id is used as the partition key).
|
||||
session = agent.create_session()
|
||||
|
||||
# 4. Run a multi-turn conversation; history is persisted by CosmosHistoryProvider.
|
||||
response1 = await agent.run("My name is Ada and I enjoy distributed systems.", session=session)
|
||||
print(f"Assistant: {response1.text}")
|
||||
|
||||
response2 = await agent.run("What do you remember about me?", session=session)
|
||||
print(f"Assistant: {response2.text}")
|
||||
print(f"Container: {history_provider.container_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
Assistant: Nice to meet you, Ada! Distributed systems are a fascinating area.
|
||||
Assistant: You told me your name is Ada and that you enjoy distributed systems.
|
||||
Container: <AZURE_COSMOS_CONTAINER_NAME>
|
||||
"""
|
||||
@@ -0,0 +1,409 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import agent_framework_azure_cosmos._history_provider as history_provider_module
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, Message
|
||||
from agent_framework._sessions import AgentSession, SessionContext
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider
|
||||
from azure.cosmos.aio import CosmosClient
|
||||
from azure.cosmos.exceptions import CosmosResourceNotFoundError
|
||||
|
||||
skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif(
|
||||
any(
|
||||
os.getenv(name, "") == ""
|
||||
for name in (
|
||||
"AZURE_COSMOS_ENDPOINT",
|
||||
"AZURE_COSMOS_KEY",
|
||||
"AZURE_COSMOS_DATABASE_NAME",
|
||||
"AZURE_COSMOS_CONTAINER_NAME",
|
||||
)
|
||||
),
|
||||
reason=(
|
||||
"AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_KEY, AZURE_COSMOS_DATABASE_NAME, and "
|
||||
"AZURE_COSMOS_CONTAINER_NAME are required for Cosmos integration tests."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]:
|
||||
async def _iterator() -> AsyncIterator[Any]:
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
return _iterator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_container() -> MagicMock:
|
||||
container = MagicMock()
|
||||
container.query_items = MagicMock(return_value=_to_async_iter([]))
|
||||
container.execute_item_batch = AsyncMock(return_value=[])
|
||||
return container
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cosmos_client(mock_container: MagicMock) -> MagicMock:
|
||||
database_client = MagicMock()
|
||||
database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container)
|
||||
|
||||
client = MagicMock()
|
||||
client.get_database_client.return_value = database_client
|
||||
client.close = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderInit:
|
||||
def test_uses_provided_container_client(self, mock_container: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
assert provider.source_id == "mem"
|
||||
assert provider.load_messages is True
|
||||
assert provider.store_outputs is True
|
||||
assert provider.store_inputs is True
|
||||
assert provider.database_name == ""
|
||||
assert provider.container_name == ""
|
||||
|
||||
def test_uses_provided_cosmos_client(self, mock_cosmos_client: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(
|
||||
source_id="mem",
|
||||
cosmos_client=mock_cosmos_client,
|
||||
database_name="db1",
|
||||
container_name="history",
|
||||
)
|
||||
|
||||
mock_cosmos_client.get_database_client.assert_called_once_with("db1")
|
||||
assert provider.database_name == "db1"
|
||||
assert provider.container_name == "history"
|
||||
|
||||
def test_missing_required_settings_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False)
|
||||
monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False)
|
||||
monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False)
|
||||
monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False)
|
||||
|
||||
with pytest.raises(SettingNotFoundError, match="database_name"):
|
||||
CosmosHistoryProvider()
|
||||
|
||||
def test_constructs_client_with_string_credential(
|
||||
self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
|
||||
) -> None:
|
||||
mock_factory = MagicMock(return_value=mock_cosmos_client)
|
||||
monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory)
|
||||
|
||||
CosmosHistoryProvider(
|
||||
endpoint="https://account.documents.azure.com:443/",
|
||||
credential="key-123",
|
||||
database_name="db1",
|
||||
container_name="history",
|
||||
)
|
||||
|
||||
mock_factory.assert_called_once()
|
||||
kwargs = mock_factory.call_args.kwargs
|
||||
assert kwargs["url"] == "https://account.documents.azure.com:443/"
|
||||
assert kwargs["credential"] == "key-123"
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderContainerConfig:
|
||||
async def test_provider_container_name_is_used(self, mock_cosmos_client: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(
|
||||
source_id="mem",
|
||||
cosmos_client=mock_cosmos_client,
|
||||
database_name="db1",
|
||||
container_name="custom-history",
|
||||
)
|
||||
|
||||
await provider.get_messages("session-123")
|
||||
|
||||
database_client = mock_cosmos_client.get_database_client.return_value
|
||||
assert database_client.create_container_if_not_exists.await_count == 1
|
||||
kwargs = database_client.create_container_if_not_exists.await_args.kwargs
|
||||
assert kwargs["id"] == "custom-history"
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderGetMessages:
|
||||
async def test_returns_deserialized_messages(self, mock_container: MagicMock) -> None:
|
||||
msg1 = Message(role="user", contents=["Hello"])
|
||||
msg2 = Message(role="assistant", contents=["Hi"])
|
||||
mock_container.query_items.return_value = _to_async_iter([
|
||||
{"message": msg1.to_dict()},
|
||||
{"message": msg2.to_dict()},
|
||||
])
|
||||
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
messages = await provider.get_messages("s1")
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == "user"
|
||||
assert messages[0].text == "Hello"
|
||||
assert messages[1].role == "assistant"
|
||||
assert messages[1].text == "Hi"
|
||||
query_kwargs = mock_container.query_items.call_args.kwargs
|
||||
assert query_kwargs["partition_key"] == "s1"
|
||||
assert query_kwargs["query"] == (
|
||||
"SELECT c.message FROM c "
|
||||
"WHERE c.session_id = @session_id AND c.source_id = @source_id "
|
||||
"ORDER BY c.sort_key ASC"
|
||||
)
|
||||
assert query_kwargs["parameters"] == [
|
||||
{"name": "@session_id", "value": "s1"},
|
||||
{"name": "@source_id", "value": "mem"},
|
||||
]
|
||||
|
||||
async def test_empty_returns_empty(self, mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
messages = await provider.get_messages("s1")
|
||||
|
||||
assert messages == []
|
||||
|
||||
async def test_none_session_id_generates_guid_partition_key(
|
||||
self, mock_container: MagicMock, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([])
|
||||
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
with caplog.at_level("WARNING"):
|
||||
await provider.get_messages(None)
|
||||
|
||||
query_kwargs = mock_container.query_items.call_args.kwargs
|
||||
session_key = query_kwargs["partition_key"]
|
||||
assert isinstance(session_key, str)
|
||||
assert session_key != ""
|
||||
assert session_key != "default"
|
||||
uuid.UUID(session_key)
|
||||
assert query_kwargs["parameters"] == [
|
||||
{"name": "@session_id", "value": session_key},
|
||||
{"name": "@source_id", "value": "mem"},
|
||||
]
|
||||
assert "Received empty session_id" in caplog.text
|
||||
|
||||
async def test_skips_non_dict_message_payload(self, mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([{"message": "bad"}, {"message": None}])
|
||||
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
messages = await provider.get_messages("s1")
|
||||
|
||||
assert messages == []
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderListSessions:
|
||||
async def test_list_sessions_returns_unique_sorted_ids(self, mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter(["s2", "s1", "s1", "s3"])
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
|
||||
sessions = await provider.list_sessions()
|
||||
|
||||
assert sessions == ["s1", "s2", "s3"]
|
||||
kwargs = mock_container.query_items.call_args.kwargs
|
||||
assert kwargs["query"] == "SELECT DISTINCT VALUE c.session_id FROM c WHERE c.source_id = @source_id"
|
||||
assert kwargs["parameters"] == [{"name": "@source_id", "value": "mem"}]
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderSaveMessages:
|
||||
async def test_saves_messages(self, mock_container: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
messages = [Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Hi"])]
|
||||
|
||||
await provider.save_messages("s1", messages)
|
||||
|
||||
mock_container.execute_item_batch.assert_awaited_once()
|
||||
batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"]
|
||||
assert len(batch_operations) == 2
|
||||
first_operation, first_args = batch_operations[0]
|
||||
assert first_operation == "upsert"
|
||||
first_document = first_args[0]
|
||||
assert first_document["session_id"] == "s1"
|
||||
assert first_document["message"]["role"] == "user"
|
||||
assert mock_container.execute_item_batch.await_args.kwargs["partition_key"] == "s1"
|
||||
|
||||
async def test_empty_messages_noop(self, mock_container: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
|
||||
await provider.save_messages("s1", [])
|
||||
|
||||
mock_container.execute_item_batch.assert_not_awaited()
|
||||
|
||||
async def test_batches_when_message_count_exceeds_limit(self, mock_container: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
messages = [Message(role="user", contents=[f"msg-{index}"]) for index in range(101)]
|
||||
|
||||
await provider.save_messages("s1", messages)
|
||||
|
||||
assert mock_container.execute_item_batch.await_count == 2
|
||||
first_call = mock_container.execute_item_batch.await_args_list[0].kwargs
|
||||
second_call = mock_container.execute_item_batch.await_args_list[1].kwargs
|
||||
assert len(first_call["batch_operations"]) == 100
|
||||
assert len(second_call["batch_operations"]) == 1
|
||||
assert first_call["partition_key"] == "s1"
|
||||
assert second_call["partition_key"] == "s1"
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderClear:
|
||||
async def test_clear_deletes_all_session_items(self, mock_container: MagicMock) -> None:
|
||||
mock_container.query_items.return_value = _to_async_iter([{"id": "1"}, {"id": "2"}])
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
|
||||
await provider.clear("s1")
|
||||
|
||||
mock_container.execute_item_batch.assert_awaited_once()
|
||||
batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"]
|
||||
assert len(batch_operations) == 2
|
||||
assert batch_operations[0] == ("delete", ("1",))
|
||||
assert batch_operations[1] == ("delete", ("2",))
|
||||
assert mock_container.execute_item_batch.await_args.kwargs["partition_key"] == "s1"
|
||||
query_kwargs = mock_container.query_items.call_args.kwargs
|
||||
assert query_kwargs["query"] == (
|
||||
"SELECT c.id FROM c WHERE c.session_id = @session_id AND c.source_id = @source_id"
|
||||
)
|
||||
assert query_kwargs["parameters"] == [
|
||||
{"name": "@session_id", "value": "s1"},
|
||||
{"name": "@source_id", "value": "mem"},
|
||||
]
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderBeforeAfterRun:
|
||||
async def test_before_run_loads_history(self, mock_container: MagicMock) -> None:
|
||||
msg = Message(role="user", contents=["old msg"])
|
||||
mock_container.query_items.return_value = _to_async_iter([{"message": msg.to_dict()}])
|
||||
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
session = AgentSession(session_id="test")
|
||||
context = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1")
|
||||
|
||||
await provider.before_run(
|
||||
agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert "mem" in context.context_messages
|
||||
assert context.context_messages["mem"][0].text == "old msg"
|
||||
|
||||
async def test_after_run_stores_input_and_response(self, mock_container: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
session = AgentSession(session_id="test")
|
||||
context = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
|
||||
context._response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])])
|
||||
|
||||
await provider.after_run(
|
||||
agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
mock_container.execute_item_batch.assert_awaited_once()
|
||||
batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"]
|
||||
assert len(batch_operations) == 2
|
||||
input_doc = batch_operations[0][1][0]
|
||||
response_doc = batch_operations[1][1][0]
|
||||
assert input_doc["message"]["role"] == "user"
|
||||
assert input_doc["message"]["contents"][0]["text"] == "hi"
|
||||
assert response_doc["message"]["role"] == "assistant"
|
||||
assert response_doc["message"]["contents"][0]["text"] == "hello"
|
||||
|
||||
|
||||
class TestCosmosHistoryProviderClose:
|
||||
async def test_close_closes_owned_client(
|
||||
self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
|
||||
) -> None:
|
||||
mock_factory = MagicMock(return_value=mock_cosmos_client)
|
||||
monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory)
|
||||
|
||||
provider = CosmosHistoryProvider(
|
||||
endpoint="https://account.documents.azure.com:443/",
|
||||
credential="key-123",
|
||||
database_name="db1",
|
||||
container_name="history",
|
||||
)
|
||||
|
||||
await provider.close()
|
||||
|
||||
mock_cosmos_client.close.assert_awaited_once()
|
||||
|
||||
async def test_close_does_not_close_external_client(self, mock_cosmos_client: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(
|
||||
source_id="mem",
|
||||
cosmos_client=mock_cosmos_client,
|
||||
database_name="db1",
|
||||
container_name="history",
|
||||
)
|
||||
|
||||
await provider.close()
|
||||
|
||||
mock_cosmos_client.close.assert_not_awaited()
|
||||
|
||||
async def test_async_context_manager_closes_owned_client(
|
||||
self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock
|
||||
) -> None:
|
||||
mock_factory = MagicMock(return_value=mock_cosmos_client)
|
||||
monkeypatch.setattr(history_provider_module, "CosmosClient", mock_factory)
|
||||
|
||||
async with CosmosHistoryProvider(
|
||||
endpoint="https://account.documents.azure.com:443/",
|
||||
credential="key-123",
|
||||
database_name="db1",
|
||||
container_name="history",
|
||||
) as provider:
|
||||
assert provider is not None
|
||||
|
||||
mock_cosmos_client.close.assert_awaited_once()
|
||||
|
||||
async def test_async_context_manager_preserves_original_exception(self, mock_container: MagicMock) -> None:
|
||||
provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container)
|
||||
|
||||
with patch.object(
|
||||
provider, "close", AsyncMock(side_effect=RuntimeError("close failed"))
|
||||
), pytest.raises(ValueError, match="inner error"):
|
||||
async with provider:
|
||||
raise ValueError("inner error")
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_cosmos_integration_tests_disabled
|
||||
async def test_cosmos_history_provider_roundtrip_with_emulator() -> None:
|
||||
endpoint = os.getenv("AZURE_COSMOS_ENDPOINT", "")
|
||||
key = os.getenv("AZURE_COSMOS_KEY", "")
|
||||
database_prefix = os.getenv("AZURE_COSMOS_DATABASE_NAME", "")
|
||||
container_prefix = os.getenv("AZURE_COSMOS_CONTAINER_NAME", "")
|
||||
unique = uuid.uuid4().hex[:8]
|
||||
database_name = f"{database_prefix}-{unique}"
|
||||
container_name = f"{container_prefix}-{unique}"
|
||||
session_id = f"session-{unique}"
|
||||
|
||||
async with CosmosClient(url=endpoint, credential=key) as cosmos_client:
|
||||
await cosmos_client.create_database_if_not_exists(id=database_name)
|
||||
provider = CosmosHistoryProvider(
|
||||
source_id="cosmos_integration",
|
||||
cosmos_client=cosmos_client,
|
||||
database_name=database_name,
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await provider.save_messages(
|
||||
session_id,
|
||||
[
|
||||
Message(role="user", contents=["Hello Cosmos"]),
|
||||
Message(role="assistant", contents=["Hi from Cosmos"]),
|
||||
],
|
||||
)
|
||||
|
||||
stored_messages = await provider.get_messages(session_id)
|
||||
assert [message.role for message in stored_messages] == ["user", "assistant"]
|
||||
assert [message.text for message in stored_messages] == ["Hello Cosmos", "Hi from Cosmos"]
|
||||
|
||||
sessions = await provider.list_sessions()
|
||||
assert session_id in sessions
|
||||
|
||||
await provider.clear(session_id)
|
||||
assert await provider.get_messages(session_id) == []
|
||||
finally:
|
||||
with suppress(CosmosResourceNotFoundError):
|
||||
await cosmos_client.delete_database(database_name)
|
||||
@@ -76,6 +76,7 @@ agent-framework-core = { workspace = true }
|
||||
agent-framework-a2a = { workspace = true }
|
||||
agent-framework-ag-ui = { workspace = true }
|
||||
agent-framework-azure-ai-search = { workspace = true }
|
||||
agent-framework-azure-cosmos = { workspace = true }
|
||||
agent-framework-anthropic = { workspace = true }
|
||||
agent-framework-azure-ai = { workspace = true }
|
||||
agent-framework-azurefunctions = { workspace = true }
|
||||
@@ -238,6 +239,7 @@ check = ["check-packages", "samples-lint", "samples-syntax", "test", "markdown-c
|
||||
[tool.poe.tasks.all-tests-cov]
|
||||
cmd = """
|
||||
pytest --import-mode=importlib
|
||||
-m "not integration"
|
||||
--cov=agent_framework
|
||||
--cov=agent_framework_core
|
||||
--cov=agent_framework_a2a
|
||||
@@ -265,6 +267,7 @@ pytest --import-mode=importlib
|
||||
[tool.poe.tasks.all-tests]
|
||||
cmd = """
|
||||
pytest --import-mode=importlib
|
||||
-m "not integration"
|
||||
--ignore-glob=packages/lab/**
|
||||
--ignore-glob=packages/devui/**
|
||||
-rs
|
||||
|
||||
Generated
+508
-444
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user