mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Phase 2: Embedding clients for Ollama, Bedrock, and Azure AI Inference (#4207)
* Phase 2: Embedding clients for Ollama, Bedrock, and Azure AI Inference Add embedding client implementations to existing provider packages: - OllamaEmbeddingClient: Text embeddings via Ollama's embed API - BedrockEmbeddingClient: Text embeddings via Amazon Titan on Bedrock - AzureAIInferenceEmbeddingClient: Text and image embeddings via Azure AI Inference, supporting Content | str input with separate model IDs for text (AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID) and image (AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID) endpoints Additional changes: - Rename EmbeddingCoT -> EmbeddingT, EmbeddingOptionsCoT -> EmbeddingOptionsT - Add otel_provider_name passthrough to all embedding clients - Register integration pytest marker in all packages - Add lazy-loading namespace exports for Ollama and Bedrock embeddings - Add image embedding sample using Cohere-embed-v3-english - Add azure-ai-inference dependency to azure-ai package Part of #1188 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix mypy duplicate name and ruff lint issues - Rename second 'vector' variable to 'img_vector' in image embedding loop - Combine nested with statements in tests - Remove unused result assignments in tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * updates from feedback * Fix CI failures in embedding usage handling - Fix Azure AI embedding mypy issues by normalizing vectors to list[float], safely accumulating optional usage token fields, and filtering None entries before constructing GeneratedEmbeddings - Avoid Bandit false positive by initializing usage details as an empty dict - Update OpenAI embedding tests to assert canonical usage keys (input_token_count/total_token_count) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
e3a5b915a6
commit
6138487888
@@ -140,9 +140,10 @@ This feature ports the vector store abstractions, embedding generator abstractio
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Core Embedding Abstractions & OpenAI Implementation
|
||||
### Phase 1: Core Embedding Abstractions & OpenAI Implementation ✅ DONE
|
||||
**Goal:** Establish the embedding generator abstraction and ship one working implementation.
|
||||
**Mergeable:** Yes — adds new types/protocols, no breaking changes.
|
||||
**Status:** Merged via PR #4153. Closes sub-issue #4163.
|
||||
|
||||
#### 1.1 — Embedding types in `_types.py`
|
||||
- `EmbeddingInputT` TypeVar (default `str`) — generic input type for embedding generation
|
||||
|
||||
@@ -37,6 +37,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -46,6 +47,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -79,6 +83,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_a2a"
|
||||
test = "pytest --cov=agent_framework_a2a --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -45,6 +45,9 @@ packages = ["agent_framework_ag_ui", "agent_framework_ag_ui_examples"]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests/ag_ui"]
|
||||
pythonpath = ["."]
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
@@ -37,6 +37,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -46,6 +47,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -80,6 +84,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_anthropic"
|
||||
test = "pytest --cov=agent_framework_anthropic --cov-report=term-missing:skip-covered -n auto --dist worksteal tests"
|
||||
|
||||
@@ -47,6 +47,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -82,6 +85,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_ai_search"
|
||||
test = "pytest --cov=agent_framework_azure_ai_search --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -5,6 +5,12 @@ import importlib.metadata
|
||||
from ._agent_provider import AzureAIAgentsProvider
|
||||
from ._chat_client import AzureAIAgentClient, AzureAIAgentOptions
|
||||
from ._client import AzureAIClient, AzureAIProjectAgentOptions, RawAzureAIClient
|
||||
from ._embedding_client import (
|
||||
AzureAIInferenceEmbeddingClient,
|
||||
AzureAIInferenceEmbeddingOptions,
|
||||
AzureAIInferenceEmbeddingSettings,
|
||||
RawAzureAIInferenceEmbeddingClient,
|
||||
)
|
||||
from ._foundry_memory_provider import FoundryMemoryProvider
|
||||
from ._project_provider import AzureAIProjectAgentProvider
|
||||
from ._shared import AzureAISettings
|
||||
@@ -19,10 +25,14 @@ __all__ = [
|
||||
"AzureAIAgentOptions",
|
||||
"AzureAIAgentsProvider",
|
||||
"AzureAIClient",
|
||||
"AzureAIInferenceEmbeddingClient",
|
||||
"AzureAIInferenceEmbeddingOptions",
|
||||
"AzureAIInferenceEmbeddingSettings",
|
||||
"AzureAIProjectAgentOptions",
|
||||
"AzureAIProjectAgentProvider",
|
||||
"AzureAISettings",
|
||||
"FoundryMemoryProvider",
|
||||
"RawAzureAIClient",
|
||||
"RawAzureAIInferenceEmbeddingClient",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from contextlib import suppress
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
|
||||
from agent_framework import (
|
||||
BaseEmbeddingClient,
|
||||
Content,
|
||||
Embedding,
|
||||
EmbeddingGenerationOptions,
|
||||
GeneratedEmbeddings,
|
||||
UsageDetails,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework.observability import EmbeddingTelemetryLayer
|
||||
from azure.ai.inference.aio import EmbeddingsClient, ImageEmbeddingsClient
|
||||
from azure.ai.inference.models import ImageEmbeddingInput
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework.azure_ai")
|
||||
|
||||
_IMAGE_MEDIA_PREFIXES = ("image/",)
|
||||
|
||||
|
||||
class AzureAIInferenceEmbeddingOptions(EmbeddingGenerationOptions, total=False):
|
||||
"""Azure AI Inference-specific embedding options.
|
||||
|
||||
Extends EmbeddingGenerationOptions with Azure AI Inference-specific fields.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_azure_ai import AzureAIInferenceEmbeddingOptions
|
||||
|
||||
options: AzureAIInferenceEmbeddingOptions = {
|
||||
"model_id": "text-embedding-3-small",
|
||||
"dimensions": 1536,
|
||||
"input_type": "document",
|
||||
"encoding_format": "float",
|
||||
}
|
||||
"""
|
||||
|
||||
input_type: str
|
||||
"""Input type hint for the model. Common values: ``"text"``, ``"query"``, ``"document"``."""
|
||||
|
||||
image_model_id: str
|
||||
"""Override model for image embeddings. Falls back to the client's ``image_model_id``."""
|
||||
|
||||
encoding_format: str
|
||||
"""Output encoding format.
|
||||
|
||||
Common values: ``"float"``, ``"base64"``, ``"int8"``, ``"uint8"``,
|
||||
``"binary"``, ``"ubinary"``.
|
||||
"""
|
||||
|
||||
extra_parameters: dict[str, Any]
|
||||
"""Additional model-specific parameters passed directly to the API."""
|
||||
|
||||
|
||||
AzureAIInferenceEmbeddingOptionsT = TypeVar(
|
||||
"AzureAIInferenceEmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="AzureAIInferenceEmbeddingOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class AzureAIInferenceEmbeddingSettings(TypedDict, total=False):
|
||||
"""Azure AI Inference embedding settings."""
|
||||
|
||||
endpoint: str | None
|
||||
api_key: str | None
|
||||
embedding_model_id: str | None
|
||||
image_embedding_model_id: str | None
|
||||
|
||||
|
||||
class RawAzureAIInferenceEmbeddingClient(
|
||||
BaseEmbeddingClient[Content | str, list[float], AzureAIInferenceEmbeddingOptionsT],
|
||||
Generic[AzureAIInferenceEmbeddingOptionsT],
|
||||
):
|
||||
"""Raw Azure AI Inference embedding client without telemetry.
|
||||
|
||||
Accepts both text (``str``) and image (``Content``) inputs. Text and image
|
||||
inputs within a single batch are separated and dispatched to
|
||||
``EmbeddingsClient`` and ``ImageEmbeddingsClient`` respectively. Results
|
||||
are reassembled in the original input order.
|
||||
|
||||
Keyword Args:
|
||||
model_id: The text embedding model deployment name (e.g. "text-embedding-3-small").
|
||||
Can also be set via environment variable AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID.
|
||||
image_model_id: The image embedding model deployment name (e.g. "Cohere-embed-v3-english").
|
||||
Can also be set via environment variable AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID.
|
||||
Falls back to ``model_id`` if not provided.
|
||||
endpoint: The Azure AI Inference endpoint URL.
|
||||
Can also be set via environment variable AZURE_AI_INFERENCE_ENDPOINT.
|
||||
api_key: API key for authentication.
|
||||
Can also be set via environment variable AZURE_AI_INFERENCE_API_KEY.
|
||||
text_client: Optional pre-configured ``EmbeddingsClient``.
|
||||
image_client: Optional pre-configured ``ImageEmbeddingsClient``.
|
||||
credential: Optional ``AzureKeyCredential`` or token credential. If not provided,
|
||||
one is created from ``api_key``.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
image_model_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
text_client: EmbeddingsClient | None = None,
|
||||
image_client: ImageEmbeddingsClient | None = None,
|
||||
credential: AzureKeyCredential | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw Azure AI Inference embedding client."""
|
||||
settings = load_settings(
|
||||
AzureAIInferenceEmbeddingSettings,
|
||||
env_prefix="AZURE_AI_INFERENCE_",
|
||||
required_fields=["endpoint", "embedding_model_id"],
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
embedding_model_id=model_id,
|
||||
image_embedding_model_id=image_model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
self.model_id = settings["embedding_model_id"] # type: ignore[reportTypedDictNotRequiredAccess]
|
||||
self.image_model_id: str = settings.get("image_embedding_model_id") or self.model_id # type: ignore[assignment]
|
||||
resolved_endpoint = settings["endpoint"] # type: ignore[reportTypedDictNotRequiredAccess]
|
||||
|
||||
if credential is None and settings.get("api_key"):
|
||||
credential = AzureKeyCredential(settings["api_key"]) # type: ignore[arg-type]
|
||||
|
||||
if credential is None and text_client is None and image_client is None:
|
||||
raise ValueError("Either 'api_key', 'credential', or pre-configured client(s) must be provided.")
|
||||
|
||||
self._text_client = text_client or EmbeddingsClient(
|
||||
endpoint=resolved_endpoint, # type: ignore[arg-type]
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
)
|
||||
self._image_client = image_client or ImageEmbeddingsClient(
|
||||
endpoint=resolved_endpoint, # type: ignore[arg-type]
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
)
|
||||
self._endpoint = resolved_endpoint
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying SDK clients and release resources."""
|
||||
with suppress(Exception):
|
||||
await self._text_client.close()
|
||||
with suppress(Exception):
|
||||
await self._image_client.close()
|
||||
|
||||
async def __aenter__(self) -> RawAzureAIInferenceEmbeddingClient[AzureAIInferenceEmbeddingOptionsT]:
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
"""Exit the async context manager and close clients."""
|
||||
await self.close()
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
return self._endpoint or ""
|
||||
|
||||
async def get_embeddings(
|
||||
self,
|
||||
values: Sequence[Content | str],
|
||||
*,
|
||||
options: AzureAIInferenceEmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[list[float]]:
|
||||
"""Generate embeddings for text and/or image inputs.
|
||||
|
||||
Text inputs (``str`` or ``Content`` with ``type="text"``) are sent to the
|
||||
text embeddings endpoint. Image inputs (``Content`` with an image
|
||||
``media_type``) are sent to the image embeddings endpoint. Results are
|
||||
returned in the same order as the input.
|
||||
|
||||
Args:
|
||||
values: A sequence of text strings or ``Content`` instances.
|
||||
options: Optional embedding generation options.
|
||||
|
||||
Returns:
|
||||
Generated embeddings with usage metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_id is not provided or an unsupported content type is encountered.
|
||||
"""
|
||||
if not values:
|
||||
return GeneratedEmbeddings([], options=options) # type: ignore[reportReturnType]
|
||||
|
||||
opts: dict[str, Any] = dict(options) if options else {}
|
||||
|
||||
# Separate text and image inputs, tracking original indices.
|
||||
text_items: list[tuple[int, str]] = []
|
||||
image_items: list[tuple[int, ImageEmbeddingInput]] = []
|
||||
|
||||
for idx, value in enumerate(values):
|
||||
if isinstance(value, str):
|
||||
text_items.append((idx, value))
|
||||
elif isinstance(value, Content):
|
||||
if value.type == "text" and value.text is not None:
|
||||
text_items.append((idx, value.text))
|
||||
elif (
|
||||
value.type in ("data", "uri")
|
||||
and value.media_type
|
||||
and value.media_type.startswith(_IMAGE_MEDIA_PREFIXES[0])
|
||||
):
|
||||
if not value.uri:
|
||||
raise ValueError(f"Image Content at index {idx} has no URI.")
|
||||
image_input = ImageEmbeddingInput(image=value.uri, text=value.text)
|
||||
image_items.append((idx, image_input))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported Content type '{value.type}' with media_type "
|
||||
f"'{value.media_type}' at index {idx}. Expected text content or "
|
||||
f"image content (media_type starting with 'image/')."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type {type(value).__name__} at index {idx}.")
|
||||
|
||||
# Build shared API kwargs (without model, which differs per client).
|
||||
common_kwargs: dict[str, Any] = {}
|
||||
if dimensions := opts.get("dimensions"):
|
||||
common_kwargs["dimensions"] = dimensions
|
||||
if encoding_format := opts.get("encoding_format"):
|
||||
common_kwargs["encoding_format"] = encoding_format
|
||||
if input_type := opts.get("input_type"):
|
||||
common_kwargs["input_type"] = input_type
|
||||
if extra_parameters := opts.get("extra_parameters"):
|
||||
common_kwargs["model_extras"] = extra_parameters
|
||||
|
||||
# Allocate results array.
|
||||
embeddings: list[Embedding[list[float]] | None] = [None] * len(values)
|
||||
usage_details: UsageDetails = {}
|
||||
|
||||
# Embed text inputs.
|
||||
if text_items:
|
||||
if not (text_model := opts.get("model_id") or self.model_id):
|
||||
raise ValueError("An model_id is required, either in the client or options, for text inputs.")
|
||||
text_inputs = [t for _, t in text_items]
|
||||
response = await self._text_client.embed(
|
||||
input=text_inputs,
|
||||
model=text_model,
|
||||
**common_kwargs,
|
||||
)
|
||||
for i, item in enumerate(response.data):
|
||||
original_idx = text_items[i][0]
|
||||
vector: list[float] = [float(v) for v in item.embedding]
|
||||
embeddings[original_idx] = Embedding(
|
||||
vector=vector,
|
||||
dimensions=len(vector),
|
||||
model_id=response.model or text_model,
|
||||
)
|
||||
if response.usage:
|
||||
usage_details["input_token_count"] = (usage_details.get("input_token_count") or 0) + (
|
||||
response.usage.prompt_tokens or 0
|
||||
)
|
||||
usage_details["output_token_count"] = (usage_details.get("output_token_count") or 0) + (
|
||||
getattr(response.usage, "completion_tokens", 0) or 0
|
||||
)
|
||||
|
||||
# Embed image inputs.
|
||||
if image_items:
|
||||
if not (image_model := opts.get("image_model_id") or self.image_model_id):
|
||||
raise ValueError("An image_model_id is required, either in the client or options, for image inputs.")
|
||||
image_inputs = [img for _, img in image_items]
|
||||
response = await self._image_client.embed(
|
||||
input=image_inputs,
|
||||
model=image_model,
|
||||
**common_kwargs,
|
||||
)
|
||||
for i, item in enumerate(response.data):
|
||||
original_idx = image_items[i][0]
|
||||
image_vector: list[float] = [float(v) for v in item.embedding]
|
||||
embeddings[original_idx] = Embedding(
|
||||
vector=image_vector,
|
||||
dimensions=len(image_vector),
|
||||
model_id=response.model or image_model,
|
||||
)
|
||||
if response.usage:
|
||||
usage_details["input_token_count"] = (usage_details.get("input_token_count") or 0) + (
|
||||
response.usage.prompt_tokens or 0
|
||||
)
|
||||
usage_details["output_token_count"] = (usage_details.get("output_token_count") or 0) + (
|
||||
getattr(response.usage, "completion_tokens", 0) or 0
|
||||
)
|
||||
return GeneratedEmbeddings(
|
||||
[embedding for embedding in embeddings if embedding is not None],
|
||||
options=options,
|
||||
usage=usage_details,
|
||||
) # type: ignore[reportReturnType]
|
||||
|
||||
|
||||
class AzureAIInferenceEmbeddingClient(
|
||||
EmbeddingTelemetryLayer[Content | str, list[float], AzureAIInferenceEmbeddingOptionsT],
|
||||
RawAzureAIInferenceEmbeddingClient[AzureAIInferenceEmbeddingOptionsT],
|
||||
Generic[AzureAIInferenceEmbeddingOptionsT],
|
||||
):
|
||||
"""Azure AI Inference embedding client with telemetry support.
|
||||
|
||||
Supports both text and image inputs in a single client. Pass plain strings
|
||||
or ``Content`` instances created with ``Content.from_text()`` or
|
||||
``Content.from_data()``.
|
||||
|
||||
Keyword Args:
|
||||
model_id: The text embedding model deployment name (e.g. "text-embedding-3-small").
|
||||
Can also be set via environment variable AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID.
|
||||
image_model_id: The image embedding model deployment name
|
||||
(e.g. "Cohere-embed-v3-english"). Can also be set via environment variable
|
||||
AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID. Falls back to ``model_id``.
|
||||
endpoint: The Azure AI Inference endpoint URL.
|
||||
Can also be set via environment variable AZURE_AI_INFERENCE_ENDPOINT.
|
||||
api_key: API key for authentication.
|
||||
Can also be set via environment variable AZURE_AI_INFERENCE_API_KEY.
|
||||
text_client: Optional pre-configured ``EmbeddingsClient``.
|
||||
image_client: Optional pre-configured ``ImageEmbeddingsClient``.
|
||||
credential: Optional ``AzureKeyCredential`` or token credential.
|
||||
otel_provider_name: Override for the OpenTelemetry provider name.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_azure_ai import AzureAIInferenceEmbeddingClient
|
||||
|
||||
# Using environment variables
|
||||
# Set AZURE_AI_INFERENCE_ENDPOINT=https://your-endpoint.inference.ai.azure.com
|
||||
# Set AZURE_AI_INFERENCE_API_KEY=your-key
|
||||
# Set AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID=text-embedding-3-small
|
||||
# Set AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID=Cohere-embed-v3-english
|
||||
client = AzureAIInferenceEmbeddingClient()
|
||||
|
||||
# Text embeddings
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
|
||||
# Image embeddings
|
||||
from agent_framework import Content
|
||||
|
||||
image = Content.from_data(data=image_bytes, media_type="image/png")
|
||||
result = await client.get_embeddings([image])
|
||||
|
||||
# Mixed text and image
|
||||
result = await client.get_embeddings(["hello", image])
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.inference" # type: ignore[reportIncompatibleVariableOverride, misc]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
image_model_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
text_client: EmbeddingsClient | None = None,
|
||||
image_client: ImageEmbeddingsClient | None = None,
|
||||
credential: AzureKeyCredential | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Azure AI Inference embedding client."""
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
image_model_id=image_model_id,
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
text_client=text_client,
|
||||
image_client=image_client,
|
||||
credential=credential,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -25,6 +25,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.0.0rc1",
|
||||
"azure-ai-agents == 1.2.0b5",
|
||||
"azure-ai-inference>=1.0.0b9",
|
||||
"aiohttp",
|
||||
]
|
||||
|
||||
@@ -38,6 +39,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -45,6 +47,9 @@ asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -78,6 +83,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_ai"
|
||||
test = "pytest --cov=agent_framework_azure_ai --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Content
|
||||
|
||||
from agent_framework_azure_ai import (
|
||||
AzureAIInferenceEmbeddingClient,
|
||||
AzureAIInferenceEmbeddingOptions,
|
||||
RawAzureAIInferenceEmbeddingClient,
|
||||
)
|
||||
|
||||
|
||||
def _make_embed_response(
|
||||
embeddings: Sequence[list[float]],
|
||||
model: str = "test-model",
|
||||
prompt_tokens: int = 10,
|
||||
) -> MagicMock:
|
||||
"""Create a mock EmbeddingsResult."""
|
||||
data = []
|
||||
for emb in embeddings:
|
||||
item = MagicMock()
|
||||
item.embedding = emb
|
||||
data.append(item)
|
||||
|
||||
usage = MagicMock()
|
||||
usage.prompt_tokens = prompt_tokens
|
||||
usage.completion_tokens = 0
|
||||
|
||||
result = MagicMock()
|
||||
result.data = data
|
||||
result.model = model
|
||||
result.usage = usage
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_client() -> AsyncMock:
|
||||
"""Create a mock text EmbeddingsClient."""
|
||||
client = AsyncMock()
|
||||
client.embed = AsyncMock(return_value=_make_embed_response([[0.1, 0.2, 0.3]]))
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_client() -> AsyncMock:
|
||||
"""Create a mock image ImageEmbeddingsClient."""
|
||||
client = AsyncMock()
|
||||
client.embed = AsyncMock(return_value=_make_embed_response([[0.4, 0.5, 0.6]]))
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def raw_client(mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> RawAzureAIInferenceEmbeddingClient[Any]:
|
||||
"""Create a RawAzureAIInferenceEmbeddingClient with mocked SDK clients."""
|
||||
return RawAzureAIInferenceEmbeddingClient(
|
||||
model_id="test-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> AzureAIInferenceEmbeddingClient[Any]:
|
||||
"""Create an AzureAIInferenceEmbeddingClient with mocked SDK clients."""
|
||||
return AzureAIInferenceEmbeddingClient(
|
||||
model_id="test-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
|
||||
|
||||
class TestRawAzureAIInferenceEmbeddingClient:
|
||||
"""Tests for the raw Azure AI Inference embedding client."""
|
||||
|
||||
async def test_text_embeddings(
|
||||
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Text inputs are dispatched to the text client."""
|
||||
result = await raw_client.get_embeddings(["hello", "world"])
|
||||
assert result is not None
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["input"] == ["hello", "world"]
|
||||
assert call_kwargs.kwargs["model"] == "test-model"
|
||||
|
||||
async def test_text_content_embeddings(
|
||||
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Content.from_text() inputs are dispatched to the text client."""
|
||||
text_content = Content.from_text("hello")
|
||||
await raw_client.get_embeddings([text_content])
|
||||
|
||||
mock_text_client.embed.assert_called_once()
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["input"] == ["hello"]
|
||||
|
||||
async def test_image_content_embeddings(
|
||||
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_image_client: AsyncMock
|
||||
) -> None:
|
||||
"""Image Content inputs are dispatched to the image client."""
|
||||
image_content = Content.from_data(data=b"\x89PNG", media_type="image/png")
|
||||
await raw_client.get_embeddings([image_content])
|
||||
|
||||
mock_image_client.embed.assert_called_once()
|
||||
call_kwargs = mock_image_client.embed.call_args
|
||||
image_inputs = call_kwargs.kwargs["input"]
|
||||
assert len(image_inputs) == 1
|
||||
assert image_inputs[0].image == image_content.uri
|
||||
|
||||
async def test_mixed_text_and_image(
|
||||
self,
|
||||
raw_client: RawAzureAIInferenceEmbeddingClient[Any],
|
||||
mock_text_client: AsyncMock,
|
||||
mock_image_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Mixed text and image inputs are dispatched to the correct clients."""
|
||||
mock_text_client.embed.return_value = _make_embed_response([[0.1, 0.2]])
|
||||
mock_image_client.embed.return_value = _make_embed_response([[0.3, 0.4]])
|
||||
|
||||
image = Content.from_data(data=b"\x89PNG", media_type="image/png")
|
||||
await raw_client.get_embeddings(["hello", image, "world"])
|
||||
|
||||
# Text client gets "hello" and "world"
|
||||
text_call = mock_text_client.embed.call_args
|
||||
assert text_call.kwargs["input"] == ["hello", "world"]
|
||||
|
||||
# Image client gets the image
|
||||
image_call = mock_image_client.embed.call_args
|
||||
assert len(image_call.kwargs["input"]) == 1
|
||||
|
||||
async def test_empty_input(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
|
||||
"""Empty input returns empty result."""
|
||||
result = await raw_client.get_embeddings([])
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_options_passed_through(
|
||||
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Options are passed through to the SDK."""
|
||||
options: AzureAIInferenceEmbeddingOptions = {
|
||||
"dimensions": 512,
|
||||
"input_type": "document",
|
||||
"encoding_format": "float",
|
||||
}
|
||||
await raw_client.get_embeddings(["hello"], options=options)
|
||||
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["dimensions"] == 512
|
||||
assert call_kwargs.kwargs["input_type"] == "document"
|
||||
assert call_kwargs.kwargs["encoding_format"] == "float"
|
||||
|
||||
async def test_model_override_in_options(
|
||||
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""model_id in options overrides the default."""
|
||||
options: AzureAIInferenceEmbeddingOptions = {"model_id": "custom-model"}
|
||||
await raw_client.get_embeddings(["hello"], options=options)
|
||||
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["model"] == "custom-model"
|
||||
|
||||
async def test_unsupported_content_type_raises(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
|
||||
"""Non-text, non-image Content raises ValueError."""
|
||||
error_content = Content("error", message="fail")
|
||||
with pytest.raises(ValueError, match="Unsupported Content type"):
|
||||
await raw_client.get_embeddings([error_content])
|
||||
|
||||
async def test_usage_metadata(
|
||||
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Usage metadata is populated from the response."""
|
||||
mock_text_client.embed.return_value = _make_embed_response([[0.1, 0.2]], prompt_tokens=42)
|
||||
result = await raw_client.get_embeddings(["hello"])
|
||||
assert result.usage is not None
|
||||
assert result.usage["input_token_count"] == 42
|
||||
|
||||
def test_service_url(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
|
||||
"""service_url returns the configured endpoint."""
|
||||
assert raw_client.service_url() == "https://test.inference.ai.azure.com"
|
||||
|
||||
def test_settings_from_env(self) -> None:
|
||||
"""Settings are loaded from environment variables."""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_AI_INFERENCE_ENDPOINT": "https://env.inference.ai.azure.com",
|
||||
"AZURE_AI_INFERENCE_API_KEY": "env-key",
|
||||
"AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID": "env-model",
|
||||
},
|
||||
),
|
||||
patch("agent_framework_azure_ai._embedding_client.EmbeddingsClient"),
|
||||
patch("agent_framework_azure_ai._embedding_client.ImageEmbeddingsClient"),
|
||||
):
|
||||
client = RawAzureAIInferenceEmbeddingClient()
|
||||
assert client.model_id == "env-model"
|
||||
assert client.image_model_id == "env-model" # falls back to model_id
|
||||
|
||||
def test_image_model_id_from_env(self) -> None:
|
||||
"""image_model_id is loaded from its own environment variable."""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"AZURE_AI_INFERENCE_ENDPOINT": "https://env.inference.ai.azure.com",
|
||||
"AZURE_AI_INFERENCE_API_KEY": "env-key",
|
||||
"AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID": "text-model",
|
||||
"AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID": "image-model",
|
||||
},
|
||||
),
|
||||
patch("agent_framework_azure_ai._embedding_client.EmbeddingsClient"),
|
||||
patch("agent_framework_azure_ai._embedding_client.ImageEmbeddingsClient"),
|
||||
):
|
||||
client = RawAzureAIInferenceEmbeddingClient()
|
||||
assert client.model_id == "text-model"
|
||||
assert client.image_model_id == "image-model"
|
||||
|
||||
def test_image_model_id_explicit(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
|
||||
"""image_model_id can be set explicitly."""
|
||||
client = RawAzureAIInferenceEmbeddingClient(
|
||||
model_id="text-model",
|
||||
image_model_id="image-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
assert client.model_id == "text-model"
|
||||
assert client.image_model_id == "image-model"
|
||||
|
||||
async def test_image_model_id_sent_to_image_client(
|
||||
self, mock_text_client: AsyncMock, mock_image_client: AsyncMock
|
||||
) -> None:
|
||||
"""image_model_id is passed to the image client embed call."""
|
||||
client = RawAzureAIInferenceEmbeddingClient(
|
||||
model_id="text-model",
|
||||
image_model_id="image-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
image_content = Content.from_data(data=b"\x89PNG", media_type="image/png")
|
||||
await client.get_embeddings([image_content])
|
||||
call_kwargs = mock_image_client.embed.call_args
|
||||
assert call_kwargs.kwargs["model"] == "image-model"
|
||||
|
||||
|
||||
class TestAzureAIInferenceEmbeddingClient:
|
||||
"""Tests for the telemetry-enabled Azure AI Inference embedding client."""
|
||||
|
||||
async def test_text_embeddings(
|
||||
self, client: AzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Text embeddings work through the telemetry layer."""
|
||||
result = await client.get_embeddings(["hello"])
|
||||
assert len(result) == 1
|
||||
assert result[0].vector == [0.1, 0.2, 0.3]
|
||||
|
||||
async def test_otel_provider_name_default(self) -> None:
|
||||
"""Default OTEL provider name is azure.ai.inference."""
|
||||
assert AzureAIInferenceEmbeddingClient.OTEL_PROVIDER_NAME == "azure.ai.inference"
|
||||
|
||||
async def test_otel_provider_name_override(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
|
||||
"""OTEL provider name can be overridden."""
|
||||
client = AzureAIInferenceEmbeddingClient(
|
||||
model_id="test-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
otel_provider_name="custom-provider",
|
||||
)
|
||||
assert client.otel_provider_name == "custom-provider"
|
||||
|
||||
|
||||
_SKIP_REASON = "Azure AI Inference integration tests disabled"
|
||||
|
||||
|
||||
def _integration_tests_enabled() -> bool:
|
||||
return bool(
|
||||
os.environ.get("AZURE_AI_INFERENCE_ENDPOINT")
|
||||
and os.environ.get("AZURE_AI_INFERENCE_API_KEY")
|
||||
and os.environ.get("AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID")
|
||||
)
|
||||
|
||||
|
||||
skip_if_azure_ai_inference_integration_tests_disabled = pytest.mark.skipif(
|
||||
not _integration_tests_enabled(),
|
||||
reason=_SKIP_REASON,
|
||||
)
|
||||
|
||||
|
||||
class TestAzureAIInferenceEmbeddingIntegration:
|
||||
"""Integration tests requiring a live Azure AI Inference endpoint."""
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_azure_ai_inference_integration_tests_disabled
|
||||
async def test_text_embedding_live(self) -> None:
|
||||
"""Generate text embeddings against a live endpoint."""
|
||||
client = AzureAIInferenceEmbeddingClient()
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
assert len(result) == 1
|
||||
assert len(result[0].vector) > 0
|
||||
assert result[0].model_id is not None
|
||||
@@ -41,6 +41,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
pythonpath = ["tests/integration_tests"]
|
||||
@@ -88,6 +89,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azurefunctions"
|
||||
test = "pytest --cov=agent_framework_azurefunctions --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import importlib.metadata
|
||||
|
||||
from ._chat_client import BedrockChatClient, BedrockChatOptions, BedrockGuardrailConfig, BedrockSettings
|
||||
from ._embedding_client import BedrockEmbeddingClient, BedrockEmbeddingOptions, BedrockEmbeddingSettings
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
@@ -12,6 +13,9 @@ except importlib.metadata.PackageNotFoundError:
|
||||
__all__ = [
|
||||
"BedrockChatClient",
|
||||
"BedrockChatOptions",
|
||||
"BedrockEmbeddingClient",
|
||||
"BedrockEmbeddingOptions",
|
||||
"BedrockEmbeddingSettings",
|
||||
"BedrockGuardrailConfig",
|
||||
"BedrockSettings",
|
||||
"__version__",
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
BaseEmbeddingClient,
|
||||
Embedding,
|
||||
EmbeddingGenerationOptions,
|
||||
GeneratedEmbeddings,
|
||||
SecretString,
|
||||
UsageDetails,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework.observability import EmbeddingTelemetryLayer
|
||||
from boto3.session import Session as Boto3Session
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config as BotoConfig
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework.bedrock")
|
||||
DEFAULT_REGION = "us-east-1"
|
||||
|
||||
|
||||
class BedrockEmbeddingSettings(TypedDict, total=False):
|
||||
"""Bedrock embedding settings."""
|
||||
|
||||
region: str | None
|
||||
embedding_model_id: str | None
|
||||
access_key: SecretString | None
|
||||
secret_key: SecretString | None
|
||||
session_token: SecretString | None
|
||||
|
||||
|
||||
class BedrockEmbeddingOptions(EmbeddingGenerationOptions, total=False):
|
||||
"""Bedrock-specific embedding options.
|
||||
|
||||
Extends EmbeddingGenerationOptions with Bedrock-specific fields.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_bedrock import BedrockEmbeddingOptions
|
||||
|
||||
options: BedrockEmbeddingOptions = {
|
||||
"model_id": "amazon.titan-embed-text-v2:0",
|
||||
"dimensions": 1024,
|
||||
"normalize": True,
|
||||
}
|
||||
"""
|
||||
|
||||
normalize: bool
|
||||
|
||||
|
||||
BedrockEmbeddingOptionsT = TypeVar(
|
||||
"BedrockEmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="BedrockEmbeddingOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class RawBedrockEmbeddingClient(
|
||||
BaseEmbeddingClient[str, list[float], BedrockEmbeddingOptionsT],
|
||||
Generic[BedrockEmbeddingOptionsT],
|
||||
):
|
||||
"""Raw Bedrock embedding client without telemetry.
|
||||
|
||||
Keyword Args:
|
||||
model_id: The Bedrock embedding model ID (e.g. "amazon.titan-embed-text-v2:0").
|
||||
Can also be set via environment variable BEDROCK_EMBEDDING_MODEL_ID.
|
||||
region: AWS region. Will try to load from BEDROCK_REGION env var,
|
||||
if not set, the regular Boto3 configuration/loading applies
|
||||
(which may include other env vars, config files, or instance metadata).
|
||||
access_key: AWS access key for manual credential injection.
|
||||
secret_key: AWS secret key paired with access_key.
|
||||
session_token: AWS session token for temporary credentials.
|
||||
client: Preconfigured Bedrock runtime client.
|
||||
boto3_session: Custom boto3 session used to build the runtime client.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
region: str | None = None,
|
||||
model_id: str | None = None,
|
||||
access_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
session_token: str | None = None,
|
||||
client: BaseClient | None = None,
|
||||
boto3_session: Boto3Session | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw Bedrock embedding client."""
|
||||
settings = load_settings(
|
||||
BedrockEmbeddingSettings,
|
||||
env_prefix="BEDROCK_",
|
||||
required_fields=["embedding_model_id"],
|
||||
region=region,
|
||||
embedding_model_id=model_id,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
session_token=session_token,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
resolved_region = settings.get("region") or DEFAULT_REGION
|
||||
|
||||
if client is None:
|
||||
if not boto3_session:
|
||||
session_kwargs: dict[str, Any] = {}
|
||||
if region := settings.get("region"):
|
||||
session_kwargs["region_name"] = region
|
||||
if (access_key := settings.get("access_key")) and (secret_key := settings.get("secret_key")):
|
||||
session_kwargs["aws_access_key_id"] = access_key.get_secret_value() # type: ignore[union-attr]
|
||||
session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value() # type: ignore[union-attr]
|
||||
if session_token := settings.get("session_token"):
|
||||
session_kwargs["aws_session_token"] = session_token.get_secret_value() # type: ignore[union-attr]
|
||||
boto3_session = Boto3Session(**session_kwargs)
|
||||
client = boto3_session.client(
|
||||
"bedrock-runtime",
|
||||
region_name=boto3_session.region_name or resolved_region,
|
||||
config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT),
|
||||
)
|
||||
|
||||
self._bedrock_client = client
|
||||
self.model_id = settings["embedding_model_id"] # type: ignore[assignment]
|
||||
self.region = resolved_region
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
return str(self._bedrock_client.meta.endpoint_url)
|
||||
|
||||
async def get_embeddings(
|
||||
self,
|
||||
values: Sequence[str],
|
||||
*,
|
||||
options: BedrockEmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[list[float]]:
|
||||
"""Call the Bedrock invoke_model API for embeddings.
|
||||
|
||||
Uses the Amazon Titan Embeddings model format. Each value is embedded
|
||||
individually since Titan's invoke_model API accepts one input at a time.
|
||||
|
||||
Args:
|
||||
values: The text values to generate embeddings for.
|
||||
options: Optional embedding generation options.
|
||||
|
||||
Returns:
|
||||
Generated embeddings with usage metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_id is not provided or values is empty.
|
||||
"""
|
||||
if not values:
|
||||
return GeneratedEmbeddings([], options=options)
|
||||
|
||||
opts: dict[str, Any] = dict(options) if options else {}
|
||||
model = opts.get("model_id") or self.model_id
|
||||
if not model:
|
||||
raise ValueError("model_id is required")
|
||||
|
||||
embedding_results = await asyncio.gather(
|
||||
*(self._generate_embedding_for_text(opts, model, text) for text in values)
|
||||
)
|
||||
embeddings: list[Embedding[list[float]]] = []
|
||||
total_input_tokens = 0
|
||||
for embedding, input_tokens in embedding_results:
|
||||
embeddings.append(embedding)
|
||||
total_input_tokens += input_tokens
|
||||
|
||||
usage_dict: UsageDetails | None = None
|
||||
if total_input_tokens > 0:
|
||||
usage_dict = {"input_token_count": total_input_tokens}
|
||||
|
||||
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
|
||||
|
||||
async def _generate_embedding_for_text(
|
||||
self,
|
||||
opts: dict[str, Any],
|
||||
model: str,
|
||||
text: str,
|
||||
) -> tuple[Embedding[list[float]], int]:
|
||||
body: dict[str, Any] = {"inputText": text}
|
||||
if dimensions := opts.get("dimensions"):
|
||||
body["dimensions"] = dimensions
|
||||
if (normalize := opts.get("normalize")) is not None:
|
||||
body["normalize"] = normalize
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
self._bedrock_client.invoke_model,
|
||||
modelId=model,
|
||||
contentType="application/json",
|
||||
accept="application/json",
|
||||
body=json.dumps(body),
|
||||
)
|
||||
|
||||
response_body = json.loads(response["body"].read())
|
||||
embedding = Embedding(
|
||||
vector=response_body["embedding"],
|
||||
dimensions=len(response_body["embedding"]),
|
||||
model_id=model,
|
||||
)
|
||||
input_tokens = int(response_body.get("inputTextTokenCount", 0))
|
||||
return embedding, input_tokens
|
||||
|
||||
|
||||
class BedrockEmbeddingClient(
|
||||
EmbeddingTelemetryLayer[str, list[float], BedrockEmbeddingOptionsT],
|
||||
RawBedrockEmbeddingClient[BedrockEmbeddingOptionsT],
|
||||
Generic[BedrockEmbeddingOptionsT],
|
||||
):
|
||||
"""Bedrock embedding client with telemetry support.
|
||||
|
||||
Uses the Amazon Titan Embeddings model via Bedrock's invoke_model API.
|
||||
|
||||
Keyword Args:
|
||||
model_id: The Bedrock embedding model ID (e.g. "amazon.titan-embed-text-v2:0").
|
||||
Can also be set via environment variable BEDROCK_EMBEDDING_MODEL_ID.
|
||||
region: AWS region. Defaults to "us-east-1".
|
||||
Can also be set via environment variable BEDROCK_REGION.
|
||||
access_key: AWS access key for manual credential injection.
|
||||
secret_key: AWS secret key paired with access_key.
|
||||
session_token: AWS session token for temporary credentials.
|
||||
client: Preconfigured Bedrock runtime client.
|
||||
boto3_session: Custom boto3 session used to build the runtime client.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_bedrock import BedrockEmbeddingClient
|
||||
|
||||
# Using default AWS credentials
|
||||
client = BedrockEmbeddingClient(
|
||||
model_id="amazon.titan-embed-text-v2:0",
|
||||
)
|
||||
|
||||
# Generate embeddings
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
print(result[0].vector)
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
region: str | None = None,
|
||||
model_id: str | None = None,
|
||||
access_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
session_token: str | None = None,
|
||||
client: BaseClient | None = None,
|
||||
boto3_session: Boto3Session | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a Bedrock embedding client."""
|
||||
super().__init__(
|
||||
region=region,
|
||||
model_id=model_id,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
session_token=session_token,
|
||||
client=client,
|
||||
boto3_session=boto3_session,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -28,7 +28,6 @@ dependencies = [
|
||||
"botocore>=1.35.0,<2.0.0",
|
||||
]
|
||||
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
@@ -46,6 +45,9 @@ addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
timeout = 120
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from agent_framework import Embedding, GeneratedEmbeddings
|
||||
|
||||
from agent_framework_bedrock import BedrockEmbeddingClient, BedrockEmbeddingOptions
|
||||
|
||||
|
||||
class _StubBedrockEmbeddingRuntime:
|
||||
"""Stub for the Bedrock runtime client that handles invoke_model for embeddings."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
def invoke_model(self, **kwargs: Any) -> dict[str, Any]:
|
||||
self.calls.append(kwargs)
|
||||
body = json.loads(kwargs.get("body", "{}"))
|
||||
# Simulate Titan embedding response
|
||||
dimensions = body.get("dimensions", 3)
|
||||
return {
|
||||
"body": MagicMock(
|
||||
read=lambda: json.dumps({
|
||||
"embedding": [0.1 * (i + 1) for i in range(dimensions)],
|
||||
"inputTextTokenCount": 5,
|
||||
}).encode()
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def test_bedrock_embedding_construction() -> None:
|
||||
"""Test construction with explicit parameters."""
|
||||
stub = _StubBedrockEmbeddingRuntime()
|
||||
client = BedrockEmbeddingClient(
|
||||
model_id="amazon.titan-embed-text-v2:0",
|
||||
region="us-west-2",
|
||||
client=stub,
|
||||
)
|
||||
assert client.model_id == "amazon.titan-embed-text-v2:0"
|
||||
assert client.region == "us-west-2"
|
||||
|
||||
|
||||
async def test_bedrock_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that missing model_id raises an error."""
|
||||
monkeypatch.delenv("BEDROCK_EMBEDDING_MODEL_ID", raising=False)
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
with pytest.raises(SettingNotFoundError):
|
||||
BedrockEmbeddingClient(region="us-west-2")
|
||||
|
||||
|
||||
async def test_bedrock_embedding_get_embeddings() -> None:
|
||||
"""Test generating embeddings via the Bedrock invoke_model API."""
|
||||
stub = _StubBedrockEmbeddingRuntime()
|
||||
client = BedrockEmbeddingClient(
|
||||
model_id="amazon.titan-embed-text-v2:0",
|
||||
region="us-west-2",
|
||||
client=stub,
|
||||
)
|
||||
|
||||
result = await client.get_embeddings(["hello", "world"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
assert len(result[0].vector) == 3
|
||||
assert len(result[1].vector) == 3
|
||||
assert result[0].model_id == "amazon.titan-embed-text-v2:0"
|
||||
assert result.usage == {"input_token_count": 10}
|
||||
|
||||
# Two calls since Titan processes one input at a time
|
||||
assert len(stub.calls) == 2
|
||||
call_texts = {json.loads(call["body"])["inputText"] for call in stub.calls}
|
||||
assert call_texts == {"hello", "world"}
|
||||
|
||||
|
||||
async def test_bedrock_embedding_get_embeddings_empty_input() -> None:
|
||||
"""Test generating embeddings with empty input."""
|
||||
stub = _StubBedrockEmbeddingRuntime()
|
||||
client = BedrockEmbeddingClient(
|
||||
model_id="amazon.titan-embed-text-v2:0",
|
||||
region="us-west-2",
|
||||
client=stub,
|
||||
)
|
||||
|
||||
result = await client.get_embeddings([])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 0
|
||||
assert len(stub.calls) == 0
|
||||
|
||||
|
||||
async def test_bedrock_embedding_get_embeddings_with_options() -> None:
|
||||
"""Test generating embeddings with custom options."""
|
||||
stub = _StubBedrockEmbeddingRuntime()
|
||||
client = BedrockEmbeddingClient(
|
||||
model_id="amazon.titan-embed-text-v2:0",
|
||||
region="us-west-2",
|
||||
client=stub,
|
||||
)
|
||||
|
||||
options: BedrockEmbeddingOptions = {
|
||||
"dimensions": 5,
|
||||
"normalize": True,
|
||||
}
|
||||
result = await client.get_embeddings(["hello"], options=options)
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0].vector) == 5
|
||||
|
||||
body = json.loads(stub.calls[0]["body"])
|
||||
assert body["dimensions"] == 5
|
||||
assert body["normalize"] is True
|
||||
|
||||
|
||||
async def test_bedrock_embedding_get_embeddings_no_model_raises() -> None:
|
||||
"""Test that missing model_id at call time raises ValueError."""
|
||||
stub = _StubBedrockEmbeddingRuntime()
|
||||
client = BedrockEmbeddingClient(
|
||||
model_id="amazon.titan-embed-text-v2:0",
|
||||
region="us-west-2",
|
||||
client=stub,
|
||||
)
|
||||
client.model_id = None # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(ValueError, match="model_id is required"):
|
||||
await client.get_embeddings(["hello"])
|
||||
|
||||
|
||||
async def test_bedrock_embedding_default_region() -> None:
|
||||
"""Test that default region is us-east-1."""
|
||||
stub = _StubBedrockEmbeddingRuntime()
|
||||
client = BedrockEmbeddingClient(
|
||||
model_id="amazon.titan-embed-text-v2:0",
|
||||
client=stub,
|
||||
)
|
||||
assert client.region == "us-east-1"
|
||||
|
||||
|
||||
# region: Integration Tests
|
||||
|
||||
skip_if_bedrock_embedding_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("BEDROCK_EMBEDDING_MODEL_ID", "") in ("", "test-model")
|
||||
or not (os.getenv("AWS_ACCESS_KEY_ID") or os.getenv("BEDROCK_ACCESS_KEY")),
|
||||
reason="No real Bedrock embedding model or AWS credentials provided; skipping integration tests.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_bedrock_embedding_integration_tests_disabled
|
||||
async def test_bedrock_embedding_integration() -> None:
|
||||
"""Integration test for Bedrock embedding client."""
|
||||
client = BedrockEmbeddingClient()
|
||||
result = await client.get_embeddings(["Hello, world!", "How are you?"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
for embedding in result:
|
||||
assert isinstance(embedding, Embedding)
|
||||
assert isinstance(embedding.vector, list)
|
||||
assert len(embedding.vector) > 0
|
||||
assert all(isinstance(v, float) for v in embedding.vector)
|
||||
@@ -36,6 +36,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -43,6 +44,9 @@ asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -80,6 +84,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_chatkit"
|
||||
test = "pytest --cov=agent_framework_chatkit --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -37,6 +37,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -46,6 +47,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -80,6 +84,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_claude"
|
||||
test = "pytest --cov=agent_framework_claude --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -37,6 +37,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -46,6 +47,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -58,7 +62,6 @@ omit = [
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
@@ -80,6 +83,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_copilotstudio"
|
||||
test = "pytest --cov=agent_framework_copilotstudio --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -667,16 +667,12 @@ class SupportsFileSearchTool(Protocol):
|
||||
|
||||
# region SupportsGetEmbeddings Protocol
|
||||
|
||||
# Contravariant/covariant TypeVars for the Protocol
|
||||
# Contravariant TypeVars for the Protocol
|
||||
EmbeddingInputContraT = TypeVar(
|
||||
"EmbeddingInputContraT",
|
||||
default="str",
|
||||
contravariant=True,
|
||||
)
|
||||
EmbeddingCoT = TypeVar(
|
||||
"EmbeddingCoT",
|
||||
default="list[float]",
|
||||
)
|
||||
EmbeddingOptionsContraT = TypeVar(
|
||||
"EmbeddingOptionsContraT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
@@ -686,7 +682,7 @@ EmbeddingOptionsContraT = TypeVar(
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingCoT, EmbeddingOptionsContraT]):
|
||||
class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingT, EmbeddingOptionsContraT]):
|
||||
"""Protocol for an embedding client that can generate embeddings.
|
||||
|
||||
This protocol enables duck-typing for embedding generation. Any class that
|
||||
@@ -714,7 +710,7 @@ class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingCoT, Embedd
|
||||
values: Sequence[EmbeddingInputContraT],
|
||||
*,
|
||||
options: EmbeddingOptionsContraT | None = None,
|
||||
) -> Awaitable[GeneratedEmbeddings[EmbeddingCoT]]:
|
||||
) -> Awaitable[GeneratedEmbeddings[EmbeddingT]]:
|
||||
"""Generate embeddings for the given values.
|
||||
|
||||
Args:
|
||||
@@ -733,15 +729,15 @@ class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingCoT, Embedd
|
||||
# region BaseEmbeddingClient
|
||||
|
||||
# Covariant for the BaseEmbeddingClient
|
||||
EmbeddingOptionsCoT = TypeVar(
|
||||
"EmbeddingOptionsCoT",
|
||||
EmbeddingOptionsT = TypeVar(
|
||||
"EmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="EmbeddingGenerationOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsCoT]):
|
||||
class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsT]):
|
||||
"""Abstract base class for embedding clients.
|
||||
|
||||
Subclasses implement ``get_embeddings`` to provide the actual
|
||||
@@ -785,7 +781,7 @@ class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, Embe
|
||||
self,
|
||||
values: Sequence[EmbeddingInputT],
|
||||
*,
|
||||
options: EmbeddingOptionsCoT | None = None,
|
||||
options: EmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[EmbeddingT]:
|
||||
"""Generate embeddings for the given values.
|
||||
|
||||
|
||||
@@ -377,6 +377,12 @@ class UsageDetails(TypedDict, total=False):
|
||||
|
||||
This is a non-closed dictionary, so any specific provider fields can be added as needed.
|
||||
Whenever they can be mapped to standard fields, they will be.
|
||||
|
||||
Keys:
|
||||
input_token_count: The number of input tokens used.
|
||||
output_token_count: The number of output tokens generated.
|
||||
total_token_count: The total number of tokens (input + output).
|
||||
|
||||
"""
|
||||
|
||||
input_token_count: int | None
|
||||
@@ -3289,7 +3295,7 @@ class GeneratedEmbeddings(list[Embedding[EmbeddingT]], Generic[EmbeddingT, Embed
|
||||
embeddings: Iterable[Embedding[EmbeddingT]] | None = None,
|
||||
*,
|
||||
options: EmbeddingOptionsT | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
usage: UsageDetails | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(embeddings or [])
|
||||
|
||||
@@ -8,6 +8,9 @@ This module lazily re-exports objects from:
|
||||
Supported classes:
|
||||
- BedrockChatClient
|
||||
- BedrockChatOptions
|
||||
- BedrockEmbeddingClient
|
||||
- BedrockEmbeddingOptions
|
||||
- BedrockEmbeddingSettings
|
||||
- BedrockGuardrailConfig
|
||||
- BedrockSettings
|
||||
"""
|
||||
@@ -17,7 +20,15 @@ from typing import Any
|
||||
|
||||
IMPORT_PATH = "agent_framework_bedrock"
|
||||
PACKAGE_NAME = "agent-framework-bedrock"
|
||||
_IMPORTS = ["BedrockChatClient", "BedrockChatOptions", "BedrockGuardrailConfig", "BedrockSettings"]
|
||||
_IMPORTS = [
|
||||
"BedrockChatClient",
|
||||
"BedrockChatOptions",
|
||||
"BedrockEmbeddingClient",
|
||||
"BedrockEmbeddingOptions",
|
||||
"BedrockEmbeddingSettings",
|
||||
"BedrockGuardrailConfig",
|
||||
"BedrockSettings",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
from agent_framework_bedrock import (
|
||||
BedrockChatClient,
|
||||
BedrockChatOptions,
|
||||
BedrockEmbeddingClient,
|
||||
BedrockEmbeddingOptions,
|
||||
BedrockEmbeddingSettings,
|
||||
BedrockGuardrailConfig,
|
||||
BedrockSettings,
|
||||
)
|
||||
@@ -10,6 +13,9 @@ from agent_framework_bedrock import (
|
||||
__all__ = [
|
||||
"BedrockChatClient",
|
||||
"BedrockChatOptions",
|
||||
"BedrockEmbeddingClient",
|
||||
"BedrockEmbeddingOptions",
|
||||
"BedrockEmbeddingSettings",
|
||||
"BedrockGuardrailConfig",
|
||||
"BedrockSettings",
|
||||
]
|
||||
|
||||
@@ -99,6 +99,7 @@ class AzureOpenAIEmbeddingClient(
|
||||
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -133,4 +134,5 @@ class AzureOpenAIEmbeddingClient(
|
||||
credential=credential,
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
otel_provider_name=otel_provider_name,
|
||||
)
|
||||
|
||||
@@ -1279,15 +1279,15 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
return _get_response()
|
||||
|
||||
|
||||
EmbeddingOptionsCoT = TypeVar(
|
||||
"EmbeddingOptionsCoT",
|
||||
EmbeddingOptionsT = TypeVar(
|
||||
"EmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="EmbeddingGenerationOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingTelemetryLayer(Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsCoT]):
|
||||
class EmbeddingTelemetryLayer(Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsT]):
|
||||
"""Layer that wraps embedding client get_embeddings with OpenTelemetry tracing."""
|
||||
|
||||
def __init__(self, *args: Any, otel_provider_name: str | None = None, **kwargs: Any) -> None:
|
||||
@@ -1301,7 +1301,7 @@ class EmbeddingTelemetryLayer(Generic[EmbeddingInputT, EmbeddingT, EmbeddingOpti
|
||||
self,
|
||||
values: Sequence[EmbeddingInputT],
|
||||
*,
|
||||
options: EmbeddingOptionsCoT | None = None,
|
||||
options: EmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[EmbeddingT]:
|
||||
"""Trace embedding generation with OpenTelemetry spans and metrics."""
|
||||
global OBSERVABILITY_SETTINGS
|
||||
|
||||
@@ -7,6 +7,10 @@ This module lazily re-exports objects from:
|
||||
|
||||
Supported classes:
|
||||
- OllamaChatClient
|
||||
- OllamaChatOptions
|
||||
- OllamaEmbeddingClient
|
||||
- OllamaEmbeddingOptions
|
||||
- OllamaEmbeddingSettings
|
||||
- OllamaSettings
|
||||
"""
|
||||
|
||||
@@ -15,7 +19,14 @@ from typing import Any
|
||||
|
||||
IMPORT_PATH = "agent_framework_ollama"
|
||||
PACKAGE_NAME = "agent-framework-ollama"
|
||||
_IMPORTS = ["OllamaChatClient", "OllamaSettings"]
|
||||
_IMPORTS = [
|
||||
"OllamaChatClient",
|
||||
"OllamaChatOptions",
|
||||
"OllamaEmbeddingClient",
|
||||
"OllamaEmbeddingOptions",
|
||||
"OllamaEmbeddingSettings",
|
||||
"OllamaSettings",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
|
||||
@@ -2,10 +2,18 @@
|
||||
|
||||
from agent_framework_ollama import (
|
||||
OllamaChatClient,
|
||||
OllamaChatOptions,
|
||||
OllamaEmbeddingClient,
|
||||
OllamaEmbeddingOptions,
|
||||
OllamaEmbeddingSettings,
|
||||
OllamaSettings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OllamaChatClient",
|
||||
"OllamaChatOptions",
|
||||
"OllamaEmbeddingClient",
|
||||
"OllamaEmbeddingOptions",
|
||||
"OllamaEmbeddingSettings",
|
||||
"OllamaSettings",
|
||||
]
|
||||
|
||||
@@ -12,7 +12,7 @@ from openai import AsyncOpenAI
|
||||
|
||||
from .._clients import BaseEmbeddingClient
|
||||
from .._settings import load_settings
|
||||
from .._types import Embedding, EmbeddingGenerationOptions, GeneratedEmbeddings
|
||||
from .._types import Embedding, EmbeddingGenerationOptions, GeneratedEmbeddings, UsageDetails
|
||||
from ..observability import EmbeddingTelemetryLayer
|
||||
from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings
|
||||
|
||||
@@ -116,11 +116,11 @@ class RawOpenAIEmbeddingClient(
|
||||
)
|
||||
)
|
||||
|
||||
usage_dict: dict[str, Any] | None = None
|
||||
usage_dict: UsageDetails | None = None
|
||||
if response.usage:
|
||||
usage_dict = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"input_token_count": response.usage.prompt_tokens,
|
||||
"total_token_count": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
|
||||
@@ -143,6 +143,7 @@ class OpenAIEmbeddingClient(
|
||||
default_headers: Additional HTTP headers.
|
||||
async_client: Pre-configured AsyncOpenAI client.
|
||||
base_url: Custom API base URL.
|
||||
otel_provider_name: Override the OpenTelemetry provider name for telemetry.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
|
||||
@@ -176,6 +177,7 @@ class OpenAIEmbeddingClient(
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
base_url: str | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -208,4 +210,5 @@ class OpenAIEmbeddingClient(
|
||||
org_id=openai_settings["org_id"],
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
otel_provider_name=otel_provider_name,
|
||||
)
|
||||
|
||||
@@ -91,6 +91,9 @@ asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
@@ -124,6 +127,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework"
|
||||
test = "pytest --cov=agent_framework --cov-report=term-missing:skip-covered -n auto --dist worksteal tests"
|
||||
|
||||
@@ -100,8 +100,8 @@ async def test_openai_get_embeddings_usage(openai_unit_test_env: None) -> None:
|
||||
result = await client.get_embeddings(["test"])
|
||||
|
||||
assert result.usage is not None
|
||||
assert result.usage["prompt_tokens"] == 10
|
||||
assert result.usage["total_tokens"] == 10
|
||||
assert result.usage["input_token_count"] == 10
|
||||
assert result.usage["total_token_count"] == 10
|
||||
|
||||
|
||||
async def test_openai_options_passthrough_dimensions(openai_unit_test_env: None) -> None:
|
||||
@@ -284,7 +284,7 @@ async def test_integration_openai_get_embeddings() -> None:
|
||||
assert all(isinstance(v, float) for v in result[0].vector)
|
||||
assert result[0].model_id is not None
|
||||
assert result.usage is not None
|
||||
assert result.usage["prompt_tokens"] > 0
|
||||
assert result.usage["input_token_count"] > 0
|
||||
|
||||
|
||||
@skip_if_openai_integration_tests_disabled
|
||||
@@ -327,7 +327,7 @@ async def test_integration_azure_openai_get_embeddings() -> None:
|
||||
assert all(isinstance(v, float) for v in result[0].vector)
|
||||
assert result[0].model_id is not None
|
||||
assert result.usage is not None
|
||||
assert result.usage["prompt_tokens"] > 0
|
||||
assert result.usage["input_token_count"] > 0
|
||||
|
||||
|
||||
@skip_if_azure_openai_integration_tests_disabled
|
||||
|
||||
@@ -39,9 +39,9 @@ environments = [
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -51,6 +51,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -88,6 +91,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_declarative"
|
||||
test = "pytest --cov=agent_framework_declarative --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -54,6 +54,9 @@ addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -66,7 +69,6 @@ omit = [
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
@@ -89,6 +91,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_devui"
|
||||
test = "pytest --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -43,6 +43,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
pythonpath = ["tests/integration_tests"]
|
||||
@@ -94,6 +95,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_durabletask"
|
||||
test = "pytest --cov=agent_framework_durabletask --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -37,6 +37,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -44,6 +45,9 @@ asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -78,6 +82,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_foundry_local"
|
||||
test = "pytest --cov=agent_framework_foundry_local --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -37,6 +37,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -46,6 +47,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -58,7 +62,6 @@ omit = [
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
@@ -80,6 +83,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_github_copilot"
|
||||
test = "pytest --cov=agent_framework_github_copilot --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -37,6 +37,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -46,6 +47,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -58,7 +62,6 @@ omit = [
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
@@ -80,6 +83,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mem0"
|
||||
test = "pytest --cov=agent_framework_mem0 --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import importlib.metadata
|
||||
|
||||
from ._chat_client import OllamaChatClient, OllamaChatOptions, OllamaSettings
|
||||
from ._embedding_client import OllamaEmbeddingClient, OllamaEmbeddingOptions, OllamaEmbeddingSettings
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
@@ -12,6 +13,9 @@ except importlib.metadata.PackageNotFoundError:
|
||||
__all__ = [
|
||||
"OllamaChatClient",
|
||||
"OllamaChatOptions",
|
||||
"OllamaEmbeddingClient",
|
||||
"OllamaEmbeddingOptions",
|
||||
"OllamaEmbeddingSettings",
|
||||
"OllamaSettings",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,230 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
|
||||
from agent_framework import (
|
||||
BaseEmbeddingClient,
|
||||
Embedding,
|
||||
EmbeddingGenerationOptions,
|
||||
GeneratedEmbeddings,
|
||||
UsageDetails,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework.observability import EmbeddingTelemetryLayer
|
||||
from ollama import AsyncClient
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework.ollama")
|
||||
|
||||
|
||||
class OllamaEmbeddingOptions(EmbeddingGenerationOptions, total=False):
|
||||
"""Ollama-specific embedding options.
|
||||
|
||||
Extends EmbeddingGenerationOptions with Ollama-specific fields.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_ollama import OllamaEmbeddingOptions
|
||||
|
||||
options: OllamaEmbeddingOptions = {
|
||||
"model_id": "nomic-embed-text",
|
||||
"dimensions": 768,
|
||||
"truncate": True,
|
||||
}
|
||||
"""
|
||||
|
||||
truncate: bool
|
||||
"""Whether to truncate input text that exceeds the model's context length.
|
||||
|
||||
When True, input that is too long will be silently truncated.
|
||||
When False (default), the request will fail if input exceeds the context length.
|
||||
"""
|
||||
|
||||
keep_alive: float | str
|
||||
"""How long to keep the model loaded in memory (e.g. ``"5m"``, ``300``)."""
|
||||
|
||||
|
||||
OllamaEmbeddingOptionsT = TypeVar(
|
||||
"OllamaEmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="OllamaEmbeddingOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class OllamaEmbeddingSettings(TypedDict, total=False):
|
||||
"""Ollama embedding settings."""
|
||||
|
||||
host: str | None
|
||||
embedding_model_id: str | None
|
||||
|
||||
|
||||
class RawOllamaEmbeddingClient(
|
||||
BaseEmbeddingClient[str, list[float], OllamaEmbeddingOptionsT],
|
||||
Generic[OllamaEmbeddingOptionsT],
|
||||
):
|
||||
"""Raw Ollama embedding client without telemetry.
|
||||
|
||||
Keyword Args:
|
||||
model_id: The Ollama embedding model ID (e.g. "nomic-embed-text").
|
||||
Can also be set via environment variable OLLAMA_EMBEDDING_MODEL_ID.
|
||||
host: Ollama server URL. Defaults to http://localhost:11434.
|
||||
Can also be set via environment variable OLLAMA_HOST.
|
||||
client: Optional pre-configured Ollama AsyncClient.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
host: str | None = None,
|
||||
client: AsyncClient | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw Ollama embedding client."""
|
||||
ollama_settings = load_settings(
|
||||
OllamaEmbeddingSettings,
|
||||
env_prefix="OLLAMA_",
|
||||
required_fields=["embedding_model_id"],
|
||||
host=host,
|
||||
embedding_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
self.model_id = ollama_settings["embedding_model_id"]
|
||||
self.client = client or AsyncClient(host=ollama_settings.get("host"))
|
||||
self.host = str(self.client._client.base_url) # pyright: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
return self.host
|
||||
|
||||
async def get_embeddings(
|
||||
self,
|
||||
values: Sequence[str],
|
||||
*,
|
||||
options: OllamaEmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[list[float]]:
|
||||
"""Call the Ollama embed API.
|
||||
|
||||
Args:
|
||||
values: The text values to generate embeddings for.
|
||||
options: Optional embedding generation options.
|
||||
|
||||
Returns:
|
||||
Generated embeddings with usage metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_id is not provided or values is empty.
|
||||
"""
|
||||
if not values:
|
||||
return GeneratedEmbeddings([], options=options)
|
||||
|
||||
opts: dict[str, Any] = dict(options) if options else {}
|
||||
model = opts.get("model_id") or self.model_id
|
||||
if not model:
|
||||
raise ValueError("model_id is required")
|
||||
|
||||
kwargs: dict[str, Any] = {"model": model, "input": list(values)}
|
||||
if (truncate := opts.get("truncate")) is not None:
|
||||
kwargs["truncate"] = truncate
|
||||
if keep_alive := opts.get("keep_alive"):
|
||||
kwargs["keep_alive"] = keep_alive
|
||||
if dimensions := opts.get("dimensions"):
|
||||
kwargs["dimensions"] = dimensions
|
||||
|
||||
response = await self.client.embed(**kwargs)
|
||||
|
||||
embeddings = [
|
||||
Embedding(
|
||||
vector=list(emb),
|
||||
dimensions=len(emb),
|
||||
model_id=response.get("model") or model,
|
||||
)
|
||||
for emb in response.get("embeddings", [])
|
||||
]
|
||||
|
||||
usage_dict: UsageDetails | None = None
|
||||
prompt_eval_count = response.get("prompt_eval_count")
|
||||
if prompt_eval_count is not None:
|
||||
usage_dict = {"input_token_count": prompt_eval_count}
|
||||
|
||||
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
|
||||
|
||||
|
||||
class OllamaEmbeddingClient(
|
||||
EmbeddingTelemetryLayer[str, list[float], OllamaEmbeddingOptionsT],
|
||||
RawOllamaEmbeddingClient[OllamaEmbeddingOptionsT],
|
||||
Generic[OllamaEmbeddingOptionsT],
|
||||
):
|
||||
"""Ollama embedding client with telemetry support.
|
||||
|
||||
Keyword Args:
|
||||
model_id: The Ollama embedding model ID (e.g. "nomic-embed-text").
|
||||
Can also be set via environment variable OLLAMA_EMBEDDING_MODEL_ID.
|
||||
host: Ollama server URL. Defaults to http://localhost:11434.
|
||||
Can also be set via environment variable OLLAMA_HOST.
|
||||
client: Optional pre-configured Ollama AsyncClient.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_ollama import OllamaEmbeddingClient
|
||||
|
||||
# Using environment variables
|
||||
# Set OLLAMA_EMBEDDING_MODEL_ID=nomic-embed-text
|
||||
client = OllamaEmbeddingClient()
|
||||
|
||||
# Or passing parameters directly
|
||||
client = OllamaEmbeddingClient(
|
||||
model_id="nomic-embed-text",
|
||||
host="http://localhost:11434",
|
||||
)
|
||||
|
||||
# Generate embeddings
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
print(result[0].vector)
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME: ClassVar[str] = "ollama"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
host: str | None = None,
|
||||
client: AsyncClient | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Ollama embedding client."""
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
host=host,
|
||||
client=client,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -37,12 +37,16 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
timeout = 120
|
||||
|
||||
[tool.ruff]
|
||||
@@ -82,6 +86,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_ollama"
|
||||
test = "pytest --cov=agent_framework_ollama --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Embedding, GeneratedEmbeddings
|
||||
|
||||
from agent_framework_ollama import OllamaEmbeddingClient, OllamaEmbeddingOptions
|
||||
|
||||
# region: Unit Tests
|
||||
|
||||
|
||||
def test_ollama_embedding_construction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test construction with explicit parameters."""
|
||||
monkeypatch.setenv("OLLAMA_EMBEDDING_MODEL_ID", "nomic-embed-text")
|
||||
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
|
||||
mock_client_cls.return_value = MagicMock()
|
||||
client = OllamaEmbeddingClient()
|
||||
assert client.model_id == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_ollama_embedding_construction_with_params() -> None:
|
||||
"""Test construction with explicit parameters."""
|
||||
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
|
||||
mock_client_cls.return_value = MagicMock()
|
||||
client = OllamaEmbeddingClient(
|
||||
model_id="nomic-embed-text",
|
||||
host="http://localhost:11434",
|
||||
)
|
||||
assert client.model_id == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_ollama_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that missing model_id raises an error."""
|
||||
monkeypatch.delenv("OLLAMA_EMBEDDING_MODEL_ID", raising=False)
|
||||
monkeypatch.delenv("OLLAMA_MODEL_ID", raising=False)
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
with pytest.raises(SettingNotFoundError):
|
||||
OllamaEmbeddingClient()
|
||||
|
||||
|
||||
async def test_ollama_embedding_get_embeddings() -> None:
|
||||
"""Test generating embeddings via the Ollama API."""
|
||||
mock_response = {
|
||||
"model": "nomic-embed-text",
|
||||
"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
|
||||
"prompt_eval_count": 10,
|
||||
}
|
||||
|
||||
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embed = AsyncMock(return_value=mock_response)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
|
||||
result = await client.get_embeddings(["hello", "world"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
assert result[0].vector == [0.1, 0.2, 0.3]
|
||||
assert result[1].vector == [0.4, 0.5, 0.6]
|
||||
assert result[0].model_id == "nomic-embed-text"
|
||||
assert result.usage == {"input_token_count": 10}
|
||||
|
||||
mock_client.embed.assert_called_once_with(
|
||||
model="nomic-embed-text",
|
||||
input=["hello", "world"],
|
||||
)
|
||||
|
||||
|
||||
async def test_ollama_embedding_get_embeddings_empty_input() -> None:
|
||||
"""Test generating embeddings with empty input."""
|
||||
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
|
||||
result = await client.get_embeddings([])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 0
|
||||
mock_client.embed.assert_not_called()
|
||||
|
||||
|
||||
async def test_ollama_embedding_get_embeddings_with_options() -> None:
|
||||
"""Test generating embeddings with custom options."""
|
||||
mock_response = {
|
||||
"model": "nomic-embed-text",
|
||||
"embeddings": [[0.1, 0.2, 0.3]],
|
||||
}
|
||||
|
||||
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embed = AsyncMock(return_value=mock_response)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
|
||||
options: OllamaEmbeddingOptions = {
|
||||
"truncate": True,
|
||||
"dimensions": 512,
|
||||
}
|
||||
result = await client.get_embeddings(["hello"], options=options)
|
||||
|
||||
assert len(result) == 1
|
||||
mock_client.embed.assert_called_once_with(
|
||||
model="nomic-embed-text",
|
||||
input=["hello"],
|
||||
truncate=True,
|
||||
dimensions=512,
|
||||
)
|
||||
|
||||
|
||||
async def test_ollama_embedding_get_embeddings_no_model_raises() -> None:
|
||||
"""Test that missing model_id at call time raises ValueError."""
|
||||
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
|
||||
client.model_id = None # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(ValueError, match="model_id is required"):
|
||||
await client.get_embeddings(["hello"])
|
||||
|
||||
|
||||
# region: Integration Tests
|
||||
|
||||
skip_if_ollama_embedding_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("OLLAMA_EMBEDDING_MODEL_ID", "") in ("", "test-model"),
|
||||
reason="No real Ollama embedding model provided; skipping integration tests.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_ollama_embedding_integration_tests_disabled
|
||||
async def test_ollama_embedding_integration() -> None:
|
||||
"""Integration test for Ollama embedding client."""
|
||||
client = OllamaEmbeddingClient()
|
||||
result = await client.get_embeddings(["Hello, world!", "How are you?"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
for embedding in result:
|
||||
assert isinstance(embedding, Embedding)
|
||||
assert isinstance(embedding.vector, list)
|
||||
assert len(embedding.vector) > 0
|
||||
assert all(isinstance(v, float) for v in embedding.vector)
|
||||
@@ -44,6 +44,9 @@ asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -78,6 +81,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_orchestrations"
|
||||
test = "pytest --cov=agent_framework_orchestrations --cov-report=term-missing:skip-covered -n auto --dist worksteal tests"
|
||||
|
||||
@@ -39,12 +39,16 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -57,7 +61,6 @@ omit = [
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
@@ -79,6 +82,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_purview"
|
||||
test = "pytest --cov=agent_framework_purview --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -39,6 +39,7 @@ environments = [
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
@@ -48,6 +49,9 @@ filterwarnings = [
|
||||
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
|
||||
]
|
||||
timeout = 120
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
@@ -81,6 +85,7 @@ exclude_dirs = ["tests"]
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_redis"
|
||||
test = "pytest --cov=agent_framework_redis --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "agent-framework-azure-ai",
|
||||
# ]
|
||||
# ///
|
||||
# Run with: uv run samples/02-agents/embeddings/azure_ai_inference_embeddings.py
|
||||
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
|
||||
from agent_framework import Content
|
||||
from agent_framework_azure_ai import AzureAIInferenceEmbeddingClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
"""Azure AI Inference Image Embedding Example
|
||||
|
||||
This sample demonstrates how to generate image embeddings using the
|
||||
Azure AI Inference embedding client with the Cohere-embed-v3-english model.
|
||||
Images are passed as ``Content`` objects created with ``Content.from_data()``.
|
||||
|
||||
Prerequisites:
|
||||
Set the following environment variables or add them to a .env file:
|
||||
- AZURE_AI_INFERENCE_ENDPOINT: Your Azure AI model inference endpoint URL
|
||||
- AZURE_AI_INFERENCE_API_KEY: Your API key
|
||||
- AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID: The text embedding model name
|
||||
(e.g. "text-embedding-3-small")
|
||||
- AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID: The image embedding model name
|
||||
(e.g. "Cohere-embed-v3-english")
|
||||
"""
|
||||
|
||||
SAMPLE_IMAGE_PATH = pathlib.Path(__file__).parent.parent.parent / "shared" / "sample_assets" / "sample_image.jpg"
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Generate image embeddings with Azure AI Inference."""
|
||||
async with AzureAIInferenceEmbeddingClient() as client:
|
||||
# 1. Generate an image embedding.
|
||||
image_bytes = SAMPLE_IMAGE_PATH.read_bytes()
|
||||
image_content = Content.from_data(data=image_bytes, media_type="image/jpeg")
|
||||
result = await client.get_embeddings([image_content])
|
||||
print(f"Image embedding dimensions: {result[0].dimensions}")
|
||||
print(f"First 5 values: {result[0].vector[:5]}")
|
||||
print(f"Model: {result[0].model_id}")
|
||||
print(f"Usage: {result.usage}")
|
||||
print()
|
||||
|
||||
# 2. Generate image and text embeddings separately in one call.
|
||||
# The client dispatches text to the text endpoint and images to the image
|
||||
# endpoint, then reassembles results in the original input order.
|
||||
result = await client.get_embeddings(["A half-timbered house in a forested valley", image_content])
|
||||
print(f"Text embedding dimensions: {result[0].dimensions}")
|
||||
print(f"First 5 values: {result[0].vector[:5]}")
|
||||
print(f"Image embedding dimensions: {result[1].dimensions}")
|
||||
print(f"First 5 values: {result[1].vector[:5]}")
|
||||
print()
|
||||
|
||||
# 3. Generate image embeddings with input_type option.
|
||||
result = await client.get_embeddings(
|
||||
[image_content],
|
||||
options={"input_type": "document"},
|
||||
)
|
||||
print(f"Document embedding dimensions: {result[0].dimensions}")
|
||||
print(f"First 5 values: {result[0].vector[:5]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
"""
|
||||
Sample output (using Cohere-embed-v3-english):
|
||||
Image embedding dimensions: 1024
|
||||
First 5 values: [0.023, -0.045, 0.067, -0.089, 0.011]
|
||||
Model: Cohere-embed-v3-english
|
||||
Usage: {'prompt_tokens': 1, 'total_tokens': 1}
|
||||
|
||||
Image+text (separate) results:
|
||||
Text embedding dimensions: 1536
|
||||
Image embedding dimensions: 1024
|
||||
|
||||
Document embedding dimensions: 1024
|
||||
"""
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 178 KiB |
Generated
+16
@@ -209,6 +209,7 @@ dependencies = [
|
||||
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-ai-agents", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "azure-ai-inference", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -216,6 +217,7 @@ requires-dist = [
|
||||
{ name = "agent-framework-core", editable = "packages/core" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "azure-ai-agents", specifier = "==1.2.0b5" },
|
||||
{ name = "azure-ai-inference", specifier = ">=1.0.0b9" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -979,6 +981,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/6d/15070d23d7a94833a210da09d5d7ed3c24838bb84f0463895e5d159f1695/azure_ai_agents-1.2.0b5-py3-none-any.whl", hash = "sha256:257d0d24a6bf13eed4819cfa5c12fb222e5908deafb3cbfd5711d3a511cc4e88", size = 217948, upload-time = "2025-09-30T01:55:04.155Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-ai-inference"
|
||||
version = "1.0.0b9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "azure-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "isodate", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4e/6a/ed85592e5c64e08c291992f58b1a94dab6869f28fb0f40fd753dced73ba6/azure_ai_inference-1.0.0b9.tar.gz", hash = "sha256:1feb496bd84b01ee2691befc04358fa25d7c344d8288e99364438859ad7cd5a4", size = 182408, upload-time = "2025-02-15T00:37:28.464Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/0f/27520da74769db6e58327d96c98e7b9a07ce686dff582c9a5ec60b03f9dd/azure_ai_inference-1.0.0b9-py3-none-any.whl", hash = "sha256:49823732e674092dad83bb8b0d1b65aa73111fab924d61349eb2a8cdc0493990", size = 124885, upload-time = "2025-02-15T00:37:29.964Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-ai-projects"
|
||||
version = "2.0.0b3"
|
||||
|
||||
Reference in New Issue
Block a user