diff --git a/python/packages/main/agent_framework/_agents.py b/python/packages/main/agent_framework/_agents.py index 0606d4ed4c..c4ec3cf6dc 100644 --- a/python/packages/main/agent_framework/_agents.py +++ b/python/packages/main/agent_framework/_agents.py @@ -1,15 +1,25 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, Callable, Sequence +from collections.abc import AsyncIterable, Callable, MutableMapping, Sequence from enum import Enum -from typing import Any, Protocol, TypeVar, runtime_checkable +from typing import Any, Literal, Protocol, TypeVar, runtime_checkable from uuid import uuid4 -from pydantic import Field +from pydantic import BaseModel, Field from ._clients import ChatClient from ._pydantic import AFBaseModel -from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole +from ._tools import AITool +from ._types import ( + AgentRunResponse, + AgentRunResponseUpdate, + ChatMessage, + ChatOptions, + ChatResponse, + ChatResponseUpdate, + ChatRole, + ChatToolMode, +) from .exceptions import AgentExecutionException TThreadType = TypeVar("TThreadType", bound="AgentThread") @@ -71,7 +81,7 @@ class Agent(Protocol): async def run( self, - messages: str | ChatMessage | list[str | ChatMessage] | None = None, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, @@ -100,7 +110,7 @@ class Agent(Protocol): def run_stream( self, - messages: str | ChatMessage | list[str | ChatMessage] | None = None, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, @@ -236,17 +246,168 @@ class ChatClientAgentThread(AgentThread): class ChatClientAgent(AgentBase): - """A Chat Client Agent which depends on ChatClient.""" + """A Chat Client Agent.""" chat_client: ChatClient + instructions: str | None = None + chat_options: ChatOptions + + def __init__( + self, + chat_client: ChatClient, + instructions: str | None = None, + *, + id: str | None = None, + name: str | None = None, + description: str | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str | int, float] | None = None, + max_tokens: int | None = None, + metadata: dict[str, Any] | None = None, + model: str | None = None, + presence_penalty: float | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + stop: str | Sequence[str] | None = None, + store: bool | None = None, + temperature: float | None = None, + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tools: AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None = None, + top_p: float | None = None, + user: str | None = None, + additional_properties: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Create a ChatClientAgent. + + Remarks: + The set of attributes from frequency_penalty to additional_properties are used to + call the chat client, they can also be passed to both run methods. + When both are set, the ones passed to the run methods take precedence. + + Args: + chat_client: The chat client to use for the agent. + instructions: Optional instructions for the agent. + These will be put into the messages sent to the chat client service as a system message. + id: The unique identifier for the agent, will be created automatically if not provided. + name: The name of the agent. + description: A brief description of the agent's purpose. + frequency_penalty: the frequency penalty to use. + logit_bias: the logit bias to use. + max_tokens: The maximum number of tokens to generate. + metadata: additional metadata to include in the request. + model: The model to use for the agent. + presence_penalty: the presence penalty to use. + response_format: the format of the response. + seed: the random seed to use. + stop: the stop sequence(s) for the request. + store: whether to store the response. + temperature: the sampling temperature to use. + tool_choice: the tool choice for the request. + tools: the tools to use for the request. + top_p: the nucleus sampling probability to use. + user: the user to associate with the request. + additional_properties: additional properties to include in the request. + kwargs: any additional keyword arguments. + Unused, can be used by subclasses of this Agent. + """ + args: dict[str, Any] = { + "chat_client": chat_client, + "chat_options": ChatOptions( + ai_model_id=model, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + max_tokens=max_tokens, + metadata=metadata, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + store=store, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, # type: ignore + top_p=top_p, + user=user, + additional_properties=additional_properties or {}, + ), + } + if instructions is not None: + args["instructions"] = instructions + if name is not None: + args["name"] = name + if description is not None: + args["description"] = description + if id is not None: + args["id"] = id + + super().__init__(**args) async def run( self, - messages: str | ChatMessage | list[str | ChatMessage] | None = None, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str | int, float] | None = None, + max_tokens: int | None = None, + metadata: dict[str, Any] | None = None, + model: str | None = None, + presence_penalty: float | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + stop: str | Sequence[str] | None = None, + store: bool | None = None, + temperature: float | None = None, + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, + tools: AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None = None, + top_p: float | None = None, + user: str | None = None, + additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> AgentRunResponse: + """Run the agent with the given messages and options. + + Remarks: + Since you won't always call the agent.run directly, but it get's called + through orchestration, it is advised to set your default values for + all the chat client parameters in the agent constructor. + If both parameters are used, the ones passed to the run methods take precedence. + + Args: + messages: The messages to process. + thread: The thread to use for the agent. + frequency_penalty: the frequency penalty to use. + logit_bias: the logit bias to use. + max_tokens: The maximum number of tokens to generate. + metadata: additional metadata to include in the request. + model: The model to use for the agent. + presence_penalty: the presence penalty to use. + response_format: the format of the response. + seed: the random seed to use. + stop: the stop sequence(s) for the request. + store: whether to store the response. + temperature: the sampling temperature to use. + tool_choice: the tool choice for the request. + tools: the tools to use for the request. + top_p: the nucleus sampling probability to use. + user: the user to associate with the request. + additional_properties: additional properties to include in the request. + kwargs: Additional keyword arguments for the agent. + will only be passed to functions that are called. + """ thread, thread_messages = await self._prepare_thread_and_messages( thread=thread, input_messages=messages, @@ -254,7 +415,29 @@ class ChatClientAgent(AgentBase): expected_type=ChatClientAgentThread, ) - response = await self.chat_client.get_response(thread_messages, **kwargs) + response = await self.chat_client.get_response( + messages=thread_messages, + chat_options=self.chat_options + & ChatOptions( + ai_model_id=model, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + max_tokens=max_tokens, + metadata=metadata, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + store=store, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, # type: ignore + top_p=top_p, + user=user, + additional_properties=additional_properties or {}, + ), + **kwargs, + ) self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) @@ -274,11 +457,64 @@ class ChatClientAgent(AgentBase): async def run_stream( self, - messages: str | ChatMessage | list[str | ChatMessage] | None = None, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: AgentThread | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str | int, float] | None = None, + max_tokens: int | None = None, + metadata: dict[str, Any] | None = None, + model: str | None = None, + presence_penalty: float | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + stop: str | Sequence[str] | None = None, + store: bool | None = None, + temperature: float | None = None, + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None, + tools: AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None = None, + top_p: float | None = None, + user: str | None = None, + additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: + """Stream the agent with the given messages and options. + + Remarks: + Since you won't always call the agent.run_stream directly, but it get's called + through orchestration, it is advised to set your default values for + all the chat client parameters in the agent constructor. + If both parameters are used, the ones passed to the run methods take precedence. + + Args: + messages: The messages to process. + thread: The thread to use for the agent. + frequency_penalty: the frequency penalty to use. + logit_bias: the logit bias to use. + max_tokens: The maximum number of tokens to generate. + metadata: additional metadata to include in the request. + model: The model to use for the agent. + presence_penalty: the presence penalty to use. + response_format: the format of the response. + seed: the random seed to use. + stop: the stop sequence(s) for the request. + store: whether to store the response. + temperature: the sampling temperature to use. + tool_choice: the tool choice for the request. + tools: the tools to use for the request. + top_p: the nucleus sampling probability to use. + user: the user to associate with the request. + additional_properties: additional properties to include in the request. + kwargs: any additional keyword arguments. + will only be passed to functions that are called. + + """ thread, thread_messages = await self._prepare_thread_and_messages( thread=thread, input_messages=messages, @@ -288,9 +524,29 @@ class ChatClientAgent(AgentBase): response_updates: list[ChatResponseUpdate] = [] - streaming_response: AsyncIterable[ChatResponseUpdate] = self.chat_client.get_streaming_response(thread_messages) - - async for update in streaming_response: + async for update in self.chat_client.get_streaming_response( + messages=thread_messages, + chat_options=self.chat_options + & ChatOptions( + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + max_tokens=max_tokens, + metadata=metadata, + ai_model_id=model, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + store=store, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, # type: ignore + top_p=top_p, + user=user, + additional_properties=additional_properties or {}, + ), + **kwargs, + ): response_updates.append(update) yield AgentRunResponseUpdate( contents=update.contents, @@ -348,7 +604,7 @@ class ChatClientAgent(AgentBase): self, *, thread: AgentThread | None, - input_messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, + input_messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None, construct_thread: Callable[[], TThreadType], expected_type: type[TThreadType], ) -> tuple[TThreadType, list[ChatMessage]]: @@ -367,6 +623,8 @@ class ChatClientAgent(AgentBase): AgentExecutionException: If thread type is incompatible. """ messages: list[ChatMessage] = [] + if self.instructions: + messages.append(ChatMessage(role=ChatRole.SYSTEM, text=self.instructions)) if thread is None: thread = construct_thread() @@ -380,17 +638,15 @@ class ChatClientAgent(AgentBase): if isinstance(thread, MessagesRetrievableThread): async for message in thread.get_messages(): messages.append(message) - if input_messages is None: - input_messages = [] - - if isinstance(input_messages, (str, ChatMessage)): - input_messages = [input_messages] - - normalized_messages = [ + return thread, messages + if isinstance(input_messages, str): + messages.append(ChatMessage(role=ChatRole.USER, text=input_messages)) + return thread, messages + if isinstance(input_messages, ChatMessage): + messages.append(input_messages) + return thread, messages + messages.extend([ ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in input_messages - ] - - messages.extend(normalized_messages) - + ]) return thread, messages diff --git a/python/packages/main/agent_framework/_clients.py b/python/packages/main/agent_framework/_clients.py index 04cc7ad892..d8d10a3d5d 100644 --- a/python/packages/main/agent_framework/_clients.py +++ b/python/packages/main/agent_framework/_clients.py @@ -2,11 +2,11 @@ import asyncio from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, MutableSequence, Sequence from functools import wraps from typing import Annotated, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable -from pydantic import BaseModel, PrivateAttr, StringConstraints +from pydantic import BaseModel, StringConstraints from ._logging import get_logger from ._pydantic import AFBaseModel @@ -37,24 +37,6 @@ logger = get_logger() # region: Tool Calling Functions and Decorators -def _merge_function_results( - messages: list[ChatMessage], -) -> ChatMessage: - """Combine multiple function result content types to one chat message content type. - - This method combines the FunctionResultContent items from separate ChatMessageContent messages, - and is used in the event that the `context.terminate = True` condition is met. - """ - contents: list[Any] = [] - for message in messages: - contents.extend([item for item in message.contents if isinstance(item, FunctionResultContent)]) - - return ChatMessage( - role="tool", - contents=contents, - ) - - async def _auto_invoke_function( function_call_content: FunctionCallContent, custom_args: dict[str, Any] | None = None, @@ -106,7 +88,7 @@ def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None: chat_options.tool_choice = ChatToolMode.NONE.mode return chat_options.tools = [ - (_tool_to_json_schema_spec(t) if isinstance(t, AITool) else t) for t in chat_options.tools or [] + (_tool_to_json_schema_spec(t) if isinstance(t, AITool) else t) for t in chat_options._ai_tools or [] ] if not chat_options.tools: chat_options.tool_choice = ChatToolMode.NONE.mode @@ -115,11 +97,7 @@ def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None: def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse: - """Decorate the internal _inner_get_response method to enable tool calls. - - Remarks: - Relies on a class that has the _tool_map attribute for the executable tools to call. - """ + """Decorate the internal _inner_get_response method to enable tool calls.""" @wraps(func) async def wrapper( @@ -131,7 +109,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse: ) -> ChatResponse: response: ChatResponse | None = None fcc_messages: list[ChatMessage] = [] - for attempt_idx in range(self.maximum_iterations_per_request): + for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)): response = await func(self, messages=messages, chat_options=chat_options) # if there are function calls, we will handle them first function_calls = [it for it in response.messages[0].contents if isinstance(it, FunctionCallContent)] @@ -141,7 +119,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse: _auto_invoke_function( function_call, custom_args=kwargs, - tool_map=self._tool_map, + tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)}, sequence_index=seq_idx, request_index=attempt_idx, ) @@ -181,11 +159,7 @@ def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse: def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreamingResponse: - """Decorate the internal _inner_get_response method to enable tool calls. - - Remarks: - Relies on a class that has the _tool_map attribute for the executable tools to call. - """ + """Decorate the internal _inner_get_response method to enable tool calls.""" @wraps(func) async def wrapper( @@ -195,7 +169,8 @@ def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreaming chat_options: ChatOptions, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: - for attempt_idx in range(self.maximum_iterations_per_request): + """Wrap the inner get streaming response method to handle tool calls.""" + for attempt_idx in range(getattr(self, "__maximum_iterations_per_request", 10)): function_call_returned = False all_messages: list[ChatResponseUpdate] = [] async for update in func(self, messages=messages, chat_options=chat_options): @@ -221,7 +196,7 @@ def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreaming _auto_invoke_function( function_call, custom_args=kwargs, - tool_map=self._tool_map, + tool_map={t.name: t for t in chat_options._ai_tools or [] if isinstance(t, AIFunction)}, sequence_index=seq_idx, request_index=attempt_idx, ) @@ -247,10 +222,25 @@ def use_tool_calling(cls: type[TChatClientBase]) -> type[TChatClientBase]: Remarks: This only works on classes that derive from ChatClientBase - and have the _tool_map attribute as well as the _inner_get_response + and the _inner_get_response and _inner_get_streaming_response methods. + It also sets a __maximum_iterations_per_request attribute on the class. + if you want to expose this to end_users, do a version of this: + ```python + @use_tool_calling + class MyChatClient(ChatClientBase): + @property + def maximum_iterations_per_request(self): + return getattr(self, "__maximum_iterations_per_request", 10) + + @maximum_iterations_per_request.setter + def maximum_iterations_per_request(self, value: int) -> None: + setattr(self, "__maximum_iterations_per_request", value) + ``` """ + setattr(cls, "__maximum_iterations_per_request", 10) + if inner_response := getattr(cls, "_inner_get_response", None): cls._inner_get_response = _tool_call_non_streaming(inner_response) # type: ignore if inner_streaming_response := getattr(cls, "_inner_get_streaming_response", None): @@ -267,15 +257,54 @@ class ChatClient(Protocol): async def get_response( self, - messages: str | ChatMessage | list[ChatMessage], + messages: str | ChatMessage | list[str] | list[ChatMessage], + *, + frequency_penalty: float | None = None, + logit_bias: dict[str | int, float] | None = None, + max_tokens: int | None = None, + metadata: dict[str, Any] | None = None, + model: str | None = None, + presence_penalty: float | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + stop: str | Sequence[str] | None = None, + store: bool | None = None, + temperature: float | None = None, + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tools: AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None = None, + top_p: float | None = None, + user: str | None = None, + additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> ChatResponse: """Sends input and returns the response. Args: messages: The sequence of input messages to send. - **kwargs: Additional options for the request, such as ai_model_id, temperature, etc. - See `ChatOptions` for more details. + frequency_penalty: the frequency penalty to use. + logit_bias: the logit bias to use. + max_tokens: The maximum number of tokens to generate. + metadata: additional metadata to include in the request. + model: The model to use for the agent. + presence_penalty: the presence penalty to use. + response_format: the format of the response. + seed: the random seed to use. + stop: the stop sequence(s) for the request. + store: whether to store the response. + temperature: the sampling temperature to use. + tool_choice: the tool choice for the request. + tools: the tools to use for the request. + top_p: the nucleus sampling probability to use. + user: the user to associate with the request. + additional_properties: additional properties to include in the request + kwargs: any additional keyword arguments, + will only be passed to functions that are called. Returns: The response messages generated by the client. @@ -287,15 +316,54 @@ class ChatClient(Protocol): def get_streaming_response( self, - messages: str | ChatMessage | list[ChatMessage], + messages: str | ChatMessage | list[str] | list[ChatMessage], + *, + frequency_penalty: float | None = None, + logit_bias: dict[str | int, float] | None = None, + max_tokens: int | None = None, + metadata: dict[str, Any] | None = None, + model: str | None = None, + presence_penalty: float | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + stop: str | Sequence[str] | None = None, + store: bool | None = None, + temperature: float | None = None, + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tools: AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None = None, + top_p: float | None = None, + user: str | None = None, + additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Sends input messages and streams the response. Args: messages: The sequence of input messages to send. - **kwargs: Additional options for the request, such as ai_model_id, temperature, etc. - See `ChatOptions` for more details. + frequency_penalty: the frequency penalty to use. + logit_bias: the logit bias to use. + max_tokens: The maximum number of tokens to generate. + metadata: additional metadata to include in the request. + model: The model to use for the agent. + presence_penalty: the presence penalty to use. + response_format: the format of the response. + seed: the random seed to use. + stop: the stop sequence(s) for the request. + store: whether to store the response. + temperature: the sampling temperature to use. + tool_choice: the tool choice for the request. + tools: the tools to use for the request. + top_p: the nucleus sampling probability to use. + user: the user to associate with the request. + additional_properties: additional properties to include in the request + kwargs: any additional keyword arguments, + will only be passed to functions that are called. Yields: An async iterable of chat response updates containing the content of the response messages @@ -311,19 +379,21 @@ class ChatClientBase(AFBaseModel, ABC): """Base class for chat clients.""" ai_model_id: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)] - maximum_iterations_per_request: int = 10 - _tool_map: dict[str, AIFunction[BaseModel, Any]] = PrivateAttr(default_factory=dict) # type: ignore - def _prepare_messages(self, messages: str | ChatMessage | list[str | ChatMessage]) -> MutableSequence[ChatMessage]: + def _prepare_messages( + self, messages: str | ChatMessage | list[str] | list[ChatMessage] + ) -> MutableSequence[ChatMessage]: """Turn the allowed input into a list of chat messages.""" if isinstance(messages, str): - messages = [ChatMessage(role="user", text=messages)] + return [ChatMessage(role="user", text=messages)] if isinstance(messages, ChatMessage): - messages = [messages] - for i, msg in enumerate(messages): + return [messages] + return_messages: list[ChatMessage] = [] + for msg in messages: if isinstance(msg, str): - messages[i] = ChatMessage(role="user", text=msg) - return messages # type: ignore[return-value] + msg = ChatMessage(role="user", text=msg) + return_messages.append(msg) + return return_messages # region Internal methods to be implemented by the derived classes @@ -377,23 +447,29 @@ class ChatClientBase(AFBaseModel, ABC): async def get_response( self, - messages: str | ChatMessage | list[str | ChatMessage], + messages: str | ChatMessage | list[str] | list[ChatMessage], *, - model: str | None = None, - max_tokens: int | None = None, - temperature: float | None = None, - top_p: float | None = None, - tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", - tools: AITool | Sequence[AITool] | None = None, - response_format: type[BaseModel] | None = None, - user: str | None = None, - stop: str | Sequence[str] | None = None, frequency_penalty: float | None = None, logit_bias: dict[str | int, float] | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - store: bool | None = None, + max_tokens: int | None = None, metadata: dict[str, Any] | None = None, + model: str | None = None, + presence_penalty: float | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + stop: str | Sequence[str] | None = None, + store: bool | None = None, + temperature: float | None = None, + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tools: AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None = None, + top_p: float | None = None, + user: str | None = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> ChatResponse: @@ -401,73 +477,80 @@ class ChatClientBase(AFBaseModel, ABC): Args: messages: the message or messages to send to the model - model: the model to use for the request - max_tokens: the maximum number of tokens to generate - temperature: the sampling temperature to use - top_p: the nucleus sampling probability to use - tool_choice: the tool choice for the request - tools: the tools to use for the request - response_format: the format of the response - user: the user to associate with the request - stop: the stop sequence(s) for the request - frequency_penalty: the frequency penalty to use - logit_bias: the logit bias to use - presence_penalty: the presence penalty to use - seed: the random seed to use - store: whether to store the response - metadata: additional metadata to include in the request - additional_properties: additional properties to include in the request + frequency_penalty: the frequency penalty to use. + logit_bias: the logit bias to use. + max_tokens: The maximum number of tokens to generate. + metadata: additional metadata to include in the request. + model: The model to use for the agent. + presence_penalty: the presence penalty to use. + response_format: the format of the response. + seed: the random seed to use. + stop: the stop sequence(s) for the request. + store: whether to store the response. + temperature: the sampling temperature to use. + tool_choice: the tool choice for the request. + tools: the tools to use for the request. + top_p: the nucleus sampling probability to use. + user: the user to associate with the request. + additional_properties: additional properties to include in the request. kwargs: any additional keyword arguments, will only be passed to functions that are called. Returns: A chat response from the model. """ - if tools is not None: - if not isinstance(tools, Sequence): - tools = [tools] - self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)} - chat_options = ChatOptions( - ai_model_id=model, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tool_choice=tool_choice, - tools=tools, - response_format=response_format, - user=user, - stop=stop, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - presence_penalty=presence_penalty, - seed=seed, - store=store, - metadata=metadata, - additional_properties=additional_properties or {}, - ) + if "chat_options" in kwargs: + chat_options = kwargs.pop("chat_options") + if not isinstance(chat_options, ChatOptions): + raise TypeError("chat_options must be an instance of ChatOptions") + else: + chat_options = ChatOptions( + ai_model_id=model, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + max_tokens=max_tokens, + metadata=metadata, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + store=store, + temperature=temperature, + top_p=top_p, + tool_choice=tool_choice, + tools=tools, # type: ignore + user=user, + additional_properties=additional_properties or {}, + ) prepped_messages = self._prepare_messages(messages) _prepare_tools_and_tool_choice(chat_options=chat_options) return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs) async def get_streaming_response( self, - messages: str | ChatMessage | list[str | ChatMessage], + messages: str | ChatMessage | list[str] | list[ChatMessage], *, - model: str | None = None, - max_tokens: int | None = None, - temperature: float | None = None, - top_p: float | None = None, - tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", - tools: AITool | Sequence[AITool] | None = None, - response_format: type[BaseModel] | None = None, - user: str | None = None, - stop: str | Sequence[str] | None = None, frequency_penalty: float | None = None, logit_bias: dict[str | int, float] | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - store: bool | None = None, + max_tokens: int | None = None, metadata: dict[str, Any] | None = None, + model: str | None = None, + presence_penalty: float | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + stop: str | Sequence[str] | None = None, + store: bool | None = None, + temperature: float | None = None, + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", + tools: AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None = None, + top_p: float | None = None, + user: str | None = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: @@ -475,50 +558,51 @@ class ChatClientBase(AFBaseModel, ABC): Args: messages: the message or messages to send to the model - model: the model to use for the request - max_tokens: the maximum number of tokens to generate - temperature: the sampling temperature to use - top_p: the nucleus sampling probability to use - tool_choice: the tool choice for the request - tools: the tools to use for the request - response_format: the format of the response - user: the user to associate with the request - stop: the stop sequence(s) for the request frequency_penalty: the frequency penalty to use logit_bias: the logit bias to use - presence_penalty: the presence penalty to use - seed: the random seed to use - store: whether to store the response - metadata: additional metadata to include in the request + max_tokens: The maximum number of tokens to generate. + metadata: additional metadata to include in the request. + model: The model to use for the agent. + presence_penalty: the presence penalty to use. + response_format: the format of the response. + seed: the random seed to use. + stop: the stop sequence(s) for the request. + store: whether to store the response. + temperature: the sampling temperature to use. + tool_choice: the tool choice for the request. + tools: the tools to use for the request. + top_p: the nucleus sampling probability to use. + user: the user to associate with the request. additional_properties: additional properties to include in the request kwargs: any additional keyword arguments Yields: A stream representing the response(s) from the LLM. """ - if tools is not None: - if not isinstance(tools, Sequence): - tools = [tools] - self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)} - chat_options = ChatOptions( - ai_model_id=model, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tool_choice=tool_choice, - tools=tools, - response_format=response_format, - user=user, - stop=stop, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - presence_penalty=presence_penalty, - seed=seed, - store=store, - metadata=metadata, - additional_properties=additional_properties or {}, - **kwargs, - ) + if "chat_options" in kwargs: + chat_options = kwargs.pop("chat_options") + if not isinstance(chat_options, ChatOptions): + raise TypeError("chat_options must be an instance of ChatOptions") + else: + chat_options = ChatOptions( + ai_model_id=model, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + max_tokens=max_tokens, + metadata=metadata, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + store=store, + temperature=temperature, + top_p=top_p, + tool_choice=tool_choice, + tools=tools, # type: ignore + user=user, + additional_properties=additional_properties or {}, + **kwargs, + ) prepped_messages = self._prepare_messages(messages) _prepare_tools_and_tool_choice(chat_options=chat_options) async for update in self._inner_get_streaming_response( diff --git a/python/packages/main/agent_framework/_tools.py b/python/packages/main/agent_framework/_tools.py index 0df9e68b24..88ef73f529 100644 --- a/python/packages/main/agent_framework/_tools.py +++ b/python/packages/main/agent_framework/_tools.py @@ -10,7 +10,16 @@ from pydantic import BaseModel, create_model @runtime_checkable class AITool(Protocol): - """Represents a generic tool that can be specified to an AI service.""" + """Represents a generic tool that can be specified to an AI service. + + Attributes: + name: The name of the tool. + description: A description of the tool. + additional_properties: Additional properties associated with the tool. + + Methods: + parameters: The parameters accepted by the tool, in a json schema format. + """ name: str """The name of the tool.""" diff --git a/python/packages/main/agent_framework/_types.py b/python/packages/main/agent_framework/_types.py index 95e9bdc9db..43896299b4 100644 --- a/python/packages/main/agent_framework/_types.py +++ b/python/packages/main/agent_framework/_types.py @@ -4,13 +4,22 @@ import base64 import json import re import sys -from collections.abc import AsyncIterable, Iterable, Iterator, Mapping, MutableMapping, MutableSequence, Sequence +from collections.abc import ( + AsyncIterable, + Callable, + Iterable, + Iterator, + Mapping, + MutableMapping, + MutableSequence, + Sequence, +) from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, field_validator, model_validator from ._pydantic import AFBaseModel -from ._tools import AITool +from ._tools import AITool, ai_function from .exceptions import AgentFrameworkException if sys.version_info >= (3, 11): @@ -1399,34 +1408,69 @@ class ChatOptions(AFBaseModel): """Common request settings for AI services.""" ai_model_id: Annotated[str | None, Field(serialization_alias="model")] = None + frequency_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None + logit_bias: MutableMapping[str | int, float] | None = None max_tokens: Annotated[int | None, Field(gt=0)] = None - temperature: Annotated[float | None, Field(ge=0.0, le=2.0)] = None - top_p: Annotated[float | None, Field(ge=0.0, le=1.0)] = None - tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None - tools: Sequence[AITool] | Sequence[MutableMapping[str, Any]] | None = None + metadata: MutableMapping[str, str] | None = None + presence_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None response_format: type[BaseModel] | None = Field( default=None, description="Structured output response format schema. Must be a valid Pydantic model." ) - user: str | None = None - stop: str | Sequence[str] | None = None - frequency_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None - logit_bias: MutableMapping[str | int, float] | None = None - presence_penalty: Annotated[float | None, Field(ge=-2.0, le=2.0)] = None seed: int | None = None + stop: str | Sequence[str] | None = None store: bool | None = None - metadata: MutableMapping[str, str] | None = None + temperature: Annotated[float | None, Field(ge=0.0, le=2.0)] = None + tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None = None + tools: list[AITool | MutableMapping[str, Any]] | None = None + _ai_tools: list[AITool | MutableMapping[str, Any]] | None = PrivateAttr(default=None) + top_p: Annotated[float | None, Field(ge=0.0, le=1.0)] = None + user: str | None = None additional_properties: MutableMapping[str, Any] = Field( default_factory=dict, description="Provider-specific additional properties." ) + @model_validator(mode="after") + def _copy_to_ai_tools(self) -> Self: + if self.tools and not self._ai_tools: + self._ai_tools = self.tools + return self + + @field_validator("tools", mode="before") + @classmethod + def _validate_tools( + cls, + tools: ( + AITool + | list[AITool] + | Callable[..., Any] + | list[Callable[..., Any]] + | MutableMapping[str, Any] + | list[MutableMapping[str, Any]] + | None + ), + ) -> list[AITool | MutableMapping[str, Any]] | None: + """Parse the tools field. + + All tools are stored in both tools and _ai_tools. + """ + if not tools: + return None + if not isinstance(tools, list): + tools = [tools] # type: ignore[reportAssignmentType, assignment] + for idx, tool in enumerate(tools): # type: ignore[reportArgumentType, arg-type] + if not isinstance(tool, (AITool, MutableMapping)): + # Convert to AITool if it's a function or callable + tools[idx] = ai_function(tool) # type: ignore[reportIndexIssues, reportCallIssue, reportArgumentType, index, call-overload, arg-type] + return tools # type: ignore[reportReturnType, return-value] + @field_validator("tool_choice", mode="before") @classmethod def _validate_tool_mode( cls, tool_choice: ChatToolMode | Literal["auto", "required", "none"] | Mapping[str, Any] | None - ) -> ChatToolMode: + ) -> ChatToolMode | None: """Validates the tool_choice field to ensure it is a valid ChatToolMode.""" if not tool_choice: - return ChatToolMode.NONE + return None if isinstance(tool_choice, str): match tool_choice: case "auto": @@ -1461,6 +1505,32 @@ class ChatOptions(AFBaseModel): settings.pop(key, None) return settings + def __and__(self, other: object) -> Self: + """Combines two ChatOptions instances. + + The values from the other ChatOptions take precedence. + List and dicts are combined. + """ + if not isinstance(other, ChatOptions): + return self + ai_tools = other._ai_tools + updated_values = other.model_dump(exclude_none=True) + updated_values.pop("tools", []) + logit_bias = updated_values.pop("logit_bias", {}) + metadata = updated_values.pop("metadata", {}) + additional_properties = updated_values.pop("additional_properties", {}) + combined = self.model_copy(update=updated_values) + if ai_tools: + if not combined._ai_tools: + combined._ai_tools = [] + for tool in ai_tools: + if tool not in combined._ai_tools: + combined._ai_tools.append(tool) + combined.logit_bias = {**(combined.logit_bias or {}), **logit_bias} + combined.metadata = {**(combined.metadata or {}), **metadata} + combined.additional_properties = {**(combined.additional_properties or {}), **additional_properties} + return combined + # region: GeneratedEmbeddings diff --git a/python/packages/main/tests/unit/conftest.py b/python/packages/main/tests/unit/conftest.py new file mode 100644 index 0000000000..905f4d9418 --- /dev/null +++ b/python/packages/main/tests/unit/conftest.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any + +from pydantic import BaseModel +from pytest import fixture + +from agent_framework import AITool, ai_function + + +@fixture +def ai_tool() -> AITool: + """Returns a generic AITool.""" + + class GenericTool(BaseModel): + name: str + description: str | None = None + additional_properties: dict[str, Any] | None = None + + def parameters(self) -> dict[str, Any]: + """Return the parameters of the tool as a JSON schema.""" + return { + "name": {"type": "string"}, + } + + return GenericTool(name="generic_tool", description="A generic tool") + + +@fixture +def ai_function_tool() -> AITool: + """Returns a executable AITool.""" + + @ai_function + def simple_function(x: int, y: int) -> int: + """A simple function that adds two numbers.""" + return x + y + + return simple_function diff --git a/python/packages/main/tests/unit/test_clients.py b/python/packages/main/tests/unit/test_clients.py index 68540d6877..180b4cf86d 100644 --- a/python/packages/main/tests/unit/test_clients.py +++ b/python/packages/main/tests/unit/test_clients.py @@ -209,7 +209,7 @@ async def test_base_client_with_function_calling(chat_client_base: MockChatClien async def test_base_client_with_function_calling_disabled(chat_client_base: MockChatClientBase): - chat_client_base.maximum_iterations_per_request = 0 + chat_client_base.__maximum_iterations_per_request = 0 exec_counter = 0 @ai_function(name="test_function") @@ -273,7 +273,7 @@ async def test_base_client_with_streaming_function_calling(chat_client_base: Moc async def test_base_client_with_streaming_function_calling_disabled(chat_client_base: MockChatClientBase): - chat_client_base.maximum_iterations_per_request = 0 + chat_client_base.__maximum_iterations_per_request = 0 exec_counter = 0 @ai_function(name="test_function") diff --git a/python/packages/main/tests/unit/test_types.py b/python/packages/main/tests/unit/test_types.py index 82378753e0..739b895fb8 100644 --- a/python/packages/main/tests/unit/test_types.py +++ b/python/packages/main/tests/unit/test_types.py @@ -10,7 +10,9 @@ from agent_framework import ( AgentRunResponseUpdate, AIContent, AIContents, + AITool, ChatMessage, + ChatOptions, ChatResponse, ChatResponseUpdate, ChatRole, @@ -492,6 +494,99 @@ def test_generated_embeddings(): assert issubclass(GeneratedEmbeddings, MutableSequence) +# region: ChatOptions + + +def test_chat_options_init() -> None: + options = ChatOptions() + assert options.ai_model_id is None + + +def test_chat_options_init_with_args(ai_function_tool, ai_tool) -> None: + options = ChatOptions( + ai_model_id="gpt-4", + max_tokens=1024, + temperature=0.7, + top_p=0.9, + presence_penalty=0.0, + frequency_penalty=0.0, + user="user-123", + tools=[ai_function_tool, ai_tool], + ) + assert options.ai_model_id == "gpt-4" + assert options.max_tokens == 1024 + assert options.temperature == 0.7 + assert options.top_p == 0.9 + assert options.presence_penalty == 0.0 + assert options.frequency_penalty == 0.0 + assert options.user == "user-123" + for tool in options._ai_tools: + assert isinstance(tool, AITool) + assert tool.name is not None + assert tool.description is not None + assert tool.parameters() is not None + + +def test_chat_options_and(ai_function_tool, ai_tool) -> None: + options1 = ChatOptions(ai_model_id="gpt-4o", tools=[ai_function_tool]) + options2 = ChatOptions(ai_model_id="gpt-4.1", tools=[ai_tool]) + assert options1 != options2 + options3 = options1 & options2 + + assert options3.ai_model_id == "gpt-4.1" + assert len(options3._ai_tools) == 2 + assert options3._ai_tools == [ai_function_tool, ai_tool] + assert options3.tools == [ai_function_tool, ai_tool] + + +def test_chat_options_parsing_tools(ai_function_tool, ai_tool) -> None: + from agent_framework._clients import _prepare_tools_and_tool_choice + + def echo() -> str: + """Echo the input.""" + return "Echo" + + dict_function = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Retrieves current weather for the given location.", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City and country e.g. Bogotá, Colombia"}, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Units the temperature will be returned in.", + }, + }, + "required": ["location", "units"], + "additionalProperties": False, + }, + "strict": True, + }, + } + + options = ChatOptions(tools=[ai_function_tool, ai_tool, echo, dict_function], tool_choice="auto") + assert len(options.tools) == 4 + assert options.tools[0] == ai_function_tool + assert options.tools[1] == ai_tool + assert options.tools[2] != echo + assert options.tools[3] == dict_function + # after prepare, the tools should be represented as dicts + # while ai_tools is still the same. + _prepare_tools_and_tool_choice(options) + assert options._ai_tools[0] == ai_function_tool + assert options._ai_tools[1] == ai_tool + assert options._ai_tools[3] == dict_function + assert len(options.tools) == 4 + assert options.tools[0]["function"]["name"] == "simple_function" + assert options.tools[1]["function"]["name"] == "generic_tool" + assert options.tools[2]["function"]["name"] == "echo" + assert options.tools[3]["function"]["name"] == "get_weather" + + # region Agent Response Fixtures @@ -548,7 +643,7 @@ def test_agent_run_response_from_updates(agent_run_response_update: AgentRunResp updates = [agent_run_response_update, agent_run_response_update] response = AgentRunResponse.from_agent_run_response_updates(updates) assert len(response.messages) > 0 - assert response.text == "Test content\nTest content" + assert response.text == "Test content Test content" def test_agent_run_response_str_method(chat_message: ChatMessage) -> None: diff --git a/python/samples/getting_started/agents/chat_client_agent.py b/python/samples/getting_started/agents/chat_client_agent.py new file mode 100644 index 0000000000..76192cf415 --- /dev/null +++ b/python/samples/getting_started/agents/chat_client_agent.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from random import randint +from typing import Annotated + +from agent_framework import ChatClientAgent +from agent_framework.openai import OpenAIChatClient +from pydantic import Field + + +def get_weather( + location: Annotated[str, Field(description="The location to get the weather for.")], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +async def main(): + instructions = "You are a helpful assistant, you can help the user with weather information." + agent = ChatClientAgent(OpenAIChatClient(), instructions=instructions, tools=get_weather) + print(str(await agent.run("What's the weather in Amsterdam?"))) + + +if __name__ == "__main__": + asyncio.run(main())