mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: OpenAI Clients accepting api_key callback (#1139)
* openai clients accepting callable * fixes * version checks
This commit is contained in:
committed by
GitHub
Unverified
parent
ee8d95b3dc
commit
4e6ae443cf
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
@@ -65,7 +65,7 @@ class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
|
||||
assistant_id: str | None = None,
|
||||
assistant_name: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
|
||||
org_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
@@ -139,7 +139,7 @@ class OpenAIAssistantsClient(OpenAIConfigMixin, BaseChatClient):
|
||||
|
||||
super().__init__(
|
||||
model_id=openai_settings.chat_model_id,
|
||||
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
|
||||
api_key=self._get_api_key(openai_settings.api_key),
|
||||
org_id=openai_settings.org_id,
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from typing import Any, TypeVar
|
||||
@@ -467,7 +467,7 @@ class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
self,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
|
||||
org_id: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
@@ -537,7 +537,7 @@ class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
|
||||
super().__init__(
|
||||
model_id=openai_settings.chat_model_id,
|
||||
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
|
||||
api_key=self._get_api_key(openai_settings.api_key),
|
||||
base_url=openai_settings.base_url if openai_settings.base_url else None,
|
||||
org_id=openai_settings.org_id,
|
||||
default_headers=default_headers,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from typing import Any, TypeVar
|
||||
@@ -947,7 +947,7 @@ class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
self,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
|
||||
org_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
@@ -1018,7 +1018,7 @@ class OpenAIResponsesClient(OpenAIConfigMixin, OpenAIBaseResponsesClient):
|
||||
|
||||
super().__init__(
|
||||
model_id=openai_settings.responses_model_id,
|
||||
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
|
||||
api_key=self._get_api_key(openai_settings.api_key),
|
||||
org_id=openai_settings.org_id,
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, ClassVar, Union
|
||||
|
||||
import openai
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
@@ -16,6 +17,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.responses.response import Response
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
from packaging import version
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .._logging import get_logger
|
||||
@@ -49,6 +51,29 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _check_openai_version_for_callable_api_key() -> None:
|
||||
"""Check if OpenAI version supports callable API keys.
|
||||
|
||||
Callable API keys require OpenAI >= 1.106.0.
|
||||
If the version is too old, raise a ServiceInitializationError with helpful message.
|
||||
"""
|
||||
try:
|
||||
current_version = version.parse(openai.__version__)
|
||||
min_required_version = version.parse("1.106.0")
|
||||
|
||||
if current_version < min_required_version:
|
||||
raise ServiceInitializationError(
|
||||
f"Callable API keys require OpenAI SDK >= 1.106.0, but you have {openai.__version__}. "
|
||||
f"Please upgrade with 'pip install openai>=1.106.0' or provide a string API key instead. "
|
||||
f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=0.1.118 "
|
||||
f"to allow newer OpenAI versions."
|
||||
)
|
||||
except ServiceInitializationError:
|
||||
raise # Re-raise our own exception
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check OpenAI version for callable API key support: {e}")
|
||||
|
||||
|
||||
class OpenAISettings(AFBaseSettings):
|
||||
"""OpenAI environment settings.
|
||||
|
||||
@@ -90,7 +115,7 @@ class OpenAISettings(AFBaseSettings):
|
||||
|
||||
env_prefix: ClassVar[str] = "OPENAI_"
|
||||
|
||||
api_key: SecretStr | None = None
|
||||
api_key: SecretStr | Callable[[], str | Awaitable[str]] | None = None
|
||||
base_url: str | None = None
|
||||
org_id: str | None = None
|
||||
chat_model_id: str | None = None
|
||||
@@ -137,6 +162,28 @@ class OpenAIBase(SerializationMixin):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def _get_api_key(
|
||||
self, api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None
|
||||
) -> str | Callable[[], str | Awaitable[str]] | None:
|
||||
"""Get the appropriate API key value for client initialization.
|
||||
|
||||
Args:
|
||||
api_key: The API key parameter which can be a string, SecretStr, callable, or None.
|
||||
|
||||
Returns:
|
||||
For callable API keys: returns the callable directly.
|
||||
For SecretStr API keys: returns the string value.
|
||||
For string/None API keys: returns as-is.
|
||||
"""
|
||||
if isinstance(api_key, SecretStr):
|
||||
return api_key.get_secret_value()
|
||||
|
||||
# Check version compatibility for callable API keys
|
||||
if callable(api_key):
|
||||
_check_openai_version_for_callable_api_key()
|
||||
|
||||
return api_key # Pass callable, string, or None directly to OpenAI SDK
|
||||
|
||||
|
||||
class OpenAIConfigMixin(OpenAIBase):
|
||||
"""Internal class for configuring a connection to an OpenAI service."""
|
||||
@@ -146,7 +193,7 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
api_key: str | None = None,
|
||||
api_key: str | Callable[[], str | Awaitable[str]] | None = None,
|
||||
org_id: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
client: AsyncOpenAI | None = None,
|
||||
@@ -162,7 +209,7 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
Args:
|
||||
model_id: OpenAI model identifier. Must be non-empty.
|
||||
Default to a preset value.
|
||||
api_key: OpenAI API key for authentication.
|
||||
api_key: OpenAI API key for authentication, or a callable that returns an API key.
|
||||
Must be non-empty. (Optional)
|
||||
org_id: OpenAI organization ID. This is optional
|
||||
unless the account belongs to multiple organizations.
|
||||
@@ -182,10 +229,13 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
merged_headers.update(APP_INFO)
|
||||
merged_headers = prepend_agent_framework_to_user_agent(merged_headers)
|
||||
|
||||
# Handle callable API key using base class method
|
||||
api_key_value = self._get_api_key(api_key)
|
||||
|
||||
if not client:
|
||||
if not api_key:
|
||||
raise ServiceInitializationError("Please provide an api_key")
|
||||
args: dict[str, Any] = {"api_key": api_key, "default_headers": merged_headers}
|
||||
args: dict[str, Any] = {"api_key": api_key_value, "default_headers": merged_headers}
|
||||
if org_id:
|
||||
args["organization"] = org_id
|
||||
if base_url:
|
||||
|
||||
@@ -1258,3 +1258,18 @@ async def test_openai_assistants_client_agent_level_tool_persistence():
|
||||
assert second_response.text is not None
|
||||
# Should use the agent-level weather tool again
|
||||
assert any(term in second_response.text.lower() for term in ["miami", "sunny", "72"])
|
||||
|
||||
|
||||
# Callable API Key Tests
|
||||
def test_openai_assistants_client_with_callable_api_key() -> None:
|
||||
"""Test OpenAIAssistantsClient initialization with callable API key."""
|
||||
|
||||
async def get_api_key() -> str:
|
||||
return "test-api-key-123"
|
||||
|
||||
client = OpenAIAssistantsClient(model_id="gpt-4o", api_key=get_api_key)
|
||||
|
||||
# Verify client was created successfully
|
||||
assert client.model_id == "gpt-4o"
|
||||
# OpenAI SDK now manages callable API keys internally
|
||||
assert client.client is not None
|
||||
|
||||
@@ -783,3 +783,17 @@ def test_openai_content_parser_data_content_image(openai_unit_test_env: dict[str
|
||||
# Should use custom filename
|
||||
assert result["type"] == "file"
|
||||
assert result["file"]["filename"] == "report.pdf"
|
||||
|
||||
|
||||
def test_openai_chat_client_with_callable_api_key() -> None:
|
||||
"""Test OpenAIChatClient initialization with callable API key."""
|
||||
|
||||
async def get_api_key() -> str:
|
||||
return "test-api-key-123"
|
||||
|
||||
client = OpenAIChatClient(model_id="gpt-4o", api_key=get_api_key)
|
||||
|
||||
# Verify client was created successfully
|
||||
assert client.model_id == "gpt-4o"
|
||||
# OpenAI SDK now manages callable API keys internally
|
||||
assert client.client is not None
|
||||
|
||||
@@ -2088,3 +2088,17 @@ def test_prepare_options_store_parameter_handling() -> None:
|
||||
options = client._prepare_options(messages, chat_options) # type: ignore
|
||||
assert options["store"] is False
|
||||
assert "previous_response_id" not in options
|
||||
|
||||
|
||||
def test_openai_responses_client_with_callable_api_key() -> None:
|
||||
"""Test OpenAIResponsesClient initialization with callable API key."""
|
||||
|
||||
async def get_api_key() -> str:
|
||||
return "test-api-key-123"
|
||||
|
||||
client = OpenAIResponsesClient(model_id="gpt-4o", api_key=get_api_key)
|
||||
|
||||
# Verify client was created successfully
|
||||
assert client.model_id == "gpt-4o"
|
||||
# OpenAI SDK now manages callable API keys internally
|
||||
assert client.client is not None
|
||||
|
||||
Reference in New Issue
Block a user