diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index d1f31f3b5f..19c7cb6f84 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -55,6 +55,8 @@ + + diff --git a/dotnet/samples/GettingStarted/AgentSample.cs b/dotnet/samples/GettingStarted/AgentSample.cs index 57e2305c72..273749fc9a 100644 --- a/dotnet/samples/GettingStarted/AgentSample.cs +++ b/dotnet/samples/GettingStarted/AgentSample.cs @@ -167,7 +167,7 @@ public class AgentSample(ITestOutputHelper output) : BaseSample(output) // If a thread is provided, delete it as well. if (thread is not null) { - await persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id, cancellationToken); + await persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId, cancellationToken); } } @@ -183,7 +183,7 @@ public class AgentSample(ITestOutputHelper output) : BaseSample(output) // If a thread is provided, delete it as well. if (thread is not null) { - await assistantClient.DeleteThreadAsync(thread.Id, cancellationToken); + await assistantClient.DeleteThreadAsync(thread.ConversationId, cancellationToken); } } diff --git a/dotnet/samples/GettingStarted/GettingStarted.csproj b/dotnet/samples/GettingStarted/GettingStarted.csproj index 0707da1ad5..3722c8294a 100644 --- a/dotnet/samples/GettingStarted/GettingStarted.csproj +++ b/dotnet/samples/GettingStarted/GettingStarted.csproj @@ -27,6 +27,7 @@ + diff --git a/dotnet/samples/GettingStarted/Providers/AIAgent_With_AzureAIAgentsPersistent.cs b/dotnet/samples/GettingStarted/Providers/AIAgent_With_AzureAIAgentsPersistent.cs index 89f4d704d5..9e3dccc7eb 100644 --- a/dotnet/samples/GettingStarted/Providers/AIAgent_With_AzureAIAgentsPersistent.cs +++ b/dotnet/samples/GettingStarted/Providers/AIAgent_With_AzureAIAgentsPersistent.cs @@ -51,7 +51,7 @@ public sealed class AIAgent_With_AzureAIAgentsPersistent(ITestOutputHelper outpu } // Cleanup - await persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id); + await persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId); await persistentAgentsClient.Administration.DeleteAgentAsync(agent.Id); } @@ -85,7 +85,7 @@ public sealed class AIAgent_With_AzureAIAgentsPersistent(ITestOutputHelper outpu } // Cleanup - await persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id); + await persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId); await persistentAgentsClient.Administration.DeleteAgentAsync(agent.Id); } } diff --git a/dotnet/samples/GettingStarted/Providers/AIAgent_With_OpenAIAssistant.cs b/dotnet/samples/GettingStarted/Providers/AIAgent_With_OpenAIAssistant.cs index df97729226..13fc621133 100644 --- a/dotnet/samples/GettingStarted/Providers/AIAgent_With_OpenAIAssistant.cs +++ b/dotnet/samples/GettingStarted/Providers/AIAgent_With_OpenAIAssistant.cs @@ -52,7 +52,7 @@ public sealed class AIAgent_With_OpenAIAssistant(ITestOutputHelper output) : Age // Cleanup var assistantClient = openAIClient.GetAssistantClient(); - await assistantClient.DeleteThreadAsync(thread.Id); + await assistantClient.DeleteThreadAsync(thread.ConversationId); await assistantClient.DeleteAssistantAsync(agent.Id); } } diff --git a/dotnet/samples/GettingStarted/Steps/Step08_ChatClientAgent_SuspendResumeThread.cs b/dotnet/samples/GettingStarted/Steps/Step08_ChatClientAgent_SuspendResumeThread.cs new file mode 100644 index 0000000000..515aae080a --- /dev/null +++ b/dotnet/samples/GettingStarted/Steps/Step08_ChatClientAgent_SuspendResumeThread.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.Extensions.AI.Agents; + +namespace Steps; + +/// +/// Demonstrates how to suspend and resume a thread with the . +/// +public sealed class Step08_ChatClientAgent_SuspendResumeThread(ITestOutputHelper output) : AgentSample(output) +{ + private const string JokerName = "Joker"; + private const string JokerInstructions = "You are good at telling jokes."; + + /// + /// Demonstrate the usage of where a thread is suspended. + /// The thread is serialized and can be stored to a database, file, or any other storage mechanism, + /// and then deserialized later to resume the conversation with the agent. + /// + [Theory] + [InlineData(ChatClientProviders.AzureAIAgentsPersistent)] + [InlineData(ChatClientProviders.AzureOpenAI)] + [InlineData(ChatClientProviders.OpenAIAssistant)] + [InlineData(ChatClientProviders.OpenAIResponses_InMemoryMessageThread)] + [InlineData(ChatClientProviders.OpenAIResponses_ConversationIdThread)] + public async Task SuspendResumeThread(ChatClientProviders provider) + { + // Define the options for the chat client agent. + var agentOptions = new ChatClientAgentOptions + { + Name = JokerName, + Instructions = JokerInstructions, + + // Get chat options based on the store type, if needed. + ChatOptions = base.GetChatOptions(provider), + }; + + // Create the server-side agent Id when applicable (depending on the provider). + agentOptions.Id = await base.AgentCreateAsync(provider, agentOptions); + + // Get the chat client to use for the agent. + using var chatClient = base.GetChatClient(provider, agentOptions); + + // Define the agent + var agent = new ChatClientAgent(chatClient, agentOptions); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Respond to user input + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); + + // Serialize the thread state, so it can be stored for later use. + JsonElement serializedThread = await thread.SerializeAsync(); + + // The thread can now be saved to a database, file, or any other storage mechanism + // and loaded again later. + + // Deserialize the thread state after loading from storage. + AgentThread resumedThread = await agent.DeserializeThreadAsync(serializedThread); + + Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); + + // Clean up the server-side agent and thread after use when applicable (depending on the provider). + await base.AgentCleanUpAsync(provider, agent, thread); + } +} diff --git a/dotnet/samples/GettingStarted/Steps/Step09_ChatClientAgent_3rdPartyThreadStorage.cs b/dotnet/samples/GettingStarted/Steps/Step09_ChatClientAgent_3rdPartyThreadStorage.cs new file mode 100644 index 0000000000..db58f6f862 --- /dev/null +++ b/dotnet/samples/GettingStarted/Steps/Step09_ChatClientAgent_3rdPartyThreadStorage.cs @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI.Agents; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.InMemory; + +namespace Steps; + +/// +/// Demonstrates how to store the chat history of a thread in a 3rd party store when using . +/// +public sealed class Step09_ChatClientAgent_3rdPartyThreadStorage(ITestOutputHelper output) : AgentSample(output) +{ + private const string JokerName = "Joker"; + private const string JokerInstructions = "You are good at telling jokes."; + + /// + /// Demonstrate storage of the chat history of a thread in a 3rd party store when using . + /// + /// + /// Note that this is only supported for services that do not already store the chat history in their own service. + /// + [Theory] + [InlineData(ChatClientProviders.AzureOpenAI)] + [InlineData(ChatClientProviders.OpenAIResponses_InMemoryMessageThread)] + public async Task ThirdPartyStorageThread(ChatClientProviders provider) + { + var inMemoryVectorStore = new InMemoryVectorStore(); + + // Define the options for the chat client agent. + var agentOptions = new ChatClientAgentOptions + { + Name = JokerName, + Instructions = JokerInstructions, + + // Get chat options based on the store type, if needed. + ChatOptions = base.GetChatOptions(provider), + + ChatMessageStoreFactory = () => + { + // Create a new chat message store for this agent that stores the messages in a vector store. + // Each thread must get its own copy of the VectorChatMessageStore, since the store + // also contains the id that the thread is stored under. + return new VectorChatMessageStore(inMemoryVectorStore); + } + }; + + // Get the chat client to use for the agent. + using var chatClient = base.GetChatClient(provider, agentOptions); + + // Define the agent + var agent = new ChatClientAgent(chatClient, agentOptions); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Respond to user input + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); + + // Serialize the thread state, so it can be stored for later use. + // Since the chat history is stored in the vector store, the serialized there + // only contains the guid that the messages are stored under in the vector store. + JsonElement serializedThread = await thread.SerializeAsync(); + + // The serialized thread can now be saved to a database, file, or any other storage mechanism + // and loaded again later. + + // Deserialize the thread state after loading from storage. + AgentThread resumedThread = await agent.DeserializeThreadAsync(serializedThread); + + Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); + } + + /// + /// A sample implementation of that stores chat messages in a vector store. + /// + /// The vector store to store the messages in. + private sealed class VectorChatMessageStore(VectorStore vectorStore) : IChatMessageStore + { + private string? _threadId; + + public string? ThreadId => this._threadId; + + public async Task AddMessagesAsync(IReadOnlyCollection messages, CancellationToken cancellationToken) + { + this._threadId ??= Guid.NewGuid().ToString(); + + var collection = vectorStore.GetCollection("ChatHistory"); + await collection.EnsureCollectionExistsAsync(cancellationToken); + + await collection.UpsertAsync(messages.Select(x => new ChatHistoryItem() + { + Key = this._threadId + x.MessageId, + Timestamp = DateTimeOffset.UtcNow, + ThreadId = this._threadId, + SerializedMessage = JsonSerializer.Serialize(x), + MessageText = x.Text + }), cancellationToken); + } + + public async Task> GetMessagesAsync(CancellationToken cancellationToken) + { + var collection = vectorStore.GetCollection("ChatHistory"); + await collection.EnsureCollectionExistsAsync(cancellationToken); + + var records = await collection + .GetAsync( + x => x.ThreadId == this._threadId, 10, + new() { OrderBy = x => x.Descending(y => y.Timestamp) }, + cancellationToken) + .ToListAsync(cancellationToken); + + var messages = records + .Select(x => JsonSerializer.Deserialize(x.SerializedMessage!)!) + .ToList(); + messages.Reverse(); + return messages; + } + + public ValueTask SerializeStateAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + // We have to serialize the thread id, so that on deserialization we can retrieve the messages using the same thread id. + return new ValueTask(JsonSerializer.SerializeToElement(this._threadId)); + } + + public ValueTask DeserializeStateAsync(JsonElement? serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + // Here we can deserialize the thread id so that we can access the same messages as before the suspension. + this._threadId = JsonSerializer.Deserialize((JsonElement)serializedStoreState!); + return new ValueTask(); + } + + /// + /// The data structure used to store chat history items in the vector store. + /// + private sealed class ChatHistoryItem + { + [VectorStoreKey] + public string? Key { get; set; } + + [VectorStoreData] + public string? ThreadId { get; set; } + + [VectorStoreData] + public DateTimeOffset? Timestamp { get; set; } + + [VectorStoreData] + public string? SerializedMessage { get; set; } + + [VectorStoreData] + public string? MessageText { get; set; } + } + } +} diff --git a/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActor.cs b/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActor.cs index dc29511e73..275c4f3467 100644 --- a/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActor.cs +++ b/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActor.cs @@ -14,7 +14,7 @@ internal sealed class ChatClientAgentActor( ILogger logger) : IActor { private string? _etag; - private ChatClientAgentThread? _thread; + private AgentThread? _thread; public ValueTask DisposeAsync() => default; @@ -34,11 +34,11 @@ internal sealed class ChatClientAgentActor( if (threadResult.Value is { } threadJson) { // Deserialize the thread state if it exist - this._thread = threadJson.Deserialize(ChatClientAgentActorJsonContext.Default.ChatClientAgentThread); + this._thread = await agent.DeserializeThreadAsync(threadJson, cancellationToken: cancellationToken).ConfigureAwait(false); } } - this._thread ??= agent.GetNewThread() as ChatClientAgentThread ?? throw new InvalidOperationException("The agent did not provide a valid thread instance."); + this._thread ??= agent.GetNewThread(); Log.ThreadStateRestored(logger, context.ActorId.ToString(), response.Results[0] is GetValueResult { Value: not null }); while (!cancellationToken.IsCancellationRequested) diff --git a/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActorJsonContext.cs b/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActorJsonContext.cs index 52671785b1..fe8d72b533 100644 --- a/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActorJsonContext.cs +++ b/dotnet/samples/HelloHttpApi/HelloHttpApi.ApiService/ChatClientAgentActorJsonContext.cs @@ -2,7 +2,6 @@ using System.Text.Json; using System.Text.Json.Serialization; -using Microsoft.Extensions.AI.Agents; namespace HelloHttpApi.ApiService; @@ -14,6 +13,5 @@ namespace HelloHttpApi.ApiService; UseStringEnumConverter = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, WriteIndented = false)] -[JsonSerializable(typeof(ChatClientAgentThread))] [JsonSerializable(typeof(ChatClientAgentRunRequest))] internal sealed partial class ChatClientAgentActorJsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.Orchestration/OrchestratingAgent.cs b/dotnet/src/Microsoft.Agents.Orchestration/OrchestratingAgent.cs index cf974cd6ef..733003d8c0 100644 --- a/dotnet/src/Microsoft.Agents.Orchestration/OrchestratingAgent.cs +++ b/dotnet/src/Microsoft.Agents.Orchestration/OrchestratingAgent.cs @@ -70,13 +70,13 @@ public abstract partial class OrchestratingAgent : AIAgent if (thread is not null) { - if (thread is not IMessagesRetrievableThread retrievableThread) + if (thread.MessageStore is null) { - throw new InvalidOperationException($"The thread type '{thread.GetType().Name}' is not supported by this agent. Use {nameof(GetNewThread)} to create a thread when needed."); + throw new InvalidOperationException("An agent service managed thread is not supported by this agent."); } List messagesList = []; - await foreach (var threadMessage in retrievableThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) + await foreach (var threadMessage in thread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) { messagesList.Add(threadMessage); } @@ -101,9 +101,6 @@ public abstract partial class OrchestratingAgent : AIAgent } } - /// - public sealed override AgentThread GetNewThread() => new ChatClientAgentThread(); - /// /// Initiates processing of the orchestration. /// @@ -207,10 +204,6 @@ public abstract partial class OrchestratingAgent : AIAgent return response; } - /// - protected sealed override TThreadType ValidateOrCreateThreadType(AgentThread? thread, Func constructThread) => - base.ValidateOrCreateThreadType(thread, constructThread); - /// Writes the specified checkpoint state to the runtime. /// The state to persist. /// The context for the orchestrating operation. diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs index d0cd678957..b09a3985cd 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -50,7 +51,21 @@ public abstract class AIAgent /// If the thread needs to be created via a service call it would be created on first use. /// /// - public abstract AgentThread GetNewThread(); + public virtual AgentThread GetNewThread() => new(); + + /// + /// Deserialize the thread from JSON. + /// + /// The representing the thread state. + /// Optional to use for deserializing the thread state. + /// The to monitor for cancellation requests. The default is . + /// The deserialized instance. + public async ValueTask DeserializeThreadAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + var thread = this.GetNewThread(); + await thread.DeserializeAsync(serializedThread, jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + return thread; + } /// /// Run the agent with no message assuming that all required instructions are already provided to the agent or on the thread. @@ -192,30 +207,6 @@ public abstract class AIAgent AgentRunOptions? options = null, CancellationToken cancellationToken = default); - /// - /// Checks that the thread is of the expected type, or if null, creates the default thread type. - /// - /// The expected type of the thead. - /// The thread to create if it's null and validate its type if not null. - /// A callback to use to construct the thread if it's null. - /// An async task that completes once all update are complete. - protected virtual TThreadType ValidateOrCreateThreadType( - AgentThread? thread, - Func constructThread) - where TThreadType : AgentThread - { - Throw.IfNull(constructThread); - - thread ??= constructThread(); - - if (thread is not TThreadType concreteThreadType) - { - throw new NotSupportedException($"{this.GetType().Name} currently only supports agent threads of type {typeof(TThreadType).Name}."); - } - - return concreteThreadType; - } - /// /// Notfiy the given thread that new messages are available. /// diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentAbstractionsJsonUtilities.cs index 65830d4b09..4f86e66994 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -57,7 +57,8 @@ public static partial class AgentAbstractionsJsonUtilities [JsonSerializable(typeof(AgentRunResponse[]))] [JsonSerializable(typeof(AgentRunResponseUpdate))] [JsonSerializable(typeof(AgentRunResponseUpdate[]))] - [JsonSerializable(typeof(AgentThread))] + [JsonSerializable(typeof(AgentThread.ThreadState))] + [JsonSerializable(typeof(InMemoryChatMessageStore.StoreState))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentRunResponseUpdate.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentRunResponseUpdate.cs index 4e6ce5f22e..9cd3996021 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentRunResponseUpdate.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentRunResponseUpdate.cs @@ -141,9 +141,11 @@ public class AgentRunResponseUpdate /// Gets a object to display in the debugger display. [DebuggerBrowsable(DebuggerBrowsableState.Never)] + [ExcludeFromCodeCoverage] private AIContent? ContentForDebuggerDisplay => this._contents is { Count: > 0 } ? this._contents[0] : null; /// Gets an indication for the debugger display of whether there's more content. [DebuggerBrowsable(DebuggerBrowsableState.Never)] + [ExcludeFromCodeCoverage] private string EllipsesForDebuggerDisplay => this._contents is { Count: > 1 } ? ", ..." : string.Empty; } diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs index b2fcb9d0a4..debc0abadc 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs @@ -2,8 +2,12 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Agents; @@ -13,21 +17,114 @@ namespace Microsoft.Extensions.AI.Agents; /// public class AgentThread { + private string? _conversationId; + private IChatMessageStore? _messageStore; + /// - /// Gets or sets the id of the current thread. + /// Initializes a new instance of the class. + /// + public AgentThread() + { + } + + /// + /// Gets or sets the id of the current thread to support cases where the thread is owned by the agent service. /// /// /// - /// This id may be null if the thread has no id, or - /// if it represents a service-owned thread but the service - /// has not yet been called to create the thread. + /// Note that either or may be set, but not both. + /// If is not null, and is set, + /// will be reverted to null, and vice versa. /// /// - /// The id may also change over time where the - /// is a proxy to a service owned thread that forks on each agent invocation. + /// This property may be null in the following cases: + /// + /// The thread stores messages via the and not in the agent service. + /// This thread object is new and a server managed thread has not yet been created in the agent service. + /// + /// + /// + /// The id may also change over time where the the id is pointing at a + /// agent service managed thread, and the default behavior of a service is + /// to fork the thread with each iteration. /// /// - public string? Id { get; set; } + public string? ConversationId + { + get { return this._conversationId; } + set + { + if (string.IsNullOrWhiteSpace(this._conversationId) && string.IsNullOrWhiteSpace(value)) + { + return; + } + + if (this._messageStore is not null) + { + // If we have a message store already, we shouldn't switch the thread to use a conversation id + // since it means that the thread contents will essentially be deleted, and the thread will not work + // with the original agent anymore. + throw new InvalidOperationException("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported."); + } + + this._conversationId = Throw.IfNullOrWhitespace(value); + } + } + + /// + /// Gets or sets the used by this thread, for cases where messages should be stored in a custom location. + /// + /// + /// + /// Note that either or may be set, but not both. + /// If is not null, and is set, + /// will be reverted to null, and vice versa. + /// + /// + /// This property may be null in the following cases: + /// + /// The thread stores messages in the agent service and just has an id to the remove thread, instead of in an . + /// This thread object is new it is not yet clear whether it will be backed by a server managed thread or an . + /// + /// + /// + public IChatMessageStore? MessageStore + { + get { return this._messageStore; } + set + { + if (this._messageStore is null && value is null) + { + return; + } + + if (!string.IsNullOrWhiteSpace(this._conversationId)) + { + // If we have a conversation id already, we shouldn't switch the thread to use a message store + // since it means that the thread will not work with the original agent anymore. + throw new InvalidOperationException("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported."); + } + + this._messageStore = Throw.IfNull(value); + } + } + + /// + /// Retrieves any messages stored in the of the thread, otherwise returns an empty collection. + /// + /// The to monitor for cancellation requests. The default is . + /// The messages from the in ascending chronological order, with the oldest message first. + public virtual async IAsyncEnumerable GetMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (this._messageStore is not null) + { + var messages = await this._messageStore!.GetMessagesAsync(cancellationToken).ConfigureAwait(false); + foreach (var message in messages) + { + yield return message; + } + } + } /// /// This method is called when new messages have been contributed to the chat by any participant. @@ -39,8 +136,92 @@ public class AgentThread /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been updated. /// The thread has been deleted. - protected internal virtual Task OnNewMessagesAsync(IReadOnlyCollection newMessages, CancellationToken cancellationToken = default) + protected internal virtual async Task OnNewMessagesAsync(IReadOnlyCollection newMessages, CancellationToken cancellationToken = default) { - return Task.CompletedTask; + switch (this) + { + case { ConversationId: not null }: + // If the thread messages are stored in the service + // there is nothing to do here, since invoking the + // service should already update the thread. + break; + + case { MessageStore: null }: + // If there is no conversation id, and no store we can createa a default in memory store and add messages to it. + this._messageStore = new InMemoryChatMessageStore(); + await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false); + break; + + case { MessageStore: not null }: + // If a store has been provided, we need to add the messages to the store. + await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false); + break; + + default: + throw new UnreachableException(); + } + } + + /// + /// Deserializes the state contained in the provided into the properties on this thread. + /// + /// A representing the state of the thread. + /// Optional settings for customizing the JSON deserialization process. + /// The to monitor for cancellation requests. The default is . + protected internal virtual async Task DeserializeAsync(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + var state = JsonSerializer.Deserialize( + serializedThread, + AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState))) as ThreadState; + + if (state?.ConversationId is string threadId) + { + this.ConversationId = threadId; + + // Since we have an ID, we should not have a chat message store and we can return here. + return; + } + + // If we don't have any IChatMessageStore state return here. + if (state?.StoreState is null || state?.StoreState?.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) + { + return; + } + + if (this._messageStore is null) + { + // If we don't have a chat message store yet, create an in-memory one. + this._messageStore = new InMemoryChatMessageStore(); + } + + await this._messageStore.DeserializeStateAsync(state!.StoreState.Value, jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + } + + /// + /// Serializes the current object's state to a using the specified serialization options. + /// + /// The JSON serialization options to use. + /// The to monitor for cancellation requests. The default is . + /// A representation of the object's state. + public virtual async Task SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + var storeState = this._messageStore is null ? + (JsonElement?)null : + await this._messageStore.SerializeStateAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + + var state = new ThreadState + { + ConversationId = this.ConversationId, + StoreState = storeState + }; + + return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState))); + } + + internal class ThreadState + { + public string? ConversationId { get; set; } + + public JsonElement? StoreState { get; set; } } } diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/IChatMessageStore.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/IChatMessageStore.cs new file mode 100644 index 0000000000..f54c091264 --- /dev/null +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/IChatMessageStore.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI.Agents; + +/// +/// Defines methods for storing and retrieving chat messages associated with a specific thread. +/// +/// +/// Implementations of this interface are responsible for managing the storage of chat messages, +/// including handling large volumes of data by truncating or summarizing messages as necessary. +/// +public interface IChatMessageStore +{ + /// + /// Gets all the messages from the store that should be used for the next agent invocation. + /// + /// The to monitor for cancellation requests. The default is . + /// A collection of chat messages. + /// + /// + /// Messages are returned in ascending chronological order, with the oldest message first. + /// + /// + /// If the messages stored in the store become very large, it is up to the store to + /// truncate, summarize or otherwise limit the number of messages returned. + /// + /// + /// When using implementations of , a new one should be created for each thread + /// since they may contain state that is specific to a thread. + /// + /// + Task> GetMessagesAsync(CancellationToken cancellationToken); + + /// + /// Adds messages to the store. + /// + /// The messages to add. + /// The to monitor for cancellation requests. The default is . + /// An async task. + Task AddMessagesAsync(IReadOnlyCollection messages, CancellationToken cancellationToken); + + /// + /// Deserializes the state contained in the provided into the properties on this store. + /// + /// A representing the state of the store. + /// Optional settings for customizing the JSON deserialization process. + /// The to monitor for cancellation requests. The default is . + /// + /// This method, together with can be used to save and load messages from a persistent store + /// if this store only has messages in memory. + /// + ValueTask DeserializeStateAsync(JsonElement? serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default); + + /// + /// Serializes the current object's state to a using the specified serialization options. + /// + /// The JSON serialization options to use. + /// The to monitor for cancellation requests. The default is . + /// A representation of the object's state. + /// + /// This method, together with can be used to save and load messages from a persistent store + /// if this store only has messages in memory. + /// + ValueTask SerializeStateAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/IMessagesRetrievableThread.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/IMessagesRetrievableThread.cs deleted file mode 100644 index 85c0f8aaa9..0000000000 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/IMessagesRetrievableThread.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Threading; - -namespace Microsoft.Extensions.AI.Agents; - -/// -/// An interface for agent threads that allow retrieval of messages in the thread for agent invocation. -/// -/// -/// -/// Some agents need to be invoked with all relevant chat history messages in order to produce a result, while some must be invoked -/// with the id of a server side thread that contains the chat history. -/// -/// -/// This interface can be implemented by all thread types that support the case where the agent is invoked with the chat history. -/// Implementations must consider the size of the messages provided, so that they do not exceed the maximum size of the context window -/// of the agent they are used with. Where appropriate, implementations should truncate or summarize messages so that the size of messages -/// are constrained. -/// -/// -public interface IMessagesRetrievableThread -{ - /// - /// Asynchronously retrieves all messages to be used for the agent invocation. - /// - /// - /// Messages are returned in ascending chronological order. - /// - /// The to monitor for cancellation requests. The default is . - /// The messages in the thread. - /// The thread has been deleted. - IAsyncEnumerable GetMessagesAsync(CancellationToken cancellationToken = default); -} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/InMemoryChatMessageStore.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/InMemoryChatMessageStore.cs new file mode 100644 index 0000000000..57dce94c71 --- /dev/null +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/InMemoryChatMessageStore.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI.Agents; + +/// +/// Represents an in-memory store for chat messages associated with a specific thread. +/// +internal class InMemoryChatMessageStore : IList, IChatMessageStore +{ + private readonly List _messages = new(); + + /// + public int Count => this._messages.Count; + + /// + public bool IsReadOnly => ((IList)this._messages).IsReadOnly; + + /// + public ChatMessage this[int index] + { + get => this._messages[index]; + set => this._messages[index] = value; + } + + /// + public Task AddMessagesAsync(IReadOnlyCollection messages, CancellationToken cancellationToken) + { + _ = Throw.IfNull(messages); + this._messages.AddRange(messages); + return Task.CompletedTask; + } + + /// + public Task> GetMessagesAsync(CancellationToken cancellationToken) + { + return Task.FromResult>(this._messages); + } + + /// + public ValueTask DeserializeStateAsync(JsonElement? serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + if (serializedStoreState is null) + { + return new ValueTask(); + } + + var state = JsonSerializer.Deserialize( + serializedStoreState.Value, + AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState))) as StoreState; + + if (state?.Messages is { Count: > 0 } messages) + { + this._messages.AddRange(messages); + } + + return new ValueTask(); + } + + /// + public ValueTask SerializeStateAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + StoreState state = new() + { + Messages = this._messages, + }; + + return new ValueTask(JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState)))); + } + + /// + public int IndexOf(ChatMessage item) + => this._messages.IndexOf(item); + + /// + public void Insert(int index, ChatMessage item) + => this._messages.Insert(index, item); + + /// + public void RemoveAt(int index) + => this._messages.RemoveAt(index); + + /// + public void Add(ChatMessage item) + => this._messages.Add(item); + + /// + public void Clear() + => this._messages.Clear(); + + /// + public bool Contains(ChatMessage item) + => this._messages.Contains(item); + + /// + public void CopyTo(ChatMessage[] array, int arrayIndex) + => this._messages.CopyTo(array, arrayIndex); + + /// + public bool Remove(ChatMessage item) + => this._messages.Remove(item); + + /// + public IEnumerator GetEnumerator() + => this._messages.GetEnumerator(); + + /// + IEnumerator IEnumerable.GetEnumerator() + => this.GetEnumerator(); + + internal class StoreState + { + public IList Messages { get; set; } = new List(); + } +} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj index 89492227c5..e0ad6ff297 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj @@ -10,6 +10,7 @@ true + true true diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.CopilotStudio/CopilotStudioAgent.cs index ede152af99..9705662786 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.CopilotStudio/CopilotStudioAgent.cs @@ -37,12 +37,6 @@ public class CopilotStudioAgent : AIAgent this._logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - /// - public override AgentThread GetNewThread() - { - return new CopilotStudioAgentThread(); - } - /// public override async Task RunAsync( IReadOnlyCollection messages, @@ -54,12 +48,12 @@ public class CopilotStudioAgent : AIAgent // Ensure that we have a valid thread to work with. // If the thread ID is null, we need to start a new conversation and set the thread ID accordingly. - CopilotStudioAgentThread copilotStudioAgentThread = base.ValidateOrCreateThreadType(thread, () => new CopilotStudioAgentThread()); - copilotStudioAgentThread.Id ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false); + thread ??= this.GetNewThread(); + thread.ConversationId ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false); // Invoke the Copilot Studio agent with the provided messages. string question = string.Join("\n", messages.Select(m => m.Text)); - var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, copilotStudioAgentThread.Id, cancellationToken), streaming: false, this._logger); + var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, thread.ConversationId, cancellationToken), streaming: false, this._logger); var responseMessagesList = new List(); await foreach (var message in responseMessages.ConfigureAwait(false)) { @@ -87,12 +81,12 @@ public class CopilotStudioAgent : AIAgent // Ensure that we have a valid thread to work with. // If the thread ID is null, we need to start a new conversation and set the thread ID accordingly. - CopilotStudioAgentThread copilotStudioAgentThread = base.ValidateOrCreateThreadType(thread, () => new CopilotStudioAgentThread()); - copilotStudioAgentThread.Id ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false); + thread ??= this.GetNewThread(); + thread.ConversationId ??= await this.StartNewConversationAsync(cancellationToken).ConfigureAwait(false); // Invoke the Copilot Studio agent with the provided messages. string question = string.Join("\n", messages.Select(m => m.Text)); - var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, copilotStudioAgentThread.Id, cancellationToken), streaming: true, this._logger); + var responseMessages = ActivityProcessor.ProcessActivityAsync(this.Client.AskQuestionAsync(question, thread.ConversationId, cancellationToken), streaming: true, this._logger); // Enumerate the response messages await foreach (ChatMessage message in responseMessages.ConfigureAwait(false)) diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.CopilotStudio/CopilotStudioAgentThread.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.CopilotStudio/CopilotStudioAgentThread.cs deleted file mode 100644 index cff6d2c399..0000000000 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.CopilotStudio/CopilotStudioAgentThread.cs +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -namespace Microsoft.Extensions.AI.Agents.CopilotStudio; - -/// -/// Represents a thread for interacting with a Copilot Studio agent. -/// -public class CopilotStudioAgentThread : AgentThread; diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/AgentsJsonContext.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/AgentsJsonContext.cs deleted file mode 100644 index d222a34966..0000000000 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/AgentsJsonContext.cs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using System.Text.Json; -using System.Text.Json.Serialization; - -namespace Microsoft.Extensions.AI.Agents; - -/// -/// Source-generated JSON type information for use by all Agents implementations. -/// -[JsonSourceGenerationOptions( - JsonSerializerDefaults.Web, - UseStringEnumConverter = true, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - WriteIndented = false)] -[JsonSerializable(typeof(ChatMessage))] -[JsonSerializable(typeof(List))] -[JsonSerializable(typeof(ChatClientAgentThread))] -internal sealed partial class AgentsJsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs index a664e214e0..37ac945ec6 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs @@ -100,7 +100,7 @@ public sealed class ChatClientAgent : AIAgent { Throw.IfNull(messages); - (ChatClientAgentThread chatClientThread, ChatOptions? chatOptions, List threadMessages) = + (AgentThread safeThread, ChatOptions? chatOptions, List threadMessages) = await this.PrepareThreadAndMessagesAsync(thread, messages, options, cancellationToken).ConfigureAwait(false); var agentName = this.GetLoggingAgentName(); @@ -113,10 +113,10 @@ public sealed class ChatClientAgent : AIAgent // We can derive the type of supported thread from whether we have a conversation id, // so let's update it and set the conversation id for the service thread case. - this.UpdateThreadWithTypeAndConversationId(chatClientThread, chatResponse.ConversationId); + this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread. - await this.NotifyThreadOfNewMessagesAsync(chatClientThread, messages, cancellationToken).ConfigureAwait(false); + await this.NotifyThreadOfNewMessagesAsync(safeThread, messages, cancellationToken).ConfigureAwait(false); // Ensure that the author name is set for each message in the response. foreach (ChatMessage chatResponseMessage in chatResponse.Messages) @@ -127,7 +127,7 @@ public sealed class ChatClientAgent : AIAgent // Convert the chat response messages to a valid IReadOnlyCollection for notification signatures below. var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection ?? [.. chatResponse.Messages]; - await this.NotifyThreadOfNewMessagesAsync(chatClientThread, chatResponseMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false); return new(chatResponse) { AgentId = this.Id }; } @@ -141,7 +141,7 @@ public sealed class ChatClientAgent : AIAgent { var inputMessages = Throw.IfNull(messages); - (ChatClientAgentThread chatClientThread, ChatOptions? chatOptions, List threadMessages) = + (AgentThread safeThread, ChatOptions? chatOptions, List threadMessages) = await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false); int messageCount = threadMessages.Count; @@ -177,16 +177,20 @@ public sealed class ChatClientAgent : AIAgent // We can derive the type of supported thread from whether we have a conversation id, // so let's update it and set the conversation id for the service thread case. - this.UpdateThreadWithTypeAndConversationId(chatClientThread, chatResponse.ConversationId); + this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. - await this.NotifyThreadOfNewMessagesAsync(chatClientThread, inputMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false); - await this.NotifyThreadOfNewMessagesAsync(chatClientThread, chatResponseMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false); } /// - public override AgentThread GetNewThread() => new ChatClientAgentThread(); + public override AgentThread GetNewThread() + { + var thread = new AgentThread() { MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke() }; + return thread; + } #region Private @@ -312,7 +316,7 @@ public sealed class ChatClientAgent : AIAgent /// Optional parameters for agent invocation. /// The cancellation token. /// A tuple containing the thread, chat options, and thread messages. - private async Task<(ChatClientAgentThread, ChatOptions?, List)> PrepareThreadAndMessagesAsync( + private async Task<(AgentThread, ChatOptions?, List)> PrepareThreadAndMessagesAsync( AgentThread? thread, IReadOnlyCollection inputMessages, AgentRunOptions? runOptions, @@ -320,16 +324,13 @@ public sealed class ChatClientAgent : AIAgent { ChatOptions? chatOptions = this.CreateConfiguredChatOptions(runOptions); - var chatClientThread = this.ValidateOrCreateThreadType(thread, () => new()); + thread ??= this.GetNewThread(); // Add any existing messages from the thread to the messages to be sent to the chat client. List threadMessages = []; - if (chatClientThread is IMessagesRetrievableThread messagesRetrievableThread) + await foreach (ChatMessage message in thread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) { - await foreach (ChatMessage message in messagesRetrievableThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false)) - { - threadMessages.Add(message); - } + threadMessages.Add(message); } // Update the messages with agent instructions. @@ -340,39 +341,43 @@ public sealed class ChatClientAgent : AIAgent // If a user provided two different thread ids, via the thread object and options, we should throw // since we don't know which one to use. - if (!string.IsNullOrWhiteSpace(chatClientThread.Id) && !string.IsNullOrWhiteSpace(chatOptions?.ConversationId) && chatClientThread.Id != chatOptions.ConversationId) + if (!string.IsNullOrWhiteSpace(thread.ConversationId) && !string.IsNullOrWhiteSpace(chatOptions?.ConversationId) && thread.ConversationId != chatOptions.ConversationId) { throw new InvalidOperationException( $"The {nameof(chatOptions.ConversationId)} provided via {nameof(Microsoft.Extensions.AI.ChatOptions)} is different to the id of the provided {nameof(AgentThread)}. Only one thread id can be used for a run."); } // Only clone and update ChatOptions if we have an id on the thread and we don't have the same one already in ChatOptions. - if (!string.IsNullOrWhiteSpace(chatClientThread.Id) && chatClientThread.Id != chatOptions?.ConversationId) + if (!string.IsNullOrWhiteSpace(thread.ConversationId) && thread.ConversationId != chatOptions?.ConversationId) { chatOptions ??= new(); - chatOptions.ConversationId = chatClientThread.Id; + chatOptions.ConversationId = thread.ConversationId; } - return (chatClientThread, chatOptions, threadMessages); + return (thread, chatOptions, threadMessages); } - private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread chatClientThread, string? responseConversationId) + private void UpdateThreadWithTypeAndConversationId(AgentThread thread, string? responseConversationId) { - // Set the thread's storage location, the first time that we use it. - chatClientThread.StorageLocation ??= string.IsNullOrWhiteSpace(responseConversationId) - ? ChatClientAgentThreadType.InMemoryMessages - : ChatClientAgentThreadType.ConversationId; - - // If we got a conversation id back from the chat client, it means that the service supports server side thread storage - // so we should capture the id and update the thread with the new id. - if (chatClientThread.StorageLocation == ChatClientAgentThreadType.ConversationId) + if (string.IsNullOrWhiteSpace(responseConversationId) && !string.IsNullOrWhiteSpace(thread.ConversationId)) { - if (string.IsNullOrWhiteSpace(responseConversationId)) - { - throw new InvalidOperationException("Service did not return a valid conversation id when using a service managed thread."); - } + // We were passed a thread that is service managed, but we got no conversation id back from the chat client, + // meaning the service doesn't support service managed threads, so the thread cannot be used with this service. + throw new InvalidOperationException("Service did not return a valid conversation id when using a service managed thread."); + } - chatClientThread.Id = responseConversationId; + if (!string.IsNullOrWhiteSpace(responseConversationId)) + { + // If we got a conversation id back from the chat client, it means that the service supports server side thread storage + // so we should update the thread with the new id. + thread.ConversationId = responseConversationId; + } + else if (thread.MessageStore is null) + { + // If the service doesn't use service side thread storage (i.e. we got no id back from invocation), and + // the thread has no MessageStore yet, and we have a custom messages store, we should update the thread + // with the custom MessageStore so that it has somewhere to store the chat history. + thread.MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(); } } diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs index a8bb9d144a..6c5cee3ab7 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; namespace Microsoft.Extensions.AI.Agents; @@ -72,6 +73,12 @@ public class ChatClientAgentOptions /// public ChatOptions? ChatOptions { get; set; } + /// + /// Gets or sets a factory function to create an instance of + /// which will be used to store chat messages for this agent. + /// + public Func? ChatMessageStoreFactory { get; set; } = null; + /// /// Creates a new instance of with the same values as this instance. /// @@ -82,6 +89,7 @@ public class ChatClientAgentOptions Name = this.Name, Instructions = this.Instructions, Description = this.Description, - ChatOptions = this.ChatOptions?.Clone() + ChatOptions = this.ChatOptions?.Clone(), + ChatMessageStoreFactory = this.ChatMessageStoreFactory }; } diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentThread.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentThread.cs deleted file mode 100644 index 80f822b91b..0000000000 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentThread.cs +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Diagnostics; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Text.Json.Serialization; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Extensions.AI.Agents; - -/// -/// Chat client agent thread. -/// -[JsonConverter(typeof(Converter))] -public sealed class ChatClientAgentThread : AgentThread, IMessagesRetrievableThread -{ - private readonly List _chatMessages = []; - - /// - /// Initializes a new instance of the class. - /// - public ChatClientAgentThread() - { - } - - /// - /// Initializes a new instance of the class. - /// - /// The id of an existing server side thread to continue. - /// - /// This constructor creates a that supports in-service message storage. - /// - public ChatClientAgentThread(string id) - { - Throw.IfNullOrWhitespace(id); - - this.Id = id; - this.StorageLocation = ChatClientAgentThreadType.ConversationId; - } - - /// - /// Initializes a new instance of the class. - /// - /// A set of initial messages to seed the thread with. - /// - /// This constructor creates a that supports local in-memory message storage. - /// - public ChatClientAgentThread(IEnumerable messages) - { - Throw.IfNull(messages); - - this._chatMessages.AddRange(messages); - this.StorageLocation = ChatClientAgentThreadType.InMemoryMessages; - } - - /// - /// Gets the location of the thread contents. - /// - internal ChatClientAgentThreadType? StorageLocation { get; set; } - -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - /// - public async IAsyncEnumerable GetMessagesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) - { - foreach (var message in this._chatMessages) - { - yield return message; - } - } - -#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously - - /// - protected override Task OnNewMessagesAsync(IReadOnlyCollection newMessages, CancellationToken cancellationToken = default) - { - switch (this.StorageLocation) - { - case ChatClientAgentThreadType.InMemoryMessages: - this._chatMessages.AddRange(newMessages); - break; - case ChatClientAgentThreadType.ConversationId: - // If the thread messages are stored in the service - // there is nothing to do here, since invoking the - // service should already update the thread. - break; - default: - throw new UnreachableException(); - } - - return Task.CompletedTask; - } - - /// - /// Provides a for objects. - /// - [EditorBrowsable(EditorBrowsableState.Never)] - public sealed class Converter : JsonConverter - { - /// - public override ChatClientAgentThread? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - if (reader.TokenType != JsonTokenType.StartObject) - { - throw new JsonException("Expected StartObject token"); - } - - using var doc = JsonDocument.ParseValue(ref reader); - var root = doc.RootElement; - - // Extract properties from JSON - string? id = null; - if (root.TryGetProperty("id", out var idProperty)) - { - id = idProperty.GetString(); - } - - List? messages = null; - if (root.TryGetProperty("messages", out var messagesProperty)) - { - if (messagesProperty.ValueKind == JsonValueKind.Array) - { - messages = []; - foreach (var messageElement in messagesProperty.EnumerateArray()) - { - var message = messageElement.Deserialize(options.GetTypeInfo(AgentsJsonContext.Default)); - if (message != null) - { - messages.Add(message); - } - } - } - } - - // Create the appropriate instance based on available data - // StorageLocation will be set automatically by the constructors - ChatClientAgentThread thread; - if (messages?.Count > 0) - { - thread = new ChatClientAgentThread(messages); - } - else if (!string.IsNullOrWhiteSpace(id)) - { - thread = new ChatClientAgentThread(id); - } - else - { - thread = new ChatClientAgentThread(); - } - - // Override Id if it was explicitly set in JSON (for cases where messages exist but ID is also provided) - if (id != null) - { - thread.Id = id; - } - - return thread; - } - - /// - public override void Write(Utf8JsonWriter writer, ChatClientAgentThread value, JsonSerializerOptions options) - { - writer.WriteStartObject(); - - // Write base properties - if (value.Id != null) - { - writer.WriteString("id", value.Id); - } - - // Write messages if in memory storage (StorageLocation is determined by presence of messages vs ID) - if (value.StorageLocation == ChatClientAgentThreadType.InMemoryMessages) - { - writer.WritePropertyName("messages"); - JsonSerializer.Serialize(writer, value._chatMessages, options.GetTypeInfo>(AgentsJsonContext.Default)); - } - - writer.WriteEndObject(); - } - } -} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentThreadType.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentThreadType.cs deleted file mode 100644 index ba4d1c7f59..0000000000 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentThreadType.cs +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -namespace Microsoft.Extensions.AI.Agents; - -/// -/// Defines the different supported storage locations for . -/// -internal enum ChatClientAgentThreadType -{ - /// - /// Messages are stored in memory inside the thread object. - /// - InMemoryMessages, - - /// - /// Messages are stored in the service and the thread object just has an id reference the service storage. - /// - ConversationId -} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/OpenTelemetryAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/OpenTelemetryAgent.cs index 3a8cf78efe..b8fbef864e 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/OpenTelemetryAgent.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/OpenTelemetryAgent.cs @@ -210,9 +210,9 @@ public sealed class OpenTelemetryAgent : AIAgent, IDisposable } // Add conversation ID if thread is available (following gen_ai.conversation.id convention) - if (!string.IsNullOrWhiteSpace(thread?.Id)) + if (!string.IsNullOrWhiteSpace(thread?.ConversationId)) { - _ = activity.AddTag(AgentOpenTelemetryConsts.GenAI.ConversationId, thread.Id); + _ = activity.AddTag(AgentOpenTelemetryConsts.GenAI.ConversationId, thread.ConversationId); } // Add instructions if available (for ChatClientAgent) diff --git a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs index 08d952fb0e..00519e40d4 100644 --- a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs +++ b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Threading.Tasks; using AgentConformance.IntegrationTests; @@ -29,14 +28,9 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture public async Task> GetChatHistoryAsync(AgentThread thread) { - if (thread is not ChatClientAgentThread chatClientThread) - { - throw new InvalidOperationException($"The thread must be of type {nameof(ChatClientAgentThread)} to retrieve chat history."); - } - List messages = []; - AsyncPageable threadMessages = this._persistentAgentsClient.Messages.GetMessagesAsync(threadId: thread.Id, order: ListSortOrder.Ascending); + AsyncPageable threadMessages = this._persistentAgentsClient.Messages.GetMessagesAsync(threadId: thread.ConversationId, order: ListSortOrder.Ascending); await foreach (var threadMessage in threadMessages) { @@ -87,9 +81,9 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture public Task DeleteThreadAsync(AgentThread thread) { - if (thread?.Id is not null) + if (thread?.ConversationId is not null) { - return this._persistentAgentsClient.Threads.DeleteThreadAsync(thread.Id); + return this._persistentAgentsClient.Threads.DeleteThreadAsync(thread.ConversationId); } return Task.CompletedTask; diff --git a/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/MockAgent.cs b/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/MockAgent.cs index 78cd4e8680..bf0a428aa3 100644 --- a/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/MockAgent.cs +++ b/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/MockAgent.cs @@ -33,7 +33,7 @@ internal sealed class MockAgent(int index) : AIAgent public override AgentThread GetNewThread() { - return new AgentThread() { Id = Guid.NewGuid().ToString() }; + return new AgentThread() { ConversationId = Guid.NewGuid().ToString() }; } public override Task RunAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/OrchestrationResultTests.cs b/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/OrchestrationResultTests.cs index 42b1fc55cb..bbc79b661c 100644 --- a/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/OrchestrationResultTests.cs +++ b/dotnet/tests/Microsoft.Agents.Orchestration.UnitTests/OrchestrationResultTests.cs @@ -89,8 +89,6 @@ public class OrchestrationResultTests private sealed class MockAgent : AIAgent { - public override AgentThread GetNewThread() => - throw new NotSupportedException(); public override Task RunAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => throw new NotSupportedException(); public override IAsyncEnumerable RunStreamingAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentTests.cs index e9b4c2bd31..2da5604699 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentTests.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Linq; using System.Threading; @@ -220,23 +219,6 @@ public class AgentTests Assert.Equal(id, agent.Id); } - [Fact] - public void ValidateOrCreateThreadTypeVerifiesAndCreatesThread() - { - // Custom thread type for type checking - var threadMock = new Mock() { CallBase = true }; - - var agent = new MockAgent(); - - // Should create - var result = agent.ValidateOrCreateThreadType(null, () => threadMock.Object); - Assert.Same(threadMock.Object, result); - - // Should throw if wrong type - var wrongThread = new Mock().Object; - Assert.Throws(() => agent.ValidateOrCreateThreadType(wrongThread, () => threadMock.Object)); - } - [Fact] public async Task NotifyThreadOfNewMessagesNotifiesThreadAsync() { @@ -245,6 +227,8 @@ public class AgentTests var messages = new[] { new ChatMessage(ChatRole.User, "msg1"), new ChatMessage(ChatRole.User, "msg2") }; var threadMock = new Mock() { CallBase = true }; + threadMock.SetupAllProperties(); + threadMock.Object.ConversationId = "test-thread-id"; var agent = new MockAgent(); await agent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken); @@ -257,31 +241,13 @@ public class AgentTests /// public abstract class TestAgentThread : AgentThread; - /// - /// Mock class to test the method. - /// private sealed class MockAgent : AIAgent { - public new TThreadType ValidateOrCreateThreadType( - AgentThread? thread, - Func constructThread) - where TThreadType : AgentThread - { - return base.ValidateOrCreateThreadType( - thread, - constructThread); - } - public new Task NotifyThreadOfNewMessagesAsync(AgentThread thread, IReadOnlyCollection messages, CancellationToken cancellationToken) { return base.NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken); } - public override AgentThread GetNewThread() - { - throw new NotImplementedException(); - } - public override Task RunAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { throw new System.NotImplementedException(); diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs new file mode 100644 index 0000000000..3a598c36a6 --- /dev/null +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Moq; + +namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests; + +public class AgentThreadTests +{ + #region Constructor and Property Tests + + [Fact] + public void ConstructorSetsDefaults() + { + // Arrange & Act + var thread = new AgentThread(); + + // Assert + Assert.Null(thread.ConversationId); + Assert.Null(thread.MessageStore); + } + + [Fact] + public void SetConversationIdRoundtrips() + { + // Arrange + var thread = new AgentThread(); + var conversationid = "test-thread-id"; + + // Act + thread.ConversationId = conversationid; + + // Assert + Assert.Equal(conversationid, thread.ConversationId); + Assert.Null(thread.MessageStore); + } + + [Fact] + public void SetChatMessageStoreRoundtrips() + { + // Arrange + var thread = new AgentThread(); + var messageStore = new InMemoryChatMessageStore(); + + // Act + thread.MessageStore = messageStore; + + // Assert + Assert.Same(messageStore, thread.MessageStore); + Assert.Null(thread.ConversationId); + } + + [Fact] + public void SetConversationIdThrowsWhenMessageStoreIsSet() + { + // Arrange + var thread = new AgentThread(); + thread.MessageStore = new InMemoryChatMessageStore(); + + // Act & Assert + var exception = Assert.Throws(() => thread.ConversationId = "new-thread-id"); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.NotNull(thread.MessageStore); + } + + [Fact] + public void SetChatMessageStoreThrowsWhenConversationIdIsSet() + { + // Arrange + var thread = new AgentThread(); + thread.ConversationId = "existing-thread-id"; + var store = new InMemoryChatMessageStore(); + + // Act & Assert + var exception = Assert.Throws(() => thread.MessageStore = store); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.NotNull(thread.ConversationId); + } + + #endregion Constructor and Property Tests + + #region GetMessagesAsync Tests + + [Fact] + public async Task GetMessagesAsyncReturnsEmptyListWhenNoStoreAsync() + { + // Arrange + var thread = new AgentThread(); + + // Act + var messages = await thread.GetMessagesAsync(CancellationToken.None).ToListAsync(); + + // Assert + Assert.Empty(messages); + } + + [Fact] + public async Task GetMessagesAsyncReturnsEmptyListWhenAgentServiceIdAsync() + { + // Arrange + var thread = new AgentThread() { ConversationId = "thread-123" }; + + // Act + var messages = await thread.GetMessagesAsync(CancellationToken.None).ToListAsync(); + + // Assert + Assert.Empty(messages); + } + + [Fact] + public async Task GetMessagesAsyncReturnsMessagesFromStoreAsync() + { + // Arrange + var store = new InMemoryChatMessageStore + { + new ChatMessage(ChatRole.User, "Hello"), + new ChatMessage(ChatRole.Assistant, "Hi there!") + }; + var thread = new AgentThread() { MessageStore = store }; + + // Act + var messages = await thread.GetMessagesAsync(CancellationToken.None).ToListAsync(); + + // Assert + Assert.Equal(2, messages.Count); + Assert.Equal("Hello", messages[0].Text); + Assert.Equal("Hi there!", messages[1].Text); + } + + #endregion GetMessagesAsync Tests + + #region OnNewMessagesAsync Tests + + [Fact] + public async Task OnNewMessagesAsyncDoesNothingWhenAgentServiceIdAsync() + { + // Arrange + var thread = new AgentThread() { ConversationId = "thread-123" }; + var messages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi there!") + }; + + // Act + await thread.OnNewMessagesAsync(messages, CancellationToken.None); + } + + [Fact] + public async Task OnNewMessagesAsyncAddsMessagesToStoreAsync() + { + // Arrange + var store = new InMemoryChatMessageStore(); + var thread = new AgentThread() { MessageStore = store }; + var messages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi there!") + }; + + // Act + await thread.OnNewMessagesAsync(messages, CancellationToken.None); + + // Assert + Assert.Equal(2, store.Count); + Assert.Equal("Hello", store[0].Text); + Assert.Equal("Hi there!", store[1].Text); + } + + #endregion OnNewMessagesAsync Tests + + #region Deserialize Tests + + [Fact] + public async Task VerifyDeserializeWithMessagesAsync() + { + // Arrange + var chatMessageStore = new InMemoryChatMessageStore(); + var json = JsonSerializer.Deserialize(""" + { + "storeState": { "messages": [{"authorName": "testAuthor"}] } + } + """); + var thread = new AgentThread() { MessageStore = chatMessageStore }; + + // Act. + await thread.DeserializeAsync(json); + + // Assert + Assert.Null(thread.ConversationId); + + Assert.Single(chatMessageStore); + Assert.Equal("testAuthor", chatMessageStore[0].AuthorName); + } + + [Fact] + public async Task VerifyDeserializeWithIdAsync() + { + // Arrange + var json = JsonSerializer.Deserialize(""" + { + "conversationId": "TestConvId" + } + """); + var thread = new AgentThread(); + + // Act + await thread.DeserializeAsync(json); + + // Assert + Assert.Equal("TestConvId", thread.ConversationId); + Assert.Null(thread.MessageStore); + } + + [Fact] + public async Task DeserializeWithInvalidJsonThrowsAsync() + { + // Arrange + var invalidJson = JsonSerializer.Deserialize("[42]"); + var thread = new AgentThread(); + + // Act & Assert + await Assert.ThrowsAsync(() => thread.DeserializeAsync(invalidJson)); + } + + #endregion Deserialize Tests + + #region Serialize Tests + + /// + /// Verify thread serialization to JSON when the thread has an id. + /// + [Fact] + public async Task VerifyThreadSerializationWithIdAsync() + { + // Arrange + var thread = new AgentThread() { ConversationId = "TestConvId" }; + + // Act + var json = await thread.SerializeAsync(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + + Assert.True(json.TryGetProperty("conversationId", out var idProperty)); + Assert.Equal("TestConvId", idProperty.GetString()); + + Assert.False(json.TryGetProperty("storeState", out var storeStateProperty)); + } + + /// + /// Verify thread serialization to JSON when the thread has messages. + /// + [Fact] + public async Task VerifyThreadSerializationWithMessagesAsync() + { + // Arrange + var store = new InMemoryChatMessageStore(); + store.Add(new ChatMessage(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }); + var thread = new AgentThread() { MessageStore = store }; + + // Act + var json = await thread.SerializeAsync(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + + Assert.False(json.TryGetProperty("conversationId", out var idProperty)); + + Assert.True(json.TryGetProperty("storeState", out var storeStateProperty)); + Assert.Equal(JsonValueKind.Object, storeStateProperty.ValueKind); + + Assert.True(storeStateProperty.TryGetProperty("messages", out var messagesProperty)); + Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); + Assert.Single(messagesProperty.EnumerateArray()); + + var message = messagesProperty.EnumerateArray().First(); + Assert.Equal("TestAuthor", message.GetProperty("authorName").GetString()); + Assert.True(message.TryGetProperty("contents", out var contentsProperty)); + Assert.Equal(JsonValueKind.Array, contentsProperty.ValueKind); + Assert.Single(contentsProperty.EnumerateArray()); + + var textContent = contentsProperty.EnumerateArray().First(); + Assert.Equal("TestContent", textContent.GetProperty("text").GetString()); + } + + /// + /// Verify thread serialization to JSON with custom options. + /// + [Fact] + public async Task VerifyThreadSerializationWithCustomOptionsAsync() + { + // Arrange + var thread = new AgentThread(); + JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower }; + options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); + + var storeStateElement = JsonSerializer.SerializeToElement(new { Key = "TestValue" }); + + var messageStoreMock = new Mock(); + messageStoreMock + .Setup(m => m.SerializeStateAsync(options, It.IsAny())) + .ReturnsAsync(storeStateElement); + thread.MessageStore = messageStoreMock.Object; + + // Act + var json = await thread.SerializeAsync(options); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + + Assert.False(json.TryGetProperty("conversationId", out var idProperty)); + + Assert.True(json.TryGetProperty("storeState", out var storeStateProperty)); + Assert.Equal(JsonValueKind.Object, storeStateProperty.ValueKind); + + Assert.True(storeStateProperty.TryGetProperty("Key", out var keyProperty)); + Assert.Equal("TestValue", keyProperty.GetString()); + + messageStoreMock.Verify(m => m.SerializeStateAsync(options, It.IsAny()), Times.Once); + } + + #endregion Serialize Tests +} diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/InMemoryChatMessageStoreTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/InMemoryChatMessageStoreTests.cs new file mode 100644 index 0000000000..db7c975c65 --- /dev/null +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/InMemoryChatMessageStoreTests.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public class InMemoryChatMessageStoreTests +{ + [Fact] + public async Task AddMessagesAsyncAddsMessagesAndReturnsNullThreadIdAsync() + { + var store = new InMemoryChatMessageStore(); + var messages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi there!") + }; + + await store.AddMessagesAsync(messages, CancellationToken.None); + + Assert.Equal(2, store.Count); + Assert.Equal("Hello", store[0].Text); + Assert.Equal("Hi there!", store[1].Text); + } + + [Fact] + public async Task AddMessagesAsyncWithEmptyDoesNotFailAsync() + { + var store = new InMemoryChatMessageStore(); + + await store.AddMessagesAsync([], CancellationToken.None); + + Assert.Empty(store); + } + + [Fact] + public async Task GetMessagesAsyncReturnsAllMessagesAsync() + { + var store = new InMemoryChatMessageStore + { + new ChatMessage(ChatRole.User, "Test1"), + new ChatMessage(ChatRole.Assistant, "Test2") + }; + + var result = (await store.GetMessagesAsync(CancellationToken.None)).ToList(); + + Assert.Equal(2, result.Count); + Assert.Contains(result, m => m.Text == "Test1"); + Assert.Contains(result, m => m.Text == "Test2"); + } + + [Fact] + public async Task DeserializeWithEmptyElementAsync() + { + var newStore = new InMemoryChatMessageStore(); + + var emptyObject = JsonSerializer.Deserialize("{}"); + + await newStore.DeserializeStateAsync(emptyObject); + + Assert.Empty(newStore); + } + + [Fact] + public async Task SerializeAndDeserializeRoundtripsAsync() + { + var store = new InMemoryChatMessageStore + { + new ChatMessage(ChatRole.User, "A"), + new ChatMessage(ChatRole.Assistant, "B") + }; + + var jsonElement = await store.SerializeStateAsync(); + var newStore = new InMemoryChatMessageStore(); + + await newStore.DeserializeStateAsync(jsonElement); + + Assert.Equal(2, newStore.Count); + Assert.Equal("A", newStore[0].Text); + Assert.Equal("B", newStore[1].Text); + } + + [Fact] + public async Task AddMessagesAsyncWithEmptyMessagesDoesNotChangeStoreAsync() + { + var store = new InMemoryChatMessageStore(); + var messages = new List(); + + await store.AddMessagesAsync(messages, CancellationToken.None); + + Assert.Empty(store); + } +} diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs index aa3bfd892d..bae62f4d15 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs @@ -321,7 +321,7 @@ public class ChatClientAgentTests ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" }); - ChatClientAgentThread thread = new("ConvId"); + AgentThread thread = new() { ConversationId = "ConvId" }; // Act & Assert await agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions); @@ -340,7 +340,7 @@ public class ChatClientAgentTests ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" }); - ChatClientAgentThread thread = new("ThreadId"); + AgentThread thread = new() { ConversationId = "ThreadId" }; // Act & Assert await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions)); @@ -363,7 +363,7 @@ public class ChatClientAgentTests ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" }); - ChatClientAgentThread thread = new("ConvId"); + AgentThread thread = new() { ConversationId = "ConvId" }; // Act await agent.RunAsync([new(ChatRole.User, "test")], thread, chatOptions: chatOptions); @@ -388,7 +388,7 @@ public class ChatClientAgentTests ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" }); - ChatClientAgentThread thread = new("ConvId"); + AgentThread thread = new() { ConversationId = "ConvId" }; // Act & Assert await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread)); diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentThreadTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentThreadTests.cs deleted file mode 100644 index 9fd357eea0..0000000000 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentThreadTests.cs +++ /dev/null @@ -1,978 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; -using Moq; - -#pragma warning disable CS0162 // Unreachable code detected - -namespace Microsoft.Extensions.AI.Agents.UnitTests.ChatCompletion; - -public class ChatClientAgentThreadTests -{ - /// - /// Verify that implements . - /// - [Fact] - public void VerifyChatClientAgentThreadImplementsIMessagesRetrievableThread() - { - // Arrange & Act - var thread = new ChatClientAgentThread(); - - // Assert - Assert.IsType(thread, exactMatch: false); - Assert.IsType(thread, exactMatch: false); - } - - /// - /// Verify that can retrieve messages through . - /// This test verifies the interface works correctly when no messages have been added. - /// - [Fact] - public async Task VerifyIMessagesRetrievableThreadGetMessagesAsyncWhenEmptyAsync() - { - // Arrange - var thread = new ChatClientAgentThread(); - - // Act - Retrieve messages when thread is empty - var retrievedMessages = new List(); - await foreach (var message in thread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Assert.Empty(retrievedMessages); - } - - /// - /// Verify that can retrieve messages through . - /// This test verifies the interface works correctly when messages have been added via ChatClientAgent. - /// - [Fact] - public async Task VerifyIMessagesRetrievableThreadGetMessagesAsyncWhenNotEmptyAsync() - { - // Arrange - var userMessage = new ChatMessage(ChatRole.User, "Hello, how are you?"); - var assistantMessage = new ChatMessage(ChatRole.Assistant, "I'm doing well, thank you!"); - - // Mock IChatClient to return the assistant message - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .ReturnsAsync(new ChatResponse([assistantMessage])); - - // Create ChatClientAgent with the mocked client - var agent = new ChatClientAgent(mockChatClient.Object, options: new() - { - Instructions = "You are a helpful assistant" - }); - - // Get a new thread from the agent - var thread = agent.GetNewThread(); - - // Run the agent again with the thread to populate it with messages - var responseWithThread = await agent.RunAsync([userMessage], thread); - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - - // Retrieve messages through the interface - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Assert.NotEmpty(retrievedMessages); - - // Verify that the messages include the assistant response - Assert.Collection(retrievedMessages, - m => Assert.Equal(ChatRole.User, m.Role), - m => Assert.Equal(ChatRole.Assistant, m.Role)); - - // Verify the content matches what we expect - Assert.Contains(retrievedMessages, m => m.Text == "Hello, how are you?" && m.Role == ChatRole.User); - Assert.Contains(retrievedMessages, m => m.Text == "I'm doing well, thank you!" && m.Role == ChatRole.Assistant); - } - - /// - /// Verify that works with cancellation token. - /// - [Fact] - public async Task VerifyGetMessagesAsyncWithCancellationTokenAsync() - { - // Arrange - var thread = new ChatClientAgentThread(); - using var cts = new CancellationTokenSource(); - - // Act - Test that GetMessagesAsync accepts cancellation token without throwing - var retrievedMessages = new List(); - await foreach (var msg in thread.GetMessagesAsync(cts.Token)) - { - retrievedMessages.Add(msg); - } - - // Assert - Should return empty list when no messages - Assert.Empty(retrievedMessages); - } - - /// - /// Verify that initializes with expected default values. - /// - [Fact] - public void VerifyThreadInitialState() - { - // Arrange & Act - var thread = new ChatClientAgentThread(); - - // Assert - Assert.Null(thread.Id); // Id should be null until created on first use. - Assert.Null(thread.StorageLocation); // StorageLocation should be null until first use - } - - /// - /// Verify that initializes with expected default values. - /// - [Fact] - public async Task VerifyThreadWithMessagesInitialStateAsync() - { - // Arrange - var message = new ChatMessage(ChatRole.User, "Hello"); - - // Act - var thread = new ChatClientAgentThread([message]); - - // Assert - Assert.Null(thread.Id); // Id should be null when we add messages, since it's a local thread. - Assert.Equal(ChatClientAgentThreadType.InMemoryMessages, thread.StorageLocation); // StorageLocation should be set to local since we are adding messages already. - - var messages = await thread.GetMessagesAsync().ToListAsync(); - Assert.Contains(message, messages); - } - - /// - /// Verify that initializes with expected default values. - /// - [Fact] - public async Task VerifyThreadWithIdInitialStateAsync() - { - // Act - var thread = new ChatClientAgentThread("TestConvId"); - - // Assert - Assert.Equal("TestConvId", thread.Id); - Assert.Equal(ChatClientAgentThreadType.ConversationId, thread.StorageLocation); - - var messages = await thread.GetMessagesAsync().ToListAsync(); - Assert.Empty(messages); - } - - #region Core Override Method Tests - - /// - /// Verify that thread creation generates a valid thread ID through integration with ChatClientAgent. - /// - [Fact] - public void ThreadCreationGeneratesValidThreadId() - { - // Arrange - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .ReturnsAsync(new ChatResponse([new ChatMessage(ChatRole.Assistant, "response")])); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - - // Act - var thread = agent.GetNewThread(); - - // Assert - Assert.NotNull(thread); - var chatClientAgentThread = Assert.IsType(thread); - Assert.Null(thread.Id); // Id should be null until created on first use. - Assert.Null(chatClientAgentThread.StorageLocation); // StorageLocation should be null until first use - } - - /// - /// Verify that thread creation generates unique instances. - /// - [Fact] - public void ThreadCreationGeneratesUniqueInstances() - { - // Arrange - var mockChatClient = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - - // Act - var thread1 = agent.GetNewThread(); - var thread2 = agent.GetNewThread(); - - // Assert - Assert.NotSame(thread1, thread2); - Assert.IsType(thread1); - Assert.IsType(thread2); - } - - /// - /// Verify that messages are properly stored and retrieved through the thread lifecycle. - /// - [Theory] - [InlineData(null, true)] - [InlineData("TestConvid", false)] - public async Task ThreadLifecycleStoresAndRetrievesMessagesAsync(string? responseConversationId, bool messagesStored) - { - // Arrange - var userMessage = new ChatMessage(ChatRole.User, "Hello"); - var assistantMessage = new ChatMessage(ChatRole.Assistant, "Hi there!"); - - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .ReturnsAsync(new ChatResponse([assistantMessage]) { ConversationId = responseConversationId }); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new() { Instructions = "Test instructions" }); - - // Act - var thread = agent.GetNewThread(); - - // Run the agent to populate the thread with messages - await agent.RunAsync([userMessage], thread); - - // Retrieve messages from the thread - var retrievedMessages = new List(); - await foreach (var message in ((IMessagesRetrievableThread)thread).GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Assert.Equal(messagesStored ? 2 : 0, retrievedMessages.Count); - if (messagesStored) - { - Assert.Contains(retrievedMessages, m => m.Text == "Hello" && m.Role == ChatRole.User); - Assert.Contains(retrievedMessages, m => m.Text == "Hi there!" && m.Role == ChatRole.Assistant); - } - - var chatClientAgentThread = Assert.IsType(thread); - Assert.Equal(responseConversationId, thread.Id); // Id should match the returned conversation id. - Assert.Equal( - messagesStored - ? ChatClientAgentThreadType.InMemoryMessages - : ChatClientAgentThreadType.ConversationId, - chatClientAgentThread.StorageLocation); // StorageLocation should be based on whether we got back a conversation id - } - - /// - /// Verify that multiple messages can be added and retrieved in order. - /// - [Fact] - public async Task ThreadMessageHandlingHandlesMultipleMessagesInOrderAsync() - { - // Arrange - var messages = new[] - { - new ChatMessage(ChatRole.User, "First message"), - new ChatMessage(ChatRole.Assistant, "First response"), - new ChatMessage(ChatRole.User, "Second message"), - new ChatMessage(ChatRole.Assistant, "Second response") - }; - - var mockChatClient = new Mock(); - mockChatClient.SetupSequence( - c => c.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .ReturnsAsync(new ChatResponse([messages[1]])) - .ReturnsAsync(new ChatResponse([messages[3]])); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - var thread = agent.GetNewThread(); - - // Act - Add messages through multiple agent runs - await agent.RunAsync([messages[0]], thread); - await agent.RunAsync([messages[2]], thread); - - // Assert - Verify all messages are stored in order - var retrievedMessages = new List(); - await foreach (var message in ((IMessagesRetrievableThread)thread).GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - Assert.Equal(4, retrievedMessages.Count); - Assert.Equal("First message", retrievedMessages[0].Text); - Assert.Equal("First response", retrievedMessages[1].Text); - Assert.Equal("Second message", retrievedMessages[2].Text); - Assert.Equal("Second response", retrievedMessages[3].Text); - } - - #endregion - - #region RunStreamingAsync Thread Notification Tests - - /// - /// Verify that thread is notified of both input and response messages when invoking the streaming API with RunStreamingAsync. - /// - [Fact] - public async Task VerifyThreadNotificationDuringStreamingAsync() - { - // Arrange - var userMessage = new ChatMessage(ChatRole.User, "Hello, streaming!"); - var assistantMessage = new ChatMessage(ChatRole.Assistant, "Hi there, streaming response!"); - - // Create streaming response updates - ChatResponseUpdate[] returnUpdates = - [ - new ChatResponseUpdate(role: ChatRole.Assistant, content: "Hi there, "), - new ChatResponseUpdate(role: null, content: "streaming response!"), - ]; - - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Returns(returnUpdates.ToAsyncEnumerable()); - - // Create ChatClientAgent with the mocked client - var agent = new ChatClientAgent(mockChatClient.Object, options: new() - { - Instructions = "You are a helpful assistant" - }); - - // Get a new thread from the agent - var thread = agent.GetNewThread(); - - // Act - Run the agent with streaming to populate the thread with messages - var streamingResults = new List(); - await foreach (var update in agent.RunStreamingAsync([userMessage], thread)) - { - streamingResults.Add(update); - } - - // Assert - Verify streaming worked - Assert.Equal(2, streamingResults.Count); - - // Retrieve messages from the thread to verify notification occurred - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Verify that the thread was notified and contains both user and assistant messages - Assert.NotEmpty(retrievedMessages); - Assert.Equal(2, retrievedMessages.Count); - Assert.Contains(retrievedMessages, m => m.Text == "Hello, streaming!" && m.Role == ChatRole.User); - Assert.Contains(retrievedMessages, m => m.Text == "Hi there, streaming response!" && m.Role == ChatRole.Assistant); - } - - /// - /// Verify that thread accumulates both input and response messages across multiple streaming calls. - /// - [Fact] - public async Task VerifyThreadAccumulatesMessagesAcrossMultipleStreamingCallsAsync() - { - // Arrange - var firstUserMessage = new ChatMessage(ChatRole.User, "First streaming message"); - var secondUserMessage = new ChatMessage(ChatRole.User, "Second streaming message"); - - // Create streaming response updates for first call - ChatResponseUpdate[] firstReturnUpdates = - [ - new ChatResponseUpdate(role: ChatRole.Assistant, content: "First "), - new ChatResponseUpdate(role: null, content: "response"), - ]; - - // Create streaming response updates for second call - ChatResponseUpdate[] secondReturnUpdates = - [ - new ChatResponseUpdate(role: ChatRole.Assistant, content: "Second "), - new ChatResponseUpdate(role: null, content: "response"), - ]; - - var mockChatClient = new Mock(); - mockChatClient.SetupSequence( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Returns(firstReturnUpdates.ToAsyncEnumerable()) - .Returns(secondReturnUpdates.ToAsyncEnumerable()); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - var thread = agent.GetNewThread(); - - // Act - Make two streaming calls - var firstStreamingResults = new List(); - await foreach (var update in agent.RunStreamingAsync([firstUserMessage], thread)) - { - firstStreamingResults.Add(update); - } - - var secondStreamingResults = new List(); - await foreach (var update in agent.RunStreamingAsync([secondUserMessage], thread)) - { - secondStreamingResults.Add(update); - } - - // Assert - Verify both streaming calls worked - Assert.Equal(2, firstStreamingResults.Count); - Assert.Equal(2, secondStreamingResults.Count); - - // Retrieve all messages from the thread - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Verify that the thread contains all messages in order - Assert.Equal(4, retrievedMessages.Count); - Assert.Equal("First streaming message", retrievedMessages[0].Text); - Assert.Equal("First response", retrievedMessages[1].Text); - Assert.Equal("Second streaming message", retrievedMessages[2].Text); - Assert.Equal("Second response", retrievedMessages[3].Text); - } - - /// - /// Verify that thread notification works correctly when streaming with existing thread messages. - /// Both RunAsync and RunStreamingAsync should add both input and response messages to the thread. - /// - [Fact] - public async Task VerifyStreamingWithExistingThreadMessagesAsync() - { - // Arrange - var initialUserMessage = new ChatMessage(ChatRole.User, "Initial message"); - var initialAssistantMessage = new ChatMessage(ChatRole.Assistant, "Initial response"); - var newUserMessage = new ChatMessage(ChatRole.User, "New streaming message"); - - // Setup for initial non-streaming call - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .ReturnsAsync(new ChatResponse([initialAssistantMessage])); - - // Setup for streaming call - ChatResponseUpdate[] streamingUpdates = - [ - new ChatResponseUpdate(role: ChatRole.Assistant, content: "Streaming "), - new ChatResponseUpdate(role: null, content: "response"), - ]; - - mockChatClient.Setup( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Returns(streamingUpdates.ToAsyncEnumerable()); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - var thread = agent.GetNewThread(); - - // Act - First, make a regular call to populate the thread - await agent.RunAsync([initialUserMessage], thread); - - // Then make a streaming call - var streamingResults = new List(); - await foreach (var update in agent.RunStreamingAsync([newUserMessage], thread)) - { - streamingResults.Add(update); - } - - // Assert - Verify streaming worked - Assert.Equal(2, streamingResults.Count); - - // Retrieve all messages from the thread - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Verify that the thread contains all messages including the new streaming ones - Assert.Equal(4, retrievedMessages.Count); - Assert.Equal("Initial message", retrievedMessages[0].Text); - Assert.Equal("Initial response", retrievedMessages[1].Text); - Assert.Equal("New streaming message", retrievedMessages[2].Text); - Assert.Equal("Streaming response", retrievedMessages[3].Text); - } - - /// - /// Verify that thread is notified of input messages even when zero streaming updates are received. - /// - [Fact] - public async Task VerifyThreadNotificationWithZeroStreamingUpdatesAsync() - { - // Arrange - var userMessage = new ChatMessage(ChatRole.User, "Hello with no response!"); - - // Create empty streaming response (no updates) - ChatResponseUpdate[] returnUpdates = []; - - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Returns(returnUpdates.ToAsyncEnumerable()); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - var thread = agent.GetNewThread(); - - // Act - Run the agent with streaming that returns no updates - var streamingResults = new List(); - await foreach (var update in agent.RunStreamingAsync([userMessage], thread)) - { - streamingResults.Add(update); - } - - // Assert - Verify no streaming updates were received - Assert.Empty(streamingResults); - - // Retrieve messages from the thread to verify notification occurred - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Verify that the thread was notified of input messages even with zero updates - // The fallback mechanism should ensure input messages are added to the thread - Assert.Single(retrievedMessages); - Assert.Contains(retrievedMessages, m => m.Text == "Hello with no response!" && m.Role == ChatRole.User); - } - - /// - /// Verify that thread is notified of input messages only once even with multiple streaming updates. - /// - [Fact] - public async Task VerifyThreadNotificationWithMultipleStreamingUpdatesAsync() - { - // Arrange - var userMessage = new ChatMessage(ChatRole.User, "Hello with many updates!"); - - // Create multiple streaming response updates - ChatResponseUpdate[] returnUpdates = - [ - new ChatResponseUpdate(role: ChatRole.Assistant, content: "First "), - new ChatResponseUpdate(role: null, content: "update, "), - new ChatResponseUpdate(role: null, content: "second "), - new ChatResponseUpdate(role: null, content: "update, "), - new ChatResponseUpdate(role: null, content: "third "), - new ChatResponseUpdate(role: null, content: "update!"), - ]; - - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Returns(returnUpdates.ToAsyncEnumerable()); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - var thread = agent.GetNewThread(); - - // Act - Run the agent with streaming that returns multiple updates - var streamingResults = new List(); - await foreach (var update in agent.RunStreamingAsync([userMessage], thread)) - { - streamingResults.Add(update); - } - - // Assert - Verify all streaming updates were received - Assert.Equal(6, streamingResults.Count); - - // Retrieve messages from the thread to verify notification occurred - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Verify that the thread contains both input and response messages - // Input message should be added only once despite multiple updates - Assert.Equal(2, retrievedMessages.Count); - Assert.Contains(retrievedMessages, m => m.Text == "Hello with many updates!" && m.Role == ChatRole.User); - Assert.Contains(retrievedMessages, m => m.Text == "First update, second update, third update!" && m.Role == ChatRole.Assistant); - } - - /// - /// Verify that thread is NOT notified of input messages when an exception occurs during streaming. - /// - [Fact] - public async Task VerifyThreadNotNotifiedWhenStreamingThrowsExceptionAsync() - { - // Arrange - var userMessage = new ChatMessage(ChatRole.User, "Hello that will fail!"); - - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Throws(new InvalidOperationException("Streaming failed")); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - var thread = agent.GetNewThread(); - - // Act & Assert - Verify that streaming throws an exception - await Assert.ThrowsAsync(async () => - { - await foreach (var update in agent.RunStreamingAsync([userMessage], thread)) - { - Assert.Fail("Should not yield updates."); - } - }); - - // Retrieve messages from the thread to verify NO notification occurred - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Verify that the thread was NOT notified of any messages due to the exception - // This ensures that failed operations don't leave the thread in an inconsistent state - Assert.Empty(retrievedMessages); - } - - /// - /// Verify that thread is NOT notified of input messages when an exception occurs after some streaming updates. - /// - [Fact] - public async Task VerifyThreadNotNotifiedWhenStreamingThrowsExceptionAfterUpdatesAsync() - { - // Arrange - var userMessage = new ChatMessage(ChatRole.User, "Hello that will partially fail!"); - - // Create an async enumerable that yields some updates then throws - static async IAsyncEnumerable GetUpdatesWithExceptionAsync() - { - await Task.CompletedTask; // Simulate async operation - throw new InvalidOperationException("Streaming failed after partial response"); - yield break; - } - - var mockChatClient = new Mock(); - mockChatClient.Setup( - c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Returns(GetUpdatesWithExceptionAsync()); - - var agent = new ChatClientAgent(mockChatClient.Object, options: new()); - var thread = agent.GetNewThread(); - - // Act & Assert - Verify that streaming throws an exception after some updates - var streamingResults = new List(); - await Assert.ThrowsAsync(async () => - { - await foreach (var update in agent.RunStreamingAsync([userMessage], thread)) - { - streamingResults.Add(update); - } - }); - - // Verify that some updates were received before the exception - Assert.Empty(streamingResults); - - // Retrieve messages from the thread to verify NO notification occurred - var messagesRetrievableThread = (IMessagesRetrievableThread)thread; - var retrievedMessages = new List(); - await foreach (var message in messagesRetrievableThread.GetMessagesAsync()) - { - retrievedMessages.Add(message); - } - - // Assert - Verify that the thread was NOT notified of any messages due to the exception - // Even though some updates were received, the exception should prevent thread notification - Assert.Empty(retrievedMessages); - } - - #endregion - - #region JSON Serialization Tests - - /// - /// Verify that can be serialized to JSON and deserialized back correctly. - /// - [Fact] - public void VerifyJsonSerializationRoundTrip_DefaultThread() - { - // Arrange - var originalThread = new ChatClientAgentThread(); - - // Act - string json = JsonSerializer.Serialize(originalThread); - var deserializedThread = JsonSerializer.Deserialize(json); - - // Assert - Assert.NotNull(deserializedThread); - Assert.Equal(originalThread.Id, deserializedThread.Id); - Assert.Equal(originalThread.StorageLocation, deserializedThread.StorageLocation); - } - - /// - /// Verify that with ID can be serialized to JSON and deserialized back correctly. - /// - [Fact] - public void VerifyJsonSerializationRoundTrip_ThreadWithId() - { - // Arrange - var originalThread = new ChatClientAgentThread("test-conversation-id"); - - // Act - string json = JsonSerializer.Serialize(originalThread); - var deserializedThread = JsonSerializer.Deserialize(json); - - // Assert - Assert.NotNull(deserializedThread); - Assert.Equal("test-conversation-id", deserializedThread.Id); - Assert.Equal(ChatClientAgentThreadType.ConversationId, deserializedThread.StorageLocation); - } - - /// - /// Verify that with messages can be serialized to JSON and deserialized back correctly. - /// - [Fact] - public async Task VerifyJsonSerializationRoundTrip_ThreadWithMessagesAsync() - { - // Arrange - var messages = new[] - { - new ChatMessage(ChatRole.User, "Hello, world!"), - new ChatMessage(ChatRole.Assistant, "Hi there! How can I help you?"), - new ChatMessage(ChatRole.User, "What's the weather like?") - }; - var originalThread = new ChatClientAgentThread(messages); - - // Act - string json = JsonSerializer.Serialize(originalThread); - var deserializedThread = JsonSerializer.Deserialize(json); - - // Assert - Assert.NotNull(deserializedThread); - Assert.Equal(originalThread.Id, deserializedThread.Id); - Assert.Equal(ChatClientAgentThreadType.InMemoryMessages, deserializedThread.StorageLocation); - - // Verify messages are preserved - var originalMessages = await originalThread.GetMessagesAsync().ToListAsync(); - var deserializedMessages = await deserializedThread.GetMessagesAsync().ToListAsync(); - - Assert.Equal(originalMessages.Count, deserializedMessages.Count); - for (int i = 0; i < originalMessages.Count; i++) - { - Assert.Equal(originalMessages[i].Role, deserializedMessages[i].Role); - Assert.Equal(originalMessages[i].Text, deserializedMessages[i].Text); - } - } - - /// - /// Verify that serialization handles null properties correctly. - /// - [Fact] - public void VerifyJsonSerialization_HandlesNullProperties() - { - // Arrange - var thread = new ChatClientAgentThread(); - - // Act - string json = JsonSerializer.Serialize(thread); - - // Assert - StorageLocation is no longer serialized independently - Assert.DoesNotContain("storageLocation", json, StringComparison.OrdinalIgnoreCase); - - // Verify deserialization handles empty JSON correctly - var deserializedThread = JsonSerializer.Deserialize(json); - Assert.NotNull(deserializedThread); - Assert.Null(deserializedThread.StorageLocation); - } - - /// - /// Verify that serialization only includes messages for InMemoryMessages storage type. - /// - [Fact] - public void VerifyJsonSerialization_OnlyIncludesMessagesForInMemoryStorage() - { - // Arrange - Create thread with conversation ID (server-side storage) - var threadWithId = new ChatClientAgentThread("test-id"); - - // Act - string json = JsonSerializer.Serialize(threadWithId); - - // Assert - Messages should not be included for ConversationId storage, and storageLocation is not serialized - Assert.DoesNotContain("\"messages\"", json, StringComparison.OrdinalIgnoreCase); - Assert.DoesNotContain("\"storageLocation\"", json, StringComparison.OrdinalIgnoreCase); - Assert.Contains("\"id\":\"test-id\"", json, StringComparison.OrdinalIgnoreCase); - } - - /// - /// Verify that serialization includes messages for InMemoryMessages storage type. - /// - [Fact] - public void VerifyJsonSerialization_IncludesMessagesForInMemoryStorage() - { - // Arrange - Create thread with messages (in-memory storage) - var messages = new[] { new ChatMessage(ChatRole.User, "Test message") }; - var threadWithMessages = new ChatClientAgentThread(messages); - - // Act - string json = JsonSerializer.Serialize(threadWithMessages); - - // Assert - Messages should be included for InMemoryMessages storage, but storageLocation is not serialized - Assert.Contains("\"messages\"", json, StringComparison.OrdinalIgnoreCase); - Assert.DoesNotContain("\"storageLocation\"", json, StringComparison.OrdinalIgnoreCase); - Assert.Contains("Test message", json, StringComparison.OrdinalIgnoreCase); - } - - /// - /// Verify that deserialization handles missing properties gracefully. - /// - [Fact] - public void VerifyJsonDeserialization_HandlesMissingProperties() - { - // Arrange - JSON with minimal properties - string minimalJson = "{}"; - - // Act - var thread = JsonSerializer.Deserialize(minimalJson); - - // Assert - Assert.NotNull(thread); - Assert.Null(thread.Id); - Assert.Null(thread.StorageLocation); - } - - /// - /// Verify that deserialization handles invalid JSON gracefully. - /// - [Fact] - public void VerifyJsonDeserialization_HandlesMalformedJson() - { - // Arrange - Invalid JSON structure -#pragma warning disable JSON001 // Invalid JSON pattern - string invalidJson = "{ invalid json"; -#pragma warning restore JSON001 // Invalid JSON pattern - - // Act & Assert - Assert.Throws(() => JsonSerializer.Deserialize(invalidJson)); - } - - /// - /// Verify that deserialization handles invalid storage location values. - /// This test is no longer relevant since storageLocation is not independently deserialized. - /// - [Fact] - public void VerifyJsonDeserialization_HandlesInvalidStorageLocation() - { - // Arrange - JSON with ID (which will set storage location to ConversationId) - string jsonWithId = @"{""id"":""test""}"; - - // Act - var thread = JsonSerializer.Deserialize(jsonWithId); - - // Assert - Storage location is determined by presence of ID - Assert.NotNull(thread); - Assert.Equal("test", thread.Id); - Assert.Equal(ChatClientAgentThreadType.ConversationId, thread.StorageLocation); - } - - /// - /// Verify that deserialization preserves messages correctly. - /// - [Fact] - public async Task VerifyJsonDeserialization_PreservesMessagesCorrectlyAsync() - { - // Arrange - Create a thread with messages and serialize it to get the correct format - var originalMessages = new[] - { - new ChatMessage(ChatRole.User, "Hello"), - new ChatMessage(ChatRole.Assistant, "Hi there!") - }; - var originalThread = new ChatClientAgentThread(originalMessages); - - // Serialize to get the actual format, then deserialize - string json = JsonSerializer.Serialize(originalThread); - var thread = JsonSerializer.Deserialize(json); - - // Assert - Assert.NotNull(thread); - Assert.Equal(ChatClientAgentThreadType.InMemoryMessages, thread.StorageLocation); - - var messages = await thread.GetMessagesAsync().ToListAsync(); - Assert.Equal(2, messages.Count); - Assert.Equal(ChatRole.User, messages[0].Role); - Assert.Equal("Hello", messages[0].Text); - Assert.Equal(ChatRole.Assistant, messages[1].Role); - Assert.Equal("Hi there!", messages[1].Text); - } - - /// - /// Verify that serialization and deserialization works with complex message content. - /// - [Fact] - public async Task VerifyJsonSerializationRoundTrip_ComplexMessageContentAsync() - { - // Arrange - Create messages with various content types - var messages = new[] - { - new ChatMessage(ChatRole.User, "Simple text message"), - new ChatMessage(ChatRole.Assistant, [ - new TextContent("Mixed content: "), - new TextContent("multiple parts") - ]), - new ChatMessage(ChatRole.User, "Message with special characters: ñáéíóú !@#$%^&*()") - }; - var originalThread = new ChatClientAgentThread(messages); - - // Act - string json = JsonSerializer.Serialize(originalThread); - var deserializedThread = JsonSerializer.Deserialize(json); - - // Assert - Assert.NotNull(deserializedThread); - - var originalMessages = await originalThread.GetMessagesAsync().ToListAsync(); - var deserializedMessages = await deserializedThread.GetMessagesAsync().ToListAsync(); - - Assert.Equal(originalMessages.Count, deserializedMessages.Count); - - // Verify complex content is preserved - for (int i = 0; i < originalMessages.Count; i++) - { - Assert.Equal(originalMessages[i].Role, deserializedMessages[i].Role); - Assert.Equal(originalMessages[i].Text, deserializedMessages[i].Text); - } - } - - #endregion -} diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/OpenTelemetryAgentTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/OpenTelemetryAgentTests.cs index 653fde0eef..577ca94253 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/OpenTelemetryAgentTests.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/OpenTelemetryAgentTests.cs @@ -388,7 +388,7 @@ public class OpenTelemetryAgentTests new(ChatRole.User, "Hello") }; - var thread = new AgentThread { Id = "thread-123" }; + var thread = new AgentThread { ConversationId = "thread-123" }; // Act await telemetryAgent.RunAsync(messages, thread); diff --git a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs index 8fbd3fa314..b5a44a7e6c 100644 --- a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs +++ b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Threading.Tasks; using AgentConformance.IntegrationTests; @@ -28,13 +27,8 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture public async Task> GetChatHistoryAsync(AgentThread thread) { - if (thread is not ChatClientAgentThread chatClientThread) - { - throw new InvalidOperationException("The thread must be of type ChatClientAgentThread to retrieve chat history."); - } - List messages = []; - await foreach (var agentMessage in this._assistantClient!.GetMessagesAsync(chatClientThread.Id, new() { Order = MessageCollectionOrder.Ascending })) + await foreach (var agentMessage in this._assistantClient!.GetMessagesAsync(thread.ConversationId, new() { Order = MessageCollectionOrder.Ascending })) { messages.Add(new() { @@ -79,9 +73,9 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture public Task DeleteThreadAsync(AgentThread thread) { - if (thread?.Id is not null) + if (thread?.ConversationId is not null) { - return this._assistantClient!.DeleteThreadAsync(thread.Id); + return this._assistantClient!.DeleteThreadAsync(thread.ConversationId); } return Task.CompletedTask; diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index 7edec6af58..da5324e60e 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -33,12 +32,7 @@ public class OpenAIChatCompletionFixture : IChatClientAgentFixture public async Task> GetChatHistoryAsync(AgentThread thread) { - if (thread is not ChatClientAgentThread chatClientThread) - { - throw new InvalidOperationException("The thread must be of type ChatClientAgentThread to retrieve chat history."); - } - - return await chatClientThread.GetMessagesAsync().ToListAsync(); + return await thread.GetMessagesAsync().ToListAsync(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index 3acb3b3cb5..77b964559b 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -29,15 +29,10 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture public async Task> GetChatHistoryAsync(AgentThread thread) { - if (thread is not ChatClientAgentThread chatClientThread) - { - throw new InvalidOperationException("The thread must be of type ChatClientAgentThread to retrieve chat history."); - } - if (store) { - var inputItems = await this._openAIResponseClient.GetResponseInputItemsAsync(chatClientThread.Id).ToListAsync(); - var response = await this._openAIResponseClient.GetResponseAsync(chatClientThread.Id); + var inputItems = await this._openAIResponseClient.GetResponseInputItemsAsync(thread.ConversationId).ToListAsync(); + var response = await this._openAIResponseClient.GetResponseAsync(thread.ConversationId); var responseItem = response.Value.OutputItems.FirstOrDefault()!; // Take the messages that were the chat history leading up to the current response @@ -55,7 +50,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture return [.. previousMessages, responseMessage]; } - return await chatClientThread.GetMessagesAsync().ToListAsync(); + return await thread.GetMessagesAsync().ToListAsync(); } private static ChatMessage ConvertToChatMessage(ResponseItem item)