Python: simple Agent sample (#180)

* tweaks to agents and sample

* updated clients and agents

* single line run and print

* improved tool handling

* added note on setting max iterations

* fixed streaming param name

* updated tools test

* made kwargs alphabetical

* added params to run methods

* tweak to ensure right overload
This commit is contained in:
Eduard van Valkenburg
2025-07-15 16:01:21 +02:00
committed by GitHub
Unverified
parent 59bf7343af
commit 62917ee5e5
8 changed files with 773 additions and 194 deletions
+281 -25
View File
@@ -1,15 +1,25 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Callable, Sequence
from collections.abc import AsyncIterable, Callable, MutableMapping, Sequence
from enum import Enum
from typing import Any, Protocol, TypeVar, runtime_checkable
from typing import Any, Literal, Protocol, TypeVar, runtime_checkable
from uuid import uuid4
from pydantic import Field
from pydantic import BaseModel, Field
from ._clients import ChatClient
from ._pydantic import AFBaseModel
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole
from ._tools import AITool
from ._types import (
AgentRunResponse,
AgentRunResponseUpdate,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
ChatRole,
ChatToolMode,
)
from .exceptions import AgentExecutionException
TThreadType = TypeVar("TThreadType", bound="AgentThread")
@@ -71,7 +81,7 @@ class Agent(Protocol):
async def run(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
@@ -100,7 +110,7 @@ class Agent(Protocol):
def run_stream(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
@@ -236,17 +246,168 @@ class ChatClientAgentThread(AgentThread):
class ChatClientAgent(AgentBase):
"""A Chat Client Agent which depends on ChatClient."""
"""A Chat Client Agent."""
chat_client: ChatClient
instructions: str | None = None
chat_options: ChatOptions
def __init__(
self,
chat_client: ChatClient,
instructions: str | None = None,
*,
id: str | None = None,
name: str | None = None,
description: str | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Create a ChatClientAgent.
Remarks:
The set of attributes from frequency_penalty to additional_properties are used to
call the chat client, they can also be passed to both run methods.
When both are set, the ones passed to the run methods take precedence.
Args:
chat_client: The chat client to use for the agent.
instructions: Optional instructions for the agent.
These will be put into the messages sent to the chat client service as a system message.
id: The unique identifier for the agent, will be created automatically if not provided.
name: The name of the agent.
description: A brief description of the agent's purpose.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request.
kwargs: any additional keyword arguments.
Unused, can be used by subclasses of this Agent.
"""
args: dict[str, Any] = {
"chat_client": chat_client,
"chat_options": ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
tool_choice=tool_choice,
tools=tools, # type: ignore
top_p=top_p,
user=user,
additional_properties=additional_properties or {},
),
}
if instructions is not None:
args["instructions"] = instructions
if name is not None:
args["name"] = name
if description is not None:
args["description"] = description
if id is not None:
args["id"] = id
super().__init__(**args)
async def run(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> AgentRunResponse:
"""Run the agent with the given messages and options.
Remarks:
Since you won't always call the agent.run directly, but it get's called
through orchestration, it is advised to set your default values for
all the chat client parameters in the agent constructor.
If both parameters are used, the ones passed to the run methods take precedence.
Args:
messages: The messages to process.
thread: The thread to use for the agent.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request.
kwargs: Additional keyword arguments for the agent.
will only be passed to functions that are called.
"""
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread,
input_messages=messages,
@@ -254,7 +415,29 @@ class ChatClientAgent(AgentBase):
expected_type=ChatClientAgentThread,
)
response = await self.chat_client.get_response(thread_messages, **kwargs)
response = await self.chat_client.get_response(
messages=thread_messages,
chat_options=self.chat_options
& ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
tool_choice=tool_choice,
tools=tools, # type: ignore
top_p=top_p,
user=user,
additional_properties=additional_properties or {},
),
**kwargs,
)
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
@@ -274,11 +457,64 @@ class ChatClientAgent(AgentBase):
async def run_stream(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
"""Stream the agent with the given messages and options.
Remarks:
Since you won't always call the agent.run_stream directly, but it get's called
through orchestration, it is advised to set your default values for
all the chat client parameters in the agent constructor.
If both parameters are used, the ones passed to the run methods take precedence.
Args:
messages: The messages to process.
thread: The thread to use for the agent.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request.
kwargs: any additional keyword arguments.
will only be passed to functions that are called.
"""
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread,
input_messages=messages,
@@ -288,9 +524,29 @@ class ChatClientAgent(AgentBase):
response_updates: list[ChatResponseUpdate] = []
streaming_response: AsyncIterable[ChatResponseUpdate] = self.chat_client.get_streaming_response(thread_messages)
async for update in streaming_response:
async for update in self.chat_client.get_streaming_response(
messages=thread_messages,
chat_options=self.chat_options
& ChatOptions(
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
ai_model_id=model,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
tool_choice=tool_choice,
tools=tools, # type: ignore
top_p=top_p,
user=user,
additional_properties=additional_properties or {},
),
**kwargs,
):
response_updates.append(update)
yield AgentRunResponseUpdate(
contents=update.contents,
@@ -348,7 +604,7 @@ class ChatClientAgent(AgentBase):
self,
*,
thread: AgentThread | None,
input_messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
input_messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None,
construct_thread: Callable[[], TThreadType],
expected_type: type[TThreadType],
) -> tuple[TThreadType, list[ChatMessage]]:
@@ -367,6 +623,8 @@ class ChatClientAgent(AgentBase):
AgentExecutionException: If thread type is incompatible.
"""
messages: list[ChatMessage] = []
if self.instructions:
messages.append(ChatMessage(role=ChatRole.SYSTEM, text=self.instructions))
if thread is None:
thread = construct_thread()
@@ -380,17 +638,15 @@ class ChatClientAgent(AgentBase):
if isinstance(thread, MessagesRetrievableThread):
async for message in thread.get_messages():
messages.append(message)
if input_messages is None:
input_messages = []
if isinstance(input_messages, (str, ChatMessage)):
input_messages = [input_messages]
normalized_messages = [
return thread, messages
if isinstance(input_messages, str):
messages.append(ChatMessage(role=ChatRole.USER, text=input_messages))
return thread, messages
if isinstance(input_messages, ChatMessage):
messages.append(input_messages)
return thread, messages
messages.extend([
ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in input_messages
]
messages.extend(normalized_messages)
])
return thread, messages
+234 -150
View File
@@ -2,11 +2,11 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, MutableSequence, Sequence
from functools import wraps
from typing import Annotated, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
from pydantic import BaseModel, PrivateAttr, StringConstraints
from pydantic import BaseModel, StringConstraints
from ._logging import get_logger
from ._pydantic import AFBaseModel
@@ -37,24 +37,6 @@ logger = get_logger()
# region: Tool Calling Functions and Decorators
def _merge_function_results(
messages: list[ChatMessage],
) -> ChatMessage:
"""Combine multiple function result content types to one chat message content type.
This method combines the FunctionResultContent items from separate ChatMessageContent messages,
and is used in the event that the `context.terminate = True` condition is met.
"""
contents: list[Any] = []
for message in messages:
contents.extend([item for item in message.contents if isinstance(item, FunctionResultContent)])
return ChatMessage(
role="tool",
contents=contents,
)
async def _auto_invoke_function(
function_call_content: FunctionCallContent,
custom_args: dict[str, Any] | None = None,
@@ -106,7 +88,7 @@ def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None:
chat_options.tool_choice = ChatToolMode.NONE.mode
return
chat_options.tools = [
(_tool_to_json_schema_spec(t) if isinstance(t, AITool) else t) for t in chat_options.tools or []
(_tool_to_json_schema_spec(t) if isinstance(t, AITool) else t) for t in chat_options._ai_tools or []
]
if not chat_options.tools:
chat_options.tool_choice = ChatToolMode.NONE.mode
@@ -115,11 +97,7 @@ def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None:
def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
"""Decorate the internal _inner_get_response method to enable tool calls.
Remarks:
Relies on a class that has the _tool_map attribute for the executable tools to call.
"""
"""Decorate the internal _inner_get_response method to enable tool calls."""
@wraps(func)
async def wrapper(
@@ -131,7 +109,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
) -> ChatResponse:
response: ChatResponse | None = None
fcc_messages: list[ChatMessage] = []
for attempt_idx in range(self.maximum_iterations_per_request):
for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)):
response = await func(self, messages=messages, chat_options=chat_options)
# if there are function calls, we will handle them first
function_calls = [it for it in response.messages[0].contents if isinstance(it, FunctionCallContent)]
@@ -141,7 +119,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
_auto_invoke_function(
function_call,
custom_args=kwargs,
tool_map=self._tool_map,
tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)},
sequence_index=seq_idx,
request_index=attempt_idx,
)
@@ -181,11 +159,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreamingResponse:
"""Decorate the internal _inner_get_response method to enable tool calls.
Remarks:
Relies on a class that has the _tool_map attribute for the executable tools to call.
"""
"""Decorate the internal _inner_get_response method to enable tool calls."""
@wraps(func)
async def wrapper(
@@ -195,7 +169,8 @@ def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreaming
chat_options: ChatOptions,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
for attempt_idx in range(self.maximum_iterations_per_request):
"""Wrap the inner get streaming response method to handle tool calls."""
for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)):
function_call_returned = False
all_messages: list[ChatResponseUpdate] = []
async for update in func(self, messages=messages, chat_options=chat_options):
@@ -221,7 +196,7 @@ def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreaming
_auto_invoke_function(
function_call,
custom_args=kwargs,
tool_map=self._tool_map,
tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)},
sequence_index=seq_idx,
request_index=attempt_idx,
)
@@ -247,10 +222,25 @@ def use_tool_calling(cls: type[TChatClientBase]) -> type[TChatClientBase]:
Remarks:
This only works on classes that derive from ChatClientBase
and have the _tool_map attribute as well as the _inner_get_response
and the _inner_get_response
and _inner_get_streaming_response methods.
It also sets a __maximum_iterations_per_request attribute on the class.
if you want to expose this to end_users, do a version of this:
```python
@use_tool_calling
class MyChatClient(ChatClientBase):
@property
def maximum_iterations_per_request(self):
return getattr(self, "__maximum_iterations_per_request", 10)
@maximum_iterations_per_request.setter
def maximum_iterations_per_request(self, value: int) -> None:
setattr(self, "__maximum_iterations_per_request", value)
```
"""
setattr(cls, "__maximum_iterations_per_request", 10)
if inner_response := getattr(cls, "_inner_get_response", None):
cls._inner_get_response = _tool_call_non_streaming(inner_response) # type: ignore
if inner_streaming_response := getattr(cls, "_inner_get_streaming_response", None):
@@ -267,15 +257,54 @@ class ChatClient(Protocol):
async def get_response(
self,
messages: str | ChatMessage | list[ChatMessage],
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
"""Sends input and returns the response.
Args:
messages: The sequence of input messages to send.
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
See `ChatOptions` for more details.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request
kwargs: any additional keyword arguments,
will only be passed to functions that are called.
Returns:
The response messages generated by the client.
@@ -287,15 +316,54 @@ class ChatClient(Protocol):
def get_streaming_response(
self,
messages: str | ChatMessage | list[ChatMessage],
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Sends input messages and streams the response.
Args:
messages: The sequence of input messages to send.
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
See `ChatOptions` for more details.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request
kwargs: any additional keyword arguments,
will only be passed to functions that are called.
Yields:
An async iterable of chat response updates containing the content of the response messages
@@ -311,19 +379,21 @@ class ChatClientBase(AFBaseModel, ABC):
"""Base class for chat clients."""
ai_model_id: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]
maximum_iterations_per_request: int = 10
_tool_map: dict[str, AIFunction[BaseModel, Any]] = PrivateAttr(default_factory=dict) # type: ignore
def _prepare_messages(self, messages: str | ChatMessage | list[str | ChatMessage]) -> MutableSequence[ChatMessage]:
def _prepare_messages(
self, messages: str | ChatMessage | list[str] | list[ChatMessage]
) -> MutableSequence[ChatMessage]:
"""Turn the allowed input into a list of chat messages."""
if isinstance(messages, str):
messages = [ChatMessage(role="user", text=messages)]
return [ChatMessage(role="user", text=messages)]
if isinstance(messages, ChatMessage):
messages = [messages]
for i, msg in enumerate(messages):
return [messages]
return_messages: list[ChatMessage] = []
for msg in messages:
if isinstance(msg, str):
messages[i] = ChatMessage(role="user", text=msg)
return messages # type: ignore[return-value]
msg = ChatMessage(role="user", text=msg)
return_messages.append(msg)
return return_messages
# region Internal methods to be implemented by the derived classes
@@ -377,23 +447,29 @@ class ChatClientBase(AFBaseModel, ABC):
async def get_response(
self,
messages: str | ChatMessage | list[str | ChatMessage],
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: AITool | Sequence[AITool] | None = None,
response_format: type[BaseModel] | None = None,
user: str | None = None,
stop: str | Sequence[str] | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
store: bool | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
@@ -401,73 +477,80 @@ class ChatClientBase(AFBaseModel, ABC):
Args:
messages: the message or messages to send to the model
model: the model to use for the request
max_tokens: the maximum number of tokens to generate
temperature: the sampling temperature to use
top_p: the nucleus sampling probability to use
tool_choice: the tool choice for the request
tools: the tools to use for the request
response_format: the format of the response
user: the user to associate with the request
stop: the stop sequence(s) for the request
frequency_penalty: the frequency penalty to use
logit_bias: the logit bias to use
presence_penalty: the presence penalty to use
seed: the random seed to use
store: whether to store the response
metadata: additional metadata to include in the request
additional_properties: additional properties to include in the request
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request.
kwargs: any additional keyword arguments,
will only be passed to functions that are called.
Returns:
A chat response from the model.
"""
if tools is not None:
if not isinstance(tools, Sequence):
tools = [tools]
self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)}
chat_options = ChatOptions(
ai_model_id=model,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
tool_choice=tool_choice,
tools=tools,
response_format=response_format,
user=user,
stop=stop,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
presence_penalty=presence_penalty,
seed=seed,
store=store,
metadata=metadata,
additional_properties=additional_properties or {},
)
if "chat_options" in kwargs:
chat_options = kwargs.pop("chat_options")
if not isinstance(chat_options, ChatOptions):
raise TypeError("chat_options must be an instance of ChatOptions")
else:
chat_options = ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
top_p=top_p,
tool_choice=tool_choice,
tools=tools, # type: ignore
user=user,
additional_properties=additional_properties or {},
)
prepped_messages = self._prepare_messages(messages)
_prepare_tools_and_tool_choice(chat_options=chat_options)
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
async def get_streaming_response(
self,
messages: str | ChatMessage | list[str | ChatMessage],
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
model: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: AITool | Sequence[AITool] | None = None,
response_format: type[BaseModel] | None = None,
user: str | None = None,
stop: str | Sequence[str] | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
store: bool | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
@@ -475,50 +558,51 @@ class ChatClientBase(AFBaseModel, ABC):
Args:
messages: the message or messages to send to the model
model: the model to use for the request
max_tokens: the maximum number of tokens to generate
temperature: the sampling temperature to use
top_p: the nucleus sampling probability to use
tool_choice: the tool choice for the request
tools: the tools to use for the request
response_format: the format of the response
user: the user to associate with the request
stop: the stop sequence(s) for the request
frequency_penalty: the frequency penalty to use
logit_bias: the logit bias to use
presence_penalty: the presence penalty to use
seed: the random seed to use
store: whether to store the response
metadata: additional metadata to include in the request
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request
kwargs: any additional keyword arguments
Yields:
A stream representing the response(s) from the LLM.
"""
if tools is not None:
if not isinstance(tools, Sequence):
tools = [tools]
self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)}
chat_options = ChatOptions(
ai_model_id=model,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
tool_choice=tool_choice,
tools=tools,
response_format=response_format,
user=user,
stop=stop,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
presence_penalty=presence_penalty,
seed=seed,
store=store,
metadata=metadata,
additional_properties=additional_properties or {},
**kwargs,
)
if "chat_options" in kwargs:
chat_options = kwargs.pop("chat_options")
if not isinstance(chat_options, ChatOptions):
raise TypeError("chat_options must be an instance of ChatOptions")
else:
chat_options = ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
top_p=top_p,
tool_choice=tool_choice,
tools=tools, # type: ignore
user=user,
additional_properties=additional_properties or {},
**kwargs,
)
prepped_messages = self._prepare_messages(messages)
_prepare_tools_and_tool_choice(chat_options=chat_options)
async for update in self._inner_get_streaming_response(
+10 -1
View File
@@ -10,7 +10,16 @@ from pydantic import BaseModel, create_model
@runtime_checkable
class AITool(Protocol):
"""Represents a generic tool that can be specified to an AI service."""
"""Represents a generic tool that can be specified to an AI service.
Attributes:
name: The name of the tool.
description: A description of the tool.
additional_properties: Additional properties associated with the tool.
Methods:
parameters: The parameters accepted by the tool, in a json schema format.
"""
name: str
"""The name of the tool."""
+85 -15
View File
@@ -4,13 +4,22 @@ import base64
import json
import re
import sys
from collections.abc import AsyncIterable, Iterable, Iterator, Mapping, MutableMapping, MutableSequence, Sequence
from collections.abc import (
AsyncIterable,
Callable,
Iterable,
Iterator,
Mapping,
MutableMapping,
MutableSequence,
Sequence,
)
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, field_validator, model_validator
from ._pydantic import AFBaseModel
from ._tools import AITool
from ._tools import AITool, ai_function
from .exceptions import AgentFrameworkException
if sys.version_info >= (3, 11):
@@ -1399,34 +1408,69 @@ class ChatOptions(AFBaseModel):
"""Common request settings for AI services."""
ai_model_id: Annotated[str | None, Field(serialization_alias="model")] = None
frequency_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
logit_bias: MutableMapping[str | int, float] | None = None
max_tokens: Annotated[int | None, Field(gt=0)] = None
temperature: Annotated[float | None, Field(ge=0.0, le=2.0)] = None
top_p: Annotated[float | None, Field(ge=0.0, le=1.0)] = None
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None
tools: Sequence[AITool] | Sequence[MutableMapping[str, Any]] | None = None
metadata: MutableMapping[str, str] | None = None
presence_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
response_format: type[BaseModel] | None = Field(
default=None, description="Structured output response format schema. Must be a valid Pydantic model."
)
user: str | None = None
stop: str | Sequence[str] | None = None
frequency_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
logit_bias: MutableMapping[str | int, float] | None = None
presence_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None
seed: int | None = None
stop: str | Sequence[str] | None = None
store: bool | None = None
metadata: MutableMapping[str, str] | None = None
temperature: Annotated[float | None, Field(ge=0.0, le=2.0)] = None
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None
tools: list[AITool | MutableMapping[str, Any]] | None = None
_ai_tools: list[AITool | MutableMapping[str, Any]] | None = PrivateAttr(default=None)
top_p: Annotated[float | None, Field(ge=0.0, le=1.0)] = None
user: str | None = None
additional_properties: MutableMapping[str, Any] = Field(
default_factory=dict, description="Provider-specific additional properties."
)
@model_validator(mode="after")
def _copy_to_ai_tools(self) -> Self:
if self.tools and not self._ai_tools:
self._ai_tools = self.tools
return self
@field_validator("tools", mode="before")
@classmethod
def _validate_tools(
cls,
tools: (
AITool
| list[AITool]
| Callable[..., Any]
| list[Callable[..., Any]]
| MutableMapping[str, Any]
| list[MutableMapping[str, Any]]
| None
),
) -> list[AITool | MutableMapping[str, Any]] | None:
"""Parse the tools field.
All tools are stored in both tools and _ai_tools.
"""
if not tools:
return None
if not isinstance(tools, list):
tools = [tools] # type: ignore[reportAssignmentType, assignment]
for idx, tool in enumerate(tools): # type: ignore[reportArgumentType, arg-type]
if not isinstance(tool, (AITool, MutableMapping)):
# Convert to AITool if it's a function or callable
tools[idx] = ai_function(tool) # type: ignore[reportIndexIssues, reportCallIssue, reportArgumentType, index, call-overload, arg-type]
return tools # type: ignore[reportReturnType, return-value]
@field_validator("tool_choice", mode="before")
@classmethod
def _validate_tool_mode(
cls, tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None
) -> ChatToolMode:
) -> ChatToolMode | None:
"""Validates the tool_choice field to ensure it is a valid ChatToolMode."""
if not tool_choice:
return ChatToolMode.NONE
return None
if isinstance(tool_choice, str):
match tool_choice:
case "auto":
@@ -1461,6 +1505,32 @@ class ChatOptions(AFBaseModel):
settings.pop(key, None)
return settings
def __and__(self, other: object) -> Self:
"""Combines two ChatOptions instances.
The values from the other ChatOptions take precedence.
List and dicts are combined.
"""
if not isinstance(other, ChatOptions):
return self
ai_tools = other._ai_tools
updated_values = other.model_dump(exclude_none=True)
updated_values.pop("tools", [])
logit_bias = updated_values.pop("logit_bias", {})
metadata = updated_values.pop("metadata", {})
additional_properties = updated_values.pop("additional_properties", {})
combined = self.model_copy(update=updated_values)
if ai_tools:
if not combined._ai_tools:
combined._ai_tools = []
for tool in ai_tools:
if tool not in combined._ai_tools:
combined._ai_tools.append(tool)
combined.logit_bias = {**(combined.logit_bias or {}), **logit_bias}
combined.metadata = {**(combined.metadata or {}), **metadata}
combined.additional_properties = {**(combined.additional_properties or {}), **additional_properties}
return combined
# region: GeneratedEmbeddings
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Any
from pydantic import BaseModel
from pytest import fixture
from agent_framework import AITool, ai_function
@fixture
def ai_tool() -> AITool:
"""Returns a generic AITool."""
class GenericTool(BaseModel):
name: str
description: str | None = None
additional_properties: dict[str, Any] | None = None
def parameters(self) -> dict[str, Any]:
"""Return the parameters of the tool as a JSON schema."""
return {
"name": {"type": "string"},
}
return GenericTool(name="generic_tool", description="A generic tool")
@fixture
def ai_function_tool() -> AITool:
"""Returns a executable AITool."""
@ai_function
def simple_function(x: int, y: int) -> int:
"""A simple function that adds two numbers."""
return x + y
return simple_function
@@ -209,7 +209,7 @@ async def test_base_client_with_function_calling(chat_client_base: MockChatClien
async def test_base_client_with_function_calling_disabled(chat_client_base: MockChatClientBase):
chat_client_base.maximum_iterations_per_request = 0
chat_client_base.__maximum_iterations_per_request = 0
exec_counter = 0
@ai_function(name="test_function")
@@ -273,7 +273,7 @@ async def test_base_client_with_streaming_function_calling(chat_client_base: Moc
async def test_base_client_with_streaming_function_calling_disabled(chat_client_base: MockChatClientBase):
chat_client_base.maximum_iterations_per_request = 0
chat_client_base.__maximum_iterations_per_request = 0
exec_counter = 0
@ai_function(name="test_function")
+96 -1
View File
@@ -10,7 +10,9 @@ from agent_framework import (
AgentRunResponseUpdate,
AIContent,
AIContents,
AITool,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
ChatRole,
@@ -492,6 +494,99 @@ def test_generated_embeddings():
assert issubclass(GeneratedEmbeddings, MutableSequence)
# region: ChatOptions
def test_chat_options_init() -> None:
options = ChatOptions()
assert options.ai_model_id is None
def test_chat_options_init_with_args(ai_function_tool, ai_tool) -> None:
options = ChatOptions(
ai_model_id="gpt-4",
max_tokens=1024,
temperature=0.7,
top_p=0.9,
presence_penalty=0.0,
frequency_penalty=0.0,
user="user-123",
tools=[ai_function_tool, ai_tool],
)
assert options.ai_model_id == "gpt-4"
assert options.max_tokens == 1024
assert options.temperature == 0.7
assert options.top_p == 0.9
assert options.presence_penalty == 0.0
assert options.frequency_penalty == 0.0
assert options.user == "user-123"
for tool in options._ai_tools:
assert isinstance(tool, AITool)
assert tool.name is not None
assert tool.description is not None
assert tool.parameters() is not None
def test_chat_options_and(ai_function_tool, ai_tool) -> None:
options1 = ChatOptions(ai_model_id="gpt-4o", tools=[ai_function_tool])
options2 = ChatOptions(ai_model_id="gpt-4.1", tools=[ai_tool])
assert options1 != options2
options3 = options1 & options2
assert options3.ai_model_id == "gpt-4.1"
assert len(options3._ai_tools) == 2
assert options3._ai_tools == [ai_function_tool, ai_tool]
assert options3.tools == [ai_function_tool, ai_tool]
def test_chat_options_parsing_tools(ai_function_tool, ai_tool) -> None:
from agent_framework._clients import _prepare_tools_and_tool_choice
def echo() -> str:
"""Echo the input."""
return "Echo"
dict_function = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Retrieves current weather for the given location.",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City and country e.g. Bogotá, Colombia"},
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Units the temperature will be returned in.",
},
},
"required": ["location", "units"],
"additionalProperties": False,
},
"strict": True,
},
}
options = ChatOptions(tools=[ai_function_tool, ai_tool, echo, dict_function], tool_choice="auto")
assert len(options.tools) == 4
assert options.tools[0] == ai_function_tool
assert options.tools[1] == ai_tool
assert options.tools[2] != echo
assert options.tools[3] == dict_function
# after prepare, the tools should be represented as dicts
# while ai_tools is still the same.
_prepare_tools_and_tool_choice(options)
assert options._ai_tools[0] == ai_function_tool
assert options._ai_tools[1] == ai_tool
assert options._ai_tools[3] == dict_function
assert len(options.tools) == 4
assert options.tools[0]["function"]["name"] == "simple_function"
assert options.tools[1]["function"]["name"] == "generic_tool"
assert options.tools[2]["function"]["name"] == "echo"
assert options.tools[3]["function"]["name"] == "get_weather"
# region Agent Response Fixtures
@@ -548,7 +643,7 @@ def test_agent_run_response_from_updates(agent_run_response_update: AgentRunResp
updates = [agent_run_response_update, agent_run_response_update]
response = AgentRunResponse.from_agent_run_response_updates(updates)
assert len(response.messages) > 0
assert response.text == "Test content\nTest content"
assert response.text == "Test content Test content"
def test_agent_run_response_str_method(chat_message: ChatMessage) -> None:
@@ -0,0 +1,27 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from random import randint
from typing import Annotated
from agent_framework import ChatClientAgent
from agent_framework.openai import OpenAIChatClient
from pydantic import Field
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def main():
instructions = "You are a helpful assistant, you can help the user with weather information."
agent = ChatClientAgent(OpenAIChatClient(), instructions=instructions, tools=get_weather)
print(str(await agent.run("What's the weather in Amsterdam?")))
if __name__ == "__main__":
asyncio.run(main())