From 09309c123907a67e3ed09048143af9f7fec509aa Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Mon, 7 Jul 2025 18:43:54 +0200 Subject: [PATCH] Python: added embeddingsclient and redid chatclient (#132) * added embeddingsclient and redid chatclient * added to init * added client tests * fixed typing * fixed slice import * fixed pyright --- python/agent_framework/__init__.py | 4 +- python/agent_framework/__init__.pyi | 7 +- python/agent_framework/_clients.py | 82 +++++++++++++ python/agent_framework/_types.py | 175 ++++++++++++++++++---------- python/tests/unit/test_clients.py | 93 +++++++++++++++ python/tests/unit/test_types.py | 16 +++ python/uv.lock | 6 +- 7 files changed, 316 insertions(+), 67 deletions(-) create mode 100644 python/agent_framework/_clients.py create mode 100644 python/tests/unit/test_clients.py diff --git a/python/agent_framework/__init__.py b/python/agent_framework/__init__.py index 7135cf7618..c933ae97e8 100644 --- a/python/agent_framework/__init__.py +++ b/python/agent_framework/__init__.py @@ -29,9 +29,11 @@ _IMPORTS = { "ChatResponseUpdate": "._types", "ChatRole": "._types", "ErrorContent": "._types", - "ModelClient": "._types", + "GeneratedEmbeddings": "._types", "ChatOptions": "._types", "ChatToolMode": "._types", + "ChatClient": "._clients", + "EmbeddingGenerator": "._clients", "InputGuardrail": ".guard_rails", "OutputGuardrail": ".guard_rails", } diff --git a/python/agent_framework/__init__.pyi b/python/agent_framework/__init__.pyi index dc995f7ee5..b4d78c6828 100644 --- a/python/agent_framework/__init__.pyi +++ b/python/agent_framework/__init__.pyi @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from . import __version__ # type: ignore[attr-defined] +from ._clients import ChatClient, EmbeddingGenerator from ._logging import get_logger from ._tools import AITool, ai_function from ._types import ( @@ -17,7 +18,7 @@ from ._types import ( ErrorContent, FunctionCallContent, FunctionResultContent, - ModelClient, + GeneratedEmbeddings, StructuredResponse, TextContent, TextReasoningContent, @@ -31,6 +32,7 @@ __all__ = [ "AIContent", "AIContents", "AITool", + "ChatClient", "ChatFinishReason", "ChatMessage", "ChatOptions", @@ -39,11 +41,12 @@ __all__ = [ "ChatRole", "ChatToolMode", "DataContent", + "EmbeddingGenerator", "ErrorContent", "FunctionCallContent", "FunctionResultContent", + "GeneratedEmbeddings", "InputGuardrail", - "ModelClient", "OutputGuardrail", "StructuredResponse", "TextContent", diff --git a/python/agent_framework/_clients.py b/python/agent_framework/_clients.py new file mode 100644 index 0000000000..db30ed1505 --- /dev/null +++ b/python/agent_framework/_clients.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable, Sequence +from typing import Any, Generic, Protocol, TypeVar, runtime_checkable + +from ._types import ChatMessage, ChatResponse, ChatResponseUpdate, GeneratedEmbeddings + +TInput = TypeVar("TInput", contravariant=True) +TEmbedding = TypeVar("TEmbedding") + +# region: ChatClient Protocol + + +@runtime_checkable +class ChatClient(Protocol): + """A protocol for a chat client that can generate responses.""" + + async def get_response( + self, + messages: ChatMessage | Sequence[ChatMessage], + **kwargs: Any, + ) -> ChatResponse: + """Sends input and returns the response. + + Args: + messages: The sequence of input messages to send. + **kwargs: Additional options for the request, such as ai_model_id, temperature, etc. + See `ChatOptions` for more details. + + Returns: + The response messages generated by the client. + + Raises: + ValueError: If the input message sequence is `None`. + """ + ... + + async def get_streaming_response( + self, + messages: ChatMessage | Sequence[ChatMessage], + **kwargs: Any, + ) -> AsyncIterable[ChatResponseUpdate]: + """Sends input messages and streams the response. + + Args: + messages: The sequence of input messages to send. + **kwargs: Additional options for the request, such as ai_model_id, temperature, etc. + See `ChatOptions` for more details. + + Yields: + An async iterable of chat response updates containing the content of the response messages + generated by the client. + + Raises: + ValueError: If the input message sequence is `None`. + """ + ... + + +# region: Embedding Client + + +@runtime_checkable +class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]): + """A protocol for an embedding generator that can create embeddings from input data.""" + + async def generate( + self, + input_data: Sequence[TInput], + **kwargs: Any, + ) -> GeneratedEmbeddings[TEmbedding]: + """Generates an embedding for the given input data. + + Args: + input_data: The input data to generate an embedding for. + **kwargs: Additional options for the request. + + Returns: + The generated embedding, this acts like a list, but has additional metadata and usage details. + + """ + ... diff --git a/python/agent_framework/_types.py b/python/agent_framework/_types.py index c7e59f5464..df86ade3e1 100644 --- a/python/agent_framework/_types.py +++ b/python/agent_framework/_types.py @@ -3,14 +3,13 @@ import base64 import re import sys -from collections.abc import AsyncIterable, MutableSequence, Sequence -from typing import Annotated, Any, ClassVar, Generic, Literal, Protocol, TypeVar, overload, runtime_checkable +from collections.abc import AsyncIterable, Iterable, Iterator, MutableSequence, Sequence +from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload from pydantic import BaseModel, ConfigDict, Field, field_validator from ._pydantic import AFBaseModel from ._tools import AITool -from .guard_rails import InputGuardrail, OutputGuardrail if sys.version_info >= (3, 12): pass # pragma: no cover @@ -22,8 +21,7 @@ else: # region: Constants and types _T = TypeVar("_T") TValue = TypeVar("TValue") -TInput = TypeVar("TInput") -TResponse = TypeVar("TResponse") +TEmbedding = TypeVar("TEmbedding") TChatResponse = TypeVar("TChatResponse", bound="ChatResponse") TChatToolMode = TypeVar("TChatToolMode", bound="ChatToolMode") @@ -118,8 +116,10 @@ class UsageDetails(AFBaseModel): """ return self.model_extra or {} - def __add__(self, other: "UsageDetails") -> "UsageDetails": + def __add__(self, other: "UsageDetails | None") -> "UsageDetails": """Combines two `UsageDetails` instances.""" + if not other: + return self if not isinstance(other, UsageDetails): raise ValueError("Can only add two usage details objects together.") @@ -135,6 +135,21 @@ class UsageDetails(AFBaseModel): **additional_counts, ) + def __iadd__(self, other: "UsageDetails | None") -> Self: + if not other: + return self + if not isinstance(other, UsageDetails): + raise ValueError("Can only add usage details objects together.") + + self.input_token_count = (self.input_token_count or 0) + (other.input_token_count or 0) + self.output_token_count = (self.output_token_count or 0) + (other.output_token_count or 0) + self.total_token_count = (self.total_token_count or 0) + (other.total_token_count or 0) + + for key, value in other.additional_counts.items(): + self.additional_counts[key] = self.additional_counts.get(key, 0) + (value or 0) + + return self + def _process_update(response: "ChatResponse", update: "ChatResponseUpdate") -> None: """Processes a single update and modifies the response in place.""" @@ -1202,8 +1217,8 @@ class ChatResponseUpdate(AFBaseModel): def __init__( self, *, - contents: list[AIContent], - role: ChatRole | None = None, + contents: list[AIContents], + role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None, author_name: str | None = None, response_id: str | None = None, message_id: str | None = None, @@ -1221,7 +1236,7 @@ class ChatResponseUpdate(AFBaseModel): self, *, text: TextContent | str, - role: ChatRole | None = None, + role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None, author_name: str | None = None, response_id: str | None = None, message_id: str | None = None, @@ -1237,9 +1252,9 @@ class ChatResponseUpdate(AFBaseModel): def __init__( self, *, - contents: list[AIContent] | None = None, + contents: list[AIContents] | None = None, text: TextContent | str | None = None, - role: ChatRole | None = None, + role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None, author_name: str | None = None, response_id: str | None = None, message_id: str | None = None, @@ -1257,7 +1272,8 @@ class ChatResponseUpdate(AFBaseModel): if isinstance(text, str): text = TextContent(text=text) contents.append(text) - + if role and isinstance(role, str): + role = ChatRole(value=role) super().__init__( contents=contents, # type: ignore[reportCallIssue] additional_properties=additional_properties, # type: ignore[reportCallIssue] @@ -1379,66 +1395,103 @@ class ChatOptions(AFBaseModel): return settings -# region: ModelClient Protocol +# region: GeneratedEmbeddings -@runtime_checkable -class ModelClient(Protocol, Generic[TInput, TResponse]): - """A protocol for a model client that can generate responses.""" +class GeneratedEmbeddings(AFBaseModel, MutableSequence[TEmbedding], Generic[TEmbedding]): + """A model representing generated embeddings.""" - async def get_response( - self, - messages: TInput | Sequence[TInput], - **kwargs: Any, - ) -> TResponse: - """Sends input and returns the response. + embeddings: list[TEmbedding] = Field(default_factory=list, kw_only=False) # type: ignore[ReportUnknownVariableType] + usage: UsageDetails | None = None + additional_properties: dict[str, Any] = Field(default_factory=dict) - Args: - messages: The sequence of input messages to send. - **kwargs: Additional options for the request, such as ai_model_id, temperature, etc. - See `ChatOptions` for more details. + def __contains__(self, value: object) -> bool: + return value in self.embeddings - Returns: - The response messages generated by the client. + def __iter__(self) -> Iterator[TEmbedding]: # type: ignore[override] # overrides a method in BaseModel, ignoring + return iter(self.embeddings) - Raises: - ValueError: If the input message sequence is `None`. - """ - ... + def __len__(self) -> int: + return len(self.embeddings) - async def get_streaming_response( - self, - messages: TInput | Sequence[TInput], - **kwargs: Any, # kwargs? - ) -> AsyncIterable[TResponse]: - """Sends input messages and streams the response. + def __reversed__(self) -> Iterator[TEmbedding]: + return self.embeddings.__reversed__() - Args: - messages: The sequence of input messages to send. - **kwargs: Additional options for the request, such as ai_model_id, temperature, etc. - See `ChatOptions` for more details. + def index(self, value: TEmbedding, start: int = 0, stop: int | None = None) -> int: + if start > 0: + if stop is not None: + return self.embeddings.index(value, start, stop) + return self.embeddings.index(value, start) + return self.embeddings.index(value) - Returns: - An async iterable of chat response updates containing the content of the response messages - generated by the client. + def count(self, value: TEmbedding) -> int: + return self.embeddings.count(value) - Raises: - ValueError: If the input message sequence is `None`. - """ - ... + @overload + def __getitem__(self, index: int) -> TEmbedding: ... - def add_input_guardrails(self, guardrails: list[InputGuardrail[TInput]]) -> None: - """Add input guardrails to the model client. + @overload + def __getitem__(self, index: slice) -> MutableSequence[TEmbedding]: ... - Args: - guardrails: The list of input guardrails to add. - """ - ... + def __getitem__(self, index: int | slice) -> TEmbedding | MutableSequence[TEmbedding]: + return self.embeddings[index] - def add_output_guardrails(self, guardrails: list[OutputGuardrail[TResponse | Sequence[TResponse]]]) -> None: - """Add output guardrails to the model client. + @overload + def __setitem__(self, index: int, value: TEmbedding) -> None: ... - Args: - guardrails: The list of output guardrails to add. - """ - ... + @overload + def __setitem__(self, index: slice, value: Iterable[TEmbedding]) -> None: ... + + def __setitem__(self, index: int | slice, value: TEmbedding | Iterable[TEmbedding]) -> None: + if isinstance(index, int): + if isinstance(value, Iterable): + raise TypeError("Value must be an iterable when setting a slice.") + self.embeddings[index] = value + return + if not isinstance(value, Iterable): + raise TypeError("Value must be an iterable when setting a slice.") + self.embeddings[index] = value + + @overload + def __delitem__(self, index: int) -> None: ... + + @overload + def __delitem__(self, index: slice) -> None: ... + + def __delitem__(self, index: int | slice) -> None: + del self.embeddings[index] + + def insert(self, index: int, value: TEmbedding) -> None: + self.embeddings.insert(index, value) + + def append(self, value: TEmbedding) -> None: + self.embeddings.append(value) + + def clear(self) -> None: + self.embeddings.clear() + self.usage = None + self.additional_properties = {} + + def reverse(self) -> None: + self.embeddings.reverse() + + def extend(self, values: Iterable[TEmbedding]) -> None: + self.embeddings.extend(values) + + def pop(self, index: int = -1) -> TEmbedding: + return self.embeddings.pop(index) + + def remove(self, value: TEmbedding) -> None: + self.embeddings.remove(value) + + def __iadd__(self, values: Iterable[TEmbedding] | Self) -> Self: + if isinstance(values, GeneratedEmbeddings): + self.embeddings += values.embeddings + if not self.usage: + self.usage = values.usage + else: + self.usage += values.usage + self.additional_properties.update(values.additional_properties) + else: + self.embeddings += values + return self diff --git a/python/tests/unit/test_clients.py b/python/tests/unit/test_clients.py new file mode 100644 index 0000000000..ac15bd8b67 --- /dev/null +++ b/python/tests/unit/test_clients.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable, Sequence +from typing import Any + +from pytest import fixture + +from agent_framework import ( + ChatClient, + ChatMessage, + ChatResponse, + ChatResponseUpdate, + ChatRole, + EmbeddingGenerator, + GeneratedEmbeddings, + TextContent, +) + + +class ImplementedChatClient: + """Simple implementation of a chat client.""" + + async def get_response( + self, + messages: ChatMessage | Sequence[ChatMessage], + **kwargs: Any, + ) -> ChatResponse: + # Implement the method + + return ChatResponse(messages=ChatMessage(role="assistant", text="test response")) + + async def get_streaming_response( + self, + messages: ChatMessage | Sequence[ChatMessage], + **kwargs: Any, + ) -> AsyncIterable[ChatResponseUpdate]: + # Implement the method + yield ChatResponseUpdate(text=TextContent(text="test streaming response"), role="assistant") + yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant") + + +class ImplementedEmbeddingGenerator: + """Simple implementation of an embedding generator.""" + + async def generate( + self, + input_data: Sequence[str], + **kwargs: Any, + ) -> GeneratedEmbeddings[list[float]]: + # Implement the method + embeddings = GeneratedEmbeddings[list[float]]() + for i, _ in enumerate(input_data): + embeddings.append([0.0 * 1, 0.1 * 1, 0.2 * 1, 0.3 * i, 0.4 * i]) + return embeddings + + +@fixture +def chat_client() -> ImplementedChatClient: + return ImplementedChatClient() + + +@fixture +def embedding_generator() -> ImplementedEmbeddingGenerator: + gen: EmbeddingGenerator[str, list[float]] = ImplementedEmbeddingGenerator() + return gen + + +def test_chat_client_type(chat_client: ImplementedChatClient): + assert isinstance(chat_client, ChatClient) + + +async def test_chat_client_get_response(chat_client: ImplementedChatClient): + response = await chat_client.get_response(ChatMessage(role="user", text="Hello")) + assert response.text == "test response" + assert response.messages[0].role == ChatRole.ASSISTANT + + +async def test_chat_client_get_streaming_response(chat_client: ImplementedChatClient): + async for update in chat_client.get_streaming_response(ChatMessage(role="user", text="Hello")): + assert update.text == "test streaming response" or update.text == "another update" + assert update.role == ChatRole.ASSISTANT + + +def test_embedding_generator_type(embedding_generator: ImplementedEmbeddingGenerator): + assert isinstance(embedding_generator, EmbeddingGenerator) + + +async def test_embedding_generator_generate(embedding_generator: ImplementedEmbeddingGenerator): + input_data = ["Hello", "world"] + embeddings = await embedding_generator.generate(input_data) + assert len(embeddings) == len(input_data) + for emb in embeddings: + assert len(emb) == 5 diff --git a/python/tests/unit/test_types.py b/python/tests/unit/test_types.py index 16f9a56caa..79bc570c46 100644 --- a/python/tests/unit/test_types.py +++ b/python/tests/unit/test_types.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import MutableSequence + from pydantic import BaseModel, ValidationError from pytest import mark, raises @@ -14,6 +16,7 @@ from agent_framework import ( DataContent, FunctionCallContent, FunctionResultContent, + GeneratedEmbeddings, StructuredResponse, TextContent, TextReasoningContent, @@ -464,3 +467,16 @@ def test_chat_tool_mode_from_dict(): # Ensure the instance is of type ChatToolMode assert isinstance(mode, ChatToolMode) + + +def test_generated_embeddings(): + """Test the GeneratedEmbeddings class to ensure it initializes correctly.""" + # Create an instance of GeneratedEmbeddings + embeddings = GeneratedEmbeddings(embeddings=[[0.1, 0.2, 0.3]]) + + # Check the type and content + assert embeddings.embeddings == [[0.1, 0.2, 0.3]] + + # Ensure the instance is of type GeneratedEmbeddings + assert isinstance(embeddings, GeneratedEmbeddings) + assert issubclass(GeneratedEmbeddings, MutableSequence) diff --git a/python/uv.lock b/python/uv.lock index 23c367ec1b..1dfc66f560 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -2627,11 +2627,11 @@ wheels = [ [[package]] name = "typing-extensions" -version = "4.14.0" +version = "4.14.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d1/bc/51647cd02527e87d05cb083ccc402f93e441606ff1f01739a62c8ad09ba5/typing_extensions-4.14.0.tar.gz", hash = "sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4", size = 107423, upload-time = "2025-06-02T14:52:11.399Z" } +sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/e0/552843e0d356fbb5256d21449fa957fa4eff3bbc135a74a691ee70c7c5da/typing_extensions-4.14.0-py3-none-any.whl", hash = "sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af", size = 43839, upload-time = "2025-06-02T14:52:10.026Z" }, + { url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" }, ] [[package]]