mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
More updates to assistant client
This commit is contained in:
@@ -96,7 +96,7 @@ class FoundrySettings(AFBaseSettings):
|
||||
class FoundryChatClient(ChatClientBase):
|
||||
"""Azure AI Foundry Chat client."""
|
||||
|
||||
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride]
|
||||
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride, misc]
|
||||
client: AIProjectClient = Field(...)
|
||||
credential: AsyncTokenCredential | None = Field(...)
|
||||
agent_id: str | None = Field(default=None)
|
||||
|
||||
@@ -1,22 +1,44 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Mapping, MutableSequence
|
||||
import json
|
||||
from collections.abc import AsyncIterable, Mapping, MutableMapping, MutableSequence
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.beta.threads import (
|
||||
ImageURLContentBlockParam,
|
||||
ImageURLParam,
|
||||
MessageContentPartParam,
|
||||
Run,
|
||||
TextContentBlockParam,
|
||||
)
|
||||
from openai.types.beta.threads.run_create_params import AdditionalMessage
|
||||
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
|
||||
from pydantic import Field, PrivateAttr, SecretStr, ValidationError
|
||||
|
||||
from .._clients import ChatClientBase, use_tool_calling
|
||||
from .._types import ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, TextContent
|
||||
from .._clients import ChatClientBase, ai_function_to_json_schema_spec, use_tool_calling
|
||||
from .._tools import AIFunction, HostedCodeInterpreterTool
|
||||
from .._types import (
|
||||
AIContents,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ChatToolMode,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
TextContent,
|
||||
UriContent,
|
||||
)
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ._shared import OpenAIConfigBase, OpenAIHandler, OpenAIModelTypes, OpenAISettings
|
||||
from ._shared import OpenAIConfigBase, OpenAISettings
|
||||
|
||||
|
||||
@use_tool_calling
|
||||
class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
|
||||
class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase):
|
||||
"""OpenAI Assistants client."""
|
||||
|
||||
MODEL_PROVIDER_NAME: ClassVar[str] = "openai" # type: ignore[reportIncompatibleVariableOverride]
|
||||
MODEL_PROVIDER_NAME: ClassVar[str] = "openai" # type: ignore[reportIncompatibleVariableOverride, misc]
|
||||
assistant_id: str | None = Field(default=None)
|
||||
assistant_name: str | None = Field(default=None)
|
||||
thread_id: str | None = Field(default=None)
|
||||
@@ -38,24 +60,24 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
|
||||
"""Initialize an OpenAI Assistants client.
|
||||
|
||||
Args:
|
||||
ai_model_id (str): OpenAI model name, see
|
||||
ai_model_id: OpenAI model name, see
|
||||
https://platform.openai.com/docs/models
|
||||
assistant_id (str | None): The ID of an OpenAI assistant to use.
|
||||
assistant_id: The ID of an OpenAI assistant to use.
|
||||
If not provided, a new assistant will be created (and deleted after the request).
|
||||
assistant_name (str | None): The name to use when creating new assistants.
|
||||
assistant_name: The name to use when creating new assistants.
|
||||
thread_id: Default thread ID to use for conversations. Can be overridden by
|
||||
conversation_id property from ChatOptions, when making a request.
|
||||
If not provided, a new thread will be created (and deleted after the request).
|
||||
api_key (str | None): The optional API key to use. If provided will override,
|
||||
api_key: The optional API key to use. If provided will override,
|
||||
the env vars or .env file value.
|
||||
org_id (str | None): The optional org ID to use. If provided will override,
|
||||
org_id: The optional org ID to use. If provided will override,
|
||||
the env vars or .env file value.
|
||||
default_headers: The default headers mapping of string keys to
|
||||
string values for HTTP requests. (Optional)
|
||||
async_client (Optional[AsyncOpenAI]): An existing client to use. (Optional)
|
||||
env_file_path (str | None): Use the environment settings file as a fallback
|
||||
async_client: An existing client to use. (Optional)
|
||||
env_file_path: Use the environment settings file as a fallback
|
||||
to environment variables. (Optional)
|
||||
env_file_encoding (str | None): The encoding of the environment settings file. (Optional)
|
||||
env_file_encoding: The encoding of the environment settings file. (Optional)
|
||||
"""
|
||||
try:
|
||||
openai_settings = OpenAISettings(
|
||||
@@ -80,7 +102,6 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
|
||||
thread_id=thread_id, # type: ignore[reportCallIssue]
|
||||
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
|
||||
org_id=openai_settings.org_id,
|
||||
ai_model_type=OpenAIModelTypes.ASSISTANT,
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
)
|
||||
@@ -103,6 +124,20 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
|
||||
chat_options: ChatOptions,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
# Extract necessary state from messages and options
|
||||
run_options, tool_results = self._create_run_options(messages, chat_options, **kwargs)
|
||||
|
||||
# Get the thread ID
|
||||
thread_id: str | None = (
|
||||
chat_options.conversation_id if chat_options.conversation_id is not None else self.thread_id
|
||||
)
|
||||
|
||||
if thread_id is None and tool_results is not None:
|
||||
raise ValueError("No thread ID was provided, but chat messages includes tool results.")
|
||||
|
||||
# Determine which assistant to use and create if needed
|
||||
assistant_id = await self._get_assistant_id_or_create()
|
||||
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="test")])
|
||||
|
||||
async def _get_assistant_id_or_create(self) -> str:
|
||||
@@ -121,3 +156,178 @@ class OpenAIAssistantsClient(OpenAIConfigBase, ChatClientBase, OpenAIHandler):
|
||||
self._should_delete_assistant = True
|
||||
|
||||
return self.assistant_id
|
||||
|
||||
# TODO: _create_agent_stream
|
||||
|
||||
async def _get_active_thread_run(self, thread_id: str | None) -> Run | None:
|
||||
"""Get any active run for the given thread."""
|
||||
if thread_id is None:
|
||||
return None
|
||||
|
||||
async for run in self.client.beta.threads.runs.list(thread_id=thread_id, limit=1, order="desc"): # type: ignore[reportDeprecated]
|
||||
if run.status not in ["completed", "cancelled", "failed", "expired"]:
|
||||
return run
|
||||
return None
|
||||
|
||||
async def _prepare_thread(self, thread_id: str | None, thread_run: Run | None, run_options: dict[str, Any]) -> str:
|
||||
"""Prepare the thread for a new run, creating or cleaning up as needed."""
|
||||
if thread_id is None:
|
||||
# No thread ID was provided, so create a new thread.
|
||||
thread = await self.client.beta.threads.create( # type: ignore[reportDeprecated]
|
||||
messages=run_options["additional_messages"],
|
||||
tool_resources=run_options.get("tool_resources"),
|
||||
metadata=run_options.get("metadata"),
|
||||
)
|
||||
run_options["additional_messages"] = []
|
||||
return thread.id
|
||||
|
||||
if thread_run is not None:
|
||||
# There was an active run; we need to cancel it before starting a new run.
|
||||
await self.client.beta.threads.runs.cancel(run_id=thread_run.id, thread_id=thread_id) # type: ignore[reportDeprecated]
|
||||
|
||||
return thread_id
|
||||
|
||||
# TODO: _process_stream_events
|
||||
|
||||
# TODO: _process_stream_events_from_iterator
|
||||
|
||||
def _create_function_call_contents(self, event_data: Run, response_id: str | None) -> list[AIContents]:
|
||||
"""Create function call contents from a tool action event."""
|
||||
contents: list[AIContents] = []
|
||||
|
||||
if event_data.required_action is not None:
|
||||
for tool_call in event_data.required_action.submit_tool_outputs.tool_calls:
|
||||
call_id = json.dumps([response_id, tool_call.id])
|
||||
function_name = tool_call.function.name
|
||||
function_arguments = json.loads(tool_call.function.arguments)
|
||||
contents.append(FunctionCallContent(call_id=call_id, name=function_name, arguments=function_arguments))
|
||||
|
||||
return contents
|
||||
|
||||
async def _cleanup_assistant_if_needed(self) -> None:
|
||||
"""Clean up the assistant if we created it."""
|
||||
if self._should_delete_assistant and self.assistant_id is not None:
|
||||
await self.client.beta.assistants.delete(self.assistant_id)
|
||||
self.assistant_id = None
|
||||
self._should_delete_assistant = False
|
||||
|
||||
def _create_run_options(
|
||||
self,
|
||||
messages: MutableSequence[ChatMessage],
|
||||
chat_options: ChatOptions | None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[dict[str, Any], list[FunctionResultContent] | None]:
|
||||
run_options: dict[str, Any] = {**kwargs}
|
||||
|
||||
if chat_options is not None:
|
||||
run_options["max_completion_tokens"] = chat_options.max_tokens
|
||||
run_options["model"] = chat_options.ai_model_id
|
||||
run_options["top_p"] = chat_options.top_p
|
||||
run_options["temperature"] = chat_options.temperature
|
||||
run_options["parallel_tool_calls"] = chat_options.allow_multiple_tool_calls
|
||||
|
||||
if chat_options.tool_choice is not None:
|
||||
tool_definitions: list[MutableMapping[str, Any]] = []
|
||||
if chat_options.tool_choice != "none" and chat_options.tools is not None:
|
||||
for tool in chat_options.tools:
|
||||
if isinstance(tool, AIFunction):
|
||||
tool_definitions.append(ai_function_to_json_schema_spec(tool)) # type: ignore
|
||||
elif isinstance(tool, HostedCodeInterpreterTool):
|
||||
tool_definitions.append({"type": "code_interpreter"})
|
||||
elif isinstance(tool, MutableMapping):
|
||||
tool_definitions.append(tool)
|
||||
|
||||
if len(tool_definitions) > 0:
|
||||
run_options["tools"] = tool_definitions
|
||||
|
||||
if chat_options.tool_choice == "none" or chat_options.tool_choice == "auto":
|
||||
run_options["tool_choice"] = chat_options.tool_choice
|
||||
elif (
|
||||
isinstance(chat_options.tool_choice, ChatToolMode)
|
||||
and chat_options.tool_choice == "required"
|
||||
and chat_options.tool_choice.required_function_name is not None
|
||||
):
|
||||
run_options["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": chat_options.tool_choice.required_function_name},
|
||||
}
|
||||
|
||||
if chat_options.response_format is not None:
|
||||
run_options["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": chat_options.response_format.model_json_schema(),
|
||||
}
|
||||
|
||||
instructions: list[str] = []
|
||||
tool_results: list[FunctionResultContent] | None = None
|
||||
|
||||
additional_messages: list[AdditionalMessage] | None = None
|
||||
|
||||
# System/developer messages are turned into instructions,
|
||||
# since there is no such message roles in OpenAI Assistants.
|
||||
# All other messages are added 1:1.
|
||||
for chat_message in messages:
|
||||
if chat_message.role.value in ["system", "developer"]:
|
||||
for text_content in [content for content in chat_message.contents if isinstance(content, TextContent)]:
|
||||
instructions.append(text_content.text)
|
||||
|
||||
continue
|
||||
|
||||
message_contents: list[MessageContentPartParam] = []
|
||||
|
||||
for content in chat_message.contents:
|
||||
if isinstance(content, TextContent):
|
||||
message_contents.append(TextContentBlockParam(type="text", text=content.text))
|
||||
elif isinstance(content, UriContent) and content.has_top_level_media_type("image"):
|
||||
message_contents.append(
|
||||
ImageURLContentBlockParam(type="image_url", image_url=ImageURLParam(url=content.uri))
|
||||
)
|
||||
elif isinstance(content, FunctionResultContent):
|
||||
if tool_results is None:
|
||||
tool_results = []
|
||||
tool_results.append(content)
|
||||
|
||||
if len(message_contents) > 0:
|
||||
if additional_messages is None:
|
||||
additional_messages = []
|
||||
additional_messages.append(AdditionalMessage(role="assistant", content=message_contents))
|
||||
|
||||
if additional_messages is not None:
|
||||
run_options["additional_messages"] = additional_messages
|
||||
|
||||
if len(instructions) > 0:
|
||||
run_options["instructions"] = "".join(instructions)
|
||||
|
||||
return run_options, tool_results
|
||||
|
||||
def _convert_function_results_to_tool_output(
|
||||
self,
|
||||
tool_results: list[FunctionResultContent] | None,
|
||||
) -> tuple[str | None, list[ToolOutput] | None]:
|
||||
run_id: str | None = None
|
||||
tool_outputs: list[ToolOutput] | None = None
|
||||
|
||||
if tool_results:
|
||||
for function_result_content in tool_results:
|
||||
# When creating the FunctionCallContent, we created it with a CallId == [runId, callId].
|
||||
# We need to extract the run ID and ensure that the ToolOutput we send back to Azure
|
||||
# is only the call ID.
|
||||
run_and_call_ids: list[str] = json.loads(function_result_content.call_id)
|
||||
|
||||
if (
|
||||
not run_and_call_ids
|
||||
or len(run_and_call_ids) != 2
|
||||
or not run_and_call_ids[0]
|
||||
or not run_and_call_ids[1]
|
||||
or (run_id is not None and run_id != run_and_call_ids[0])
|
||||
):
|
||||
continue
|
||||
|
||||
run_id = run_and_call_ids[0]
|
||||
call_id = run_and_call_ids[1]
|
||||
|
||||
if tool_outputs is None:
|
||||
tool_outputs = []
|
||||
tool_outputs.append(ToolOutput(tool_call_id=call_id, output=str(function_result_content.result)))
|
||||
|
||||
return run_id, tool_outputs
|
||||
|
||||
@@ -117,7 +117,6 @@ class OpenAIModelTypes(Enum):
|
||||
TEXT_TO_SPEECH = "text-to-speech"
|
||||
REALTIME = "realtime"
|
||||
RESPONSE = "response"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class OpenAIHandler(AFBaseModel, ABC):
|
||||
|
||||
Reference in New Issue
Block a user