mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Added author name logic to ChatClientAgent (#313)
This commit is contained in:
committed by
GitHub
Unverified
parent
f5b35d8403
commit
544d2ba48f
@@ -491,6 +491,7 @@ class ChatClientAgent(AgentBase):
|
||||
"""
|
||||
input_messages = self._normalize_messages(messages)
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
|
||||
agent_name = self._get_agent_name()
|
||||
|
||||
response = await self.chat_client.get_response(
|
||||
messages=thread_messages,
|
||||
@@ -519,6 +520,11 @@ class ChatClientAgent(AgentBase):
|
||||
|
||||
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
|
||||
|
||||
# Ensure that the author name is set for each message in the response.
|
||||
for message in response.messages:
|
||||
if message.author_name is None:
|
||||
message.author_name = agent_name
|
||||
|
||||
# Only notify the thread of new messages if the chatResponse was successful
|
||||
# to avoid inconsistent messages state in the thread.
|
||||
await self._notify_thread_of_new_messages(thread, input_messages)
|
||||
@@ -595,6 +601,7 @@ class ChatClientAgent(AgentBase):
|
||||
"""
|
||||
input_messages = self._normalize_messages(messages)
|
||||
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
|
||||
agent_name = self._get_agent_name()
|
||||
response_updates: list[ChatResponseUpdate] = []
|
||||
|
||||
async for update in self.chat_client.get_streaming_response(
|
||||
@@ -622,6 +629,10 @@ class ChatClientAgent(AgentBase):
|
||||
**kwargs,
|
||||
):
|
||||
response_updates.append(update)
|
||||
|
||||
if update.author_name is None:
|
||||
update.author_name = agent_name
|
||||
|
||||
yield AgentRunResponseUpdate(
|
||||
contents=update.contents,
|
||||
role=update.role,
|
||||
@@ -720,3 +731,6 @@ class ChatClientAgent(AgentBase):
|
||||
return [messages]
|
||||
|
||||
return [ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in messages]
|
||||
|
||||
def _get_agent_name(self) -> str:
|
||||
return self.name or "UnnamedAgent"
|
||||
|
||||
@@ -43,6 +43,11 @@ class MockAgent(AIAgent):
|
||||
"""Returns the name of the agent."""
|
||||
return "Name"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Returns the name of the agent."""
|
||||
return "Display Name"
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
return "Description"
|
||||
@@ -243,7 +248,7 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl
|
||||
message = ChatMessage(role=ChatRole.USER, text="Hello")
|
||||
thread = ChatClientAgentThread(messages=[message])
|
||||
|
||||
result_thread = agent._validate_or_create_thread_type(
|
||||
result_thread = agent._validate_or_create_thread_type( # type: ignore[reportPrivateUsage]
|
||||
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
|
||||
) # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -264,7 +269,7 @@ async def test_chat_client_agent_validate_or_create_thread(chat_client: ChatClie
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
thread = None
|
||||
|
||||
result_thread = agent._validate_or_create_thread_type(
|
||||
result_thread = agent._validate_or_create_thread_type( # type: ignore[reportPrivateUsage]
|
||||
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
|
||||
) # type: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -313,3 +318,36 @@ async def test_chat_client_agent_update_thread_conversation_id_missing(chat_clie
|
||||
|
||||
with raises(AgentExecutionException, match="Service did not return a valid conversation id"):
|
||||
agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_chat_client_agent_default_author_name(chat_client: ChatClient) -> None:
|
||||
# Name is not specified here, so default name should be used
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
|
||||
result = await agent.run("Hello")
|
||||
assert result.text == "test response"
|
||||
assert result.messages[0].author_name == "UnnamedAgent"
|
||||
|
||||
|
||||
async def test_chat_client_agent_author_name_as_agent_name(chat_client: ChatClient) -> None:
|
||||
# Name is specified here, so it should be used as author name
|
||||
agent = ChatClientAgent(chat_client=chat_client, name="TestAgent")
|
||||
|
||||
result = await agent.run("Hello")
|
||||
assert result.text == "test response"
|
||||
assert result.messages[0].author_name == "TestAgent"
|
||||
|
||||
|
||||
async def test_chat_client_agent_author_name_is_used_from_response() -> None:
|
||||
chat_client = MockChatClient(
|
||||
mock_response=ChatResponse(
|
||||
messages=[
|
||||
ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("test response")], author_name="TestAuthor")
|
||||
]
|
||||
)
|
||||
)
|
||||
agent = ChatClientAgent(chat_client=chat_client)
|
||||
|
||||
result = await agent.run("Hello")
|
||||
assert result.text == "test response"
|
||||
assert result.messages[0].author_name == "TestAuthor"
|
||||
|
||||
Reference in New Issue
Block a user