mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Mistral AI embedding client package (#5480)
* Python: Add Mistral AI embedding client package Signed-off-by: Daria Korenieva <daric2612@gmail.com> * Address review feedback: fix dimensions check, sort embeddings by index, align docs Signed-off-by: Daria Korenieva <daric2612@gmail.com> * Address review feedback: downgrade to alpha, remove integration tests - Change version to 1.0.0a260505 (alpha) - Update classifier to Development Status :: 3 - Alpha - Update PACKAGE_STATUS.md to alpha - Remove Mistral from integration test workflows (no API keys yet) Signed-off-by: Daria Korenieva <daric2612@gmail.com> * Add samples directory for alpha package compliance Per python-package-management skill: alpha packages must include samples inside the package directory. Signed-off-by: Daria Korenieva <daric2612@gmail.com> * Fix ruff formatting in sample file Signed-off-by: Daria Korenieva <daric2612@gmail.com> --------- Signed-off-by: Daria Korenieva <daric2612@gmail.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
1fccf16f11
commit
d2d5384f28
@@ -44,6 +44,9 @@ GEMINI_MODEL=""
|
||||
# Ollama
|
||||
OLLAMA_ENDPOINT=""
|
||||
OLLAMA_MODEL=""
|
||||
# Mistral AI
|
||||
MISTRAL_API_KEY=""
|
||||
MISTRAL_EMBEDDING_MODEL=""
|
||||
# Observability (instrumentation is enabled by default; set "ENABLE_INSTRUMENTATION" to "false" to opt out)
|
||||
ENABLE_SENSITIVE_DATA=true
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317/"
|
||||
|
||||
@@ -37,6 +37,7 @@ Status is grouped into these buckets:
|
||||
| `agent-framework-hyperlight` | `python/packages/hyperlight` | `beta` |
|
||||
| `agent-framework-lab` | `python/packages/lab` | `beta` |
|
||||
| `agent-framework-mem0` | `python/packages/mem0` | `beta` |
|
||||
| `agent-framework-mistral` | `python/packages/mistral` | `alpha` |
|
||||
| `agent-framework-monty` | `python/packages/monty` | `alpha` |
|
||||
| `agent-framework-ollama` | `python/packages/ollama` | `beta` |
|
||||
| `agent-framework-openai` | `python/packages/openai` | `released` |
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
# Mistral Package (agent-framework-mistral)
|
||||
|
||||
Integration with Mistral AI for embedding generation.
|
||||
|
||||
## Main Classes
|
||||
|
||||
- **`MistralEmbeddingClient`** - Embedding client for Mistral AI models
|
||||
- **`MistralEmbeddingOptions`** - Options TypedDict for Mistral-specific embedding parameters
|
||||
- **`MistralEmbeddingSettings`** - TypedDict settings for Mistral configuration
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
# Requires MISTRAL_API_KEY environment variable (or pass api_key= directly)
|
||||
client = MistralEmbeddingClient(model="mistral-embed")
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
print(result[0].vector)
|
||||
```
|
||||
|
||||
## Import Path
|
||||
|
||||
```python
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
```
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -0,0 +1,42 @@
|
||||
# Get Started with Microsoft Agent Framework Mistral AI
|
||||
|
||||
Please install this package:
|
||||
|
||||
```bash
|
||||
pip install agent-framework-mistral --pre
|
||||
```
|
||||
|
||||
and see the [README](https://github.com/microsoft/agent-framework/tree/main/python/README.md) for more information.
|
||||
|
||||
## Embedding Client
|
||||
|
||||
The `MistralEmbeddingClient` provides embedding generation using Mistral AI models.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```python
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
# Using environment variables (MISTRAL_API_KEY, MISTRAL_EMBEDDING_MODEL)
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# Or passing parameters directly
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="your-api-key",
|
||||
)
|
||||
|
||||
# Generate embeddings
|
||||
result = await client.get_embeddings(["Hello, world!", "How are you?"])
|
||||
for embedding in result:
|
||||
print(f"Dimensions: {embedding.dimensions}")
|
||||
print(f"Vector: {embedding.vector[:5]}...")
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
| Environment Variable | Description |
|
||||
|---|---|
|
||||
| `MISTRAL_API_KEY` | Your Mistral AI API key |
|
||||
| `MISTRAL_EMBEDDING_MODEL` | Embedding model name (e.g., `mistral-embed`) |
|
||||
| `MISTRAL_SERVER_URL` | Optional server URL override |
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._embedding_client import MistralEmbeddingClient, MistralEmbeddingOptions, MistralEmbeddingSettings
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
"MistralEmbeddingClient",
|
||||
"MistralEmbeddingOptions",
|
||||
"MistralEmbeddingSettings",
|
||||
"__version__",
|
||||
]
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
|
||||
from agent_framework import (
|
||||
BaseEmbeddingClient,
|
||||
Embedding,
|
||||
EmbeddingGenerationOptions,
|
||||
GeneratedEmbeddings,
|
||||
UsageDetails,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework.observability import EmbeddingTelemetryLayer
|
||||
from mistralai.client import Mistral
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework.mistral")
|
||||
|
||||
|
||||
class MistralEmbeddingOptions(EmbeddingGenerationOptions, total=False):
|
||||
"""Mistral AI-specific embedding options.
|
||||
|
||||
Extends EmbeddingGenerationOptions with Mistral-specific fields.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingOptions
|
||||
|
||||
options: MistralEmbeddingOptions = {
|
||||
"model": "mistral-embed",
|
||||
"dimensions": 1024,
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
MistralEmbeddingOptionsT = TypeVar(
|
||||
"MistralEmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="MistralEmbeddingOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class MistralEmbeddingSettings(TypedDict, total=False):
|
||||
"""Mistral AI embedding settings.
|
||||
|
||||
Fields:
|
||||
api_key: Mistral API key. Resolved from ``MISTRAL_API_KEY``.
|
||||
embedding_model: Embedding model name. Resolved from ``MISTRAL_EMBEDDING_MODEL``.
|
||||
server_url: Optional server URL override. Resolved from ``MISTRAL_SERVER_URL``.
|
||||
"""
|
||||
|
||||
api_key: str | None
|
||||
embedding_model: str | None
|
||||
server_url: str | None
|
||||
|
||||
|
||||
class RawMistralEmbeddingClient(
|
||||
BaseEmbeddingClient[str, list[float], MistralEmbeddingOptionsT],
|
||||
Generic[MistralEmbeddingOptionsT],
|
||||
):
|
||||
"""Raw Mistral AI embedding client without telemetry.
|
||||
|
||||
Keyword Args:
|
||||
model: The Mistral embedding model (e.g. "mistral-embed").
|
||||
Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``.
|
||||
api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable.
|
||||
server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL``
|
||||
environment variable, or the Mistral default.
|
||||
client: Optional pre-configured ``Mistral`` client instance.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Path to ``.env`` file for settings.
|
||||
env_file_encoding: Encoding for ``.env`` file.
|
||||
"""
|
||||
|
||||
INJECTABLE: ClassVar[set[str]] = {"client"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
api_key: str | SecretString | None = None,
|
||||
server_url: str | None = None,
|
||||
client: Mistral | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a raw Mistral AI embedding client."""
|
||||
mistral_settings = load_settings(
|
||||
MistralEmbeddingSettings,
|
||||
env_prefix="MISTRAL_",
|
||||
required_fields=["embedding_model", "api_key"],
|
||||
api_key=str(api_key) if isinstance(api_key, SecretString) else api_key,
|
||||
embedding_model=model,
|
||||
server_url=server_url,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
self.model: str = mistral_settings["embedding_model"] # type: ignore[assignment]
|
||||
resolved_api_key: str = mistral_settings["api_key"] # type: ignore[assignment]
|
||||
resolved_server_url = mistral_settings.get("server_url")
|
||||
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
client_kwargs: dict[str, Any] = {"api_key": resolved_api_key}
|
||||
if resolved_server_url:
|
||||
client_kwargs["server_url"] = resolved_server_url
|
||||
self.client = Mistral(**client_kwargs)
|
||||
|
||||
self.server_url = resolved_server_url
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
return self.server_url or "https://api.mistral.ai"
|
||||
|
||||
async def get_embeddings(
|
||||
self,
|
||||
values: Sequence[str],
|
||||
*,
|
||||
options: MistralEmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[list[float], MistralEmbeddingOptionsT]:
|
||||
"""Call the Mistral AI embeddings API.
|
||||
|
||||
Args:
|
||||
values: The text values to generate embeddings for.
|
||||
options: Optional embedding generation options.
|
||||
|
||||
Returns:
|
||||
Generated embeddings with usage metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not provided or values is empty.
|
||||
"""
|
||||
if not values:
|
||||
return GeneratedEmbeddings([], options=options)
|
||||
|
||||
opts: dict[str, Any] = options or {} # type: ignore
|
||||
model = opts.get("model") or self.model
|
||||
if not model:
|
||||
raise ValueError("model is required")
|
||||
|
||||
kwargs: dict[str, Any] = {"model": model, "inputs": list(values)}
|
||||
if "dimensions" in opts:
|
||||
kwargs["output_dimension"] = opts["dimensions"]
|
||||
|
||||
response = await self.client.embeddings.create_async(**kwargs)
|
||||
|
||||
embeddings: list[Embedding[list[float]]] = []
|
||||
if response and response.data:
|
||||
items = sorted(response.data, key=lambda d: d.index if d.index is not None else 0)
|
||||
for item in items:
|
||||
vector = list(item.embedding) if item.embedding else []
|
||||
embeddings.append(
|
||||
Embedding(
|
||||
vector=vector,
|
||||
dimensions=len(vector),
|
||||
model=response.model or model,
|
||||
)
|
||||
)
|
||||
|
||||
usage_dict: UsageDetails | None = None
|
||||
if response and response.usage:
|
||||
usage_dict = {
|
||||
"input_token_count": response.usage.prompt_tokens,
|
||||
"total_token_count": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
|
||||
|
||||
|
||||
class MistralEmbeddingClient(
|
||||
EmbeddingTelemetryLayer[str, list[float], MistralEmbeddingOptionsT],
|
||||
RawMistralEmbeddingClient[MistralEmbeddingOptionsT],
|
||||
Generic[MistralEmbeddingOptionsT],
|
||||
):
|
||||
"""Mistral AI embedding client with telemetry support.
|
||||
|
||||
Keyword Args:
|
||||
model: The Mistral embedding model (e.g. "mistral-embed").
|
||||
Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``.
|
||||
api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable.
|
||||
server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL``
|
||||
environment variable, or the Mistral default.
|
||||
client: Optional pre-configured ``Mistral`` client instance.
|
||||
otel_provider_name: Optional telemetry provider name override.
|
||||
env_file_path: Path to ``.env`` file for settings.
|
||||
env_file_encoding: Encoding for ``.env`` file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
# Using environment variables
|
||||
# Set MISTRAL_API_KEY=your-key
|
||||
# Set MISTRAL_EMBEDDING_MODEL=mistral-embed
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# Or passing parameters directly
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="your-api-key",
|
||||
)
|
||||
|
||||
# Generate embeddings
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
print(result[0].vector)
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME: ClassVar[str] = "mistralai"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
api_key: str | SecretString | None = None,
|
||||
server_url: str | None = None,
|
||||
client: Mistral | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a Mistral AI embedding client."""
|
||||
super().__init__(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
server_url=server_url,
|
||||
client=client,
|
||||
additional_properties=additional_properties,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
[project]
|
||||
name = "agent-framework-mistral"
|
||||
description = "Mistral AI integration for Microsoft Agent Framework."
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "1.0.0a260505"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://learn.microsoft.com/en-us/agent-framework/"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Framework :: Pydantic :: 2",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.1.0,<2",
|
||||
"mistralai>=2.0.0,<3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
timeout = 120
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"**/__init__.py"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["agent_framework_mistral"]
|
||||
exclude = ['tests']
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_any_unimported = true
|
||||
|
||||
[tool.bandit]
|
||||
targets = ["agent_framework_mistral"]
|
||||
exclude_dirs = ["tests"]
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks.mypy]
|
||||
help = "Run MyPy for this package."
|
||||
cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mistral"
|
||||
|
||||
[tool.poe.tasks.test]
|
||||
help = "Run the default unit test suite for this package."
|
||||
cmd = 'pytest -m "not integration" --cov=agent_framework_mistral --cov-report=term-missing:skip-covered tests'
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "agent_framework_mistral"
|
||||
module-root = ""
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.2,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -0,0 +1,15 @@
|
||||
# Mistral AI Embedding Examples
|
||||
|
||||
This folder contains examples demonstrating how to use Mistral AI embedding models with the Agent Framework.
|
||||
|
||||
## Examples
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| [`mistral_embeddings.py`](mistral_embeddings.py) | Basic embedding generation with the Mistral AI embedding client. |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `MISTRAL_API_KEY`: Your Mistral AI API key
|
||||
- `MISTRAL_EMBEDDING_MODEL`: Embedding model name (e.g., `mistral-embed`)
|
||||
- `MISTRAL_SERVER_URL` (optional): Server URL override for custom deployments
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Shows how to generate embeddings using the Mistral AI embedding client.
|
||||
|
||||
Requires ``MISTRAL_API_KEY`` and ``MISTRAL_EMBEDDING_MODEL`` environment variables.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def basic_embedding_example() -> None:
|
||||
"""Generate embeddings for a list of texts."""
|
||||
print("=== Basic Embedding Generation ===")
|
||||
|
||||
# 1. Create the embedding client (uses MISTRAL_API_KEY and MISTRAL_EMBEDDING_MODEL env vars).
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# 2. Generate embeddings for multiple texts.
|
||||
texts = ["Hello, world!", "How are you?", "Agent Framework with Mistral AI"]
|
||||
result = await client.get_embeddings(texts)
|
||||
|
||||
# 3. Print results.
|
||||
print(f"Generated {len(result)} embeddings")
|
||||
for i, embedding in enumerate(result):
|
||||
print(f" Text {i + 1}: dimensions={embedding.dimensions}, vector={embedding.vector[:5]}...")
|
||||
|
||||
if result.usage:
|
||||
print(
|
||||
f" Usage: {result.usage['input_token_count']} input tokens, "
|
||||
f"{result.usage['total_token_count']} total tokens"
|
||||
)
|
||||
|
||||
|
||||
async def embedding_with_options_example() -> None:
|
||||
"""Generate embeddings with custom dimensions."""
|
||||
print("\n=== Embedding with Custom Dimensions ===")
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingOptions
|
||||
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# Request a specific output dimension (model must support it).
|
||||
options: MistralEmbeddingOptions = {"dimensions": 256}
|
||||
result = await client.get_embeddings(["Dimensionality reduction example"], options=options)
|
||||
|
||||
print(f" Dimensions: {result[0].dimensions}")
|
||||
print(f" Vector (first 5): {result[0].vector[:5]}...")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run embedding examples."""
|
||||
await basic_embedding_example()
|
||||
await embedding_with_options_example()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Basic Embedding Generation ===
|
||||
Generated 3 embeddings
|
||||
Text 1: dimensions=1024, vector=[0.0123, -0.0456, 0.0789, -0.0012, 0.0345]...
|
||||
Text 2: dimensions=1024, vector=[0.0234, -0.0567, 0.0891, -0.0023, 0.0456]...
|
||||
Text 3: dimensions=1024, vector=[0.0345, -0.0678, 0.0912, -0.0034, 0.0567]...
|
||||
Usage: 15 input tokens, 15 total tokens
|
||||
|
||||
=== Embedding with Custom Dimensions ===
|
||||
Dimensions: 256
|
||||
Vector (first 5): [0.0456, -0.0789, 0.0123, -0.0456, 0.0789]...
|
||||
"""
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Embedding, GeneratedEmbeddings
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingClient, MistralEmbeddingOptions
|
||||
|
||||
# region: Unit Tests
|
||||
|
||||
|
||||
def test_mistral_embedding_construction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test construction with environment variables."""
|
||||
monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed")
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient()
|
||||
assert client.model == "mistral-embed"
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_with_params() -> None:
|
||||
"""Test construction with explicit parameters."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
)
|
||||
assert client.model == "mistral-embed"
|
||||
mock_cls.assert_called_once_with(api_key="test-key")
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_with_server_url() -> None:
|
||||
"""Test construction with custom server URL."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
server_url="https://custom.mistral.ai",
|
||||
)
|
||||
assert client.model == "mistral-embed"
|
||||
assert client.server_url == "https://custom.mistral.ai"
|
||||
mock_cls.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
server_url="https://custom.mistral.ai",
|
||||
)
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_with_client() -> None:
|
||||
"""Test construction with a pre-configured client."""
|
||||
mock_client = MagicMock()
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral"):
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
client=mock_client,
|
||||
)
|
||||
assert client.client is mock_client
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that missing model raises an error."""
|
||||
monkeypatch.delenv("MISTRAL_EMBEDDING_MODEL", raising=False)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
with pytest.raises(SettingNotFoundError):
|
||||
MistralEmbeddingClient()
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_missing_api_key_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that missing API key raises an error."""
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed")
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
with pytest.raises(SettingNotFoundError):
|
||||
MistralEmbeddingClient()
|
||||
|
||||
|
||||
def test_mistral_embedding_service_url() -> None:
|
||||
"""Test service_url returns the correct URL."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
)
|
||||
assert client.service_url() == "https://api.mistral.ai"
|
||||
|
||||
|
||||
def test_mistral_embedding_service_url_custom() -> None:
|
||||
"""Test service_url returns custom URL when set."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
server_url="https://custom.mistral.ai",
|
||||
)
|
||||
assert client.service_url() == "https://custom.mistral.ai"
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings() -> None:
|
||||
"""Test generating embeddings via the Mistral API."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
MagicMock(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"),
|
||||
]
|
||||
mock_response.model = "mistral-embed"
|
||||
mock_response.usage = MagicMock(prompt_tokens=10, total_tokens=10)
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
result = await client.get_embeddings(["hello", "world"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
assert result[0].vector == [0.1, 0.2, 0.3]
|
||||
assert result[1].vector == [0.4, 0.5, 0.6]
|
||||
assert result[0].model == "mistral-embed"
|
||||
assert result.usage == {"input_token_count": 10, "total_token_count": 10}
|
||||
|
||||
mock_client.embeddings.create_async.assert_called_once_with(
|
||||
model="mistral-embed",
|
||||
inputs=["hello", "world"],
|
||||
)
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_empty_input() -> None:
|
||||
"""Test generating embeddings with empty input."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
result = await client.get_embeddings([])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_with_dimensions() -> None:
|
||||
"""Test generating embeddings with custom dimensions option."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2], index=0, object="embedding"),
|
||||
]
|
||||
mock_response.model = "mistral-embed"
|
||||
mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
options: MistralEmbeddingOptions = {"dimensions": 512}
|
||||
result = await client.get_embeddings(["hello"], options=options)
|
||||
|
||||
assert len(result) == 1
|
||||
mock_client.embeddings.create_async.assert_called_once_with(
|
||||
model="mistral-embed",
|
||||
inputs=["hello"],
|
||||
output_dimension=512,
|
||||
)
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_no_model_raises() -> None:
|
||||
"""Test that missing model at call time raises ValueError."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
client.model = None # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(ValueError, match="model is required"):
|
||||
await client.get_embeddings(["hello"])
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_model_override() -> None:
|
||||
"""Test that model can be overridden via options."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
]
|
||||
mock_response.model = "custom-embed"
|
||||
mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
options: MistralEmbeddingOptions = {"model": "custom-embed"}
|
||||
result = await client.get_embeddings(["hello"], options=options)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].model == "custom-embed"
|
||||
mock_client.embeddings.create_async.assert_called_once_with(
|
||||
model="custom-embed",
|
||||
inputs=["hello"],
|
||||
)
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_no_usage() -> None:
|
||||
"""Test handling response without usage information."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
]
|
||||
mock_response.model = "mistral-embed"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
result = await client.get_embeddings(["hello"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result.usage is None
|
||||
|
||||
|
||||
# region: Integration Tests
|
||||
|
||||
skip_if_mistral_embedding_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("MISTRAL_EMBEDDING_MODEL", "") in ("", "test-model") or os.getenv("MISTRAL_API_KEY", "") == "",
|
||||
reason="No real Mistral embedding model or API key provided; skipping integration tests.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_mistral_embedding_integration_tests_disabled
|
||||
async def test_mistral_embedding_integration() -> None:
|
||||
"""Integration test for Mistral AI embedding client."""
|
||||
client = MistralEmbeddingClient()
|
||||
result = await client.get_embeddings(["Hello, world!", "How are you?"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
for embedding in result:
|
||||
assert isinstance(embedding, Embedding)
|
||||
assert isinstance(embedding.vector, list)
|
||||
assert len(embedding.vector) > 0
|
||||
assert all(isinstance(v, float) for v in embedding.vector)
|
||||
assert result.usage is not None
|
||||
assert result.usage["input_token_count"] is not None
|
||||
assert result.usage["input_token_count"] > 0
|
||||
@@ -54,7 +54,9 @@ package = false
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
# Security floors for transitive deps; overrides bypass litellm[proxy]'s strict pins.
|
||||
constraint-dependencies = ["litellm>=1.83.7", "fastapi-sso>=0.19.0"]
|
||||
override-dependencies = ["mcp[ws]>=1.27.0", "uvicorn[standard]>=0.34.0"]
|
||||
# Allow opentelemetry-semantic-conventions 0.61b0 for mistralai compatibility
|
||||
# (mistralai pins <0.61 but 0.61b0 is compatible at runtime).
|
||||
override-dependencies = ["mcp[ws]>=1.27.0", "uvicorn[standard]>=0.34.0", "opentelemetry-semantic-conventions>=0.60b1"]
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
@@ -88,6 +90,7 @@ agent-framework-github-copilot = { workspace = true }
|
||||
agent-framework-hyperlight = { workspace = true }
|
||||
agent-framework-lab = { workspace = true }
|
||||
agent-framework-mem0 = { workspace = true }
|
||||
agent-framework-mistral = { workspace = true }
|
||||
agent-framework-monty = { workspace = true }
|
||||
agent-framework-ollama = { workspace = true }
|
||||
agent-framework-openai = { workspace = true }
|
||||
@@ -211,6 +214,7 @@ executionEnvironments = [
|
||||
{ root = "packages/lab/lightning/tests", reportPrivateUsage = "none" },
|
||||
{ root = "packages/lab/tau2/tests", reportPrivateUsage = "none" },
|
||||
{ root = "packages/mem0/tests", reportPrivateUsage = "none" },
|
||||
{ root = "packages/mistral/tests", reportPrivateUsage = "none" },
|
||||
{ root = "packages/ollama/tests", reportPrivateUsage = "none" },
|
||||
{ root = "packages/orchestrations/tests", reportPrivateUsage = "none" },
|
||||
{ root = "packages/purview/tests", reportPrivateUsage = "none" },
|
||||
|
||||
Generated
+1642
-1709
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user