mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
10d10364a9
* cleanup of threads and serialization * fix for sliding window * fix redis test * updated from comments * updated context provider and threads * updated lock * add asyncio default * fix redis tests * fix tests * fix tests * renamed to invoking * fixed tests * fix for instructions
556 lines
21 KiB
Python
556 lines
21 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import asyncio
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncIterable, Callable, MutableMapping, MutableSequence, Sequence
|
|
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from ._logging import get_logger
|
|
from ._mcp import MCPTool
|
|
from ._memory import AggregateContextProvider, ContextProvider
|
|
from ._middleware import (
|
|
ChatMiddleware,
|
|
ChatMiddlewareCallable,
|
|
FunctionMiddleware,
|
|
FunctionMiddlewareCallable,
|
|
Middleware,
|
|
)
|
|
from ._pydantic import AFBaseModel
|
|
from ._threads import ChatMessageStoreProtocol
|
|
from ._tools import ToolProtocol
|
|
from ._types import (
|
|
ChatMessage,
|
|
ChatOptions,
|
|
ChatResponse,
|
|
ChatResponseUpdate,
|
|
ChatToolMode,
|
|
GeneratedEmbeddings,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from ._agents import ChatAgent
|
|
|
|
|
|
TInput = TypeVar("TInput", contravariant=True)
|
|
TEmbedding = TypeVar("TEmbedding")
|
|
TBaseChatClient = TypeVar("TBaseChatClient", bound="BaseChatClient")
|
|
|
|
logger = get_logger()
|
|
|
|
__all__ = [
|
|
"BaseChatClient",
|
|
"ChatClientProtocol",
|
|
"EmbeddingGenerator",
|
|
]
|
|
|
|
|
|
# region ChatClientProtocol Protocol
|
|
|
|
|
|
@runtime_checkable
|
|
class ChatClientProtocol(Protocol):
|
|
"""A protocol for a chat client that can generate responses."""
|
|
|
|
@property
|
|
def additional_properties(self) -> dict[str, Any]:
|
|
"""Get additional properties associated with the client."""
|
|
...
|
|
|
|
async def get_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: ToolProtocol
|
|
| Callable[..., Any]
|
|
| MutableMapping[str, Any]
|
|
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Sends input and returns the response.
|
|
|
|
Args:
|
|
messages: The sequence of input messages to send.
|
|
response_format: the format of the response.
|
|
frequency_penalty: the frequency penalty to use.
|
|
logit_bias: the logit bias to use.
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments,
|
|
will only be passed to functions that are called.
|
|
|
|
Returns:
|
|
The response messages generated by the client.
|
|
|
|
Raises:
|
|
ValueError: If the input message sequence is `None`.
|
|
"""
|
|
...
|
|
|
|
def get_streaming_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: ToolProtocol
|
|
| Callable[..., Any]
|
|
| MutableMapping[str, Any]
|
|
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Sends input messages and streams the response.
|
|
|
|
Args:
|
|
messages: The sequence of input messages to send.
|
|
frequency_penalty: the frequency penalty to use.
|
|
logit_bias: the logit bias to use.
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
response_format: the format of the response.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments,
|
|
will only be passed to functions that are called.
|
|
|
|
Yields:
|
|
An async iterable of chat response updates containing the content of the response messages
|
|
generated by the client.
|
|
|
|
Raises:
|
|
ValueError: If the input message sequence is `None`.
|
|
"""
|
|
...
|
|
|
|
|
|
# region ChatClientBase
|
|
|
|
|
|
def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]:
|
|
"""Turn the allowed input into a list of chat messages."""
|
|
if isinstance(messages, str):
|
|
return [ChatMessage(role="user", text=messages)]
|
|
if isinstance(messages, ChatMessage):
|
|
return [messages]
|
|
return_messages: list[ChatMessage] = []
|
|
for msg in messages:
|
|
if isinstance(msg, str):
|
|
msg = ChatMessage(role="user", text=msg)
|
|
return_messages.append(msg)
|
|
return return_messages
|
|
|
|
|
|
class BaseChatClient(AFBaseModel, ABC):
|
|
"""Base class for chat clients."""
|
|
|
|
additional_properties: dict[str, Any] = Field(default_factory=dict)
|
|
middleware: (
|
|
ChatMiddleware
|
|
| ChatMiddlewareCallable
|
|
| FunctionMiddleware
|
|
| FunctionMiddlewareCallable
|
|
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
|
|
| None
|
|
) = None
|
|
OTEL_PROVIDER_NAME: str = "unknown"
|
|
# This is used for OTel setup, should be overridden in subclasses
|
|
|
|
def prepare_messages(
|
|
self, messages: str | ChatMessage | list[str] | list[ChatMessage], chat_options: ChatOptions
|
|
) -> MutableSequence[ChatMessage]:
|
|
"""Turn the allowed input into a list of chat messages."""
|
|
if chat_options.instructions:
|
|
system_msg = ChatMessage(role="system", text=chat_options.instructions)
|
|
return [system_msg, *prepare_messages(messages)]
|
|
return prepare_messages(messages)
|
|
|
|
def _filter_internal_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
"""Filter out internal framework parameters that shouldn't be passed to chat client implementations.
|
|
|
|
Args:
|
|
kwargs: The original kwargs dictionary.
|
|
|
|
Returns:
|
|
A filtered kwargs dictionary without internal parameters.
|
|
"""
|
|
return {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
|
|
|
@staticmethod
|
|
def _normalize_tools(
|
|
tools: ToolProtocol
|
|
| MutableMapping[str, Any]
|
|
| Callable[..., Any]
|
|
| list[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]]
|
|
| None = None,
|
|
) -> list[ToolProtocol | dict[str, Any] | Callable[..., Any]]:
|
|
"""Normalize the tools input to a list of tools."""
|
|
final_tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] = []
|
|
if not tools:
|
|
return final_tools
|
|
for tool in tools if isinstance(tools, list) else [tools]: # type: ignore[reportUnknownType]
|
|
if isinstance(tool, MCPTool):
|
|
final_tools.extend(tool.functions) # type: ignore
|
|
continue
|
|
final_tools.append(tool) # type: ignore
|
|
return final_tools
|
|
|
|
# region Internal methods to be implemented by the derived classes
|
|
|
|
@abstractmethod
|
|
async def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Send a chat request to the AI service.
|
|
|
|
Args:
|
|
messages: The chat messages to send.
|
|
chat_options: The options for the request.
|
|
kwargs: Any additional keyword arguments.
|
|
|
|
Returns:
|
|
The chat response contents representing the response(s).
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def _inner_get_streaming_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Send a streaming chat request to the AI service.
|
|
|
|
Args:
|
|
messages: The chat messages to send.
|
|
chat_options: The chat_options for the request.
|
|
kwargs: Any additional keyword arguments.
|
|
|
|
Yields:
|
|
ChatResponseUpdate: The streaming chat message contents.
|
|
"""
|
|
# Below is needed for mypy: https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
|
|
if False:
|
|
yield
|
|
await asyncio.sleep(0) # pragma: no cover
|
|
# This is a no-op, but it allows the method to be async and return an AsyncIterable.
|
|
# The actual implementation should yield ChatResponseUpdate instances as needed.
|
|
|
|
# endregion
|
|
|
|
# region Public method
|
|
|
|
async def get_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: ToolProtocol
|
|
| Callable[..., Any]
|
|
| MutableMapping[str, Any]
|
|
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Get a response from a chat client.
|
|
|
|
Args:
|
|
messages: the message or messages to send to the model
|
|
frequency_penalty: the frequency penalty to use.
|
|
logit_bias: the logit bias to use.
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
response_format: the format of the response.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request.
|
|
kwargs: any additional keyword arguments,
|
|
will only be passed to functions that are called.
|
|
|
|
Returns:
|
|
A chat response from the model.
|
|
"""
|
|
# Should we merge chat options instead of ignoring the input params?
|
|
if "chat_options" in kwargs:
|
|
chat_options = kwargs.pop("chat_options")
|
|
if not isinstance(chat_options, ChatOptions):
|
|
raise TypeError("chat_options must be an instance of ChatOptions")
|
|
else:
|
|
chat_options = ChatOptions(
|
|
ai_model_id=model,
|
|
frequency_penalty=frequency_penalty,
|
|
logit_bias=logit_bias,
|
|
max_tokens=max_tokens,
|
|
metadata=metadata,
|
|
presence_penalty=presence_penalty,
|
|
response_format=response_format,
|
|
seed=seed,
|
|
stop=stop,
|
|
store=store,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
tool_choice=tool_choice,
|
|
tools=self._normalize_tools(tools), # type: ignore
|
|
user=user,
|
|
additional_properties=additional_properties or {},
|
|
)
|
|
prepped_messages = self.prepare_messages(messages, chat_options)
|
|
self._prepare_tool_choice(chat_options=chat_options)
|
|
|
|
filtered_kwargs = self._filter_internal_kwargs(kwargs)
|
|
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **filtered_kwargs)
|
|
|
|
async def get_streaming_response(
|
|
self,
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
|
*,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str | int, float] | None = None,
|
|
max_tokens: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model: str | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: type[BaseModel] | None = None,
|
|
seed: int | None = None,
|
|
stop: str | Sequence[str] | None = None,
|
|
store: bool | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
|
tools: ToolProtocol
|
|
| Callable[..., Any]
|
|
| MutableMapping[str, Any]
|
|
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
|
| None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
additional_properties: dict[str, Any] | None = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
"""Get a streaming response from a chat client.
|
|
|
|
Args:
|
|
messages: the message or messages to send to the model
|
|
frequency_penalty: the frequency penalty to use
|
|
logit_bias: the logit bias to use
|
|
max_tokens: The maximum number of tokens to generate.
|
|
metadata: additional metadata to include in the request.
|
|
model: The model to use for the agent.
|
|
presence_penalty: the presence penalty to use.
|
|
response_format: the format of the response.
|
|
seed: the random seed to use.
|
|
stop: the stop sequence(s) for the request.
|
|
store: whether to store the response.
|
|
temperature: the sampling temperature to use.
|
|
tool_choice: the tool choice for the request.
|
|
tools: the tools to use for the request.
|
|
top_p: the nucleus sampling probability to use.
|
|
user: the user to associate with the request.
|
|
additional_properties: additional properties to include in the request
|
|
kwargs: any additional keyword arguments
|
|
|
|
Yields:
|
|
A stream representing the response(s) from the LLM.
|
|
"""
|
|
# Should we merge chat options instead of ignoring the input params?
|
|
if "chat_options" in kwargs:
|
|
chat_options = kwargs.pop("chat_options")
|
|
if not isinstance(chat_options, ChatOptions):
|
|
raise TypeError("chat_options must be an instance of ChatOptions")
|
|
else:
|
|
chat_options = ChatOptions(
|
|
ai_model_id=model,
|
|
frequency_penalty=frequency_penalty,
|
|
logit_bias=logit_bias,
|
|
max_tokens=max_tokens,
|
|
metadata=metadata,
|
|
presence_penalty=presence_penalty,
|
|
response_format=response_format,
|
|
seed=seed,
|
|
stop=stop,
|
|
store=store,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
tool_choice=tool_choice,
|
|
tools=self._normalize_tools(tools), # type: ignore
|
|
user=user,
|
|
additional_properties=additional_properties or {},
|
|
)
|
|
prepped_messages = self.prepare_messages(messages, chat_options)
|
|
self._prepare_tool_choice(chat_options=chat_options)
|
|
|
|
filtered_kwargs = self._filter_internal_kwargs(kwargs)
|
|
async for update in self._inner_get_streaming_response(
|
|
messages=prepped_messages, chat_options=chat_options, **filtered_kwargs
|
|
):
|
|
yield update
|
|
|
|
def _prepare_tool_choice(self, chat_options: ChatOptions) -> None:
|
|
"""Prepare the tools and tool choice for the chat options.
|
|
|
|
This function should be overridden by subclasses to customize tool handling.
|
|
Because it currently parses only AIFunctions.
|
|
"""
|
|
chat_tool_mode: ChatToolMode | None = chat_options.tool_choice # type: ignore
|
|
if chat_tool_mode is None or chat_tool_mode == ChatToolMode.NONE:
|
|
chat_options.tools = None
|
|
chat_options.tool_choice = ChatToolMode.NONE.mode
|
|
return
|
|
if not chat_options.tools:
|
|
chat_options.tool_choice = ChatToolMode.NONE.mode
|
|
else:
|
|
chat_options.tool_choice = chat_tool_mode.mode
|
|
|
|
def service_url(self) -> str:
|
|
"""Get the URL of the service.
|
|
|
|
Override this in the subclass to return the proper URL.
|
|
If the service does not have a URL, return None.
|
|
"""
|
|
return "Unknown"
|
|
|
|
def create_agent(
|
|
self,
|
|
*,
|
|
name: str | None = None,
|
|
instructions: str | None = None,
|
|
tools: ToolProtocol
|
|
| Callable[..., Any]
|
|
| MutableMapping[str, Any]
|
|
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
|
| None = None,
|
|
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
|
|
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
|
middleware: Middleware | list[Middleware] | None = None,
|
|
**kwargs: Any,
|
|
) -> "ChatAgent":
|
|
"""Create an agent with the given name and instructions.
|
|
|
|
Args:
|
|
name: The name of the agent.
|
|
instructions: The instructions for the agent.
|
|
tools: Optional list of tools to associate with the agent.
|
|
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
|
|
If not provided, the default in-memory store will be used.
|
|
context_providers: Context providers to include during agent invocation.
|
|
middleware: List of middleware to intercept agent and function invocations.
|
|
**kwargs: Additional keyword arguments to pass to the agent.
|
|
See ChatAgent for all the available options.
|
|
|
|
Returns:
|
|
An instance of ChatAgent.
|
|
"""
|
|
from ._agents import ChatAgent
|
|
|
|
return ChatAgent(
|
|
chat_client=self,
|
|
name=name,
|
|
instructions=instructions,
|
|
tools=tools,
|
|
chat_message_store_factory=chat_message_store_factory,
|
|
context_providers=context_providers,
|
|
middleware=middleware,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
# region Embedding Client
|
|
|
|
|
|
@runtime_checkable
|
|
class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]):
|
|
"""A protocol for an embedding generator that can create embeddings from input data."""
|
|
|
|
async def generate(
|
|
self,
|
|
input_data: Sequence[TInput],
|
|
**kwargs: Any,
|
|
) -> GeneratedEmbeddings[TEmbedding]:
|
|
"""Generates an embedding for the given input data.
|
|
|
|
Args:
|
|
input_data: The input data to generate an embedding for.
|
|
**kwargs: Additional options for the request.
|
|
|
|
Returns:
|
|
The generated embedding, this acts like a list, but has additional metadata and usage details.
|
|
|
|
"""
|
|
...
|