From 94e00bd49ae232ddd35a6472004847f835001ccf Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:09:09 -0700 Subject: [PATCH] Python: Added ChatClientAgentThread and ChatClientAgent implementations (#150) * Added ChatClientAgentThread * Initial version of ChatClientAgent * Completed ChatClientAgent * Small fixes and unit tests * Fixes based on pre-commit * Small fixes * Small renaming * Small improvement * Small fixes * Addressed PR feedback * Small fix * Added method for AgentRunResponse from streaming conversion * Addressed PR feedback * Addressed PR feedback * Addressed PR feedback * Small fix * More fixes --- .../packages/main/agent_framework/__init__.py | 3 + .../main/agent_framework/__init__.pyi | 5 +- .../packages/main/agent_framework/_agents.py | 334 ++++++++++++++++-- .../packages/main/agent_framework/_clients.py | 6 +- .../packages/main/agent_framework/_types.py | 11 + .../packages/main/tests/unit/test_agents.py | 262 +++++++++++--- 6 files changed, 530 insertions(+), 91 deletions(-) diff --git a/python/packages/main/agent_framework/__init__.py b/python/packages/main/agent_framework/__init__.py index 408fa83351..822b65795f 100644 --- a/python/packages/main/agent_framework/__init__.py +++ b/python/packages/main/agent_framework/__init__.py @@ -21,6 +21,9 @@ _IMPORTS = { "ai_function": "._tools", "AIContent": "._types", "AIContents": "._types", + "ChatClientAgent": "._agents", + "ChatClientAgentThread": "._agents", + "ChatClientAgentThreadType": "._agents", "TextContent": "._types", "TextReasoningContent": "._types", "DataContent": "._types", diff --git a/python/packages/main/agent_framework/__init__.pyi b/python/packages/main/agent_framework/__init__.pyi index 834a7524e1..963940e966 100644 --- a/python/packages/main/agent_framework/__init__.pyi +++ b/python/packages/main/agent_framework/__init__.pyi @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from . import __version__ # type: ignore[attr-defined] -from ._agents import Agent, AgentThread +from ._agents import Agent, AgentThread, ChatClientAgent, ChatClientAgentThread, ChatClientAgentThreadType from ._clients import ChatClient, ChatClientBase, EmbeddingGenerator, use_tool_calling from ._logging import get_logger from ._pydantic import AFBaseModel, AFBaseSettings @@ -45,6 +45,9 @@ __all__ = [ "AgentRunResponseUpdate", "AgentThread", "ChatClient", + "ChatClientAgent", + "ChatClientAgentThread", + "ChatClientAgentThreadType", "ChatClientBase", "ChatFinishReason", "ChatMessage", diff --git a/python/packages/main/agent_framework/_agents.py b/python/packages/main/agent_framework/_agents.py index a74b1cb920..0606d4ed4c 100644 --- a/python/packages/main/agent_framework/_agents.py +++ b/python/packages/main/agent_framework/_agents.py @@ -1,11 +1,18 @@ # Copyright (c) Microsoft. All rights reserved. -from abc import abstractmethod -from collections.abc import AsyncIterable, Sequence -from typing import Any, Protocol, runtime_checkable +from collections.abc import AsyncIterable, Callable, Sequence +from enum import Enum +from typing import Any, Protocol, TypeVar, runtime_checkable +from uuid import uuid4 +from pydantic import Field + +from ._clients import ChatClient from ._pydantic import AFBaseModel -from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage +from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole +from .exceptions import AgentExecutionException + +TThreadType = TypeVar("TThreadType", bound="AgentThread") # region AgentThread @@ -15,48 +22,28 @@ class AgentThread(AFBaseModel): id: str | None = None - async def create(self) -> str | None: - """Starts the thread and returns the thread ID.""" - # If the thread ID is already set, we're done, just return the Id. - if self.id is not None: - return self.id - - # Otherwise, create the thread. - self.id = await self._create() - return self.id - - async def delete(self) -> None: - """Ends the current thread.""" - await self._delete() - self.id = None - - async def on_new_message( + async def on_new_messages( self, new_messages: ChatMessage | Sequence[ChatMessage], ) -> None: """Invoked when a new message has been contributed to the chat by any participant.""" - # If the thread is not created yet, create it. - if self.id is None: - await self.create() + await self._on_new_messages(new_messages=new_messages) - await self._on_new_message(new_messages=new_messages) - - @abstractmethod - async def _create(self) -> str: - """Starts the thread and returns the thread ID.""" - ... - - @abstractmethod - async def _delete(self) -> None: - """Ends the current thread.""" - ... - - @abstractmethod - async def _on_new_message( + async def _on_new_messages( self, new_messages: ChatMessage | Sequence[ChatMessage], ) -> None: """Invoked when a new message has been contributed to the chat by any participant.""" + pass + + +# region MessagesRetrievableThread + + +@runtime_checkable +class MessagesRetrievableThread(Protocol): + def get_messages(self) -> AsyncIterable[ChatMessage]: + """Asynchronously retrieves all messages from thread.""" ... @@ -84,7 +71,7 @@ class Agent(Protocol): async def run( self, - messages: str | ChatMessage | list[ChatMessage] | None = None, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, @@ -113,7 +100,7 @@ class Agent(Protocol): def run_stream( self, - messages: str | ChatMessage | list[ChatMessage] | None = None, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, @@ -138,3 +125,272 @@ class Agent(Protocol): def get_new_thread(self) -> AgentThread: """Creates a new conversation thread for the agent.""" ... + + +# region AgentBase + + +class AgentBase(AFBaseModel): + """Base class for all agents. + + Attributes: + id: The unique identifier of the agent If no id is provided, + a new UUID will be generated. + name: The name of the agent + description: The description of the agent + """ + + id: str = Field(default_factory=lambda: str(uuid4())) + name: str = Field(default="UnnamedAgent") + description: str | None = None + + async def _notify_thread_of_new_messages( + self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage] + ) -> None: + """Notify the thread of new messages.""" + if isinstance(new_messages, ChatMessage) or len(new_messages) > 0: + await thread.on_new_messages(new_messages) + + +# region ChatClientAgentThread + + +class ChatClientAgentThreadType(Enum): + """Defines the different supported storage locations for ChatClientAgentThread.""" + + IN_MEMORY_MESSAGES = "InMemoryMessages" + """Messages are stored in memory inside the thread object.""" + + CONVERSATION_ID = "ConversationId" + """Messages are stored in the service and the thread object just has an id reference to the service storage.""" + + +class ChatClientAgentThread(AgentThread): + """Chat client agent thread. + + This class manages chat threads either locally (in-memory) or via a service based on initialization. + """ + + chat_messages: list[ChatMessage] | None = None + storage_location: ChatClientAgentThreadType | None = None + + def __init__( + self, + id: str | None = None, + messages: Sequence[ChatMessage] | None = None, + **kwargs: Any, + ): + """Initialize the chat client agent thread. + + Args: + id: Service thread identifier. If provided, thread is managed by the service and messages are + not stored locally. Must not be empty or whitespace. + messages: Initial messages for local storage. If provided, thread is managed + locally in-memory. + kwargs: Additional keyword arguments. + + Raises: + ValueError: If both id and messages are provided, or if id is empty/whitespace. + + Notes: + - If id is set, _id is assigned and _chat_messages is None (service-managed). + - If messages is set, _chat_messages is populated and _id is None (local). + - If neither is provided, creates an empty local thread. + """ + processed_messages: list[ChatMessage] | None = None + storage_location: ChatClientAgentThreadType | None = None + + if id and messages: + raise ValueError("Cannot specify both id and messages") + + if id: + if not id.strip(): + raise ValueError("ID cannot be empty or whitespace") + storage_location = ChatClientAgentThreadType.CONVERSATION_ID + elif messages: + processed_messages = [] + processed_messages.extend(messages) + storage_location = ChatClientAgentThreadType.IN_MEMORY_MESSAGES + + super().__init__( + id=id, + chat_messages=processed_messages, # type: ignore[reportCallIssue] + storage_location=storage_location, # type: ignore[reportCallIssue] + **kwargs, + ) + + async def get_messages(self) -> AsyncIterable[ChatMessage]: + """Get all messages in the thread.""" + for message in self.chat_messages or []: + yield message + + async def _on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None: + """Handle new messages.""" + if self.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES: + if self.chat_messages is None: + self.chat_messages = [] + self.chat_messages.extend([new_messages] if isinstance(new_messages, ChatMessage) else new_messages) + + +# region ChatClientAgent + + +class ChatClientAgent(AgentBase): + """A Chat Client Agent which depends on ChatClient.""" + + chat_client: ChatClient + + async def run( + self, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + thread, thread_messages = await self._prepare_thread_and_messages( + thread=thread, + input_messages=messages, + construct_thread=lambda: ChatClientAgentThread(), + expected_type=ChatClientAgentThread, + ) + + response = await self.chat_client.get_response(thread_messages, **kwargs) + + self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) + + # Only notify the thread of new messages if the chatResponse was successful + # to avoid inconsistent messages state in the thread. + await self._notify_thread_of_new_messages(thread, thread_messages) + await self._notify_thread_of_new_messages(thread, response.messages) + + return AgentRunResponse( + messages=response.messages, + response_id=response.response_id, + created_at=response.created_at, + usage_details=response.usage_details, + raw_representation=response.raw_representation, + additional_properties=response.additional_properties, + ) + + async def run_stream( + self, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + thread, thread_messages = await self._prepare_thread_and_messages( + thread=thread, + input_messages=messages, + construct_thread=lambda: ChatClientAgentThread(), + expected_type=ChatClientAgentThread, + ) + + response_updates: list[ChatResponseUpdate] = [] + + streaming_response: AsyncIterable[ChatResponseUpdate] = self.chat_client.get_streaming_response(thread_messages) + + async for update in streaming_response: + response_updates.append(update) + yield AgentRunResponseUpdate( + contents=update.contents, + role=update.role, + author_name=update.author_name, + response_id=update.response_id, + message_id=update.message_id, + created_at=update.created_at, + additional_properties=update.additional_properties, + raw_representation=update.raw_representation, + ) + + response = ChatResponse.from_chat_response_updates(response_updates) + + self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) + + # Only notify the thread of new messages if the chatResponse was successful + # to avoid inconsistent messages state in the thread. + await self._notify_thread_of_new_messages(thread, thread_messages) + await self._notify_thread_of_new_messages(thread, response.messages) + + def get_new_thread(self) -> AgentThread: + return ChatClientAgentThread() + + def _update_thread_with_type_and_conversation_id( + self, chat_client_thread: ChatClientAgentThread, responseConversationId: str | None + ) -> None: + """Update thread with storage type and conversation ID. + + Args: + chat_client_thread: The thread to update. + responseConversationId: The conversation ID from the response, if any. + + Raises: + AgentExecutionException: If conversation ID is missing for service-managed thread. + """ + # Set the thread's storage location, the first time that we use it. + if chat_client_thread.storage_location is None: + chat_client_thread.storage_location = ( + ChatClientAgentThreadType.CONVERSATION_ID + if responseConversationId is not None + else ChatClientAgentThreadType.IN_MEMORY_MESSAGES + ) + + # If we got a conversation id back from the chat client, it means that the service supports server side thread + # storage so we should capture the id and update the thread with the new id. + if chat_client_thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID: + if responseConversationId is None: + raise AgentExecutionException( + "Service did not return a valid conversation id when using a service managed thread." + ) + chat_client_thread.id = responseConversationId + + async def _prepare_thread_and_messages( + self, + *, + thread: AgentThread | None, + input_messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None, + construct_thread: Callable[[], TThreadType], + expected_type: type[TThreadType], + ) -> tuple[TThreadType, list[ChatMessage]]: + """Prepare thread and messages for agent execution. + + Args: + thread: The conversation thread, or None to create a new one. + input_messages: Messages to process, can be string, ChatMessage, or sequence. + construct_thread: Factory function to create a new thread. + expected_type: Expected thread type for validation. + + Returns: + Tuple of the thread and normalized messages. + + Raises: + AgentExecutionException: If thread type is incompatible. + """ + messages: list[ChatMessage] = [] + + if thread is None: + thread = construct_thread() + + if not isinstance(thread, expected_type): + raise AgentExecutionException( + f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}." + ) + + # Add any existing messages from the thread to the messages to be sent to the chat client. + if isinstance(thread, MessagesRetrievableThread): + async for message in thread.get_messages(): + messages.append(message) + + if input_messages is None: + input_messages = [] + + if isinstance(input_messages, (str, ChatMessage)): + input_messages = [input_messages] + + normalized_messages = [ + ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in input_messages + ] + + messages.extend(normalized_messages) + + return thread, messages diff --git a/python/packages/main/agent_framework/_clients.py b/python/packages/main/agent_framework/_clients.py index 4aa233d787..c143cf2fa1 100644 --- a/python/packages/main/agent_framework/_clients.py +++ b/python/packages/main/agent_framework/_clients.py @@ -256,7 +256,7 @@ class ChatClient(Protocol): async def get_response( self, - messages: str | ChatMessage | Sequence[ChatMessage], + messages: str | ChatMessage | list[ChatMessage], **kwargs: Any, ) -> ChatResponse: """Sends input and returns the response. @@ -274,9 +274,9 @@ class ChatClient(Protocol): """ ... - async def get_streaming_response( + def get_streaming_response( self, - messages: str | ChatMessage | Sequence[ChatMessage], + messages: str | ChatMessage | list[ChatMessage], **kwargs: Any, ) -> AsyncIterable[ChatResponseUpdate]: """Sends input messages and streams the response. diff --git a/python/packages/main/agent_framework/_types.py b/python/packages/main/agent_framework/_types.py index f3f6a92d13..d1ce0234a3 100644 --- a/python/packages/main/agent_framework/_types.py +++ b/python/packages/main/agent_framework/_types.py @@ -1631,6 +1631,17 @@ class AgentRunResponse(AFBaseModel): _finalize_response(msg) return msg + @classmethod + async def from_agent_response_generator( + cls: type[TAgentRunResponse], updates: AsyncIterable["AgentRunResponseUpdate"] + ) -> TAgentRunResponse: + """Joins multiple updates into a single AgentRunResponse.""" + msg = cls(messages=[]) + async for update in updates: + _process_update(msg, update) + _finalize_response(msg) + return msg + def __str__(self) -> str: return self.text diff --git a/python/packages/main/tests/unit/test_agents.py b/python/packages/main/tests/unit/test_agents.py index ee4435e6bf..0e732cf9bc 100644 --- a/python/packages/main/tests/unit/test_agents.py +++ b/python/packages/main/tests/unit/test_agents.py @@ -1,46 +1,55 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, Sequence -from typing import Any, TypeVar +from collections.abc import AsyncIterable, MutableSequence, Sequence +from typing import Any from uuid import uuid4 -from pydantic import BaseModel, Field -from pytest import fixture +from pytest import fixture, raises from agent_framework import ( Agent, AgentRunResponse, AgentRunResponseUpdate, AgentThread, + ChatClient, + ChatClientAgent, + ChatClientAgentThread, + ChatClientAgentThreadType, + ChatClientBase, ChatMessage, + ChatOptions, + ChatResponse, + ChatResponseUpdate, ChatRole, TextContent, ) - -TThreadType = TypeVar("TThreadType", bound=AgentThread) +from agent_framework.exceptions import AgentExecutionException # Mock AgentThread implementation for testing class MockAgentThread(AgentThread): - async def _create(self) -> str: - return str(uuid4()) - - async def _delete(self) -> None: - pass - - async def _on_new_message(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None: + async def _on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None: pass # Mock Agent implementation for testing -class MockAgent(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - name: str | None = None - description: str | None = None +class MockAgent(Agent): + @property + def id(self) -> str: + return str(uuid4()) + + @property + def name(self) -> str | None: + """Returns the name of the agent.""" + return "Name" + + @property + def description(self) -> str | None: + return "Description" async def run( self, - messages: ChatMessage | str | list[ChatMessage] | None = None, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, @@ -49,7 +58,7 @@ class MockAgent(BaseModel): async def run_stream( self, - messages: str | ChatMessage | list[ChatMessage] | None = None, + messages: str | ChatMessage | list[str | ChatMessage] | None = None, *, thread: AgentThread | None = None, **kwargs: Any, @@ -60,6 +69,36 @@ class MockAgent(BaseModel): return MockAgentThread() +# Mock ChatClient implementation for testing +class MockChatClient(ChatClientBase): + _mock_response: ChatResponse | None = None + + def __init__(self, mock_response: ChatResponse | None = None) -> None: + self._mock_response = mock_response + + async def _inner_get_response( + self, + *, + messages: MutableSequence[ChatMessage], + chat_options: ChatOptions, + **kwargs: Any, + ) -> ChatResponse: + return ( + self._mock_response + if self._mock_response + else ChatResponse(messages=ChatMessage(role=ChatRole.ASSISTANT, text="test response")) + ) + + async def _inner_get_streaming_response( + self, + *, + messages: MutableSequence[ChatMessage], + chat_options: ChatOptions, + **kwargs: Any, + ) -> AsyncIterable[ChatResponseUpdate]: + yield ChatResponseUpdate(role=ChatRole.ASSISTANT, text=TextContent(text="test streaming response")) + + @fixture def agent_thread() -> AgentThread: return MockAgentThread() @@ -70,39 +109,15 @@ def agent() -> Agent: return MockAgent() +@fixture +def chat_client() -> ChatClientBase: + return MockChatClient() + + def test_agent_thread_type(agent_thread: AgentThread) -> None: assert isinstance(agent_thread, AgentThread) -async def test_agent_thread_id_property(agent_thread: AgentThread) -> None: - assert agent_thread.id is None - await agent_thread.create() - assert isinstance(agent_thread.id, str) - - -async def test_agent_thread_create(agent_thread: AgentThread) -> None: - thread_id = await agent_thread.create() - assert thread_id == agent_thread.id - assert isinstance(thread_id, str) - - -async def test_agent_thread_create_already_exists(agent_thread: AgentThread) -> None: - thread_id = await agent_thread.create() - same_id = await agent_thread.create() - assert thread_id == same_id - - -async def test_agent_thread_delete_already_deleted(agent_thread: AgentThread) -> None: - await agent_thread.delete() - await agent_thread.delete() # Should not raise error - - -async def test_agent_thread_on_new_message_creates_thread(agent_thread: AgentThread) -> None: - message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")]) - await agent_thread.on_new_message(message) - assert agent_thread.id is not None - - def test_agent_type(agent: Agent) -> None: assert isinstance(agent, Agent) @@ -120,3 +135,154 @@ async def test_agent_run_stream(agent: Agent) -> None: updates = await collect_updates(agent.run_stream(messages="test")) assert len(updates) == 1 assert updates[0].text == "Response" + + +async def test_chat_client_agent_thread_init_in_memory() -> None: + messages = [ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])] + thread = ChatClientAgentThread(messages=messages) + + assert thread.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES + assert thread.id is None + assert thread.chat_messages == messages + + +async def test_chat_client_agent_thread_empty() -> None: + thread = ChatClientAgentThread() + + assert thread.storage_location is None + assert thread.id is None + assert thread.chat_messages is None + + +async def test_chat_client_agent_thread_init_invalid() -> None: + with raises(ValueError, match="Cannot specify both id and messages"): + ChatClientAgentThread(id="123", messages=[ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])]) + + with raises(ValueError, match="ID cannot be empty or whitespace"): + ChatClientAgentThread(id=" ") + + +async def test_chat_client_agent_thread_init_conversation_id() -> None: + thread_id = str(uuid4()) + thread = ChatClientAgentThread(id=thread_id) + + assert thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID + assert thread.id == thread_id + assert thread.chat_messages is None + + +async def test_chat_client_agent_thread_get_messages() -> None: + messages = [ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])] + thread = ChatClientAgentThread(messages=messages) + + result = [msg async for msg in thread.get_messages()] + assert result == messages + + +async def test_chat_client_agent_thread_on_new_messages_in_memory() -> None: + initial_message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Initial message")]) + new_message = ChatMessage(role=ChatRole.USER, contents=[TextContent("New message")]) + + thread = ChatClientAgentThread(messages=[initial_message]) + + await thread._on_new_messages(new_message) # type: ignore[reportPrivateUsage] + assert thread.chat_messages == [initial_message, new_message] + + +def test_chat_client_agent_type(chat_client: ChatClient) -> None: + chat_client_agent = ChatClientAgent(chat_client=chat_client) + assert isinstance(chat_client_agent, Agent) + + +async def test_chat_client_agent_init(chat_client: ChatClient) -> None: + agent_id = str(uuid4()) + agent = ChatClientAgent(chat_client=chat_client, id=agent_id, description="Test") + + assert agent.id == agent_id + assert agent.name == "UnnamedAgent" + assert agent.description == "Test" + + +async def test_chat_client_agent_run(chat_client: ChatClient) -> None: + agent = ChatClientAgent(chat_client=chat_client) + + result = await agent.run("Hello") + + assert result.text == "test response" + + +async def test_chat_client_agent_run_stream(chat_client: ChatClient) -> None: + agent = ChatClientAgent(chat_client=chat_client) + + result = await AgentRunResponse.from_agent_response_generator(agent.run_stream("Hello")) + + assert result.text == "test streaming response" + + +async def test_chat_client_agent_get_new_thread(chat_client: ChatClient) -> None: + agent = ChatClientAgent(chat_client=chat_client) + thread = agent.get_new_thread() + + assert isinstance(thread, ChatClientAgentThread) + assert thread.storage_location is None + + +async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatClient) -> None: + agent = ChatClientAgent(chat_client=chat_client) + message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")]) + thread = ChatClientAgentThread(messages=[message]) + + result_thread, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] + thread=thread, + input_messages="Test", + construct_thread=lambda: ChatClientAgentThread(), + expected_type=ChatClientAgentThread, + ) + + assert result_thread == thread + assert len(result_messages) == 2 + assert result_messages[0] == message + assert result_messages[1].text == "Test" + + +async def test_chat_client_agent_update_thread_id() -> None: + chat_client = MockChatClient( + mock_response=ChatResponse( + messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("test response")])], + conversation_id="123", + ) + ) + agent = ChatClientAgent(chat_client=chat_client) + thread = agent.get_new_thread() + + result = await agent.run("Hello", thread=thread) + assert result.text == "test response" + + assert thread.id == "123" + assert isinstance(thread, ChatClientAgentThread) + assert thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID + + +async def test_chat_client_agent_update_thread_messages(chat_client: ChatClient) -> None: + agent = ChatClientAgent(chat_client=chat_client) + thread = agent.get_new_thread() + + result = await agent.run("Hello", thread=thread) + assert result.text == "test response" + + assert thread.id is None + assert isinstance(thread, ChatClientAgentThread) + assert thread.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES + + assert thread.chat_messages is not None + assert len(thread.chat_messages) == 2 + assert thread.chat_messages[0].text == "Hello" + assert thread.chat_messages[1].text == "test response" + + +async def test_chat_client_agent_update_thread_conversation_id_missing(chat_client: ChatClient) -> None: + agent = ChatClientAgent(chat_client=chat_client) + thread = ChatClientAgentThread(id="123") + + with raises(AgentExecutionException, match="Service did not return a valid conversation id"): + agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage]