mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Python: move Azure AI embeddings to Foundry (#5056)
* renamed AzureAIINferenceEmbeddings and lazy load azure-cosmos and env var rename * updated coverage * fix readme
This commit is contained in:
committed by
GitHub
Unverified
parent
47d82911c0
commit
95fd5ec658
@@ -1,3 +1,3 @@
|
||||
# Agent Framework Foundry
|
||||
|
||||
This package contains the cloud Azure AI Foundry integrations for Microsoft Agent Framework, including Foundry chat clients, preconfigured Foundry agents, and Foundry memory providers.
|
||||
This package contains the Microsoft Foundry integrations for Microsoft Agent Framework, including Foundry chat clients, preconfigured Foundry agents, Foundry embedding clients, and Foundry memory providers.
|
||||
|
||||
@@ -4,6 +4,12 @@ import importlib.metadata
|
||||
|
||||
from ._agent import FoundryAgent, RawFoundryAgent, RawFoundryAgentChatClient
|
||||
from ._chat_client import FoundryChatClient, FoundryChatOptions, RawFoundryChatClient
|
||||
from ._embedding_client import (
|
||||
FoundryEmbeddingClient,
|
||||
FoundryEmbeddingOptions,
|
||||
FoundryEmbeddingSettings,
|
||||
RawFoundryEmbeddingClient,
|
||||
)
|
||||
from ._foundry_evals import (
|
||||
FoundryEvals,
|
||||
evaluate_foundry_target,
|
||||
@@ -20,11 +26,15 @@ __all__ = [
|
||||
"FoundryAgent",
|
||||
"FoundryChatClient",
|
||||
"FoundryChatOptions",
|
||||
"FoundryEmbeddingClient",
|
||||
"FoundryEmbeddingOptions",
|
||||
"FoundryEmbeddingSettings",
|
||||
"FoundryEvals",
|
||||
"FoundryMemoryProvider",
|
||||
"RawFoundryAgent",
|
||||
"RawFoundryAgentChatClient",
|
||||
"RawFoundryChatClient",
|
||||
"RawFoundryEmbeddingClient",
|
||||
"__version__",
|
||||
"evaluate_foundry_target",
|
||||
"evaluate_traces",
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from contextlib import suppress
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
|
||||
from agent_framework import (
|
||||
BaseEmbeddingClient,
|
||||
Content,
|
||||
Embedding,
|
||||
EmbeddingGenerationOptions,
|
||||
GeneratedEmbeddings,
|
||||
UsageDetails,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework.observability import EmbeddingTelemetryLayer
|
||||
from azure.ai.inference.aio import EmbeddingsClient, ImageEmbeddingsClient
|
||||
from azure.ai.inference.models import ImageEmbeddingInput
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework.foundry")
|
||||
|
||||
_IMAGE_MEDIA_PREFIXES = ("image/",)
|
||||
|
||||
|
||||
class FoundryEmbeddingOptions(EmbeddingGenerationOptions, total=False):
|
||||
"""Foundry inference-specific embedding options.
|
||||
|
||||
Extends ``EmbeddingGenerationOptions`` with Foundry inference-specific fields.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_foundry import FoundryEmbeddingOptions
|
||||
|
||||
options: FoundryEmbeddingOptions = {
|
||||
"model": "text-embedding-3-small",
|
||||
"dimensions": 1536,
|
||||
"input_type": "document",
|
||||
"encoding_format": "float",
|
||||
}
|
||||
"""
|
||||
|
||||
input_type: str
|
||||
"""Input type hint for the model. Common values: ``"text"``, ``"query"``, ``"document"``."""
|
||||
|
||||
image_model: str
|
||||
"""Override model for image embeddings. Falls back to the client's ``image_model``."""
|
||||
|
||||
encoding_format: str
|
||||
"""Output encoding format.
|
||||
|
||||
Common values: ``"float"``, ``"base64"``, ``"int8"``, ``"uint8"``,
|
||||
``"binary"``, ``"ubinary"``.
|
||||
"""
|
||||
|
||||
extra_parameters: dict[str, Any]
|
||||
"""Additional model-specific parameters passed directly to the API."""
|
||||
|
||||
|
||||
FoundryEmbeddingOptionsT = TypeVar(
|
||||
"FoundryEmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="FoundryEmbeddingOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class FoundryEmbeddingSettings(TypedDict, total=False):
|
||||
"""Foundry inference embedding settings."""
|
||||
|
||||
models_endpoint: str | None
|
||||
models_api_key: str | None
|
||||
embedding_model: str | None
|
||||
image_embedding_model: str | None
|
||||
|
||||
|
||||
class RawFoundryEmbeddingClient(
|
||||
BaseEmbeddingClient[Content | str, list[float], FoundryEmbeddingOptionsT],
|
||||
Generic[FoundryEmbeddingOptionsT],
|
||||
):
|
||||
"""Raw Foundry embedding client without telemetry.
|
||||
|
||||
Accepts both text (``str``) and image (``Content``) inputs. Text and image
|
||||
inputs within a single batch are separated and dispatched to
|
||||
``EmbeddingsClient`` and ``ImageEmbeddingsClient`` respectively. Results
|
||||
are reassembled in the original input order.
|
||||
|
||||
Keyword Args:
|
||||
model: The text embedding model (e.g. "text-embedding-3-small").
|
||||
Can also be set via environment variable FOUNDRY_EMBEDDING_MODEL.
|
||||
image_model: The image embedding model (e.g. "Cohere-embed-v3-english").
|
||||
Can also be set via environment variable FOUNDRY_IMAGE_EMBEDDING_MODEL.
|
||||
Falls back to ``model`` if not provided.
|
||||
endpoint: The Foundry inference endpoint URL.
|
||||
Can also be set via environment variable FOUNDRY_MODELS_ENDPOINT.
|
||||
api_key: API key for authentication.
|
||||
Can also be set via environment variable FOUNDRY_MODELS_API_KEY.
|
||||
text_client: Optional pre-configured ``EmbeddingsClient``.
|
||||
image_client: Optional pre-configured ``ImageEmbeddingsClient``.
|
||||
credential: Optional ``AzureKeyCredential`` or token credential. If not provided,
|
||||
one is created from ``api_key``.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_model: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
text_client: EmbeddingsClient | None = None,
|
||||
image_client: ImageEmbeddingsClient | None = None,
|
||||
credential: AzureKeyCredential | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a raw Foundry embedding client."""
|
||||
settings = load_settings(
|
||||
FoundryEmbeddingSettings,
|
||||
env_prefix="FOUNDRY_",
|
||||
required_fields=["models_endpoint", "embedding_model"],
|
||||
models_endpoint=endpoint,
|
||||
models_api_key=api_key,
|
||||
embedding_model=model,
|
||||
image_embedding_model=image_model,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
self.model = settings["embedding_model"] # type: ignore[reportTypedDictNotRequiredAccess]
|
||||
self.image_model: str = settings.get("image_embedding_model") or self.model # type: ignore[assignment]
|
||||
resolved_endpoint = settings["models_endpoint"] # type: ignore[reportTypedDictNotRequiredAccess]
|
||||
|
||||
if credential is None and settings.get("models_api_key"):
|
||||
credential = AzureKeyCredential(settings["models_api_key"]) # type: ignore[arg-type]
|
||||
|
||||
if credential is None and text_client is None and image_client is None:
|
||||
raise ValueError("Either 'api_key', 'credential', or pre-configured client(s) must be provided.")
|
||||
|
||||
self._text_client = text_client or EmbeddingsClient(
|
||||
endpoint=resolved_endpoint, # type: ignore[arg-type]
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
)
|
||||
self._image_client = image_client or ImageEmbeddingsClient(
|
||||
endpoint=resolved_endpoint, # type: ignore[arg-type]
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
)
|
||||
self._endpoint = resolved_endpoint
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying SDK clients and release resources."""
|
||||
with suppress(Exception):
|
||||
await self._text_client.close()
|
||||
with suppress(Exception):
|
||||
await self._image_client.close()
|
||||
|
||||
async def __aenter__(self) -> RawFoundryEmbeddingClient[FoundryEmbeddingOptionsT]:
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
"""Exit the async context manager and close clients."""
|
||||
await self.close()
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
return self._endpoint or ""
|
||||
|
||||
async def get_embeddings(
|
||||
self,
|
||||
values: Sequence[Content | str],
|
||||
*,
|
||||
options: FoundryEmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[list[float], FoundryEmbeddingOptionsT]:
|
||||
"""Generate embeddings for text and/or image inputs.
|
||||
|
||||
Text inputs (``str`` or ``Content`` with ``type="text"``) are sent to the
|
||||
text embeddings endpoint. Image inputs (``Content`` with an image
|
||||
``media_type``) are sent to the image embeddings endpoint. Results are
|
||||
returned in the same order as the input.
|
||||
|
||||
Args:
|
||||
values: A sequence of text strings or ``Content`` instances.
|
||||
options: Optional embedding generation options.
|
||||
|
||||
Returns:
|
||||
Generated embeddings with usage metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not provided or an unsupported content type is encountered.
|
||||
"""
|
||||
if not values:
|
||||
return GeneratedEmbeddings([], options=options) # type: ignore[reportReturnType]
|
||||
|
||||
opts: dict[str, Any] = dict(options) if options else {}
|
||||
|
||||
# Separate text and image inputs, tracking original indices.
|
||||
text_items: list[tuple[int, str]] = []
|
||||
image_items: list[tuple[int, ImageEmbeddingInput]] = []
|
||||
|
||||
for idx, value in enumerate(values):
|
||||
if isinstance(value, str):
|
||||
text_items.append((idx, value))
|
||||
elif isinstance(value, Content):
|
||||
if value.type == "text" and value.text is not None:
|
||||
text_items.append((idx, value.text))
|
||||
elif (
|
||||
value.type in ("data", "uri")
|
||||
and value.media_type
|
||||
and value.media_type.startswith(_IMAGE_MEDIA_PREFIXES[0])
|
||||
):
|
||||
if not value.uri:
|
||||
raise ValueError(f"Image Content at index {idx} has no URI.")
|
||||
image_input = ImageEmbeddingInput(image=value.uri, text=value.text)
|
||||
image_items.append((idx, image_input))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported Content type '{value.type}' with media_type "
|
||||
f"'{value.media_type}' at index {idx}. Expected text content or "
|
||||
f"image content (media_type starting with 'image/')."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type {type(value).__name__} at index {idx}.")
|
||||
|
||||
# Build shared API kwargs (without model, which differs per client).
|
||||
common_kwargs: dict[str, Any] = {}
|
||||
if dimensions := opts.get("dimensions"):
|
||||
common_kwargs["dimensions"] = dimensions
|
||||
if encoding_format := opts.get("encoding_format"):
|
||||
common_kwargs["encoding_format"] = encoding_format
|
||||
if input_type := opts.get("input_type"):
|
||||
common_kwargs["input_type"] = input_type
|
||||
if extra_parameters := opts.get("extra_parameters"):
|
||||
common_kwargs["model_extras"] = extra_parameters
|
||||
|
||||
# Allocate results array.
|
||||
embeddings: list[Embedding[list[float]] | None] = [None] * len(values)
|
||||
usage_details: UsageDetails = {}
|
||||
|
||||
# Embed text inputs.
|
||||
if text_items:
|
||||
if not (text_model := opts.get("model") or self.model):
|
||||
raise ValueError("A model is required, either in the client or options, for text inputs.")
|
||||
text_inputs = [t for _, t in text_items]
|
||||
response = await self._text_client.embed(
|
||||
input=text_inputs,
|
||||
model=text_model,
|
||||
**common_kwargs,
|
||||
)
|
||||
for i, item in enumerate(response.data):
|
||||
original_idx = text_items[i][0]
|
||||
vector: list[float] = [float(v) for v in item.embedding]
|
||||
embeddings[original_idx] = Embedding(
|
||||
vector=vector,
|
||||
dimensions=len(vector),
|
||||
model=response.model or text_model,
|
||||
)
|
||||
if response.usage:
|
||||
usage_details["input_token_count"] = (usage_details.get("input_token_count") or 0) + (
|
||||
response.usage.prompt_tokens or 0
|
||||
)
|
||||
usage_details["output_token_count"] = (usage_details.get("output_token_count") or 0) + (
|
||||
getattr(response.usage, "completion_tokens", 0) or 0
|
||||
)
|
||||
|
||||
# Embed image inputs.
|
||||
if image_items:
|
||||
if not (image_model := opts.get("image_model") or self.image_model):
|
||||
raise ValueError("An image_model is required, either in the client or options, for image inputs.")
|
||||
image_inputs = [img for _, img in image_items]
|
||||
response = await self._image_client.embed(
|
||||
input=image_inputs,
|
||||
model=image_model,
|
||||
**common_kwargs,
|
||||
)
|
||||
for i, item in enumerate(response.data):
|
||||
original_idx = image_items[i][0]
|
||||
image_vector: list[float] = [float(v) for v in item.embedding]
|
||||
embeddings[original_idx] = Embedding(
|
||||
vector=image_vector,
|
||||
dimensions=len(image_vector),
|
||||
model=response.model or image_model,
|
||||
)
|
||||
if response.usage:
|
||||
usage_details["input_token_count"] = (usage_details.get("input_token_count") or 0) + (
|
||||
response.usage.prompt_tokens or 0
|
||||
)
|
||||
usage_details["output_token_count"] = (usage_details.get("output_token_count") or 0) + (
|
||||
getattr(response.usage, "completion_tokens", 0) or 0
|
||||
)
|
||||
return GeneratedEmbeddings(
|
||||
[embedding for embedding in embeddings if embedding is not None],
|
||||
options=options,
|
||||
usage=usage_details,
|
||||
) # type: ignore[reportReturnType]
|
||||
|
||||
|
||||
class FoundryEmbeddingClient(
|
||||
EmbeddingTelemetryLayer[Content | str, list[float], FoundryEmbeddingOptionsT],
|
||||
RawFoundryEmbeddingClient[FoundryEmbeddingOptionsT],
|
||||
Generic[FoundryEmbeddingOptionsT],
|
||||
):
|
||||
"""Foundry embedding client with telemetry support.
|
||||
|
||||
Supports both text and image inputs in a single client. Pass plain strings
|
||||
or ``Content`` instances created with ``Content.from_text()`` or
|
||||
``Content.from_data()``.
|
||||
|
||||
Keyword Args:
|
||||
model: The text embedding model (e.g. "text-embedding-3-small").
|
||||
Can also be set via environment variable FOUNDRY_EMBEDDING_MODEL.
|
||||
image_model: The image embedding model
|
||||
(e.g. "Cohere-embed-v3-english"). Can also be set via environment variable
|
||||
FOUNDRY_IMAGE_EMBEDDING_MODEL. Falls back to ``model``.
|
||||
endpoint: The Foundry inference endpoint URL.
|
||||
Can also be set via environment variable FOUNDRY_MODELS_ENDPOINT.
|
||||
api_key: API key for authentication.
|
||||
Can also be set via environment variable FOUNDRY_MODELS_API_KEY.
|
||||
text_client: Optional pre-configured ``EmbeddingsClient``.
|
||||
image_client: Optional pre-configured ``ImageEmbeddingsClient``.
|
||||
credential: Optional ``AzureKeyCredential`` or token credential.
|
||||
otel_provider_name: Override for the OpenTelemetry provider name.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_foundry import FoundryEmbeddingClient
|
||||
|
||||
# Using environment variables
|
||||
# Set FOUNDRY_MODELS_ENDPOINT=https://your-endpoint.inference.ai.azure.com
|
||||
# Set FOUNDRY_MODELS_API_KEY=your-key
|
||||
# Set FOUNDRY_EMBEDDING_MODEL=text-embedding-3-small
|
||||
# Set FOUNDRY_IMAGE_EMBEDDING_MODEL=Cohere-embed-v3-english
|
||||
client = FoundryEmbeddingClient()
|
||||
|
||||
# Text embeddings
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
|
||||
# Image embeddings
|
||||
from agent_framework import Content
|
||||
|
||||
image = Content.from_data(data=image_bytes, media_type="image/png")
|
||||
result = await client.get_embeddings([image])
|
||||
|
||||
# Mixed text and image
|
||||
result = await client.get_embeddings(["hello", image])
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.inference" # type: ignore[reportIncompatibleVariableOverride, misc]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_model: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
text_client: EmbeddingsClient | None = None,
|
||||
image_client: ImageEmbeddingsClient | None = None,
|
||||
credential: AzureKeyCredential | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a Foundry embedding client."""
|
||||
super().__init__(
|
||||
model=model,
|
||||
image_model=image_model,
|
||||
endpoint=endpoint,
|
||||
api_key=api_key,
|
||||
text_client=text_client,
|
||||
image_client=image_client,
|
||||
credential=credential,
|
||||
additional_properties=additional_properties,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
@@ -25,6 +25,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.0.0rc6",
|
||||
"agent-framework-openai>=1.0.0rc6",
|
||||
"azure-ai-inference>=1.0.0b9,<1.0.0b10",
|
||||
"azure-ai-projects>=2.0.0,<3.0",
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Content
|
||||
|
||||
from agent_framework_foundry import (
|
||||
FoundryEmbeddingClient,
|
||||
FoundryEmbeddingOptions,
|
||||
RawFoundryEmbeddingClient,
|
||||
)
|
||||
|
||||
|
||||
def _make_embed_response(
|
||||
embeddings: Sequence[list[float]],
|
||||
model: str = "test-model",
|
||||
prompt_tokens: int = 10,
|
||||
) -> MagicMock:
|
||||
"""Create a mock EmbeddingsResult."""
|
||||
data = []
|
||||
for emb in embeddings:
|
||||
item = MagicMock()
|
||||
item.embedding = emb
|
||||
data.append(item)
|
||||
|
||||
usage = MagicMock()
|
||||
usage.prompt_tokens = prompt_tokens
|
||||
usage.completion_tokens = 0
|
||||
|
||||
result = MagicMock()
|
||||
result.data = data
|
||||
result.model = model
|
||||
result.usage = usage
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_client() -> AsyncMock:
|
||||
"""Create a mock text EmbeddingsClient."""
|
||||
client = AsyncMock()
|
||||
client.embed = AsyncMock(return_value=_make_embed_response([[0.1, 0.2, 0.3]]))
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_client() -> AsyncMock:
|
||||
"""Create a mock image ImageEmbeddingsClient."""
|
||||
client = AsyncMock()
|
||||
client.embed = AsyncMock(return_value=_make_embed_response([[0.4, 0.5, 0.6]]))
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def raw_client(mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> RawFoundryEmbeddingClient[Any]:
|
||||
"""Create a RawFoundryEmbeddingClient with mocked SDK clients."""
|
||||
return RawFoundryEmbeddingClient(
|
||||
model="test-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> FoundryEmbeddingClient[Any]:
|
||||
"""Create a FoundryEmbeddingClient with mocked SDK clients."""
|
||||
return FoundryEmbeddingClient(
|
||||
model="test-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
|
||||
|
||||
class TestRawFoundryEmbeddingClient:
|
||||
"""Tests for the raw Foundry embedding client."""
|
||||
|
||||
async def test_text_embeddings(
|
||||
self, raw_client: RawFoundryEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Text inputs are dispatched to the text client."""
|
||||
result = await raw_client.get_embeddings(["hello", "world"])
|
||||
assert result is not None
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["input"] == ["hello", "world"]
|
||||
assert call_kwargs.kwargs["model"] == "test-model"
|
||||
|
||||
async def test_text_content_embeddings(
|
||||
self, raw_client: RawFoundryEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Content.from_text() inputs are dispatched to the text client."""
|
||||
text_content = Content.from_text("hello")
|
||||
await raw_client.get_embeddings([text_content])
|
||||
|
||||
mock_text_client.embed.assert_called_once()
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["input"] == ["hello"]
|
||||
|
||||
async def test_image_content_embeddings(
|
||||
self, raw_client: RawFoundryEmbeddingClient[Any], mock_image_client: AsyncMock
|
||||
) -> None:
|
||||
"""Image Content inputs are dispatched to the image client."""
|
||||
image_content = Content.from_data(data=b"\x89PNG", media_type="image/png")
|
||||
await raw_client.get_embeddings([image_content])
|
||||
|
||||
mock_image_client.embed.assert_called_once()
|
||||
call_kwargs = mock_image_client.embed.call_args
|
||||
image_inputs = call_kwargs.kwargs["input"]
|
||||
assert len(image_inputs) == 1
|
||||
assert image_inputs[0].image == image_content.uri
|
||||
|
||||
async def test_mixed_text_and_image(
|
||||
self,
|
||||
raw_client: RawFoundryEmbeddingClient[Any],
|
||||
mock_text_client: AsyncMock,
|
||||
mock_image_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Mixed text and image inputs are dispatched to the correct clients."""
|
||||
mock_text_client.embed.return_value = _make_embed_response([[0.1, 0.2]])
|
||||
mock_image_client.embed.return_value = _make_embed_response([[0.3, 0.4]])
|
||||
|
||||
image = Content.from_data(data=b"\x89PNG", media_type="image/png")
|
||||
await raw_client.get_embeddings(["hello", image, "world"])
|
||||
|
||||
# Text client gets "hello" and "world"
|
||||
text_call = mock_text_client.embed.call_args
|
||||
assert text_call.kwargs["input"] == ["hello", "world"]
|
||||
|
||||
# Image client gets the image
|
||||
image_call = mock_image_client.embed.call_args
|
||||
assert len(image_call.kwargs["input"]) == 1
|
||||
|
||||
async def test_empty_input(self, raw_client: RawFoundryEmbeddingClient[Any]) -> None:
|
||||
"""Empty input returns empty result."""
|
||||
result = await raw_client.get_embeddings([])
|
||||
assert len(result) == 0
|
||||
|
||||
async def test_options_passed_through(
|
||||
self, raw_client: RawFoundryEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Options are passed through to the SDK."""
|
||||
options: FoundryEmbeddingOptions = {
|
||||
"dimensions": 512,
|
||||
"input_type": "document",
|
||||
"encoding_format": "float",
|
||||
}
|
||||
await raw_client.get_embeddings(["hello"], options=options)
|
||||
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["dimensions"] == 512
|
||||
assert call_kwargs.kwargs["input_type"] == "document"
|
||||
assert call_kwargs.kwargs["encoding_format"] == "float"
|
||||
|
||||
async def test_model_override_in_options(
|
||||
self, raw_client: RawFoundryEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""model in options overrides the default."""
|
||||
options: FoundryEmbeddingOptions = {"model": "custom-model"}
|
||||
await raw_client.get_embeddings(["hello"], options=options)
|
||||
|
||||
call_kwargs = mock_text_client.embed.call_args
|
||||
assert call_kwargs.kwargs["model"] == "custom-model"
|
||||
|
||||
async def test_unsupported_content_type_raises(self, raw_client: RawFoundryEmbeddingClient[Any]) -> None:
|
||||
"""Non-text, non-image Content raises ValueError."""
|
||||
error_content = Content("error", message="fail")
|
||||
with pytest.raises(ValueError, match="Unsupported Content type"):
|
||||
await raw_client.get_embeddings([error_content])
|
||||
|
||||
async def test_usage_metadata(
|
||||
self, raw_client: RawFoundryEmbeddingClient[Any], mock_text_client: AsyncMock
|
||||
) -> None:
|
||||
"""Usage metadata is populated from the response."""
|
||||
mock_text_client.embed.return_value = _make_embed_response([[0.1, 0.2]], prompt_tokens=42)
|
||||
result = await raw_client.get_embeddings(["hello"])
|
||||
assert result.usage is not None
|
||||
assert result.usage["input_token_count"] == 42
|
||||
|
||||
def test_service_url(self, raw_client: RawFoundryEmbeddingClient[Any]) -> None:
|
||||
"""service_url returns the configured endpoint."""
|
||||
assert raw_client.service_url() == "https://test.inference.ai.azure.com"
|
||||
|
||||
def test_settings_from_env(self) -> None:
|
||||
"""Settings are loaded from environment variables."""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FOUNDRY_MODELS_ENDPOINT": "https://env.inference.ai.azure.com",
|
||||
"FOUNDRY_MODELS_API_KEY": "env-key",
|
||||
"FOUNDRY_EMBEDDING_MODEL": "env-model",
|
||||
},
|
||||
),
|
||||
patch("agent_framework_foundry._embedding_client.EmbeddingsClient"),
|
||||
patch("agent_framework_foundry._embedding_client.ImageEmbeddingsClient"),
|
||||
):
|
||||
client = RawFoundryEmbeddingClient()
|
||||
assert client.model == "env-model"
|
||||
assert client.image_model == "env-model" # falls back to model
|
||||
|
||||
def test_image_model_from_env(self) -> None:
|
||||
"""image_model is loaded from its own environment variable."""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FOUNDRY_MODELS_ENDPOINT": "https://env.inference.ai.azure.com",
|
||||
"FOUNDRY_MODELS_API_KEY": "env-key",
|
||||
"FOUNDRY_EMBEDDING_MODEL": "text-model",
|
||||
"FOUNDRY_IMAGE_EMBEDDING_MODEL": "image-model",
|
||||
},
|
||||
),
|
||||
patch("agent_framework_foundry._embedding_client.EmbeddingsClient"),
|
||||
patch("agent_framework_foundry._embedding_client.ImageEmbeddingsClient"),
|
||||
):
|
||||
client = RawFoundryEmbeddingClient()
|
||||
assert client.model == "text-model"
|
||||
assert client.image_model == "image-model"
|
||||
|
||||
def test_image_model_explicit(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
|
||||
"""image_model can be set explicitly."""
|
||||
client = RawFoundryEmbeddingClient(
|
||||
model="text-model",
|
||||
image_model="image-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
assert client.model == "text-model"
|
||||
assert client.image_model == "image-model"
|
||||
|
||||
async def test_image_model_sent_to_image_client(
|
||||
self, mock_text_client: AsyncMock, mock_image_client: AsyncMock
|
||||
) -> None:
|
||||
"""image_model is passed to the image client embed call."""
|
||||
client = RawFoundryEmbeddingClient(
|
||||
model="text-model",
|
||||
image_model="image-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
)
|
||||
image_content = Content.from_data(data=b"\x89PNG", media_type="image/png")
|
||||
await client.get_embeddings([image_content])
|
||||
call_kwargs = mock_image_client.embed.call_args
|
||||
assert call_kwargs.kwargs["model"] == "image-model"
|
||||
|
||||
|
||||
class TestFoundryEmbeddingClient:
|
||||
"""Tests for the telemetry-enabled Foundry embedding client."""
|
||||
|
||||
async def test_text_embeddings(self, client: FoundryEmbeddingClient[Any], mock_text_client: AsyncMock) -> None:
|
||||
"""Text embeddings work through the telemetry layer."""
|
||||
result = await client.get_embeddings(["hello"])
|
||||
assert len(result) == 1
|
||||
assert result[0].vector == [0.1, 0.2, 0.3]
|
||||
|
||||
async def test_otel_provider_name_default(self) -> None:
|
||||
"""Default OTEL provider name is azure.ai.inference."""
|
||||
assert FoundryEmbeddingClient.OTEL_PROVIDER_NAME == "azure.ai.inference"
|
||||
|
||||
async def test_otel_provider_name_override(self, mock_text_client: AsyncMock, mock_image_client: AsyncMock) -> None:
|
||||
"""OTEL provider name can be overridden."""
|
||||
client = FoundryEmbeddingClient(
|
||||
model="test-model",
|
||||
endpoint="https://test.inference.ai.azure.com",
|
||||
api_key="test-key",
|
||||
text_client=mock_text_client,
|
||||
image_client=mock_image_client,
|
||||
otel_provider_name="custom-provider",
|
||||
)
|
||||
assert client.otel_provider_name == "custom-provider"
|
||||
|
||||
|
||||
_SKIP_REASON = "Foundry inference integration tests disabled"
|
||||
|
||||
|
||||
def _foundry_integration_tests_enabled() -> bool:
|
||||
return bool(
|
||||
os.environ.get("FOUNDRY_MODELS_ENDPOINT")
|
||||
and os.environ.get("FOUNDRY_MODELS_API_KEY")
|
||||
and os.environ.get("FOUNDRY_EMBEDDING_MODEL")
|
||||
)
|
||||
|
||||
|
||||
skip_if_foundry_inference_integration_tests_disabled = pytest.mark.skipif(
|
||||
not _foundry_integration_tests_enabled(),
|
||||
reason=_SKIP_REASON,
|
||||
)
|
||||
|
||||
|
||||
class TestFoundryEmbeddingIntegration:
|
||||
"""Integration tests requiring a live Foundry inference endpoint."""
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_foundry_inference_integration_tests_disabled
|
||||
async def test_text_embedding_live(self) -> None:
|
||||
"""Generate text embeddings against a live endpoint."""
|
||||
client = FoundryEmbeddingClient()
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
assert len(result) == 1
|
||||
assert len(result[0].vector) > 0
|
||||
assert result[0].model is not None
|
||||
Reference in New Issue
Block a user