mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix OpenAIEmbeddingClient to use AsyncOpenAI for /openai/v1 endpoints (#5137)
* Fix OpenAIEmbeddingClient with /openai/v1 endpoint (#5068) When base_url ends with /openai/v1/ and a credential is provided, load_openai_service_settings was creating an AsyncAzureOpenAI client. The Azure SDK rewrites deployment-based endpoints (including /embeddings) by inserting /deployments/{model}/ into the URL, producing 404s on the OpenAI-compatible /openai/v1 endpoint. Use AsyncOpenAI instead of AsyncAzureOpenAI when the resolved base_url targets /openai/v1, converting the Azure token provider to an async api_key callable. The responses_mode path is unaffected because the Responses API (/responses) is not in the SDK's rewrite list. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Fix OpenAIEmbeddingClient to use AsyncOpenAI for /openai/v1 endpoints Fixes #5068 * Address review feedback: improve test coverage and remove unrelated changes - Revert unrelated formatting change in test_a2a_agent.py - Fix test_init_with_openai_v1_base_url_and_api_key_uses_openai_client to exercise the Azure settings path (via AZURE_OPENAI_BASE_URL env var) instead of the plain OpenAI path, covering the elif api_key branch - Add _ensure_async_token_provider unit tests for both sync and async token providers Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #5068: Python: [Bug]: `OpenAIEmbeddingClient` does not work with `/openai/v1` endpoint --------- Co-authored-by: MAF Dashboard Bot <maf-dashboard-bot@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
aa582d021d
commit
57fa8ea902
@@ -282,9 +282,46 @@ def load_openai_service_settings(
|
||||
"Azure OpenAI client requires either an API key or an Azure AD token provider."
|
||||
" This can be provided either as a callable api_key or via the credential parameter."
|
||||
)
|
||||
|
||||
# The /openai/v1 endpoint exposes an OpenAI-compatible API surface.
|
||||
# AsyncAzureOpenAI rewrites certain request paths (e.g. /embeddings,
|
||||
# /chat/completions) by inserting /deployments/{model}/, which produces
|
||||
# 404s on this endpoint. Use AsyncOpenAI instead so request URLs are
|
||||
# sent as-is. responses_mode is excluded because the Responses API path
|
||||
# (/responses) is not rewritten by the Azure SDK.
|
||||
resolved_base_url = client_args.get("base_url", "")
|
||||
if not responses_mode and resolved_base_url and resolved_base_url.rstrip("/").endswith("/openai/v1"):
|
||||
openai_args: dict[str, Any] = {
|
||||
"base_url": resolved_base_url,
|
||||
"default_headers": client_args.get("default_headers"),
|
||||
}
|
||||
if "azure_ad_token_provider" in client_args:
|
||||
openai_args["api_key"] = _ensure_async_token_provider(client_args["azure_ad_token_provider"])
|
||||
elif "api_key" in client_args:
|
||||
openai_args["api_key"] = client_args["api_key"]
|
||||
return azure_settings, AsyncOpenAI(**openai_args), True # type: ignore[return-value]
|
||||
|
||||
return azure_settings, AsyncAzureOpenAI(**client_args), True # type: ignore[return-value]
|
||||
|
||||
|
||||
def _ensure_async_token_provider(
|
||||
provider: AzureTokenProvider,
|
||||
) -> Callable[[], Awaitable[str]]:
|
||||
"""Wrap a (possibly synchronous) token provider so it always returns an awaitable.
|
||||
|
||||
``AsyncOpenAI`` requires callable ``api_key`` values to return ``Awaitable[str]``.
|
||||
Azure token providers may return a plain ``str``, so this normalises them.
|
||||
"""
|
||||
|
||||
async def _wrapper() -> str:
|
||||
result = provider()
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
return await result
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def _resolve_azure_credential_to_token_provider(
|
||||
credential: AzureCredentialTypes | AzureTokenProvider,
|
||||
) -> AzureTokenProvider:
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
from agent_framework_openai import OpenAIEmbeddingClient, OpenAIEmbeddingOptions
|
||||
|
||||
@@ -196,6 +196,78 @@ def test_openai_base_url_wins_over_azure_aliases(monkeypatch, azure_openai_unit_
|
||||
assert client.azure_endpoint is None
|
||||
|
||||
|
||||
def test_init_with_openai_v1_base_url_and_credential_uses_openai_client(monkeypatch) -> None:
|
||||
for env in [
|
||||
"OPENAI_API_KEY",
|
||||
"OPENAI_ORG_ID",
|
||||
"OPENAI_MODEL",
|
||||
"OPENAI_EMBEDDING_MODEL",
|
||||
"OPENAI_BASE_URL",
|
||||
"AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_BASE_URL",
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_EMBEDDING_MODEL",
|
||||
"AZURE_OPENAI_MODEL",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
"AZURE_OPENAI_CHAT_MODEL",
|
||||
"AZURE_OPENAI_CHAT_COMPLETION_MODEL",
|
||||
]:
|
||||
monkeypatch.delenv(env, raising=False)
|
||||
|
||||
client = OpenAIEmbeddingClient(
|
||||
base_url="https://myproject.openai.azure.com/openai/v1/",
|
||||
model="text-embedding-3-large",
|
||||
credential=lambda: "fake-token",
|
||||
)
|
||||
|
||||
assert client.model == "text-embedding-3-large"
|
||||
assert not isinstance(client.client, AsyncAzureOpenAI)
|
||||
assert isinstance(client.client, AsyncOpenAI)
|
||||
assert client.OTEL_PROVIDER_NAME == "azure.ai.openai"
|
||||
assert str(client.client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_with_openai_v1_base_url_and_api_key_uses_openai_client(monkeypatch) -> None:
|
||||
for env in [
|
||||
"OPENAI_API_KEY",
|
||||
"OPENAI_ORG_ID",
|
||||
"OPENAI_MODEL",
|
||||
"OPENAI_EMBEDDING_MODEL",
|
||||
"OPENAI_BASE_URL",
|
||||
"AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_BASE_URL",
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_EMBEDDING_MODEL",
|
||||
"AZURE_OPENAI_MODEL",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
"AZURE_OPENAI_CHAT_MODEL",
|
||||
"AZURE_OPENAI_CHAT_COMPLETION_MODEL",
|
||||
]:
|
||||
monkeypatch.delenv(env, raising=False)
|
||||
|
||||
# AZURE_OPENAI_BASE_URL + AZURE_OPENAI_API_KEY enter the Azure settings
|
||||
# path without an explicit endpoint parameter; the /openai/v1 suffix
|
||||
# should still produce AsyncOpenAI (not AsyncAzureOpenAI).
|
||||
monkeypatch.setenv("AZURE_OPENAI_BASE_URL", "https://myproject.openai.azure.com/openai/v1/")
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
|
||||
client = OpenAIEmbeddingClient(model="text-embedding-3-large")
|
||||
|
||||
assert client.model == "text-embedding-3-large"
|
||||
assert not isinstance(client.client, AsyncAzureOpenAI)
|
||||
assert isinstance(client.client, AsyncOpenAI)
|
||||
assert str(client.client.base_url).rstrip("/").endswith("/openai/v1")
|
||||
|
||||
|
||||
def test_init_with_azure_endpoint_still_uses_azure_client(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
client = OpenAIEmbeddingClient(
|
||||
azure_endpoint=azure_openai_unit_test_env["AZURE_OPENAI_ENDPOINT"],
|
||||
api_key=azure_openai_unit_test_env["AZURE_OPENAI_API_KEY"],
|
||||
)
|
||||
|
||||
assert isinstance(client.client, AsyncAzureOpenAI)
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_azure_openai_integration_tests_disabled
|
||||
|
||||
@@ -8,7 +8,11 @@ import pytest
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
from agent_framework_openai._shared import AZURE_OPENAI_TOKEN_SCOPE, _resolve_azure_credential_to_token_provider
|
||||
from agent_framework_openai._shared import (
|
||||
AZURE_OPENAI_TOKEN_SCOPE,
|
||||
_ensure_async_token_provider,
|
||||
_resolve_azure_credential_to_token_provider,
|
||||
)
|
||||
|
||||
|
||||
class _AsyncTokenCredentialStub(AsyncTokenCredential):
|
||||
@@ -52,3 +56,23 @@ def test_resolve_azure_callable_token_provider_passthrough() -> None:
|
||||
def test_resolve_azure_invalid_credential_raises() -> None:
|
||||
with pytest.raises(ValueError, match="credential"):
|
||||
_resolve_azure_credential_to_token_provider(object()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_ensure_async_token_provider_wraps_sync_provider() -> None:
|
||||
def sync_provider() -> str:
|
||||
return "sync-token"
|
||||
|
||||
wrapper = _ensure_async_token_provider(sync_provider)
|
||||
result = await wrapper()
|
||||
|
||||
assert result == "sync-token"
|
||||
|
||||
|
||||
async def test_ensure_async_token_provider_wraps_async_provider() -> None:
|
||||
async def async_provider() -> str:
|
||||
return "async-token"
|
||||
|
||||
wrapper = _ensure_async_token_provider(async_provider)
|
||||
result = await wrapper()
|
||||
|
||||
assert result == "async-token"
|
||||
|
||||
Reference in New Issue
Block a user