mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
32bd884bfd
* Add concrete AGUIChatClient * Update logging docstrings and conventions * PR feedback * Updates to support client-side tool calls
318 lines
12 KiB
Python
318 lines
12 KiB
Python
"""Tests for AGUIChatClient."""
|
|
|
|
import json
|
|
|
|
from agent_framework import ChatMessage, ChatOptions, FunctionCallContent, Role, ai_function
|
|
|
|
from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent
|
|
|
|
|
|
class TestAGUIChatClient:
|
|
"""Test suite for AGUIChatClient."""
|
|
|
|
async def test_client_initialization(self) -> None:
|
|
"""Test client initialization."""
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
|
|
assert client._http_service is not None
|
|
assert client._http_service.endpoint.startswith("http://localhost:8888")
|
|
|
|
async def test_client_context_manager(self) -> None:
|
|
"""Test client as async context manager."""
|
|
async with AGUIChatClient(endpoint="http://localhost:8888/") as client:
|
|
assert client is not None
|
|
|
|
async def test_extract_state_from_messages_no_state(self) -> None:
|
|
"""Test state extraction when no state is present."""
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
messages = [
|
|
ChatMessage(role="user", text="Hello"),
|
|
ChatMessage(role="assistant", text="Hi there"),
|
|
]
|
|
|
|
result_messages, state = client._extract_state_from_messages(messages)
|
|
|
|
assert result_messages == messages
|
|
assert state is None
|
|
|
|
async def test_extract_state_from_messages_with_state(self) -> None:
|
|
"""Test state extraction from last message."""
|
|
import base64
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
|
|
state_data = {"key": "value", "count": 42}
|
|
state_json = json.dumps(state_data)
|
|
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
|
|
|
|
from agent_framework import DataContent
|
|
|
|
messages = [
|
|
ChatMessage(role="user", text="Hello"),
|
|
ChatMessage(
|
|
role="user",
|
|
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
|
|
),
|
|
]
|
|
|
|
result_messages, state = client._extract_state_from_messages(messages)
|
|
|
|
assert len(result_messages) == 1
|
|
assert result_messages[0].text == "Hello"
|
|
assert state == state_data
|
|
|
|
async def test_extract_state_invalid_json(self) -> None:
|
|
"""Test state extraction with invalid JSON."""
|
|
import base64
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
|
|
invalid_json = "not valid json"
|
|
state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8")
|
|
|
|
from agent_framework import DataContent
|
|
|
|
messages = [
|
|
ChatMessage(
|
|
role="user",
|
|
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
|
|
),
|
|
]
|
|
|
|
result_messages, state = client._extract_state_from_messages(messages)
|
|
|
|
assert result_messages == messages
|
|
assert state is None
|
|
|
|
async def test_convert_messages_to_agui_format(self) -> None:
|
|
"""Test message conversion to AG-UI format."""
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
messages = [
|
|
ChatMessage(role=Role.USER, text="What is the weather?"),
|
|
ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"),
|
|
]
|
|
|
|
agui_messages = client._convert_messages_to_agui_format(messages)
|
|
|
|
assert len(agui_messages) == 2
|
|
assert agui_messages[0]["role"] == "user"
|
|
assert agui_messages[0]["content"] == "What is the weather?"
|
|
assert agui_messages[1]["role"] == "assistant"
|
|
assert agui_messages[1]["content"] == "Let me check."
|
|
assert agui_messages[1]["id"] == "msg_123"
|
|
|
|
async def test_get_thread_id_from_metadata(self) -> None:
|
|
"""Test thread ID extraction from metadata."""
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"})
|
|
|
|
thread_id = client._get_thread_id(chat_options)
|
|
|
|
assert thread_id == "existing_thread_123"
|
|
|
|
async def test_get_thread_id_generation(self) -> None:
|
|
"""Test automatic thread ID generation."""
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
chat_options = ChatOptions()
|
|
|
|
thread_id = client._get_thread_id(chat_options)
|
|
|
|
assert thread_id.startswith("thread_")
|
|
assert len(thread_id) > 7
|
|
|
|
async def test_get_streaming_response(self, monkeypatch) -> None:
|
|
"""Test streaming response method."""
|
|
mock_events = [
|
|
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
|
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
|
|
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
|
|
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
|
]
|
|
|
|
async def mock_post_run(*args, **kwargs):
|
|
for event in mock_events:
|
|
yield event
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
|
|
|
messages = [ChatMessage(role="user", text="Test message")]
|
|
chat_options = ChatOptions()
|
|
|
|
updates = []
|
|
async for update in client._inner_get_streaming_response(messages=messages, chat_options=chat_options):
|
|
updates.append(update)
|
|
|
|
assert len(updates) == 4
|
|
assert updates[0].additional_properties["thread_id"] == "thread_1"
|
|
assert updates[1].contents[0].text == "Hello"
|
|
assert updates[2].contents[0].text == " world"
|
|
|
|
async def test_get_response_non_streaming(self, monkeypatch) -> None:
|
|
"""Test non-streaming response method."""
|
|
mock_events = [
|
|
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
|
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Complete response"},
|
|
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
|
]
|
|
|
|
async def mock_post_run(*args, **kwargs):
|
|
for event in mock_events:
|
|
yield event
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
|
|
|
messages = [ChatMessage(role="user", text="Test message")]
|
|
chat_options = ChatOptions()
|
|
|
|
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
|
|
|
assert response is not None
|
|
assert len(response.messages) > 0
|
|
assert "Complete response" in response.text
|
|
|
|
async def test_tool_handling(self, monkeypatch) -> None:
|
|
"""Test that client tool metadata is sent to server.
|
|
|
|
Client tool metadata (name, description, schema) is sent to server for planning.
|
|
When server requests a client function, @use_function_invocation decorator
|
|
intercepts and executes it locally. This matches .NET AG-UI implementation.
|
|
"""
|
|
from agent_framework import ai_function
|
|
|
|
@ai_function
|
|
def test_tool(param: str) -> str:
|
|
"""Test tool."""
|
|
return "result"
|
|
|
|
mock_events = [
|
|
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
|
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
|
]
|
|
|
|
async def mock_post_run(*args, **kwargs):
|
|
# Client tool metadata should be sent to server
|
|
tools = kwargs.get("tools")
|
|
assert tools is not None
|
|
assert len(tools) == 1
|
|
assert tools[0]["name"] == "test_tool"
|
|
assert tools[0]["description"] == "Test tool."
|
|
assert "parameters" in tools[0]
|
|
for event in mock_events:
|
|
yield event
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
|
|
|
messages = [ChatMessage(role="user", text="Test with tools")]
|
|
chat_options = ChatOptions(tools=[test_tool])
|
|
|
|
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
|
|
|
assert response is not None
|
|
|
|
async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch) -> None:
|
|
"""Ensure server-side tool calls are exposed as FunctionCallContent after processing."""
|
|
|
|
mock_events = [
|
|
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
|
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
|
|
{"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
|
|
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
|
]
|
|
|
|
async def mock_post_run(*args, **kwargs):
|
|
for event in mock_events:
|
|
yield event
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
|
|
|
messages = [ChatMessage(role="user", text="Test server tool execution")]
|
|
chat_options = ChatOptions()
|
|
|
|
updates = []
|
|
async for update in client.get_streaming_response(messages, chat_options=chat_options):
|
|
updates.append(update)
|
|
|
|
function_calls = [
|
|
content for update in updates for content in update.contents if isinstance(content, FunctionCallContent)
|
|
]
|
|
assert function_calls
|
|
assert function_calls[0].name == "get_time_zone"
|
|
assert not any(
|
|
isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents
|
|
)
|
|
|
|
async def test_server_tool_calls_not_executed_locally(self, monkeypatch) -> None:
|
|
"""Server tools should not trigger local function invocation even when client tools exist."""
|
|
|
|
@ai_function
|
|
def client_tool() -> str:
|
|
"""Client tool stub."""
|
|
return "client"
|
|
|
|
mock_events = [
|
|
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
|
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
|
|
{"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
|
|
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
|
]
|
|
|
|
async def mock_post_run(*args, **kwargs):
|
|
for event in mock_events:
|
|
yield event
|
|
|
|
async def fake_auto_invoke(*args, **kwargs):
|
|
function_call = kwargs.get("function_call_content") or args[0]
|
|
raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}")
|
|
|
|
monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke)
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
|
|
|
messages = [ChatMessage(role="user", text="Test server tool execution")]
|
|
chat_options = ChatOptions(tool_choice="auto", tools=[client_tool])
|
|
|
|
async for _ in client.get_streaming_response(messages, chat_options=chat_options):
|
|
pass
|
|
|
|
async def test_state_transmission(self, monkeypatch) -> None:
|
|
"""Test state is properly transmitted to server."""
|
|
import base64
|
|
|
|
state_data = {"user_id": "123", "session": "abc"}
|
|
state_json = json.dumps(state_data)
|
|
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
|
|
|
|
from agent_framework import DataContent
|
|
|
|
messages = [
|
|
ChatMessage(role="user", text="Hello"),
|
|
ChatMessage(
|
|
role="user",
|
|
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
|
|
),
|
|
]
|
|
|
|
mock_events = [
|
|
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
|
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
|
]
|
|
|
|
async def mock_post_run(*args, **kwargs):
|
|
assert kwargs.get("state") == state_data
|
|
for event in mock_events:
|
|
yield event
|
|
|
|
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
|
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
|
|
|
chat_options = ChatOptions()
|
|
|
|
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
|
|
|
assert response is not None
|