mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: add base_url parameter for openai chat client (#661)
* added base_url option for openai chat client * fix test url
This commit is contained in:
committed by
GitHub
Unverified
parent
0fa6883691
commit
383d51443c
@@ -387,6 +387,7 @@ class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
base_url: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -404,6 +405,7 @@ class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
async_client: An existing client to use. (Optional)
|
||||
instruction_role: The role to use for 'instruction' messages, for example,
|
||||
"system" or "developer". If not provided, the default is "system".
|
||||
base_url: The optional base URL to use. If provided will override the standard value for a OpenAI connector.
|
||||
env_file_path: Use the environment settings file as a fallback
|
||||
to environment variables. (Optional)
|
||||
env_file_encoding: The encoding of the environment settings file. (Optional)
|
||||
@@ -436,6 +438,7 @@ class OpenAIChatClient(OpenAIConfigMixin, OpenAIBaseChatClient):
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
instruction_role=instruction_role,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -138,6 +138,7 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
base_url: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a client for OpenAI services.
|
||||
@@ -146,17 +147,19 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
different types of AI model interactions, like chat or text completion.
|
||||
|
||||
Args:
|
||||
ai_model_id (str): OpenAI model identifier. Must be non-empty.
|
||||
ai_model_id: OpenAI model identifier. Must be non-empty.
|
||||
Default to a preset value.
|
||||
api_key (str): OpenAI API key for authentication.
|
||||
api_key: OpenAI API key for authentication.
|
||||
Must be non-empty. (Optional)
|
||||
org_id (str): OpenAI organization ID. This is optional
|
||||
org_id: OpenAI organization ID. This is optional
|
||||
unless the account belongs to multiple organizations.
|
||||
default_headers (Mapping[str, str]): Default headers
|
||||
default_headers: Default headers
|
||||
for HTTP requests. (Optional)
|
||||
client (AsyncOpenAI): An existing OpenAI client, optional.
|
||||
instruction_role (str): The role to use for 'instruction'
|
||||
client: An existing OpenAI client, optional.
|
||||
instruction_role: The role to use for 'instruction'
|
||||
messages, for example, summarization prompts could use `developer` or `system`. (Optional)
|
||||
base_url: The optional base URL to use. If provided will override the standard value for a OpenAI connector.
|
||||
Will not be used when supplying a custom client.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
"""
|
||||
@@ -169,11 +172,12 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
if not client:
|
||||
if not api_key:
|
||||
raise ServiceInitializationError("Please provide an api_key")
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
organization=org_id,
|
||||
default_headers=merged_headers,
|
||||
)
|
||||
args: dict[str, Any] = {"api_key": api_key, "default_headers": merged_headers}
|
||||
if org_id:
|
||||
args["organization"] = org_id
|
||||
if base_url:
|
||||
args["base_url"] = base_url
|
||||
client = AsyncOpenAI(**args)
|
||||
args = {
|
||||
"ai_model_id": ai_model_id,
|
||||
"client": client,
|
||||
|
||||
@@ -74,6 +74,12 @@ def test_init_with_default_header(openai_unit_test_env: dict[str, str]) -> None:
|
||||
assert open_ai_chat_completion.client.default_headers[key] == value
|
||||
|
||||
|
||||
def test_init_base_url(openai_unit_test_env: dict[str, str]) -> None:
|
||||
# Test successful initialization
|
||||
open_ai_chat_completion = OpenAIChatClient(base_url="http://localhost:1234/v1")
|
||||
assert str(open_ai_chat_completion.client.base_url) == "http://localhost:1234/v1/"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["OPENAI_CHAT_MODEL_ID"]], indirect=True)
|
||||
def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
|
||||
@@ -125,7 +125,7 @@ notice-rgx = "^# Copyright \\(c\\) Microsoft\\. All rights reserved\\."
|
||||
min-file-size = 1
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
testpaths = '**/tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
@@ -133,7 +133,7 @@ filterwarnings = []
|
||||
timeout = 120
|
||||
markers = [
|
||||
"azure: marks tests as Azure provider specific",
|
||||
"foundry: marks tests as Foundry provider specific",
|
||||
"foundry: marks tests as Foundry provider specific",
|
||||
"openai: marks tests as OpenAI provider specific",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user