Python: added embeddingsclient and redid chatclient (#132)

* added embeddingsclient and redid chatclient

* added to init

* added client tests

* fixed typing

* fixed slice import

* fixed pyright
This commit is contained in:
Eduard van Valkenburg
2025-07-07 18:43:54 +02:00
committed by GitHub
Unverified
parent d6c829cddf
commit 09309c1239
7 changed files with 316 additions and 67 deletions
+3 -1
View File
@@ -29,9 +29,11 @@ _IMPORTS = {
"ChatResponseUpdate": "._types",
"ChatRole": "._types",
"ErrorContent": "._types",
"ModelClient": "._types",
"GeneratedEmbeddings": "._types",
"ChatOptions": "._types",
"ChatToolMode": "._types",
"ChatClient": "._clients",
"EmbeddingGenerator": "._clients",
"InputGuardrail": ".guard_rails",
"OutputGuardrail": ".guard_rails",
}
+5 -2
View File
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from . import __version__ # type: ignore[attr-defined]
from ._clients import ChatClient, EmbeddingGenerator
from ._logging import get_logger
from ._tools import AITool, ai_function
from ._types import (
@@ -17,7 +18,7 @@ from ._types import (
ErrorContent,
FunctionCallContent,
FunctionResultContent,
ModelClient,
GeneratedEmbeddings,
StructuredResponse,
TextContent,
TextReasoningContent,
@@ -31,6 +32,7 @@ __all__ = [
"AIContent",
"AIContents",
"AITool",
"ChatClient",
"ChatFinishReason",
"ChatMessage",
"ChatOptions",
@@ -39,11 +41,12 @@ __all__ = [
"ChatRole",
"ChatToolMode",
"DataContent",
"EmbeddingGenerator",
"ErrorContent",
"FunctionCallContent",
"FunctionResultContent",
"GeneratedEmbeddings",
"InputGuardrail",
"ModelClient",
"OutputGuardrail",
"StructuredResponse",
"TextContent",
+82
View File
@@ -0,0 +1,82 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Sequence
from typing import Any, Generic, Protocol, TypeVar, runtime_checkable
from ._types import ChatMessage, ChatResponse, ChatResponseUpdate, GeneratedEmbeddings
TInput = TypeVar("TInput", contravariant=True)
TEmbedding = TypeVar("TEmbedding")
# region: ChatClient Protocol
@runtime_checkable
class ChatClient(Protocol):
"""A protocol for a chat client that can generate responses."""
async def get_response(
self,
messages: ChatMessage | Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
"""Sends input and returns the response.
Args:
messages: The sequence of input messages to send.
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
See `ChatOptions` for more details.
Returns:
The response messages generated by the client.
Raises:
ValueError: If the input message sequence is `None`.
"""
...
async def get_streaming_response(
self,
messages: ChatMessage | Sequence[ChatMessage],
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Sends input messages and streams the response.
Args:
messages: The sequence of input messages to send.
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
See `ChatOptions` for more details.
Yields:
An async iterable of chat response updates containing the content of the response messages
generated by the client.
Raises:
ValueError: If the input message sequence is `None`.
"""
...
# region: Embedding Client
@runtime_checkable
class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]):
"""A protocol for an embedding generator that can create embeddings from input data."""
async def generate(
self,
input_data: Sequence[TInput],
**kwargs: Any,
) -> GeneratedEmbeddings[TEmbedding]:
"""Generates an embedding for the given input data.
Args:
input_data: The input data to generate an embedding for.
**kwargs: Additional options for the request.
Returns:
The generated embedding, this acts like a list, but has additional metadata and usage details.
"""
...
+114 -61
View File
@@ -3,14 +3,13 @@
import base64
import re
import sys
from collections.abc import AsyncIterable, MutableSequence, Sequence
from typing import Annotated, Any, ClassVar, Generic, Literal, Protocol, TypeVar, overload, runtime_checkable
from collections.abc import AsyncIterable, Iterable, Iterator, MutableSequence, Sequence
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload
from pydantic import BaseModel, ConfigDict, Field, field_validator
from ._pydantic import AFBaseModel
from ._tools import AITool
from .guard_rails import InputGuardrail, OutputGuardrail
if sys.version_info >= (3, 12):
pass # pragma: no cover
@@ -22,8 +21,7 @@ else:
# region: Constants and types
_T = TypeVar("_T")
TValue = TypeVar("TValue")
TInput = TypeVar("TInput")
TResponse = TypeVar("TResponse")
TEmbedding = TypeVar("TEmbedding")
TChatResponse = TypeVar("TChatResponse", bound="ChatResponse")
TChatToolMode = TypeVar("TChatToolMode", bound="ChatToolMode")
@@ -118,8 +116,10 @@ class UsageDetails(AFBaseModel):
"""
return self.model_extra or {}
def __add__(self, other: "UsageDetails") -> "UsageDetails":
def __add__(self, other: "UsageDetails | None") -> "UsageDetails":
"""Combines two `UsageDetails` instances."""
if not other:
return self
if not isinstance(other, UsageDetails):
raise ValueError("Can only add two usage details objects together.")
@@ -135,6 +135,21 @@ class UsageDetails(AFBaseModel):
**additional_counts,
)
def __iadd__(self, other: "UsageDetails | None") -> Self:
if not other:
return self
if not isinstance(other, UsageDetails):
raise ValueError("Can only add usage details objects together.")
self.input_token_count = (self.input_token_count or 0) + (other.input_token_count or 0)
self.output_token_count = (self.output_token_count or 0) + (other.output_token_count or 0)
self.total_token_count = (self.total_token_count or 0) + (other.total_token_count or 0)
for key, value in other.additional_counts.items():
self.additional_counts[key] = self.additional_counts.get(key, 0) + (value or 0)
return self
def _process_update(response: "ChatResponse", update: "ChatResponseUpdate") -> None:
"""Processes a single update and modifies the response in place."""
@@ -1202,8 +1217,8 @@ class ChatResponseUpdate(AFBaseModel):
def __init__(
self,
*,
contents: list[AIContent],
role: ChatRole | None = None,
contents: list[AIContents],
role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None,
author_name: str | None = None,
response_id: str | None = None,
message_id: str | None = None,
@@ -1221,7 +1236,7 @@ class ChatResponseUpdate(AFBaseModel):
self,
*,
text: TextContent | str,
role: ChatRole | None = None,
role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None,
author_name: str | None = None,
response_id: str | None = None,
message_id: str | None = None,
@@ -1237,9 +1252,9 @@ class ChatResponseUpdate(AFBaseModel):
def __init__(
self,
*,
contents: list[AIContent] | None = None,
contents: list[AIContents] | None = None,
text: TextContent | str | None = None,
role: ChatRole | None = None,
role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None,
author_name: str | None = None,
response_id: str | None = None,
message_id: str | None = None,
@@ -1257,7 +1272,8 @@ class ChatResponseUpdate(AFBaseModel):
if isinstance(text, str):
text = TextContent(text=text)
contents.append(text)
if role and isinstance(role, str):
role = ChatRole(value=role)
super().__init__(
contents=contents, # type: ignore[reportCallIssue]
additional_properties=additional_properties, # type: ignore[reportCallIssue]
@@ -1379,66 +1395,103 @@ class ChatOptions(AFBaseModel):
return settings
# region: ModelClient Protocol
# region: GeneratedEmbeddings
@runtime_checkable
class ModelClient(Protocol, Generic[TInput, TResponse]):
"""A protocol for a model client that can generate responses."""
class GeneratedEmbeddings(AFBaseModel, MutableSequence[TEmbedding], Generic[TEmbedding]):
"""A model representing generated embeddings."""
async def get_response(
self,
messages: TInput | Sequence[TInput],
**kwargs: Any,
) -> TResponse:
"""Sends input and returns the response.
embeddings: list[TEmbedding] = Field(default_factory=list, kw_only=False) # type: ignore[ReportUnknownVariableType]
usage: UsageDetails | None = None
additional_properties: dict[str, Any] = Field(default_factory=dict)
Args:
messages: The sequence of input messages to send.
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
See `ChatOptions` for more details.
def __contains__(self, value: object) -> bool:
return value in self.embeddings
Returns:
The response messages generated by the client.
def __iter__(self) -> Iterator[TEmbedding]: # type: ignore[override] # overrides a method in BaseModel, ignoring
return iter(self.embeddings)
Raises:
ValueError: If the input message sequence is `None`.
"""
...
def __len__(self) -> int:
return len(self.embeddings)
async def get_streaming_response(
self,
messages: TInput | Sequence[TInput],
**kwargs: Any, # kwargs?
) -> AsyncIterable[TResponse]:
"""Sends input messages and streams the response.
def __reversed__(self) -> Iterator[TEmbedding]:
return self.embeddings.__reversed__()
Args:
messages: The sequence of input messages to send.
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
See `ChatOptions` for more details.
def index(self, value: TEmbedding, start: int = 0, stop: int | None = None) -> int:
if start > 0:
if stop is not None:
return self.embeddings.index(value, start, stop)
return self.embeddings.index(value, start)
return self.embeddings.index(value)
Returns:
An async iterable of chat response updates containing the content of the response messages
generated by the client.
def count(self, value: TEmbedding) -> int:
return self.embeddings.count(value)
Raises:
ValueError: If the input message sequence is `None`.
"""
...
@overload
def __getitem__(self, index: int) -> TEmbedding: ...
def add_input_guardrails(self, guardrails: list[InputGuardrail[TInput]]) -> None:
"""Add input guardrails to the model client.
@overload
def __getitem__(self, index: slice) -> MutableSequence[TEmbedding]: ...
Args:
guardrails: The list of input guardrails to add.
"""
...
def __getitem__(self, index: int | slice) -> TEmbedding | MutableSequence[TEmbedding]:
return self.embeddings[index]
def add_output_guardrails(self, guardrails: list[OutputGuardrail[TResponse | Sequence[TResponse]]]) -> None:
"""Add output guardrails to the model client.
@overload
def __setitem__(self, index: int, value: TEmbedding) -> None: ...
Args:
guardrails: The list of output guardrails to add.
"""
...
@overload
def __setitem__(self, index: slice, value: Iterable[TEmbedding]) -> None: ...
def __setitem__(self, index: int | slice, value: TEmbedding | Iterable[TEmbedding]) -> None:
if isinstance(index, int):
if isinstance(value, Iterable):
raise TypeError("Value must be an iterable when setting a slice.")
self.embeddings[index] = value
return
if not isinstance(value, Iterable):
raise TypeError("Value must be an iterable when setting a slice.")
self.embeddings[index] = value
@overload
def __delitem__(self, index: int) -> None: ...
@overload
def __delitem__(self, index: slice) -> None: ...
def __delitem__(self, index: int | slice) -> None:
del self.embeddings[index]
def insert(self, index: int, value: TEmbedding) -> None:
self.embeddings.insert(index, value)
def append(self, value: TEmbedding) -> None:
self.embeddings.append(value)
def clear(self) -> None:
self.embeddings.clear()
self.usage = None
self.additional_properties = {}
def reverse(self) -> None:
self.embeddings.reverse()
def extend(self, values: Iterable[TEmbedding]) -> None:
self.embeddings.extend(values)
def pop(self, index: int = -1) -> TEmbedding:
return self.embeddings.pop(index)
def remove(self, value: TEmbedding) -> None:
self.embeddings.remove(value)
def __iadd__(self, values: Iterable[TEmbedding] | Self) -> Self:
if isinstance(values, GeneratedEmbeddings):
self.embeddings += values.embeddings
if not self.usage:
self.usage = values.usage
else:
self.usage += values.usage
self.additional_properties.update(values.additional_properties)
else:
self.embeddings += values
return self
+93
View File
@@ -0,0 +1,93 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Sequence
from typing import Any
from pytest import fixture
from agent_framework import (
ChatClient,
ChatMessage,
ChatResponse,
ChatResponseUpdate,
ChatRole,
EmbeddingGenerator,
GeneratedEmbeddings,
TextContent,
)
class ImplementedChatClient:
"""Simple implementation of a chat client."""
async def get_response(
self,
messages: ChatMessage | Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
# Implement the method
return ChatResponse(messages=ChatMessage(role="assistant", text="test response"))
async def get_streaming_response(
self,
messages: ChatMessage | Sequence[ChatMessage],
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
# Implement the method
yield ChatResponseUpdate(text=TextContent(text="test streaming response"), role="assistant")
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
class ImplementedEmbeddingGenerator:
"""Simple implementation of an embedding generator."""
async def generate(
self,
input_data: Sequence[str],
**kwargs: Any,
) -> GeneratedEmbeddings[list[float]]:
# Implement the method
embeddings = GeneratedEmbeddings[list[float]]()
for i, _ in enumerate(input_data):
embeddings.append([0.0 * 1, 0.1 * 1, 0.2 * 1, 0.3 * i, 0.4 * i])
return embeddings
@fixture
def chat_client() -> ImplementedChatClient:
return ImplementedChatClient()
@fixture
def embedding_generator() -> ImplementedEmbeddingGenerator:
gen: EmbeddingGenerator[str, list[float]] = ImplementedEmbeddingGenerator()
return gen
def test_chat_client_type(chat_client: ImplementedChatClient):
assert isinstance(chat_client, ChatClient)
async def test_chat_client_get_response(chat_client: ImplementedChatClient):
response = await chat_client.get_response(ChatMessage(role="user", text="Hello"))
assert response.text == "test response"
assert response.messages[0].role == ChatRole.ASSISTANT
async def test_chat_client_get_streaming_response(chat_client: ImplementedChatClient):
async for update in chat_client.get_streaming_response(ChatMessage(role="user", text="Hello")):
assert update.text == "test streaming response" or update.text == "another update"
assert update.role == ChatRole.ASSISTANT
def test_embedding_generator_type(embedding_generator: ImplementedEmbeddingGenerator):
assert isinstance(embedding_generator, EmbeddingGenerator)
async def test_embedding_generator_generate(embedding_generator: ImplementedEmbeddingGenerator):
input_data = ["Hello", "world"]
embeddings = await embedding_generator.generate(input_data)
assert len(embeddings) == len(input_data)
for emb in embeddings:
assert len(emb) == 5
+16
View File
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import MutableSequence
from pydantic import BaseModel, ValidationError
from pytest import mark, raises
@@ -14,6 +16,7 @@ from agent_framework import (
DataContent,
FunctionCallContent,
FunctionResultContent,
GeneratedEmbeddings,
StructuredResponse,
TextContent,
TextReasoningContent,
@@ -464,3 +467,16 @@ def test_chat_tool_mode_from_dict():
# Ensure the instance is of type ChatToolMode
assert isinstance(mode, ChatToolMode)
def test_generated_embeddings():
"""Test the GeneratedEmbeddings class to ensure it initializes correctly."""
# Create an instance of GeneratedEmbeddings
embeddings = GeneratedEmbeddings(embeddings=[[0.1, 0.2, 0.3]])
# Check the type and content
assert embeddings.embeddings == [[0.1, 0.2, 0.3]]
# Ensure the instance is of type GeneratedEmbeddings
assert isinstance(embeddings, GeneratedEmbeddings)
assert issubclass(GeneratedEmbeddings, MutableSequence)
+3 -3
View File
@@ -2627,11 +2627,11 @@ wheels = [
[[package]]
name = "typing-extensions"
version = "4.14.0"
version = "4.14.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d1/bc/51647cd02527e87d05cb083ccc402f93e441606ff1f01739a62c8ad09ba5/typing_extensions-4.14.0.tar.gz", hash = "sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4", size = 107423, upload-time = "2025-06-02T14:52:11.399Z" }
sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/69/e0/552843e0d356fbb5256d21449fa957fa4eff3bbc135a74a691ee70c7c5da/typing_extensions-4.14.0-py3-none-any.whl", hash = "sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af", size = 43839, upload-time = "2025-06-02T14:52:10.026Z" },
{ url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" },
]
[[package]]