# Copyright (c) Microsoft. All rights reserved. """Tests for AGUIChatClient.""" import json from collections.abc import AsyncGenerator, AsyncIterable, MutableSequence from typing import Any from agent_framework import ( ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, Content, Role, ai_function, ) from pytest import MonkeyPatch from agent_framework_ag_ui._client import AGUIChatClient from agent_framework_ag_ui._http_service import AGUIHttpService class TestableAGUIChatClient(AGUIChatClient): """Testable wrapper exposing protected helpers.""" @property def http_service(self) -> AGUIHttpService: """Expose http service for monkeypatching.""" return self._http_service def extract_state_from_messages( self, messages: list[ChatMessage] ) -> tuple[list[ChatMessage], dict[str, Any] | None]: """Expose state extraction helper.""" return self._extract_state_from_messages(messages) def convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[dict[str, Any]]: """Expose message conversion helper.""" return self._convert_messages_to_agui_format(messages) def get_thread_id(self, options: dict[str, Any]) -> str: """Expose thread id helper.""" return self._get_thread_id(options) async def inner_get_streaming_response( self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any] ) -> AsyncIterable[ChatResponseUpdate]: """Proxy to protected streaming call.""" async for update in self._inner_get_streaming_response(messages=messages, options=options): yield update async def inner_get_response( self, *, messages: MutableSequence[ChatMessage], options: dict[str, Any] ) -> ChatResponse: """Proxy to protected response call.""" return await self._inner_get_response(messages=messages, options=options) class TestAGUIChatClient: """Test suite for AGUIChatClient.""" async def test_client_initialization(self) -> None: """Test client initialization.""" client = TestableAGUIChatClient(endpoint="http://localhost:8888/") assert client.http_service is not None assert client.http_service.endpoint.startswith("http://localhost:8888") async def test_client_context_manager(self) -> None: """Test client as async context manager.""" async with TestableAGUIChatClient(endpoint="http://localhost:8888/") as client: assert client is not None async def test_extract_state_from_messages_no_state(self) -> None: """Test state extraction when no state is present.""" client = TestableAGUIChatClient(endpoint="http://localhost:8888/") messages = [ ChatMessage(role="user", text="Hello"), ChatMessage(role="assistant", text="Hi there"), ] result_messages, state = client.extract_state_from_messages(messages) assert result_messages == messages assert state is None async def test_extract_state_from_messages_with_state(self) -> None: """Test state extraction from last message.""" import base64 client = TestableAGUIChatClient(endpoint="http://localhost:8888/") state_data = {"key": "value", "count": 42} state_json = json.dumps(state_data) state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8") messages = [ ChatMessage(role="user", text="Hello"), ChatMessage( role="user", contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")], ), ] result_messages, state = client.extract_state_from_messages(messages) assert len(result_messages) == 1 assert result_messages[0].text == "Hello" assert state == state_data async def test_extract_state_invalid_json(self) -> None: """Test state extraction with invalid JSON.""" import base64 client = TestableAGUIChatClient(endpoint="http://localhost:8888/") invalid_json = "not valid json" state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8") messages = [ ChatMessage( role="user", contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")], ), ] result_messages, state = client.extract_state_from_messages(messages) assert result_messages == messages assert state is None async def test_convert_messages_to_agui_format(self) -> None: """Test message conversion to AG-UI format.""" client = TestableAGUIChatClient(endpoint="http://localhost:8888/") messages = [ ChatMessage(role=Role.USER, text="What is the weather?"), ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"), ] agui_messages = client.convert_messages_to_agui_format(messages) assert len(agui_messages) == 2 assert agui_messages[0]["role"] == "user" assert agui_messages[0]["content"] == "What is the weather?" assert agui_messages[1]["role"] == "assistant" assert agui_messages[1]["content"] == "Let me check." assert agui_messages[1]["id"] == "msg_123" async def test_get_thread_id_from_metadata(self) -> None: """Test thread ID extraction from metadata.""" client = TestableAGUIChatClient(endpoint="http://localhost:8888/") chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"}) thread_id = client.get_thread_id(chat_options) assert thread_id == "existing_thread_123" async def test_get_thread_id_generation(self) -> None: """Test automatic thread ID generation.""" client = TestableAGUIChatClient(endpoint="http://localhost:8888/") chat_options = ChatOptions() thread_id = client.get_thread_id(chat_options) assert thread_id.startswith("thread_") assert len(thread_id) > 7 async def test_get_streaming_response(self, monkeypatch: MonkeyPatch) -> None: """Test streaming response method.""" mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"}, {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"}, {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event client = TestableAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test message")] chat_options = ChatOptions() updates: list[ChatResponseUpdate] = [] async for update in client.inner_get_streaming_response(messages=messages, options=chat_options): updates.append(update) assert len(updates) == 4 assert updates[0].additional_properties is not None assert updates[0].additional_properties["thread_id"] == "thread_1" first_content = updates[1].contents[0] second_content = updates[2].contents[0] assert first_content.type == "text" assert second_content.type == "text" assert first_content.text == "Hello" assert second_content.text == " world" async def test_get_response_non_streaming(self, monkeypatch: MonkeyPatch) -> None: """Test non-streaming response method.""" mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Complete response"}, {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event client = TestableAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test message")] chat_options = {} response = await client.inner_get_response(messages=messages, options=chat_options) assert response is not None assert len(response.messages) > 0 assert "Complete response" in response.text async def test_tool_handling(self, monkeypatch: MonkeyPatch) -> None: """Test that client tool metadata is sent to server. Client tool metadata (name, description, schema) is sent to server for planning. When server requests a client function, @use_function_invocation decorator intercepts and executes it locally. This matches .NET AG-UI implementation. """ from agent_framework import ai_function @ai_function def test_tool(param: str) -> str: """Test tool.""" return "result" mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: # Client tool metadata should be sent to server tools: list[dict[str, Any]] | None = kwargs.get("tools") assert tools is not None assert len(tools) == 1 tool_entry = tools[0] assert tool_entry["name"] == "test_tool" assert tool_entry["description"] == "Test tool." assert "parameters" in tool_entry for event in mock_events: yield event client = TestableAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test with tools")] chat_options = ChatOptions(tools=[test_tool]) response = await client.inner_get_response(messages=messages, options=chat_options) assert response is not None async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch: MonkeyPatch) -> None: """Ensure server-side tool calls are exposed as FunctionCallContent after processing.""" mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, {"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"}, {"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'}, {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event client = TestableAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test server tool execution")] updates: list[ChatResponseUpdate] = [] async for update in client.get_streaming_response(messages): updates.append(update) function_calls = [ content for update in updates for content in update.contents if content.type == "function_call" ] assert function_calls assert function_calls[0].name == "get_time_zone" assert not any(content.type == "server_function_call" for update in updates for content in update.contents) async def test_server_tool_calls_not_executed_locally(self, monkeypatch: MonkeyPatch) -> None: """Server tools should not trigger local function invocation even when client tools exist.""" @ai_function def client_tool() -> str: """Client tool stub.""" return "client" mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, {"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"}, {"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'}, {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: for event in mock_events: yield event async def fake_auto_invoke(*args: object, **kwargs: Any) -> None: function_call = kwargs.get("function_call_content") or args[0] raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}") monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke) client = TestableAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [ChatMessage(role="user", text="Test server tool execution")] async for _ in client.get_streaming_response(messages, options={"tool_choice": "auto", "tools": [client_tool]}): pass async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None: """Test state is properly transmitted to server.""" import base64 state_data = {"user_id": "123", "session": "abc"} state_json = json.dumps(state_data) state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8") messages = [ ChatMessage(role="user", text="Hello"), ChatMessage( role="user", contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")], ), ] mock_events = [ {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: assert kwargs.get("state") == state_data for event in mock_events: yield event client = TestableAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) chat_options = ChatOptions() response = await client.inner_get_response(messages=messages, options=chat_options) assert response is not None