More updates to assistant client

This commit is contained in:
Dmytro Struk
2025-07-29 17:57:44 -07:00
Unverified
parent fc2eb3d52b
commit ccdf494e9f
3 changed files with 226 additions and 17 deletions
@@ -96,7 +96,7 @@ class FoundrySettings(AFBaseSettings):
class FoundryChatClient(ChatClientBase):
"""Azure AI Foundry Chat client."""
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride]
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride, misc]
client: AIProjectClient = Field(...)
credential: AsyncTokenCredential | None = Field(...)
agent_id: str | None = Field(default=None)
@@ -1,22 +1,44 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Mapping, MutableSequence
import json
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence
from typing import Any, ClassVar
from openai import AsyncOpenAI
from openai.types.beta.threads import (
ImageURLContentBlockParam,
ImageURLParam,
MessageContentPartParam,
Run,
TextContentBlockParam,
)
from openai.types.beta.threads.run_create_params import AdditionalMessage
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
from .._clients import ChatClientBase, use_tool_calling
from .._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, TextContent
from .._clients import ChatClientBase, ai_function_to_json_schema_spec, use_tool_calling
from .._tools import AIFunction, HostedCodeInterpreterTool
from .._types import (
AIContents,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
ChatToolMode,
FunctionCallContent,
FunctionResultContent,
TextContent,
UriContent,
)
from ..exceptions import ServiceInitializationError
from ._shared import OpenAIConfigBase, OpenAIHandler, OpenAIModelTypes, OpenAISettings
from ._shared import OpenAIConfigBase, OpenAISettings
@use_tool_calling
class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase):
"""OpenAI Assistants client."""
MODEL_PROVIDER_NAME: ClassVar[str] = "openai" # type: ignore[reportIncompatibleVariableOverride]
MODEL_PROVIDER_NAME: ClassVar[str] = "openai" # type: ignore[reportIncompatibleVariableOverride, misc]
assistant_id: str | None = Field(default=None)
assistant_name: str | None = Field(default=None)
thread_id: str | None = Field(default=None)
@@ -38,24 +60,24 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
"""Initialize an OpenAI Assistants client.
Args:
ai_model_id (str): OpenAI model name, see
ai_model_id: OpenAI model name, see
https://platform.openai.com/docs/models
assistant_id (str | None): The ID of an OpenAI assistant to use.
assistant_id: The ID of an OpenAI assistant to use.
If not provided, a new assistant will be created (and deleted after the request).
assistant_name (str | None): The name to use when creating new assistants.
assistant_name: The name to use when creating new assistants.
thread_id: Default thread ID to use for conversations. Can be overridden by
conversation_id property from ChatOptions, when making a request.
If not provided, a new thread will be created (and deleted after the request).
api_key (str | None): The optional API key to use. If provided will override,
api_key: The optional API key to use. If provided will override,
the env vars or .env file value.
org_id (str | None): The optional org ID to use. If provided will override,
org_id: The optional org ID to use. If provided will override,
the env vars or .env file value.
default_headers: The default headers mapping of string keys to
string values for HTTP requests. (Optional)
async_client (Optional[AsyncOpenAI]): An existing client to use. (Optional)
env_file_path (str | None): Use the environment settings file as a fallback
async_client: An existing client to use. (Optional)
env_file_path: Use the environment settings file as a fallback
to environment variables. (Optional)
env_file_encoding (str | None): The encoding of the environment settings file. (Optional)
env_file_encoding: The encoding of the environment settings file. (Optional)
"""
try:
openai_settings = OpenAISettings(
@@ -80,7 +102,6 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
thread_id=thread_id, # type: ignore[reportCallIssue]
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
org_id=openai_settings.org_id,
ai_model_type=OpenAIModelTypes.ASSISTANT,
default_headers=default_headers,
client=async_client,
)
@@ -103,6 +124,20 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
chat_options: ChatOptions,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
# Extract necessary state from messages and options
run_options, tool_results = self._create_run_options(messages, chat_options, **kwargs)
# Get the thread ID
thread_id: str | None = (
chat_options.conversation_id if chat_options.conversation_id is not None else self.thread_id
)
if thread_id is None and tool_results is not None:
raise ValueError("No thread ID was provided, but chat messages includes tool results.")
# Determine which assistant to use and create if needed
assistant_id = await self._get_assistant_id_or_create()
yield ChatResponseUpdate(contents=[TextContent(text="test")])
async def _get_assistant_id_or_create(self) -> str:
@@ -121,3 +156,178 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
self._should_delete_assistant = True
return self.assistant_id
# TODO: _create_agent_stream
async def _get_active_thread_run(self, thread_id: str | None) -> Run | None:
"""Get any active run for the given thread."""
if thread_id is None:
return None
async for run in self.client.beta.threads.runs.list(thread_id=thread_id, limit=1, order="desc"): # type: ignore[reportDeprecated]
if run.status not in ["completed", "cancelled", "failed", "expired"]:
return run
return None
async def _prepare_thread(self, thread_id: str | None, thread_run: Run | None, run_options: dict[str, Any]) -> str:
"""Prepare the thread for a new run, creating or cleaning up as needed."""
if thread_id is None:
# No thread ID was provided, so create a new thread.
thread = await self.client.beta.threads.create( # type: ignore[reportDeprecated]
messages=run_options["additional_messages"],
tool_resources=run_options.get("tool_resources"),
metadata=run_options.get("metadata"),
)
run_options["additional_messages"] = []
return thread.id
if thread_run is not None:
# There was an active run; we need to cancel it before starting a new run.
await self.client.beta.threads.runs.cancel(run_id=thread_run.id, thread_id=thread_id) # type: ignore[reportDeprecated]
return thread_id
# TODO: _process_stream_events
# TODO: _process_stream_events_from_iterator
def _create_function_call_contents(self, event_data: Run, response_id: str | None) -> list[AIContents]:
"""Create function call contents from a tool action event."""
contents: list[AIContents] = []
if event_data.required_action is not None:
for tool_call in event_data.required_action.submit_tool_outputs.tool_calls:
call_id = json.dumps([response_id, tool_call.id])
function_name = tool_call.function.name
function_arguments = json.loads(tool_call.function.arguments)
contents.append(FunctionCallContent(call_id=call_id, name=function_name, arguments=function_arguments))
return contents
async def _cleanup_assistant_if_needed(self) -> None:
"""Clean up the assistant if we created it."""
if self._should_delete_assistant and self.assistant_id is not None:
await self.client.beta.assistants.delete(self.assistant_id)
self.assistant_id = None
self._should_delete_assistant = False
def _create_run_options(
self,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions | None,
**kwargs: Any,
) -> tuple[dict[str, Any], list[FunctionResultContent] | None]:
run_options: dict[str, Any] = {**kwargs}
if chat_options is not None:
run_options["max_completion_tokens"] = chat_options.max_tokens
run_options["model"] = chat_options.ai_model_id
run_options["top_p"] = chat_options.top_p
run_options["temperature"] = chat_options.temperature
run_options["parallel_tool_calls"] = chat_options.allow_multiple_tool_calls
if chat_options.tool_choice is not None:
tool_definitions: list[MutableMapping[str, Any]] = []
if chat_options.tool_choice != "none" and chat_options.tools is not None:
for tool in chat_options.tools:
if isinstance(tool, AIFunction):
tool_definitions.append(ai_function_to_json_schema_spec(tool)) # type: ignore
elif isinstance(tool, HostedCodeInterpreterTool):
tool_definitions.append({"type": "code_interpreter"})
elif isinstance(tool, MutableMapping):
tool_definitions.append(tool)
if len(tool_definitions) > 0:
run_options["tools"] = tool_definitions
if chat_options.tool_choice == "none" or chat_options.tool_choice == "auto":
run_options["tool_choice"] = chat_options.tool_choice
elif (
isinstance(chat_options.tool_choice, ChatToolMode)
and chat_options.tool_choice == "required"
and chat_options.tool_choice.required_function_name is not None
):
run_options["tool_choice"] = {
"type": "function",
"function": {"name": chat_options.tool_choice.required_function_name},
}
if chat_options.response_format is not None:
run_options["response_format"] = {
"type": "json_schema",
"json_schema": chat_options.response_format.model_json_schema(),
}
instructions: list[str] = []
tool_results: list[FunctionResultContent] | None = None
additional_messages: list[AdditionalMessage] | None = None
# System/developer messages are turned into instructions,
# since there is no such message roles in OpenAI Assistants.
# All other messages are added 1:1.
for chat_message in messages:
if chat_message.role.value in ["system", "developer"]:
for text_content in [content for content in chat_message.contents if isinstance(content, TextContent)]:
instructions.append(text_content.text)
continue
message_contents: list[MessageContentPartParam] = []
for content in chat_message.contents:
if isinstance(content, TextContent):
message_contents.append(TextContentBlockParam(type="text", text=content.text))
elif isinstance(content, UriContent) and content.has_top_level_media_type("image"):
message_contents.append(
ImageURLContentBlockParam(type="image_url", image_url=ImageURLParam(url=content.uri))
)
elif isinstance(content, FunctionResultContent):
if tool_results is None:
tool_results = []
tool_results.append(content)
if len(message_contents) > 0:
if additional_messages is None:
additional_messages = []
additional_messages.append(AdditionalMessage(role="assistant", content=message_contents))
if additional_messages is not None:
run_options["additional_messages"] = additional_messages
if len(instructions) > 0:
run_options["instructions"] = "".join(instructions)
return run_options, tool_results
def _convert_function_results_to_tool_output(
self,
tool_results: list[FunctionResultContent] | None,
) -> tuple[str | None, list[ToolOutput] | None]:
run_id: str | None = None
tool_outputs: list[ToolOutput] | None = None
if tool_results:
for function_result_content in tool_results:
# When creating the FunctionCallContent, we created it with a CallId == [runId, callId].
# We need to extract the run ID and ensure that the ToolOutput we send back to Azure
# is only the call ID.
run_and_call_ids: list[str] = json.loads(function_result_content.call_id)
if (
not run_and_call_ids
or len(run_and_call_ids) != 2
or not run_and_call_ids[0]
or not run_and_call_ids[1]
or (run_id is not None and run_id != run_and_call_ids[0])
):
continue
run_id = run_and_call_ids[0]
call_id = run_and_call_ids[1]
if tool_outputs is None:
tool_outputs = []
tool_outputs.append(ToolOutput(tool_call_id=call_id, output=str(function_result_content.result)))
return run_id, tool_outputs
@@ -117,7 +117,6 @@ class OpenAIModelTypes(Enum):
TEXT_TO_SPEECH = "text-to-speech"
REALTIME = "realtime"
RESPONSE = "response"
ASSISTANT = "assistant"
class OpenAIHandler(AFBaseModel, ABC):