Python: Added ChatClientAgentThread and ChatClientAgent implementations (#150)

* Added ChatClientAgentThread

* Initial version of ChatClientAgent

* Completed ChatClientAgent

* Small fixes and unit tests

* Fixes based on pre-commit

* Small fixes

* Small renaming

* Small improvement

* Small fixes

* Addressed PR feedback

* Small fix

* Added method for AgentRunResponse from streaming conversion

* Addressed PR feedback

* Addressed PR feedback

* Addressed PR feedback

* Small fix

* More fixes
This commit is contained in:
Dmytro Struk
2025-07-11 11:09:09 -07:00
committed by GitHub
Unverified
parent df84675c0f
commit 94e00bd49a
6 changed files with 530 additions and 91 deletions
@@ -21,6 +21,9 @@ _IMPORTS = {
"ai_function": "._tools",
"AIContent": "._types",
"AIContents": "._types",
"ChatClientAgent": "._agents",
"ChatClientAgentThread": "._agents",
"ChatClientAgentThreadType": "._agents",
"TextContent": "._types",
"TextReasoningContent": "._types",
"DataContent": "._types",
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from . import __version__ # type: ignore[attr-defined]
from ._agents import Agent, AgentThread
from ._agents import Agent, AgentThread, ChatClientAgent, ChatClientAgentThread, ChatClientAgentThreadType
from ._clients import ChatClient, ChatClientBase, EmbeddingGenerator, use_tool_calling
from ._logging import get_logger
from ._pydantic import AFBaseModel, AFBaseSettings
@@ -45,6 +45,9 @@ __all__ = [
"AgentRunResponseUpdate",
"AgentThread",
"ChatClient",
"ChatClientAgent",
"ChatClientAgentThread",
"ChatClientAgentThreadType",
"ChatClientBase",
"ChatFinishReason",
"ChatMessage",
+295 -39
View File
@@ -1,11 +1,18 @@
# Copyright (c) Microsoft. All rights reserved.
from abc import abstractmethod
from collections.abc import AsyncIterable, Sequence
from typing import Any, Protocol, runtime_checkable
from collections.abc import AsyncIterable, Callable, Sequence
from enum import Enum
from typing import Any, Protocol, TypeVar, runtime_checkable
from uuid import uuid4
from pydantic import Field
from ._clients import ChatClient
from ._pydantic import AFBaseModel
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ChatResponse, ChatResponseUpdate, ChatRole
from .exceptions import AgentExecutionException
TThreadType = TypeVar("TThreadType", bound="AgentThread")
# region AgentThread
@@ -15,48 +22,28 @@ class AgentThread(AFBaseModel):
id: str | None = None
async def create(self) -> str | None:
"""Starts the thread and returns the thread ID."""
# If the thread ID is already set, we're done, just return the Id.
if self.id is not None:
return self.id
# Otherwise, create the thread.
self.id = await self._create()
return self.id
async def delete(self) -> None:
"""Ends the current thread."""
await self._delete()
self.id = None
async def on_new_message(
async def on_new_messages(
self,
new_messages: ChatMessage | Sequence[ChatMessage],
) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
# If the thread is not created yet, create it.
if self.id is None:
await self.create()
await self._on_new_messages(new_messages=new_messages)
await self._on_new_message(new_messages=new_messages)
@abstractmethod
async def _create(self) -> str:
"""Starts the thread and returns the thread ID."""
...
@abstractmethod
async def _delete(self) -> None:
"""Ends the current thread."""
...
@abstractmethod
async def _on_new_message(
async def _on_new_messages(
self,
new_messages: ChatMessage | Sequence[ChatMessage],
) -> None:
"""Invoked when a new message has been contributed to the chat by any participant."""
pass
# region MessagesRetrievableThread
@runtime_checkable
class MessagesRetrievableThread(Protocol):
def get_messages(self) -> AsyncIterable[ChatMessage]:
"""Asynchronously retrieves all messages from thread."""
...
@@ -84,7 +71,7 @@ class Agent(Protocol):
async def run(
self,
messages: str | ChatMessage | list[ChatMessage] | None = None,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
@@ -113,7 +100,7 @@ class Agent(Protocol):
def run_stream(
self,
messages: str | ChatMessage | list[ChatMessage] | None = None,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
@@ -138,3 +125,272 @@ class Agent(Protocol):
def get_new_thread(self) -> AgentThread:
"""Creates a new conversation thread for the agent."""
...
# region AgentBase
class AgentBase(AFBaseModel):
"""Base class for all agents.
Attributes:
id: The unique identifier of the agent If no id is provided,
a new UUID will be generated.
name: The name of the agent
description: The description of the agent
"""
id: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(default="UnnamedAgent")
description: str | None = None
async def _notify_thread_of_new_messages(
self, thread: AgentThread, new_messages: ChatMessage | Sequence[ChatMessage]
) -> None:
"""Notify the thread of new messages."""
if isinstance(new_messages, ChatMessage) or len(new_messages) > 0:
await thread.on_new_messages(new_messages)
# region ChatClientAgentThread
class ChatClientAgentThreadType(Enum):
"""Defines the different supported storage locations for ChatClientAgentThread."""
IN_MEMORY_MESSAGES = "InMemoryMessages"
"""Messages are stored in memory inside the thread object."""
CONVERSATION_ID = "ConversationId"
"""Messages are stored in the service and the thread object just has an id reference to the service storage."""
class ChatClientAgentThread(AgentThread):
"""Chat client agent thread.
This class manages chat threads either locally (in-memory) or via a service based on initialization.
"""
chat_messages: list[ChatMessage] | None = None
storage_location: ChatClientAgentThreadType | None = None
def __init__(
self,
id: str | None = None,
messages: Sequence[ChatMessage] | None = None,
**kwargs: Any,
):
"""Initialize the chat client agent thread.
Args:
id: Service thread identifier. If provided, thread is managed by the service and messages are
not stored locally. Must not be empty or whitespace.
messages: Initial messages for local storage. If provided, thread is managed
locally in-memory.
kwargs: Additional keyword arguments.
Raises:
ValueError: If both id and messages are provided, or if id is empty/whitespace.
Notes:
- If id is set, _id is assigned and _chat_messages is None (service-managed).
- If messages is set, _chat_messages is populated and _id is None (local).
- If neither is provided, creates an empty local thread.
"""
processed_messages: list[ChatMessage] | None = None
storage_location: ChatClientAgentThreadType | None = None
if id and messages:
raise ValueError("Cannot specify both id and messages")
if id:
if not id.strip():
raise ValueError("ID cannot be empty or whitespace")
storage_location = ChatClientAgentThreadType.CONVERSATION_ID
elif messages:
processed_messages = []
processed_messages.extend(messages)
storage_location = ChatClientAgentThreadType.IN_MEMORY_MESSAGES
super().__init__(
id=id,
chat_messages=processed_messages, # type: ignore[reportCallIssue]
storage_location=storage_location, # type: ignore[reportCallIssue]
**kwargs,
)
async def get_messages(self) -> AsyncIterable[ChatMessage]:
"""Get all messages in the thread."""
for message in self.chat_messages or []:
yield message
async def _on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
"""Handle new messages."""
if self.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES:
if self.chat_messages is None:
self.chat_messages = []
self.chat_messages.extend([new_messages] if isinstance(new_messages, ChatMessage) else new_messages)
# region ChatClientAgent
class ChatClientAgent(AgentBase):
"""A Chat Client Agent which depends on ChatClient."""
chat_client: ChatClient
async def run(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AgentRunResponse:
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread,
input_messages=messages,
construct_thread=lambda: ChatClientAgentThread(),
expected_type=ChatClientAgentThread,
)
response = await self.chat_client.get_response(thread_messages, **kwargs)
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
# Only notify the thread of new messages if the chatResponse was successful
# to avoid inconsistent messages state in the thread.
await self._notify_thread_of_new_messages(thread, thread_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
return AgentRunResponse(
messages=response.messages,
response_id=response.response_id,
created_at=response.created_at,
usage_details=response.usage_details,
raw_representation=response.raw_representation,
additional_properties=response.additional_properties,
)
async def run_stream(
self,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentRunResponseUpdate]:
thread, thread_messages = await self._prepare_thread_and_messages(
thread=thread,
input_messages=messages,
construct_thread=lambda: ChatClientAgentThread(),
expected_type=ChatClientAgentThread,
)
response_updates: list[ChatResponseUpdate] = []
streaming_response: AsyncIterable[ChatResponseUpdate] = self.chat_client.get_streaming_response(thread_messages)
async for update in streaming_response:
response_updates.append(update)
yield AgentRunResponseUpdate(
contents=update.contents,
role=update.role,
author_name=update.author_name,
response_id=update.response_id,
message_id=update.message_id,
created_at=update.created_at,
additional_properties=update.additional_properties,
raw_representation=update.raw_representation,
)
response = ChatResponse.from_chat_response_updates(response_updates)
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
# Only notify the thread of new messages if the chatResponse was successful
# to avoid inconsistent messages state in the thread.
await self._notify_thread_of_new_messages(thread, thread_messages)
await self._notify_thread_of_new_messages(thread, response.messages)
def get_new_thread(self) -> AgentThread:
return ChatClientAgentThread()
def _update_thread_with_type_and_conversation_id(
self, chat_client_thread: ChatClientAgentThread, responseConversationId: str | None
) -> None:
"""Update thread with storage type and conversation ID.
Args:
chat_client_thread: The thread to update.
responseConversationId: The conversation ID from the response, if any.
Raises:
AgentExecutionException: If conversation ID is missing for service-managed thread.
"""
# Set the thread's storage location, the first time that we use it.
if chat_client_thread.storage_location is None:
chat_client_thread.storage_location = (
ChatClientAgentThreadType.CONVERSATION_ID
if responseConversationId is not None
else ChatClientAgentThreadType.IN_MEMORY_MESSAGES
)
# If we got a conversation id back from the chat client, it means that the service supports server side thread
# storage so we should capture the id and update the thread with the new id.
if chat_client_thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID:
if responseConversationId is None:
raise AgentExecutionException(
"Service did not return a valid conversation id when using a service managed thread."
)
chat_client_thread.id = responseConversationId
async def _prepare_thread_and_messages(
self,
*,
thread: AgentThread | None,
input_messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
construct_thread: Callable[[], TThreadType],
expected_type: type[TThreadType],
) -> tuple[TThreadType, list[ChatMessage]]:
"""Prepare thread and messages for agent execution.
Args:
thread: The conversation thread, or None to create a new one.
input_messages: Messages to process, can be string, ChatMessage, or sequence.
construct_thread: Factory function to create a new thread.
expected_type: Expected thread type for validation.
Returns:
Tuple of the thread and normalized messages.
Raises:
AgentExecutionException: If thread type is incompatible.
"""
messages: list[ChatMessage] = []
if thread is None:
thread = construct_thread()
if not isinstance(thread, expected_type):
raise AgentExecutionException(
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
)
# Add any existing messages from the thread to the messages to be sent to the chat client.
if isinstance(thread, MessagesRetrievableThread):
async for message in thread.get_messages():
messages.append(message)
if input_messages is None:
input_messages = []
if isinstance(input_messages, (str, ChatMessage)):
input_messages = [input_messages]
normalized_messages = [
ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in input_messages
]
messages.extend(normalized_messages)
return thread, messages
@@ -256,7 +256,7 @@ class ChatClient(Protocol):
async def get_response(
self,
messages: str | ChatMessage | Sequence[ChatMessage],
messages: str | ChatMessage | list[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
"""Sends input and returns the response.
@@ -274,9 +274,9 @@ class ChatClient(Protocol):
"""
...
async def get_streaming_response(
def get_streaming_response(
self,
messages: str | ChatMessage | Sequence[ChatMessage],
messages: str | ChatMessage | list[ChatMessage],
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Sends input messages and streams the response.
@@ -1631,6 +1631,17 @@ class AgentRunResponse(AFBaseModel):
_finalize_response(msg)
return msg
@classmethod
async def from_agent_response_generator(
cls: type[TAgentRunResponse], updates: AsyncIterable["AgentRunResponseUpdate"]
) -> TAgentRunResponse:
"""Joins multiple updates into a single AgentRunResponse."""
msg = cls(messages=[])
async for update in updates:
_process_update(msg, update)
_finalize_response(msg)
return msg
def __str__(self) -> str:
return self.text
+214 -48
View File
@@ -1,46 +1,55 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Sequence
from typing import Any, TypeVar
from collections.abc import AsyncIterable, MutableSequence, Sequence
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from pytest import fixture
from pytest import fixture, raises
from agent_framework import (
Agent,
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
ChatClient,
ChatClientAgent,
ChatClientAgentThread,
ChatClientAgentThreadType,
ChatClientBase,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
ChatRole,
TextContent,
)
TThreadType = TypeVar("TThreadType", bound=AgentThread)
from agent_framework.exceptions import AgentExecutionException
# Mock AgentThread implementation for testing
class MockAgentThread(AgentThread):
async def _create(self) -> str:
return str(uuid4())
async def _delete(self) -> None:
pass
async def _on_new_message(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
async def _on_new_messages(self, new_messages: ChatMessage | Sequence[ChatMessage]) -> None:
pass
# Mock Agent implementation for testing
class MockAgent(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str | None = None
description: str | None = None
class MockAgent(Agent):
@property
def id(self) -> str:
return str(uuid4())
@property
def name(self) -> str | None:
"""Returns the name of the agent."""
return "Name"
@property
def description(self) -> str | None:
return "Description"
async def run(
self,
messages: ChatMessage | str | list[ChatMessage] | None = None,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
@@ -49,7 +58,7 @@ class MockAgent(BaseModel):
async def run_stream(
self,
messages: str | ChatMessage | list[ChatMessage] | None = None,
messages: str | ChatMessage | list[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
@@ -60,6 +69,36 @@ class MockAgent(BaseModel):
return MockAgentThread()
# Mock ChatClient implementation for testing
class MockChatClient(ChatClientBase):
_mock_response: ChatResponse | None = None
def __init__(self, mock_response: ChatResponse | None = None) -> None:
self._mock_response = mock_response
async def _inner_get_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> ChatResponse:
return (
self._mock_response
if self._mock_response
else ChatResponse(messages=ChatMessage(role=ChatRole.ASSISTANT, text="test response"))
)
async def _inner_get_streaming_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(role=ChatRole.ASSISTANT, text=TextContent(text="test streaming response"))
@fixture
def agent_thread() -> AgentThread:
return MockAgentThread()
@@ -70,39 +109,15 @@ def agent() -> Agent:
return MockAgent()
@fixture
def chat_client() -> ChatClientBase:
return MockChatClient()
def test_agent_thread_type(agent_thread: AgentThread) -> None:
assert isinstance(agent_thread, AgentThread)
async def test_agent_thread_id_property(agent_thread: AgentThread) -> None:
assert agent_thread.id is None
await agent_thread.create()
assert isinstance(agent_thread.id, str)
async def test_agent_thread_create(agent_thread: AgentThread) -> None:
thread_id = await agent_thread.create()
assert thread_id == agent_thread.id
assert isinstance(thread_id, str)
async def test_agent_thread_create_already_exists(agent_thread: AgentThread) -> None:
thread_id = await agent_thread.create()
same_id = await agent_thread.create()
assert thread_id == same_id
async def test_agent_thread_delete_already_deleted(agent_thread: AgentThread) -> None:
await agent_thread.delete()
await agent_thread.delete() # Should not raise error
async def test_agent_thread_on_new_message_creates_thread(agent_thread: AgentThread) -> None:
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
await agent_thread.on_new_message(message)
assert agent_thread.id is not None
def test_agent_type(agent: Agent) -> None:
assert isinstance(agent, Agent)
@@ -120,3 +135,154 @@ async def test_agent_run_stream(agent: Agent) -> None:
updates = await collect_updates(agent.run_stream(messages="test"))
assert len(updates) == 1
assert updates[0].text == "Response"
async def test_chat_client_agent_thread_init_in_memory() -> None:
messages = [ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])]
thread = ChatClientAgentThread(messages=messages)
assert thread.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES
assert thread.id is None
assert thread.chat_messages == messages
async def test_chat_client_agent_thread_empty() -> None:
thread = ChatClientAgentThread()
assert thread.storage_location is None
assert thread.id is None
assert thread.chat_messages is None
async def test_chat_client_agent_thread_init_invalid() -> None:
with raises(ValueError, match="Cannot specify both id and messages"):
ChatClientAgentThread(id="123", messages=[ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])])
with raises(ValueError, match="ID cannot be empty or whitespace"):
ChatClientAgentThread(id=" ")
async def test_chat_client_agent_thread_init_conversation_id() -> None:
thread_id = str(uuid4())
thread = ChatClientAgentThread(id=thread_id)
assert thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID
assert thread.id == thread_id
assert thread.chat_messages is None
async def test_chat_client_agent_thread_get_messages() -> None:
messages = [ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])]
thread = ChatClientAgentThread(messages=messages)
result = [msg async for msg in thread.get_messages()]
assert result == messages
async def test_chat_client_agent_thread_on_new_messages_in_memory() -> None:
initial_message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Initial message")])
new_message = ChatMessage(role=ChatRole.USER, contents=[TextContent("New message")])
thread = ChatClientAgentThread(messages=[initial_message])
await thread._on_new_messages(new_message) # type: ignore[reportPrivateUsage]
assert thread.chat_messages == [initial_message, new_message]
def test_chat_client_agent_type(chat_client: ChatClient) -> None:
chat_client_agent = ChatClientAgent(chat_client=chat_client)
assert isinstance(chat_client_agent, Agent)
async def test_chat_client_agent_init(chat_client: ChatClient) -> None:
agent_id = str(uuid4())
agent = ChatClientAgent(chat_client=chat_client, id=agent_id, description="Test")
assert agent.id == agent_id
assert agent.name == "UnnamedAgent"
assert agent.description == "Test"
async def test_chat_client_agent_run(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
result = await agent.run("Hello")
assert result.text == "test response"
async def test_chat_client_agent_run_stream(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
result = await AgentRunResponse.from_agent_response_generator(agent.run_stream("Hello"))
assert result.text == "test streaming response"
async def test_chat_client_agent_get_new_thread(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
thread = agent.get_new_thread()
assert isinstance(thread, ChatClientAgentThread)
assert thread.storage_location is None
async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
message = ChatMessage(role=ChatRole.USER, contents=[TextContent("Hello")])
thread = ChatClientAgentThread(messages=[message])
result_thread, result_messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=thread,
input_messages="Test",
construct_thread=lambda: ChatClientAgentThread(),
expected_type=ChatClientAgentThread,
)
assert result_thread == thread
assert len(result_messages) == 2
assert result_messages[0] == message
assert result_messages[1].text == "Test"
async def test_chat_client_agent_update_thread_id() -> None:
chat_client = MockChatClient(
mock_response=ChatResponse(
messages=[ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("test response")])],
conversation_id="123",
)
)
agent = ChatClientAgent(chat_client=chat_client)
thread = agent.get_new_thread()
result = await agent.run("Hello", thread=thread)
assert result.text == "test response"
assert thread.id == "123"
assert isinstance(thread, ChatClientAgentThread)
assert thread.storage_location == ChatClientAgentThreadType.CONVERSATION_ID
async def test_chat_client_agent_update_thread_messages(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
thread = agent.get_new_thread()
result = await agent.run("Hello", thread=thread)
assert result.text == "test response"
assert thread.id is None
assert isinstance(thread, ChatClientAgentThread)
assert thread.storage_location == ChatClientAgentThreadType.IN_MEMORY_MESSAGES
assert thread.chat_messages is not None
assert len(thread.chat_messages) == 2
assert thread.chat_messages[0].text == "Hello"
assert thread.chat_messages[1].text == "test response"
async def test_chat_client_agent_update_thread_conversation_id_missing(chat_client: ChatClient) -> None:
agent = ChatClientAgent(chat_client=chat_client)
thread = ChatClientAgentThread(id="123")
with raises(AgentExecutionException, match="Service did not return a valid conversation id"):
agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage]