Update ChatClientAgentThread to support both in-memory and service storage. (#70)

* Update ChatClientAgentThread to support both in-memory and service storage.

* Fix typos.

* Address PR comments

* Move code to reusable sections.

* Remove DefaultThreadStorageLocation

* Change thread type naming and make it internal

* Fix bug for streaming case.

* Add additional unit tests

* Add more unit tests to verify agent's thread update behavior.
This commit is contained in:
westey
2025-06-12 14:28:08 +01:00
committed by GitHub
Unverified
parent 2c75f13337
commit d6dc360215
11 changed files with 359 additions and 14 deletions
+4
View File
@@ -53,6 +53,10 @@
<File Path="src/LegacySupport/DiagnosticAttributes/NullableAttributes.cs" />
<File Path="src/LegacySupport/DiagnosticAttributes/README.md" />
</Folder>
<Folder Name="/Solution Items/src/LegacySupport/DiagnosticClasses/">
<File Path="src/LegacySupport/DiagnosticClasses/UnreachableException.cs" />
<File Path="src/LegacySupport/DiagnosticClasses/README.md" />
</Folder>
<Folder Name="/Solution Items/src/LegacySupport/ExperimentalAttribute/">
<File Path="src/LegacySupport/ExperimentalAttribute/ExperimentalAttribute.cs" />
<File Path="src/LegacySupport/ExperimentalAttribute/README.md" />
+4
View File
@@ -1,4 +1,8 @@
<Project>
<ItemGroup Condition="'$(InjectDiagnosticClassesOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\DiagnosticClasses\UnreachableException.cs" LinkBase="LegacySupport\DiagnosticClasses" />
</ItemGroup>
<ItemGroup Condition="'$(InjectDiagnosticAttributesOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\DiagnosticAttributes\*.cs" LinkBase="LegacySupport\DiagnosticAttributes" />
</ItemGroup>
@@ -0,0 +1,9 @@
# DiagnosticClasses
To use this source in your project, add the following to your `.csproj` file:
```xml
<PropertyGroup>
<InjectDiagnosticClassesOnLegacy>true</InjectDiagnosticClassesOnLegacy>
</PropertyGroup>
```
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft. All rights reserved.
// Polyfill for using UnreachableException with .NET Standard 2.0
namespace System.Diagnostics;
#pragma warning disable CA1064 // Exceptions should be public
#pragma warning disable CA1812 // Internal class that is (sometimes) never instantiated.
/// <summary>
/// Exception thrown when the program executes an instruction that was thought to be unreachable.
/// </summary>
internal sealed class UnreachableException : Exception
{
private const string MessageText = "The program executed an instruction that was thought to be unreachable.";
/// <summary>
/// Initializes a new instance of the <see cref="UnreachableException"/> class with the default error message.
/// </summary>
public UnreachableException()
: base(MessageText)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="UnreachableException"/>
/// class with a specified error message.
/// </summary>
/// <param name="message">The error message that explains the reason for the exception.</param>
public UnreachableException(string? message)
: base(message ?? MessageText)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="UnreachableException"/>
/// class with a specified error message and a reference to the inner exception that is the cause of
/// this exception.
/// </summary>
/// <param name="message">The error message that explains the reason for the exception.</param>
/// <param name="innerException">The exception that is the cause of the current exception.</param>
public UnreachableException(string? message, Exception? innerException)
: base(message ?? MessageText, innerException)
{
}
}
@@ -17,6 +17,17 @@ public class AgentThread
/// <summary>
/// Gets or sets the id of the current thread.
/// </summary>
/// <remarks>
/// <para>
/// This id may be null if the thread has no id, or
/// if it represents a service-owned thread but the service
/// has not yet been called to create the thread.
/// </para>
/// <para>
/// The id may also change over time where the <see cref="AgentThread"/>
/// is a proxy to a service owned thread that forks on each agent invocation.
/// </para>
/// </remarks>
public string? Id { get; set; }
/// <summary>
@@ -75,6 +75,10 @@ public sealed class ChatClientAgent : Agent
this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType, messages.Count);
// We can derive the type of supported thread from whether we have a conversation id,
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(chatClientThread, chatResponse.ConversationId);
// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread.
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, messages, cancellationToken).ConfigureAwait(false);
@@ -123,9 +127,6 @@ public sealed class ChatClientAgent : Agent
// Ensure we start the streaming request
var hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, messages, cancellationToken).ConfigureAwait(false);
while (hasUpdates)
{
var update = responseUpdatesEnumerator.Current;
@@ -142,6 +143,13 @@ public sealed class ChatClientAgent : Agent
var chatResponse = responseUpdates.ToChatResponse();
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? chatResponse.Messages.ToArray();
// We can derive the type of supported thread from whether we have a conversation id,
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(chatClientThread, chatResponse.ConversationId);
// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, messages, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
if (options?.OnIntermediateMessages is not null)
{
@@ -189,9 +197,47 @@ public sealed class ChatClientAgent : Agent
// Add the input messages to the end of thread messages.
threadMessages.AddRange(inputMessages);
// If a user provided two different thread ids, via the thread object and options, we should throw
// since we don't know which one to use.
if (!string.IsNullOrWhiteSpace(chatClientThread.Id) && !string.IsNullOrWhiteSpace(chatOptions?.ConversationId) && chatClientThread.Id != chatOptions.ConversationId)
{
throw new InvalidOperationException(
$"The {nameof(ChatOptions.ConversationId)} provided via {nameof(ChatOptions)} is different to the id of the provided {nameof(AgentThread)}. Only one thread id can be used for a run.");
}
// Only clone and update ChatOptions if we have an id on the thread and we don't have the same one already in ChatOptions.
if (!string.IsNullOrWhiteSpace(chatClientThread.Id) && chatClientThread.Id != chatOptions?.ConversationId)
{
chatOptions = chatOptions is null ? new ChatOptions() : chatOptions.Clone();
chatOptions.ConversationId = chatClientThread.Id;
}
return (chatClientThread, chatOptions, threadMessages);
}
private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread chatClientThread, string? responseConversationId)
{
// Set the thread's storage location, the first time that we use it.
if (chatClientThread.StorageLocation is null)
{
chatClientThread.StorageLocation = string.IsNullOrWhiteSpace(responseConversationId)
? ChatClientAgentThreadType.InMemoryMessages
: ChatClientAgentThreadType.ConversationId;
}
// If we got a conversation id back from the chat client, it means that the service supports server side thread storage
// so we should capture the id and update the thread with the new id.
if (chatClientThread.StorageLocation == ChatClientAgentThreadType.ConversationId)
{
if (string.IsNullOrWhiteSpace(responseConversationId))
{
throw new InvalidOperationException("Service did not return a valid conversation id when using a service managed thread.");
}
chatClientThread.Id = responseConversationId;
}
}
private void UpdateThreadMessagesWithAgentInstructions(List<ChatMessage> threadMessages, AgentRunOptions? options)
{
if (!string.IsNullOrWhiteSpace(options?.AdditionalInstructions))
@@ -1,10 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents;
@@ -15,8 +17,50 @@ public sealed class ChatClientAgentThread : AgentThread, IMessagesRetrievableThr
{
private readonly List<ChatMessage> _chatMessages = [];
/// <inheritdoc/>
/// <summary>
/// Initializes a new instance of the <see cref="ChatClientAgentThread"/> class.
/// </summary>
public ChatClientAgentThread()
{
}
/// <summary>
/// Initializes a new instance of the <see cref="ChatClientAgentThread"/> class.
/// </summary>
/// <param name="id">The id of an existing server side thread to continue.</param>
/// <remarks>
/// This constructor creates a <see cref="ChatClientAgentThread"/> that supports in-service message storage.
/// </remarks>
public ChatClientAgentThread(string id)
{
Throw.IfNullOrWhitespace(id);
this.Id = id;
this.StorageLocation = ChatClientAgentThreadType.ConversationId;
}
/// <summary>
/// Initializes a new instance of the <see cref="ChatClientAgentThread"/> class.
/// </summary>
/// <param name="messages">A set of initial messages to seed the thread with.</param>
/// <remarks>
/// This constructor creates a <see cref="ChatClientAgentThread"/> that supports local in-memory message storage.
/// </remarks>
public ChatClientAgentThread(IEnumerable<ChatMessage> messages)
{
Throw.IfNull(messages);
this._chatMessages.AddRange(messages);
this.StorageLocation = ChatClientAgentThreadType.InMemoryMessages;
}
/// <summary>
/// Gets the location of the thread contents.
/// </summary>
internal ChatClientAgentThreadType? StorageLocation { get; set; }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc/>
public async IAsyncEnumerable<ChatMessage> GetMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var message in this._chatMessages)
@@ -24,12 +68,26 @@ public sealed class ChatClientAgentThread : AgentThread, IMessagesRetrievableThr
yield return message;
}
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc/>
protected override Task OnNewMessagesAsync(IReadOnlyCollection<ChatMessage> newMessages, CancellationToken cancellationToken = default)
{
this._chatMessages.AddRange(newMessages);
switch (this.StorageLocation)
{
case ChatClientAgentThreadType.InMemoryMessages:
this._chatMessages.AddRange(newMessages);
break;
case ChatClientAgentThreadType.ConversationId:
// If the thread messages are stored in the service
// there is nothing to do here, since invoking the
// service should already update the thread.
break;
default:
throw new UnreachableException();
}
return Task.CompletedTask;
}
}
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Agents;
/// <summary>
/// Defines the different supported storage locations for <see cref="ChatClientAgentThread"/>.
/// </summary>
internal enum ChatClientAgentThreadType
{
/// <summary>
/// Messages are stored in memory inside the thread object.
/// </summary>
InMemoryMessages,
/// <summary>
/// Messages are stored in the service and the thread object just has an id reference the service storage.
/// </summary>
ConversationId
}
@@ -7,6 +7,7 @@
<PropertyGroup>
<InjectSharedThrow>true</InjectSharedThrow>
<InjectDiagnosticClassesOnLegacy>true</InjectDiagnosticClassesOnLegacy>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)' == 'Debug'">
@@ -26,4 +27,8 @@
<Description>Contains the Microsoft Agent Framework core functionality.</Description>
</PropertyGroup>
<ItemGroup>
<InternalsVisibleTo Include="Microsoft.Agents.UnitTests" />
</ItemGroup>
</Project>
@@ -346,6 +346,97 @@ public class ChatClientAgentTests
Assert.Equal(ChatRole.System, capturedMessages[0].Role);
}
/// <summary>
/// Verify that RunAsync does not throw when providing a thread with a ThreadId and a Conversationid
/// via ChatOptions and the two are the same.
/// </summary>
[Fact]
public async Task RunAsyncDoesNotThrowWhenSpecifyingTwoSameThreadIdsAsync()
{
// Arrange
var chatOptions = new ChatOptions { ConversationId = "ConvId" };
Mock<IChatClient> mockService = new();
mockService.Setup(
s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.Is<ChatOptions>(opts => opts.ConversationId == "ConvId"),
It.IsAny<CancellationToken>())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" });
ChatClientAgent agent = new(mockService.Object, new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ConvId");
// Act & Assert
await agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions);
}
/// <summary>
/// Verify that RunAsync throws when providing a thread with a ThreadId and a Conversationid
/// via ChatOptions and the two are different.
/// </summary>
[Fact]
public async Task RunAsyncThrowsWhenSpecifyingTwoDifferentThreadIdsAsync()
{
// Arrange
var chatOptions = new ChatOptions { ConversationId = "ConvId" };
Mock<IChatClient> mockService = new();
ChatClientAgent agent = new(mockService.Object, new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ThreadId");
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions));
}
/// <summary>
/// Verify that RunAsync clones the ChatOptions when providing a thread with a ThreadId and a ChatOptions.
/// </summary>
[Fact]
public async Task RunAsyncClonesChatOptionsToAddThreadIdAsync()
{
// Arrange
var chatOptions = new ChatOptions { MaxOutputTokens = 100 };
Mock<IChatClient> mockService = new();
mockService.Setup(
s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.Is<ChatOptions>(opts => opts.MaxOutputTokens == 100 && opts.ConversationId == "ConvId"),
It.IsAny<CancellationToken>())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" });
ChatClientAgent agent = new(mockService.Object, new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ConvId");
// Act
await agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions);
// Assert
Assert.Null(chatOptions.ConversationId);
}
/// <summary>
/// Verify that RunAsync throws if a thread is provided that uses a conversation id already, but the service does not return one on invoke.
/// </summary>
[Fact]
public async Task RunAsyncThrowsForMissingConversationIdWithConversationIdThreadAsync()
{
// Arrange
Mock<IChatClient> mockService = new();
mockService.Setup(
s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]));
ChatClientAgent agent = new(mockService.Object, new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ConvId");
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync([new(ChatRole.User, "test")], thread));
}
#region Property Override Tests
/// <summary>
@@ -133,7 +133,45 @@ public class ChatClientAgentThreadTests
var thread = new ChatClientAgentThread();
// Assert
Assert.Null(thread.Id); // Id should be null until created
Assert.Null(thread.Id); // Id should be null until created on first use.
Assert.Null(thread.StorageLocation); // StorageLocation should be null until first use
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> initializes with expected default values.
/// </summary>
[Fact]
public async Task VerifyThreadWithMessagesInitialStateAsync()
{
// Arrange
var message = new ChatMessage(ChatRole.User, "Hello");
// Act
var thread = new ChatClientAgentThread([message]);
// Assert
Assert.Null(thread.Id); // Id should be null when we add messages, since it's a local thread.
Assert.Equal(ChatClientAgentThreadType.InMemoryMessages, thread.StorageLocation); // StorageLocation should be set to local since we are adding messages already.
var messages = await thread.GetMessagesAsync().ToListAsync();
Assert.Contains(message, messages);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> initializes with expected default values.
/// </summary>
[Fact]
public async Task VerifyThreadWithIdInitialStateAsync()
{
// Act
var thread = new ChatClientAgentThread("TestConvId");
// Assert
Assert.Equal("TestConvId", thread.Id);
Assert.Equal(ChatClientAgentThreadType.ConversationId, thread.StorageLocation);
var messages = await thread.GetMessagesAsync().ToListAsync();
Assert.Empty(messages);
}
#region Core Override Method Tests
@@ -160,8 +198,9 @@ public class ChatClientAgentThreadTests
// Assert
Assert.NotNull(thread);
Assert.IsType<ChatClientAgentThread>(thread);
Assert.Null(thread.Id); // Id should be null until the thread is actually used
var chatClientAgentThread = Assert.IsType<ChatClientAgentThread>(thread);
Assert.Null(thread.Id); // Id should be null until created on first use.
Assert.Null(chatClientAgentThread.StorageLocation); // StorageLocation should be null until first use
}
/// <summary>
@@ -187,8 +226,10 @@ public class ChatClientAgentThreadTests
/// <summary>
/// Verify that messages are properly stored and retrieved through the thread lifecycle.
/// </summary>
[Fact]
public async Task ThreadLifecycleStoresAndRetrievesMessagesAsync()
[Theory]
[InlineData(null, true)]
[InlineData("TestConvid", false)]
public async Task ThreadLifecycleStoresAndRetrievesMessagesAsync(string? responseConversationId, bool messagesStored)
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello");
@@ -200,7 +241,7 @@ public class ChatClientAgentThreadTests
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new ChatResponse([assistantMessage]));
.ReturnsAsync(new ChatResponse([assistantMessage]) { ConversationId = responseConversationId });
var agent = new ChatClientAgent(mockChatClient.Object, new() { Instructions = "Test instructions" });
@@ -218,9 +259,20 @@ public class ChatClientAgentThreadTests
}
// Assert
Assert.Equal(2, retrievedMessages.Count);
Assert.Contains(retrievedMessages, m => m.Text == "Hello" && m.Role == ChatRole.User);
Assert.Contains(retrievedMessages, m => m.Text == "Hi there!" && m.Role == ChatRole.Assistant);
Assert.Equal(messagesStored ? 2 : 0, retrievedMessages.Count);
if (messagesStored)
{
Assert.Contains(retrievedMessages, m => m.Text == "Hello" && m.Role == ChatRole.User);
Assert.Contains(retrievedMessages, m => m.Text == "Hi there!" && m.Role == ChatRole.Assistant);
}
var chatClientAgentThread = Assert.IsType<ChatClientAgentThread>(thread);
Assert.Equal(responseConversationId, thread.Id); // Id should match the returned conversation id.
Assert.Equal(
messagesStored
? ChatClientAgentThreadType.InMemoryMessages
: ChatClientAgentThreadType.ConversationId,
chatClientAgentThread.StorageLocation); // StorageLocation should be based on whether we got back a conversation id
}
/// <summary>