Files
agent-framework/python/packages/main/agent_framework/_clients.py
T
Dmytro Struk eec7f192eb Python: Added chat middleware and more examples (#883)
* Added example with stateful middleware

* Added chat middleware

* Updated middleware example with override scenario

* Small revert

* Small fixes

* Added kwargs to context objects

* Added README

* Added function middleware to chat client

* Small refactoring

* Reverted example files

* Made MiddlewareWrapper generic

* Added Middleware exception

* Small refactoring

* Small fix
2025-09-26 15:10:56 +00:00

540 lines
20 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
import asyncio
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Callable, MutableMapping, MutableSequence, Sequence
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
from pydantic import BaseModel, Field
from ._logging import get_logger
from ._mcp import MCPTool
from ._memory import AggregateContextProvider, ContextProvider
from ._middleware import (
ChatMiddleware,
ChatMiddlewareCallable,
FunctionMiddleware,
FunctionMiddlewareCallable,
Middleware,
)
from ._pydantic import AFBaseModel
from ._threads import ChatMessageStore
from ._tools import ToolProtocol
from ._types import (
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
ChatToolMode,
GeneratedEmbeddings,
)
if TYPE_CHECKING:
from ._agents import ChatAgent
TInput = TypeVar("TInput", contravariant=True)
TEmbedding = TypeVar("TEmbedding")
TBaseChatClient = TypeVar("TBaseChatClient", bound="BaseChatClient")
logger = get_logger()
__all__ = [
"BaseChatClient",
"ChatClientProtocol",
"EmbeddingGenerator",
]
# region ChatClientProtocol Protocol
@runtime_checkable
class ChatClientProtocol(Protocol):
"""A protocol for a chat client that can generate responses."""
@property
def additional_properties(self) -> dict[str, Any]:
"""Get additional properties associated with the client."""
...
async def get_response(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
"""Sends input and returns the response.
Args:
messages: The sequence of input messages to send.
response_format: the format of the response.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request
kwargs: any additional keyword arguments,
will only be passed to functions that are called.
Returns:
The response messages generated by the client.
Raises:
ValueError: If the input message sequence is `None`.
"""
...
def get_streaming_response(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Sends input messages and streams the response.
Args:
messages: The sequence of input messages to send.
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request
kwargs: any additional keyword arguments,
will only be passed to functions that are called.
Yields:
An async iterable of chat response updates containing the content of the response messages
generated by the client.
Raises:
ValueError: If the input message sequence is `None`.
"""
...
# region ChatClientBase
def prepare_messages(messages: str | ChatMessage | list[str] | list[ChatMessage]) -> list[ChatMessage]:
"""Turn the allowed input into a list of chat messages."""
if isinstance(messages, str):
return [ChatMessage(role="user", text=messages)]
if isinstance(messages, ChatMessage):
return [messages]
return_messages: list[ChatMessage] = []
for msg in messages:
if isinstance(msg, str):
msg = ChatMessage(role="user", text=msg)
return_messages.append(msg)
return return_messages
class BaseChatClient(AFBaseModel, ABC):
"""Base class for chat clients."""
additional_properties: dict[str, Any] = Field(default_factory=dict)
middleware: (
ChatMiddleware
| ChatMiddlewareCallable
| FunctionMiddleware
| FunctionMiddlewareCallable
| list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable]
| None
) = None
OTEL_PROVIDER_NAME: str = "unknown"
# This is used for OTel setup, should be overridden in subclasses
def prepare_messages(
self, messages: str | ChatMessage | list[str] | list[ChatMessage]
) -> MutableSequence[ChatMessage]:
"""Turn the allowed input into a list of chat messages."""
return prepare_messages(messages)
@staticmethod
def _normalize_tools(
tools: ToolProtocol
| MutableMapping[str, Any]
| Callable[..., Any]
| list[ToolProtocol | MutableMapping[str, Any] | Callable[..., Any]]
| None = None,
) -> list[ToolProtocol | dict[str, Any] | Callable[..., Any]]:
"""Normalize the tools input to a list of tools."""
final_tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] = []
if not tools:
return final_tools
for tool in tools if isinstance(tools, list) else [tools]: # type: ignore[reportUnknownType]
if isinstance(tool, MCPTool):
final_tools.extend(tool.functions) # type: ignore
continue
final_tools.append(tool) # type: ignore
return final_tools
# region Internal methods to be implemented by the derived classes
@abstractmethod
async def _inner_get_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> ChatResponse:
"""Send a chat request to the AI service.
Args:
messages: The chat messages to send.
chat_options: The options for the request.
kwargs: Any additional keyword arguments.
Returns:
The chat response contents representing the response(s).
"""
@abstractmethod
async def _inner_get_streaming_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Send a streaming chat request to the AI service.
Args:
messages: The chat messages to send.
chat_options: The chat_options for the request.
kwargs: Any additional keyword arguments.
Yields:
ChatResponseUpdate: The streaming chat message contents.
"""
# Below is needed for mypy: https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
if False:
yield
await asyncio.sleep(0) # pragma: no cover
# This is a no-op, but it allows the method to be async and return an AsyncIterable.
# The actual implementation should yield ChatResponseUpdate instances as needed.
# endregion
# region Public method
async def get_response(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
"""Get a response from a chat client.
Args:
messages: the message or messages to send to the model
frequency_penalty: the frequency penalty to use.
logit_bias: the logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request.
kwargs: any additional keyword arguments,
will only be passed to functions that are called.
Returns:
A chat response from the model.
"""
# Should we merge chat options instead of ignoring the input params?
if "chat_options" in kwargs:
chat_options = kwargs.pop("chat_options")
if not isinstance(chat_options, ChatOptions):
raise TypeError("chat_options must be an instance of ChatOptions")
else:
chat_options = ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
top_p=top_p,
tool_choice=tool_choice,
tools=self._normalize_tools(tools), # type: ignore
user=user,
additional_properties=additional_properties or {},
)
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
return await self._inner_get_response(messages=prepped_messages, chat_options=chat_options, **kwargs)
async def get_streaming_response(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage],
*,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
model: str | None = None,
presence_penalty: float | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ChatToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
top_p: float | None = None,
user: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Get a streaming response from a chat client.
Args:
messages: the message or messages to send to the model
frequency_penalty: the frequency penalty to use
logit_bias: the logit bias to use
max_tokens: The maximum number of tokens to generate.
metadata: additional metadata to include in the request.
model: The model to use for the agent.
presence_penalty: the presence penalty to use.
response_format: the format of the response.
seed: the random seed to use.
stop: the stop sequence(s) for the request.
store: whether to store the response.
temperature: the sampling temperature to use.
tool_choice: the tool choice for the request.
tools: the tools to use for the request.
top_p: the nucleus sampling probability to use.
user: the user to associate with the request.
additional_properties: additional properties to include in the request
kwargs: any additional keyword arguments
Yields:
A stream representing the response(s) from the LLM.
"""
# Should we merge chat options instead of ignoring the input params?
if "chat_options" in kwargs:
chat_options = kwargs.pop("chat_options")
if not isinstance(chat_options, ChatOptions):
raise TypeError("chat_options must be an instance of ChatOptions")
else:
chat_options = ChatOptions(
ai_model_id=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
top_p=top_p,
tool_choice=tool_choice,
tools=self._normalize_tools(tools), # type: ignore
user=user,
additional_properties=additional_properties or {},
)
prepped_messages = self.prepare_messages(messages)
self._prepare_tool_choice(chat_options=chat_options)
async for update in self._inner_get_streaming_response(
messages=prepped_messages, chat_options=chat_options, **kwargs
):
yield update
def _prepare_tool_choice(self, chat_options: ChatOptions) -> None:
"""Prepare the tools and tool choice for the chat options.
This function should be overridden by subclasses to customize tool handling.
Because it currently parses only AIFunctions.
"""
chat_tool_mode: ChatToolMode | None = chat_options.tool_choice # type: ignore
if chat_tool_mode is None or chat_tool_mode == ChatToolMode.NONE:
chat_options.tools = None
chat_options.tool_choice = ChatToolMode.NONE.mode
return
if not chat_options.tools:
chat_options.tool_choice = ChatToolMode.NONE.mode
else:
chat_options.tool_choice = chat_tool_mode.mode
def service_url(self) -> str:
"""Get the URL of the service.
Override this in the subclass to return the proper URL.
If the service does not have a URL, return None.
"""
return "Unknown"
def create_agent(
self,
*,
name: str | None = None,
instructions: str | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
chat_message_store_factory: Callable[[], ChatMessageStore] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
**kwargs: Any,
) -> "ChatAgent":
"""Create an agent with the given name and instructions.
Args:
name: The name of the agent.
instructions: The instructions for the agent.
tools: Optional list of tools to associate with the agent.
chat_message_store_factory: Factory function to create an instance of ChatMessageStore. If not provided,
the default in-memory store will be used.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
**kwargs: Additional keyword arguments to pass to the agent.
See ChatAgent for all the available options.
Returns:
An instance of ChatAgent.
"""
from ._agents import ChatAgent
return ChatAgent(
chat_client=self,
name=name,
instructions=instructions,
tools=tools,
chat_message_store_factory=chat_message_store_factory,
context_providers=context_providers,
middleware=middleware,
**kwargs,
)
# region Embedding Client
@runtime_checkable
class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]):
"""A protocol for an embedding generator that can create embeddings from input data."""
async def generate(
self,
input_data: Sequence[TInput],
**kwargs: Any,
) -> GeneratedEmbeddings[TEmbedding]:
"""Generates an embedding for the given input data.
Args:
input_data: The input data to generate an embedding for.
**kwargs: Additional options for the request.
Returns:
The generated embedding, this acts like a list, but has additional metadata and usage details.
"""
...