mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
94e00bd49a
* Added ChatClientAgentThread * Initial version of ChatClientAgent * Completed ChatClientAgent * Small fixes and unit tests * Fixes based on pre-commit * Small fixes * Small renaming * Small improvement * Small fixes * Addressed PR feedback * Small fix * Added method for AgentRunResponse from streaming conversion * Addressed PR feedback * Addressed PR feedback * Addressed PR feedback * Small fix * More fixes
531 lines
20 KiB
Python
531 lines
20 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import asyncio
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence
|
|
from functools import wraps
|
|
from typing import Annotated, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
|
|
|
|
from pydantic import BaseModel, PrivateAttr, StringConstraints
|
|
|
|
from ._logging import get_logger
|
|
from ._pydantic import AFBaseModel
|
|
from ._tools import AIFunction, AITool
|
|
from ._types import (
|
|
AIContents,
|
|
ChatMessage,
|
|
ChatOptions,
|
|
ChatResponse,
|
|
ChatResponseUpdate,
|
|
ChatToolMode,
|
|
FunctionCallContent,
|
|
FunctionResultContent,
|
|
GeneratedEmbeddings,
|
|
)
|
|
|
|
TInput = TypeVar("TInput", contravariant=True)
|
|
TEmbedding = TypeVar("TEmbedding")
|
|
TInnerGetResponse = TypeVar("TInnerGetResponse", bound=Callable[..., Awaitable[ChatResponse]])
|
|
TInnerGetStreamingResponse = TypeVar(
|
|
"TInnerGetStreamingResponse", bound=Callable[..., AsyncIterable[ChatResponseUpdate]]
|
|
)
|
|
|
|
TChatClientBase = TypeVar("TChatClientBase", bound="ChatClientBase")
|
|
|
|
logger = get_logger()
|
|
|
|
# region: Tool Calling Functions and Decorators
|
|
|
|
|
|
def _merge_function_results(
|
|
messages: list[ChatMessage],
|
|
) -> ChatMessage:
|
|
"""Combine multiple function result content types to one chat message content type.
|
|
|
|
This method combines the FunctionResultContent items from separate ChatMessageContent messages,
|
|
and is used in the event that the `context.terminate = True` condition is met.
|
|
"""
|
|
contents: list[Any] = []
|
|
for message in messages:
|
|
contents.extend([item for item in message.contents if isinstance(item, FunctionResultContent)])
|
|
|
|
return ChatMessage(
|
|
role="tool",
|
|
contents=contents,
|
|
)
|
|
|
|
|
|
async def _auto_invoke_function(
|
|
function_call_content: FunctionCallContent,
|
|
custom_args: dict[str, Any] | None = None,
|
|
*,
|
|
tool_map: dict[str, AIFunction[BaseModel, Any]],
|
|
sequence_index: int | None = None,
|
|
request_index: int | None = None,
|
|
) -> AIContents:
|
|
"""Invoke a function call requested by the agent, applying filters that are defined in the agent."""
|
|
tool: AIFunction[BaseModel, Any] | None = tool_map.get(function_call_content.name)
|
|
if tool is None:
|
|
raise KeyError(f"No tool or function named '{function_call_content.name}'")
|
|
|
|
parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {})
|
|
|
|
# Merge with user-supplied args; right-hand side dominates, so parsed args win on conflicts.
|
|
merged_args: dict[str, Any] = (custom_args or {}) | parsed_args
|
|
args = tool.input_model.model_validate(merged_args)
|
|
exception = None
|
|
try:
|
|
function_result = await tool.invoke(arguments=args)
|
|
except Exception as ex:
|
|
exception = ex
|
|
function_result = None
|
|
return FunctionResultContent(
|
|
call_id=function_call_content.call_id,
|
|
exception=exception,
|
|
result=function_result,
|
|
)
|
|
|
|
|
|
def _tool_to_json_schema_spec(tool: AITool) -> dict[str, Any]:
|
|
"""Convert a AITool to the JSON Schema function specification format."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"parameters": tool.parameters(),
|
|
},
|
|
}
|
|
|
|
|
|
def _prepare_tools_and_tool_choice(chat_options: ChatOptions) -> None:
|
|
"""Prepare the tools and tool choice for the chat options."""
|
|
chat_tool_mode: ChatToolMode | None = chat_options.tool_choice # type: ignore
|
|
if chat_tool_mode is None or chat_tool_mode == ChatToolMode.NONE:
|
|
chat_options.tools = None
|
|
chat_options.tool_choice = ChatToolMode.NONE.mode
|
|
return
|
|
chat_options.tools = [
|
|
(_tool_to_json_schema_spec(t) if isinstance(t, AITool) else t) for t in chat_options.tools or []
|
|
]
|
|
chat_options.tool_choice = chat_tool_mode.mode
|
|
|
|
|
|
def _tool_call_non_streaming(func: TInnerGetResponse) -> TInnerGetResponse:
|
|
"""Decorate the internal _inner_get_response method to enable tool calls.
|
|
|
|
Remarks:
|
|
Relies on a class that has the _tool_map attribute for the executable tools to call.
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(
|
|
self: "ChatClientBase",
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
response: ChatResponse | None = None
|
|
fcc_messages: list[ChatMessage] = []
|
|
for attempt_idx in range(self.maximum_iterations_per_request):
|
|
response = await func(self, messages=messages, chat_options=chat_options)
|
|
# if there are function calls, we will handle them first
|
|
function_calls = [it for it in response.messages[0].contents if isinstance(it, FunctionCallContent)]
|
|
if function_calls:
|
|
# Run all function calls concurrently
|
|
results = await asyncio.gather(*[
|
|
_auto_invoke_function(
|
|
function_call,
|
|
custom_args=kwargs,
|
|
tool_map=self._tool_map,
|
|
sequence_index=seq_idx,
|
|
request_index=attempt_idx,
|
|
)
|
|
for seq_idx, function_call in enumerate(function_calls)
|
|
])
|
|
# add a single ChatMessage to the response with the results
|
|
response.messages.append(ChatMessage(role="tool", contents=results))
|
|
# response should contain 2 messages after this,
|
|
# one with function call contents
|
|
# and one with function result contents
|
|
# the amount and call_id's should match
|
|
# this runs in every but the first run
|
|
# we need to keep track of all function call messages
|
|
fcc_messages.extend(response.messages)
|
|
# and add them as additional context to the messages
|
|
messages.extend(response.messages)
|
|
continue
|
|
# If we reach this point, it means there were no function calls to handle,
|
|
# we'll add the previous function call and responses
|
|
# to the front of the list, so that the final response is the last one
|
|
# TODO (eavanvalkenburg): control this behavior?
|
|
if fcc_messages:
|
|
for msg in reversed(fcc_messages):
|
|
response.messages.insert(0, msg)
|
|
return response
|
|
|
|
# Failsafe: give up on tools, ask model for plain answer
|
|
chat_options.tool_choice = "none"
|
|
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
|
response = await func(self, messages=messages, chat_options=chat_options)
|
|
if fcc_messages:
|
|
for msg in reversed(fcc_messages):
|
|
response.messages.insert(0, msg)
|
|
return response
|
|
|
|
return wrapper # type: ignore[reportReturnType, return-value]
|
|
|
|
|
|
def _tool_call_streaming(func: TInnerGetStreamingResponse) -> TInnerGetStreamingResponse:
|
|
"""Decorate the internal _inner_get_response method to enable tool calls.
|
|
|
|
Remarks:
|
|
Relies on a class that has the _tool_map attribute for the executable tools to call.
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(
|
|
self: "ChatClientBase",
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
for attempt_idx in range(self.maximum_iterations_per_request):
|
|
function_call_returned = False
|
|
all_messages: list[ChatResponseUpdate] = []
|
|
async for update in func(self, messages=messages, chat_options=chat_options):
|
|
if update.contents and any(isinstance(item, FunctionCallContent) for item in update.contents):
|
|
all_messages.append(update)
|
|
function_call_returned = True
|
|
yield update
|
|
|
|
if not function_call_returned:
|
|
return
|
|
|
|
# There is one FunctionCallContent response stream in the messages, combining now to create
|
|
# the full completion depending on the prompt, the message may contain both function call
|
|
# content and others
|
|
response: ChatResponse = ChatResponse.from_chat_response_updates(all_messages)
|
|
function_calls = [item for item in response.messages[0].contents if isinstance(item, FunctionCallContent)]
|
|
messages.append(response.messages[0])
|
|
|
|
if function_calls:
|
|
# Run all function calls concurrently
|
|
results = await asyncio.gather(*[
|
|
_auto_invoke_function(
|
|
function_call,
|
|
custom_args=kwargs,
|
|
tool_map=self._tool_map,
|
|
sequence_index=seq_idx,
|
|
request_index=attempt_idx,
|
|
)
|
|
for seq_idx, function_call in enumerate(function_calls)
|
|
])
|
|
yield ChatResponseUpdate(contents=results, role="tool")
|
|
response.messages.append(ChatMessage(role="tool", contents=results))
|
|
messages.extend(response.messages)
|
|
continue
|
|
|
|
# Failsafe: give up on tools, ask model for plain answer
|
|
chat_options.tool_choice = "none"
|
|
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
|
async for update in func(self, messages=messages, chat_options=chat_options, **kwargs):
|
|
yield update
|
|
|
|
return wrapper # type: ignore[reportReturnType, return-value]
|
|
|
|
|
|
def use_tool_calling(cls: type[TChatClientBase]) -> type[TChatClientBase]:
|
|
inner_response = getattr(cls, "_inner_get_response", None)
|
|
if inner_response is not None:
|
|
cls._inner_get_response = _tool_call_non_streaming(inner_response) # type: ignore
|
|
inner_streaming_response = getattr(cls, "_inner_get_streaming_response", None)
|
|
if inner_streaming_response is not None:
|
|
cls._inner_get_streaming_response = _tool_call_streaming(inner_streaming_response) # type: ignore
|
|
return cls
|
|
|
|
|
|
# region: ChatClient Protocol
|
|
|
|
|
|
@runtime_checkable
|
|
class ChatClient(Protocol):
|
|
"""A protocol for a chat client that can generate responses."""
|
|
|
|
async def get_response(
|
|
self,
|
|
messages: str | ChatMessage | list[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Sends input and returns the response.
|
|
|
|
Args:
|
|
messages: The sequence of input messages to send.
|
|
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
|
See `ChatOptions` for more details.
|
|
|
|
Returns:
|
|
The response messages generated by the client.
|
|
|
|
Raises:
|
|
ValueError: If the input message sequence is `None`.
|
|
"""
|
|
...
|
|
|
|
def get_streaming_response(
|
|
self,
|
|
messages: str | ChatMessage | list[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Sends input messages and streams the response.
|
|
|
|
Args:
|
|
messages: The sequence of input messages to send.
|
|
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
|
See `ChatOptions` for more details.
|
|
|
|
Yields:
|
|
An async iterable of chat response updates containing the content of the response messages
|
|
generated by the client.
|
|
|
|
Raises:
|
|
ValueError: If the input message sequence is `None`.
|
|
"""
|
|
...
|
|
|
|
|
|
class ChatClientBase(AFBaseModel, ABC):
|
|
"""Base class for chat clients."""
|
|
|
|
ai_model_id: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1)]
|
|
maximum_iterations_per_request: int = 10
|
|
_tool_map: dict[str, AIFunction[BaseModel, Any]] = PrivateAttr(default_factory=dict) # type: ignore
|
|
|
|
# region Internal methods to be implemented by the derived classes
|
|
|
|
@abstractmethod
|
|
async def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Send a chat request to the AI service.
|
|
|
|
Args:
|
|
messages: The chat messages to send.
|
|
chat_options: The options for the request.
|
|
kwargs: Any additional keyword arguments.
|
|
|
|
Returns:
|
|
The chat response contents representing the response(s).
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def _inner_get_streaming_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Send a streaming chat request to the AI service.
|
|
|
|
Args:
|
|
messages: The chat messages to send.
|
|
chat_options: The chat_options for the request.
|
|
kwargs: Any additional keyword arguments.
|
|
|
|
Yields:
|
|
ChatResponseUpdate: The streaming chat message contents.
|
|
"""
|
|
# Below is needed for mypy: https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
if False:
|
|
yield
|
|
await asyncio.sleep(0) # pragma: no cover
|
|
# This is a no-op, but it allows the method to be async and return an AsyncIterable.
|
|
# The actual implementation should yield ChatResponseUpdate instances as needed.
|
|
|
|
# endregion
|
|
|
|
# region Public method
|
|
|
|
async def get_response(
|
|
self,
|
|
messages: str | ChatMessage | list[ChatMessage],
|
|
*,
|
|
model: str | None = None,
|
|
max_tokens: int | None = None,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
|
tools: Sequence[AITool] | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
user: str | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
presence_penalty: float | None = None,
|
|
seed: int | None = None,
|
|
store: bool | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Get a response from a chat client.
|
|
|
|
Args:
|
|
messages: the message or messages to send to the model
|
|
model: the model to use for the request
|
|
max_tokens: the maximum number of tokens to generate
|
|
temperature: the sampling temperature to use
|
|
top_p: the nucleus sampling probability to use
|
|
tool_choice: the tool choice for the request
|
|
tools: the tools to use for the request
|
|
response_format: the format of the response
|
|
user: the user to associate with the request
|
|
stop: the stop sequence(s) for the request
|
|
frequency_penalty: the frequency penalty to use
|
|
logit_bias: the logit bias to use
|
|
presence_penalty: the presence penalty to use
|
|
seed: the random seed to use
|
|
store: whether to store the response
|
|
metadata: additional metadata to include in the request
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments,
|
|
will only be passed to functions that are called.
|
|
|
|
Returns:
|
|
A chat response from the model.
|
|
"""
|
|
if tools is not None:
|
|
self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)}
|
|
chat_options = ChatOptions(
|
|
ai_model_id=model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
tool_choice=tool_choice,
|
|
tools=tools,
|
|
response_format=response_format,
|
|
user=user,
|
|
stop=stop,
|
|
frequency_penalty=frequency_penalty,
|
|
logit_bias=logit_bias,
|
|
presence_penalty=presence_penalty,
|
|
seed=seed,
|
|
store=store,
|
|
metadata=metadata,
|
|
additional_properties=additional_properties or {},
|
|
)
|
|
if isinstance(messages, str):
|
|
messages = [ChatMessage(role="user", text=messages)]
|
|
if isinstance(messages, ChatMessage):
|
|
messages = [messages]
|
|
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
|
return await self._inner_get_response(messages=messages, chat_options=chat_options, **kwargs)
|
|
|
|
async def get_streaming_response(
|
|
self,
|
|
messages: str | ChatMessage | list[ChatMessage],
|
|
*,
|
|
model: str | None = None,
|
|
max_tokens: int | None = None,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
|
tools: Sequence[AITool] | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
user: str | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
presence_penalty: float | None = None,
|
|
seed: int | None = None,
|
|
store: bool | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Get a streaming response from a chat client.
|
|
|
|
Args:
|
|
messages: the message or messages to send to the model
|
|
model: the model to use for the request
|
|
max_tokens: the maximum number of tokens to generate
|
|
temperature: the sampling temperature to use
|
|
top_p: the nucleus sampling probability to use
|
|
tool_choice: the tool choice for the request
|
|
tools: the tools to use for the request
|
|
response_format: the format of the response
|
|
user: the user to associate with the request
|
|
stop: the stop sequence(s) for the request
|
|
frequency_penalty: the frequency penalty to use
|
|
logit_bias: the logit bias to use
|
|
presence_penalty: the presence penalty to use
|
|
seed: the random seed to use
|
|
store: whether to store the response
|
|
metadata: additional metadata to include in the request
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments
|
|
|
|
Yields:
|
|
A stream representing the response(s) from the LLM.
|
|
"""
|
|
if tools is not None:
|
|
self._tool_map = {tool.name: tool for tool in tools if isinstance(tool, AIFunction)}
|
|
chat_options = ChatOptions(
|
|
ai_model_id=model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
tool_choice=tool_choice,
|
|
tools=tools,
|
|
response_format=response_format,
|
|
user=user,
|
|
stop=stop,
|
|
frequency_penalty=frequency_penalty,
|
|
logit_bias=logit_bias,
|
|
presence_penalty=presence_penalty,
|
|
seed=seed,
|
|
store=store,
|
|
metadata=metadata,
|
|
additional_properties=additional_properties or {},
|
|
**kwargs,
|
|
)
|
|
if isinstance(messages, str):
|
|
messages = [ChatMessage(role="user", text=messages)]
|
|
if isinstance(messages, ChatMessage):
|
|
messages = [messages]
|
|
_prepare_tools_and_tool_choice(chat_options=chat_options)
|
|
async for update in self._inner_get_streaming_response(messages=messages, chat_options=chat_options, **kwargs):
|
|
yield update
|
|
|
|
|
|
# region: Embedding Client
|
|
|
|
|
|
@runtime_checkable
|
|
class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]):
|
|
"""A protocol for an embedding generator that can create embeddings from input data."""
|
|
|
|
async def generate(
|
|
self,
|
|
input_data: Sequence[TInput],
|
|
**kwargs: Any,
|
|
) -> GeneratedEmbeddings[TEmbedding]:
|
|
"""Generates an embedding for the given input data.
|
|
|
|
Args:
|
|
input_data: The input data to generate an embedding for.
|
|
**kwargs: Additional options for the request.
|
|
|
|
Returns:
|
|
The generated embedding, this acts like a list, but has additional metadata and usage details.
|
|
|
|
"""
|
|
...
|