mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: simple Agent sample (#180)
* tweaks to agents and sample * updated clients and agents * single line run and print * improved tool handling * added note on setting max iterations * fixed streaming param name * updated tools test * made kwargs alphabetical * added params to run methods * tweak to ensure right overload
This commit is contained in:
committed by
GitHub
Unverified
parent
59bf7343af
commit
62917ee5e5
@@ -1,15 +1,25 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Callable, Sequence
|
||||
from collections.abc import AsyncIterable, Callable, MutableMapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
from typing import Any, Literal, Protocol, TypeVar, runtime_checkable
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ._clients import ChatClient
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole
|
||||
from ._tools import AITool
|
||||
from ._types import (
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ChatRole,
|
||||
ChatToolMode,
|
||||
)
|
||||
from .exceptions import AgentExecutionException
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound="AgentThread")
|
||||
@@ -71,7 +81,7 @@ class Agent(Protocol):
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -100,7 +110,7 @@ class Agent(Protocol):
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -236,17 +246,168 @@ class ChatClientAgentThread(AgentThread):
|
||||
|
||||
|
||||
class ChatClientAgent(AgentBase):
|
||||
"""A Chat Client Agent which depends on ChatClient."""
|
||||
"""A Chat Client Agent."""
|
||||
|
||||
chat_client: ChatClient
|
||||
instructions: str | None = None
|
||||
chat_options: ChatOptions
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_client: ChatClient,
|
||||
instructions: str | None = None,
|
||||
*,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model: str | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tools: AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a ChatClientAgent.
|
||||
|
||||
Remarks:
|
||||
The set of attributes from frequency_penalty to additional_properties are used to
|
||||
call the chat client, they can also be passed to both run methods.
|
||||
When both are set, the ones passed to the run methods take precedence.
|
||||
|
||||
Args:
|
||||
chat_client: The chat client to use for the agent.
|
||||
instructions: Optional instructions for the agent.
|
||||
These will be put into the messages sent to the chat client service as a system message.
|
||||
id: The unique identifier for the agent, will be created automatically if not provided.
|
||||
name: The name of the agent.
|
||||
description: A brief description of the agent's purpose.
|
||||
frequency_penalty: the frequency penalty to use.
|
||||
logit_bias: the logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: additional metadata to include in the request.
|
||||
model: The model to use for the agent.
|
||||
presence_penalty: the presence penalty to use.
|
||||
response_format: the format of the response.
|
||||
seed: the random seed to use.
|
||||
stop: the stop sequence(s) for the request.
|
||||
store: whether to store the response.
|
||||
temperature: the sampling temperature to use.
|
||||
tool_choice: the tool choice for the request.
|
||||
tools: the tools to use for the request.
|
||||
top_p: the nucleus sampling probability to use.
|
||||
user: the user to associate with the request.
|
||||
additional_properties: additional properties to include in the request.
|
||||
kwargs: any additional keyword arguments.
|
||||
Unused, can be used by subclasses of this Agent.
|
||||
"""
|
||||
args: dict[str, Any] = {
|
||||
"chat_client": chat_client,
|
||||
"chat_options": ChatOptions(
|
||||
ai_model_id=model,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
metadata=metadata,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools, # type: ignore
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
additional_properties=additional_properties or {},
|
||||
),
|
||||
}
|
||||
if instructions is not None:
|
||||
args["instructions"] = instructions
|
||||
if name is not None:
|
||||
args["name"] = name
|
||||
if description is not None:
|
||||
args["description"] = description
|
||||
if id is not None:
|
||||
args["id"] = id
|
||||
|
||||
super().__init__(**args)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model: str | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResponse:
|
||||
"""Run the agent with the given messages and options.
|
||||
|
||||
Remarks:
|
||||
Since you won't always call the agent.run directly, but it get's called
|
||||
through orchestration, it is advised to set your default values for
|
||||
all the chat client parameters in the agent constructor.
|
||||
If both parameters are used, the ones passed to the run methods take precedence.
|
||||
|
||||
Args:
|
||||
messages: The messages to process.
|
||||
thread: The thread to use for the agent.
|
||||
frequency_penalty: the frequency penalty to use.
|
||||
logit_bias: the logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: additional metadata to include in the request.
|
||||
model: The model to use for the agent.
|
||||
presence_penalty: the presence penalty to use.
|
||||
response_format: the format of the response.
|
||||
seed: the random seed to use.
|
||||
stop: the stop sequence(s) for the request.
|
||||
store: whether to store the response.
|
||||
temperature: the sampling temperature to use.
|
||||
tool_choice: the tool choice for the request.
|
||||
tools: the tools to use for the request.
|
||||
top_p: the nucleus sampling probability to use.
|
||||
user: the user to associate with the request.
|
||||
additional_properties: additional properties to include in the request.
|
||||
kwargs: Additional keyword arguments for the agent.
|
||||
will only be passed to functions that are called.
|
||||
"""
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread,
|
||||
input_messages=messages,
|
||||
@@ -254,7 +415,29 @@ class ChatClientAgent(AgentBase):
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
|
||||
response = await self.chat_client.get_response(thread_messages, **kwargs)
|
||||
response = await self.chat_client.get_response(
|
||||
messages=thread_messages,
|
||||
chat_options=self.chat_options
|
||||
& ChatOptions(
|
||||
ai_model_id=model,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
metadata=metadata,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools, # type: ignore
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
additional_properties=additional_properties or {},
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
|
||||
|
||||
@@ -274,11 +457,64 @@ class ChatClientAgent(AgentBase):
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model: str | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
"""Stream the agent with the given messages and options.
|
||||
|
||||
Remarks:
|
||||
Since you won't always call the agent.run_stream directly, but it get's called
|
||||
through orchestration, it is advised to set your default values for
|
||||
all the chat client parameters in the agent constructor.
|
||||
If both parameters are used, the ones passed to the run methods take precedence.
|
||||
|
||||
Args:
|
||||
messages: The messages to process.
|
||||
thread: The thread to use for the agent.
|
||||
frequency_penalty: the frequency penalty to use.
|
||||
logit_bias: the logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: additional metadata to include in the request.
|
||||
model: The model to use for the agent.
|
||||
presence_penalty: the presence penalty to use.
|
||||
response_format: the format of the response.
|
||||
seed: the random seed to use.
|
||||
stop: the stop sequence(s) for the request.
|
||||
store: whether to store the response.
|
||||
temperature: the sampling temperature to use.
|
||||
tool_choice: the tool choice for the request.
|
||||
tools: the tools to use for the request.
|
||||
top_p: the nucleus sampling probability to use.
|
||||
user: the user to associate with the request.
|
||||
additional_properties: additional properties to include in the request.
|
||||
kwargs: any additional keyword arguments.
|
||||
will only be passed to functions that are called.
|
||||
|
||||
"""
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread,
|
||||
input_messages=messages,
|
||||
@@ -288,9 +524,29 @@ class ChatClientAgent(AgentBase):
|
||||
|
||||
response_updates: list[ChatResponseUpdate] = []
|
||||
|
||||
streaming_response: AsyncIterable[ChatResponseUpdate] = self.chat_client.get_streaming_response(thread_messages)
|
||||
|
||||
async for update in streaming_response:
|
||||
async for update in self.chat_client.get_streaming_response(
|
||||
messages=thread_messages,
|
||||
chat_options=self.chat_options
|
||||
& ChatOptions(
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
metadata=metadata,
|
||||
ai_model_id=model,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools, # type: ignore
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
additional_properties=additional_properties or {},
|
||||
),
|
||||
**kwargs,
|
||||
):
|
||||
response_updates.append(update)
|
||||
yield AgentRunResponseUpdate(
|
||||
contents=update.contents,
|
||||
@@ -348,7 +604,7 @@ class ChatClientAgent(AgentBase):
|
||||
self,
|
||||
*,
|
||||
thread: AgentThread | None,
|
||||
input_messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
input_messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None,
|
||||
construct_thread: Callable[[], TThreadType],
|
||||
expected_type: type[TThreadType],
|
||||
) -> tuple[TThreadType, list[ChatMessage]]:
|
||||
@@ -367,6 +623,8 @@ class ChatClientAgent(AgentBase):
|
||||
AgentExecutionException: If thread type is incompatible.
|
||||
"""
|
||||
messages: list[ChatMessage] = []
|
||||
if self.instructions:
|
||||
messages.append(ChatMessage(role=ChatRole.SYSTEM, text=self.instructions))
|
||||
|
||||
if thread is None:
|
||||
thread = construct_thread()
|
||||
@@ -380,17 +638,15 @@ class ChatClientAgent(AgentBase):
|
||||
if isinstance(thread, MessagesRetrievableThread):
|
||||
async for message in thread.get_messages():
|
||||
messages.append(message)
|
||||
|
||||
if input_messages is None:
|
||||
input_messages = []
|
||||
|
||||
if isinstance(input_messages, (str, ChatMessage)):
|
||||
input_messages = [input_messages]
|
||||
|
||||
normalized_messages = [
|
||||
return thread, messages
|
||||
if isinstance(input_messages, str):
|
||||
messages.append(ChatMessage(role=ChatRole.USER, text=input_messages))
|
||||
return thread, messages
|
||||
if isinstance(input_messages, ChatMessage):
|
||||
messages.append(input_messages)
|
||||
return thread, messages
|
||||
messages.extend([
|
||||
ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in input_messages
|
||||
]
|
||||
|
||||
messages.extend(normalized_messages)
|
||||
|
||||
])
|
||||
return thread, messages
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, MutableSequence, Sequence
|
||||
from functools import wraps
|
||||
from typing import Annotated, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr, StringConstraints
|
||||
from pydantic import BaseModel, StringConstraints
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._pydantic import AFBaseModel
|
||||
@@ -37,24 +37,6 @@ logger = get_logger()
|
||||
# region: Tool Calling Functions and Decorators
|
||||
|
||||
|
||||
def _merge_function_results(
|
||||
messages: list[ChatMessage],
|
||||
) -> ChatMessage:
|
||||
"""Combine multiple function result content types to one chat message content type.
|
||||
|
||||
This method combines the FunctionResultContent items from separate ChatMessageContent messages,
|
||||
and is used in the event that the `context.terminate = True` condition is met.
|
||||
"""
|
||||
contents: list[Any] = []
|
||||
for message in messages:
|
||||
contents.extend([item for item in message.contents if isinstance(item, FunctionResultContent)])
|
||||
|
||||
return ChatMessage(
|
||||
role="tool",
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
|
||||
async def _auto_invoke_function(
|
||||
function_call_content: FunctionCallContent,
|
||||
custom_args: dict[str, Any] | None = None,
|
||||
@@ -106,7 +88,7 @@ def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None:
|
||||
chat_options.tool_choice = ChatToolMode.NONE.mode
|
||||
return
|
||||
chat_options.tools = [
|
||||
(_tool_to_json_schema_spec(t) if isinstance(t, AITool) else t) for t in chat_options.tools or []
|
||||
(_tool_to_json_schema_spec(t) if isinstance(t, AITool) else t) for t in chat_options._ai_tools or []
|
||||
]
|
||||
if not chat_options.tools:
|
||||
chat_options.tool_choice = ChatToolMode.NONE.mode
|
||||
@@ -115,11 +97,7 @@ def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None:
|
||||
|
||||
|
||||
def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
|
||||
"""Decorate the internal _inner_get_response method to enable tool calls.
|
||||
|
||||
Remarks:
|
||||
Relies on a class that has the _tool_map attribute for the executable tools to call.
|
||||
"""
|
||||
"""Decorate the internal _inner_get_response method to enable tool calls."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(
|
||||
@@ -131,7 +109,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
|
||||
) -> ChatResponse:
|
||||
response: ChatResponse | None = None
|
||||
fcc_messages: list[ChatMessage] = []
|
||||
for attempt_idx in range(self.maximum_iterations_per_request):
|
||||
for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)):
|
||||
response = await func(self, messages=messages, chat_options=chat_options)
|
||||
# if there are function calls, we will handle them first
|
||||
function_calls = [it for it in response.messages[0].contents if isinstance(it, FunctionCallContent)]
|
||||
@@ -141,7 +119,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
|
||||
_auto_invoke_function(
|
||||
function_call,
|
||||
custom_args=kwargs,
|
||||
tool_map=self._tool_map,
|
||||
tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)},
|
||||
sequence_index=seq_idx,
|
||||
request_index=attempt_idx,
|
||||
)
|
||||
@@ -181,11 +159,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
|
||||
|
||||
|
||||
def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreamingResponse:
|
||||
"""Decorate the internal _inner_get_response method to enable tool calls.
|
||||
|
||||
Remarks:
|
||||
Relies on a class that has the _tool_map attribute for the executable tools to call.
|
||||
"""
|
||||
"""Decorate the internal _inner_get_response method to enable tool calls."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(
|
||||
@@ -195,7 +169,8 @@ def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreaming
|
||||
chat_options: ChatOptions,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
for attempt_idx in range(self.maximum_iterations_per_request):
|
||||
"""Wrap the inner get streaming response method to handle tool calls."""
|
||||
for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)):
|
||||
function_call_returned = False
|
||||
all_messages: list[ChatResponseUpdate] = []
|
||||
async for update in func(self, messages=messages, chat_options=chat_options):
|
||||
@@ -221,7 +196,7 @@ def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreaming
|
||||
_auto_invoke_function(
|
||||
function_call,
|
||||
custom_args=kwargs,
|
||||
tool_map=self._tool_map,
|
||||
tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)},
|
||||
sequence_index=seq_idx,
|
||||
request_index=attempt_idx,
|
||||
)
|
||||
@@ -247,10 +222,25 @@ def use_tool_calling(cls: type[TChatClientBase]) -> type[TChatClientBase]:
|
||||
|
||||
Remarks:
|
||||
This only works on classes that derive from ChatClientBase
|
||||
and have the _tool_map attribute as well as the _inner_get_response
|
||||
and the _inner_get_response
|
||||
and _inner_get_streaming_response methods.
|
||||
It also sets a __maximum_iterations_per_request attribute on the class.
|
||||
if you want to expose this to end_users, do a version of this:
|
||||
|
||||
```python
|
||||
@use_tool_calling
|
||||
class MyChatClient(ChatClientBase):
|
||||
@property
|
||||
def maximum_iterations_per_request(self):
|
||||
return getattr(self, "__maximum_iterations_per_request", 10)
|
||||
|
||||
@maximum_iterations_per_request.setter
|
||||
def maximum_iterations_per_request(self, value: int) -> None:
|
||||
setattr(self, "__maximum_iterations_per_request", value)
|
||||
```
|
||||
"""
|
||||
setattr(cls, "__maximum_iterations_per_request", 10)
|
||||
|
||||
if inner_response := getattr(cls, "_inner_get_response", None):
|
||||
cls._inner_get_response = _tool_call_non_streaming(inner_response) # type: ignore
|
||||
if inner_streaming_response := getattr(cls, "_inner_get_streaming_response", None):
|
||||
@@ -267,15 +257,54 @@ class ChatClient(Protocol):
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | list[ChatMessage],
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
||||
*,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model: str | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tools: AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
"""Sends input and returns the response.
|
||||
|
||||
Args:
|
||||
messages: The sequence of input messages to send.
|
||||
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
||||
See `ChatOptions` for more details.
|
||||
frequency_penalty: the frequency penalty to use.
|
||||
logit_bias: the logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: additional metadata to include in the request.
|
||||
model: The model to use for the agent.
|
||||
presence_penalty: the presence penalty to use.
|
||||
response_format: the format of the response.
|
||||
seed: the random seed to use.
|
||||
stop: the stop sequence(s) for the request.
|
||||
store: whether to store the response.
|
||||
temperature: the sampling temperature to use.
|
||||
tool_choice: the tool choice for the request.
|
||||
tools: the tools to use for the request.
|
||||
top_p: the nucleus sampling probability to use.
|
||||
user: the user to associate with the request.
|
||||
additional_properties: additional properties to include in the request
|
||||
kwargs: any additional keyword arguments,
|
||||
will only be passed to functions that are called.
|
||||
|
||||
Returns:
|
||||
The response messages generated by the client.
|
||||
@@ -287,15 +316,54 @@ class ChatClient(Protocol):
|
||||
|
||||
def get_streaming_response(
|
||||
self,
|
||||
messages: str | ChatMessage | list[ChatMessage],
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
||||
*,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model: str | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tools: AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Sends input messages and streams the response.
|
||||
|
||||
Args:
|
||||
messages: The sequence of input messages to send.
|
||||
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
||||
See `ChatOptions` for more details.
|
||||
frequency_penalty: the frequency penalty to use.
|
||||
logit_bias: the logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: additional metadata to include in the request.
|
||||
model: The model to use for the agent.
|
||||
presence_penalty: the presence penalty to use.
|
||||
response_format: the format of the response.
|
||||
seed: the random seed to use.
|
||||
stop: the stop sequence(s) for the request.
|
||||
store: whether to store the response.
|
||||
temperature: the sampling temperature to use.
|
||||
tool_choice: the tool choice for the request.
|
||||
tools: the tools to use for the request.
|
||||
top_p: the nucleus sampling probability to use.
|
||||
user: the user to associate with the request.
|
||||
additional_properties: additional properties to include in the request
|
||||
kwargs: any additional keyword arguments,
|
||||
will only be passed to functions that are called.
|
||||
|
||||
Yields:
|
||||
An async iterable of chat response updates containing the content of the response messages
|
||||
@@ -311,19 +379,21 @@ class ChatClientBase(AFBaseModel, ABC):
|
||||
"""Base class for chat clients."""
|
||||
|
||||
ai_model_id: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]
|
||||
maximum_iterations_per_request: int = 10
|
||||
_tool_map: dict[str, AIFunction[BaseModel, Any]] = PrivateAttr(default_factory=dict) # type: ignore
|
||||
|
||||
def _prepare_messages(self, messages: str | ChatMessage | list[str | ChatMessage]) -> MutableSequence[ChatMessage]:
|
||||
def _prepare_messages(
|
||||
self, messages: str | ChatMessage | list[str] | list[ChatMessage]
|
||||
) -> MutableSequence[ChatMessage]:
|
||||
"""Turn the allowed input into a list of chat messages."""
|
||||
if isinstance(messages, str):
|
||||
messages = [ChatMessage(role="user", text=messages)]
|
||||
return [ChatMessage(role="user", text=messages)]
|
||||
if isinstance(messages, ChatMessage):
|
||||
messages = [messages]
|
||||
for i, msg in enumerate(messages):
|
||||
return [messages]
|
||||
return_messages: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
messages[i] = ChatMessage(role="user", text=msg)
|
||||
return messages # type: ignore[return-value]
|
||||
msg = ChatMessage(role="user", text=msg)
|
||||
return_messages.append(msg)
|
||||
return return_messages
|
||||
|
||||
# region Internal methods to be implemented by the derived classes
|
||||
|
||||
@@ -377,23 +447,29 @@ class ChatClientBase(AFBaseModel, ABC):
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage],
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tools: AITool | Sequence[AITool] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
user: str | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
store: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model: str | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tools: AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
@@ -401,73 +477,80 @@ class ChatClientBase(AFBaseModel, ABC):
|
||||
|
||||
Args:
|
||||
messages: the message or messages to send to the model
|
||||
model: the model to use for the request
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
temperature: the sampling temperature to use
|
||||
top_p: the nucleus sampling probability to use
|
||||
tool_choice: the tool choice for the request
|
||||
tools: the tools to use for the request
|
||||
response_format: the format of the response
|
||||
user: the user to associate with the request
|
||||
stop: the stop sequence(s) for the request
|
||||
frequency_penalty: the frequency penalty to use
|
||||
logit_bias: the logit bias to use
|
||||
presence_penalty: the presence penalty to use
|
||||
seed: the random seed to use
|
||||
store: whether to store the response
|
||||
metadata: additional metadata to include in the request
|
||||
additional_properties: additional properties to include in the request
|
||||
frequency_penalty: the frequency penalty to use.
|
||||
logit_bias: the logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: additional metadata to include in the request.
|
||||
model: The model to use for the agent.
|
||||
presence_penalty: the presence penalty to use.
|
||||
response_format: the format of the response.
|
||||
seed: the random seed to use.
|
||||
stop: the stop sequence(s) for the request.
|
||||
store: whether to store the response.
|
||||
temperature: the sampling temperature to use.
|
||||
tool_choice: the tool choice for the request.
|
||||
tools: the tools to use for the request.
|
||||
top_p: the nucleus sampling probability to use.
|
||||
user: the user to associate with the request.
|
||||
additional_properties: additional properties to include in the request.
|
||||
kwargs: any additional keyword arguments,
|
||||
will only be passed to functions that are called.
|
||||
|
||||
Returns:
|
||||
A chat response from the model.
|
||||
"""
|
||||
if tools is not None:
|
||||
if not isinstance(tools, Sequence):
|
||||
tools = [tools]
|
||||
self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)}
|
||||
chat_options = ChatOptions(
|
||||
ai_model_id=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
response_format=response_format,
|
||||
user=user,
|
||||
stop=stop,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
store=store,
|
||||
metadata=metadata,
|
||||
additional_properties=additional_properties or {},
|
||||
)
|
||||
if "chat_options" in kwargs:
|
||||
chat_options = kwargs.pop("chat_options")
|
||||
if not isinstance(chat_options, ChatOptions):
|
||||
raise TypeError("chat_options must be an instance of ChatOptions")
|
||||
else:
|
||||
chat_options = ChatOptions(
|
||||
ai_model_id=model,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
metadata=metadata,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools, # type: ignore
|
||||
user=user,
|
||||
additional_properties=additional_properties or {},
|
||||
)
|
||||
prepped_messages = self._prepare_messages(messages)
|
||||
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage],
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tools: AITool | Sequence[AITool] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
user: str | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
store: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model: str | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tools: AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
@@ -475,50 +558,51 @@ class ChatClientBase(AFBaseModel, ABC):
|
||||
|
||||
Args:
|
||||
messages: the message or messages to send to the model
|
||||
model: the model to use for the request
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
temperature: the sampling temperature to use
|
||||
top_p: the nucleus sampling probability to use
|
||||
tool_choice: the tool choice for the request
|
||||
tools: the tools to use for the request
|
||||
response_format: the format of the response
|
||||
user: the user to associate with the request
|
||||
stop: the stop sequence(s) for the request
|
||||
frequency_penalty: the frequency penalty to use
|
||||
logit_bias: the logit bias to use
|
||||
presence_penalty: the presence penalty to use
|
||||
seed: the random seed to use
|
||||
store: whether to store the response
|
||||
metadata: additional metadata to include in the request
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: additional metadata to include in the request.
|
||||
model: The model to use for the agent.
|
||||
presence_penalty: the presence penalty to use.
|
||||
response_format: the format of the response.
|
||||
seed: the random seed to use.
|
||||
stop: the stop sequence(s) for the request.
|
||||
store: whether to store the response.
|
||||
temperature: the sampling temperature to use.
|
||||
tool_choice: the tool choice for the request.
|
||||
tools: the tools to use for the request.
|
||||
top_p: the nucleus sampling probability to use.
|
||||
user: the user to associate with the request.
|
||||
additional_properties: additional properties to include in the request
|
||||
kwargs: any additional keyword arguments
|
||||
|
||||
Yields:
|
||||
A stream representing the response(s) from the LLM.
|
||||
"""
|
||||
if tools is not None:
|
||||
if not isinstance(tools, Sequence):
|
||||
tools = [tools]
|
||||
self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)}
|
||||
chat_options = ChatOptions(
|
||||
ai_model_id=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
response_format=response_format,
|
||||
user=user,
|
||||
stop=stop,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
store=store,
|
||||
metadata=metadata,
|
||||
additional_properties=additional_properties or {},
|
||||
**kwargs,
|
||||
)
|
||||
if "chat_options" in kwargs:
|
||||
chat_options = kwargs.pop("chat_options")
|
||||
if not isinstance(chat_options, ChatOptions):
|
||||
raise TypeError("chat_options must be an instance of ChatOptions")
|
||||
else:
|
||||
chat_options = ChatOptions(
|
||||
ai_model_id=model,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
metadata=metadata,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools, # type: ignore
|
||||
user=user,
|
||||
additional_properties=additional_properties or {},
|
||||
**kwargs,
|
||||
)
|
||||
prepped_messages = self._prepare_messages(messages)
|
||||
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
||||
async for update in self._inner_get_streaming_response(
|
||||
|
||||
@@ -10,7 +10,16 @@ from pydantic import BaseModel, create_model
|
||||
|
||||
@runtime_checkable
|
||||
class AITool(Protocol):
|
||||
"""Represents a generic tool that can be specified to an AI service."""
|
||||
"""Represents a generic tool that can be specified to an AI service.
|
||||
|
||||
Attributes:
|
||||
name: The name of the tool.
|
||||
description: A description of the tool.
|
||||
additional_properties: Additional properties associated with the tool.
|
||||
|
||||
Methods:
|
||||
parameters: The parameters accepted by the tool, in a json schema format.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
|
||||
@@ -4,13 +4,22 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Iterable, Iterator, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
MutableSequence,
|
||||
Sequence,
|
||||
)
|
||||
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, field_validator, model_validator
|
||||
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._tools import AITool
|
||||
from ._tools import AITool, ai_function
|
||||
from .exceptions import AgentFrameworkException
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@@ -1399,34 +1408,69 @@ class ChatOptions(AFBaseModel):
|
||||
"""Common request settings for AI services."""
|
||||
|
||||
ai_model_id: Annotated[str | None, Field(serialization_alias="model")] = None
|
||||
frequency_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
|
||||
logit_bias: MutableMapping[str | int, float] | None = None
|
||||
max_tokens: Annotated[int | None, Field(gt=0)] = None
|
||||
temperature: Annotated[float | None, Field(ge=0.0, le=2.0)] = None
|
||||
top_p: Annotated[float | None, Field(ge=0.0, le=1.0)] = None
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None
|
||||
tools: Sequence[AITool] | Sequence[MutableMapping[str, Any]] | None = None
|
||||
metadata: MutableMapping[str, str] | None = None
|
||||
presence_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None, description="Structured output response format schema. Must be a valid Pydantic model."
|
||||
)
|
||||
user: str | None = None
|
||||
stop: str | Sequence[str] | None = None
|
||||
frequency_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
|
||||
logit_bias: MutableMapping[str | int, float] | None = None
|
||||
presence_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
|
||||
seed: int | None = None
|
||||
stop: str | Sequence[str] | None = None
|
||||
store: bool | None = None
|
||||
metadata: MutableMapping[str, str] | None = None
|
||||
temperature: Annotated[float | None, Field(ge=0.0, le=2.0)] = None
|
||||
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None
|
||||
tools: list[AITool | MutableMapping[str, Any]] | None = None
|
||||
_ai_tools: list[AITool | MutableMapping[str, Any]] | None = PrivateAttr(default=None)
|
||||
top_p: Annotated[float | None, Field(ge=0.0, le=1.0)] = None
|
||||
user: str | None = None
|
||||
additional_properties: MutableMapping[str, Any] = Field(
|
||||
default_factory=dict, description="Provider-specific additional properties."
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _copy_to_ai_tools(self) -> Self:
|
||||
if self.tools and not self._ai_tools:
|
||||
self._ai_tools = self.tools
|
||||
return self
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def _validate_tools(
|
||||
cls,
|
||||
tools: (
|
||||
AITool
|
||||
| list[AITool]
|
||||
| Callable[..., Any]
|
||||
| list[Callable[..., Any]]
|
||||
| MutableMapping[str, Any]
|
||||
| list[MutableMapping[str, Any]]
|
||||
| None
|
||||
),
|
||||
) -> list[AITool | MutableMapping[str, Any]] | None:
|
||||
"""Parse the tools field.
|
||||
|
||||
All tools are stored in both tools and _ai_tools.
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools] # type: ignore[reportAssignmentType, assignment]
|
||||
for idx, tool in enumerate(tools): # type: ignore[reportArgumentType, arg-type]
|
||||
if not isinstance(tool, (AITool, MutableMapping)):
|
||||
# Convert to AITool if it's a function or callable
|
||||
tools[idx] = ai_function(tool) # type: ignore[reportIndexIssues, reportCallIssue, reportArgumentType, index, call-overload, arg-type]
|
||||
return tools # type: ignore[reportReturnType, return-value]
|
||||
|
||||
@field_validator("tool_choice", mode="before")
|
||||
@classmethod
|
||||
def _validate_tool_mode(
|
||||
cls, tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None
|
||||
) -> ChatToolMode:
|
||||
) -> ChatToolMode | None:
|
||||
"""Validates the tool_choice field to ensure it is a valid ChatToolMode."""
|
||||
if not tool_choice:
|
||||
return ChatToolMode.NONE
|
||||
return None
|
||||
if isinstance(tool_choice, str):
|
||||
match tool_choice:
|
||||
case "auto":
|
||||
@@ -1461,6 +1505,32 @@ class ChatOptions(AFBaseModel):
|
||||
settings.pop(key, None)
|
||||
return settings
|
||||
|
||||
def __and__(self, other: object) -> Self:
|
||||
"""Combines two ChatOptions instances.
|
||||
|
||||
The values from the other ChatOptions take precedence.
|
||||
List and dicts are combined.
|
||||
"""
|
||||
if not isinstance(other, ChatOptions):
|
||||
return self
|
||||
ai_tools = other._ai_tools
|
||||
updated_values = other.model_dump(exclude_none=True)
|
||||
updated_values.pop("tools", [])
|
||||
logit_bias = updated_values.pop("logit_bias", {})
|
||||
metadata = updated_values.pop("metadata", {})
|
||||
additional_properties = updated_values.pop("additional_properties", {})
|
||||
combined = self.model_copy(update=updated_values)
|
||||
if ai_tools:
|
||||
if not combined._ai_tools:
|
||||
combined._ai_tools = []
|
||||
for tool in ai_tools:
|
||||
if tool not in combined._ai_tools:
|
||||
combined._ai_tools.append(tool)
|
||||
combined.logit_bias = {**(combined.logit_bias or {}), **logit_bias}
|
||||
combined.metadata = {**(combined.metadata or {}), **metadata}
|
||||
combined.additional_properties = {**(combined.additional_properties or {}), **additional_properties}
|
||||
return combined
|
||||
|
||||
|
||||
# region: GeneratedEmbeddings
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pytest import fixture
|
||||
|
||||
from agent_framework import AITool, ai_function
|
||||
|
||||
|
||||
@fixture
|
||||
def ai_tool() -> AITool:
|
||||
"""Returns a generic AITool."""
|
||||
|
||||
class GenericTool(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
additional_properties: dict[str, Any] | None = None
|
||||
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""Return the parameters of the tool as a JSON schema."""
|
||||
return {
|
||||
"name": {"type": "string"},
|
||||
}
|
||||
|
||||
return GenericTool(name="generic_tool", description="A generic tool")
|
||||
|
||||
|
||||
@fixture
|
||||
def ai_function_tool() -> AITool:
|
||||
"""Returns a executable AITool."""
|
||||
|
||||
@ai_function
|
||||
def simple_function(x: int, y: int) -> int:
|
||||
"""A simple function that adds two numbers."""
|
||||
return x + y
|
||||
|
||||
return simple_function
|
||||
@@ -209,7 +209,7 @@ async def test_base_client_with_function_calling(chat_client_base: MockChatClien
|
||||
|
||||
|
||||
async def test_base_client_with_function_calling_disabled(chat_client_base: MockChatClientBase):
|
||||
chat_client_base.maximum_iterations_per_request = 0
|
||||
chat_client_base.__maximum_iterations_per_request = 0
|
||||
exec_counter = 0
|
||||
|
||||
@ai_function(name="test_function")
|
||||
@@ -273,7 +273,7 @@ async def test_base_client_with_streaming_function_calling(chat_client_base: Moc
|
||||
|
||||
|
||||
async def test_base_client_with_streaming_function_calling_disabled(chat_client_base: MockChatClientBase):
|
||||
chat_client_base.maximum_iterations_per_request = 0
|
||||
chat_client_base.__maximum_iterations_per_request = 0
|
||||
exec_counter = 0
|
||||
|
||||
@ai_function(name="test_function")
|
||||
|
||||
@@ -10,7 +10,9 @@ from agent_framework import (
|
||||
AgentRunResponseUpdate,
|
||||
AIContent,
|
||||
AIContents,
|
||||
AITool,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ChatRole,
|
||||
@@ -492,6 +494,99 @@ def test_generated_embeddings():
|
||||
assert issubclass(GeneratedEmbeddings, MutableSequence)
|
||||
|
||||
|
||||
# region: ChatOptions
|
||||
|
||||
|
||||
def test_chat_options_init() -> None:
|
||||
options = ChatOptions()
|
||||
assert options.ai_model_id is None
|
||||
|
||||
|
||||
def test_chat_options_init_with_args(ai_function_tool, ai_tool) -> None:
|
||||
options = ChatOptions(
|
||||
ai_model_id="gpt-4",
|
||||
max_tokens=1024,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.0,
|
||||
frequency_penalty=0.0,
|
||||
user="user-123",
|
||||
tools=[ai_function_tool, ai_tool],
|
||||
)
|
||||
assert options.ai_model_id == "gpt-4"
|
||||
assert options.max_tokens == 1024
|
||||
assert options.temperature == 0.7
|
||||
assert options.top_p == 0.9
|
||||
assert options.presence_penalty == 0.0
|
||||
assert options.frequency_penalty == 0.0
|
||||
assert options.user == "user-123"
|
||||
for tool in options._ai_tools:
|
||||
assert isinstance(tool, AITool)
|
||||
assert tool.name is not None
|
||||
assert tool.description is not None
|
||||
assert tool.parameters() is not None
|
||||
|
||||
|
||||
def test_chat_options_and(ai_function_tool, ai_tool) -> None:
|
||||
options1 = ChatOptions(ai_model_id="gpt-4o", tools=[ai_function_tool])
|
||||
options2 = ChatOptions(ai_model_id="gpt-4.1", tools=[ai_tool])
|
||||
assert options1 != options2
|
||||
options3 = options1 & options2
|
||||
|
||||
assert options3.ai_model_id == "gpt-4.1"
|
||||
assert len(options3._ai_tools) == 2
|
||||
assert options3._ai_tools == [ai_function_tool, ai_tool]
|
||||
assert options3.tools == [ai_function_tool, ai_tool]
|
||||
|
||||
|
||||
def test_chat_options_parsing_tools(ai_function_tool, ai_tool) -> None:
|
||||
from agent_framework._clients import _prepare_tools_and_tool_choice
|
||||
|
||||
def echo() -> str:
|
||||
"""Echo the input."""
|
||||
return "Echo"
|
||||
|
||||
dict_function = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Retrieves current weather for the given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City and country e.g. Bogotá, Colombia"},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Units the temperature will be returned in.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "units"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
options = ChatOptions(tools=[ai_function_tool, ai_tool, echo, dict_function], tool_choice="auto")
|
||||
assert len(options.tools) == 4
|
||||
assert options.tools[0] == ai_function_tool
|
||||
assert options.tools[1] == ai_tool
|
||||
assert options.tools[2] != echo
|
||||
assert options.tools[3] == dict_function
|
||||
# after prepare, the tools should be represented as dicts
|
||||
# while ai_tools is still the same.
|
||||
_prepare_tools_and_tool_choice(options)
|
||||
assert options._ai_tools[0] == ai_function_tool
|
||||
assert options._ai_tools[1] == ai_tool
|
||||
assert options._ai_tools[3] == dict_function
|
||||
assert len(options.tools) == 4
|
||||
assert options.tools[0]["function"]["name"] == "simple_function"
|
||||
assert options.tools[1]["function"]["name"] == "generic_tool"
|
||||
assert options.tools[2]["function"]["name"] == "echo"
|
||||
assert options.tools[3]["function"]["name"] == "get_weather"
|
||||
|
||||
|
||||
# region Agent Response Fixtures
|
||||
|
||||
|
||||
@@ -548,7 +643,7 @@ def test_agent_run_response_from_updates(agent_run_response_update: AgentRunResp
|
||||
updates = [agent_run_response_update, agent_run_response_update]
|
||||
response = AgentRunResponse.from_agent_run_response_updates(updates)
|
||||
assert len(response.messages) > 0
|
||||
assert response.text == "Test content\nTest content"
|
||||
assert response.text == "Test content Test content"
|
||||
|
||||
|
||||
def test_agent_run_response_str_method(chat_message: ChatMessage) -> None:
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import ChatClientAgent
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def main():
|
||||
instructions = "You are a helpful assistant, you can help the user with weather information."
|
||||
agent = ChatClientAgent(OpenAIChatClient(), instructions=instructions, tools=get_weather)
|
||||
print(str(await agent.run("What's the weather in Amsterdam?")))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user