mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
27f7af2160
* moved prepare tools into class * moved test * changed tool handling * fix test * second fix
353 lines
12 KiB
Python
353 lines
12 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
import asyncio
|
|
import sys
|
|
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
|
from typing import Any
|
|
|
|
from pydantic import Field
|
|
from pytest import fixture
|
|
|
|
from agent_framework import (
|
|
ChatClient,
|
|
ChatClientBase,
|
|
ChatMessage,
|
|
ChatOptions,
|
|
ChatResponse,
|
|
ChatResponseUpdate,
|
|
ChatRole,
|
|
EmbeddingGenerator,
|
|
FunctionCallContent,
|
|
FunctionResultContent,
|
|
GeneratedEmbeddings,
|
|
TextContent,
|
|
ai_function,
|
|
use_tool_calling,
|
|
)
|
|
|
|
if sys.version_info >= (3, 12):
|
|
from typing import override # type: ignore
|
|
else:
|
|
from typing_extensions import override # type: ignore[import]
|
|
|
|
|
|
class MockChatClient:
|
|
"""Simple implementation of a chat client."""
|
|
|
|
async def get_response(
|
|
self,
|
|
messages: ChatMessage | Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
# Implement the method
|
|
|
|
return ChatResponse(messages=ChatMessage(role="assistant", text="test response"))
|
|
|
|
async def get_streaming_response(
|
|
self,
|
|
messages: ChatMessage | Sequence[ChatMessage],
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
# Implement the method
|
|
yield ChatResponseUpdate(text=TextContent(text="test streaming response"), role="assistant")
|
|
yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant")
|
|
|
|
|
|
@use_tool_calling
|
|
class MockChatClientBase(ChatClientBase):
|
|
"""Mock implementation of the ChatClientBase."""
|
|
|
|
run_responses: list[ChatResponse] = Field(default_factory=list)
|
|
streaming_responses: list[list[ChatResponseUpdate]] = Field(default_factory=list)
|
|
|
|
@override
|
|
async def _inner_get_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> ChatResponse:
|
|
"""Send a chat request to the AI service.
|
|
|
|
Args:
|
|
messages: The chat messages to send.
|
|
chat_options: The options for the request.
|
|
kwargs: Any additional keyword arguments.
|
|
|
|
Returns:
|
|
The chat response contents representing the response(s).
|
|
"""
|
|
if not self.run_responses or chat_options.tool_choice == "none":
|
|
return ChatResponse(messages=ChatMessage(role="assistant", text=f"test response - {messages[0].text}"))
|
|
return self.run_responses.pop(0)
|
|
|
|
@override
|
|
async def _inner_get_streaming_response(
|
|
self,
|
|
*,
|
|
messages: MutableSequence[ChatMessage],
|
|
chat_options: ChatOptions,
|
|
**kwargs: Any,
|
|
) -> AsyncIterable[ChatResponseUpdate]:
|
|
if not self.streaming_responses or chat_options.tool_choice == "none":
|
|
yield ChatResponseUpdate(text=f"update - {messages[0].text}", role="assistant")
|
|
return
|
|
response = self.streaming_responses.pop(0)
|
|
for update in response:
|
|
yield update
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
class MockEmbeddingGenerator:
|
|
"""Simple implementation of an embedding generator."""
|
|
|
|
async def generate(
|
|
self,
|
|
input_data: Sequence[str],
|
|
**kwargs: Any,
|
|
) -> GeneratedEmbeddings[list[float]]:
|
|
# Implement the method
|
|
embeddings = GeneratedEmbeddings[list[float]]()
|
|
for i, _ in enumerate(input_data):
|
|
embeddings.append([0.0 * 1, 0.1 * 1, 0.2 * 1, 0.3 * i, 0.4 * i])
|
|
return embeddings
|
|
|
|
|
|
@fixture
|
|
def chat_client() -> MockChatClient:
|
|
return MockChatClient()
|
|
|
|
|
|
@fixture
|
|
def chat_client_base() -> MockChatClientBase:
|
|
return MockChatClientBase()
|
|
|
|
|
|
@fixture
|
|
def embedding_generator() -> MockEmbeddingGenerator:
|
|
gen: EmbeddingGenerator[str, list[float]] = MockEmbeddingGenerator()
|
|
return gen
|
|
|
|
|
|
def test_chat_client_type(chat_client: MockChatClient):
|
|
assert isinstance(chat_client, ChatClient)
|
|
|
|
|
|
async def test_chat_client_get_response(chat_client: MockChatClient):
|
|
response = await chat_client.get_response(ChatMessage(role="user", text="Hello"))
|
|
assert response.text == "test response"
|
|
assert response.messages[0].role == ChatRole.ASSISTANT
|
|
|
|
|
|
async def test_chat_client_get_streaming_response(chat_client: MockChatClient):
|
|
async for update in chat_client.get_streaming_response(ChatMessage(role="user", text="Hello")):
|
|
assert update.text == "test streaming response" or update.text == "another update"
|
|
assert update.role == ChatRole.ASSISTANT
|
|
|
|
|
|
def test_embedding_generator_type(embedding_generator: MockEmbeddingGenerator):
|
|
assert isinstance(embedding_generator, EmbeddingGenerator)
|
|
|
|
|
|
async def test_embedding_generator_generate(embedding_generator: MockEmbeddingGenerator):
|
|
input_data = ["Hello", "world"]
|
|
embeddings = await embedding_generator.generate(input_data)
|
|
assert len(embeddings) == len(input_data)
|
|
for emb in embeddings:
|
|
assert len(emb) == 5
|
|
|
|
|
|
def test_base_client(chat_client_base: MockChatClientBase):
|
|
assert isinstance(chat_client_base, ChatClientBase)
|
|
assert isinstance(chat_client_base, ChatClient)
|
|
|
|
|
|
async def test_base_client_get_response(chat_client_base: MockChatClientBase):
|
|
response = await chat_client_base.get_response(ChatMessage(role="user", text="Hello"))
|
|
assert response.messages[0].role == ChatRole.ASSISTANT
|
|
assert response.messages[0].text == "test response - Hello"
|
|
|
|
|
|
async def test_base_client_get_streaming_response(chat_client_base: MockChatClientBase):
|
|
async for update in chat_client_base.get_streaming_response(ChatMessage(role="user", text="Hello")):
|
|
assert update.text == "update - Hello" or update.text == "another update"
|
|
|
|
|
|
async def test_base_client_with_function_calling(chat_client_base: MockChatClientBase):
|
|
exec_counter = 0
|
|
|
|
@ai_function(name="test_function")
|
|
def ai_func(arg1: str) -> str:
|
|
nonlocal exec_counter
|
|
exec_counter += 1
|
|
return f"Processed {arg1}"
|
|
|
|
chat_client_base.run_responses = [
|
|
ChatResponse(
|
|
messages=ChatMessage(
|
|
role="assistant",
|
|
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
|
|
)
|
|
),
|
|
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
|
]
|
|
response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[ai_func])
|
|
assert exec_counter == 1
|
|
assert len(response.messages) == 3
|
|
assert response.messages[0].role == ChatRole.ASSISTANT
|
|
assert isinstance(response.messages[0].contents[0], FunctionCallContent)
|
|
assert response.messages[0].contents[0].name == "test_function"
|
|
assert response.messages[0].contents[0].arguments == '{"arg1": "value1"}'
|
|
assert response.messages[0].contents[0].call_id == "1"
|
|
assert response.messages[1].role == ChatRole.TOOL
|
|
assert isinstance(response.messages[1].contents[0], FunctionResultContent)
|
|
assert response.messages[1].contents[0].call_id == "1"
|
|
assert response.messages[1].contents[0].result == "Processed value1"
|
|
assert response.messages[2].role == ChatRole.ASSISTANT
|
|
assert response.messages[2].text == "done"
|
|
|
|
|
|
async def test_base_client_with_function_calling_disabled(chat_client_base: MockChatClientBase):
|
|
chat_client_base.__maximum_iterations_per_request = 0
|
|
exec_counter = 0
|
|
|
|
@ai_function(name="test_function")
|
|
def ai_func(arg1: str) -> str:
|
|
nonlocal exec_counter
|
|
exec_counter += 1
|
|
return f"Processed {arg1}"
|
|
|
|
chat_client_base.run_responses = [
|
|
ChatResponse(
|
|
messages=ChatMessage(
|
|
role="assistant",
|
|
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
|
|
)
|
|
),
|
|
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
|
]
|
|
response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[ai_func])
|
|
assert exec_counter == 0
|
|
assert len(response.messages) == 1
|
|
assert response.messages[0].role == ChatRole.ASSISTANT
|
|
assert response.messages[0].text == "test response - hello"
|
|
|
|
|
|
async def test_base_client_with_streaming_function_calling(chat_client_base: MockChatClientBase):
|
|
exec_counter = 0
|
|
|
|
@ai_function(name="test_function")
|
|
def ai_func(arg1: str) -> str:
|
|
nonlocal exec_counter
|
|
exec_counter += 1
|
|
return f"Processed {arg1}"
|
|
|
|
chat_client_base.streaming_responses = [
|
|
[
|
|
ChatResponseUpdate(
|
|
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1":')],
|
|
role="assistant",
|
|
),
|
|
ChatResponseUpdate(
|
|
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='"value1"}')],
|
|
role="assistant",
|
|
),
|
|
],
|
|
[
|
|
ChatResponseUpdate(
|
|
contents=[TextContent(text="Processed value1")],
|
|
role="assistant",
|
|
)
|
|
],
|
|
]
|
|
updates = []
|
|
async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[ai_func]):
|
|
updates.append(update)
|
|
assert len(updates) == 4 # two updates with the function call, the function result and the final text
|
|
assert updates[0].contents[0].call_id == "1"
|
|
assert updates[1].contents[0].call_id == "1"
|
|
assert updates[2].contents[0].call_id == "1"
|
|
assert updates[3].text == "Processed value1"
|
|
assert exec_counter == 1
|
|
|
|
|
|
async def test_base_client_with_streaming_function_calling_disabled(chat_client_base: MockChatClientBase):
|
|
chat_client_base.__maximum_iterations_per_request = 0
|
|
exec_counter = 0
|
|
|
|
@ai_function(name="test_function")
|
|
def ai_func(arg1: str) -> str:
|
|
nonlocal exec_counter
|
|
exec_counter += 1
|
|
return f"Processed {arg1}"
|
|
|
|
chat_client_base.streaming_responses = [
|
|
[
|
|
ChatResponseUpdate(
|
|
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1":')],
|
|
role="assistant",
|
|
),
|
|
ChatResponseUpdate(
|
|
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='"value1"}')],
|
|
role="assistant",
|
|
),
|
|
],
|
|
[
|
|
ChatResponseUpdate(
|
|
contents=[TextContent(text="Processed value1")],
|
|
role="assistant",
|
|
)
|
|
],
|
|
]
|
|
updates = []
|
|
async for update in chat_client_base.get_streaming_response("hello", tool_choice="auto", tools=[ai_func]):
|
|
updates.append(update)
|
|
assert len(updates) == 1
|
|
assert exec_counter == 0
|
|
|
|
|
|
def test_chat_options_parsing_tools(chat_client_base, ai_function_tool) -> None:
|
|
"""Test that chat options can parse tools correctly."""
|
|
|
|
def echo() -> str:
|
|
"""Echo the input."""
|
|
return "Echo"
|
|
|
|
dict_function = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"description": "Retrieves current weather for the given location.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {"type": "string", "description": "City and country e.g. Bogotá, Colombia"},
|
|
"units": {
|
|
"type": "string",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
"description": "Units the temperature will be returned in.",
|
|
},
|
|
},
|
|
"required": ["location", "units"],
|
|
"additionalProperties": False,
|
|
},
|
|
"strict": True,
|
|
},
|
|
}
|
|
|
|
options = ChatOptions(tools=[ai_function_tool, echo, dict_function], tool_choice="auto")
|
|
assert len(options.tools) == 3
|
|
assert options.tools[0] == ai_function_tool
|
|
assert options.tools[1] != echo
|
|
assert options.tools[2] == dict_function
|
|
# after prepare, the tools should be represented as dicts
|
|
# while ai_tools is still the same.
|
|
chat_client_base._prepare_tools_and_tool_choice(chat_options=options)
|
|
assert options._ai_tools[0] == ai_function_tool
|
|
assert options._ai_tools[2] == dict_function
|
|
assert len(options.tools) == 3
|
|
assert options.tools[0]["function"]["name"] == "simple_function"
|
|
assert options.tools[1]["function"]["name"] == "echo"
|
|
assert options.tools[2]["function"]["name"] == "get_weather"
|