mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Added ChatClientAgentThread and ChatClientAgent implementations (#150)
* Added ChatClientAgentThread * Initial version of ChatClientAgent * Completed ChatClientAgent * Small fixes and unit tests * Fixes based on pre-commit * Small fixes * Small renaming * Small improvement * Small fixes * Addressed PR feedback * Small fix * Added method for AgentRunResponse from streaming conversion * Addressed PR feedback * Addressed PR feedback * Addressed PR feedback * Small fix * More fixes
This commit is contained in:
committed by
GitHub
Unverified
parent
df84675c0f
commit
94e00bd49a
@@ -21,6 +21,9 @@ _IMPORTS = {
|
||||
"ai_function": "._tools",
|
||||
"AIContent": "._types",
|
||||
"AIContents": "._types",
|
||||
"ChatClientAgent": "._agents",
|
||||
"ChatClientAgentThread": "._agents",
|
||||
"ChatClientAgentThreadType": "._agents",
|
||||
"TextContent": "._types",
|
||||
"TextReasoningContent": "._types",
|
||||
"DataContent": "._types",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from . import __version__ # type: ignore[attr-defined]
|
||||
from ._agents import Agent, AgentThread
|
||||
from ._agents import Agent, AgentThread, ChatClientAgent, ChatClientAgentThread, ChatClientAgentThreadType
|
||||
from ._clients import ChatClient, ChatClientBase, EmbeddingGenerator, use_tool_calling
|
||||
from ._logging import get_logger
|
||||
from ._pydantic import AFBaseModel, AFBaseSettings
|
||||
@@ -45,6 +45,9 @@ __all__ = [
|
||||
"AgentRunResponseUpdate",
|
||||
"AgentThread",
|
||||
"ChatClient",
|
||||
"ChatClientAgent",
|
||||
"ChatClientAgentThread",
|
||||
"ChatClientAgentThreadType",
|
||||
"ChatClientBase",
|
||||
"ChatFinishReason",
|
||||
"ChatMessage",
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from collections.abc import AsyncIterable, Callable, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ._clients import ChatClient
|
||||
from ._pydantic import AFBaseModel
|
||||
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole
|
||||
from .exceptions import AgentExecutionException
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound="AgentThread")
|
||||
|
||||
# region AgentThread
|
||||
|
||||
@@ -15,48 +22,28 @@ class AgentThread(AFBaseModel):
|
||||
|
||||
id: str | None = None
|
||||
|
||||
async def create(self) -> str | None:
|
||||
"""Starts the thread and returns the thread ID."""
|
||||
# If the thread ID is already set, we're done, just return the Id.
|
||||
if self.id is not None:
|
||||
return self.id
|
||||
|
||||
# Otherwise, create the thread.
|
||||
self.id = await self._create()
|
||||
return self.id
|
||||
|
||||
async def delete(self) -> None:
|
||||
"""Ends the current thread."""
|
||||
await self._delete()
|
||||
self.id = None
|
||||
|
||||
async def on_new_message(
|
||||
async def on_new_messages(
|
||||
self,
|
||||
new_messages: ChatMessage | Sequence[ChatMessage],
|
||||
) -> None:
|
||||
"""Invoked when a new message has been contributed to the chat by any participant."""
|
||||
# If the thread is not created yet, create it.
|
||||
if self.id is None:
|
||||
await self.create()
|
||||
await self._on_new_messages(new_messages=new_messages)
|
||||
|
||||
await self._on_new_message(new_messages=new_messages)
|
||||
|
||||
@abstractmethod
|
||||
async def _create(self) -> str:
|
||||
"""Starts the thread and returns the thread ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _delete(self) -> None:
|
||||
"""Ends the current thread."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _on_new_message(
|
||||
async def _on_new_messages(
|
||||
self,
|
||||
new_messages: ChatMessage | Sequence[ChatMessage],
|
||||
) -> None:
|
||||
"""Invoked when a new message has been contributed to the chat by any participant."""
|
||||
pass
|
||||
|
||||
|
||||
# region MessagesRetrievableThread
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessagesRetrievableThread(Protocol):
|
||||
def get_messages(self) -> AsyncIterable[ChatMessage]:
|
||||
"""Asynchronously retrieves all messages from thread."""
|
||||
...
|
||||
|
||||
|
||||
@@ -84,7 +71,7 @@ class Agent(Protocol):
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -113,7 +100,7 @@ class Agent(Protocol):
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -138,3 +125,272 @@ class Agent(Protocol):
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
"""Creates a new conversation thread for the agent."""
|
||||
...
|
||||
|
||||
|
||||
# region AgentBase
|
||||
|
||||
|
||||
class AgentBase(AFBaseModel):
|
||||
"""Base class for all agents.
|
||||
|
||||
Attributes:
|
||||
id: The unique identifier of the agent If no id is provided,
|
||||
a new UUID will be generated.
|
||||
name: The name of the agent
|
||||
description: The description of the agent
|
||||
"""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str = Field(default="UnnamedAgent")
|
||||
description: str | None = None
|
||||
|
||||
async def _notify_thread_of_new_messages(
|
||||
self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]
|
||||
) -> None:
|
||||
"""Notify the thread of new messages."""
|
||||
if isinstance(new_messages, ChatMessage) or len(new_messages) > 0:
|
||||
await thread.on_new_messages(new_messages)
|
||||
|
||||
|
||||
# region ChatClientAgentThread
|
||||
|
||||
|
||||
class ChatClientAgentThreadType(Enum):
|
||||
"""Defines the different supported storage locations for ChatClientAgentThread."""
|
||||
|
||||
IN_MEMORY_MESSAGES = "InMemoryMessages"
|
||||
"""Messages are stored in memory inside the thread object."""
|
||||
|
||||
CONVERSATION_ID = "ConversationId"
|
||||
"""Messages are stored in the service and the thread object just has an id reference to the service storage."""
|
||||
|
||||
|
||||
class ChatClientAgentThread(AgentThread):
|
||||
"""Chat client agent thread.
|
||||
|
||||
This class manages chat threads either locally (in-memory) or via a service based on initialization.
|
||||
"""
|
||||
|
||||
chat_messages: list[ChatMessage] | None = None
|
||||
storage_location: ChatClientAgentThreadType | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str | None = None,
|
||||
messages: Sequence[ChatMessage] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the chat client agent thread.
|
||||
|
||||
Args:
|
||||
id: Service thread identifier. If provided, thread is managed by the service and messages are
|
||||
not stored locally. Must not be empty or whitespace.
|
||||
messages: Initial messages for local storage. If provided, thread is managed
|
||||
locally in-memory.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If both id and messages are provided, or if id is empty/whitespace.
|
||||
|
||||
Notes:
|
||||
- If id is set, _id is assigned and _chat_messages is None (service-managed).
|
||||
- If messages is set, _chat_messages is populated and _id is None (local).
|
||||
- If neither is provided, creates an empty local thread.
|
||||
"""
|
||||
processed_messages: list[ChatMessage] | None = None
|
||||
storage_location: ChatClientAgentThreadType | None = None
|
||||
|
||||
if id and messages:
|
||||
raise ValueError("Cannot specify both id and messages")
|
||||
|
||||
if id:
|
||||
if not id.strip():
|
||||
raise ValueError("ID cannot be empty or whitespace")
|
||||
storage_location = ChatClientAgentThreadType.CONVERSATION_ID
|
||||
elif messages:
|
||||
processed_messages = []
|
||||
processed_messages.extend(messages)
|
||||
storage_location = ChatClientAgentThreadType.IN_MEMORY_MESSAGES
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
chat_messages=processed_messages, # type: ignore[reportCallIssue]
|
||||
storage_location=storage_location, # type: ignore[reportCallIssue]
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def get_messages(self) -> AsyncIterable[ChatMessage]:
|
||||
"""Get all messages in the thread."""
|
||||
for message in self.chat_messages or []:
|
||||
yield message
|
||||
|
||||
async def _on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
"""Handle new messages."""
|
||||
if self.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES:
|
||||
if self.chat_messages is None:
|
||||
self.chat_messages = []
|
||||
self.chat_messages.extend([new_messages] if isinstance(new_messages, ChatMessage) else new_messages)
|
||||
|
||||
|
||||
# region ChatClientAgent
|
||||
|
||||
|
||||
class ChatClientAgent(AgentBase):
|
||||
"""A Chat Client Agent which depends on ChatClient."""
|
||||
|
||||
chat_client: ChatClient
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResponse:
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread,
|
||||
input_messages=messages,
|
||||
construct_thread=lambda: ChatClientAgentThread(),
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
|
||||
response = await self.chat_client.get_response(thread_messages, **kwargs)
|
||||
|
||||
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
|
||||
|
||||
# Only notify the thread of new messages if the chatResponse was successful
|
||||
# to avoid inconsistent messages state in the thread.
|
||||
await self._notify_thread_of_new_messages(thread, thread_messages)
|
||||
await self._notify_thread_of_new_messages(thread, response.messages)
|
||||
|
||||
return AgentRunResponse(
|
||||
messages=response.messages,
|
||||
response_id=response.response_id,
|
||||
created_at=response.created_at,
|
||||
usage_details=response.usage_details,
|
||||
raw_representation=response.raw_representation,
|
||||
additional_properties=response.additional_properties,
|
||||
)
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentRunResponseUpdate]:
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(
|
||||
thread=thread,
|
||||
input_messages=messages,
|
||||
construct_thread=lambda: ChatClientAgentThread(),
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
|
||||
response_updates: list[ChatResponseUpdate] = []
|
||||
|
||||
streaming_response: AsyncIterable[ChatResponseUpdate] = self.chat_client.get_streaming_response(thread_messages)
|
||||
|
||||
async for update in streaming_response:
|
||||
response_updates.append(update)
|
||||
yield AgentRunResponseUpdate(
|
||||
contents=update.contents,
|
||||
role=update.role,
|
||||
author_name=update.author_name,
|
||||
response_id=update.response_id,
|
||||
message_id=update.message_id,
|
||||
created_at=update.created_at,
|
||||
additional_properties=update.additional_properties,
|
||||
raw_representation=update.raw_representation,
|
||||
)
|
||||
|
||||
response = ChatResponse.from_chat_response_updates(response_updates)
|
||||
|
||||
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
|
||||
|
||||
# Only notify the thread of new messages if the chatResponse was successful
|
||||
# to avoid inconsistent messages state in the thread.
|
||||
await self._notify_thread_of_new_messages(thread, thread_messages)
|
||||
await self._notify_thread_of_new_messages(thread, response.messages)
|
||||
|
||||
def get_new_thread(self) -> AgentThread:
|
||||
return ChatClientAgentThread()
|
||||
|
||||
def _update_thread_with_type_and_conversation_id(
|
||||
self, chat_client_thread: ChatClientAgentThread, responseConversationId: str | None
|
||||
) -> None:
|
||||
"""Update thread with storage type and conversation ID.
|
||||
|
||||
Args:
|
||||
chat_client_thread: The thread to update.
|
||||
responseConversationId: The conversation ID from the response, if any.
|
||||
|
||||
Raises:
|
||||
AgentExecutionException: If conversation ID is missing for service-managed thread.
|
||||
"""
|
||||
# Set the thread's storage location, the first time that we use it.
|
||||
if chat_client_thread.storage_location is None:
|
||||
chat_client_thread.storage_location = (
|
||||
ChatClientAgentThreadType.CONVERSATION_ID
|
||||
if responseConversationId is not None
|
||||
else ChatClientAgentThreadType.IN_MEMORY_MESSAGES
|
||||
)
|
||||
|
||||
# If we got a conversation id back from the chat client, it means that the service supports server side thread
|
||||
# storage so we should capture the id and update the thread with the new id.
|
||||
if chat_client_thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID:
|
||||
if responseConversationId is None:
|
||||
raise AgentExecutionException(
|
||||
"Service did not return a valid conversation id when using a service managed thread."
|
||||
)
|
||||
chat_client_thread.id = responseConversationId
|
||||
|
||||
async def _prepare_thread_and_messages(
|
||||
self,
|
||||
*,
|
||||
thread: AgentThread | None,
|
||||
input_messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
construct_thread: Callable[[], TThreadType],
|
||||
expected_type: type[TThreadType],
|
||||
) -> tuple[TThreadType, list[ChatMessage]]:
|
||||
"""Prepare thread and messages for agent execution.
|
||||
|
||||
Args:
|
||||
thread: The conversation thread, or None to create a new one.
|
||||
input_messages: Messages to process, can be string, ChatMessage, or sequence.
|
||||
construct_thread: Factory function to create a new thread.
|
||||
expected_type: Expected thread type for validation.
|
||||
|
||||
Returns:
|
||||
Tuple of the thread and normalized messages.
|
||||
|
||||
Raises:
|
||||
AgentExecutionException: If thread type is incompatible.
|
||||
"""
|
||||
messages: list[ChatMessage] = []
|
||||
|
||||
if thread is None:
|
||||
thread = construct_thread()
|
||||
|
||||
if not isinstance(thread, expected_type):
|
||||
raise AgentExecutionException(
|
||||
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
|
||||
)
|
||||
|
||||
# Add any existing messages from the thread to the messages to be sent to the chat client.
|
||||
if isinstance(thread, MessagesRetrievableThread):
|
||||
async for message in thread.get_messages():
|
||||
messages.append(message)
|
||||
|
||||
if input_messages is None:
|
||||
input_messages = []
|
||||
|
||||
if isinstance(input_messages, (str, ChatMessage)):
|
||||
input_messages = [input_messages]
|
||||
|
||||
normalized_messages = [
|
||||
ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in input_messages
|
||||
]
|
||||
|
||||
messages.extend(normalized_messages)
|
||||
|
||||
return thread, messages
|
||||
|
||||
@@ -256,7 +256,7 @@ class ChatClient(Protocol):
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[ChatMessage],
|
||||
messages: str | ChatMessage | list[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
"""Sends input and returns the response.
|
||||
@@ -274,9 +274,9 @@ class ChatClient(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_streaming_response(
|
||||
def get_streaming_response(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[ChatMessage],
|
||||
messages: str | ChatMessage | list[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Sends input messages and streams the response.
|
||||
|
||||
@@ -1631,6 +1631,17 @@ class AgentRunResponse(AFBaseModel):
|
||||
_finalize_response(msg)
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
async def from_agent_response_generator(
|
||||
cls: type[TAgentRunResponse], updates: AsyncIterable["AgentRunResponseUpdate"]
|
||||
) -> TAgentRunResponse:
|
||||
"""Joins multiple updates into a single AgentRunResponse."""
|
||||
msg = cls(messages=[])
|
||||
async for update in updates:
|
||||
_process_update(msg, update)
|
||||
_finalize_response(msg)
|
||||
return msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
@@ -1,46 +1,55 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any, TypeVar
|
||||
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pytest import fixture
|
||||
from pytest import fixture, raises
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentThread,
|
||||
ChatClient,
|
||||
ChatClientAgent,
|
||||
ChatClientAgentThread,
|
||||
ChatClientAgentThreadType,
|
||||
ChatClientBase,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ChatRole,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
TThreadType = TypeVar("TThreadType", bound=AgentThread)
|
||||
from agent_framework.exceptions import AgentExecutionException
|
||||
|
||||
|
||||
# Mock AgentThread implementation for testing
|
||||
class MockAgentThread(AgentThread):
|
||||
async def _create(self) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
async def _delete(self) -> None:
|
||||
pass
|
||||
|
||||
async def _on_new_message(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
async def _on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# Mock Agent implementation for testing
|
||||
class MockAgent(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
class MockAgent(Agent):
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
"""Returns the name of the agent."""
|
||||
return "Name"
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return "Description"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: ChatMessage | str | list[ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -49,7 +58,7 @@ class MockAgent(BaseModel):
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[ChatMessage] | None = None,
|
||||
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -60,6 +69,36 @@ class MockAgent(BaseModel):
|
||||
return MockAgentThread()
|
||||
|
||||
|
||||
# Mock ChatClient implementation for testing
|
||||
class MockChatClient(ChatClientBase):
|
||||
_mock_response: ChatResponse | None = None
|
||||
|
||||
def __init__(self, mock_response: ChatResponse | None = None) -> None:
|
||||
self._mock_response = mock_response
|
||||
|
||||
async def _inner_get_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[ChatMessage],
|
||||
chat_options: ChatOptions,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
return (
|
||||
self._mock_response
|
||||
if self._mock_response
|
||||
else ChatResponse(messages=ChatMessage(role=ChatRole.ASSISTANT, text="test response"))
|
||||
)
|
||||
|
||||
async def _inner_get_streaming_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[ChatMessage],
|
||||
chat_options: ChatOptions,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
yield ChatResponseUpdate(role=ChatRole.ASSISTANT, text=TextContent(text="test streaming response"))
|
||||
|
||||
|
||||
@fixture
|
||||
def agent_thread() -> AgentThread:
|
||||
return MockAgentThread()
|
||||
@@ -70,39 +109,15 @@ def agent() -> Agent:
|
||||
return MockAgent()
|
||||
|
||||
|
||||
@fixture
|
||||
def chat_client() -> ChatClientBase:
|
||||
return MockChatClient()
|
||||
|
||||
|
||||
def test_agent_thread_type(agent_thread: AgentThread) -> None:
|
||||
assert isinstance(agent_thread, AgentThread)
|
||||
|
||||
|
||||
async def test_agent_thread_id_property(agent_thread: AgentThread) -> None:
|
||||
assert agent_thread.id is None
|
||||
await agent_thread.create()
|
||||
assert isinstance(agent_thread.id, str)
|
||||
|
||||
|
||||
async def test_agent_thread_create(agent_thread: AgentThread) -> None:
|
||||
thread_id = await agent_thread.create()
|
||||
assert thread_id == agent_thread.id
|
||||
assert isinstance(thread_id, str)
|
||||
|
||||
|
||||
async def test_agent_thread_create_already_exists(agent_thread: AgentThread) -> None:
|
||||
thread_id = await agent_thread.create()
|
||||
same_id = await agent_thread.create()
|
||||
assert thread_id == same_id
|
||||
|
||||
|
||||
async def test_agent_thread_delete_already_deleted(agent_thread: AgentThread) -> None:
|
||||
await agent_thread.delete()
|
||||
await agent_thread.delete() # Should not raise error
|
||||
|
||||
|
||||
async def test_agent_thread_on_new_message_creates_thread(agent_thread: AgentThread) -> None:
|
||||
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
|
||||
await agent_thread.on_new_message(message)
|
||||
assert agent_thread.id is not None
|
||||
|
||||
|
||||
def test_agent_type(agent: Agent) -> None:
|
||||
assert isinstance(agent, Agent)
|
||||
|
||||
@@ -120,3 +135,154 @@ async def test_agent_run_stream(agent: Agent) -> None:
|
||||
updates = await collect_updates(agent.run_stream(messages="test"))
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "Response"
|
||||
|
||||
|
||||
async def test_chat_client_agent_thread_init_in_memory() -> None:
|
||||
messages = [ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])]
|
||||
thread = ChatClientAgentThread(messages=messages)
|
||||
|
||||
assert thread.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES
|
||||
assert thread.id is None
|
||||
assert thread.chat_messages == messages
|
||||
|
||||
|
||||
async def test_chat_client_agent_thread_empty() -> None:
|
||||
thread = ChatClientAgentThread()
|
||||
|
||||
assert thread.storage_location is None
|
||||
assert thread.id is None
|
||||
assert thread.chat_messages is None
|
||||
|
||||
|
||||
async def test_chat_client_agent_thread_init_invalid() -> None:
|
||||
with raises(ValueError, match="Cannot specify both id and messages"):
|
||||
ChatClientAgentThread(id="123", messages=[ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])])
|
||||
|
||||
with raises(ValueError, match="ID cannot be empty or whitespace"):
|
||||
ChatClientAgentThread(id=" ")
|
||||
|
||||
|
||||
async def test_chat_client_agent_thread_init_conversation_id() -> None:
|
||||
thread_id = str(uuid4())
|
||||
thread = ChatClientAgentThread(id=thread_id)
|
||||
|
||||
assert thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID
|
||||
assert thread.id == thread_id
|
||||
assert thread.chat_messages is None
|
||||
|
||||
|
||||
async def test_chat_client_agent_thread_get_messages() -> None:
|
||||
messages = [ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])]
|
||||
thread = ChatClientAgentThread(messages=messages)
|
||||
|
||||
result = [msg async for msg in thread.get_messages()]
|
||||
assert result == messages
|
||||
|
||||
|
||||
async def test_chat_client_agent_thread_on_new_messages_in_memory() -> None:
|
||||
initial_message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Initial message")])
|
||||
new_message = ChatMessage(role=ChatRole.USER, contents=[TextContent("New message")])
|
||||
|
||||
thread = ChatClientAgentThread(messages=[initial_message])
|
||||
|
||||
await thread._on_new_messages(new_message) # type: ignore[reportPrivateUsage]
|
||||
assert thread.chat_messages == [initial_message, new_message]
|
||||
|
||||
|
||||
def test_chat_client_agent_type(chat_client: ChatClient) -> None:
|
||||
chat_client_agent = ChatClientAgent(chat_client=chat_client)
|
||||
assert isinstance(chat_client_agent, Agent)
|
||||
|
||||
|
||||
async def test_chat_client_agent_init(chat_client: ChatClient) -> None:
|
||||
agent_id = str(uuid4())
|
||||
agent = ChatClientAgent(chat_client=chat_client, id=agent_id, description="Test")
|
||||
|
||||
assert agent.id == agent_id
|
||||
assert agent.name == "UnnamedAgent"
|
||||
assert agent.description == "Test"
|
||||
|
||||
|
||||
async def test_chat_client_agent_run(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
|
||||
result = await agent.run("Hello")
|
||||
|
||||
assert result.text == "test response"
|
||||
|
||||
|
||||
async def test_chat_client_agent_run_stream(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
|
||||
result = await AgentRunResponse.from_agent_response_generator(agent.run_stream("Hello"))
|
||||
|
||||
assert result.text == "test streaming response"
|
||||
|
||||
|
||||
async def test_chat_client_agent_get_new_thread(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
assert isinstance(thread, ChatClientAgentThread)
|
||||
assert thread.storage_location is None
|
||||
|
||||
|
||||
async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
|
||||
thread = ChatClientAgentThread(messages=[message])
|
||||
|
||||
result_thread, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
|
||||
thread=thread,
|
||||
input_messages="Test",
|
||||
construct_thread=lambda: ChatClientAgentThread(),
|
||||
expected_type=ChatClientAgentThread,
|
||||
)
|
||||
|
||||
assert result_thread == thread
|
||||
assert len(result_messages) == 2
|
||||
assert result_messages[0] == message
|
||||
assert result_messages[1].text == "Test"
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_thread_id() -> None:
|
||||
chat_client = MockChatClient(
|
||||
mock_response=ChatResponse(
|
||||
messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("test response")])],
|
||||
conversation_id="123",
|
||||
)
|
||||
)
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
result = await agent.run("Hello", thread=thread)
|
||||
assert result.text == "test response"
|
||||
|
||||
assert thread.id == "123"
|
||||
assert isinstance(thread, ChatClientAgentThread)
|
||||
assert thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_thread_messages(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
result = await agent.run("Hello", thread=thread)
|
||||
assert result.text == "test response"
|
||||
|
||||
assert thread.id is None
|
||||
assert isinstance(thread, ChatClientAgentThread)
|
||||
assert thread.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES
|
||||
|
||||
assert thread.chat_messages is not None
|
||||
assert len(thread.chat_messages) == 2
|
||||
assert thread.chat_messages[0].text == "Hello"
|
||||
assert thread.chat_messages[1].text == "test response"
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_thread_conversation_id_missing(chat_client: ChatClient) -> None:
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
thread = ChatClientAgentThread(id="123")
|
||||
|
||||
with raises(AgentExecutionException, match="Service did not return a valid conversation id"):
|
||||
agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage]
|
||||
|
||||
Reference in New Issue
Block a user