mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
8f27e63df6
* added create_agent * add minimal sample * allow multiple annotations * improved docstring
701 lines
28 KiB
Python
701 lines
28 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import asyncio
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, MutableSequence, Sequence
|
|
from functools import wraps
|
|
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from ._logging import get_logger
|
|
from ._pydantic import AFBaseModel
|
|
from ._tools import AIFunction, AITool
|
|
from ._types import (
|
|
AIContents,
|
|
ChatMessage,
|
|
ChatOptions,
|
|
ChatResponse,
|
|
ChatResponseUpdate,
|
|
ChatToolMode,
|
|
FunctionCallContent,
|
|
FunctionResultContent,
|
|
GeneratedEmbeddings,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from ._agents import ChatClientAgent
|
|
|
|
TInput = TypeVar("TInput", contravariant=True)
|
|
TEmbedding = TypeVar("TEmbedding")
|
|
TChatClientBase = TypeVar("TChatClientBase", bound="ChatClientBase")
|
|
|
|
logger = get_logger()
|
|
|
|
__all__ = [
|
|
"ChatClient",
|
|
"ChatClientBase",
|
|
"EmbeddingGenerator",
|
|
"use_tool_calling",
|
|
]
|
|
|
|
# region: Tool Calling Functions and Decorators
|
|
|
|
|
|
async def _auto_invoke_function(
|
|
function_call_content: FunctionCallContent,
|
|
custom_args: dict[str, Any] | None = None,
|
|
*,
|
|
tool_map: dict[str, AIFunction[BaseModel, Any]],
|
|
sequence_index: int | None = None,
|
|
request_index: int | None = None,
|
|
) -> AIContents:
|
|
"""Invoke a function call requested by the agent, applying filters that are defined in the agent."""
|
|
tool: AIFunction[BaseModel, Any] | None = tool_map.get(function_call_content.name)
|
|
if tool is None:
|
|
raise KeyError(f"No tool or function named '{function_call_content.name}'")
|
|
|
|
parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {})
|
|
|
|
# Merge with user-supplied args; right-hand side dominates, so parsed args win on conflicts.
|
|
merged_args: dict[str, Any] = (custom_args or {}) | parsed_args
|
|
args = tool.input_model.model_validate(merged_args)
|
|
exception = None
|
|
try:
|
|
function_result = await tool.invoke(arguments=args, tool_call_id=function_call_content.call_id)
|
|
except Exception as ex:
|
|
exception = ex
|
|
function_result = None
|
|
return FunctionResultContent(
|
|
call_id=function_call_content.call_id,
|
|
exception=exception,
|
|
result=function_result,
|
|
)
|
|
|
|
|
|
def ai_function_to_json_schema_spec(function: AIFunction[BaseModel, Any]) -> dict[str, Any]:
|
|
"""Convert a AIFunction to the JSON Schema function specification format."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": function.name,
|
|
"description": function.description,
|
|
"parameters": function.parameters(),
|
|
},
|
|
}
|
|
|
|
|
|
def _tool_call_non_streaming(
|
|
func: Callable[..., Awaitable["ChatResponse"]],
|
|
) -> Callable[..., Awaitable["ChatResponse"]]:
|
|
"""Decorate the internal _inner_get_response method to enable tool calls."""
|
|
|
|
@wraps(func)
|
|
async def wrapper(
|
|
self: "ChatClientBase",
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
response: ChatResponse | None = None
|
|
fcc_messages: list[ChatMessage] = []
|
|
for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)):
|
|
response = await func(self, messages=messages, chat_options=chat_options)
|
|
# if there are function calls, we will handle them first
|
|
function_results = {
|
|
it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent)
|
|
}
|
|
function_calls = [
|
|
it
|
|
for it in response.messages[0].contents
|
|
if isinstance(it, FunctionCallContent) and it.call_id not in function_results
|
|
]
|
|
if function_calls:
|
|
# Run all function calls concurrently
|
|
results = await asyncio.gather(*[
|
|
_auto_invoke_function(
|
|
function_call,
|
|
custom_args=kwargs,
|
|
tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)}, # type: ignore[reportPrivateUsage]
|
|
sequence_index=seq_idx,
|
|
request_index=attempt_idx,
|
|
)
|
|
for seq_idx, function_call in enumerate(function_calls)
|
|
])
|
|
# add a single ChatMessage to the response with the results
|
|
response.messages.append(ChatMessage(role="tool", contents=results))
|
|
# response should contain 2 messages after this,
|
|
# one with function call contents
|
|
# and one with function result contents
|
|
# the amount and call_id's should match
|
|
# this runs in every but the first run
|
|
# we need to keep track of all function call messages
|
|
fcc_messages.extend(response.messages)
|
|
# and add them as additional context to the messages
|
|
messages.extend(response.messages)
|
|
continue
|
|
# If we reach this point, it means there were no function calls to handle,
|
|
# we'll add the previous function call and responses
|
|
# to the front of the list, so that the final response is the last one
|
|
# TODO (eavanvalkenburg): control this behavior?
|
|
if fcc_messages:
|
|
for msg in reversed(fcc_messages):
|
|
response.messages.insert(0, msg)
|
|
return response
|
|
|
|
# Failsafe: give up on tools, ask model for plain answer
|
|
chat_options.tool_choice = "none"
|
|
self._prepare_tools_and_tool_choice(chat_options=chat_options) # type: ignore[reportPrivateUsage]
|
|
response = await func(self, messages=messages, chat_options=chat_options)
|
|
if fcc_messages:
|
|
for msg in reversed(fcc_messages):
|
|
response.messages.insert(0, msg)
|
|
return response
|
|
|
|
return wrapper
|
|
|
|
|
|
def _tool_call_streaming(
|
|
func: Callable[..., AsyncIterable["ChatResponseUpdate"]],
|
|
) -> Callable[..., AsyncIterable["ChatResponseUpdate"]]:
|
|
"""Decorate the internal _inner_get_response method to enable tool calls."""
|
|
|
|
@wraps(func)
|
|
async def wrapper(
|
|
self: "ChatClientBase",
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Wrap the inner get streaming response method to handle tool calls."""
|
|
for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)):
|
|
function_call_returned = False
|
|
all_messages: list[ChatResponseUpdate] = []
|
|
async for update in func(self, messages=messages, chat_options=chat_options):
|
|
if update.contents and any(isinstance(item, FunctionCallContent) for item in update.contents):
|
|
all_messages.append(update)
|
|
function_call_returned = True
|
|
yield update
|
|
|
|
if not function_call_returned:
|
|
return
|
|
|
|
# There is one FunctionCallContent response stream in the messages, combining now to create
|
|
# the full completion depending on the prompt, the message may contain both function call
|
|
# content and others
|
|
response: ChatResponse = ChatResponse.from_chat_response_updates(all_messages)
|
|
# add the single assistant response message to the history
|
|
messages.append(response.messages[0])
|
|
function_calls = [item for item in response.messages[0].contents if isinstance(item, FunctionCallContent)]
|
|
|
|
# When conversation id is present, it means that messages are hosted on the server.
|
|
# In this case, we need to update ChatOptions with conversation id and also clear messages
|
|
if response.conversation_id is not None:
|
|
chat_options.conversation_id = response.conversation_id
|
|
messages = []
|
|
|
|
if function_calls:
|
|
# Run all function calls concurrently
|
|
results = await asyncio.gather(*[
|
|
_auto_invoke_function(
|
|
function_call,
|
|
custom_args=kwargs,
|
|
tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)}, # type: ignore[reportPrivateUsage]
|
|
sequence_index=seq_idx,
|
|
request_index=attempt_idx,
|
|
)
|
|
for seq_idx, function_call in enumerate(function_calls)
|
|
])
|
|
yield ChatResponseUpdate(contents=results, role="tool")
|
|
function_result_msg = ChatMessage(role="tool", contents=results)
|
|
response.messages.append(function_result_msg)
|
|
messages.append(function_result_msg)
|
|
continue
|
|
|
|
# Failsafe: give up on tools, ask model for plain answer
|
|
chat_options.tool_choice = "none"
|
|
self._prepare_tools_and_tool_choice(chat_options=chat_options) # type: ignore[reportPrivateUsage]
|
|
async for update in func(self, messages=messages, chat_options=chat_options, **kwargs):
|
|
yield update
|
|
|
|
return wrapper
|
|
|
|
|
|
def use_tool_calling(cls: type[TChatClientBase]) -> type[TChatClientBase]:
|
|
"""Class decorator that enables tool calling for a chat client.
|
|
|
|
Remarks:
|
|
This only works on classes that derive from ChatClientBase
|
|
and the `_inner_get_response`
|
|
and `_inner_get_streaming_response` methods.
|
|
It also sets a `__maximum_iterations_per_request` attribute on the class.
|
|
if you want to expose this to end_users, do a version of this:
|
|
|
|
@property
|
|
|
|
def maximum_iterations_per_request(self):
|
|
return getattr(self, "__maximum_iterations_per_request", 10)
|
|
|
|
@maximum_iterations_per_request.setter
|
|
|
|
def maximum_iterations_per_request(self, value: int) -> None:
|
|
setattr(self, "__maximum_iterations_per_request", value)
|
|
|
|
"""
|
|
setattr(cls, "__maximum_iterations_per_request", 10)
|
|
|
|
if inner_response := getattr(cls, "_inner_get_response", None):
|
|
cls._inner_get_response = _tool_call_non_streaming(inner_response) # type: ignore
|
|
if inner_streaming_response := getattr(cls, "_inner_get_streaming_response", None):
|
|
cls._inner_get_streaming_response = _tool_call_streaming(inner_streaming_response) # type: ignore
|
|
return cls
|
|
|
|
|
|
# region: ChatClient Protocol
|
|
|
|
|
|
@runtime_checkable
|
|
class ChatClient(Protocol):
|
|
"""A protocol for a chat client that can generate responses."""
|
|
|
|
async def get_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: AITool
|
|
| list[AITool]
|
|
| Callable[..., Any]
|
|
| list[Callable[..., Any]]
|
|
| MutableMapping[str, Any]
|
|
| list[MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Sends input and returns the response.
|
|
|
|
Args:
|
|
messages: The sequence of input messages to send.
|
|
frequency_penalty: the frequency penalty to use.
|
|
logit_bias: the logit bias to use.
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
response_format: the format of the response.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments,
|
|
will only be passed to functions that are called.
|
|
|
|
Returns:
|
|
The response messages generated by the client.
|
|
|
|
Raises:
|
|
ValueError: If the input message sequence is `None`.
|
|
"""
|
|
...
|
|
|
|
def get_streaming_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: AITool
|
|
| list[AITool]
|
|
| Callable[..., Any]
|
|
| list[Callable[..., Any]]
|
|
| MutableMapping[str, Any]
|
|
| list[MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Sends input messages and streams the response.
|
|
|
|
Args:
|
|
messages: The sequence of input messages to send.
|
|
frequency_penalty: the frequency penalty to use.
|
|
logit_bias: the logit bias to use.
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
response_format: the format of the response.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments,
|
|
will only be passed to functions that are called.
|
|
|
|
Yields:
|
|
An async iterable of chat response updates containing the content of the response messages
|
|
generated by the client.
|
|
|
|
Raises:
|
|
ValueError: If the input message sequence is `None`.
|
|
"""
|
|
...
|
|
|
|
|
|
class ChatClientBase(AFBaseModel, ABC):
|
|
"""Base class for chat clients."""
|
|
|
|
MODEL_PROVIDER_NAME: str = "unknown"
|
|
# This is used for OTel setup, should be overridden in subclasses
|
|
|
|
def _prepare_messages(
|
|
self, messages: str | ChatMessage | list[str] | list[ChatMessage]
|
|
) -> MutableSequence[ChatMessage]:
|
|
"""Turn the allowed input into a list of chat messages."""
|
|
if isinstance(messages, str):
|
|
return [ChatMessage(role="user", text=messages)]
|
|
if isinstance(messages, ChatMessage):
|
|
return [messages]
|
|
return_messages: list[ChatMessage] = []
|
|
for msg in messages:
|
|
if isinstance(msg, str):
|
|
msg = ChatMessage(role="user", text=msg)
|
|
return_messages.append(msg)
|
|
return return_messages
|
|
|
|
# region Internal methods to be implemented by the derived classes
|
|
|
|
@abstractmethod
|
|
async def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Send a chat request to the AI service.
|
|
|
|
Args:
|
|
messages: The chat messages to send.
|
|
chat_options: The options for the request.
|
|
kwargs: Any additional keyword arguments.
|
|
|
|
Returns:
|
|
The chat response contents representing the response(s).
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def _inner_get_streaming_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Send a streaming chat request to the AI service.
|
|
|
|
Args:
|
|
messages: The chat messages to send.
|
|
chat_options: The chat_options for the request.
|
|
kwargs: Any additional keyword arguments.
|
|
|
|
Yields:
|
|
ChatResponseUpdate: The streaming chat message contents.
|
|
"""
|
|
# Below is needed for mypy: https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
if False:
|
|
yield
|
|
await asyncio.sleep(0) # pragma: no cover
|
|
# This is a no-op, but it allows the method to be async and return an AsyncIterable.
|
|
# The actual implementation should yield ChatResponseUpdate instances as needed.
|
|
|
|
# endregion
|
|
|
|
# region Public method
|
|
|
|
async def get_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: AITool
|
|
| list[AITool]
|
|
| Callable[..., Any]
|
|
| list[Callable[..., Any]]
|
|
| MutableMapping[str, Any]
|
|
| list[MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Get a response from a chat client.
|
|
|
|
Args:
|
|
messages: the message or messages to send to the model
|
|
frequency_penalty: the frequency penalty to use.
|
|
logit_bias: the logit bias to use.
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
response_format: the format of the response.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request.
|
|
kwargs: any additional keyword arguments,
|
|
will only be passed to functions that are called.
|
|
|
|
Returns:
|
|
A chat response from the model.
|
|
"""
|
|
if "chat_options" in kwargs:
|
|
chat_options = kwargs.pop("chat_options")
|
|
if not isinstance(chat_options, ChatOptions):
|
|
raise TypeError("chat_options must be an instance of ChatOptions")
|
|
else:
|
|
chat_options = ChatOptions(
|
|
ai_model_id=model,
|
|
frequency_penalty=frequency_penalty,
|
|
logit_bias=logit_bias,
|
|
max_tokens=max_tokens,
|
|
metadata=metadata,
|
|
presence_penalty=presence_penalty,
|
|
response_format=response_format,
|
|
seed=seed,
|
|
stop=stop,
|
|
store=store,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
tool_choice=tool_choice,
|
|
tools=tools, # type: ignore
|
|
user=user,
|
|
additional_properties=additional_properties or {},
|
|
)
|
|
prepped_messages = self._prepare_messages(messages)
|
|
self._prepare_tools_and_tool_choice(chat_options=chat_options)
|
|
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
|
|
|
|
async def get_streaming_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: AITool
|
|
| list[AITool]
|
|
| Callable[..., Any]
|
|
| list[Callable[..., Any]]
|
|
| MutableMapping[str, Any]
|
|
| list[MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Get a streaming response from a chat client.
|
|
|
|
Args:
|
|
messages: the message or messages to send to the model
|
|
frequency_penalty: the frequency penalty to use
|
|
logit_bias: the logit bias to use
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
response_format: the format of the response.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments
|
|
|
|
Yields:
|
|
A stream representing the response(s) from the LLM.
|
|
"""
|
|
if "chat_options" in kwargs:
|
|
chat_options = kwargs.pop("chat_options")
|
|
if not isinstance(chat_options, ChatOptions):
|
|
raise TypeError("chat_options must be an instance of ChatOptions")
|
|
else:
|
|
chat_options = ChatOptions(
|
|
ai_model_id=model,
|
|
frequency_penalty=frequency_penalty,
|
|
logit_bias=logit_bias,
|
|
max_tokens=max_tokens,
|
|
metadata=metadata,
|
|
presence_penalty=presence_penalty,
|
|
response_format=response_format,
|
|
seed=seed,
|
|
stop=stop,
|
|
store=store,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
tool_choice=tool_choice,
|
|
tools=tools, # type: ignore
|
|
user=user,
|
|
additional_properties=additional_properties or {},
|
|
**kwargs,
|
|
)
|
|
prepped_messages = self._prepare_messages(messages)
|
|
self._prepare_tools_and_tool_choice(chat_options=chat_options)
|
|
async for update in self._inner_get_streaming_response(
|
|
messages=prepped_messages, chat_options=chat_options, **kwargs
|
|
):
|
|
yield update
|
|
|
|
def _prepare_tools_and_tool_choice(self, chat_options: ChatOptions) -> None:
|
|
"""Prepare the tools and tool choice for the chat options.
|
|
|
|
This function should be overridden by subclasses to customize tool handling.
|
|
Because it currently parses only AIFunctions.
|
|
"""
|
|
chat_tool_mode: ChatToolMode | None = chat_options.tool_choice # type: ignore
|
|
if chat_tool_mode is None or chat_tool_mode == ChatToolMode.NONE:
|
|
chat_options.tools = None
|
|
chat_options.tool_choice = ChatToolMode.NONE.mode
|
|
return
|
|
chat_options.tools = [
|
|
(ai_function_to_json_schema_spec(t) if isinstance(t, AIFunction) else t) # type: ignore[reportUnknownArgumentType]
|
|
for t in chat_options._ai_tools or [] # type: ignore[reportPrivateUsage]
|
|
]
|
|
if not chat_options.tools:
|
|
chat_options.tool_choice = ChatToolMode.NONE.mode
|
|
else:
|
|
chat_options.tool_choice = chat_tool_mode.mode
|
|
|
|
def service_url(self) -> str | None:
|
|
"""Get the URL of the service.
|
|
|
|
Override this in the subclass to return the proper URL.
|
|
If the service does not have a URL, return None.
|
|
"""
|
|
return None
|
|
|
|
def create_agent(
|
|
self,
|
|
*,
|
|
name: str,
|
|
instructions: str,
|
|
tools: AITool
|
|
| list[AITool]
|
|
| Callable[..., Any]
|
|
| list[Callable[..., Any]]
|
|
| MutableMapping[str, Any]
|
|
| list[MutableMapping[str, Any]]
|
|
| None = None,
|
|
**kwargs: Any,
|
|
) -> "ChatClientAgent":
|
|
"""Create an agent with the given name and instructions.
|
|
|
|
Args:
|
|
name: The name of the agent.
|
|
instructions: The instructions for the agent.
|
|
tools: Optional list of tools to associate with the agent.
|
|
**kwargs: Additional keyword arguments to pass to the agent.
|
|
See ChatClientAgent for all the available options.
|
|
|
|
Returns:
|
|
An instance of ChatClientAgent.
|
|
"""
|
|
from ._agents import ChatClientAgent
|
|
|
|
return ChatClientAgent(chat_client=self, name=name, instructions=instructions, tools=tools, **kwargs)
|
|
|
|
|
|
# region: Embedding Client
|
|
|
|
|
|
@runtime_checkable
|
|
class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]):
|
|
"""A protocol for an embedding generator that can create embeddings from input data."""
|
|
|
|
async def generate(
|
|
self,
|
|
input_data: Sequence[TInput],
|
|
**kwargs: Any,
|
|
) -> GeneratedEmbeddings[TEmbedding]:
|
|
"""Generates an embedding for the given input data.
|
|
|
|
Args:
|
|
input_data: The input data to generate an embedding for.
|
|
**kwargs: Additional options for the request.
|
|
|
|
Returns:
|
|
The generated embedding, this acts like a list, but has additional metadata and usage details.
|
|
|
|
"""
|
|
...
|