Python: OpenAI Clients accepting api_key callback (#1139)

* openai clients accepting callable

* fixes

* version checks
This commit is contained in:
Giles Odigwe
2025-10-06 12:51:34 -07:00
committed by GitHub
Unverified
parent ee8d95b3dc
commit 4e6ae443cf
7 changed files with 107 additions and 14 deletions
@@ -2,7 +2,7 @@
import json
import sys
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence
from typing import Any
from openai import AsyncOpenAI
@@ -65,7 +65,7 @@ class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
assistant_id: str | None = None,
assistant_name: str | None = None,
thread_id: str | None = None,
api_key: str | None = None,
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
org_id: str | None = None,
base_url: str | None = None,
default_headers: Mapping[str, str] | None = None,
@@ -139,7 +139,7 @@ class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
super().__init__(
model_id=openai_settings.chat_model_id,
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
api_key=self._get_api_key(openai_settings.api_key),
org_id=openai_settings.org_id,
default_headers=default_headers,
client=async_client,
@@ -2,7 +2,7 @@
import json
import sys
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence, Sequence
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
from datetime import datetime
from itertools import chain
from typing import Any, TypeVar
@@ -467,7 +467,7 @@ class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
self,
*,
model_id: str | None = None,
api_key: str | None = None,
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
org_id: str | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
@@ -537,7 +537,7 @@ class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
super().__init__(
model_id=openai_settings.chat_model_id,
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
api_key=self._get_api_key(openai_settings.api_key),
base_url=openai_settings.base_url if openai_settings.base_url else None,
org_id=openai_settings.org_id,
default_headers=default_headers,
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence, Sequence
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
from datetime import datetime
from itertools import chain
from typing import Any, TypeVar
@@ -947,7 +947,7 @@ class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
self,
*,
model_id: str | None = None,
api_key: str | None = None,
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
org_id: str | None = None,
base_url: str | None = None,
default_headers: Mapping[str, str] | None = None,
@@ -1018,7 +1018,7 @@ class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
super().__init__(
model_id=openai_settings.responses_model_id,
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
api_key=self._get_api_key(openai_settings.api_key),
org_id=openai_settings.org_id,
default_headers=default_headers,
client=async_client,
@@ -1,10 +1,11 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from copy import copy
from typing import Any, ClassVar, Union
import openai
from openai import (
AsyncOpenAI,
AsyncStream,
@@ -16,6 +17,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.images_response import ImagesResponse
from openai.types.responses.response import Response
from openai.types.responses.response_stream_event import ResponseStreamEvent
from packaging import version
from pydantic import SecretStr
from .._logging import get_logger
@@ -49,6 +51,29 @@ __all__ = [
]
def _check_openai_version_for_callable_api_key() -> None:
"""Check if OpenAI version supports callable API keys.
Callable API keys require OpenAI >= 1.106.0.
If the version is too old, raise a ServiceInitializationError with helpful message.
"""
try:
current_version = version.parse(openai.__version__)
min_required_version = version.parse("1.106.0")
if current_version < min_required_version:
raise ServiceInitializationError(
f"Callable API keys require OpenAI SDK >= 1.106.0, but you have {openai.__version__}. "
f"Please upgrade with 'pip install openai>=1.106.0' or provide a string API key instead. "
f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=0.1.118 "
f"to allow newer OpenAI versions."
)
except ServiceInitializationError:
raise # Re-raise our own exception
except Exception as e:
logger.warning(f"Could not check OpenAI version for callable API key support: {e}")
class OpenAISettings(AFBaseSettings):
"""OpenAI environment settings.
@@ -90,7 +115,7 @@ class OpenAISettings(AFBaseSettings):
env_prefix: ClassVar[str] = "OPENAI_"
api_key: SecretStr | None = None
api_key: SecretStr | Callable[[], str | Awaitable[str]] | None = None
base_url: str | None = None
org_id: str | None = None
chat_model_id: str | None = None
@@ -137,6 +162,28 @@ class OpenAIBase(SerializationMixin):
for key, value in kwargs.items():
setattr(self, key, value)
def _get_api_key(
self, api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None
) -> str | Callable[[], str | Awaitable[str]] | None:
"""Get the appropriate API key value for client initialization.
Args:
api_key: The API key parameter which can be a string, SecretStr, callable, or None.
Returns:
For callable API keys: returns the callable directly.
For SecretStr API keys: returns the string value.
For string/None API keys: returns as-is.
"""
if isinstance(api_key, SecretStr):
return api_key.get_secret_value()
# Check version compatibility for callable API keys
if callable(api_key):
_check_openai_version_for_callable_api_key()
return api_key # Pass callable, string, or None directly to OpenAI SDK
class OpenAIConfigMixin(OpenAIBase):
"""Internal class for configuring a connection to an OpenAI service."""
@@ -146,7 +193,7 @@ class OpenAIConfigMixin(OpenAIBase):
def __init__(
self,
model_id: str,
api_key: str | None = None,
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
org_id: str | None = None,
default_headers: Mapping[str, str] | None = None,
client: AsyncOpenAI | None = None,
@@ -162,7 +209,7 @@ class OpenAIConfigMixin(OpenAIBase):
Args:
model_id: OpenAI model identifier. Must be non-empty.
Default to a preset value.
api_key: OpenAI API key for authentication.
api_key: OpenAI API key for authentication, or a callable that returns an API key.
Must be non-empty. (Optional)
org_id: OpenAI organization ID. This is optional
unless the account belongs to multiple organizations.
@@ -182,10 +229,13 @@ class OpenAIConfigMixin(OpenAIBase):
merged_headers.update(APP_INFO)
merged_headers = prepend_agent_framework_to_user_agent(merged_headers)
# Handle callable API key using base class method
api_key_value = self._get_api_key(api_key)
if not client:
if not api_key:
raise ServiceInitializationError("Please provide an api_key")
args: dict[str, Any] = {"api_key": api_key, "default_headers": merged_headers}
args: dict[str, Any] = {"api_key": api_key_value, "default_headers": merged_headers}
if org_id:
args["organization"] = org_id
if base_url:
@@ -1258,3 +1258,18 @@ async def test_openai_assistants_client_agent_level_tool_persistence():
assert second_response.text is not None
# Should use the agent-level weather tool again
assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"])
# Callable API Key Tests
def test_openai_assistants_client_with_callable_api_key() -> None:
"""Test OpenAIAssistantsClient initialization with callable API key."""
async def get_api_key() -> str:
return "test-api-key-123"
client = OpenAIAssistantsClient(model_id="gpt-4o", api_key=get_api_key)
# Verify client was created successfully
assert client.model_id == "gpt-4o"
# OpenAI SDK now manages callable API keys internally
assert client.client is not None
@@ -783,3 +783,17 @@ def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str
# Should use custom filename
assert result["type"] == "file"
assert result["file"]["filename"] == "report.pdf"
def test_openai_chat_client_with_callable_api_key() -> None:
"""Test OpenAIChatClient initialization with callable API key."""
async def get_api_key() -> str:
return "test-api-key-123"
client = OpenAIChatClient(model_id="gpt-4o", api_key=get_api_key)
# Verify client was created successfully
assert client.model_id == "gpt-4o"
# OpenAI SDK now manages callable API keys internally
assert client.client is not None
@@ -2088,3 +2088,17 @@ def test_prepare_options_store_parameter_handling() -> None:
options = client._prepare_options(messages, chat_options) # type: ignore
assert options["store"] is False
assert "previous_response_id" not in options
def test_openai_responses_client_with_callable_api_key() -> None:
"""Test OpenAIResponsesClient initialization with callable API key."""
async def get_api_key() -> str:
return "test-api-key-123"
client = OpenAIResponsesClient(model_id="gpt-4o", api_key=get_api_key)
# Verify client was created successfully
assert client.model_id == "gpt-4o"
# OpenAI SDK now manages callable API keys internally
assert client.client is not None