# Copyright (c) Microsoft. All rights reserved. import asyncio from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Callable, MutableMapping, MutableSequence, Sequence from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable from pydantic import BaseModel, Field from ._logging import get_logger from ._mcp import MCPTool from ._memory import AggregateContextProvider, ContextProvider from ._middleware import ( ChatMiddleware, ChatMiddlewareCallable, FunctionMiddleware, FunctionMiddlewareCallable, Middleware, ) from ._pydantic import AFBaseModel from ._threads import ChatMessageStore from ._tools import ToolProtocol from ._types import ( ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, ChatToolMode, GeneratedEmbeddings, ) if TYPE_CHECKING: from ._agents import ChatAgent TInput = TypeVar("TInput", contravariant=True) TEmbedding = TypeVar("TEmbedding") TBaseChatClient = TypeVar("TBaseChatClient", bound="BaseChatClient") logger = get_logger() __all__ = [ "BaseChatClient", "ChatClientProtocol", "EmbeddingGenerator", ] # region ChatClientProtocol Protocol @runtime_checkable class ChatClientProtocol(Protocol): """A protocol for a chat client that can generate responses.""" @property def additional_properties(self) -> dict[str, Any]: """Get additional properties associated with the client.""" ... async def get_response( self, messages: str | ChatMessage | list[str] | list[ChatMessage], *, frequency_penalty: float | None = None, logit_bias: dict[str | int, float] | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, model: str | None = None, presence_penalty: float | None = None, response_format: type[BaseModel] | None = None, seed: int | None = None, stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, top_p: float | None = None, user: str | None = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> ChatResponse: """Sends input and returns the response. Args: messages: The sequence of input messages to send. response_format: the format of the response. frequency_penalty: the frequency penalty to use. logit_bias: the logit bias to use. max_tokens: The maximum number of tokens to generate. metadata: additional metadata to include in the request. model: The model to use for the agent. presence_penalty: the presence penalty to use. seed: the random seed to use. stop: the stop sequence(s) for the request. store: whether to store the response. temperature: the sampling temperature to use. tool_choice: the tool choice for the request. tools: the tools to use for the request. top_p: the nucleus sampling probability to use. user: the user to associate with the request. additional_properties: additional properties to include in the request kwargs: any additional keyword arguments, will only be passed to functions that are called. Returns: The response messages generated by the client. Raises: ValueError: If the input message sequence is `None`. """ ... def get_streaming_response( self, messages: str | ChatMessage | list[str] | list[ChatMessage], *, frequency_penalty: float | None = None, logit_bias: dict[str | int, float] | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, model: str | None = None, presence_penalty: float | None = None, response_format: type[BaseModel] | None = None, seed: int | None = None, stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, top_p: float | None = None, user: str | None = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Sends input messages and streams the response. Args: messages: The sequence of input messages to send. frequency_penalty: the frequency penalty to use. logit_bias: the logit bias to use. max_tokens: The maximum number of tokens to generate. metadata: additional metadata to include in the request. model: The model to use for the agent. presence_penalty: the presence penalty to use. response_format: the format of the response. seed: the random seed to use. stop: the stop sequence(s) for the request. store: whether to store the response. temperature: the sampling temperature to use. tool_choice: the tool choice for the request. tools: the tools to use for the request. top_p: the nucleus sampling probability to use. user: the user to associate with the request. additional_properties: additional properties to include in the request kwargs: any additional keyword arguments, will only be passed to functions that are called. Yields: An async iterable of chat response updates containing the content of the response messages generated by the client. Raises: ValueError: If the input message sequence is `None`. """ ... # region ChatClientBase def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]: """Turn the allowed input into a list of chat messages.""" if isinstance(messages, str): return [ChatMessage(role="user", text=messages)] if isinstance(messages, ChatMessage): return [messages] return_messages: list[ChatMessage] = [] for msg in messages: if isinstance(msg, str): msg = ChatMessage(role="user", text=msg) return_messages.append(msg) return return_messages class BaseChatClient(AFBaseModel, ABC): """Base class for chat clients.""" additional_properties: dict[str, Any] = Field(default_factory=dict) middleware: ( ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable | list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable] | None ) = None OTEL_PROVIDER_NAME: str = "unknown" # This is used for OTel setup, should be overridden in subclasses def prepare_messages( self, messages: str | ChatMessage | list[str] | list[ChatMessage] ) -> MutableSequence[ChatMessage]: """Turn the allowed input into a list of chat messages.""" return prepare_messages(messages) @staticmethod def _normalize_tools( tools: ToolProtocol | MutableMapping[str, Any] | Callable[..., Any] | list[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]] | None = None, ) -> list[ToolProtocol | dict[str, Any] | Callable[..., Any]]: """Normalize the tools input to a list of tools.""" final_tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] = [] if not tools: return final_tools for tool in tools if isinstance(tools, list) else [tools]: # type: ignore[reportUnknownType] if isinstance(tool, MCPTool): final_tools.extend(tool.functions) # type: ignore continue final_tools.append(tool) # type: ignore return final_tools # region Internal methods to be implemented by the derived classes @abstractmethod async def _inner_get_response( self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any, ) -> ChatResponse: """Send a chat request to the AI service. Args: messages: The chat messages to send. chat_options: The options for the request. kwargs: Any additional keyword arguments. Returns: The chat response contents representing the response(s). """ @abstractmethod async def _inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Send a streaming chat request to the AI service. Args: messages: The chat messages to send. chat_options: The chat_options for the request. kwargs: Any additional keyword arguments. Yields: ChatResponseUpdate: The streaming chat message contents. """ # Below is needed for mypy: https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators if False: yield await asyncio.sleep(0) # pragma: no cover # This is a no-op, but it allows the method to be async and return an AsyncIterable. # The actual implementation should yield ChatResponseUpdate instances as needed. # endregion # region Public method async def get_response( self, messages: str | ChatMessage | list[str] | list[ChatMessage], *, frequency_penalty: float | None = None, logit_bias: dict[str | int, float] | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, model: str | None = None, presence_penalty: float | None = None, response_format: type[BaseModel] | None = None, seed: int | None = None, stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, top_p: float | None = None, user: str | None = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> ChatResponse: """Get a response from a chat client. Args: messages: the message or messages to send to the model frequency_penalty: the frequency penalty to use. logit_bias: the logit bias to use. max_tokens: The maximum number of tokens to generate. metadata: additional metadata to include in the request. model: The model to use for the agent. presence_penalty: the presence penalty to use. response_format: the format of the response. seed: the random seed to use. stop: the stop sequence(s) for the request. store: whether to store the response. temperature: the sampling temperature to use. tool_choice: the tool choice for the request. tools: the tools to use for the request. top_p: the nucleus sampling probability to use. user: the user to associate with the request. additional_properties: additional properties to include in the request. kwargs: any additional keyword arguments, will only be passed to functions that are called. Returns: A chat response from the model. """ # Should we merge chat options instead of ignoring the input params? if "chat_options" in kwargs: chat_options = kwargs.pop("chat_options") if not isinstance(chat_options, ChatOptions): raise TypeError("chat_options must be an instance of ChatOptions") else: chat_options = ChatOptions( ai_model_id=model, frequency_penalty=frequency_penalty, logit_bias=logit_bias, max_tokens=max_tokens, metadata=metadata, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, store=store, temperature=temperature, top_p=top_p, tool_choice=tool_choice, tools=self._normalize_tools(tools), # type: ignore user=user, additional_properties=additional_properties or {}, ) prepped_messages = self.prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs) async def get_streaming_response( self, messages: str | ChatMessage | list[str] | list[ChatMessage], *, frequency_penalty: float | None = None, logit_bias: dict[str | int, float] | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, model: str | None = None, presence_penalty: float | None = None, response_format: type[BaseModel] | None = None, seed: int | None = None, stop: str | Sequence[str] | None = None, store: bool | None = None, temperature: float | None = None, tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto", tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, top_p: float | None = None, user: str | None = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Get a streaming response from a chat client. Args: messages: the message or messages to send to the model frequency_penalty: the frequency penalty to use logit_bias: the logit bias to use max_tokens: The maximum number of tokens to generate. metadata: additional metadata to include in the request. model: The model to use for the agent. presence_penalty: the presence penalty to use. response_format: the format of the response. seed: the random seed to use. stop: the stop sequence(s) for the request. store: whether to store the response. temperature: the sampling temperature to use. tool_choice: the tool choice for the request. tools: the tools to use for the request. top_p: the nucleus sampling probability to use. user: the user to associate with the request. additional_properties: additional properties to include in the request kwargs: any additional keyword arguments Yields: A stream representing the response(s) from the LLM. """ # Should we merge chat options instead of ignoring the input params? if "chat_options" in kwargs: chat_options = kwargs.pop("chat_options") if not isinstance(chat_options, ChatOptions): raise TypeError("chat_options must be an instance of ChatOptions") else: chat_options = ChatOptions( ai_model_id=model, frequency_penalty=frequency_penalty, logit_bias=logit_bias, max_tokens=max_tokens, metadata=metadata, presence_penalty=presence_penalty, response_format=response_format, seed=seed, stop=stop, store=store, temperature=temperature, top_p=top_p, tool_choice=tool_choice, tools=self._normalize_tools(tools), # type: ignore user=user, additional_properties=additional_properties or {}, ) prepped_messages = self.prepare_messages(messages) self._prepare_tool_choice(chat_options=chat_options) async for update in self._inner_get_streaming_response( messages=prepped_messages, chat_options=chat_options, **kwargs ): yield update def _prepare_tool_choice(self, chat_options: ChatOptions) -> None: """Prepare the tools and tool choice for the chat options. This function should be overridden by subclasses to customize tool handling. Because it currently parses only AIFunctions. """ chat_tool_mode: ChatToolMode | None = chat_options.tool_choice # type: ignore if chat_tool_mode is None or chat_tool_mode == ChatToolMode.NONE: chat_options.tools = None chat_options.tool_choice = ChatToolMode.NONE.mode return if not chat_options.tools: chat_options.tool_choice = ChatToolMode.NONE.mode else: chat_options.tool_choice = chat_tool_mode.mode def service_url(self) -> str: """Get the URL of the service. Override this in the subclass to return the proper URL. If the service does not have a URL, return None. """ return "Unknown" def create_agent( self, *, name: str | None = None, instructions: str | None = None, tools: ToolProtocol | Callable[..., Any] | MutableMapping[str, Any] | list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] | None = None, chat_message_store_factory: Callable[[], ChatMessageStore] | None = None, context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, middleware: Middleware | list[Middleware] | None = None, **kwargs: Any, ) -> "ChatAgent": """Create an agent with the given name and instructions. Args: name: The name of the agent. instructions: The instructions for the agent. tools: Optional list of tools to associate with the agent. chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided, the default in-memory store will be used. context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. **kwargs: Additional keyword arguments to pass to the agent. See ChatAgent for all the available options. Returns: An instance of ChatAgent. """ from ._agents import ChatAgent return ChatAgent( chat_client=self, name=name, instructions=instructions, tools=tools, chat_message_store_factory=chat_message_store_factory, context_providers=context_providers, middleware=middleware, **kwargs, ) # region Embedding Client @runtime_checkable class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]): """A protocol for an embedding generator that can create embeddings from input data.""" async def generate( self, input_data: Sequence[TInput], **kwargs: Any, ) -> GeneratedEmbeddings[TEmbedding]: """Generates an embedding for the given input data. Args: input_data: The input data to generate an embedding for. **kwargs: Additional options for the request. Returns: The generated embedding, this acts like a list, but has additional metadata and usage details. """ ...