Python: Phase 2: Embedding clients for Ollama, Bedrock, and Azure AI Inference (#4207)

* Phase 2: Embedding clients for Ollama, Bedrock, and Azure AI Inference

Add embedding client implementations to existing provider packages:

- OllamaEmbeddingClient: Text embeddings via Ollama's embed API
- BedrockEmbeddingClient: Text embeddings via Amazon Titan on Bedrock
- AzureAIInferenceEmbeddingClient: Text and image embeddings via Azure AI
  Inference, supporting Content | str input with separate model IDs for
  text (AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID) and image
  (AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID) endpoints

Additional changes:
- Rename EmbeddingCoT -> EmbeddingT, EmbeddingOptionsCoT -> EmbeddingOptionsT
- Add otel_provider_name passthrough to all embedding clients
- Register integration pytest marker in all packages
- Add lazy-loading namespace exports for Ollama and Bedrock embeddings
- Add image embedding sample using Cohere-embed-v3-english
- Add azure-ai-inference dependency to azure-ai package

Part of #1188

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix mypy duplicate name and ruff lint issues

- Rename second 'vector' variable to 'img_vector' in image embedding loop
- Combine nested with statements in tests
- Remove unused result assignments in tests

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* updates from feedback

* Fix CI failures in embedding usage handling

- Fix Azure AI embedding mypy issues by normalizing vectors to list[float],
  safely accumulating optional usage token fields, and filtering None entries
  before constructing GeneratedEmbeddings
- Avoid Bandit false positive by initializing usage details as an empty dict
- Update OpenAI embedding tests to assert canonical usage keys
  (input_token_count/total_token_count)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-02-25 18:45:08 +01:00
committed by GitHub
Unverified
parent e3a5b915a6
commit 6138487888
44 changed files with 1836 additions and 34 deletions
@@ -140,9 +140,10 @@ This feature ports the vector store abstractions, embedding generator abstractio
## Implementation Phases
### Phase 1: Core Embedding Abstractions & OpenAI Implementation
### Phase 1: Core Embedding Abstractions & OpenAI Implementation ✅ DONE
**Goal:** Establish the embedding generator abstraction and ship one working implementation.
**Mergeable:** Yes — adds new types/protocols, no breaking changes.
**Status:** Merged via PR #4153. Closes sub-issue #4163.
#### 1.1 — Embedding types in `_types.py`
- `EmbeddingInputT` TypeVar (default `str`) — generic input type for embedding generation
+5
View File
@@ -37,6 +37,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -46,6 +47,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -79,6 +83,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_a2a"
test = "pytest --cov=agent_framework_a2a --cov-report=term-missing:skip-covered tests"
+3
View File
@@ -45,6 +45,9 @@ packages = ["agent_framework_ag_ui", "agent_framework_ag_ui_examples"]
asyncio_mode = "auto"
testpaths = ["tests/ag_ui"]
pythonpath = ["."]
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
line-length = 120
+5
View File
@@ -37,6 +37,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -46,6 +47,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -80,6 +84,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_anthropic"
test = "pytest --cov=agent_framework_anthropic --cov-report=term-missing:skip-covered -n auto --dist worksteal tests"
@@ -47,6 +47,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -82,6 +85,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_ai_search"
test = "pytest --cov=agent_framework_azure_ai_search --cov-report=term-missing:skip-covered tests"
@@ -5,6 +5,12 @@ import importlib.metadata
from ._agent_provider import AzureAIAgentsProvider
from ._chat_client import AzureAIAgentClient, AzureAIAgentOptions
from ._client import AzureAIClient, AzureAIProjectAgentOptions, RawAzureAIClient
from ._embedding_client import (
AzureAIInferenceEmbeddingClient,
AzureAIInferenceEmbeddingOptions,
AzureAIInferenceEmbeddingSettings,
RawAzureAIInferenceEmbeddingClient,
)
from ._foundry_memory_provider import FoundryMemoryProvider
from ._project_provider import AzureAIProjectAgentProvider
from ._shared import AzureAISettings
@@ -19,10 +25,14 @@ __all__ = [
"AzureAIAgentOptions",
"AzureAIAgentsProvider",
"AzureAIClient",
"AzureAIInferenceEmbeddingClient",
"AzureAIInferenceEmbeddingOptions",
"AzureAIInferenceEmbeddingSettings",
"AzureAIProjectAgentOptions",
"AzureAIProjectAgentProvider",
"AzureAISettings",
"FoundryMemoryProvider",
"RawAzureAIClient",
"RawAzureAIInferenceEmbeddingClient",
"__version__",
]
@@ -0,0 +1,396 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import logging
import sys
from collections.abc import Sequence
from contextlib import suppress
from typing import Any, ClassVar, Generic, TypedDict
from agent_framework import (
BaseEmbeddingClient,
Content,
Embedding,
EmbeddingGenerationOptions,
GeneratedEmbeddings,
UsageDetails,
load_settings,
)
from agent_framework.observability import EmbeddingTelemetryLayer
from azure.ai.inference.aio import EmbeddingsClient, ImageEmbeddingsClient
from azure.ai.inference.models import ImageEmbeddingInput
from azure.core.credentials import AzureKeyCredential
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
logger = logging.getLogger("agent_framework.azure_ai")
_IMAGE_MEDIA_PREFIXES = ("image/",)
class AzureAIInferenceEmbeddingOptions(EmbeddingGenerationOptions, total=False):
"""Azure AI Inference-specific embedding options.
Extends EmbeddingGenerationOptions with Azure AI Inference-specific fields.
Examples:
.. code-block:: python
from agent_framework_azure_ai import AzureAIInferenceEmbeddingOptions
options: AzureAIInferenceEmbeddingOptions = {
"model_id": "text-embedding-3-small",
"dimensions": 1536,
"input_type": "document",
"encoding_format": "float",
}
"""
input_type: str
"""Input type hint for the model. Common values: ``"text"``, ``"query"``, ``"document"``."""
image_model_id: str
"""Override model for image embeddings. Falls back to the client's ``image_model_id``."""
encoding_format: str
"""Output encoding format.
Common values: ``"float"``, ``"base64"``, ``"int8"``, ``"uint8"``,
``"binary"``, ``"ubinary"``.
"""
extra_parameters: dict[str, Any]
"""Additional model-specific parameters passed directly to the API."""
AzureAIInferenceEmbeddingOptionsT = TypeVar(
"AzureAIInferenceEmbeddingOptionsT",
bound=TypedDict, # type: ignore[valid-type]
default="AzureAIInferenceEmbeddingOptions",
covariant=True,
)
class AzureAIInferenceEmbeddingSettings(TypedDict, total=False):
"""Azure AI Inference embedding settings."""
endpoint: str | None
api_key: str | None
embedding_model_id: str | None
image_embedding_model_id: str | None
class RawAzureAIInferenceEmbeddingClient(
BaseEmbeddingClient[Content | str, list[float], AzureAIInferenceEmbeddingOptionsT],
Generic[AzureAIInferenceEmbeddingOptionsT],
):
"""Raw Azure AI Inference embedding client without telemetry.
Accepts both text (``str``) and image (``Content``) inputs. Text and image
inputs within a single batch are separated and dispatched to
``EmbeddingsClient`` and ``ImageEmbeddingsClient`` respectively. Results
are reassembled in the original input order.
Keyword Args:
model_id: The text embedding model deployment name (e.g. "text-embedding-3-small").
Can also be set via environment variable AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID.
image_model_id: The image embedding model deployment name (e.g. "Cohere-embed-v3-english").
Can also be set via environment variable AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID.
Falls back to ``model_id`` if not provided.
endpoint: The Azure AI Inference endpoint URL.
Can also be set via environment variable AZURE_AI_INFERENCE_ENDPOINT.
api_key: API key for authentication.
Can also be set via environment variable AZURE_AI_INFERENCE_API_KEY.
text_client: Optional pre-configured ``EmbeddingsClient``.
image_client: Optional pre-configured ``ImageEmbeddingsClient``.
credential: Optional ``AzureKeyCredential`` or token credential. If not provided,
one is created from ``api_key``.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
"""
def __init__(
self,
*,
model_id: str | None = None,
image_model_id: str | None = None,
endpoint: str | None = None,
api_key: str | None = None,
text_client: EmbeddingsClient | None = None,
image_client: ImageEmbeddingsClient | None = None,
credential: AzureKeyCredential | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw Azure AI Inference embedding client."""
settings = load_settings(
AzureAIInferenceEmbeddingSettings,
env_prefix="AZURE_AI_INFERENCE_",
required_fields=["endpoint", "embedding_model_id"],
endpoint=endpoint,
api_key=api_key,
embedding_model_id=model_id,
image_embedding_model_id=image_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self.model_id = settings["embedding_model_id"] # type: ignore[reportTypedDictNotRequiredAccess]
self.image_model_id: str = settings.get("image_embedding_model_id") or self.model_id # type: ignore[assignment]
resolved_endpoint = settings["endpoint"] # type: ignore[reportTypedDictNotRequiredAccess]
if credential is None and settings.get("api_key"):
credential = AzureKeyCredential(settings["api_key"]) # type: ignore[arg-type]
if credential is None and text_client is None and image_client is None:
raise ValueError("Either 'api_key', 'credential', or pre-configured client(s) must be provided.")
self._text_client = text_client or EmbeddingsClient(
endpoint=resolved_endpoint, # type: ignore[arg-type]
credential=credential, # type: ignore[arg-type]
)
self._image_client = image_client or ImageEmbeddingsClient(
endpoint=resolved_endpoint, # type: ignore[arg-type]
credential=credential, # type: ignore[arg-type]
)
self._endpoint = resolved_endpoint
super().__init__(**kwargs)
async def close(self) -> None:
"""Close the underlying SDK clients and release resources."""
with suppress(Exception):
await self._text_client.close()
with suppress(Exception):
await self._image_client.close()
async def __aenter__(self) -> RawAzureAIInferenceEmbeddingClient[AzureAIInferenceEmbeddingOptionsT]:
"""Enter the async context manager."""
return self
async def __aexit__(self, *args: Any) -> None:
"""Exit the async context manager and close clients."""
await self.close()
def service_url(self) -> str:
"""Get the URL of the service."""
return self._endpoint or ""
async def get_embeddings(
self,
values: Sequence[Content | str],
*,
options: AzureAIInferenceEmbeddingOptionsT | None = None,
) -> GeneratedEmbeddings[list[float]]:
"""Generate embeddings for text and/or image inputs.
Text inputs (``str`` or ``Content`` with ``type="text"``) are sent to the
text embeddings endpoint. Image inputs (``Content`` with an image
``media_type``) are sent to the image embeddings endpoint. Results are
returned in the same order as the input.
Args:
values: A sequence of text strings or ``Content`` instances.
options: Optional embedding generation options.
Returns:
Generated embeddings with usage metadata.
Raises:
ValueError: If model_id is not provided or an unsupported content type is encountered.
"""
if not values:
return GeneratedEmbeddings([], options=options) # type: ignore[reportReturnType]
opts: dict[str, Any] = dict(options) if options else {}
# Separate text and image inputs, tracking original indices.
text_items: list[tuple[int, str]] = []
image_items: list[tuple[int, ImageEmbeddingInput]] = []
for idx, value in enumerate(values):
if isinstance(value, str):
text_items.append((idx, value))
elif isinstance(value, Content):
if value.type == "text" and value.text is not None:
text_items.append((idx, value.text))
elif (
value.type in ("data", "uri")
and value.media_type
and value.media_type.startswith(_IMAGE_MEDIA_PREFIXES[0])
):
if not value.uri:
raise ValueError(f"Image Content at index {idx} has no URI.")
image_input = ImageEmbeddingInput(image=value.uri, text=value.text)
image_items.append((idx, image_input))
else:
raise ValueError(
f"Unsupported Content type '{value.type}' with media_type "
f"'{value.media_type}' at index {idx}. Expected text content or "
f"image content (media_type starting with 'image/')."
)
else:
raise ValueError(f"Unsupported input type {type(value).__name__} at index {idx}.")
# Build shared API kwargs (without model, which differs per client).
common_kwargs: dict[str, Any] = {}
if dimensions := opts.get("dimensions"):
common_kwargs["dimensions"] = dimensions
if encoding_format := opts.get("encoding_format"):
common_kwargs["encoding_format"] = encoding_format
if input_type := opts.get("input_type"):
common_kwargs["input_type"] = input_type
if extra_parameters := opts.get("extra_parameters"):
common_kwargs["model_extras"] = extra_parameters
# Allocate results array.
embeddings: list[Embedding[list[float]] | None] = [None] * len(values)
usage_details: UsageDetails = {}
# Embed text inputs.
if text_items:
if not (text_model := opts.get("model_id") or self.model_id):
raise ValueError("An model_id is required, either in the client or options, for text inputs.")
text_inputs = [t for _, t in text_items]
response = await self._text_client.embed(
input=text_inputs,
model=text_model,
**common_kwargs,
)
for i, item in enumerate(response.data):
original_idx = text_items[i][0]
vector: list[float] = [float(v) for v in item.embedding]
embeddings[original_idx] = Embedding(
vector=vector,
dimensions=len(vector),
model_id=response.model or text_model,
)
if response.usage:
usage_details["input_token_count"] = (usage_details.get("input_token_count") or 0) + (
response.usage.prompt_tokens or 0
)
usage_details["output_token_count"] = (usage_details.get("output_token_count") or 0) + (
getattr(response.usage, "completion_tokens", 0) or 0
)
# Embed image inputs.
if image_items:
if not (image_model := opts.get("image_model_id") or self.image_model_id):
raise ValueError("An image_model_id is required, either in the client or options, for image inputs.")
image_inputs = [img for _, img in image_items]
response = await self._image_client.embed(
input=image_inputs,
model=image_model,
**common_kwargs,
)
for i, item in enumerate(response.data):
original_idx = image_items[i][0]
image_vector: list[float] = [float(v) for v in item.embedding]
embeddings[original_idx] = Embedding(
vector=image_vector,
dimensions=len(image_vector),
model_id=response.model or image_model,
)
if response.usage:
usage_details["input_token_count"] = (usage_details.get("input_token_count") or 0) + (
response.usage.prompt_tokens or 0
)
usage_details["output_token_count"] = (usage_details.get("output_token_count") or 0) + (
getattr(response.usage, "completion_tokens", 0) or 0
)
return GeneratedEmbeddings(
[embedding for embedding in embeddings if embedding is not None],
options=options,
usage=usage_details,
) # type: ignore[reportReturnType]
class AzureAIInferenceEmbeddingClient(
EmbeddingTelemetryLayer[Content | str, list[float], AzureAIInferenceEmbeddingOptionsT],
RawAzureAIInferenceEmbeddingClient[AzureAIInferenceEmbeddingOptionsT],
Generic[AzureAIInferenceEmbeddingOptionsT],
):
"""Azure AI Inference embedding client with telemetry support.
Supports both text and image inputs in a single client. Pass plain strings
or ``Content`` instances created with ``Content.from_text()`` or
``Content.from_data()``.
Keyword Args:
model_id: The text embedding model deployment name (e.g. "text-embedding-3-small").
Can also be set via environment variable AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID.
image_model_id: The image embedding model deployment name
(e.g. "Cohere-embed-v3-english"). Can also be set via environment variable
AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID. Falls back to ``model_id``.
endpoint: The Azure AI Inference endpoint URL.
Can also be set via environment variable AZURE_AI_INFERENCE_ENDPOINT.
api_key: API key for authentication.
Can also be set via environment variable AZURE_AI_INFERENCE_API_KEY.
text_client: Optional pre-configured ``EmbeddingsClient``.
image_client: Optional pre-configured ``ImageEmbeddingsClient``.
credential: Optional ``AzureKeyCredential`` or token credential.
otel_provider_name: Override for the OpenTelemetry provider name.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
Examples:
.. code-block:: python
from agent_framework_azure_ai import AzureAIInferenceEmbeddingClient
# Using environment variables
# Set AZURE_AI_INFERENCE_ENDPOINT=https://your-endpoint.inference.ai.azure.com
# Set AZURE_AI_INFERENCE_API_KEY=your-key
# Set AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID=text-embedding-3-small
# Set AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID=Cohere-embed-v3-english
client = AzureAIInferenceEmbeddingClient()
# Text embeddings
result = await client.get_embeddings(["Hello, world!"])
# Image embeddings
from agent_framework import Content
image = Content.from_data(data=image_bytes, media_type="image/png")
result = await client.get_embeddings([image])
# Mixed text and image
result = await client.get_embeddings(["hello", image])
"""
OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.inference" # type: ignore[reportIncompatibleVariableOverride, misc]
def __init__(
self,
*,
model_id: str | None = None,
image_model_id: str | None = None,
endpoint: str | None = None,
api_key: str | None = None,
text_client: EmbeddingsClient | None = None,
image_client: ImageEmbeddingsClient | None = None,
credential: AzureKeyCredential | None = None,
otel_provider_name: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Azure AI Inference embedding client."""
super().__init__(
model_id=model_id,
image_model_id=image_model_id,
endpoint=endpoint,
api_key=api_key,
text_client=text_client,
image_client=image_client,
credential=credential,
otel_provider_name=otel_provider_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
+6
View File
@@ -25,6 +25,7 @@ classifiers = [
dependencies = [
"agent-framework-core>=1.0.0rc1",
"azure-ai-agents == 1.2.0b5",
"azure-ai-inference>=1.0.0b9",
"aiohttp",
]
@@ -38,6 +39,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -45,6 +47,9 @@ asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -78,6 +83,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_ai"
test = "pytest --cov=agent_framework_azure_ai --cov-report=term-missing:skip-covered tests"
@@ -0,0 +1,316 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import os
from collections.abc import Sequence
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import Content
from agent_framework_azure_ai import (
AzureAIInferenceEmbeddingClient,
AzureAIInferenceEmbeddingOptions,
RawAzureAIInferenceEmbeddingClient,
)
def _make_embed_response(
embeddings: Sequence[list[float]],
model: str = "test-model",
prompt_tokens: int = 10,
) -> MagicMock:
"""Create a mock EmbeddingsResult."""
data = []
for emb in embeddings:
item = MagicMock()
item.embedding = emb
data.append(item)
usage = MagicMock()
usage.prompt_tokens = prompt_tokens
usage.completion_tokens = 0
result = MagicMock()
result.data = data
result.model = model
result.usage = usage
return result
@pytest.fixture
def mock_text_client() -> AsyncMock:
"""Create a mock text EmbeddingsClient."""
client = AsyncMock()
client.embed = AsyncMock(return_value=_make_embed_response([[0.1, 0.2, 0.3]]))
return client
@pytest.fixture
def mock_image_client() -> AsyncMock:
"""Create a mock image ImageEmbeddingsClient."""
client = AsyncMock()
client.embed = AsyncMock(return_value=_make_embed_response([[0.4, 0.5, 0.6]]))
return client
@pytest.fixture
def raw_client(mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> RawAzureAIInferenceEmbeddingClient[Any]:
"""Create a RawAzureAIInferenceEmbeddingClient with mocked SDK clients."""
return RawAzureAIInferenceEmbeddingClient(
model_id="test-model",
endpoint="https://test.inference.ai.azure.com",
api_key="test-key",
text_client=mock_text_client,
image_client=mock_image_client,
)
@pytest.fixture
def client(mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> AzureAIInferenceEmbeddingClient[Any]:
"""Create an AzureAIInferenceEmbeddingClient with mocked SDK clients."""
return AzureAIInferenceEmbeddingClient(
model_id="test-model",
endpoint="https://test.inference.ai.azure.com",
api_key="test-key",
text_client=mock_text_client,
image_client=mock_image_client,
)
class TestRawAzureAIInferenceEmbeddingClient:
"""Tests for the raw Azure AI Inference embedding client."""
async def test_text_embeddings(
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
) -> None:
"""Text inputs are dispatched to the text client."""
result = await raw_client.get_embeddings(["hello", "world"])
assert result is not None
call_kwargs = mock_text_client.embed.call_args
assert call_kwargs.kwargs["input"] == ["hello", "world"]
assert call_kwargs.kwargs["model"] == "test-model"
async def test_text_content_embeddings(
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
) -> None:
"""Content.from_text() inputs are dispatched to the text client."""
text_content = Content.from_text("hello")
await raw_client.get_embeddings([text_content])
mock_text_client.embed.assert_called_once()
call_kwargs = mock_text_client.embed.call_args
assert call_kwargs.kwargs["input"] == ["hello"]
async def test_image_content_embeddings(
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_image_client: AsyncMock
) -> None:
"""Image Content inputs are dispatched to the image client."""
image_content = Content.from_data(data=b"\x89PNG", media_type="image/png")
await raw_client.get_embeddings([image_content])
mock_image_client.embed.assert_called_once()
call_kwargs = mock_image_client.embed.call_args
image_inputs = call_kwargs.kwargs["input"]
assert len(image_inputs) == 1
assert image_inputs[0].image == image_content.uri
async def test_mixed_text_and_image(
self,
raw_client: RawAzureAIInferenceEmbeddingClient[Any],
mock_text_client: AsyncMock,
mock_image_client: AsyncMock,
) -> None:
"""Mixed text and image inputs are dispatched to the correct clients."""
mock_text_client.embed.return_value = _make_embed_response([[0.1, 0.2]])
mock_image_client.embed.return_value = _make_embed_response([[0.3, 0.4]])
image = Content.from_data(data=b"\x89PNG", media_type="image/png")
await raw_client.get_embeddings(["hello", image, "world"])
# Text client gets "hello" and "world"
text_call = mock_text_client.embed.call_args
assert text_call.kwargs["input"] == ["hello", "world"]
# Image client gets the image
image_call = mock_image_client.embed.call_args
assert len(image_call.kwargs["input"]) == 1
async def test_empty_input(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
"""Empty input returns empty result."""
result = await raw_client.get_embeddings([])
assert len(result) == 0
async def test_options_passed_through(
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
) -> None:
"""Options are passed through to the SDK."""
options: AzureAIInferenceEmbeddingOptions = {
"dimensions": 512,
"input_type": "document",
"encoding_format": "float",
}
await raw_client.get_embeddings(["hello"], options=options)
call_kwargs = mock_text_client.embed.call_args
assert call_kwargs.kwargs["dimensions"] == 512
assert call_kwargs.kwargs["input_type"] == "document"
assert call_kwargs.kwargs["encoding_format"] == "float"
async def test_model_override_in_options(
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
) -> None:
"""model_id in options overrides the default."""
options: AzureAIInferenceEmbeddingOptions = {"model_id": "custom-model"}
await raw_client.get_embeddings(["hello"], options=options)
call_kwargs = mock_text_client.embed.call_args
assert call_kwargs.kwargs["model"] == "custom-model"
async def test_unsupported_content_type_raises(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
"""Non-text, non-image Content raises ValueError."""
error_content = Content("error", message="fail")
with pytest.raises(ValueError, match="Unsupported Content type"):
await raw_client.get_embeddings([error_content])
async def test_usage_metadata(
self, raw_client: RawAzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
) -> None:
"""Usage metadata is populated from the response."""
mock_text_client.embed.return_value = _make_embed_response([[0.1, 0.2]], prompt_tokens=42)
result = await raw_client.get_embeddings(["hello"])
assert result.usage is not None
assert result.usage["input_token_count"] == 42
def test_service_url(self, raw_client: RawAzureAIInferenceEmbeddingClient[Any]) -> None:
"""service_url returns the configured endpoint."""
assert raw_client.service_url() == "https://test.inference.ai.azure.com"
def test_settings_from_env(self) -> None:
"""Settings are loaded from environment variables."""
with (
patch.dict(
os.environ,
{
"AZURE_AI_INFERENCE_ENDPOINT": "https://env.inference.ai.azure.com",
"AZURE_AI_INFERENCE_API_KEY": "env-key",
"AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID": "env-model",
},
),
patch("agent_framework_azure_ai._embedding_client.EmbeddingsClient"),
patch("agent_framework_azure_ai._embedding_client.ImageEmbeddingsClient"),
):
client = RawAzureAIInferenceEmbeddingClient()
assert client.model_id == "env-model"
assert client.image_model_id == "env-model" # falls back to model_id
def test_image_model_id_from_env(self) -> None:
"""image_model_id is loaded from its own environment variable."""
with (
patch.dict(
os.environ,
{
"AZURE_AI_INFERENCE_ENDPOINT": "https://env.inference.ai.azure.com",
"AZURE_AI_INFERENCE_API_KEY": "env-key",
"AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID": "text-model",
"AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID": "image-model",
},
),
patch("agent_framework_azure_ai._embedding_client.EmbeddingsClient"),
patch("agent_framework_azure_ai._embedding_client.ImageEmbeddingsClient"),
):
client = RawAzureAIInferenceEmbeddingClient()
assert client.model_id == "text-model"
assert client.image_model_id == "image-model"
def test_image_model_id_explicit(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
"""image_model_id can be set explicitly."""
client = RawAzureAIInferenceEmbeddingClient(
model_id="text-model",
image_model_id="image-model",
endpoint="https://test.inference.ai.azure.com",
api_key="test-key",
text_client=mock_text_client,
image_client=mock_image_client,
)
assert client.model_id == "text-model"
assert client.image_model_id == "image-model"
async def test_image_model_id_sent_to_image_client(
self, mock_text_client: AsyncMock, mock_image_client: AsyncMock
) -> None:
"""image_model_id is passed to the image client embed call."""
client = RawAzureAIInferenceEmbeddingClient(
model_id="text-model",
image_model_id="image-model",
endpoint="https://test.inference.ai.azure.com",
api_key="test-key",
text_client=mock_text_client,
image_client=mock_image_client,
)
image_content = Content.from_data(data=b"\x89PNG", media_type="image/png")
await client.get_embeddings([image_content])
call_kwargs = mock_image_client.embed.call_args
assert call_kwargs.kwargs["model"] == "image-model"
class TestAzureAIInferenceEmbeddingClient:
"""Tests for the telemetry-enabled Azure AI Inference embedding client."""
async def test_text_embeddings(
self, client: AzureAIInferenceEmbeddingClient[Any], mock_text_client: AsyncMock
) -> None:
"""Text embeddings work through the telemetry layer."""
result = await client.get_embeddings(["hello"])
assert len(result) == 1
assert result[0].vector == [0.1, 0.2, 0.3]
async def test_otel_provider_name_default(self) -> None:
"""Default OTEL provider name is azure.ai.inference."""
assert AzureAIInferenceEmbeddingClient.OTEL_PROVIDER_NAME == "azure.ai.inference"
async def test_otel_provider_name_override(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
"""OTEL provider name can be overridden."""
client = AzureAIInferenceEmbeddingClient(
model_id="test-model",
endpoint="https://test.inference.ai.azure.com",
api_key="test-key",
text_client=mock_text_client,
image_client=mock_image_client,
otel_provider_name="custom-provider",
)
assert client.otel_provider_name == "custom-provider"
_SKIP_REASON = "Azure AI Inference integration tests disabled"
def _integration_tests_enabled() -> bool:
return bool(
os.environ.get("AZURE_AI_INFERENCE_ENDPOINT")
and os.environ.get("AZURE_AI_INFERENCE_API_KEY")
and os.environ.get("AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID")
)
skip_if_azure_ai_inference_integration_tests_disabled = pytest.mark.skipif(
not _integration_tests_enabled(),
reason=_SKIP_REASON,
)
class TestAzureAIInferenceEmbeddingIntegration:
"""Integration tests requiring a live Azure AI Inference endpoint."""
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_azure_ai_inference_integration_tests_disabled
async def test_text_embedding_live(self) -> None:
"""Generate text embeddings against a live endpoint."""
client = AzureAIInferenceEmbeddingClient()
result = await client.get_embeddings(["Hello, world!"])
assert len(result) == 1
assert len(result[0].vector) > 0
assert result[0].model_id is not None
@@ -41,6 +41,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
pythonpath = ["tests/integration_tests"]
@@ -88,6 +89,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azurefunctions"
test = "pytest --cov=agent_framework_azurefunctions --cov-report=term-missing:skip-covered tests"
@@ -3,6 +3,7 @@
import importlib.metadata
from ._chat_client import BedrockChatClient, BedrockChatOptions, BedrockGuardrailConfig, BedrockSettings
from ._embedding_client import BedrockEmbeddingClient, BedrockEmbeddingOptions, BedrockEmbeddingSettings
try:
__version__ = importlib.metadata.version(__name__)
@@ -12,6 +13,9 @@ except importlib.metadata.PackageNotFoundError:
__all__ = [
"BedrockChatClient",
"BedrockChatOptions",
"BedrockEmbeddingClient",
"BedrockEmbeddingOptions",
"BedrockEmbeddingSettings",
"BedrockGuardrailConfig",
"BedrockSettings",
"__version__",
@@ -0,0 +1,292 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import asyncio
import json
import logging
import sys
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, TypedDict
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
BaseEmbeddingClient,
Embedding,
EmbeddingGenerationOptions,
GeneratedEmbeddings,
SecretString,
UsageDetails,
load_settings,
)
from agent_framework.observability import EmbeddingTelemetryLayer
from boto3.session import Session as Boto3Session
from botocore.client import BaseClient
from botocore.config import Config as BotoConfig
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
logger = logging.getLogger("agent_framework.bedrock")
DEFAULT_REGION = "us-east-1"
class BedrockEmbeddingSettings(TypedDict, total=False):
"""Bedrock embedding settings."""
region: str | None
embedding_model_id: str | None
access_key: SecretString | None
secret_key: SecretString | None
session_token: SecretString | None
class BedrockEmbeddingOptions(EmbeddingGenerationOptions, total=False):
"""Bedrock-specific embedding options.
Extends EmbeddingGenerationOptions with Bedrock-specific fields.
Examples:
.. code-block:: python
from agent_framework_bedrock import BedrockEmbeddingOptions
options: BedrockEmbeddingOptions = {
"model_id": "amazon.titan-embed-text-v2:0",
"dimensions": 1024,
"normalize": True,
}
"""
normalize: bool
BedrockEmbeddingOptionsT = TypeVar(
"BedrockEmbeddingOptionsT",
bound=TypedDict, # type: ignore[valid-type]
default="BedrockEmbeddingOptions",
covariant=True,
)
class RawBedrockEmbeddingClient(
BaseEmbeddingClient[str, list[float], BedrockEmbeddingOptionsT],
Generic[BedrockEmbeddingOptionsT],
):
"""Raw Bedrock embedding client without telemetry.
Keyword Args:
model_id: The Bedrock embedding model ID (e.g. "amazon.titan-embed-text-v2:0").
Can also be set via environment variable BEDROCK_EMBEDDING_MODEL_ID.
region: AWS region. Will try to load from BEDROCK_REGION env var,
if not set, the regular Boto3 configuration/loading applies
(which may include other env vars, config files, or instance metadata).
access_key: AWS access key for manual credential injection.
secret_key: AWS secret key paired with access_key.
session_token: AWS session token for temporary credentials.
client: Preconfigured Bedrock runtime client.
boto3_session: Custom boto3 session used to build the runtime client.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
"""
def __init__(
self,
*,
region: str | None = None,
model_id: str | None = None,
access_key: str | None = None,
secret_key: str | None = None,
session_token: str | None = None,
client: BaseClient | None = None,
boto3_session: Boto3Session | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw Bedrock embedding client."""
settings = load_settings(
BedrockEmbeddingSettings,
env_prefix="BEDROCK_",
required_fields=["embedding_model_id"],
region=region,
embedding_model_id=model_id,
access_key=access_key,
secret_key=secret_key,
session_token=session_token,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
resolved_region = settings.get("region") or DEFAULT_REGION
if client is None:
if not boto3_session:
session_kwargs: dict[str, Any] = {}
if region := settings.get("region"):
session_kwargs["region_name"] = region
if (access_key := settings.get("access_key")) and (secret_key := settings.get("secret_key")):
session_kwargs["aws_access_key_id"] = access_key.get_secret_value() # type: ignore[union-attr]
session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value() # type: ignore[union-attr]
if session_token := settings.get("session_token"):
session_kwargs["aws_session_token"] = session_token.get_secret_value() # type: ignore[union-attr]
boto3_session = Boto3Session(**session_kwargs)
client = boto3_session.client(
"bedrock-runtime",
region_name=boto3_session.region_name or resolved_region,
config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT),
)
self._bedrock_client = client
self.model_id = settings["embedding_model_id"] # type: ignore[assignment]
self.region = resolved_region
super().__init__(**kwargs)
def service_url(self) -> str:
"""Get the URL of the service."""
return str(self._bedrock_client.meta.endpoint_url)
async def get_embeddings(
self,
values: Sequence[str],
*,
options: BedrockEmbeddingOptionsT | None = None,
) -> GeneratedEmbeddings[list[float]]:
"""Call the Bedrock invoke_model API for embeddings.
Uses the Amazon Titan Embeddings model format. Each value is embedded
individually since Titan's invoke_model API accepts one input at a time.
Args:
values: The text values to generate embeddings for.
options: Optional embedding generation options.
Returns:
Generated embeddings with usage metadata.
Raises:
ValueError: If model_id is not provided or values is empty.
"""
if not values:
return GeneratedEmbeddings([], options=options)
opts: dict[str, Any] = dict(options) if options else {}
model = opts.get("model_id") or self.model_id
if not model:
raise ValueError("model_id is required")
embedding_results = await asyncio.gather(
*(self._generate_embedding_for_text(opts, model, text) for text in values)
)
embeddings: list[Embedding[list[float]]] = []
total_input_tokens = 0
for embedding, input_tokens in embedding_results:
embeddings.append(embedding)
total_input_tokens += input_tokens
usage_dict: UsageDetails | None = None
if total_input_tokens > 0:
usage_dict = {"input_token_count": total_input_tokens}
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
async def _generate_embedding_for_text(
self,
opts: dict[str, Any],
model: str,
text: str,
) -> tuple[Embedding[list[float]], int]:
body: dict[str, Any] = {"inputText": text}
if dimensions := opts.get("dimensions"):
body["dimensions"] = dimensions
if (normalize := opts.get("normalize")) is not None:
body["normalize"] = normalize
response = await asyncio.to_thread(
self._bedrock_client.invoke_model,
modelId=model,
contentType="application/json",
accept="application/json",
body=json.dumps(body),
)
response_body = json.loads(response["body"].read())
embedding = Embedding(
vector=response_body["embedding"],
dimensions=len(response_body["embedding"]),
model_id=model,
)
input_tokens = int(response_body.get("inputTextTokenCount", 0))
return embedding, input_tokens
class BedrockEmbeddingClient(
EmbeddingTelemetryLayer[str, list[float], BedrockEmbeddingOptionsT],
RawBedrockEmbeddingClient[BedrockEmbeddingOptionsT],
Generic[BedrockEmbeddingOptionsT],
):
"""Bedrock embedding client with telemetry support.
Uses the Amazon Titan Embeddings model via Bedrock's invoke_model API.
Keyword Args:
model_id: The Bedrock embedding model ID (e.g. "amazon.titan-embed-text-v2:0").
Can also be set via environment variable BEDROCK_EMBEDDING_MODEL_ID.
region: AWS region. Defaults to "us-east-1".
Can also be set via environment variable BEDROCK_REGION.
access_key: AWS access key for manual credential injection.
secret_key: AWS secret key paired with access_key.
session_token: AWS session token for temporary credentials.
client: Preconfigured Bedrock runtime client.
boto3_session: Custom boto3 session used to build the runtime client.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
Examples:
.. code-block:: python
from agent_framework_bedrock import BedrockEmbeddingClient
# Using default AWS credentials
client = BedrockEmbeddingClient(
model_id="amazon.titan-embed-text-v2:0",
)
# Generate embeddings
result = await client.get_embeddings(["Hello, world!"])
print(result[0].vector)
"""
OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc]
def __init__(
self,
*,
region: str | None = None,
model_id: str | None = None,
access_key: str | None = None,
secret_key: str | None = None,
session_token: str | None = None,
client: BaseClient | None = None,
boto3_session: Boto3Session | None = None,
otel_provider_name: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a Bedrock embedding client."""
super().__init__(
region=region,
model_id=model_id,
access_key=access_key,
secret_key=secret_key,
session_token=session_token,
client=client,
boto3_session=boto3_session,
otel_provider_name=otel_provider_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
+3 -1
View File
@@ -28,7 +28,6 @@ dependencies = [
"botocore>=1.35.0,<2.0.0",
]
[tool.uv]
prerelease = "if-necessary-or-explicit"
environments = [
@@ -46,6 +45,9 @@ addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
markers = [
"integration: marks tests as integration tests that require external services",
]
timeout = 120
[tool.ruff]
@@ -0,0 +1,168 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import json
import os
from typing import Any
from unittest.mock import MagicMock
import pytest
from agent_framework import Embedding, GeneratedEmbeddings
from agent_framework_bedrock import BedrockEmbeddingClient, BedrockEmbeddingOptions
class _StubBedrockEmbeddingRuntime:
"""Stub for the Bedrock runtime client that handles invoke_model for embeddings."""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def invoke_model(self, **kwargs: Any) -> dict[str, Any]:
self.calls.append(kwargs)
body = json.loads(kwargs.get("body", "{}"))
# Simulate Titan embedding response
dimensions = body.get("dimensions", 3)
return {
"body": MagicMock(
read=lambda: json.dumps({
"embedding": [0.1 * (i + 1) for i in range(dimensions)],
"inputTextTokenCount": 5,
}).encode()
),
}
async def test_bedrock_embedding_construction() -> None:
"""Test construction with explicit parameters."""
stub = _StubBedrockEmbeddingRuntime()
client = BedrockEmbeddingClient(
model_id="amazon.titan-embed-text-v2:0",
region="us-west-2",
client=stub,
)
assert client.model_id == "amazon.titan-embed-text-v2:0"
assert client.region == "us-west-2"
async def test_bedrock_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that missing model_id raises an error."""
monkeypatch.delenv("BEDROCK_EMBEDDING_MODEL_ID", raising=False)
from agent_framework.exceptions import SettingNotFoundError
with pytest.raises(SettingNotFoundError):
BedrockEmbeddingClient(region="us-west-2")
async def test_bedrock_embedding_get_embeddings() -> None:
"""Test generating embeddings via the Bedrock invoke_model API."""
stub = _StubBedrockEmbeddingRuntime()
client = BedrockEmbeddingClient(
model_id="amazon.titan-embed-text-v2:0",
region="us-west-2",
client=stub,
)
result = await client.get_embeddings(["hello", "world"])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 2
assert len(result[0].vector) == 3
assert len(result[1].vector) == 3
assert result[0].model_id == "amazon.titan-embed-text-v2:0"
assert result.usage == {"input_token_count": 10}
# Two calls since Titan processes one input at a time
assert len(stub.calls) == 2
call_texts = {json.loads(call["body"])["inputText"] for call in stub.calls}
assert call_texts == {"hello", "world"}
async def test_bedrock_embedding_get_embeddings_empty_input() -> None:
"""Test generating embeddings with empty input."""
stub = _StubBedrockEmbeddingRuntime()
client = BedrockEmbeddingClient(
model_id="amazon.titan-embed-text-v2:0",
region="us-west-2",
client=stub,
)
result = await client.get_embeddings([])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 0
assert len(stub.calls) == 0
async def test_bedrock_embedding_get_embeddings_with_options() -> None:
"""Test generating embeddings with custom options."""
stub = _StubBedrockEmbeddingRuntime()
client = BedrockEmbeddingClient(
model_id="amazon.titan-embed-text-v2:0",
region="us-west-2",
client=stub,
)
options: BedrockEmbeddingOptions = {
"dimensions": 5,
"normalize": True,
}
result = await client.get_embeddings(["hello"], options=options)
assert len(result) == 1
assert len(result[0].vector) == 5
body = json.loads(stub.calls[0]["body"])
assert body["dimensions"] == 5
assert body["normalize"] is True
async def test_bedrock_embedding_get_embeddings_no_model_raises() -> None:
"""Test that missing model_id at call time raises ValueError."""
stub = _StubBedrockEmbeddingRuntime()
client = BedrockEmbeddingClient(
model_id="amazon.titan-embed-text-v2:0",
region="us-west-2",
client=stub,
)
client.model_id = None # type: ignore[assignment]
with pytest.raises(ValueError, match="model_id is required"):
await client.get_embeddings(["hello"])
async def test_bedrock_embedding_default_region() -> None:
"""Test that default region is us-east-1."""
stub = _StubBedrockEmbeddingRuntime()
client = BedrockEmbeddingClient(
model_id="amazon.titan-embed-text-v2:0",
client=stub,
)
assert client.region == "us-east-1"
# region: Integration Tests
skip_if_bedrock_embedding_integration_tests_disabled = pytest.mark.skipif(
os.getenv("BEDROCK_EMBEDDING_MODEL_ID", "") in ("", "test-model")
or not (os.getenv("AWS_ACCESS_KEY_ID") or os.getenv("BEDROCK_ACCESS_KEY")),
reason="No real Bedrock embedding model or AWS credentials provided; skipping integration tests.",
)
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_bedrock_embedding_integration_tests_disabled
async def test_bedrock_embedding_integration() -> None:
"""Integration test for Bedrock embedding client."""
client = BedrockEmbeddingClient()
result = await client.get_embeddings(["Hello, world!", "How are you?"])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 2
for embedding in result:
assert isinstance(embedding, Embedding)
assert isinstance(embedding.vector, list)
assert len(embedding.vector) > 0
assert all(isinstance(v, float) for v in embedding.vector)
+5
View File
@@ -36,6 +36,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -43,6 +44,9 @@ asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -80,6 +84,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_chatkit"
test = "pytest --cov=agent_framework_chatkit --cov-report=term-missing:skip-covered tests"
+5
View File
@@ -37,6 +37,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -46,6 +47,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -80,6 +84,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_claude"
test = "pytest --cov=agent_framework_claude --cov-report=term-missing:skip-covered tests"
+5 -1
View File
@@ -37,6 +37,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -46,6 +47,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -58,7 +62,6 @@ omit = [
[tool.pyright]
extends = "../../pyproject.toml"
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
@@ -80,6 +83,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_copilotstudio"
test = "pytest --cov=agent_framework_copilotstudio --cov-report=term-missing:skip-covered tests"
@@ -667,16 +667,12 @@ class SupportsFileSearchTool(Protocol):
# region SupportsGetEmbeddings Protocol
# Contravariant/covariant TypeVars for the Protocol
# Contravariant TypeVars for the Protocol
EmbeddingInputContraT = TypeVar(
"EmbeddingInputContraT",
default="str",
contravariant=True,
)
EmbeddingCoT = TypeVar(
"EmbeddingCoT",
default="list[float]",
)
EmbeddingOptionsContraT = TypeVar(
"EmbeddingOptionsContraT",
bound=TypedDict, # type: ignore[valid-type]
@@ -686,7 +682,7 @@ EmbeddingOptionsContraT = TypeVar(
@runtime_checkable
class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingCoT, EmbeddingOptionsContraT]):
class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingT, EmbeddingOptionsContraT]):
"""Protocol for an embedding client that can generate embeddings.
This protocol enables duck-typing for embedding generation. Any class that
@@ -714,7 +710,7 @@ class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingCoT, Embedd
values: Sequence[EmbeddingInputContraT],
*,
options: EmbeddingOptionsContraT | None = None,
) -> Awaitable[GeneratedEmbeddings[EmbeddingCoT]]:
) -> Awaitable[GeneratedEmbeddings[EmbeddingT]]:
"""Generate embeddings for the given values.
Args:
@@ -733,15 +729,15 @@ class SupportsGetEmbeddings(Protocol[EmbeddingInputContraT, EmbeddingCoT, Embedd
# region BaseEmbeddingClient
# Covariant for the BaseEmbeddingClient
EmbeddingOptionsCoT = TypeVar(
"EmbeddingOptionsCoT",
EmbeddingOptionsT = TypeVar(
"EmbeddingOptionsT",
bound=TypedDict, # type: ignore[valid-type]
default="EmbeddingGenerationOptions",
covariant=True,
)
class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsCoT]):
class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsT]):
"""Abstract base class for embedding clients.
Subclasses implement ``get_embeddings`` to provide the actual
@@ -785,7 +781,7 @@ class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, Embe
self,
values: Sequence[EmbeddingInputT],
*,
options: EmbeddingOptionsCoT | None = None,
options: EmbeddingOptionsT | None = None,
) -> GeneratedEmbeddings[EmbeddingT]:
"""Generate embeddings for the given values.
@@ -377,6 +377,12 @@ class UsageDetails(TypedDict, total=False):
This is a non-closed dictionary, so any specific provider fields can be added as needed.
Whenever they can be mapped to standard fields, they will be.
Keys:
input_token_count: The number of input tokens used.
output_token_count: The number of output tokens generated.
total_token_count: The total number of tokens (input + output).
"""
input_token_count: int | None
@@ -3289,7 +3295,7 @@ class GeneratedEmbeddings(list[Embedding[EmbeddingT]], Generic[EmbeddingT, Embed
embeddings: Iterable[Embedding[EmbeddingT]] | None = None,
*,
options: EmbeddingOptionsT | None = None,
usage: dict[str, Any] | None = None,
usage: UsageDetails | None = None,
additional_properties: dict[str, Any] | None = None,
) -> None:
super().__init__(embeddings or [])
@@ -8,6 +8,9 @@ This module lazily re-exports objects from:
Supported classes:
- BedrockChatClient
- BedrockChatOptions
- BedrockEmbeddingClient
- BedrockEmbeddingOptions
- BedrockEmbeddingSettings
- BedrockGuardrailConfig
- BedrockSettings
"""
@@ -17,7 +20,15 @@ from typing import Any
IMPORT_PATH = "agent_framework_bedrock"
PACKAGE_NAME = "agent-framework-bedrock"
_IMPORTS = ["BedrockChatClient", "BedrockChatOptions", "BedrockGuardrailConfig", "BedrockSettings"]
_IMPORTS = [
"BedrockChatClient",
"BedrockChatOptions",
"BedrockEmbeddingClient",
"BedrockEmbeddingOptions",
"BedrockEmbeddingSettings",
"BedrockGuardrailConfig",
"BedrockSettings",
]
def __getattr__(name: str) -> Any:
@@ -3,6 +3,9 @@
from agent_framework_bedrock import (
BedrockChatClient,
BedrockChatOptions,
BedrockEmbeddingClient,
BedrockEmbeddingOptions,
BedrockEmbeddingSettings,
BedrockGuardrailConfig,
BedrockSettings,
)
@@ -10,6 +13,9 @@ from agent_framework_bedrock import (
__all__ = [
"BedrockChatClient",
"BedrockChatOptions",
"BedrockEmbeddingClient",
"BedrockEmbeddingOptions",
"BedrockEmbeddingSettings",
"BedrockGuardrailConfig",
"BedrockSettings",
]
@@ -99,6 +99,7 @@ class AzureOpenAIEmbeddingClient(
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | None = None,
otel_provider_name: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -133,4 +134,5 @@ class AzureOpenAIEmbeddingClient(
credential=credential,
default_headers=default_headers,
client=async_client,
otel_provider_name=otel_provider_name,
)
@@ -1279,15 +1279,15 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
return _get_response()
EmbeddingOptionsCoT = TypeVar(
"EmbeddingOptionsCoT",
EmbeddingOptionsT = TypeVar(
"EmbeddingOptionsT",
bound=TypedDict, # type: ignore[valid-type]
default="EmbeddingGenerationOptions",
covariant=True,
)
class EmbeddingTelemetryLayer(Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsCoT]):
class EmbeddingTelemetryLayer(Generic[EmbeddingInputT, EmbeddingT, EmbeddingOptionsT]):
"""Layer that wraps embedding client get_embeddings with OpenTelemetry tracing."""
def __init__(self, *args: Any, otel_provider_name: str | None = None, **kwargs: Any) -> None:
@@ -1301,7 +1301,7 @@ class EmbeddingTelemetryLayer(Generic[EmbeddingInputT, EmbeddingT, EmbeddingOpti
self,
values: Sequence[EmbeddingInputT],
*,
options: EmbeddingOptionsCoT | None = None,
options: EmbeddingOptionsT | None = None,
) -> GeneratedEmbeddings[EmbeddingT]:
"""Trace embedding generation with OpenTelemetry spans and metrics."""
global OBSERVABILITY_SETTINGS
@@ -7,6 +7,10 @@ This module lazily re-exports objects from:
Supported classes:
- OllamaChatClient
- OllamaChatOptions
- OllamaEmbeddingClient
- OllamaEmbeddingOptions
- OllamaEmbeddingSettings
- OllamaSettings
"""
@@ -15,7 +19,14 @@ from typing import Any
IMPORT_PATH = "agent_framework_ollama"
PACKAGE_NAME = "agent-framework-ollama"
_IMPORTS = ["OllamaChatClient", "OllamaSettings"]
_IMPORTS = [
"OllamaChatClient",
"OllamaChatOptions",
"OllamaEmbeddingClient",
"OllamaEmbeddingOptions",
"OllamaEmbeddingSettings",
"OllamaSettings",
]
def __getattr__(name: str) -> Any:
@@ -2,10 +2,18 @@
from agent_framework_ollama import (
OllamaChatClient,
OllamaChatOptions,
OllamaEmbeddingClient,
OllamaEmbeddingOptions,
OllamaEmbeddingSettings,
OllamaSettings,
)
__all__ = [
"OllamaChatClient",
"OllamaChatOptions",
"OllamaEmbeddingClient",
"OllamaEmbeddingOptions",
"OllamaEmbeddingSettings",
"OllamaSettings",
]
@@ -12,7 +12,7 @@ from openai import AsyncOpenAI
from .._clients import BaseEmbeddingClient
from .._settings import load_settings
from .._types import Embedding, EmbeddingGenerationOptions, GeneratedEmbeddings
from .._types import Embedding, EmbeddingGenerationOptions, GeneratedEmbeddings, UsageDetails
from ..observability import EmbeddingTelemetryLayer
from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings
@@ -116,11 +116,11 @@ class RawOpenAIEmbeddingClient(
)
)
usage_dict: dict[str, Any] | None = None
usage_dict: UsageDetails | None = None
if response.usage:
usage_dict = {
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
"input_token_count": response.usage.prompt_tokens,
"total_token_count": response.usage.total_tokens,
}
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
@@ -143,6 +143,7 @@ class OpenAIEmbeddingClient(
default_headers: Additional HTTP headers.
async_client: Pre-configured AsyncOpenAI client.
base_url: Custom API base URL.
otel_provider_name: Override the OpenTelemetry provider name for telemetry.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
@@ -176,6 +177,7 @@ class OpenAIEmbeddingClient(
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
base_url: str | None = None,
otel_provider_name: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -208,4 +210,5 @@ class OpenAIEmbeddingClient(
org_id=openai_settings["org_id"],
default_headers=default_headers,
client=async_client,
otel_provider_name=otel_provider_name,
)
+4
View File
@@ -91,6 +91,9 @@ asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.coverage.run]
omit = [
@@ -124,6 +127,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework"
test = "pytest --cov=agent_framework --cov-report=term-missing:skip-covered -n auto --dist worksteal tests"
@@ -100,8 +100,8 @@ async def test_openai_get_embeddings_usage(openai_unit_test_env: None) -> None:
result = await client.get_embeddings(["test"])
assert result.usage is not None
assert result.usage["prompt_tokens"] == 10
assert result.usage["total_tokens"] == 10
assert result.usage["input_token_count"] == 10
assert result.usage["total_token_count"] == 10
async def test_openai_options_passthrough_dimensions(openai_unit_test_env: None) -> None:
@@ -284,7 +284,7 @@ async def test_integration_openai_get_embeddings() -> None:
assert all(isinstance(v, float) for v in result[0].vector)
assert result[0].model_id is not None
assert result.usage is not None
assert result.usage["prompt_tokens"] > 0
assert result.usage["input_token_count"] > 0
@skip_if_openai_integration_tests_disabled
@@ -327,7 +327,7 @@ async def test_integration_azure_openai_get_embeddings() -> None:
assert all(isinstance(v, float) for v in result[0].vector)
assert result[0].model_id is not None
assert result.usage is not None
assert result.usage["prompt_tokens"] > 0
assert result.usage["input_token_count"] > 0
@skip_if_azure_openai_integration_tests_disabled
+5 -1
View File
@@ -39,9 +39,9 @@ environments = [
"sys_platform == 'win32'"
]
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -51,6 +51,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -88,6 +91,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_declarative"
test = "pytest --cov=agent_framework_declarative --cov-report=term-missing:skip-covered tests"
+4 -1
View File
@@ -54,6 +54,9 @@ addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -66,7 +69,6 @@ omit = [
[tool.pyright]
extends = "../../pyproject.toml"
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
@@ -89,6 +91,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_devui"
test = "pytest --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests"
@@ -43,6 +43,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
pythonpath = ["tests/integration_tests"]
@@ -94,6 +95,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_durabletask"
test = "pytest --cov=agent_framework_durabletask --cov-report=term-missing:skip-covered tests"
@@ -37,6 +37,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -44,6 +45,9 @@ asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -78,6 +82,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_foundry_local"
test = "pytest --cov=agent_framework_foundry_local --cov-report=term-missing:skip-covered tests"
@@ -37,6 +37,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -46,6 +47,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -58,7 +62,6 @@ omit = [
[tool.pyright]
extends = "../../pyproject.toml"
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
@@ -80,6 +83,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_github_copilot"
test = "pytest --cov=agent_framework_github_copilot --cov-report=term-missing:skip-covered tests"
+5 -1
View File
@@ -37,6 +37,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -46,6 +47,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -58,7 +62,6 @@ omit = [
[tool.pyright]
extends = "../../pyproject.toml"
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
@@ -80,6 +83,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mem0"
test = "pytest --cov=agent_framework_mem0 --cov-report=term-missing:skip-covered tests"
@@ -3,6 +3,7 @@
import importlib.metadata
from ._chat_client import OllamaChatClient, OllamaChatOptions, OllamaSettings
from ._embedding_client import OllamaEmbeddingClient, OllamaEmbeddingOptions, OllamaEmbeddingSettings
try:
__version__ = importlib.metadata.version(__name__)
@@ -12,6 +13,9 @@ except importlib.metadata.PackageNotFoundError:
__all__ = [
"OllamaChatClient",
"OllamaChatOptions",
"OllamaEmbeddingClient",
"OllamaEmbeddingOptions",
"OllamaEmbeddingSettings",
"OllamaSettings",
"__version__",
]
@@ -0,0 +1,230 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import logging
import sys
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, TypedDict
from agent_framework import (
BaseEmbeddingClient,
Embedding,
EmbeddingGenerationOptions,
GeneratedEmbeddings,
UsageDetails,
load_settings,
)
from agent_framework.observability import EmbeddingTelemetryLayer
from ollama import AsyncClient
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
logger = logging.getLogger("agent_framework.ollama")
class OllamaEmbeddingOptions(EmbeddingGenerationOptions, total=False):
"""Ollama-specific embedding options.
Extends EmbeddingGenerationOptions with Ollama-specific fields.
Examples:
.. code-block:: python
from agent_framework_ollama import OllamaEmbeddingOptions
options: OllamaEmbeddingOptions = {
"model_id": "nomic-embed-text",
"dimensions": 768,
"truncate": True,
}
"""
truncate: bool
"""Whether to truncate input text that exceeds the model's context length.
When True, input that is too long will be silently truncated.
When False (default), the request will fail if input exceeds the context length.
"""
keep_alive: float | str
"""How long to keep the model loaded in memory (e.g. ``"5m"``, ``300``)."""
OllamaEmbeddingOptionsT = TypeVar(
"OllamaEmbeddingOptionsT",
bound=TypedDict, # type: ignore[valid-type]
default="OllamaEmbeddingOptions",
covariant=True,
)
class OllamaEmbeddingSettings(TypedDict, total=False):
"""Ollama embedding settings."""
host: str | None
embedding_model_id: str | None
class RawOllamaEmbeddingClient(
BaseEmbeddingClient[str, list[float], OllamaEmbeddingOptionsT],
Generic[OllamaEmbeddingOptionsT],
):
"""Raw Ollama embedding client without telemetry.
Keyword Args:
model_id: The Ollama embedding model ID (e.g. "nomic-embed-text").
Can also be set via environment variable OLLAMA_EMBEDDING_MODEL_ID.
host: Ollama server URL. Defaults to http://localhost:11434.
Can also be set via environment variable OLLAMA_HOST.
client: Optional pre-configured Ollama AsyncClient.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
"""
def __init__(
self,
*,
model_id: str | None = None,
host: str | None = None,
client: AsyncClient | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw Ollama embedding client."""
ollama_settings = load_settings(
OllamaEmbeddingSettings,
env_prefix="OLLAMA_",
required_fields=["embedding_model_id"],
host=host,
embedding_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self.model_id = ollama_settings["embedding_model_id"]
self.client = client or AsyncClient(host=ollama_settings.get("host"))
self.host = str(self.client._client.base_url) # pyright: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
super().__init__(**kwargs)
def service_url(self) -> str:
"""Get the URL of the service."""
return self.host
async def get_embeddings(
self,
values: Sequence[str],
*,
options: OllamaEmbeddingOptionsT | None = None,
) -> GeneratedEmbeddings[list[float]]:
"""Call the Ollama embed API.
Args:
values: The text values to generate embeddings for.
options: Optional embedding generation options.
Returns:
Generated embeddings with usage metadata.
Raises:
ValueError: If model_id is not provided or values is empty.
"""
if not values:
return GeneratedEmbeddings([], options=options)
opts: dict[str, Any] = dict(options) if options else {}
model = opts.get("model_id") or self.model_id
if not model:
raise ValueError("model_id is required")
kwargs: dict[str, Any] = {"model": model, "input": list(values)}
if (truncate := opts.get("truncate")) is not None:
kwargs["truncate"] = truncate
if keep_alive := opts.get("keep_alive"):
kwargs["keep_alive"] = keep_alive
if dimensions := opts.get("dimensions"):
kwargs["dimensions"] = dimensions
response = await self.client.embed(**kwargs)
embeddings = [
Embedding(
vector=list(emb),
dimensions=len(emb),
model_id=response.get("model") or model,
)
for emb in response.get("embeddings", [])
]
usage_dict: UsageDetails | None = None
prompt_eval_count = response.get("prompt_eval_count")
if prompt_eval_count is not None:
usage_dict = {"input_token_count": prompt_eval_count}
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
class OllamaEmbeddingClient(
EmbeddingTelemetryLayer[str, list[float], OllamaEmbeddingOptionsT],
RawOllamaEmbeddingClient[OllamaEmbeddingOptionsT],
Generic[OllamaEmbeddingOptionsT],
):
"""Ollama embedding client with telemetry support.
Keyword Args:
model_id: The Ollama embedding model ID (e.g. "nomic-embed-text").
Can also be set via environment variable OLLAMA_EMBEDDING_MODEL_ID.
host: Ollama server URL. Defaults to http://localhost:11434.
Can also be set via environment variable OLLAMA_HOST.
client: Optional pre-configured Ollama AsyncClient.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
Examples:
.. code-block:: python
from agent_framework_ollama import OllamaEmbeddingClient
# Using environment variables
# Set OLLAMA_EMBEDDING_MODEL_ID=nomic-embed-text
client = OllamaEmbeddingClient()
# Or passing parameters directly
client = OllamaEmbeddingClient(
model_id="nomic-embed-text",
host="http://localhost:11434",
)
# Generate embeddings
result = await client.get_embeddings(["Hello, world!"])
print(result[0].vector)
"""
OTEL_PROVIDER_NAME: ClassVar[str] = "ollama"
def __init__(
self,
*,
model_id: str | None = None,
host: str | None = None,
client: AsyncClient | None = None,
otel_provider_name: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Ollama embedding client."""
super().__init__(
model_id=model_id,
host=host,
client=client,
otel_provider_name=otel_provider_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
+5
View File
@@ -37,12 +37,16 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
markers = [
"integration: marks tests as integration tests that require external services",
]
timeout = 120
[tool.ruff]
@@ -82,6 +86,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_ollama"
test = "pytest --cov=agent_framework_ollama --cov-report=term-missing:skip-covered tests"
@@ -0,0 +1,150 @@
# Copyright (c) Microsoft. All rights reserved.
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import Embedding, GeneratedEmbeddings
from agent_framework_ollama import OllamaEmbeddingClient, OllamaEmbeddingOptions
# region: Unit Tests
def test_ollama_embedding_construction(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test construction with explicit parameters."""
monkeypatch.setenv("OLLAMA_EMBEDDING_MODEL_ID", "nomic-embed-text")
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
mock_client_cls.return_value = MagicMock()
client = OllamaEmbeddingClient()
assert client.model_id == "nomic-embed-text"
def test_ollama_embedding_construction_with_params() -> None:
"""Test construction with explicit parameters."""
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
mock_client_cls.return_value = MagicMock()
client = OllamaEmbeddingClient(
model_id="nomic-embed-text",
host="http://localhost:11434",
)
assert client.model_id == "nomic-embed-text"
def test_ollama_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that missing model_id raises an error."""
monkeypatch.delenv("OLLAMA_EMBEDDING_MODEL_ID", raising=False)
monkeypatch.delenv("OLLAMA_MODEL_ID", raising=False)
from agent_framework.exceptions import SettingNotFoundError
with pytest.raises(SettingNotFoundError):
OllamaEmbeddingClient()
async def test_ollama_embedding_get_embeddings() -> None:
"""Test generating embeddings via the Ollama API."""
mock_response = {
"model": "nomic-embed-text",
"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
"prompt_eval_count": 10,
}
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
mock_client = MagicMock()
mock_client.embed = AsyncMock(return_value=mock_response)
mock_client_cls.return_value = mock_client
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
result = await client.get_embeddings(["hello", "world"])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 2
assert result[0].vector == [0.1, 0.2, 0.3]
assert result[1].vector == [0.4, 0.5, 0.6]
assert result[0].model_id == "nomic-embed-text"
assert result.usage == {"input_token_count": 10}
mock_client.embed.assert_called_once_with(
model="nomic-embed-text",
input=["hello", "world"],
)
async def test_ollama_embedding_get_embeddings_empty_input() -> None:
"""Test generating embeddings with empty input."""
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
result = await client.get_embeddings([])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 0
mock_client.embed.assert_not_called()
async def test_ollama_embedding_get_embeddings_with_options() -> None:
"""Test generating embeddings with custom options."""
mock_response = {
"model": "nomic-embed-text",
"embeddings": [[0.1, 0.2, 0.3]],
}
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
mock_client = MagicMock()
mock_client.embed = AsyncMock(return_value=mock_response)
mock_client_cls.return_value = mock_client
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
options: OllamaEmbeddingOptions = {
"truncate": True,
"dimensions": 512,
}
result = await client.get_embeddings(["hello"], options=options)
assert len(result) == 1
mock_client.embed.assert_called_once_with(
model="nomic-embed-text",
input=["hello"],
truncate=True,
dimensions=512,
)
async def test_ollama_embedding_get_embeddings_no_model_raises() -> None:
"""Test that missing model_id at call time raises ValueError."""
with patch("agent_framework_ollama._embedding_client.AsyncClient") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
client = OllamaEmbeddingClient(model_id="nomic-embed-text")
client.model_id = None # type: ignore[assignment]
with pytest.raises(ValueError, match="model_id is required"):
await client.get_embeddings(["hello"])
# region: Integration Tests
skip_if_ollama_embedding_integration_tests_disabled = pytest.mark.skipif(
os.getenv("OLLAMA_EMBEDDING_MODEL_ID", "") in ("", "test-model"),
reason="No real Ollama embedding model provided; skipping integration tests.",
)
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_ollama_embedding_integration_tests_disabled
async def test_ollama_embedding_integration() -> None:
"""Integration test for Ollama embedding client."""
client = OllamaEmbeddingClient()
result = await client.get_embeddings(["Hello, world!", "How are you?"])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 2
for embedding in result:
assert isinstance(embedding, Embedding)
assert isinstance(embedding.vector, list)
assert len(embedding.vector) > 0
assert all(isinstance(v, float) for v in embedding.vector)
@@ -44,6 +44,9 @@ asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -78,6 +81,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_orchestrations"
test = "pytest --cov=agent_framework_orchestrations --cov-report=term-missing:skip-covered -n auto --dist worksteal tests"
+5 -1
View File
@@ -39,12 +39,16 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -57,7 +61,6 @@ omit = [
[tool.pyright]
extends = "../../pyproject.toml"
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
@@ -79,6 +82,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_purview"
test = "pytest --cov=agent_framework_purview --cov-report=term-missing:skip-covered tests"
+5
View File
@@ -39,6 +39,7 @@ environments = [
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
@@ -48,6 +49,9 @@ filterwarnings = [
"ignore:Support for class-based `config` is deprecated:DeprecationWarning:pydantic.*"
]
timeout = 120
markers = [
"integration: marks tests as integration tests that require external services",
]
[tool.ruff]
extend = "../../pyproject.toml"
@@ -81,6 +85,7 @@ exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_redis"
test = "pytest --cov=agent_framework_redis --cov-report=term-missing:skip-covered tests"
@@ -0,0 +1,87 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "agent-framework-azure-ai",
# ]
# ///
# Run with: uv run samples/02-agents/embeddings/azure_ai_inference_embeddings.py
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import pathlib
from agent_framework import Content
from agent_framework_azure_ai import AzureAIInferenceEmbeddingClient
from dotenv import load_dotenv
load_dotenv()
"""Azure AI Inference Image Embedding Example
This sample demonstrates how to generate image embeddings using the
Azure AI Inference embedding client with the Cohere-embed-v3-english model.
Images are passed as ``Content`` objects created with ``Content.from_data()``.
Prerequisites:
Set the following environment variables or add them to a .env file:
- AZURE_AI_INFERENCE_ENDPOINT: Your Azure AI model inference endpoint URL
- AZURE_AI_INFERENCE_API_KEY: Your API key
- AZURE_AI_INFERENCE_EMBEDDING_MODEL_ID: The text embedding model name
(e.g. "text-embedding-3-small")
- AZURE_AI_INFERENCE_IMAGE_EMBEDDING_MODEL_ID: The image embedding model name
(e.g. "Cohere-embed-v3-english")
"""
SAMPLE_IMAGE_PATH = pathlib.Path(__file__).parent.parent.parent / "shared" / "sample_assets" / "sample_image.jpg"
async def main() -> None:
"""Generate image embeddings with Azure AI Inference."""
async with AzureAIInferenceEmbeddingClient() as client:
# 1. Generate an image embedding.
image_bytes = SAMPLE_IMAGE_PATH.read_bytes()
image_content = Content.from_data(data=image_bytes, media_type="image/jpeg")
result = await client.get_embeddings([image_content])
print(f"Image embedding dimensions: {result[0].dimensions}")
print(f"First 5 values: {result[0].vector[:5]}")
print(f"Model: {result[0].model_id}")
print(f"Usage: {result.usage}")
print()
# 2. Generate image and text embeddings separately in one call.
# The client dispatches text to the text endpoint and images to the image
# endpoint, then reassembles results in the original input order.
result = await client.get_embeddings(["A half-timbered house in a forested valley", image_content])
print(f"Text embedding dimensions: {result[0].dimensions}")
print(f"First 5 values: {result[0].vector[:5]}")
print(f"Image embedding dimensions: {result[1].dimensions}")
print(f"First 5 values: {result[1].vector[:5]}")
print()
# 3. Generate image embeddings with input_type option.
result = await client.get_embeddings(
[image_content],
options={"input_type": "document"},
)
print(f"Document embedding dimensions: {result[0].dimensions}")
print(f"First 5 values: {result[0].vector[:5]}")
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output (using Cohere-embed-v3-english):
Image embedding dimensions: 1024
First 5 values: [0.023, -0.045, 0.067, -0.089, 0.011]
Model: Cohere-embed-v3-english
Usage: {'prompt_tokens': 1, 'total_tokens': 1}
Image+text (separate) results:
Text embedding dimensions: 1536
Image embedding dimensions: 1024
Document embedding dimensions: 1024
"""
Binary file not shown.

After

Width:  |  Height:  |  Size: 178 KiB

+16
View File
@@ -209,6 +209,7 @@ dependencies = [
{ name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-ai-agents", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "azure-ai-inference", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
[package.metadata]
@@ -216,6 +217,7 @@ requires-dist = [
{ name = "agent-framework-core", editable = "packages/core" },
{ name = "aiohttp" },
{ name = "azure-ai-agents", specifier = "==1.2.0b5" },
{ name = "azure-ai-inference", specifier = ">=1.0.0b9" },
]
[[package]]
@@ -979,6 +981,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6d/6d/15070d23d7a94833a210da09d5d7ed3c24838bb84f0463895e5d159f1695/azure_ai_agents-1.2.0b5-py3-none-any.whl", hash = "sha256:257d0d24a6bf13eed4819cfa5c12fb222e5908deafb3cbfd5711d3a511cc4e88", size = 217948, upload-time = "2025-09-30T01:55:04.155Z" },
]
[[package]]
name = "azure-ai-inference"
version = "1.0.0b9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "azure-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "isodate", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4e/6a/ed85592e5c64e08c291992f58b1a94dab6869f28fb0f40fd753dced73ba6/azure_ai_inference-1.0.0b9.tar.gz", hash = "sha256:1feb496bd84b01ee2691befc04358fa25d7c344d8288e99364438859ad7cd5a4", size = 182408, upload-time = "2025-02-15T00:37:28.464Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4f/0f/27520da74769db6e58327d96c98e7b9a07ce686dff582c9a5ec60b03f9dd/azure_ai_inference-1.0.0b9-py3-none-any.whl", hash = "sha256:49823732e674092dad83bb8b0d1b65aa73111fab924d61349eb2a8cdc0493990", size = 124885, upload-time = "2025-02-15T00:37:29.964Z" },
]
[[package]]
name = "azure-ai-projects"
version = "2.0.0b3"