.Net: Add support for 3rd party thread storage and thread serialization (#203)

* Add thread storage and serialization POC

* Switch to using JsonElement and add unit tests

* Add additional unit tests.

* Exclude private debugger properties from CodeCoverage.

* Rename IChatMessagesStorable to IChatMessageStore

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Improve xml doc.

* Update the message storing thread to always use external store for both local and remote storage.

* Remove threadid from the IChatMessageStore interface, since the store should own the thread id itself, if it requires one.

* Switch GetMessages to IEnumerable

* Address pr comments.

* Make jsonserializer options default consistent on DeserializeThreadAsync

* Move message storing thread functionality into AgentThread and simplify AgentThread behavior.

* Remove embedding generation from VectorStore chat history sample.

* Remove unecessary code and fix formatting.

* Make GetNewThread and DeserializeThread virtual with default implementations.
Remove unsued json utilities.

* Fix formatting

* Remove problem test.

* Add more unit tests

* Remove unused using clause.

* Address pr feedback.

* Address PR comments.

* Make InMemory store internal

* Switch InMemoryChatMessageStore to implement IList instead of inheriting from List.

* Rename store deserialize param.

* Update serialization based on PR comments.

* Remove confusing comment.

* Address Deserialization PR comments in the same way as Serialization

* Add State to IChatMessageStore Serialize and Deserialize names.
Make Thread Deserialize internal.
Make AgentThread type switching fobidden.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
This commit is contained in:
westey
2025-08-05 18:24:25 +01:00
committed by GitHub
Unverified
parent c1d306ec95
commit ff3e13c2aa
38 changed files with 1143 additions and 1427 deletions
+2
View File
@@ -55,6 +55,8 @@
<PackageVersion Include="Microsoft.Extensions.Logging" Version="9.0.7" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.7" />
<PackageVersion Include="Microsoft.Extensions.Logging.Testing" Version="9.0.7" />
<!-- Vector Stores -->
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.InMemory" Version="1.61.0-preview" />
<!-- Agent SDKs -->
<PackageVersion Include="Microsoft.Agents.CopilotStudio.Client" Version="1.1.151" />
<!-- Identity -->
+2 -2
View File
@@ -167,7 +167,7 @@ public class AgentSample(ITestOutputHelper output) : BaseSample(output)
// If a thread is provided, delete it as well.
if (thread is not null)
{
await persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id, cancellationToken);
await persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId, cancellationToken);
}
}
@@ -183,7 +183,7 @@ public class AgentSample(ITestOutputHelper output) : BaseSample(output)
// If a thread is provided, delete it as well.
if (thread is not null)
{
await assistantClient.DeleteThreadAsync(thread.Id, cancellationToken);
await assistantClient.DeleteThreadAsync(thread.ConversationId, cancellationToken);
}
}
@@ -27,6 +27,7 @@
<PackageReference Include="Microsoft.Extensions.Configuration.Json" />
<PackageReference Include="Microsoft.Extensions.Configuration.UserSecrets" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="Microsoft.SemanticKernel.Connectors.InMemory" />
<PackageReference Include="System.Diagnostics.DiagnosticSource" />
<PackageReference Include="OpenTelemetry" />
<PackageReference Include="OpenTelemetry.Exporter.Console" />
@@ -51,7 +51,7 @@ public sealed class AIAgent_With_AzureAIAgentsPersistent(ITestOutputHelper outpu
}
// Cleanup
await persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id);
await persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId);
await persistentAgentsClient.Administration.DeleteAgentAsync(agent.Id);
}
@@ -85,7 +85,7 @@ public sealed class AIAgent_With_AzureAIAgentsPersistent(ITestOutputHelper outpu
}
// Cleanup
await persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id);
await persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId);
await persistentAgentsClient.Administration.DeleteAgentAsync(agent.Id);
}
}
@@ -52,7 +52,7 @@ public sealed class AIAgent_With_OpenAIAssistant(ITestOutputHelper output) : Age
// Cleanup
var assistantClient = openAIClient.GetAssistantClient();
await assistantClient.DeleteThreadAsync(thread.Id);
await assistantClient.DeleteThreadAsync(thread.ConversationId);
await assistantClient.DeleteAssistantAsync(agent.Id);
}
}
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json;
using Microsoft.Extensions.AI.Agents;
namespace Steps;
/// <summary>
/// Demonstrates how to suspend and resume a thread with the <see cref="ChatClientAgent"/>.
/// </summary>
public sealed class Step08_ChatClientAgent_SuspendResumeThread(ITestOutputHelper output) : AgentSample(output)
{
private const string JokerName = "Joker";
private const string JokerInstructions = "You are good at telling jokes.";
/// <summary>
/// Demonstrate the usage of <see cref="ChatClientAgent"/> where a thread is suspended.
/// The thread is serialized and can be stored to a database, file, or any other storage mechanism,
/// and then deserialized later to resume the conversation with the agent.
/// </summary>
[Theory]
[InlineData(ChatClientProviders.AzureAIAgentsPersistent)]
[InlineData(ChatClientProviders.AzureOpenAI)]
[InlineData(ChatClientProviders.OpenAIAssistant)]
[InlineData(ChatClientProviders.OpenAIResponses_InMemoryMessageThread)]
[InlineData(ChatClientProviders.OpenAIResponses_ConversationIdThread)]
public async Task SuspendResumeThread(ChatClientProviders provider)
{
// Define the options for the chat client agent.
var agentOptions = new ChatClientAgentOptions
{
Name = JokerName,
Instructions = JokerInstructions,
// Get chat options based on the store type, if needed.
ChatOptions = base.GetChatOptions(provider),
};
// Create the server-side agent Id when applicable (depending on the provider).
agentOptions.Id = await base.AgentCreateAsync(provider, agentOptions);
// Get the chat client to use for the agent.
using var chatClient = base.GetChatClient(provider, agentOptions);
// Define the agent
var agent = new ChatClientAgent(chatClient, agentOptions);
// Start a new thread for the agent conversation.
AgentThread thread = agent.GetNewThread();
// Respond to user input
Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread));
// Serialize the thread state, so it can be stored for later use.
JsonElement serializedThread = await thread.SerializeAsync();
// The thread can now be saved to a database, file, or any other storage mechanism
// and loaded again later.
// Deserialize the thread state after loading from storage.
AgentThread resumedThread = await agent.DeserializeThreadAsync(serializedThread);
Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread));
// Clean up the server-side agent and thread after use when applicable (depending on the provider).
await base.AgentCleanUpAsync(provider, agent, thread);
}
}
@@ -0,0 +1,156 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.AI.Agents;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.InMemory;
namespace Steps;
/// <summary>
/// Demonstrates how to store the chat history of a thread in a 3rd party store when using <see cref="ChatClientAgent"/>.
/// </summary>
public sealed class Step09_ChatClientAgent_3rdPartyThreadStorage(ITestOutputHelper output) : AgentSample(output)
{
private const string JokerName = "Joker";
private const string JokerInstructions = "You are good at telling jokes.";
/// <summary>
/// Demonstrate storage of the chat history of a thread in a 3rd party store when using <see cref="ChatClientAgent"/>.
/// </summary>
/// <remarks>
/// Note that this is only supported for services that do not already store the chat history in their own service.
/// </remarks>
[Theory]
[InlineData(ChatClientProviders.AzureOpenAI)]
[InlineData(ChatClientProviders.OpenAIResponses_InMemoryMessageThread)]
public async Task ThirdPartyStorageThread(ChatClientProviders provider)
{
var inMemoryVectorStore = new InMemoryVectorStore();
// Define the options for the chat client agent.
var agentOptions = new ChatClientAgentOptions
{
Name = JokerName,
Instructions = JokerInstructions,
// Get chat options based on the store type, if needed.
ChatOptions = base.GetChatOptions(provider),
ChatMessageStoreFactory = () =>
{
// Create a new chat message store for this agent that stores the messages in a vector store.
// Each thread must get its own copy of the VectorChatMessageStore, since the store
// also contains the id that the thread is stored under.
return new VectorChatMessageStore(inMemoryVectorStore);
}
};
// Get the chat client to use for the agent.
using var chatClient = base.GetChatClient(provider, agentOptions);
// Define the agent
var agent = new ChatClientAgent(chatClient, agentOptions);
// Start a new thread for the agent conversation.
AgentThread thread = agent.GetNewThread();
// Respond to user input
Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread));
// Serialize the thread state, so it can be stored for later use.
// Since the chat history is stored in the vector store, the serialized there
// only contains the guid that the messages are stored under in the vector store.
JsonElement serializedThread = await thread.SerializeAsync();
// The serialized thread can now be saved to a database, file, or any other storage mechanism
// and loaded again later.
// Deserialize the thread state after loading from storage.
AgentThread resumedThread = await agent.DeserializeThreadAsync(serializedThread);
Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread));
}
/// <summary>
/// A sample implementation of <see cref="IChatMessageStore"/> that stores chat messages in a vector store.
/// </summary>
/// <param name="vectorStore">The vector store to store the messages in.</param>
private sealed class VectorChatMessageStore(VectorStore vectorStore) : IChatMessageStore
{
private string? _threadId;
public string? ThreadId => this._threadId;
public async Task AddMessagesAsync(IReadOnlyCollection<ChatMessage> messages, CancellationToken cancellationToken)
{
this._threadId ??= Guid.NewGuid().ToString();
var collection = vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);
await collection.UpsertAsync(messages.Select(x => new ChatHistoryItem()
{
Key = this._threadId + x.MessageId,
Timestamp = DateTimeOffset.UtcNow,
ThreadId = this._threadId,
SerializedMessage = JsonSerializer.Serialize(x),
MessageText = x.Text
}), cancellationToken);
}
public async Task<IEnumerable<ChatMessage>> GetMessagesAsync(CancellationToken cancellationToken)
{
var collection = vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);
var records = await collection
.GetAsync(
x => x.ThreadId == this._threadId, 10,
new() { OrderBy = x => x.Descending(y => y.Timestamp) },
cancellationToken)
.ToListAsync(cancellationToken);
var messages = records
.Select(x => JsonSerializer.Deserialize<ChatMessage>(x.SerializedMessage!)!)
.ToList();
messages.Reverse();
return messages;
}
public ValueTask<JsonElement?> SerializeStateAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
// We have to serialize the thread id, so that on deserialization we can retrieve the messages using the same thread id.
return new ValueTask<JsonElement?>(JsonSerializer.SerializeToElement(this._threadId));
}
public ValueTask DeserializeStateAsync(JsonElement? serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
// Here we can deserialize the thread id so that we can access the same messages as before the suspension.
this._threadId = JsonSerializer.Deserialize<string>((JsonElement)serializedStoreState!);
return new ValueTask();
}
/// <summary>
/// The data structure used to store chat history items in the vector store.
/// </summary>
private sealed class ChatHistoryItem
{
[VectorStoreKey]
public string? Key { get; set; }
[VectorStoreData]
public string? ThreadId { get; set; }
[VectorStoreData]
public DateTimeOffset? Timestamp { get; set; }
[VectorStoreData]
public string? SerializedMessage { get; set; }
[VectorStoreData]
public string? MessageText { get; set; }
}
}
}
@@ -14,7 +14,7 @@ internal sealed class ChatClientAgentActor(
ILogger<ChatClientAgentActor> logger) : IActor
{
private string? _etag;
private ChatClientAgentThread? _thread;
private AgentThread? _thread;
public ValueTask DisposeAsync() => default;
@@ -34,11 +34,11 @@ internal sealed class ChatClientAgentActor(
if (threadResult.Value is { } threadJson)
{
// Deserialize the thread state if it exist
this._thread = threadJson.Deserialize(ChatClientAgentActorJsonContext.Default.ChatClientAgentThread);
this._thread = await agent.DeserializeThreadAsync(threadJson, cancellationToken: cancellationToken).ConfigureAwait(false);
}
}
this._thread ??= agent.GetNewThread() as ChatClientAgentThread ?? throw new InvalidOperationException("The agent did not provide a valid thread instance.");
this._thread ??= agent.GetNewThread();
Log.ThreadStateRestored(logger, context.ActorId.ToString(), response.Results[0] is GetValueResult { Value: not null });
while (!cancellationToken.IsCancellationRequested)
@@ -2,7 +2,6 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Extensions.AI.Agents;
namespace HelloHttpApi.ApiService;
@@ -14,6 +13,5 @@ namespace HelloHttpApi.ApiService;
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = false)]
[JsonSerializable(typeof(ChatClientAgentThread))]
[JsonSerializable(typeof(ChatClientAgentRunRequest))]
internal sealed partial class ChatClientAgentActorJsonContext : JsonSerializerContext;
@@ -70,13 +70,13 @@ public abstract partial class OrchestratingAgent : AIAgent
if (thread is not null)
{
if (thread is not IMessagesRetrievableThread retrievableThread)
if (thread.MessageStore is null)
{
throw new InvalidOperationException($"The thread type '{thread.GetType().Name}' is not supported by this agent. Use {nameof(GetNewThread)} to create a thread when needed.");
throw new InvalidOperationException("An agent service managed thread is not supported by this agent.");
}
List<ChatMessage> messagesList = [];
await foreach (var threadMessage in retrievableThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
await foreach (var threadMessage in thread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
messagesList.Add(threadMessage);
}
@@ -101,9 +101,6 @@ public abstract partial class OrchestratingAgent : AIAgent
}
}
/// <inheritdoc />
public sealed override AgentThread GetNewThread() => new ChatClientAgentThread();
/// <summary>
/// Initiates processing of the orchestration.
/// </summary>
@@ -207,10 +204,6 @@ public abstract partial class OrchestratingAgent : AIAgent
return response;
}
/// <inheritdoc />
protected sealed override TThreadType ValidateOrCreateThreadType<TThreadType>(AgentThread? thread, Func<TThreadType> constructThread) =>
base.ValidateOrCreateThreadType(thread, constructThread);
/// <summary>Writes the specified checkpoint state to the runtime.</summary>
/// <param name="state">The state to persist.</param>
/// <param name="context">The context for the orchestrating operation.</param>
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
@@ -50,7 +51,21 @@ public abstract class AIAgent
/// If the thread needs to be created via a service call it would be created on first use.
/// </para>
/// </remarks>
public abstract AgentThread GetNewThread();
public virtual AgentThread GetNewThread() => new();
/// <summary>
/// Deserialize the thread from JSON.
/// </summary>
/// <param name="serializedThread">The <see cref="JsonElement"/> representing the thread state.</param>
/// <param name="jsonSerializerOptions">Optional <see cref="JsonSerializerOptions"/> to use for deserializing the thread state.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The deserialized <see cref="AgentThread"/> instance.</returns>
public async ValueTask<AgentThread> DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
var thread = this.GetNewThread();
await thread.DeserializeAsync(serializedThread, jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
return thread;
}
/// <summary>
/// Run the agent with no message assuming that all required instructions are already provided to the agent or on the thread.
@@ -192,30 +207,6 @@ public abstract class AIAgent
AgentRunOptions? options = null,
CancellationToken cancellationToken = default);
/// <summary>
/// Checks that the thread is of the expected type, or if null, creates the default thread type.
/// </summary>
/// <typeparam name="TThreadType">The expected type of the thead.</typeparam>
/// <param name="thread">The thread to create if it's null and validate its type if not null.</param>
/// <param name="constructThread">A callback to use to construct the thread if it's null.</param>
/// <returns>An async task that completes once all update are complete.</returns>
protected virtual TThreadType ValidateOrCreateThreadType<TThreadType>(
AgentThread? thread,
Func<TThreadType> constructThread)
where TThreadType : AgentThread
{
Throw.IfNull(constructThread);
thread ??= constructThread();
if (thread is not TThreadType concreteThreadType)
{
throw new NotSupportedException($"{this.GetType().Name} currently only supports agent threads of type {typeof(TThreadType).Name}.");
}
return concreteThreadType;
}
/// <summary>
/// Notfiy the given thread that new messages are available.
/// </summary>
@@ -57,7 +57,8 @@ public static partial class AgentAbstractionsJsonUtilities
[JsonSerializable(typeof(AgentRunResponse[]))]
[JsonSerializable(typeof(AgentRunResponseUpdate))]
[JsonSerializable(typeof(AgentRunResponseUpdate[]))]
[JsonSerializable(typeof(AgentThread))]
[JsonSerializable(typeof(AgentThread.ThreadState))]
[JsonSerializable(typeof(InMemoryChatMessageStore.StoreState))]
[ExcludeFromCodeCoverage]
internal sealed partial class JsonContext : JsonSerializerContext;
@@ -141,9 +141,11 @@ public class AgentRunResponseUpdate
/// <summary>Gets a <see cref="AIContent"/> object to display in the debugger display.</summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
[ExcludeFromCodeCoverage]
private AIContent? ContentForDebuggerDisplay => this._contents is { Count: > 0 } ? this._contents[0] : null;
/// <summary>Gets an indication for the debugger display of whether there's more content.</summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
[ExcludeFromCodeCoverage]
private string EllipsesForDebuggerDisplay => this._contents is { Count: > 1 } ? ", ..." : string.Empty;
}
@@ -2,8 +2,12 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Extensions.AI.Agents;
@@ -13,21 +17,114 @@ namespace Microsoft.Extensions.AI.Agents;
/// </summary>
public class AgentThread
{
private string? _conversationId;
private IChatMessageStore? _messageStore;
/// <summary>
/// Gets or sets the id of the current thread.
/// Initializes a new instance of the <see cref="AgentThread"/> class.
/// </summary>
public AgentThread()
{
}
/// <summary>
/// Gets or sets the id of the current thread to support cases where the thread is owned by the agent service.
/// </summary>
/// <remarks>
/// <para>
/// This id may be null if the thread has no id, or
/// if it represents a service-owned thread but the service
/// has not yet been called to create the thread.
/// Note that either <see cref="ConversationId"/> or <see cref="MessageStore "/> may be set, but not both.
/// If <see cref="MessageStore "/> is not null, and <see cref="ConversationId"/> is set, <see cref="MessageStore "/>
/// will be reverted to null, and vice versa.
/// </para>
/// <para>
/// The id may also change over time where the <see cref="AgentThread"/>
/// is a proxy to a service owned thread that forks on each agent invocation.
/// This property may be null in the following cases:
/// <list type="bullet">
/// <item>The thread stores messages via the <see cref="IChatMessageStore"/> and not in the agent service.</item>
/// <item>This thread object is new and a server managed thread has not yet been created in the agent service.</item>
/// </list>
/// </para>
/// <para>
/// The id may also change over time where the the id is pointing at a
/// agent service managed thread, and the default behavior of a service is
/// to fork the thread with each iteration.
/// </para>
/// </remarks>
public string? Id { get; set; }
public string? ConversationId
{
get { return this._conversationId; }
set
{
if (string.IsNullOrWhiteSpace(this._conversationId) && string.IsNullOrWhiteSpace(value))
{
return;
}
if (this._messageStore is not null)
{
// If we have a message store already, we shouldn't switch the thread to use a conversation id
// since it means that the thread contents will essentially be deleted, and the thread will not work
// with the original agent anymore.
throw new InvalidOperationException("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.");
}
this._conversationId = Throw.IfNullOrWhitespace(value);
}
}
/// <summary>
/// Gets or sets the <see cref="IChatMessageStore"/> used by this thread, for cases where messages should be stored in a custom location.
/// </summary>
/// <remarks>
/// <para>
/// Note that either <see cref="ConversationId"/> or <see cref="MessageStore "/> may be set, but not both.
/// If <see cref="ConversationId"/> is not null, and <see cref="MessageStore "/> is set, <see cref="ConversationId"/>
/// will be reverted to null, and vice versa.
/// </para>
/// <para>
/// This property may be null in the following cases:
/// <list type="bullet">
/// <item>The thread stores messages in the agent service and just has an id to the remove thread, instead of in an <see cref="IChatMessageStore"/>.</item>
/// <item>This thread object is new it is not yet clear whether it will be backed by a server managed thread or an <see cref="IChatMessageStore"/>.</item>
/// </list>
/// </para>
/// </remarks>
public IChatMessageStore? MessageStore
{
get { return this._messageStore; }
set
{
if (this._messageStore is null && value is null)
{
return;
}
if (!string.IsNullOrWhiteSpace(this._conversationId))
{
// If we have a conversation id already, we shouldn't switch the thread to use a message store
// since it means that the thread will not work with the original agent anymore.
throw new InvalidOperationException("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.");
}
this._messageStore = Throw.IfNull(value);
}
}
/// <summary>
/// Retrieves any messages stored in the <see cref="IChatMessageStore"/> of the thread, otherwise returns an empty collection.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The messages from the <see cref="IChatMessageStore"/> in ascending chronological order, with the oldest message first.</returns>
public virtual async IAsyncEnumerable<ChatMessage> GetMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (this._messageStore is not null)
{
var messages = await this._messageStore!.GetMessagesAsync(cancellationToken).ConfigureAwait(false);
foreach (var message in messages)
{
yield return message;
}
}
}
/// <summary>
/// This method is called when new messages have been contributed to the chat by any participant.
@@ -39,8 +136,92 @@ public class AgentThread
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the context has been updated.</returns>
/// <exception cref="InvalidOperationException">The thread has been deleted.</exception>
protected internal virtual Task OnNewMessagesAsync(IReadOnlyCollection<ChatMessage> newMessages, CancellationToken cancellationToken = default)
protected internal virtual async Task OnNewMessagesAsync(IReadOnlyCollection<ChatMessage> newMessages, CancellationToken cancellationToken = default)
{
return Task.CompletedTask;
switch (this)
{
case { ConversationId: not null }:
// If the thread messages are stored in the service
// there is nothing to do here, since invoking the
// service should already update the thread.
break;
case { MessageStore: null }:
// If there is no conversation id, and no store we can createa a default in memory store and add messages to it.
this._messageStore = new InMemoryChatMessageStore();
await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false);
break;
case { MessageStore: not null }:
// If a store has been provided, we need to add the messages to the store.
await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false);
break;
default:
throw new UnreachableException();
}
}
/// <summary>
/// Deserializes the state contained in the provided <see cref="JsonElement"/> into the properties on this thread.
/// </summary>
/// <param name="serializedThread">A <see cref="JsonElement"/> representing the state of the thread.</param>
/// <param name="jsonSerializerOptions">Optional settings for customizing the JSON deserialization process.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
protected internal virtual async Task DeserializeAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
var state = JsonSerializer.Deserialize(
serializedThread,
AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState))) as ThreadState;
if (state?.ConversationId is string threadId)
{
this.ConversationId = threadId;
// Since we have an ID, we should not have a chat message store and we can return here.
return;
}
// If we don't have any IChatMessageStore state return here.
if (state?.StoreState is null || state?.StoreState?.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null)
{
return;
}
if (this._messageStore is null)
{
// If we don't have a chat message store yet, create an in-memory one.
this._messageStore = new InMemoryChatMessageStore();
}
await this._messageStore.DeserializeStateAsync(state!.StoreState.Value, jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Serializes the current object's state to a <see cref="JsonElement"/> using the specified serialization options.
/// </summary>
/// <param name="jsonSerializerOptions">The JSON serialization options to use.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A <see cref="JsonElement"/> representation of the object's state.</returns>
public virtual async Task<JsonElement> SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
var storeState = this._messageStore is null ?
(JsonElement?)null :
await this._messageStore.SerializeStateAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
var state = new ThreadState
{
ConversationId = this.ConversationId,
StoreState = storeState
};
return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState)));
}
internal class ThreadState
{
public string? ConversationId { get; set; }
public JsonElement? StoreState { get; set; }
}
}
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// Defines methods for storing and retrieving chat messages associated with a specific thread.
/// </summary>
/// <remarks>
/// Implementations of this interface are responsible for managing the storage of chat messages,
/// including handling large volumes of data by truncating or summarizing messages as necessary.
/// </remarks>
public interface IChatMessageStore
{
/// <summary>
/// Gets all the messages from the store that should be used for the next agent invocation.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A collection of chat messages.</returns>
/// <remarks>
/// <para>
/// Messages are returned in ascending chronological order, with the oldest message first.
/// </para>
/// <para>
/// If the messages stored in the store become very large, it is up to the store to
/// truncate, summarize or otherwise limit the number of messages returned.
/// </para>
/// <para>
/// When using implementations of <see cref="IChatMessageStore"/>, a new one should be created for each thread
/// since they may contain state that is specific to a thread.
/// </para>
/// </remarks>
Task<IEnumerable<ChatMessage>> GetMessagesAsync(CancellationToken cancellationToken);
/// <summary>
/// Adds messages to the store.
/// </summary>
/// <param name="messages">The messages to add.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>An async task.</returns>
Task AddMessagesAsync(IReadOnlyCollection<ChatMessage> messages, CancellationToken cancellationToken);
/// <summary>
/// Deserializes the state contained in the provided <see cref="JsonElement"/> into the properties on this store.
/// </summary>
/// <param name="serializedStoreState">A <see cref="JsonElement"/> representing the state of the store.</param>
/// <param name="jsonSerializerOptions">Optional settings for customizing the JSON deserialization process.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <remarks>
/// This method, together with <see cref="SerializeStateAsync(JsonSerializerOptions?, CancellationToken)"/> can be used to save and load messages from a persistent store
/// if this store only has messages in memory.
/// </remarks>
ValueTask DeserializeStateAsync(JsonElement? serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default);
/// <summary>
/// Serializes the current object's state to a <see cref="JsonElement"/> using the specified serialization options.
/// </summary>
/// <param name="jsonSerializerOptions">The JSON serialization options to use.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A <see cref="JsonElement"/> representation of the object's state.</returns>
/// <remarks>
/// This method, together with <see cref="DeserializeStateAsync(JsonElement?, JsonSerializerOptions?, CancellationToken)"/> can be used to save and load messages from a persistent store
/// if this store only has messages in memory.
/// </remarks>
ValueTask<JsonElement?> SerializeStateAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default);
}
@@ -1,36 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Threading;
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// An interface for agent threads that allow retrieval of messages in the thread for agent invocation.
/// </summary>
/// <remarks>
/// <para>
/// Some agents need to be invoked with all relevant chat history messages in order to produce a result, while some must be invoked
/// with the id of a server side thread that contains the chat history.
/// </para>
/// <para>
/// This interface can be implemented by all thread types that support the case where the agent is invoked with the chat history.
/// Implementations must consider the size of the messages provided, so that they do not exceed the maximum size of the context window
/// of the agent they are used with. Where appropriate, implementations should truncate or summarize messages so that the size of messages
/// are constrained.
/// </para>
/// </remarks>
public interface IMessagesRetrievableThread
{
/// <summary>
/// Asynchronously retrieves all messages to be used for the agent invocation.
/// </summary>
/// <remarks>
/// Messages are returned in ascending chronological order.
/// </remarks>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The messages in the thread.</returns>
/// <exception cref="InvalidOperationException">The thread has been deleted.</exception>
IAsyncEnumerable<ChatMessage> GetMessagesAsync(CancellationToken cancellationToken = default);
}
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// Represents an in-memory store for chat messages associated with a specific thread.
/// </summary>
internal class InMemoryChatMessageStore : IList<ChatMessage>, IChatMessageStore
{
private readonly List<ChatMessage> _messages = new();
/// <inheritdoc />
public int Count => this._messages.Count;
/// <inheritdoc />
public bool IsReadOnly => ((IList)this._messages).IsReadOnly;
/// <inheritdoc />
public ChatMessage this[int index]
{
get => this._messages[index];
set => this._messages[index] = value;
}
/// <inheritdoc />
public Task AddMessagesAsync(IReadOnlyCollection<ChatMessage> messages, CancellationToken cancellationToken)
{
_ = Throw.IfNull(messages);
this._messages.AddRange(messages);
return Task.CompletedTask;
}
/// <inheritdoc />
public Task<IEnumerable<ChatMessage>> GetMessagesAsync(CancellationToken cancellationToken)
{
return Task.FromResult<IEnumerable<ChatMessage>>(this._messages);
}
/// <inheritdoc />
public ValueTask DeserializeStateAsync(JsonElement? serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
if (serializedStoreState is null)
{
return new ValueTask();
}
var state = JsonSerializer.Deserialize(
serializedStoreState.Value,
AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState))) as StoreState;
if (state?.Messages is { Count: > 0 } messages)
{
this._messages.AddRange(messages);
}
return new ValueTask();
}
/// <inheritdoc />
public ValueTask<JsonElement?> SerializeStateAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
StoreState state = new()
{
Messages = this._messages,
};
return new ValueTask<JsonElement?>(JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState))));
}
/// <inheritdoc />
public int IndexOf(ChatMessage item)
=> this._messages.IndexOf(item);
/// <inheritdoc />
public void Insert(int index, ChatMessage item)
=> this._messages.Insert(index, item);
/// <inheritdoc />
public void RemoveAt(int index)
=> this._messages.RemoveAt(index);
/// <inheritdoc />
public void Add(ChatMessage item)
=> this._messages.Add(item);
/// <inheritdoc />
public void Clear()
=> this._messages.Clear();
/// <inheritdoc />
public bool Contains(ChatMessage item)
=> this._messages.Contains(item);
/// <inheritdoc />
public void CopyTo(ChatMessage[] array, int arrayIndex)
=> this._messages.CopyTo(array, arrayIndex);
/// <inheritdoc />
public bool Remove(ChatMessage item)
=> this._messages.Remove(item);
/// <inheritdoc />
public IEnumerator<ChatMessage> GetEnumerator()
=> this._messages.GetEnumerator();
/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator()
=> this.GetEnumerator();
internal class StoreState
{
public IList<ChatMessage> Messages { get; set; } = new List<ChatMessage>();
}
}
@@ -10,6 +10,7 @@
<PropertyGroup>
<InjectSharedThrow>true</InjectSharedThrow>
<InjectDiagnosticClassesOnLegacy>true</InjectDiagnosticClassesOnLegacy>
<InjectTrimAttributesOnLegacy>true</InjectTrimAttributesOnLegacy>
</PropertyGroup>
@@ -37,12 +37,6 @@ public class CopilotStudioAgent : AIAgent
this._logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<CopilotStudioAgent>();
}
/// <inheritdoc/>
public override AgentThread GetNewThread()
{
return new CopilotStudioAgentThread();
}
/// <inheritdoc/>
public override async Task<AgentRunResponse> RunAsync(
IReadOnlyCollection<ChatMessage> messages,
@@ -54,12 +48,12 @@ public class CopilotStudioAgent : AIAgent
// Ensure that we have a valid thread to work with.
// If the thread ID is null, we need to start a new conversation and set the thread ID accordingly.
CopilotStudioAgentThread copilotStudioAgentThread = base.ValidateOrCreateThreadType(thread, () => new CopilotStudioAgentThread());
copilotStudioAgentThread.Id ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false);
thread ??= this.GetNewThread();
thread.ConversationId ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false);
// Invoke the Copilot Studio agent with the provided messages.
string question = string.Join("\n", messages.Select(m => m.Text));
var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, copilotStudioAgentThread.Id, cancellationToken), streaming: false, this._logger);
var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, thread.ConversationId, cancellationToken), streaming: false, this._logger);
var responseMessagesList = new List<ChatMessage>();
await foreach (var message in responseMessages.ConfigureAwait(false))
{
@@ -87,12 +81,12 @@ public class CopilotStudioAgent : AIAgent
// Ensure that we have a valid thread to work with.
// If the thread ID is null, we need to start a new conversation and set the thread ID accordingly.
CopilotStudioAgentThread copilotStudioAgentThread = base.ValidateOrCreateThreadType(thread, () => new CopilotStudioAgentThread());
copilotStudioAgentThread.Id ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false);
thread ??= this.GetNewThread();
thread.ConversationId ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false);
// Invoke the Copilot Studio agent with the provided messages.
string question = string.Join("\n", messages.Select(m => m.Text));
var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, copilotStudioAgentThread.Id, cancellationToken), streaming: true, this._logger);
var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, thread.ConversationId, cancellationToken), streaming: true, this._logger);
// Enumerate the response messages
await foreach (ChatMessage message in responseMessages.ConfigureAwait(false))
@@ -1,8 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Extensions.AI.Agents.CopilotStudio;
/// <summary>
/// Represents a thread for interacting with a Copilot Studio agent.
/// </summary>
public class CopilotStudioAgentThread : AgentThread;
@@ -1,20 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// Source-generated JSON type information for use by all Agents implementations.
/// </summary>
[JsonSourceGenerationOptions(
JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = false)]
[JsonSerializable(typeof(ChatMessage))]
[JsonSerializable(typeof(List<ChatMessage>))]
[JsonSerializable(typeof(ChatClientAgentThread))]
internal sealed partial class AgentsJsonContext : JsonSerializerContext;
@@ -100,7 +100,7 @@ public sealed class ChatClientAgent : AIAgent
{
Throw.IfNull(messages);
(ChatClientAgentThread chatClientThread, ChatOptions? chatOptions, List<ChatMessage> threadMessages) =
(AgentThread safeThread, ChatOptions? chatOptions, List<ChatMessage> threadMessages) =
await this.PrepareThreadAndMessagesAsync(thread, messages, options, cancellationToken).ConfigureAwait(false);
var agentName = this.GetLoggingAgentName();
@@ -113,10 +113,10 @@ public sealed class ChatClientAgent : AIAgent
// We can derive the type of supported thread from whether we have a conversation id,
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(chatClientThread, chatResponse.ConversationId);
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread.
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, messages, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessagesAsync(safeThread, messages, cancellationToken).ConfigureAwait(false);
// Ensure that the author name is set for each message in the response.
foreach (ChatMessage chatResponseMessage in chatResponse.Messages)
@@ -127,7 +127,7 @@ public sealed class ChatClientAgent : AIAgent
// Convert the chat response messages to a valid IReadOnlyCollection for notification signatures below.
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? [.. chatResponse.Messages];
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
return new(chatResponse) { AgentId = this.Id };
}
@@ -141,7 +141,7 @@ public sealed class ChatClientAgent : AIAgent
{
var inputMessages = Throw.IfNull(messages);
(ChatClientAgentThread chatClientThread, ChatOptions? chatOptions, List<ChatMessage> threadMessages) =
(AgentThread safeThread, ChatOptions? chatOptions, List<ChatMessage> threadMessages) =
await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false);
int messageCount = threadMessages.Count;
@@ -177,16 +177,20 @@ public sealed class ChatClientAgent : AIAgent
// We can derive the type of supported thread from whether we have a conversation id,
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(chatClientThread, chatResponse.ConversationId);
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, inputMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public override AgentThread GetNewThread() => new ChatClientAgentThread();
public override AgentThread GetNewThread()
{
var thread = new AgentThread() { MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke() };
return thread;
}
#region Private
@@ -312,7 +316,7 @@ public sealed class ChatClientAgent : AIAgent
/// <param name="runOptions">Optional parameters for agent invocation.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A tuple containing the thread, chat options, and thread messages.</returns>
private async Task<(ChatClientAgentThread, ChatOptions?, List<ChatMessage>)> PrepareThreadAndMessagesAsync(
private async Task<(AgentThread, ChatOptions?, List<ChatMessage>)> PrepareThreadAndMessagesAsync(
AgentThread? thread,
IReadOnlyCollection<ChatMessage> inputMessages,
AgentRunOptions? runOptions,
@@ -320,16 +324,13 @@ public sealed class ChatClientAgent : AIAgent
{
ChatOptions? chatOptions = this.CreateConfiguredChatOptions(runOptions);
var chatClientThread = this.ValidateOrCreateThreadType<ChatClientAgentThread>(thread, () => new());
thread ??= this.GetNewThread();
// Add any existing messages from the thread to the messages to be sent to the chat client.
List<ChatMessage> threadMessages = [];
if (chatClientThread is IMessagesRetrievableThread messagesRetrievableThread)
await foreach (ChatMessage message in thread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
await foreach (ChatMessage message in messagesRetrievableThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
threadMessages.Add(message);
}
threadMessages.Add(message);
}
// Update the messages with agent instructions.
@@ -340,39 +341,43 @@ public sealed class ChatClientAgent : AIAgent
// If a user provided two different thread ids, via the thread object and options, we should throw
// since we don't know which one to use.
if (!string.IsNullOrWhiteSpace(chatClientThread.Id) && !string.IsNullOrWhiteSpace(chatOptions?.ConversationId) && chatClientThread.Id != chatOptions.ConversationId)
if (!string.IsNullOrWhiteSpace(thread.ConversationId) && !string.IsNullOrWhiteSpace(chatOptions?.ConversationId) && thread.ConversationId != chatOptions.ConversationId)
{
throw new InvalidOperationException(
$"The {nameof(chatOptions.ConversationId)} provided via {nameof(Microsoft.Extensions.AI.ChatOptions)} is different to the id of the provided {nameof(AgentThread)}. Only one thread id can be used for a run.");
}
// Only clone and update ChatOptions if we have an id on the thread and we don't have the same one already in ChatOptions.
if (!string.IsNullOrWhiteSpace(chatClientThread.Id) && chatClientThread.Id != chatOptions?.ConversationId)
if (!string.IsNullOrWhiteSpace(thread.ConversationId) && thread.ConversationId != chatOptions?.ConversationId)
{
chatOptions ??= new();
chatOptions.ConversationId = chatClientThread.Id;
chatOptions.ConversationId = thread.ConversationId;
}
return (chatClientThread, chatOptions, threadMessages);
return (thread, chatOptions, threadMessages);
}
private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread chatClientThread, string? responseConversationId)
private void UpdateThreadWithTypeAndConversationId(AgentThread thread, string? responseConversationId)
{
// Set the thread's storage location, the first time that we use it.
chatClientThread.StorageLocation ??= string.IsNullOrWhiteSpace(responseConversationId)
? ChatClientAgentThreadType.InMemoryMessages
: ChatClientAgentThreadType.ConversationId;
// If we got a conversation id back from the chat client, it means that the service supports server side thread storage
// so we should capture the id and update the thread with the new id.
if (chatClientThread.StorageLocation == ChatClientAgentThreadType.ConversationId)
if (string.IsNullOrWhiteSpace(responseConversationId) && !string.IsNullOrWhiteSpace(thread.ConversationId))
{
if (string.IsNullOrWhiteSpace(responseConversationId))
{
throw new InvalidOperationException("Service did not return a valid conversation id when using a service managed thread.");
}
// We were passed a thread that is service managed, but we got no conversation id back from the chat client,
// meaning the service doesn't support service managed threads, so the thread cannot be used with this service.
throw new InvalidOperationException("Service did not return a valid conversation id when using a service managed thread.");
}
chatClientThread.Id = responseConversationId;
if (!string.IsNullOrWhiteSpace(responseConversationId))
{
// If we got a conversation id back from the chat client, it means that the service supports server side thread storage
// so we should update the thread with the new id.
thread.ConversationId = responseConversationId;
}
else if (thread.MessageStore is null)
{
// If the service doesn't use service side thread storage (i.e. we got no id back from invocation), and
// the thread has no MessageStore yet, and we have a custom messages store, we should update the thread
// with the custom MessageStore so that it has somewhere to store the chat history.
thread.MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke();
}
}
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
namespace Microsoft.Extensions.AI.Agents;
@@ -72,6 +73,12 @@ public class ChatClientAgentOptions
/// </summary>
public ChatOptions? ChatOptions { get; set; }
/// <summary>
/// Gets or sets a factory function to create an instance of <see cref="IChatMessageStore"/>
/// which will be used to store chat messages for this agent.
/// </summary>
public Func<IChatMessageStore>? ChatMessageStoreFactory { get; set; } = null;
/// <summary>
/// Creates a new instance of <see cref="ChatClientAgentOptions"/> with the same values as this instance.
/// </summary>
@@ -82,6 +89,7 @@ public class ChatClientAgentOptions
Name = this.Name,
Instructions = this.Instructions,
Description = this.Description,
ChatOptions = this.ChatOptions?.Clone()
ChatOptions = this.ChatOptions?.Clone(),
ChatMessageStoreFactory = this.ChatMessageStoreFactory
};
}
@@ -1,185 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// Chat client agent thread.
/// </summary>
[JsonConverter(typeof(Converter))]
public sealed class ChatClientAgentThread : AgentThread, IMessagesRetrievableThread
{
private readonly List<ChatMessage> _chatMessages = [];
/// <summary>
/// Initializes a new instance of the <see cref="ChatClientAgentThread"/> class.
/// </summary>
public ChatClientAgentThread()
{
}
/// <summary>
/// Initializes a new instance of the <see cref="ChatClientAgentThread"/> class.
/// </summary>
/// <param name="id">The id of an existing server side thread to continue.</param>
/// <remarks>
/// This constructor creates a <see cref="ChatClientAgentThread"/> that supports in-service message storage.
/// </remarks>
public ChatClientAgentThread(string id)
{
Throw.IfNullOrWhitespace(id);
this.Id = id;
this.StorageLocation = ChatClientAgentThreadType.ConversationId;
}
/// <summary>
/// Initializes a new instance of the <see cref="ChatClientAgentThread"/> class.
/// </summary>
/// <param name="messages">A set of initial messages to seed the thread with.</param>
/// <remarks>
/// This constructor creates a <see cref="ChatClientAgentThread"/> that supports local in-memory message storage.
/// </remarks>
public ChatClientAgentThread(IEnumerable<ChatMessage> messages)
{
Throw.IfNull(messages);
this._chatMessages.AddRange(messages);
this.StorageLocation = ChatClientAgentThreadType.InMemoryMessages;
}
/// <summary>
/// Gets the location of the thread contents.
/// </summary>
internal ChatClientAgentThreadType? StorageLocation { get; set; }
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc/>
public async IAsyncEnumerable<ChatMessage> GetMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var message in this._chatMessages)
{
yield return message;
}
}
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
/// <inheritdoc/>
protected override Task OnNewMessagesAsync(IReadOnlyCollection<ChatMessage> newMessages, CancellationToken cancellationToken = default)
{
switch (this.StorageLocation)
{
case ChatClientAgentThreadType.InMemoryMessages:
this._chatMessages.AddRange(newMessages);
break;
case ChatClientAgentThreadType.ConversationId:
// If the thread messages are stored in the service
// there is nothing to do here, since invoking the
// service should already update the thread.
break;
default:
throw new UnreachableException();
}
return Task.CompletedTask;
}
/// <summary>
/// Provides a <see cref="JsonConverter"/> for <see cref="ChatClientAgentThread"/> objects.
/// </summary>
[EditorBrowsable(EditorBrowsableState.Never)]
public sealed class Converter : JsonConverter<ChatClientAgentThread>
{
/// <inheritdoc/>
public override ChatClientAgentThread? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
if (reader.TokenType != JsonTokenType.StartObject)
{
throw new JsonException("Expected StartObject token");
}
using var doc = JsonDocument.ParseValue(ref reader);
var root = doc.RootElement;
// Extract properties from JSON
string? id = null;
if (root.TryGetProperty("id", out var idProperty))
{
id = idProperty.GetString();
}
List<ChatMessage>? messages = null;
if (root.TryGetProperty("messages", out var messagesProperty))
{
if (messagesProperty.ValueKind == JsonValueKind.Array)
{
messages = [];
foreach (var messageElement in messagesProperty.EnumerateArray())
{
var message = messageElement.Deserialize(options.GetTypeInfo<ChatMessage>(AgentsJsonContext.Default));
if (message != null)
{
messages.Add(message);
}
}
}
}
// Create the appropriate instance based on available data
// StorageLocation will be set automatically by the constructors
ChatClientAgentThread thread;
if (messages?.Count > 0)
{
thread = new ChatClientAgentThread(messages);
}
else if (!string.IsNullOrWhiteSpace(id))
{
thread = new ChatClientAgentThread(id);
}
else
{
thread = new ChatClientAgentThread();
}
// Override Id if it was explicitly set in JSON (for cases where messages exist but ID is also provided)
if (id != null)
{
thread.Id = id;
}
return thread;
}
/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, ChatClientAgentThread value, JsonSerializerOptions options)
{
writer.WriteStartObject();
// Write base properties
if (value.Id != null)
{
writer.WriteString("id", value.Id);
}
// Write messages if in memory storage (StorageLocation is determined by presence of messages vs ID)
if (value.StorageLocation == ChatClientAgentThreadType.InMemoryMessages)
{
writer.WritePropertyName("messages");
JsonSerializer.Serialize(writer, value._chatMessages, options.GetTypeInfo<List<ChatMessage>>(AgentsJsonContext.Default));
}
writer.WriteEndObject();
}
}
}
@@ -1,19 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// Defines the different supported storage locations for <see cref="ChatClientAgentThread"/>.
/// </summary>
internal enum ChatClientAgentThreadType
{
/// <summary>
/// Messages are stored in memory inside the thread object.
/// </summary>
InMemoryMessages,
/// <summary>
/// Messages are stored in the service and the thread object just has an id reference the service storage.
/// </summary>
ConversationId
}
@@ -210,9 +210,9 @@ public sealed class OpenTelemetryAgent : AIAgent, IDisposable
}
// Add conversation ID if thread is available (following gen_ai.conversation.id convention)
if (!string.IsNullOrWhiteSpace(thread?.Id))
if (!string.IsNullOrWhiteSpace(thread?.ConversationId))
{
_ = activity.AddTag(AgentOpenTelemetryConsts.GenAI.ConversationId, thread.Id);
_ = activity.AddTag(AgentOpenTelemetryConsts.GenAI.ConversationId, thread.ConversationId);
}
// Add instructions if available (for ChatClientAgent)
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using AgentConformance.IntegrationTests;
@@ -29,14 +28,9 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentThread thread)
{
if (thread is not ChatClientAgentThread chatClientThread)
{
throw new InvalidOperationException($"The thread must be of type {nameof(ChatClientAgentThread)} to retrieve chat history.");
}
List<ChatMessage> messages = [];
AsyncPageable<PersistentThreadMessage> threadMessages = this._persistentAgentsClient.Messages.GetMessagesAsync(threadId: thread.Id, order: ListSortOrder.Ascending);
AsyncPageable<PersistentThreadMessage> threadMessages = this._persistentAgentsClient.Messages.GetMessagesAsync(threadId: thread.ConversationId, order: ListSortOrder.Ascending);
await foreach (var threadMessage in threadMessages)
{
@@ -87,9 +81,9 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture
public Task DeleteThreadAsync(AgentThread thread)
{
if (thread?.Id is not null)
if (thread?.ConversationId is not null)
{
return this._persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id);
return this._persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId);
}
return Task.CompletedTask;
@@ -33,7 +33,7 @@ internal sealed class MockAgent(int index) : AIAgent
public override AgentThread GetNewThread()
{
return new AgentThread() { Id = Guid.NewGuid().ToString() };
return new AgentThread() { ConversationId = Guid.NewGuid().ToString() };
}
public override Task<AgentRunResponse> RunAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
@@ -89,8 +89,6 @@ public class OrchestrationResultTests
private sealed class MockAgent : AIAgent
{
public override AgentThread GetNewThread() =>
throw new NotSupportedException();
public override Task<AgentRunResponse> RunAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) =>
throw new NotSupportedException();
public override IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) =>
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
@@ -220,23 +219,6 @@ public class AgentTests
Assert.Equal(id, agent.Id);
}
[Fact]
public void ValidateOrCreateThreadTypeVerifiesAndCreatesThread()
{
// Custom thread type for type checking
var threadMock = new Mock<TestAgentThread>() { CallBase = true };
var agent = new MockAgent();
// Should create
var result = agent.ValidateOrCreateThreadType<TestAgentThread>(null, () => threadMock.Object);
Assert.Same(threadMock.Object, result);
// Should throw if wrong type
var wrongThread = new Mock<AgentThread>().Object;
Assert.Throws<NotSupportedException>(() => agent.ValidateOrCreateThreadType<TestAgentThread>(wrongThread, () => threadMock.Object));
}
[Fact]
public async Task NotifyThreadOfNewMessagesNotifiesThreadAsync()
{
@@ -245,6 +227,8 @@ public class AgentTests
var messages = new[] { new ChatMessage(ChatRole.User, "msg1"), new ChatMessage(ChatRole.User, "msg2") };
var threadMock = new Mock<TestAgentThread>() { CallBase = true };
threadMock.SetupAllProperties();
threadMock.Object.ConversationId = "test-thread-id";
var agent = new MockAgent();
await agent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken);
@@ -257,31 +241,13 @@ public class AgentTests
/// </summary>
public abstract class TestAgentThread : AgentThread;
/// <summary>
/// Mock class to test the <see cref="AIAgent.ValidateOrCreateThreadType{TThreadType}"/> method.
/// </summary>
private sealed class MockAgent : AIAgent
{
public new TThreadType ValidateOrCreateThreadType<TThreadType>(
AgentThread? thread,
Func<TThreadType> constructThread)
where TThreadType : AgentThread
{
return base.ValidateOrCreateThreadType<TThreadType>(
thread,
constructThread);
}
public new Task NotifyThreadOfNewMessagesAsync(AgentThread thread, IReadOnlyCollection<ChatMessage> messages, CancellationToken cancellationToken)
{
return base.NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
}
public override AgentThread GetNewThread()
{
throw new NotImplementedException();
}
public override Task<AgentRunResponse> RunAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
{
throw new System.NotImplementedException();
@@ -0,0 +1,329 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Moq;
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
public class AgentThreadTests
{
#region Constructor and Property Tests
[Fact]
public void ConstructorSetsDefaults()
{
// Arrange & Act
var thread = new AgentThread();
// Assert
Assert.Null(thread.ConversationId);
Assert.Null(thread.MessageStore);
}
[Fact]
public void SetConversationIdRoundtrips()
{
// Arrange
var thread = new AgentThread();
var conversationid = "test-thread-id";
// Act
thread.ConversationId = conversationid;
// Assert
Assert.Equal(conversationid, thread.ConversationId);
Assert.Null(thread.MessageStore);
}
[Fact]
public void SetChatMessageStoreRoundtrips()
{
// Arrange
var thread = new AgentThread();
var messageStore = new InMemoryChatMessageStore();
// Act
thread.MessageStore = messageStore;
// Assert
Assert.Same(messageStore, thread.MessageStore);
Assert.Null(thread.ConversationId);
}
[Fact]
public void SetConversationIdThrowsWhenMessageStoreIsSet()
{
// Arrange
var thread = new AgentThread();
thread.MessageStore = new InMemoryChatMessageStore();
// Act & Assert
var exception = Assert.Throws<InvalidOperationException>(() => thread.ConversationId = "new-thread-id");
Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message);
Assert.NotNull(thread.MessageStore);
}
[Fact]
public void SetChatMessageStoreThrowsWhenConversationIdIsSet()
{
// Arrange
var thread = new AgentThread();
thread.ConversationId = "existing-thread-id";
var store = new InMemoryChatMessageStore();
// Act & Assert
var exception = Assert.Throws<InvalidOperationException>(() => thread.MessageStore = store);
Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message);
Assert.NotNull(thread.ConversationId);
}
#endregion Constructor and Property Tests
#region GetMessagesAsync Tests
[Fact]
public async Task GetMessagesAsyncReturnsEmptyListWhenNoStoreAsync()
{
// Arrange
var thread = new AgentThread();
// Act
var messages = await thread.GetMessagesAsync(CancellationToken.None).ToListAsync();
// Assert
Assert.Empty(messages);
}
[Fact]
public async Task GetMessagesAsyncReturnsEmptyListWhenAgentServiceIdAsync()
{
// Arrange
var thread = new AgentThread() { ConversationId = "thread-123" };
// Act
var messages = await thread.GetMessagesAsync(CancellationToken.None).ToListAsync();
// Assert
Assert.Empty(messages);
}
[Fact]
public async Task GetMessagesAsyncReturnsMessagesFromStoreAsync()
{
// Arrange
var store = new InMemoryChatMessageStore
{
new ChatMessage(ChatRole.User, "Hello"),
new ChatMessage(ChatRole.Assistant, "Hi there!")
};
var thread = new AgentThread() { MessageStore = store };
// Act
var messages = await thread.GetMessagesAsync(CancellationToken.None).ToListAsync();
// Assert
Assert.Equal(2, messages.Count);
Assert.Equal("Hello", messages[0].Text);
Assert.Equal("Hi there!", messages[1].Text);
}
#endregion GetMessagesAsync Tests
#region OnNewMessagesAsync Tests
[Fact]
public async Task OnNewMessagesAsyncDoesNothingWhenAgentServiceIdAsync()
{
// Arrange
var thread = new AgentThread() { ConversationId = "thread-123" };
var messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
// Act
await thread.OnNewMessagesAsync(messages, CancellationToken.None);
}
[Fact]
public async Task OnNewMessagesAsyncAddsMessagesToStoreAsync()
{
// Arrange
var store = new InMemoryChatMessageStore();
var thread = new AgentThread() { MessageStore = store };
var messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
// Act
await thread.OnNewMessagesAsync(messages, CancellationToken.None);
// Assert
Assert.Equal(2, store.Count);
Assert.Equal("Hello", store[0].Text);
Assert.Equal("Hi there!", store[1].Text);
}
#endregion OnNewMessagesAsync Tests
#region Deserialize Tests
[Fact]
public async Task VerifyDeserializeWithMessagesAsync()
{
// Arrange
var chatMessageStore = new InMemoryChatMessageStore();
var json = JsonSerializer.Deserialize<JsonElement>("""
{
"storeState": { "messages": [{"authorName": "testAuthor"}] }
}
""");
var thread = new AgentThread() { MessageStore = chatMessageStore };
// Act.
await thread.DeserializeAsync(json);
// Assert
Assert.Null(thread.ConversationId);
Assert.Single(chatMessageStore);
Assert.Equal("testAuthor", chatMessageStore[0].AuthorName);
}
[Fact]
public async Task VerifyDeserializeWithIdAsync()
{
// Arrange
var json = JsonSerializer.Deserialize<JsonElement>("""
{
"conversationId": "TestConvId"
}
""");
var thread = new AgentThread();
// Act
await thread.DeserializeAsync(json);
// Assert
Assert.Equal("TestConvId", thread.ConversationId);
Assert.Null(thread.MessageStore);
}
[Fact]
public async Task DeserializeWithInvalidJsonThrowsAsync()
{
// Arrange
var invalidJson = JsonSerializer.Deserialize<JsonElement>("[42]");
var thread = new AgentThread();
// Act & Assert
await Assert.ThrowsAsync<JsonException>(() => thread.DeserializeAsync(invalidJson));
}
#endregion Deserialize Tests
#region Serialize Tests
/// <summary>
/// Verify thread serialization to JSON when the thread has an id.
/// </summary>
[Fact]
public async Task VerifyThreadSerializationWithIdAsync()
{
// Arrange
var thread = new AgentThread() { ConversationId = "TestConvId" };
// Act
var json = await thread.SerializeAsync();
// Assert
Assert.Equal(JsonValueKind.Object, json.ValueKind);
Assert.True(json.TryGetProperty("conversationId", out var idProperty));
Assert.Equal("TestConvId", idProperty.GetString());
Assert.False(json.TryGetProperty("storeState", out var storeStateProperty));
}
/// <summary>
/// Verify thread serialization to JSON when the thread has messages.
/// </summary>
[Fact]
public async Task VerifyThreadSerializationWithMessagesAsync()
{
// Arrange
var store = new InMemoryChatMessageStore();
store.Add(new ChatMessage(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" });
var thread = new AgentThread() { MessageStore = store };
// Act
var json = await thread.SerializeAsync();
// Assert
Assert.Equal(JsonValueKind.Object, json.ValueKind);
Assert.False(json.TryGetProperty("conversationId", out var idProperty));
Assert.True(json.TryGetProperty("storeState", out var storeStateProperty));
Assert.Equal(JsonValueKind.Object, storeStateProperty.ValueKind);
Assert.True(storeStateProperty.TryGetProperty("messages", out var messagesProperty));
Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind);
Assert.Single(messagesProperty.EnumerateArray());
var message = messagesProperty.EnumerateArray().First();
Assert.Equal("TestAuthor", message.GetProperty("authorName").GetString());
Assert.True(message.TryGetProperty("contents", out var contentsProperty));
Assert.Equal(JsonValueKind.Array, contentsProperty.ValueKind);
Assert.Single(contentsProperty.EnumerateArray());
var textContent = contentsProperty.EnumerateArray().First();
Assert.Equal("TestContent", textContent.GetProperty("text").GetString());
}
/// <summary>
/// Verify thread serialization to JSON with custom options.
/// </summary>
[Fact]
public async Task VerifyThreadSerializationWithCustomOptionsAsync()
{
// Arrange
var thread = new AgentThread();
JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower };
options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!);
var storeStateElement = JsonSerializer.SerializeToElement(new { Key = "TestValue" });
var messageStoreMock = new Mock<IChatMessageStore>();
messageStoreMock
.Setup(m => m.SerializeStateAsync(options, It.IsAny<CancellationToken>()))
.ReturnsAsync(storeStateElement);
thread.MessageStore = messageStoreMock.Object;
// Act
var json = await thread.SerializeAsync(options);
// Assert
Assert.Equal(JsonValueKind.Object, json.ValueKind);
Assert.False(json.TryGetProperty("conversationId", out var idProperty));
Assert.True(json.TryGetProperty("storeState", out var storeStateProperty));
Assert.Equal(JsonValueKind.Object, storeStateProperty.ValueKind);
Assert.True(storeStateProperty.TryGetProperty("Key", out var keyProperty));
Assert.Equal("TestValue", keyProperty.GetString());
messageStoreMock.Verify(m => m.SerializeStateAsync(options, It.IsAny<CancellationToken>()), Times.Once);
}
#endregion Serialize Tests
}
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
/// <summary>
/// Contains tests for the <see cref="InMemoryChatMessageStore"/> class.
/// </summary>
public class InMemoryChatMessageStoreTests
{
[Fact]
public async Task AddMessagesAsyncAddsMessagesAndReturnsNullThreadIdAsync()
{
var store = new InMemoryChatMessageStore();
var messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
await store.AddMessagesAsync(messages, CancellationToken.None);
Assert.Equal(2, store.Count);
Assert.Equal("Hello", store[0].Text);
Assert.Equal("Hi there!", store[1].Text);
}
[Fact]
public async Task AddMessagesAsyncWithEmptyDoesNotFailAsync()
{
var store = new InMemoryChatMessageStore();
await store.AddMessagesAsync([], CancellationToken.None);
Assert.Empty(store);
}
[Fact]
public async Task GetMessagesAsyncReturnsAllMessagesAsync()
{
var store = new InMemoryChatMessageStore
{
new ChatMessage(ChatRole.User, "Test1"),
new ChatMessage(ChatRole.Assistant, "Test2")
};
var result = (await store.GetMessagesAsync(CancellationToken.None)).ToList();
Assert.Equal(2, result.Count);
Assert.Contains(result, m => m.Text == "Test1");
Assert.Contains(result, m => m.Text == "Test2");
}
[Fact]
public async Task DeserializeWithEmptyElementAsync()
{
var newStore = new InMemoryChatMessageStore();
var emptyObject = JsonSerializer.Deserialize<JsonElement>("{}");
await newStore.DeserializeStateAsync(emptyObject);
Assert.Empty(newStore);
}
[Fact]
public async Task SerializeAndDeserializeRoundtripsAsync()
{
var store = new InMemoryChatMessageStore
{
new ChatMessage(ChatRole.User, "A"),
new ChatMessage(ChatRole.Assistant, "B")
};
var jsonElement = await store.SerializeStateAsync();
var newStore = new InMemoryChatMessageStore();
await newStore.DeserializeStateAsync(jsonElement);
Assert.Equal(2, newStore.Count);
Assert.Equal("A", newStore[0].Text);
Assert.Equal("B", newStore[1].Text);
}
[Fact]
public async Task AddMessagesAsyncWithEmptyMessagesDoesNotChangeStoreAsync()
{
var store = new InMemoryChatMessageStore();
var messages = new List<ChatMessage>();
await store.AddMessagesAsync(messages, CancellationToken.None);
Assert.Empty(store);
}
}
@@ -321,7 +321,7 @@ public class ChatClientAgentTests
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ConvId");
AgentThread thread = new() { ConversationId = "ConvId" };
// Act & Assert
await agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions);
@@ -340,7 +340,7 @@ public class ChatClientAgentTests
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ThreadId");
AgentThread thread = new() { ConversationId = "ThreadId" };
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions));
@@ -363,7 +363,7 @@ public class ChatClientAgentTests
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ConvId");
AgentThread thread = new() { ConversationId = "ConvId" };
// Act
await agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions);
@@ -388,7 +388,7 @@ public class ChatClientAgentTests
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" });
ChatClientAgentThread thread = new("ConvId");
AgentThread thread = new() { ConversationId = "ConvId" };
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync([new(ChatRole.User, "test")], thread));
@@ -1,978 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Moq;
#pragma warning disable CS0162 // Unreachable code detected
namespace Microsoft.Extensions.AI.Agents.UnitTests.ChatCompletion;
public class ChatClientAgentThreadTests
{
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> implements <see cref="IMessagesRetrievableThread"/>.
/// </summary>
[Fact]
public void VerifyChatClientAgentThreadImplementsIMessagesRetrievableThread()
{
// Arrange & Act
var thread = new ChatClientAgentThread();
// Assert
Assert.IsType<IMessagesRetrievableThread>(thread, exactMatch: false);
Assert.IsType<AgentThread>(thread, exactMatch: false);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> can retrieve messages through <see cref="IMessagesRetrievableThread.GetMessagesAsync"/>.
/// This test verifies the interface works correctly when no messages have been added.
/// </summary>
[Fact]
public async Task VerifyIMessagesRetrievableThreadGetMessagesAsyncWhenEmptyAsync()
{
// Arrange
var thread = new ChatClientAgentThread();
// Act - Retrieve messages when thread is empty
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in thread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert
Assert.Empty(retrievedMessages);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> can retrieve messages through <see cref="IMessagesRetrievableThread.GetMessagesAsync"/>.
/// This test verifies the interface works correctly when messages have been added via ChatClientAgent.
/// </summary>
[Fact]
public async Task VerifyIMessagesRetrievableThreadGetMessagesAsyncWhenNotEmptyAsync()
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello, how are you?");
var assistantMessage = new ChatMessage(ChatRole.Assistant, "I'm doing well, thank you!");
// Mock IChatClient to return the assistant message
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new ChatResponse([assistantMessage]));
// Create ChatClientAgent with the mocked client
var agent = new ChatClientAgent(mockChatClient.Object, options: new()
{
Instructions = "You are a helpful assistant"
});
// Get a new thread from the agent
var thread = agent.GetNewThread();
// Run the agent again with the thread to populate it with messages
var responseWithThread = await agent.RunAsync([userMessage], thread);
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
// Retrieve messages through the interface
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert
Assert.NotEmpty(retrievedMessages);
// Verify that the messages include the assistant response
Assert.Collection(retrievedMessages,
m => Assert.Equal(ChatRole.User, m.Role),
m => Assert.Equal(ChatRole.Assistant, m.Role));
// Verify the content matches what we expect
Assert.Contains(retrievedMessages, m => m.Text == "Hello, how are you?" && m.Role == ChatRole.User);
Assert.Contains(retrievedMessages, m => m.Text == "I'm doing well, thank you!" && m.Role == ChatRole.Assistant);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread.GetMessagesAsync"/> works with cancellation token.
/// </summary>
[Fact]
public async Task VerifyGetMessagesAsyncWithCancellationTokenAsync()
{
// Arrange
var thread = new ChatClientAgentThread();
using var cts = new CancellationTokenSource();
// Act - Test that GetMessagesAsync accepts cancellation token without throwing
var retrievedMessages = new List<ChatMessage>();
await foreach (var msg in thread.GetMessagesAsync(cts.Token))
{
retrievedMessages.Add(msg);
}
// Assert - Should return empty list when no messages
Assert.Empty(retrievedMessages);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> initializes with expected default values.
/// </summary>
[Fact]
public void VerifyThreadInitialState()
{
// Arrange & Act
var thread = new ChatClientAgentThread();
// Assert
Assert.Null(thread.Id); // Id should be null until created on first use.
Assert.Null(thread.StorageLocation); // StorageLocation should be null until first use
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> initializes with expected default values.
/// </summary>
[Fact]
public async Task VerifyThreadWithMessagesInitialStateAsync()
{
// Arrange
var message = new ChatMessage(ChatRole.User, "Hello");
// Act
var thread = new ChatClientAgentThread([message]);
// Assert
Assert.Null(thread.Id); // Id should be null when we add messages, since it's a local thread.
Assert.Equal(ChatClientAgentThreadType.InMemoryMessages, thread.StorageLocation); // StorageLocation should be set to local since we are adding messages already.
var messages = await thread.GetMessagesAsync().ToListAsync();
Assert.Contains(message, messages);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> initializes with expected default values.
/// </summary>
[Fact]
public async Task VerifyThreadWithIdInitialStateAsync()
{
// Act
var thread = new ChatClientAgentThread("TestConvId");
// Assert
Assert.Equal("TestConvId", thread.Id);
Assert.Equal(ChatClientAgentThreadType.ConversationId, thread.StorageLocation);
var messages = await thread.GetMessagesAsync().ToListAsync();
Assert.Empty(messages);
}
#region Core Override Method Tests
/// <summary>
/// Verify that thread creation generates a valid thread ID through integration with ChatClientAgent.
/// </summary>
[Fact]
public void ThreadCreationGeneratesValidThreadId()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new ChatResponse([new ChatMessage(ChatRole.Assistant, "response")]));
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
// Act
var thread = agent.GetNewThread();
// Assert
Assert.NotNull(thread);
var chatClientAgentThread = Assert.IsType<ChatClientAgentThread>(thread);
Assert.Null(thread.Id); // Id should be null until created on first use.
Assert.Null(chatClientAgentThread.StorageLocation); // StorageLocation should be null until first use
}
/// <summary>
/// Verify that thread creation generates unique instances.
/// </summary>
[Fact]
public void ThreadCreationGeneratesUniqueInstances()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
// Act
var thread1 = agent.GetNewThread();
var thread2 = agent.GetNewThread();
// Assert
Assert.NotSame(thread1, thread2);
Assert.IsType<ChatClientAgentThread>(thread1);
Assert.IsType<ChatClientAgentThread>(thread2);
}
/// <summary>
/// Verify that messages are properly stored and retrieved through the thread lifecycle.
/// </summary>
[Theory]
[InlineData(null, true)]
[InlineData("TestConvid", false)]
public async Task ThreadLifecycleStoresAndRetrievesMessagesAsync(string? responseConversationId, bool messagesStored)
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello");
var assistantMessage = new ChatMessage(ChatRole.Assistant, "Hi there!");
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new ChatResponse([assistantMessage]) { ConversationId = responseConversationId });
var agent = new ChatClientAgent(mockChatClient.Object, options: new() { Instructions = "Test instructions" });
// Act
var thread = agent.GetNewThread();
// Run the agent to populate the thread with messages
await agent.RunAsync([userMessage], thread);
// Retrieve messages from the thread
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in ((IMessagesRetrievableThread)thread).GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert
Assert.Equal(messagesStored ? 2 : 0, retrievedMessages.Count);
if (messagesStored)
{
Assert.Contains(retrievedMessages, m => m.Text == "Hello" && m.Role == ChatRole.User);
Assert.Contains(retrievedMessages, m => m.Text == "Hi there!" && m.Role == ChatRole.Assistant);
}
var chatClientAgentThread = Assert.IsType<ChatClientAgentThread>(thread);
Assert.Equal(responseConversationId, thread.Id); // Id should match the returned conversation id.
Assert.Equal(
messagesStored
? ChatClientAgentThreadType.InMemoryMessages
: ChatClientAgentThreadType.ConversationId,
chatClientAgentThread.StorageLocation); // StorageLocation should be based on whether we got back a conversation id
}
/// <summary>
/// Verify that multiple messages can be added and retrieved in order.
/// </summary>
[Fact]
public async Task ThreadMessageHandlingHandlesMultipleMessagesInOrderAsync()
{
// Arrange
var messages = new[]
{
new ChatMessage(ChatRole.User, "First message"),
new ChatMessage(ChatRole.Assistant, "First response"),
new ChatMessage(ChatRole.User, "Second message"),
new ChatMessage(ChatRole.Assistant, "Second response")
};
var mockChatClient = new Mock<IChatClient>();
mockChatClient.SetupSequence(
c => c.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new ChatResponse([messages[1]]))
.ReturnsAsync(new ChatResponse([messages[3]]));
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
var thread = agent.GetNewThread();
// Act - Add messages through multiple agent runs
await agent.RunAsync([messages[0]], thread);
await agent.RunAsync([messages[2]], thread);
// Assert - Verify all messages are stored in order
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in ((IMessagesRetrievableThread)thread).GetMessagesAsync())
{
retrievedMessages.Add(message);
}
Assert.Equal(4, retrievedMessages.Count);
Assert.Equal("First message", retrievedMessages[0].Text);
Assert.Equal("First response", retrievedMessages[1].Text);
Assert.Equal("Second message", retrievedMessages[2].Text);
Assert.Equal("Second response", retrievedMessages[3].Text);
}
#endregion
#region RunStreamingAsync Thread Notification Tests
/// <summary>
/// Verify that thread is notified of both input and response messages when invoking the streaming API with RunStreamingAsync.
/// </summary>
[Fact]
public async Task VerifyThreadNotificationDuringStreamingAsync()
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello, streaming!");
var assistantMessage = new ChatMessage(ChatRole.Assistant, "Hi there, streaming response!");
// Create streaming response updates
ChatResponseUpdate[] returnUpdates =
[
new ChatResponseUpdate(role: ChatRole.Assistant, content: "Hi there, "),
new ChatResponseUpdate(role: null, content: "streaming response!"),
];
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Returns(returnUpdates.ToAsyncEnumerable());
// Create ChatClientAgent with the mocked client
var agent = new ChatClientAgent(mockChatClient.Object, options: new()
{
Instructions = "You are a helpful assistant"
});
// Get a new thread from the agent
var thread = agent.GetNewThread();
// Act - Run the agent with streaming to populate the thread with messages
var streamingResults = new List<AgentRunResponseUpdate>();
await foreach (var update in agent.RunStreamingAsync([userMessage], thread))
{
streamingResults.Add(update);
}
// Assert - Verify streaming worked
Assert.Equal(2, streamingResults.Count);
// Retrieve messages from the thread to verify notification occurred
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert - Verify that the thread was notified and contains both user and assistant messages
Assert.NotEmpty(retrievedMessages);
Assert.Equal(2, retrievedMessages.Count);
Assert.Contains(retrievedMessages, m => m.Text == "Hello, streaming!" && m.Role == ChatRole.User);
Assert.Contains(retrievedMessages, m => m.Text == "Hi there, streaming response!" && m.Role == ChatRole.Assistant);
}
/// <summary>
/// Verify that thread accumulates both input and response messages across multiple streaming calls.
/// </summary>
[Fact]
public async Task VerifyThreadAccumulatesMessagesAcrossMultipleStreamingCallsAsync()
{
// Arrange
var firstUserMessage = new ChatMessage(ChatRole.User, "First streaming message");
var secondUserMessage = new ChatMessage(ChatRole.User, "Second streaming message");
// Create streaming response updates for first call
ChatResponseUpdate[] firstReturnUpdates =
[
new ChatResponseUpdate(role: ChatRole.Assistant, content: "First "),
new ChatResponseUpdate(role: null, content: "response"),
];
// Create streaming response updates for second call
ChatResponseUpdate[] secondReturnUpdates =
[
new ChatResponseUpdate(role: ChatRole.Assistant, content: "Second "),
new ChatResponseUpdate(role: null, content: "response"),
];
var mockChatClient = new Mock<IChatClient>();
mockChatClient.SetupSequence(
c => c.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Returns(firstReturnUpdates.ToAsyncEnumerable())
.Returns(secondReturnUpdates.ToAsyncEnumerable());
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
var thread = agent.GetNewThread();
// Act - Make two streaming calls
var firstStreamingResults = new List<AgentRunResponseUpdate>();
await foreach (var update in agent.RunStreamingAsync([firstUserMessage], thread))
{
firstStreamingResults.Add(update);
}
var secondStreamingResults = new List<AgentRunResponseUpdate>();
await foreach (var update in agent.RunStreamingAsync([secondUserMessage], thread))
{
secondStreamingResults.Add(update);
}
// Assert - Verify both streaming calls worked
Assert.Equal(2, firstStreamingResults.Count);
Assert.Equal(2, secondStreamingResults.Count);
// Retrieve all messages from the thread
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert - Verify that the thread contains all messages in order
Assert.Equal(4, retrievedMessages.Count);
Assert.Equal("First streaming message", retrievedMessages[0].Text);
Assert.Equal("First response", retrievedMessages[1].Text);
Assert.Equal("Second streaming message", retrievedMessages[2].Text);
Assert.Equal("Second response", retrievedMessages[3].Text);
}
/// <summary>
/// Verify that thread notification works correctly when streaming with existing thread messages.
/// Both RunAsync and RunStreamingAsync should add both input and response messages to the thread.
/// </summary>
[Fact]
public async Task VerifyStreamingWithExistingThreadMessagesAsync()
{
// Arrange
var initialUserMessage = new ChatMessage(ChatRole.User, "Initial message");
var initialAssistantMessage = new ChatMessage(ChatRole.Assistant, "Initial response");
var newUserMessage = new ChatMessage(ChatRole.User, "New streaming message");
// Setup for initial non-streaming call
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new ChatResponse([initialAssistantMessage]));
// Setup for streaming call
ChatResponseUpdate[] streamingUpdates =
[
new ChatResponseUpdate(role: ChatRole.Assistant, content: "Streaming "),
new ChatResponseUpdate(role: null, content: "response"),
];
mockChatClient.Setup(
c => c.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Returns(streamingUpdates.ToAsyncEnumerable());
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
var thread = agent.GetNewThread();
// Act - First, make a regular call to populate the thread
await agent.RunAsync([initialUserMessage], thread);
// Then make a streaming call
var streamingResults = new List<AgentRunResponseUpdate>();
await foreach (var update in agent.RunStreamingAsync([newUserMessage], thread))
{
streamingResults.Add(update);
}
// Assert - Verify streaming worked
Assert.Equal(2, streamingResults.Count);
// Retrieve all messages from the thread
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert - Verify that the thread contains all messages including the new streaming ones
Assert.Equal(4, retrievedMessages.Count);
Assert.Equal("Initial message", retrievedMessages[0].Text);
Assert.Equal("Initial response", retrievedMessages[1].Text);
Assert.Equal("New streaming message", retrievedMessages[2].Text);
Assert.Equal("Streaming response", retrievedMessages[3].Text);
}
/// <summary>
/// Verify that thread is notified of input messages even when zero streaming updates are received.
/// </summary>
[Fact]
public async Task VerifyThreadNotificationWithZeroStreamingUpdatesAsync()
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello with no response!");
// Create empty streaming response (no updates)
ChatResponseUpdate[] returnUpdates = [];
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Returns(returnUpdates.ToAsyncEnumerable());
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
var thread = agent.GetNewThread();
// Act - Run the agent with streaming that returns no updates
var streamingResults = new List<AgentRunResponseUpdate>();
await foreach (var update in agent.RunStreamingAsync([userMessage], thread))
{
streamingResults.Add(update);
}
// Assert - Verify no streaming updates were received
Assert.Empty(streamingResults);
// Retrieve messages from the thread to verify notification occurred
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert - Verify that the thread was notified of input messages even with zero updates
// The fallback mechanism should ensure input messages are added to the thread
Assert.Single(retrievedMessages);
Assert.Contains(retrievedMessages, m => m.Text == "Hello with no response!" && m.Role == ChatRole.User);
}
/// <summary>
/// Verify that thread is notified of input messages only once even with multiple streaming updates.
/// </summary>
[Fact]
public async Task VerifyThreadNotificationWithMultipleStreamingUpdatesAsync()
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello with many updates!");
// Create multiple streaming response updates
ChatResponseUpdate[] returnUpdates =
[
new ChatResponseUpdate(role: ChatRole.Assistant, content: "First "),
new ChatResponseUpdate(role: null, content: "update, "),
new ChatResponseUpdate(role: null, content: "second "),
new ChatResponseUpdate(role: null, content: "update, "),
new ChatResponseUpdate(role: null, content: "third "),
new ChatResponseUpdate(role: null, content: "update!"),
];
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Returns(returnUpdates.ToAsyncEnumerable());
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
var thread = agent.GetNewThread();
// Act - Run the agent with streaming that returns multiple updates
var streamingResults = new List<AgentRunResponseUpdate>();
await foreach (var update in agent.RunStreamingAsync([userMessage], thread))
{
streamingResults.Add(update);
}
// Assert - Verify all streaming updates were received
Assert.Equal(6, streamingResults.Count);
// Retrieve messages from the thread to verify notification occurred
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert - Verify that the thread contains both input and response messages
// Input message should be added only once despite multiple updates
Assert.Equal(2, retrievedMessages.Count);
Assert.Contains(retrievedMessages, m => m.Text == "Hello with many updates!" && m.Role == ChatRole.User);
Assert.Contains(retrievedMessages, m => m.Text == "First update, second update, third update!" && m.Role == ChatRole.Assistant);
}
/// <summary>
/// Verify that thread is NOT notified of input messages when an exception occurs during streaming.
/// </summary>
[Fact]
public async Task VerifyThreadNotNotifiedWhenStreamingThrowsExceptionAsync()
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello that will fail!");
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Throws(new InvalidOperationException("Streaming failed"));
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
var thread = agent.GetNewThread();
// Act & Assert - Verify that streaming throws an exception
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
await foreach (var update in agent.RunStreamingAsync([userMessage], thread))
{
Assert.Fail("Should not yield updates.");
}
});
// Retrieve messages from the thread to verify NO notification occurred
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert - Verify that the thread was NOT notified of any messages due to the exception
// This ensures that failed operations don't leave the thread in an inconsistent state
Assert.Empty(retrievedMessages);
}
/// <summary>
/// Verify that thread is NOT notified of input messages when an exception occurs after some streaming updates.
/// </summary>
[Fact]
public async Task VerifyThreadNotNotifiedWhenStreamingThrowsExceptionAfterUpdatesAsync()
{
// Arrange
var userMessage = new ChatMessage(ChatRole.User, "Hello that will partially fail!");
// Create an async enumerable that yields some updates then throws
static async IAsyncEnumerable<ChatResponseUpdate> GetUpdatesWithExceptionAsync()
{
await Task.CompletedTask; // Simulate async operation
throw new InvalidOperationException("Streaming failed after partial response");
yield break;
}
var mockChatClient = new Mock<IChatClient>();
mockChatClient.Setup(
c => c.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Returns(GetUpdatesWithExceptionAsync());
var agent = new ChatClientAgent(mockChatClient.Object, options: new());
var thread = agent.GetNewThread();
// Act & Assert - Verify that streaming throws an exception after some updates
var streamingResults = new List<AgentRunResponseUpdate>();
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
await foreach (var update in agent.RunStreamingAsync([userMessage], thread))
{
streamingResults.Add(update);
}
});
// Verify that some updates were received before the exception
Assert.Empty(streamingResults);
// Retrieve messages from the thread to verify NO notification occurred
var messagesRetrievableThread = (IMessagesRetrievableThread)thread;
var retrievedMessages = new List<ChatMessage>();
await foreach (var message in messagesRetrievableThread.GetMessagesAsync())
{
retrievedMessages.Add(message);
}
// Assert - Verify that the thread was NOT notified of any messages due to the exception
// Even though some updates were received, the exception should prevent thread notification
Assert.Empty(retrievedMessages);
}
#endregion
#region JSON Serialization Tests
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> can be serialized to JSON and deserialized back correctly.
/// </summary>
[Fact]
public void VerifyJsonSerializationRoundTrip_DefaultThread()
{
// Arrange
var originalThread = new ChatClientAgentThread();
// Act
string json = JsonSerializer.Serialize(originalThread);
var deserializedThread = JsonSerializer.Deserialize<ChatClientAgentThread>(json);
// Assert
Assert.NotNull(deserializedThread);
Assert.Equal(originalThread.Id, deserializedThread.Id);
Assert.Equal(originalThread.StorageLocation, deserializedThread.StorageLocation);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> with ID can be serialized to JSON and deserialized back correctly.
/// </summary>
[Fact]
public void VerifyJsonSerializationRoundTrip_ThreadWithId()
{
// Arrange
var originalThread = new ChatClientAgentThread("test-conversation-id");
// Act
string json = JsonSerializer.Serialize(originalThread);
var deserializedThread = JsonSerializer.Deserialize<ChatClientAgentThread>(json);
// Assert
Assert.NotNull(deserializedThread);
Assert.Equal("test-conversation-id", deserializedThread.Id);
Assert.Equal(ChatClientAgentThreadType.ConversationId, deserializedThread.StorageLocation);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> with messages can be serialized to JSON and deserialized back correctly.
/// </summary>
[Fact]
public async Task VerifyJsonSerializationRoundTrip_ThreadWithMessagesAsync()
{
// Arrange
var messages = new[]
{
new ChatMessage(ChatRole.User, "Hello, world!"),
new ChatMessage(ChatRole.Assistant, "Hi there! How can I help you?"),
new ChatMessage(ChatRole.User, "What's the weather like?")
};
var originalThread = new ChatClientAgentThread(messages);
// Act
string json = JsonSerializer.Serialize(originalThread);
var deserializedThread = JsonSerializer.Deserialize<ChatClientAgentThread>(json);
// Assert
Assert.NotNull(deserializedThread);
Assert.Equal(originalThread.Id, deserializedThread.Id);
Assert.Equal(ChatClientAgentThreadType.InMemoryMessages, deserializedThread.StorageLocation);
// Verify messages are preserved
var originalMessages = await originalThread.GetMessagesAsync().ToListAsync();
var deserializedMessages = await deserializedThread.GetMessagesAsync().ToListAsync();
Assert.Equal(originalMessages.Count, deserializedMessages.Count);
for (int i = 0; i < originalMessages.Count; i++)
{
Assert.Equal(originalMessages[i].Role, deserializedMessages[i].Role);
Assert.Equal(originalMessages[i].Text, deserializedMessages[i].Text);
}
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> serialization handles null properties correctly.
/// </summary>
[Fact]
public void VerifyJsonSerialization_HandlesNullProperties()
{
// Arrange
var thread = new ChatClientAgentThread();
// Act
string json = JsonSerializer.Serialize(thread);
// Assert - StorageLocation is no longer serialized independently
Assert.DoesNotContain("storageLocation", json, StringComparison.OrdinalIgnoreCase);
// Verify deserialization handles empty JSON correctly
var deserializedThread = JsonSerializer.Deserialize<ChatClientAgentThread>(json);
Assert.NotNull(deserializedThread);
Assert.Null(deserializedThread.StorageLocation);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> serialization only includes messages for InMemoryMessages storage type.
/// </summary>
[Fact]
public void VerifyJsonSerialization_OnlyIncludesMessagesForInMemoryStorage()
{
// Arrange - Create thread with conversation ID (server-side storage)
var threadWithId = new ChatClientAgentThread("test-id");
// Act
string json = JsonSerializer.Serialize(threadWithId);
// Assert - Messages should not be included for ConversationId storage, and storageLocation is not serialized
Assert.DoesNotContain("\"messages\"", json, StringComparison.OrdinalIgnoreCase);
Assert.DoesNotContain("\"storageLocation\"", json, StringComparison.OrdinalIgnoreCase);
Assert.Contains("\"id\":\"test-id\"", json, StringComparison.OrdinalIgnoreCase);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> serialization includes messages for InMemoryMessages storage type.
/// </summary>
[Fact]
public void VerifyJsonSerialization_IncludesMessagesForInMemoryStorage()
{
// Arrange - Create thread with messages (in-memory storage)
var messages = new[] { new ChatMessage(ChatRole.User, "Test message") };
var threadWithMessages = new ChatClientAgentThread(messages);
// Act
string json = JsonSerializer.Serialize(threadWithMessages);
// Assert - Messages should be included for InMemoryMessages storage, but storageLocation is not serialized
Assert.Contains("\"messages\"", json, StringComparison.OrdinalIgnoreCase);
Assert.DoesNotContain("\"storageLocation\"", json, StringComparison.OrdinalIgnoreCase);
Assert.Contains("Test message", json, StringComparison.OrdinalIgnoreCase);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> deserialization handles missing properties gracefully.
/// </summary>
[Fact]
public void VerifyJsonDeserialization_HandlesMissingProperties()
{
// Arrange - JSON with minimal properties
string minimalJson = "{}";
// Act
var thread = JsonSerializer.Deserialize<ChatClientAgentThread>(minimalJson);
// Assert
Assert.NotNull(thread);
Assert.Null(thread.Id);
Assert.Null(thread.StorageLocation);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> deserialization handles invalid JSON gracefully.
/// </summary>
[Fact]
public void VerifyJsonDeserialization_HandlesMalformedJson()
{
// Arrange - Invalid JSON structure
#pragma warning disable JSON001 // Invalid JSON pattern
string invalidJson = "{ invalid json";
#pragma warning restore JSON001 // Invalid JSON pattern
// Act & Assert
Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<ChatClientAgentThread>(invalidJson));
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> deserialization handles invalid storage location values.
/// This test is no longer relevant since storageLocation is not independently deserialized.
/// </summary>
[Fact]
public void VerifyJsonDeserialization_HandlesInvalidStorageLocation()
{
// Arrange - JSON with ID (which will set storage location to ConversationId)
string jsonWithId = @"{""id"":""test""}";
// Act
var thread = JsonSerializer.Deserialize<ChatClientAgentThread>(jsonWithId);
// Assert - Storage location is determined by presence of ID
Assert.NotNull(thread);
Assert.Equal("test", thread.Id);
Assert.Equal(ChatClientAgentThreadType.ConversationId, thread.StorageLocation);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> deserialization preserves messages correctly.
/// </summary>
[Fact]
public async Task VerifyJsonDeserialization_PreservesMessagesCorrectlyAsync()
{
// Arrange - Create a thread with messages and serialize it to get the correct format
var originalMessages = new[]
{
new ChatMessage(ChatRole.User, "Hello"),
new ChatMessage(ChatRole.Assistant, "Hi there!")
};
var originalThread = new ChatClientAgentThread(originalMessages);
// Serialize to get the actual format, then deserialize
string json = JsonSerializer.Serialize(originalThread);
var thread = JsonSerializer.Deserialize<ChatClientAgentThread>(json);
// Assert
Assert.NotNull(thread);
Assert.Equal(ChatClientAgentThreadType.InMemoryMessages, thread.StorageLocation);
var messages = await thread.GetMessagesAsync().ToListAsync();
Assert.Equal(2, messages.Count);
Assert.Equal(ChatRole.User, messages[0].Role);
Assert.Equal("Hello", messages[0].Text);
Assert.Equal(ChatRole.Assistant, messages[1].Role);
Assert.Equal("Hi there!", messages[1].Text);
}
/// <summary>
/// Verify that <see cref="ChatClientAgentThread"/> serialization and deserialization works with complex message content.
/// </summary>
[Fact]
public async Task VerifyJsonSerializationRoundTrip_ComplexMessageContentAsync()
{
// Arrange - Create messages with various content types
var messages = new[]
{
new ChatMessage(ChatRole.User, "Simple text message"),
new ChatMessage(ChatRole.Assistant, [
new TextContent("Mixed content: "),
new TextContent("multiple parts")
]),
new ChatMessage(ChatRole.User, "Message with special characters: ñáéíóú !@#$%^&*()")
};
var originalThread = new ChatClientAgentThread(messages);
// Act
string json = JsonSerializer.Serialize(originalThread);
var deserializedThread = JsonSerializer.Deserialize<ChatClientAgentThread>(json);
// Assert
Assert.NotNull(deserializedThread);
var originalMessages = await originalThread.GetMessagesAsync().ToListAsync();
var deserializedMessages = await deserializedThread.GetMessagesAsync().ToListAsync();
Assert.Equal(originalMessages.Count, deserializedMessages.Count);
// Verify complex content is preserved
for (int i = 0; i < originalMessages.Count; i++)
{
Assert.Equal(originalMessages[i].Role, deserializedMessages[i].Role);
Assert.Equal(originalMessages[i].Text, deserializedMessages[i].Text);
}
}
#endregion
}
@@ -388,7 +388,7 @@ public class OpenTelemetryAgentTests
new(ChatRole.User, "Hello")
};
var thread = new AgentThread { Id = "thread-123" };
var thread = new AgentThread { ConversationId = "thread-123" };
// Act
await telemetryAgent.RunAsync(messages, thread);
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using AgentConformance.IntegrationTests;
@@ -28,13 +27,8 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentThread thread)
{
if (thread is not ChatClientAgentThread chatClientThread)
{
throw new InvalidOperationException("The thread must be of type ChatClientAgentThread to retrieve chat history.");
}
List<ChatMessage> messages = [];
await foreach (var agentMessage in this._assistantClient!.GetMessagesAsync(chatClientThread.Id, new() { Order = MessageCollectionOrder.Ascending }))
await foreach (var agentMessage in this._assistantClient!.GetMessagesAsync(thread.ConversationId, new() { Order = MessageCollectionOrder.Ascending }))
{
messages.Add(new()
{
@@ -79,9 +73,9 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture
public Task DeleteThreadAsync(AgentThread thread)
{
if (thread?.Id is not null)
if (thread?.ConversationId is not null)
{
return this._assistantClient!.DeleteThreadAsync(thread.Id);
return this._assistantClient!.DeleteThreadAsync(thread.ConversationId);
}
return Task.CompletedTask;
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
@@ -33,12 +32,7 @@ public class OpenAIChatCompletionFixture : IChatClientAgentFixture
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentThread thread)
{
if (thread is not ChatClientAgentThread chatClientThread)
{
throw new InvalidOperationException("The thread must be of type ChatClientAgentThread to retrieve chat history.");
}
return await chatClientThread.GetMessagesAsync().ToListAsync();
return await thread.GetMessagesAsync().ToListAsync();
}
public Task<ChatClientAgent> CreateChatClientAgentAsync(
@@ -29,15 +29,10 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentThread thread)
{
if (thread is not ChatClientAgentThread chatClientThread)
{
throw new InvalidOperationException("The thread must be of type ChatClientAgentThread to retrieve chat history.");
}
if (store)
{
var inputItems = await this._openAIResponseClient.GetResponseInputItemsAsync(chatClientThread.Id).ToListAsync();
var response = await this._openAIResponseClient.GetResponseAsync(chatClientThread.Id);
var inputItems = await this._openAIResponseClient.GetResponseInputItemsAsync(thread.ConversationId).ToListAsync();
var response = await this._openAIResponseClient.GetResponseAsync(thread.ConversationId);
var responseItem = response.Value.OutputItems.FirstOrDefault()!;
// Take the messages that were the chat history leading up to the current response
@@ -55,7 +50,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture
return [.. previousMessages, responseMessage];
}
return await chatClientThread.GetMessagesAsync().ToListAsync();
return await thread.GetMessagesAsync().ToListAsync();
}
private static ChatMessage ConvertToChatMessage(ResponseItem item)