mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: added embeddingsclient and redid chatclient (#132)
* added embeddingsclient and redid chatclient * added to init * added client tests * fixed typing * fixed slice import * fixed pyright
This commit is contained in:
committed by
GitHub
Unverified
parent
d6c829cddf
commit
09309c1239
@@ -29,9 +29,11 @@ _IMPORTS = {
|
||||
"ChatResponseUpdate": "._types",
|
||||
"ChatRole": "._types",
|
||||
"ErrorContent": "._types",
|
||||
"ModelClient": "._types",
|
||||
"GeneratedEmbeddings": "._types",
|
||||
"ChatOptions": "._types",
|
||||
"ChatToolMode": "._types",
|
||||
"ChatClient": "._clients",
|
||||
"EmbeddingGenerator": "._clients",
|
||||
"InputGuardrail": ".guard_rails",
|
||||
"OutputGuardrail": ".guard_rails",
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from . import __version__ # type: ignore[attr-defined]
|
||||
from ._clients import ChatClient, EmbeddingGenerator
|
||||
from ._logging import get_logger
|
||||
from ._tools import AITool, ai_function
|
||||
from ._types import (
|
||||
@@ -17,7 +18,7 @@ from ._types import (
|
||||
ErrorContent,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
ModelClient,
|
||||
GeneratedEmbeddings,
|
||||
StructuredResponse,
|
||||
TextContent,
|
||||
TextReasoningContent,
|
||||
@@ -31,6 +32,7 @@ __all__ = [
|
||||
"AIContent",
|
||||
"AIContents",
|
||||
"AITool",
|
||||
"ChatClient",
|
||||
"ChatFinishReason",
|
||||
"ChatMessage",
|
||||
"ChatOptions",
|
||||
@@ -39,11 +41,12 @@ __all__ = [
|
||||
"ChatRole",
|
||||
"ChatToolMode",
|
||||
"DataContent",
|
||||
"EmbeddingGenerator",
|
||||
"ErrorContent",
|
||||
"FunctionCallContent",
|
||||
"FunctionResultContent",
|
||||
"GeneratedEmbeddings",
|
||||
"InputGuardrail",
|
||||
"ModelClient",
|
||||
"OutputGuardrail",
|
||||
"StructuredResponse",
|
||||
"TextContent",
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any, Generic, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
from ._types import ChatMessage, ChatResponse, ChatResponseUpdate, GeneratedEmbeddings
|
||||
|
||||
TInput = TypeVar("TInput", contravariant=True)
|
||||
TEmbedding = TypeVar("TEmbedding")
|
||||
|
||||
# region: ChatClient Protocol
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ChatClient(Protocol):
|
||||
"""A protocol for a chat client that can generate responses."""
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: ChatMessage | Sequence[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
"""Sends input and returns the response.
|
||||
|
||||
Args:
|
||||
messages: The sequence of input messages to send.
|
||||
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
||||
See `ChatOptions` for more details.
|
||||
|
||||
Returns:
|
||||
The response messages generated by the client.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input message sequence is `None`.
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
messages: ChatMessage | Sequence[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Sends input messages and streams the response.
|
||||
|
||||
Args:
|
||||
messages: The sequence of input messages to send.
|
||||
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
||||
See `ChatOptions` for more details.
|
||||
|
||||
Yields:
|
||||
An async iterable of chat response updates containing the content of the response messages
|
||||
generated by the client.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input message sequence is `None`.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# region: Embedding Client
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EmbeddingGenerator(Protocol, Generic[TInput, TEmbedding]):
|
||||
"""A protocol for an embedding generator that can create embeddings from input data."""
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
input_data: Sequence[TInput],
|
||||
**kwargs: Any,
|
||||
) -> GeneratedEmbeddings[TEmbedding]:
|
||||
"""Generates an embedding for the given input data.
|
||||
|
||||
Args:
|
||||
input_data: The input data to generate an embedding for.
|
||||
**kwargs: Additional options for the request.
|
||||
|
||||
Returns:
|
||||
The generated embedding, this acts like a list, but has additional metadata and usage details.
|
||||
|
||||
"""
|
||||
...
|
||||
@@ -3,14 +3,13 @@
|
||||
import base64
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from typing import Annotated, Any, ClassVar, Generic, Literal, Protocol, TypeVar, overload, runtime_checkable
|
||||
from collections.abc import AsyncIterable, Iterable, Iterator, MutableSequence, Sequence
|
||||
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._tools import AITool
|
||||
from .guard_rails import InputGuardrail, OutputGuardrail
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
pass # pragma: no cover
|
||||
@@ -22,8 +21,7 @@ else:
|
||||
# region: Constants and types
|
||||
_T = TypeVar("_T")
|
||||
TValue = TypeVar("TValue")
|
||||
TInput = TypeVar("TInput")
|
||||
TResponse = TypeVar("TResponse")
|
||||
TEmbedding = TypeVar("TEmbedding")
|
||||
TChatResponse = TypeVar("TChatResponse", bound="ChatResponse")
|
||||
TChatToolMode = TypeVar("TChatToolMode", bound="ChatToolMode")
|
||||
|
||||
@@ -118,8 +116,10 @@ class UsageDetails(AFBaseModel):
|
||||
"""
|
||||
return self.model_extra or {}
|
||||
|
||||
def __add__(self, other: "UsageDetails") -> "UsageDetails":
|
||||
def __add__(self, other: "UsageDetails | None") -> "UsageDetails":
|
||||
"""Combines two `UsageDetails` instances."""
|
||||
if not other:
|
||||
return self
|
||||
if not isinstance(other, UsageDetails):
|
||||
raise ValueError("Can only add two usage details objects together.")
|
||||
|
||||
@@ -135,6 +135,21 @@ class UsageDetails(AFBaseModel):
|
||||
**additional_counts,
|
||||
)
|
||||
|
||||
def __iadd__(self, other: "UsageDetails | None") -> Self:
|
||||
if not other:
|
||||
return self
|
||||
if not isinstance(other, UsageDetails):
|
||||
raise ValueError("Can only add usage details objects together.")
|
||||
|
||||
self.input_token_count = (self.input_token_count or 0) + (other.input_token_count or 0)
|
||||
self.output_token_count = (self.output_token_count or 0) + (other.output_token_count or 0)
|
||||
self.total_token_count = (self.total_token_count or 0) + (other.total_token_count or 0)
|
||||
|
||||
for key, value in other.additional_counts.items():
|
||||
self.additional_counts[key] = self.additional_counts.get(key, 0) + (value or 0)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _process_update(response: "ChatResponse", update: "ChatResponseUpdate") -> None:
|
||||
"""Processes a single update and modifies the response in place."""
|
||||
@@ -1202,8 +1217,8 @@ class ChatResponseUpdate(AFBaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
contents: list[AIContent],
|
||||
role: ChatRole | None = None,
|
||||
contents: list[AIContents],
|
||||
role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None,
|
||||
author_name: str | None = None,
|
||||
response_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
@@ -1221,7 +1236,7 @@ class ChatResponseUpdate(AFBaseModel):
|
||||
self,
|
||||
*,
|
||||
text: TextContent | str,
|
||||
role: ChatRole | None = None,
|
||||
role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None,
|
||||
author_name: str | None = None,
|
||||
response_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
@@ -1237,9 +1252,9 @@ class ChatResponseUpdate(AFBaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
contents: list[AIContent] | None = None,
|
||||
contents: list[AIContents] | None = None,
|
||||
text: TextContent | str | None = None,
|
||||
role: ChatRole | None = None,
|
||||
role: ChatRole | Literal["system", "user", "assistant", "tool"] | None = None,
|
||||
author_name: str | None = None,
|
||||
response_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
@@ -1257,7 +1272,8 @@ class ChatResponseUpdate(AFBaseModel):
|
||||
if isinstance(text, str):
|
||||
text = TextContent(text=text)
|
||||
contents.append(text)
|
||||
|
||||
if role and isinstance(role, str):
|
||||
role = ChatRole(value=role)
|
||||
super().__init__(
|
||||
contents=contents, # type: ignore[reportCallIssue]
|
||||
additional_properties=additional_properties, # type: ignore[reportCallIssue]
|
||||
@@ -1379,66 +1395,103 @@ class ChatOptions(AFBaseModel):
|
||||
return settings
|
||||
|
||||
|
||||
# region: ModelClient Protocol
|
||||
# region: GeneratedEmbeddings
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ModelClient(Protocol, Generic[TInput, TResponse]):
|
||||
"""A protocol for a model client that can generate responses."""
|
||||
class GeneratedEmbeddings(AFBaseModel, MutableSequence[TEmbedding], Generic[TEmbedding]):
|
||||
"""A model representing generated embeddings."""
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: TInput | Sequence[TInput],
|
||||
**kwargs: Any,
|
||||
) -> TResponse:
|
||||
"""Sends input and returns the response.
|
||||
embeddings: list[TEmbedding] = Field(default_factory=list, kw_only=False) # type: ignore[ReportUnknownVariableType]
|
||||
usage: UsageDetails | None = None
|
||||
additional_properties: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
Args:
|
||||
messages: The sequence of input messages to send.
|
||||
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
||||
See `ChatOptions` for more details.
|
||||
def __contains__(self, value: object) -> bool:
|
||||
return value in self.embeddings
|
||||
|
||||
Returns:
|
||||
The response messages generated by the client.
|
||||
def __iter__(self) -> Iterator[TEmbedding]: # type: ignore[override] # overrides a method in BaseModel, ignoring
|
||||
return iter(self.embeddings)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input message sequence is `None`.
|
||||
"""
|
||||
...
|
||||
def __len__(self) -> int:
|
||||
return len(self.embeddings)
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
messages: TInput | Sequence[TInput],
|
||||
**kwargs: Any, # kwargs?
|
||||
) -> AsyncIterable[TResponse]:
|
||||
"""Sends input messages and streams the response.
|
||||
def __reversed__(self) -> Iterator[TEmbedding]:
|
||||
return self.embeddings.__reversed__()
|
||||
|
||||
Args:
|
||||
messages: The sequence of input messages to send.
|
||||
**kwargs: Additional options for the request, such as ai_model_id, temperature, etc.
|
||||
See `ChatOptions` for more details.
|
||||
def index(self, value: TEmbedding, start: int = 0, stop: int | None = None) -> int:
|
||||
if start > 0:
|
||||
if stop is not None:
|
||||
return self.embeddings.index(value, start, stop)
|
||||
return self.embeddings.index(value, start)
|
||||
return self.embeddings.index(value)
|
||||
|
||||
Returns:
|
||||
An async iterable of chat response updates containing the content of the response messages
|
||||
generated by the client.
|
||||
def count(self, value: TEmbedding) -> int:
|
||||
return self.embeddings.count(value)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input message sequence is `None`.
|
||||
"""
|
||||
...
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> TEmbedding: ...
|
||||
|
||||
def add_input_guardrails(self, guardrails: list[InputGuardrail[TInput]]) -> None:
|
||||
"""Add input guardrails to the model client.
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> MutableSequence[TEmbedding]: ...
|
||||
|
||||
Args:
|
||||
guardrails: The list of input guardrails to add.
|
||||
"""
|
||||
...
|
||||
def __getitem__(self, index: int | slice) -> TEmbedding | MutableSequence[TEmbedding]:
|
||||
return self.embeddings[index]
|
||||
|
||||
def add_output_guardrails(self, guardrails: list[OutputGuardrail[TResponse | Sequence[TResponse]]]) -> None:
|
||||
"""Add output guardrails to the model client.
|
||||
@overload
|
||||
def __setitem__(self, index: int, value: TEmbedding) -> None: ...
|
||||
|
||||
Args:
|
||||
guardrails: The list of output guardrails to add.
|
||||
"""
|
||||
...
|
||||
@overload
|
||||
def __setitem__(self, index: slice, value: Iterable[TEmbedding]) -> None: ...
|
||||
|
||||
def __setitem__(self, index: int | slice, value: TEmbedding | Iterable[TEmbedding]) -> None:
|
||||
if isinstance(index, int):
|
||||
if isinstance(value, Iterable):
|
||||
raise TypeError("Value must be an iterable when setting a slice.")
|
||||
self.embeddings[index] = value
|
||||
return
|
||||
if not isinstance(value, Iterable):
|
||||
raise TypeError("Value must be an iterable when setting a slice.")
|
||||
self.embeddings[index] = value
|
||||
|
||||
@overload
|
||||
def __delitem__(self, index: int) -> None: ...
|
||||
|
||||
@overload
|
||||
def __delitem__(self, index: slice) -> None: ...
|
||||
|
||||
def __delitem__(self, index: int | slice) -> None:
|
||||
del self.embeddings[index]
|
||||
|
||||
def insert(self, index: int, value: TEmbedding) -> None:
|
||||
self.embeddings.insert(index, value)
|
||||
|
||||
def append(self, value: TEmbedding) -> None:
|
||||
self.embeddings.append(value)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.embeddings.clear()
|
||||
self.usage = None
|
||||
self.additional_properties = {}
|
||||
|
||||
def reverse(self) -> None:
|
||||
self.embeddings.reverse()
|
||||
|
||||
def extend(self, values: Iterable[TEmbedding]) -> None:
|
||||
self.embeddings.extend(values)
|
||||
|
||||
def pop(self, index: int = -1) -> TEmbedding:
|
||||
return self.embeddings.pop(index)
|
||||
|
||||
def remove(self, value: TEmbedding) -> None:
|
||||
self.embeddings.remove(value)
|
||||
|
||||
def __iadd__(self, values: Iterable[TEmbedding] | Self) -> Self:
|
||||
if isinstance(values, GeneratedEmbeddings):
|
||||
self.embeddings += values.embeddings
|
||||
if not self.usage:
|
||||
self.usage = values.usage
|
||||
else:
|
||||
self.usage += values.usage
|
||||
self.additional_properties.update(values.additional_properties)
|
||||
else:
|
||||
self.embeddings += values
|
||||
return self
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from pytest import fixture
|
||||
|
||||
from agent_framework import (
|
||||
ChatClient,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ChatRole,
|
||||
EmbeddingGenerator,
|
||||
GeneratedEmbeddings,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
|
||||
class ImplementedChatClient:
|
||||
"""Simple implementation of a chat client."""
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: ChatMessage | Sequence[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
# Implement the method
|
||||
|
||||
return ChatResponse(messages=ChatMessage(role="assistant", text="test response"))
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
messages: ChatMessage | Sequence[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
# Implement the method
|
||||
yield ChatResponseUpdate(text=TextContent(text="test streaming response"), role="assistant")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
|
||||
|
||||
|
||||
class ImplementedEmbeddingGenerator:
|
||||
"""Simple implementation of an embedding generator."""
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
input_data: Sequence[str],
|
||||
**kwargs: Any,
|
||||
) -> GeneratedEmbeddings[list[float]]:
|
||||
# Implement the method
|
||||
embeddings = GeneratedEmbeddings[list[float]]()
|
||||
for i, _ in enumerate(input_data):
|
||||
embeddings.append([0.0 * 1, 0.1 * 1, 0.2 * 1, 0.3 * i, 0.4 * i])
|
||||
return embeddings
|
||||
|
||||
|
||||
@fixture
|
||||
def chat_client() -> ImplementedChatClient:
|
||||
return ImplementedChatClient()
|
||||
|
||||
|
||||
@fixture
|
||||
def embedding_generator() -> ImplementedEmbeddingGenerator:
|
||||
gen: EmbeddingGenerator[str, list[float]] = ImplementedEmbeddingGenerator()
|
||||
return gen
|
||||
|
||||
|
||||
def test_chat_client_type(chat_client: ImplementedChatClient):
|
||||
assert isinstance(chat_client, ChatClient)
|
||||
|
||||
|
||||
async def test_chat_client_get_response(chat_client: ImplementedChatClient):
|
||||
response = await chat_client.get_response(ChatMessage(role="user", text="Hello"))
|
||||
assert response.text == "test response"
|
||||
assert response.messages[0].role == ChatRole.ASSISTANT
|
||||
|
||||
|
||||
async def test_chat_client_get_streaming_response(chat_client: ImplementedChatClient):
|
||||
async for update in chat_client.get_streaming_response(ChatMessage(role="user", text="Hello")):
|
||||
assert update.text == "test streaming response" or update.text == "another update"
|
||||
assert update.role == ChatRole.ASSISTANT
|
||||
|
||||
|
||||
def test_embedding_generator_type(embedding_generator: ImplementedEmbeddingGenerator):
|
||||
assert isinstance(embedding_generator, EmbeddingGenerator)
|
||||
|
||||
|
||||
async def test_embedding_generator_generate(embedding_generator: ImplementedEmbeddingGenerator):
|
||||
input_data = ["Hello", "world"]
|
||||
embeddings = await embedding_generator.generate(input_data)
|
||||
assert len(embeddings) == len(input_data)
|
||||
for emb in embeddings:
|
||||
assert len(emb) == 5
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import MutableSequence
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pytest import mark, raises
|
||||
|
||||
@@ -14,6 +16,7 @@ from agent_framework import (
|
||||
DataContent,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
GeneratedEmbeddings,
|
||||
StructuredResponse,
|
||||
TextContent,
|
||||
TextReasoningContent,
|
||||
@@ -464,3 +467,16 @@ def test_chat_tool_mode_from_dict():
|
||||
|
||||
# Ensure the instance is of type ChatToolMode
|
||||
assert isinstance(mode, ChatToolMode)
|
||||
|
||||
|
||||
def test_generated_embeddings():
|
||||
"""Test the GeneratedEmbeddings class to ensure it initializes correctly."""
|
||||
# Create an instance of GeneratedEmbeddings
|
||||
embeddings = GeneratedEmbeddings(embeddings=[[0.1, 0.2, 0.3]])
|
||||
|
||||
# Check the type and content
|
||||
assert embeddings.embeddings == [[0.1, 0.2, 0.3]]
|
||||
|
||||
# Ensure the instance is of type GeneratedEmbeddings
|
||||
assert isinstance(embeddings, GeneratedEmbeddings)
|
||||
assert issubclass(GeneratedEmbeddings, MutableSequence)
|
||||
|
||||
Generated
+3
-3
@@ -2627,11 +2627,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.14.0"
|
||||
version = "4.14.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d1/bc/51647cd02527e87d05cb083ccc402f93e441606ff1f01739a62c8ad09ba5/typing_extensions-4.14.0.tar.gz", hash = "sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4", size = 107423, upload-time = "2025-06-02T14:52:11.399Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/69/e0/552843e0d356fbb5256d21449fa957fa4eff3bbc135a74a691ee70c7c5da/typing_extensions-4.14.0-py3-none-any.whl", hash = "sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af", size = 43839, upload-time = "2025-06-02T14:52:10.026Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user