Python: OpenAI Connector (#144)

* Initial checkin of openai connector

* add tests

* extensions work

* chat completion client implicitly implementing ChatClient

* remove AIServiceClientBase

* remove PromptExecutionSettings

* consolidate chat completion types

* add integration test

* fix pre-commit check errors

* remove usage statistics from OpenAIHandler

* Update python/extensions/agent-framework-openai/agent_framework/openai/exceptions.py

Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>

* PR comments

* fix merge

* fix test import

* remove tests for now because they just fail

* Remove fixed TODO

---------

Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
peterychang
2025-07-11 03:34:25 -04:00
committed by GitHub
Unverified
parent fef4fd2c18
commit e70401e658
15 changed files with 982 additions and 8 deletions
@@ -2,6 +2,7 @@
import importlib
import importlib.metadata
from typing import Any
try:
__version__ = importlib.metadata.version(__name__)
@@ -10,6 +11,8 @@ except importlib.metadata.PackageNotFoundError:
_IMPORTS = {
"get_logger": "._logging",
"AFBaseModel": "._pydantic",
"AFBaseSettings": "._pydantic",
"Agent": "._agents",
"AgentThread": "._agents",
"AITool": "._tools",
@@ -40,10 +43,12 @@ _IMPORTS = {
"EmbeddingGenerator": "._clients",
"InputGuardrail": ".guard_rails",
"OutputGuardrail": ".guard_rails",
"TextToSpeechOptions": "._types",
"SpeechToTextOptions": "._types",
}
def __getattr__(name: str):
def __getattr__(name: str) -> Any:
if name == "__version__":
return __version__
if name in _IMPORTS:
@@ -53,5 +58,5 @@ def __getattr__(name: str):
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
def __dir__() -> list[str]:
return [*list(_IMPORTS.keys()), "__version__"]
@@ -4,6 +4,7 @@ from . import __version__ # type: ignore[attr-defined]
from ._agents import Agent, AgentThread
from ._clients import ChatClient, ChatClientBase, EmbeddingGenerator, use_tool_calling
from ._logging import get_logger
from ._pydantic import AFBaseModel, AFBaseSettings
from ._tools import AITool, ai_function
from ._types import (
AIContent,
@@ -20,9 +21,11 @@ from ._types import (
FunctionCallContent,
FunctionResultContent,
GeneratedEmbeddings,
SpeechToTextOptions,
StructuredResponse,
TextContent,
TextReasoningContent,
TextToSpeechOptions,
UriContent,
UsageContent,
UsageDetails,
@@ -30,6 +33,8 @@ from ._types import (
from .guard_rails import InputGuardrail, OutputGuardrail
__all__ = [
"AFBaseModel",
"AFBaseSettings",
"AIContent",
"AIContents",
"AITool",
@@ -52,9 +57,11 @@ __all__ = [
"GeneratedEmbeddings",
"InputGuardrail",
"OutputGuardrail",
"SpeechToTextOptions",
"StructuredResponse",
"TextContent",
"TextReasoningContent",
"TextToSpeechOptions",
"UriContent",
"UsageContent",
"UsageDetails",
@@ -1,10 +1,69 @@
# Copyright (c) Microsoft. All rights reserved.
from pydantic import BaseModel, ConfigDict
from typing import Any, ClassVar, TypeVar
from pydantic import BaseModel, ConfigDict, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class AFBaseModel(BaseModel):
"""Base class for all pydantic models in the Agent Framework."""
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True, validate_assignment=True)
TSettings = TypeVar("TSettings", bound="AFBaseSettings")
class AFBaseSettings(BaseSettings):
"""Base class for all settings classes in the Agent Framework.
A subclass creates it's fields and overrides the env_prefix class variable
with the prefix for the environment variables.
In the case where a value is specified for the same Settings field in multiple ways,
the selected value is determined as follows (in descending order of priority):
- Arguments passed to the Settings class initializer.
- Environment variables, e.g. my_prefix_special_function as described above.
- Variables loaded from a dotenv (.env) file.
- Variables loaded from the secrets directory.
- The default field values for the Settings model.
"""
env_prefix: ClassVar[str] = ""
env_file_path: str | None = Field(default=None, exclude=True)
env_file_encoding: str | None = Field(default="utf-8", exclude=True)
model_config = SettingsConfigDict(
extra="ignore",
case_sensitive=False,
)
def __init__(
self,
**kwargs: Any,
) -> None:
"""Initialize the settings class."""
# Remove any None values from the kwargs so that defaults are used.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
super().__init__(**kwargs)
def __new__(cls: type["TSettings"], *args: Any, **kwargs: Any) -> "TSettings":
"""Override the __new__ method to set the env_prefix."""
# for both, if supplied but None, set to default
if "env_file_encoding" in kwargs and kwargs["env_file_encoding"] is not None:
env_file_encoding = kwargs["env_file_encoding"]
else:
env_file_encoding = "utf-8"
if "env_file_path" in kwargs and kwargs["env_file_path"] is not None:
env_file_path = kwargs["env_file_path"]
else:
env_file_path = ".env"
cls.model_config.update( # type: ignore
env_prefix=cls.env_prefix,
env_file=env_file_path,
env_file_encoding=env_file_encoding,
)
cls.model_rebuild()
return super().__new__(cls) # type: ignore[return-value]
@@ -1552,3 +1552,76 @@ class GeneratedEmbeddings(AFBaseModel, MutableSequence[TEmbedding], Generic[TEmb
else:
self.embeddings += values
return self
# region: SpeechToTextOptions
class SpeechToTextOptions(AFBaseModel):
"""Common request settings for Speech to Text AI services."""
ai_model_id: Annotated[str | None, Field(serialization_alias="model")] = None
speech_language: Annotated[str | None, Field(description="Language of the input speech.")] = None
text_language: Annotated[str | None, Field(description="Language of the output text.")] = None
speech_sample_rate: Annotated[int | None, Field(description="Sample rate of the input speech.")] = None
additional_properties: dict[str, Any] = Field(
default_factory=dict, description="Provider-specific additional properties."
)
def to_provider_settings(self, by_alias: bool = True, exclude: set[str] | None = None) -> dict[str, Any]:
"""Convert the SpeechToTextOptions to a dictionary suitable for provider requests.
Args:
by_alias: Use alias names for fields if True.
exclude: Additional keys to exclude from the output.
Returns:
Dictionary of settings for provider.
"""
default_exclude = {"additional_properties"}
merged_exclude = default_exclude if exclude is None else default_exclude | set(exclude)
settings: dict[str, Any] = self.model_dump(exclude_none=True, by_alias=by_alias, exclude=merged_exclude)
settings = {k: v for k, v in settings.items() if not (isinstance(v, dict) and not v)}
settings.update(self.additional_properties)
for key in merged_exclude:
settings.pop(key, None)
return settings
# region: TextToSpeechOptions
class TextToSpeechOptions(AFBaseModel):
"""Request settings for text to speech services.
Tailor this to be more general as more models (aside from OpenAI) are added.
"""
ai_model_id: str | None = Field(None, serialization_alias="model")
voice: Literal["alloy", "ash", "ballad", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] = "alloy"
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | None = None
speed: Annotated[float | None, Field(ge=0.25, le=4.0)] = None
additional_properties: dict[str, Any] = Field(
default_factory=dict, description="Provider-specific additional properties."
)
def to_provider_settings(self, by_alias: bool = True, exclude: set[str] | None = None) -> dict[str, Any]:
"""Convert the SpeechToTextOptions to a dictionary suitable for provider requests.
Args:
by_alias: Use alias names for fields if True.
exclude: Additional keys to exclude from the output.
Returns:
Dictionary of settings for provider.
"""
default_exclude = {"additional_properties"}
merged_exclude = default_exclude if exclude is None else default_exclude | set(exclude)
settings: dict[str, Any] = self.model_dump(exclude_none=True, by_alias=by_alias, exclude=merged_exclude)
settings = {k: v for k, v in settings.items() if not (isinstance(v, dict) and not v)}
settings.update(self.additional_properties)
for key in merged_exclude:
settings.pop(key, None)
return settings
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Final
USER_AGENT: Final[str] = "User-Agent"
@@ -17,3 +17,48 @@ class AgentExecutionException(AgentException):
"""An error occurred while executing the agent."""
pass
# region: Service Exceptions
class ServiceException(AgentFrameworkException):
"""Base class for all service exceptions."""
pass
class ServiceInitializationError(ServiceException):
"""An error occurred while initializing the service."""
pass
class ServiceResponseException(ServiceException):
"""Base class for all service response exceptions."""
pass
class ServiceContentFilterException(ServiceResponseException):
"""An error was raised by the content filter of the service."""
pass
class ServiceInvalidExecutionSettingsError(ServiceResponseException):
"""An error occurred while validating the execution settings of the service."""
pass
class ServiceInvalidRequestError(ServiceResponseException):
"""An error occurred while validating the request to the service."""
pass
class ServiceInvalidResponseError(ServiceResponseException):
"""An error occurred while validating the response from the service."""
pass
@@ -2,9 +2,15 @@
import importlib.metadata
from ._chat_completion import OpenAIChatCompletion, OpenAIChatCompletionBase
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = ["__version__"]
__all__ = [
"OpenAIChatCompletion",
"OpenAIChatCompletionBase",
"__version__",
]
@@ -0,0 +1,313 @@
# Copyright (c) Microsoft. All rights reserved.
import json
from collections.abc import AsyncIterable, Mapping, MutableSequence, Sequence
from datetime import datetime
from typing import Any, ClassVar, cast
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidResponseError
from pydantic import SecretStr, ValidationError
from agent_framework import (
ChatClientBase,
ChatFinishReason,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
ChatRole,
FunctionCallContent,
TextContent,
UsageDetails,
)
from openai import AsyncOpenAI, AsyncStream
from openai.types import CompletionUsage
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from ._openai_config_base import OpenAIConfigBase
from ._openai_handler import OpenAIHandler
from ._openai_model_types import OpenAIModelTypes
from ._openai_settings import OpenAISettings
# Implements agent_framework.ChatClient protocol
class OpenAIChatCompletionBase(OpenAIHandler, ChatClientBase):
"""OpenAI Chat completion class."""
MODEL_PROVIDER_NAME: ClassVar[str] = "openai"
SUPPORTS_FUNCTION_CALLING: ClassVar[bool] = True
# region Overriding base class methods
# most of the methods are overridden from the ChatCompletionClientBase class, otherwise it is mentioned
async def _inner_get_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> ChatResponse:
# TODO(peterychang): Is there a better way to handle this?
chat_options.additional_properties = dict(chat_options.additional_properties)
chat_options.additional_properties.update({"stream": False})
chat_options.ai_model_id = chat_options.ai_model_id or self.ai_model_id
response = await self._send_request(chat_options, messages=self._prepare_chat_history_for_request(messages))
assert isinstance(response, ChatCompletion) # nosec # noqa: S101
response_metadata = self._get_metadata_from_chat_response(response)
return next(
self._create_chat_message_content(response, choice, response_metadata) for choice in response.choices
)
# @trace_streaming_chat_completion(MODEL_PROVIDER_NAME)
async def _inner_get_streaming_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
# TODO(peterychang): Is there a better way to handle this?
chat_options.additional_properties = dict(chat_options.additional_properties)
chat_options.additional_properties.update({"stream": True, "stream_options": {"include_usage": True}})
chat_options.ai_model_id = chat_options.ai_model_id or self.ai_model_id
response = await self._send_request(chat_options, messages=self._prepare_chat_history_for_request(messages))
if not isinstance(response, AsyncStream):
raise ServiceInvalidResponseError("Expected an AsyncStream[ChatCompletionChunk] response.")
async for chunk in response:
if len(chunk.choices) == 0 and chunk.usage is None:
continue
assert isinstance(chunk, ChatCompletionChunk) # nosec # noqa: S101
chunk_metadata = self._get_metadata_from_streaming_chat_response(chunk)
if chunk.usage is not None:
# Usage is contained in the last chunk where the choices are empty
# We are duplicating the usage metadata to all the choices in the response
yield ChatResponseUpdate(
role=ChatRole.ASSISTANT,
contents=[],
ai_model_id=chat_options.ai_model_id,
additional_properties=chunk_metadata,
)
else:
yield next(
self._create_streaming_chat_message_content(chunk, choice, chunk_metadata)
for choice in chunk.choices
)
# endregion
# region content creation
def _create_chat_message_content(
self, response: ChatCompletion, choice: Choice, response_metadata: dict[str, Any]
) -> "ChatResponse":
"""Create a chat message content object from a choice."""
metadata = self._get_metadata_from_chat_choice(choice)
metadata.update(response_metadata)
items: list[ChatMessage] = [
ChatMessage(role="assistant", contents=[tool]) for tool in self._get_tool_calls_from_chat_choice(choice)
]
if choice.message.content:
items.append(ChatMessage(role="assistant", text=choice.message.content))
elif hasattr(choice.message, "refusal") and choice.message.refusal:
items.append(ChatMessage(role="assistant", text=choice.message.refusal))
return ChatResponse(
response_id=response.id,
created_at=datetime.fromtimestamp(response.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
usage_details=self._usage_details_from_openai(response.usage) if response.usage else None,
messages=items,
model_id=self.ai_model_id,
additional_properties=metadata,
finish_reason=(ChatFinishReason(value=choice.finish_reason) if choice.finish_reason else None),
)
def _create_streaming_chat_message_content(
self,
chunk: ChatCompletionChunk,
choice: ChunkChoice,
chunk_metadata: dict[str, Any],
) -> ChatResponseUpdate:
"""Create a streaming chat message content object from a choice."""
metadata = self._get_metadata_from_chat_choice(choice)
metadata.update(chunk_metadata)
items: list[Any] = self._get_tool_calls_from_chat_choice(choice)
if choice.delta and choice.delta.content is not None:
items.append(TextContent(text=choice.delta.content))
return ChatResponseUpdate(
created_at=datetime.fromtimestamp(chunk.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
contents=items,
role=ChatRole.ASSISTANT,
ai_model_id=self.ai_model_id,
additional_properties=metadata,
finish_reason=(ChatFinishReason(value=choice.finish_reason) if choice.finish_reason else None),
)
def _usage_details_from_openai(self, usage: CompletionUsage) -> UsageDetails | None:
return UsageDetails(
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
)
def _get_metadata_from_chat_response(self, response: ChatCompletion) -> dict[str, Any]:
"""Get metadata from a chat response."""
return {
"system_fingerprint": response.system_fingerprint,
}
def _get_metadata_from_streaming_chat_response(self, response: ChatCompletionChunk) -> dict[str, Any]:
"""Get metadata from a streaming chat response."""
return {
"system_fingerprint": response.system_fingerprint,
}
def _get_metadata_from_chat_choice(self, choice: Choice | ChunkChoice) -> dict[str, Any]:
"""Get metadata from a chat choice."""
return {
"logprobs": getattr(choice, "logprobs", None),
}
def _get_tool_calls_from_chat_choice(self, choice: Choice | ChunkChoice) -> list[FunctionCallContent]:
"""Get tool calls from a chat choice."""
resp: list[FunctionCallContent] = []
content = choice.message if isinstance(choice, Choice) else choice.delta
if content and (tool_calls := getattr(content, "tool_calls", None)) is not None:
for tool in cast(list[ChatCompletionMessageToolCall] | list[ChoiceDeltaToolCall], tool_calls):
if tool.function:
fcc = FunctionCallContent(
call_id=tool.id if tool.id else "",
name=tool.function.name if tool.function and tool.function.name else "",
arguments=json.loads(tool.function.arguments)
if tool.function and tool.function.arguments
else {},
)
resp.append(fcc)
# When you enable asynchronous content filtering in Azure OpenAI, you may receive empty deltas
return resp
def _prepare_chat_history_for_request(
self,
chat_history: ChatMessage | Sequence[ChatMessage],
role_key: str = "role",
content_key: str = "content",
) -> list[dict[str, Any]]:
"""Prepare the chat history for a request.
Allowing customization of the key names for role/author, and optionally overriding the role.
ChatRole.TOOL messages need to be formatted different than system/user/assistant messages:
They require a "tool_call_id" and (function) "name" key, and the "metadata" key should
be removed. The "encoding" key should also be removed.
Override this method to customize the formatting of the chat history for a request.
Args:
chat_history (list[ChatMessage]): The chat history to prepare.
role_key (str): The key name for the role/author.
content_key (str): The key name for the content/message.
Returns:
prepared_chat_history (Any): The prepared chat history for a request.
"""
# TODO(peterychang): Chat history type is not finalized yet
if not isinstance(chat_history, Sequence):
chat_history = [chat_history]
# TODO(peterychang): This is the bare minimum to get the chat history into a format that OpenAI expects.
return [
{
"role": message.role.value if isinstance(message.role, ChatRole) else message.role,
"content": [content.model_dump() for content in message.contents],
"metadata": message.additional_properties or {},
}
for message in chat_history
]
# endregion
class OpenAIChatCompletion(OpenAIConfigBase, OpenAIChatCompletionBase):
"""OpenAI Chat completion class."""
def __init__(
self,
ai_model_id: str | None = None,
api_key: str | None = None,
org_id: str | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
) -> None:
"""Initialize an OpenAIChatCompletion service.
Args:
ai_model_id (str): OpenAI model name, see
https://platform.openai.com/docs/models
api_key (str | None): The optional API key to use. If provided will override,
the env vars or .env file value.
org_id (str | None): The optional org ID to use. If provided will override,
the env vars or .env file value.
default_headers: The default headers mapping of string keys to
string values for HTTP requests. (Optional)
async_client (Optional[AsyncOpenAI]): An existing client to use. (Optional)
env_file_path (str | None): Use the environment settings file as a fallback
to environment variables. (Optional)
env_file_encoding (str | None): The encoding of the environment settings file. (Optional)
instruction_role (str | None): The role to use for 'instruction' messages, for example,
"""
try:
if api_key:
openai_settings = OpenAISettings(
api_key=SecretStr(api_key),
org_id=org_id,
chat_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
else:
openai_settings = OpenAISettings(
org_id=org_id,
chat_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
if not async_client and not openai_settings.api_key:
raise ServiceInitializationError("The OpenAI API key is required.")
if not openai_settings.chat_model_id:
raise ServiceInitializationError("The OpenAI model ID is required.")
super().__init__(
ai_model_id=openai_settings.chat_model_id,
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
org_id=openai_settings.org_id,
ai_model_type=OpenAIModelTypes.CHAT,
default_headers=default_headers,
client=async_client,
instruction_role=instruction_role,
)
@classmethod
def from_dict(cls, settings: dict[str, Any]) -> "OpenAIChatCompletion":
"""Initialize an Open AI service from a dictionary of settings.
Args:
settings: A dictionary of settings for the service.
"""
return OpenAIChatCompletion(
ai_model_id=settings["ai_model_id"],
default_headers=settings.get("default_headers"),
)
@@ -0,0 +1,97 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from collections.abc import Mapping
from copy import copy
from typing import Any
from agent_framework.exceptions import ServiceInitializationError
from pydantic import ConfigDict, Field, validate_call
from openai import AsyncOpenAI
from ._openai_handler import OpenAIHandler
from ._openai_model_types import OpenAIModelTypes
logger: logging.Logger = logging.getLogger(__name__)
class OpenAIConfigBase(OpenAIHandler):
"""Internal class for configuring a connection to an OpenAI service."""
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
ai_model_id: str = Field(min_length=1),
api_key: str | None = Field(min_length=1),
ai_model_type: OpenAIModelTypes | None = OpenAIModelTypes.CHAT,
org_id: str | None = None,
default_headers: Mapping[str, str] | None = None,
client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a client for OpenAI services.
This constructor sets up a client to interact with OpenAI's API, allowing for
different types of AI model interactions, like chat or text completion.
Args:
ai_model_id (str): OpenAI model identifier. Must be non-empty.
Default to a preset value.
api_key (str): OpenAI API key for authentication.
Must be non-empty. (Optional)
ai_model_type (OpenAIModelTypes): The type of OpenAI
model to interact with. Defaults to CHAT.
org_id (str): OpenAI organization ID. This is optional
unless the account belongs to multiple organizations.
default_headers (Mapping[str, str]): Default headers
for HTTP requests. (Optional)
client (AsyncOpenAI): An existing OpenAI client, optional.
instruction_role (str): The role to use for 'instruction'
messages, for example, summarization prompts could use `developer` or `system`. (Optional)
kwargs: Additional keyword arguments.
"""
# Merge APP_INFO into the headers if it exists
merged_headers = dict(copy(default_headers)) if default_headers else {}
if not client:
if not api_key:
raise ServiceInitializationError("Please provide an api_key")
client = AsyncOpenAI(
api_key=api_key,
organization=org_id,
default_headers=merged_headers,
)
args = {
"ai_model_id": ai_model_id,
"client": client,
"ai_model_type": ai_model_type,
}
if instruction_role:
args["instruction_role"] = instruction_role
super().__init__(**args, **kwargs)
def to_dict(self) -> dict[str, Any]:
"""Create a dict of the service settings."""
client_settings = {
"api_key": self.client.api_key,
"default_headers": {k: v for k, v in self.client.default_headers.items() if k != "User-Agent"},
}
if self.client.organization:
client_settings["org_id"] = self.client.organization
base = self.model_dump(
exclude={
"prompt_tokens",
"completion_tokens",
"total_tokens",
"api_type",
"ai_model_type",
"client",
},
by_alias=True,
exclude_none=True,
)
base.update(client_settings)
return base
@@ -0,0 +1,148 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from abc import ABC
from typing import Annotated, Any, Union
from agent_framework.exceptions import ServiceInvalidRequestError, ServiceResponseException
from pydantic import BaseModel
from pydantic.types import StringConstraints
from agent_framework import AFBaseModel, ChatOptions, SpeechToTextOptions, TextToSpeechOptions
from openai import (
AsyncOpenAI,
AsyncStream,
BadRequestError,
_legacy_response, # type: ignore
)
from openai.lib._parsing._completions import type_to_response_format_param
from openai.types import Completion
from openai.types.audio import Transcription
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.images_response import ImagesResponse
from ._openai_model_types import OpenAIModelTypes
from .exceptions import OpenAIContentFilterException
logger: logging.Logger = logging.getLogger(__name__)
RESPONSE_TYPE = Union[
ChatCompletion,
Completion,
AsyncStream[ChatCompletionChunk],
AsyncStream[Completion],
list[Any],
ImagesResponse,
Transcription,
_legacy_response.HttpxBinaryResponseContent,
]
# TODO(evmattso): update with proper Options types to move away from ExecutionSettings
OPTION_TYPE = Union[
ChatOptions,
SpeechToTextOptions,
TextToSpeechOptions,
]
class OpenAIHandler(AFBaseModel, ABC):
"""Internal class for calls to OpenAI API's."""
client: AsyncOpenAI
ai_model_id: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]
ai_model_type: OpenAIModelTypes = OpenAIModelTypes.CHAT
async def _send_request(self, options: OPTION_TYPE, messages: list[dict[str, Any]] | None = None) -> RESPONSE_TYPE:
"""Send a request to the OpenAI API."""
if self.ai_model_type == OpenAIModelTypes.CHAT:
assert isinstance(options, ChatOptions) # nosec # noqa: S101
return await self._send_completion_request(options, messages)
# TODO(evmattso): move other PromptExecutionSettings to a common options class
if self.ai_model_type == OpenAIModelTypes.EMBEDDING:
raise NotImplementedError("Embedding generation is not yet implemented in OpenAIHandler")
if self.ai_model_type == OpenAIModelTypes.TEXT_TO_IMAGE:
raise NotImplementedError("Text to image generation is not yet implemented in OpenAIHandler")
if self.ai_model_type == OpenAIModelTypes.SPEECH_TO_TEXT:
assert isinstance(options, SpeechToTextOptions) # nosec # noqa: S101
return await self._send_audio_to_text_request(options)
if self.ai_model_type == OpenAIModelTypes.TEXT_TO_SPEECH:
assert isinstance(options, TextToSpeechOptions) # nosec # noqa: S101
return await self._send_text_to_audio_request(options)
raise NotImplementedError(f"Model type {self.ai_model_type} is not supported")
async def _send_completion_request(
self,
chat_options: "ChatOptions",
messages: list[dict[str, Any]] | None = None,
) -> ChatCompletion | Completion | AsyncStream[ChatCompletionChunk] | AsyncStream[Completion]:
"""Execute the appropriate call to OpenAI models."""
try:
options_dict = chat_options.to_provider_settings()
if messages is not None:
options_dict["messages"] = messages
if self.ai_model_type == OpenAIModelTypes.CHAT:
self._handle_structured_outputs(chat_options, options_dict)
if chat_options.tools is None:
options_dict.pop("parallel_tool_calls", None)
response = await self.client.chat.completions.create(**options_dict) # type: ignore
else:
response = await self.client.completions.create(**options_dict) # type: ignore
assert isinstance(response, (ChatCompletion, Completion, AsyncStream)) # nosec # noqa: S101
return response # type: ignore
except BadRequestError as ex:
if ex.code == "content_filter":
raise OpenAIContentFilterException(
f"{type(self)} service encountered a content error",
ex,
) from ex
raise ServiceResponseException(
f"{type(self)} service failed to complete the prompt",
ex,
) from ex
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the prompt",
ex,
) from ex
async def _send_audio_to_text_request(self, options: SpeechToTextOptions) -> Transcription:
"""Send a request to the OpenAI audio to text endpoint."""
if not options.additional_properties["filename"]:
raise ServiceInvalidRequestError("Audio file is required for audio to text service")
try:
# TODO(peterychang): open isn't async safe
with open(options.additional_properties["filename"], "rb") as audio_file: # noqa: ASYNC230
return await self.client.audio.transcriptions.create(
file=audio_file,
**options.to_provider_settings(exclude={"filename"}),
) # type: ignore
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to transcribe audio",
ex,
) from ex
async def _send_text_to_audio_request(
self, options: TextToSpeechOptions
) -> _legacy_response.HttpxBinaryResponseContent:
"""Send a request to the OpenAI text to audio endpoint.
The OpenAI API returns the content of the generated audio file.
"""
try:
return await self.client.audio.speech.create(
**options.to_provider_settings(),
)
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to generate audio",
ex,
) from ex
def _handle_structured_outputs(self, chat_options: "ChatOptions", options_dict: dict[str, Any]) -> None:
response_format = getattr(chat_options, "response_format", None)
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
options_dict["response_format"] = type_to_response_format_param(response_format)
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft. All rights reserved.
from enum import Enum
class OpenAIModelTypes(Enum):
"""OpenAI model types, can be text, chat or embedding."""
CHAT = "chat"
EMBEDDING = "embedding"
TEXT_TO_IMAGE = "text-to-image"
SPEECH_TO_TEXT = "speech-to-text"
TEXT_TO_SPEECH = "text-to-speech"
REALTIME = "realtime"
RESPONSE = "response"
@@ -0,0 +1,54 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import ClassVar
from pydantic import SecretStr
from agent_framework import AFBaseSettings
class OpenAISettings(AFBaseSettings):
"""OpenAI model settings.
The settings are first loaded from environment variables with the prefix 'OPENAI_'.
If the environment variables are not found, the settings can be loaded from a .env file with the
encoding 'utf-8'. If the settings are not found in the .env file, the settings are ignored;
however, validation will fail alerting that the settings are missing.
Optional settings for prefix 'OPENAI_' are:
- api_key: SecretStr - OpenAI API key, see https://platform.openai.com/account/api-keys
(Env var OPENAI_API_KEY)
- org_id: str | None - This is usually optional unless your account belongs to multiple organizations.
(Env var OPENAI_ORG_ID)
- chat_model_id: str | None - The OpenAI chat model ID to use, for example, gpt-3.5-turbo or gpt-4.
(Env var OPENAI_CHAT_MODEL_ID)
- responses_model_id: str | None - The OpenAI responses model ID to use, for example, gpt-4o or o1.
(Env var OPENAI_RESPONSES_MODEL_ID)
- text_model_id: str | None - The OpenAI text model ID to use, for example, gpt-3.5-turbo-instruct.
(Env var OPENAI_TEXT_MODEL_ID)
- embedding_model_id: str | None - The OpenAI embedding model ID to use, for example, text-embedding-ada-002.
(Env var OPENAI_EMBEDDING_MODEL_ID)
- text_to_image_model_id: str | None - The OpenAI text to image model ID to use, for example, dall-e-3.
(Env var OPENAI_TEXT_TO_IMAGE_MODEL_ID)
- audio_to_text_model_id: str | None - The OpenAI audio to text model ID to use, for example, whisper-1.
(Env var OPENAI_AUDIO_TO_TEXT_MODEL_ID)
- text_to_audio_model_id: str | None - The OpenAI text to audio model ID to use, for example, jukebox-1.
(Env var OPENAI_TEXT_TO_AUDIO_MODEL_ID)
- realtime_model_id: str | None - The OpenAI realtime model ID to use,
for example, gpt-4o-realtime-preview-2024-12-17.
(Env var OPENAI_REALTIME_MODEL_ID)
- env_file_path: str | None - if provided, the .env settings are read from this file path location
"""
env_prefix: ClassVar[str] = "OPENAI_"
api_key: SecretStr | None = None
org_id: str | None = None
chat_model_id: str | None = None
responses_model_id: str | None = None
text_model_id: str | None = None
embedding_model_id: str | None = None
text_to_image_model_id: str | None = None
audio_to_text_model_id: str | None = None
text_to_audio_model_id: str | None = None
realtime_model_id: str | None = None
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft. All rights reserved.
from dataclasses import dataclass
from enum import Enum
from typing import Any
from agent_framework.exceptions import ServiceContentFilterException
from openai import BadRequestError
class ContentFilterResultSeverity(Enum):
"""The severity of the content filter result."""
HIGH = "high"
MEDIUM = "medium"
SAFE = "safe"
LOW = "low"
@dataclass
class ContentFilterResult:
"""The result of a content filter check."""
filtered: bool = False
detected: bool = False
severity: ContentFilterResultSeverity = ContentFilterResultSeverity.SAFE
@classmethod
def from_inner_error_result(cls, inner_error_results: dict[str, Any]) -> "ContentFilterResult":
"""Creates a ContentFilterResult from the inner error results.
Args:
key (str): The key to get the inner error result from.
inner_error_results (Dict[str, Any]): The inner error results.
Returns:
ContentFilterResult: The ContentFilterResult.
"""
return cls(
filtered=inner_error_results.get("filtered", False),
detected=inner_error_results.get("detected", False),
severity=ContentFilterResultSeverity(
inner_error_results.get("severity", ContentFilterResultSeverity.SAFE.value)
),
)
class ContentFilterCodes(Enum):
"""Content filter codes."""
RESPONSIBLE_AI_POLICY_VIOLATION = "ResponsibleAIPolicyViolation"
@dataclass
class OpenAIContentFilterException(ServiceContentFilterException):
"""AI exception for an error from Azure OpenAI's content filter."""
# The parameter that caused the error.
param: str | None
# The error code specific to the content filter.
content_filter_code: ContentFilterCodes
# The results of the different content filter checks.
content_filter_result: dict[str, ContentFilterResult]
def __init__(
self,
message: str,
inner_exception: BadRequestError,
) -> None:
"""Initializes a new instance of the ContentFilterAIException class.
Args:
message (str): The error message.
inner_exception (Exception): The inner exception.
"""
super().__init__(message)
self.param = inner_exception.param
if inner_exception.body is not None and isinstance(inner_exception.body, dict):
inner_error = inner_exception.body.get("innererror", {}) # type: ignore
self.content_filter_code = ContentFilterCodes(
inner_error.get("code", ContentFilterCodes.RESPONSIBLE_AI_POLICY_VIOLATION.value) # type: ignore
)
self.content_filter_result = {
key: ContentFilterResult.from_inner_error_result(values) # type: ignore
for key, values in inner_error.get("content_filter_result", {}).items() # type: ignore
}
+57
View File
@@ -0,0 +1,57 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Any
from pytest import fixture
from agent_framework import ChatMessage
# region: Connector Settings fixtures
@fixture
def exclude_list(request: Any) -> list[str]:
"""Fixture that returns a list of environment variables to exclude."""
return request.param if hasattr(request, "param") else []
@fixture
def override_env_param_dict(request: Any) -> dict[str, str]:
"""Fixture that returns a dict of environment variables to override."""
return request.param if hasattr(request, "param") else {}
@fixture()
def openai_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): # type: ignore
"""Fixture to set environment variables for OpenAISettings."""
if exclude_list is None:
exclude_list = []
if override_env_param_dict is None:
override_env_param_dict = {}
env_vars = {
"OPENAI_API_KEY": "test_api_key",
"OPENAI_ORG_ID": "test_org_id",
"OPENAI_RESPONSES_MODEL_ID": "test_responses_model_id",
"OPENAI_CHAT_MODEL_ID": "test_chat_model_id",
"OPENAI_TEXT_MODEL_ID": "test_text_model_id",
"OPENAI_EMBEDDING_MODEL_ID": "test_embedding_model_id",
"OPENAI_TEXT_TO_IMAGE_MODEL_ID": "test_text_to_image_model_id",
"OPENAI_AUDIO_TO_TEXT_MODEL_ID": "test_audio_to_text_model_id",
"OPENAI_TEXT_TO_AUDIO_MODEL_ID": "test_text_to_audio_model_id",
"OPENAI_REALTIME_MODEL_ID": "test_realtime_model_id",
}
env_vars.update(override_env_param_dict) # type: ignore
for key, value in env_vars.items():
if key not in exclude_list:
monkeypatch.setenv(key, value) # type: ignore
else:
monkeypatch.delenv(key, raising=False) # type: ignore
return env_vars
@fixture(scope="function")
def chat_history() -> list["ChatMessage"]:
return []
+4 -4
View File
@@ -639,7 +639,7 @@ name = "exceptiongroup"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'win32')" },
{ name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
wheels = [
@@ -1353,7 +1353,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.93.3"
version = "1.94.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
@@ -1365,9 +1365,9 @@ dependencies = [
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e0/66/fadc0cad6a229c6a85c3aa5f222a786ec4d9bf14c2a004f80ffa21dbaf21/openai-1.93.3.tar.gz", hash = "sha256:488b76399238c694af7e4e30c58170ea55e6f65038ab27dbe95b5077a00f8af8", size = 487595, upload-time = "2025-07-09T14:08:27.789Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/7e/2e36eb5d2e9a18028ee66f2e553c6392ae1775ef9f6aa11f15f1074c7e98/openai-1.94.0.tar.gz", hash = "sha256:31c6c213cc80365d54632296c4aef7cda1800003ca5c784ac50a05d6bc05c197", size = 487682, upload-time = "2025-07-10T14:21:08.686Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8b/b9/0df6351b25c6bd494c534d2a8191dc9460fb5bb09c88b1427775d49fde05/openai-1.93.3-py3-none-any.whl", hash = "sha256:41aaa7594c7d141b46eed0a58dcd75d20edcc809fdd2c931ecbb4957dc98a892", size = 755132, upload-time = "2025-07-09T14:08:25.533Z" },
{ url = "https://files.pythonhosted.org/packages/b7/93/a20d43aa9c6d8b1b1f2a9262da6180b4420ff71fae2e5d14e496022cfe66/openai-1.94.0-py3-none-any.whl", hash = "sha256:159c43b811669abe9bb4aafdc57a049966dfde2eac94b151aac3eb63bf9825b4", size = 755167, upload-time = "2025-07-10T14:21:06.974Z" },
]
[[package]]