From 66fe1c957ca297b87c8f2425deac3e2a2e700640 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Tue, 16 Sep 2025 10:54:18 +0100 Subject: [PATCH] .NET: Add AIContextProvider support (#691) * Add AIContextProvider support * Address feedback. * Address PR comments. * Switch to valuetask and remove parallel calls for AIContextProvider * Remove Model from ModelInvokingAsync method name * Remove agent thread id again and remove it from context provider interface * Add AIContextProvider serialization support to AgentThread and update sample to show this feature * Address PR comments * Improve memory sample * Update sample comment. * Remove AggregateAIContextProvider for now since it makes too many assumptions. We can include it later as a sample if needed. * Update AIContextProviders to have an Invoked method instead of MessagesAddingAsync. * Remove unused using. * Address PR comments. * Address PR comment. * Update comment. * Update comment * Address PR comments. --- dotnet/agent-framework-dotnet.slnx | 9 + dotnet/eng/MSBuild/LegacySupport.props | 8 + .../Program.cs | 6 +- .../Agent_Step13_Memory.csproj | 22 ++ .../Agents/Agent_Step13_Memory/Program.cs | 152 ++++++++++++ .../samples/GettingStarted/Agents/README.md | 3 +- .../CompilerFeatureRequiredAttribute.cs | 37 +++ .../README.md | 9 + .../RequiredMemberAttribute/README.md | 9 + .../RequiredMemberAttribute.cs | 10 + .../AIAgent.cs | 2 +- .../AIContext.cs | 43 ++++ .../AIContextProvider.cs | 118 +++++++++ .../AgentThread.cs | 23 +- ...t.Extensions.AI.Agents.Abstractions.csproj | 3 + .../ChatCompletion/ChatClientAgent.cs | 127 ++++++++-- .../ChatCompletion/ChatClientAgentOptions.cs | 10 +- .../AIAgentTests.cs | 2 +- .../AIContextProviderTests.cs | 56 +++++ .../AIContextTests.cs | 58 +++++ .../AgentThreadTests.cs | 51 +++- .../TestJsonSerializerContext.cs | 1 + .../ChatClientAgentOptionsTests.cs | 9 +- .../ChatCompletion/ChatClientAgentTests.cs | 231 ++++++++++++++++++ .../TestJsonSerializerContext.cs | 14 ++ 25 files changed, 981 insertions(+), 32 deletions(-) create mode 100644 dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj create mode 100644 dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs create mode 100644 dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs create mode 100644 dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md create mode 100644 dotnet/src/LegacySupport/RequiredMemberAttribute/README.md create mode 100644 dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs create mode 100644 dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs create mode 100644 dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs create mode 100644 dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs create mode 100644 dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs create mode 100644 dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs diff --git a/dotnet/agent-framework-dotnet.slnx b/dotnet/agent-framework-dotnet.slnx index 2b07e3346d..e662856548 100644 --- a/dotnet/agent-framework-dotnet.slnx +++ b/dotnet/agent-framework-dotnet.slnx @@ -50,6 +50,7 @@ + @@ -190,6 +191,10 @@ + + + + @@ -206,6 +211,10 @@ + + + + diff --git a/dotnet/eng/MSBuild/LegacySupport.props b/dotnet/eng/MSBuild/LegacySupport.props index c921bc59a2..54d65288bd 100644 --- a/dotnet/eng/MSBuild/LegacySupport.props +++ b/dotnet/eng/MSBuild/LegacySupport.props @@ -22,4 +22,12 @@ + + + + + + + + \ No newline at end of file diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index a2713c80d2..04f8cc8b24 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -39,8 +39,7 @@ namespace SampleApp List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList(); // Notify the thread of the input and output messages. - await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken); - await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken); + await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken); return new AgentRunResponse { @@ -59,8 +58,7 @@ namespace SampleApp List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList(); // Notify the thread of the input and output messages. - await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken); - await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken); + await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken); foreach (var message in responseMessages) { diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj new file mode 100644 index 0000000000..68b0f8d36b --- /dev/null +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj @@ -0,0 +1,22 @@ + + + + Exe + net9.0 + + enable + disable + + + + + + + + + + + + + + diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs new file mode 100644 index 0000000000..c91c81bd4b --- /dev/null +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft. All rights reserved. + +// This sample shows how to add a basic custom memory component to an agent. +// The memory component subscribes to all messages added to the conversation and +// extracts the user's name and age if provided. +// The component adds a prompt to ask for this information if it is not already known +// and provides it to the model before each invocation if known. + +using System; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.OpenAI; +using Azure.Identity; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI.Agents; +using OpenAI; +using OpenAI.Chat; +using SampleApp; + +var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new InvalidOperationException("AZURE_OPENAI_ENDPOINT is not set."); +var deploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOYMENT_NAME") ?? "gpt-4o-mini"; + +ChatClient chatClient = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + .GetChatClient(deploymentName); + +// Create the agent and provide a factory to add our custom memory component to +// all threads created by the agent. Here each new memory component will have its own +// user info object, so each thread will have its own memory. +// In real world applications/services, where the user info would be persisted in a database, +// and preferably shared between multiple threads used by the same user, ensure that the +// factory reads the user id from the current context and scopes the memory component +// and its storage to that user id. +AIAgent agent = chatClient.CreateAIAgent(new ChatClientAgentOptions() +{ + Instructions = "You are a friendly assistant. Always address the user by their name.", + AIContextProviderFactory = () => new SampleApp.UserInfoMemory(chatClient.AsIChatClient()) +}); + +// Create a new thread for the conversation. +AgentThread thread = agent.GetNewThread(); + +Console.WriteLine(">> Use thread with blank memory\n"); + +// Invoke the agent and output the text result. +Console.WriteLine(await agent.RunAsync("Hello, what is the square root of 9?", thread)); +Console.WriteLine(await agent.RunAsync("My name is Ruaidhrí", thread)); +Console.WriteLine(await agent.RunAsync("I am 20 years old", thread)); + +// We can serialize the thread. The serialized state will include the state of the memory component. +var threadElement = await thread.SerializeAsync(); + +Console.WriteLine("\n>> Use deserialized thread with previously created memories\n"); + +// Later we can deserialize the thread and continue the conversation with the previous memory component state. +var deserializedThread = await agent.DeserializeThreadAsync(threadElement); +Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedThread)); + +Console.WriteLine("\n>> Read memories from memory component\n"); + +// It's possible to access the memory component via the thread's AIContextProvider property. +var userInfo = ((UserInfoMemory)deserializedThread.AIContextProvider!).UserInfo; + +// Output the user info that was captured by the memory component. +Console.WriteLine($"MEMORY - User Name: {userInfo.UserName}"); +Console.WriteLine($"MEMORY - User Age: {userInfo.UserAge}"); + +Console.WriteLine("\n>> Use new thread with previously created memories\n"); + +// Create a new thread. +thread = agent.GetNewThread(); + +// It is also possible to add the memory component to an individual thread only instead of all +// threads via the factory above. +// In this case we will also use the same user info object, so this thread will share the same +// memories as the previous thread. +thread.AIContextProvider = new UserInfoMemory(chatClient.AsIChatClient(), userInfo); + +// Invoke the agent and output the text result. +// This time the agent should remember the user's name and use it in the response. +Console.WriteLine(await agent.RunAsync("What is my name and age?", thread)); + +namespace SampleApp +{ + /// + /// Sample memory component that can remember a user's name and age. + /// + internal sealed class UserInfoMemory(IChatClient chatClient, UserInfo? userInfo = null) : AIContextProvider + { + public UserInfo UserInfo { get; set; } = userInfo ?? new(); + + public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + { + // Try and extract the user name and age from the message if we don't have it already and it's a user message. + if ((this.UserInfo.UserName == null || this.UserInfo.UserAge == null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) + { + var result = await chatClient.GetResponseAsync( + context.RequestMessages, + new ChatOptions() + { + Instructions = "Extract the user's name and age from the message if present. If not present return nulls." + }, + cancellationToken: cancellationToken); + + this.UserInfo.UserName ??= result.Result.UserName; + this.UserInfo.UserAge ??= result.Result.UserAge; + } + } + + public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + StringBuilder instructions = new(); + + // If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context. + instructions.AppendLine( + this.UserInfo.UserName == null ? + "Ask the user for their name and politely decline to answer any questions until they provide it." : + $"The user's name is {this.UserInfo.UserName}."); + + instructions.AppendLine( + this.UserInfo.UserAge == null ? + "Ask the user for their age and politely decline to answer any questions until they provide it." : + $"The user's age is {this.UserInfo.UserAge}."); + + return new ValueTask(new AIContext + { + Instructions = instructions.ToString() + }); + } + + public override ValueTask SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + return new ValueTask(JsonSerializer.SerializeToElement(this.UserInfo, jsonSerializerOptions)); + } + + public override ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + this.UserInfo = JsonSerializer.Deserialize(serializedState, jsonSerializerOptions) ?? new UserInfo(); + return default; + } + } + + internal sealed class UserInfo + { + public string? UserName { get; set; } + public int? UserAge { get; set; } + } +} diff --git a/dotnet/samples/GettingStarted/Agents/README.md b/dotnet/samples/GettingStarted/Agents/README.md index 43da38e0ff..d831fab364 100644 --- a/dotnet/samples/GettingStarted/Agents/README.md +++ b/dotnet/samples/GettingStarted/Agents/README.md @@ -37,7 +37,8 @@ Before you begin, ensure you have the following prerequisites: |[Dependency injection with a simple agent](./Agent_Step09_DependencyInjection/)|This sample demonstrates how to add and resolve an agent with a dependency injection container| |[Exposing a simple agent as MCP tool](./Agent_Step10_AsMcpTool/)|This sample demonstrates how to expose an agent as an MCP tool| |[Using images with a simple agent](./Agent_Step11_UsingImages/)|This sample demonstrates how to use image multi-modality with an AI agent| -|[Exposing a simple agent a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool| +|[Exposing a simple agent as a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool| +|[Using memory with an agent](./Agent_Step12_Memory/)|This sample demonstrates how to create a simple memory component and use it with an agent| ## Running the samples from the console diff --git a/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs new file mode 100644 index 0000000000..74a67d7c6e --- /dev/null +++ b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +#pragma warning disable SA1623 // Property summary documentation should match accessors + +namespace System.Runtime.CompilerServices; + +/// +/// Indicates that compiler support for a particular feature is required for the location where this attribute is applied. +/// +[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)] +internal sealed class CompilerFeatureRequiredAttribute : Attribute +{ + public CompilerFeatureRequiredAttribute(string featureName) + { + FeatureName = featureName; + } + + /// + /// The name of the compiler feature. + /// + public string FeatureName { get; } + + /// + /// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand . + /// + public bool IsOptional { get; init; } + + /// + /// The used for the ref structs C# feature. + /// + public const string RefStructs = nameof(RefStructs); + + /// + /// The used for the required members C# feature. + /// + public const string RequiredMembers = nameof(RequiredMembers); +} diff --git a/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md new file mode 100644 index 0000000000..c30799eef0 --- /dev/null +++ b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md @@ -0,0 +1,9 @@ +Enables use of C# required members on older frameworks. + +To use this source in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/dotnet/src/LegacySupport/RequiredMemberAttribute/README.md b/dotnet/src/LegacySupport/RequiredMemberAttribute/README.md new file mode 100644 index 0000000000..da8c9bc98c --- /dev/null +++ b/dotnet/src/LegacySupport/RequiredMemberAttribute/README.md @@ -0,0 +1,9 @@ +Enables use of C# required members on older frameworks. + +To use this source in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs b/dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs new file mode 100644 index 0000000000..1a82954140 --- /dev/null +++ b/dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.ComponentModel; + +namespace System.Runtime.CompilerServices; + +/// Specifies that a type has required members or that a member is required. +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)] +[EditorBrowsable(EditorBrowsableState.Never)] +internal sealed class RequiredMemberAttribute : Attribute; diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs index 395f075029..b744a17892 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs @@ -290,6 +290,6 @@ public abstract class AIAgent _ = Throw.IfNull(thread); _ = Throw.IfNull(messages); - await thread.OnNewMessagesAsync(messages, cancellationToken).ConfigureAwait(false); + await thread.MessagesReceivedAsync(messages, cancellationToken).ConfigureAwait(false); } } diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs new file mode 100644 index 0000000000..f0c8e64bff --- /dev/null +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.AI.Agents; + +/// +/// A class containing any context that should be provided to the AI model +/// as supplied by an . +/// +/// +/// Each has the ability to provide its own context for each invocation. +/// The class contains the additional context supplied by the . +/// This context will be combined with context supplied by other providers before being passed to the AI model. +/// +public sealed class AIContext +{ + /// + /// Gets or sets any instructions to pass to the AI model in addition to any other prompts + /// that it may already have (in the case of an agent), or chat history that may + /// already exist. + /// + /// + /// These instructions will be transient and only apply to the current invocation. + /// + public string? Instructions { get; set; } + + /// + /// Gets or sets a list of messages to add to the chat history. + /// + /// + /// These messages will permanently be added to the chat history. + /// + public IList? Messages { get; set; } + + /// + /// Gets or sets a list of functions/tools to make available to the AI model for the current invocation. + /// + /// + /// These functions/tools will be transient and only apply to the current invocation. + /// + public IList? Tools { get; set; } +} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs new file mode 100644 index 0000000000..fc6d244051 --- /dev/null +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI.Agents; + +/// +/// Base class for all AI context providers. +/// +/// +/// An AI context provider is a component that can be used to enhance the AI's context management. +/// It can listen to changes in the conversation, provide additional context to +/// the Model/Agent/etc. just before invocation and supply additional function tools. +/// +public abstract class AIContextProvider +{ + /// + /// Called just before the Model/Agent/etc. is invoked + /// Implementers can load any additional context required at this time, + /// and they should return any context that should be passed to the Model/Agent/etc. + /// + /// Contains the event context. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been rendered and returned. + public abstract ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default); + + /// + /// Called just before the Model/Agent/etc. is invoked + /// Implementers can load any additional context required at this time, + /// and they should return any context that should be passed to the Model/Agent/etc. + /// + /// Contains the event context. + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the context has been rendered and returned. + public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + { + return default; + } + + /// + /// Serializes the current object's state to a using the specified serialization options. + /// + /// The JSON serialization options to use. + /// The to monitor for cancellation requests. The default is . + /// A representation of the object's state. + public virtual ValueTask SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + return default; + } + + /// + /// Deserializes the state contained in the provided into the properties on this object. + /// + /// A representing the state of the object. + /// Optional settings for customizing the JSON deserialization process. + /// The to monitor for cancellation requests. The default is . + /// A that completes when the state has been deserialized. + public virtual ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + return default; + } + + /// + /// Contains the event context provided to . + /// + public class InvokingContext + { + /// + /// Initializes a new instance of the class. + /// + /// The messages to be sent to the Model/Agent/etc. for this invocation. + /// Thrown if is . + public InvokingContext(IEnumerable requestMessages) + { + RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + } + + /// + /// Gets the messages that will be sent to the Model/Agent/etc. for this invocation. + /// + public IEnumerable RequestMessages { get; private set; } + } + + /// + /// Contains the event conext provided to . + /// + public class InvokedContext + { + /// + /// Initializes a new instance of the class with the specified request messages. + /// + /// The messages that were sent to the Model/Agent/etc. for this invocation. + /// Thrown if is . + public InvokedContext(IEnumerable requestMessages) + { + RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + } + + /// + /// Gets the messages that were sent to the Model/Agent/etc. for this invocation. + /// + public IEnumerable RequestMessages { get; private set; } + + /// + /// Gets the collection of response messages generated by Model/Agent/etc. if the invocation succeeded. + /// + public IEnumerable? ResponseMessages { get; init; } + + /// + /// Gets the that was thrown during the invocation, if the invocation failed. + /// + public Exception? InvokeException { get; init; } + } +} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs index 74a5c8108e..9b1fbcdb5b 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs @@ -27,7 +27,7 @@ public class AgentThread } /// - /// Gets or sets the id of the current thread to support cases where the thread is owned by the agent service. + /// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service. /// /// /// @@ -108,6 +108,11 @@ public class AgentThread } } + /// + /// Gets or sets the used by this thread to provide additional context to the AI model before each invocation. + /// + public AIContextProvider? AIContextProvider { get; set; } + /// /// Serializes the current object's state to a using the specified serialization options. /// @@ -120,10 +125,15 @@ public class AgentThread null : await this._messageStore.SerializeStateAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + var aiContextProviderState = this.AIContextProvider is null ? + null : + await this.AIContextProvider.SerializeAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + var state = new ThreadState { ConversationId = this.ConversationId, - StoreState = storeState + StoreState = storeState, + AIContextProviderState = aiContextProviderState }; return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState))); @@ -139,7 +149,7 @@ public class AgentThread /// The to monitor for cancellation requests. The default is . /// A task that completes when the context has been updated. /// The thread has been deleted. - protected internal virtual async Task OnNewMessagesAsync(IEnumerable newMessages, CancellationToken cancellationToken = default) + protected internal virtual async Task MessagesReceivedAsync(IEnumerable newMessages, CancellationToken cancellationToken = default) { switch (this) { @@ -186,6 +196,11 @@ public class AgentThread return; } + if (state?.AIContextProviderState.HasValue is true && this.AIContextProvider is not null) + { + await this.AIContextProvider.DeserializeAsync(state.AIContextProviderState.Value, jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + } + // If we don't have any IChatMessageStore state return here. if (state?.StoreState is null || state?.StoreState.Value.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) { @@ -206,5 +221,7 @@ public class AgentThread public string? ConversationId { get; set; } public JsonElement? StoreState { get; set; } + + public JsonElement? AIContextProviderState { get; set; } } } diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj index 0bdc6a5b49..c4f910dbec 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj +++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj @@ -12,6 +12,9 @@ true true true + true + true + true diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs index 3deafcdc6f..5878cc71c9 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs @@ -114,7 +114,17 @@ public sealed class ChatClientAgent : AIAgent this._logger.LogAgentChatClientInvokingAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType); - ChatResponse chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false); + // Call the IChatClient and notify the AIContextProvider of any failures. + ChatResponse chatResponse; + try + { + chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false); + throw; + } this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType, inputMessages.Count); @@ -122,19 +132,17 @@ public sealed class ChatClientAgent : AIAgent // so let's update it and set the conversation id for the service thread case. this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); - // Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread. - await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false); - // Ensure that the author name is set for each message in the response. foreach (ChatMessage chatResponseMessage in chatResponse.Messages) { chatResponseMessage.AuthorName ??= agentName; } - // Convert the chat response messages to a valid IReadOnlyCollection for notification signatures below. - var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection ?? [.. chatResponse.Messages]; + // Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent message state in the thread. + await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false); - await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false); + // Notify the AIContextProvider of all new messages. + await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); return new(chatResponse) { AgentId = this.Id }; } @@ -156,15 +164,34 @@ public sealed class ChatClientAgent : AIAgent this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType); - // Using the enumerator to ensure we consider the case where no updates are returned for notification. - var responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken); + List responseUpdates = []; + + IAsyncEnumerator responseUpdatesEnumerator; + + try + { + // Using the enumerator to ensure we consider the case where no updates are returned for notification. + responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) + { + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false); + throw; + } this._logger.LogAgentChatClientInvokedStreamingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType); - List responseUpdates = []; - - // Ensure we start the streaming request - var hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false); + bool hasUpdates; + try + { + // Ensure we start the streaming request + hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false); + throw; + } while (hasUpdates) { @@ -176,20 +203,28 @@ public sealed class ChatClientAgent : AIAgent yield return new(update) { AgentId = this.Id }; } - hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false); + try + { + hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false); + throw; + } } var chatResponse = responseUpdates.ToChatResponse(); - var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection ?? [.. chatResponse.Messages]; // We can derive the type of supported thread from whether we have a conversation id, // so let's update it and set the conversation id for the service thread case. this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. - await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false); + await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false); - await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false); + // Notify the AIContextProvider of all new messages. + await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -204,12 +239,40 @@ public sealed class ChatClientAgent : AIAgent /// public override AgentThread GetNewThread() { - var thread = new AgentThread { MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke() }; + var thread = new AgentThread + { + MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(), + AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke() + }; return thread; } #region Private + /// + /// Notify the when an agent run succeeded, if there is an . + /// + private static async Task NotifyAIContextProviderOfSuccessAsync(AgentThread thread, IEnumerable inputMessages, IEnumerable responseMessages, CancellationToken cancellationToken) + { + if (thread.AIContextProvider is not null) + { + await thread.AIContextProvider.InvokedAsync(new(inputMessages) { ResponseMessages = responseMessages }, + cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Notify the of any failure during an agent run, if there is an . + /// + private static async Task NotifyAIContextProviderOfFailureAsync(AgentThread thread, Exception ex, IEnumerable inputMessages, CancellationToken cancellationToken) + { + if (thread.AIContextProvider is not null) + { + await thread.AIContextProvider.InvokedAsync(new(inputMessages) { InvokeException = ex }, + cancellationToken).ConfigureAwait(false); + } + } + /// /// Configures and returns chat options by merging the provided run options with the agent's default chat options. /// @@ -350,6 +413,34 @@ public sealed class ChatClientAgent : AIAgent threadMessages.AddRange(await thread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false)); } + // If we have an AIContextProvider, we should get context from it, and update our + // messages and options with the additional context. + if (thread.AIContextProvider is not null) + { + var invokingContext = new AIContextProvider.InvokingContext(inputMessages); + var aiContext = await thread.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); + if (aiContext.Messages is { Count: > 0 }) + { + threadMessages.AddRange(aiContext.Messages); + } + + if (aiContext.Tools is { Count: > 0 }) + { + chatOptions ??= new(); + chatOptions.Tools ??= []; + foreach (AITool tool in aiContext.Tools) + { + chatOptions.Tools.Add(tool); + } + } + + if (aiContext.Instructions is not null) + { + chatOptions ??= new(); + chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}"; + } + } + // Add the input messages to the end of thread messages. threadMessages.AddRange(inputMessages); diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs index 2759523cc2..461ed01aa3 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs @@ -80,6 +80,13 @@ public class ChatClientAgentOptions /// public Func? ChatMessageStoreFactory { get; set; } + /// + /// Gets or sets a factory function to create an instance of + /// which will be used to create a context provider for each new thread, and can then + /// provide additional context for each agent run. + /// + public Func? AIContextProviderFactory { get; set; } + /// /// Gets or sets a value indicating whether to use the provided instance as is, /// without applying any default decorators. @@ -105,6 +112,7 @@ public class ChatClientAgentOptions Instructions = this.Instructions, Description = this.Description, ChatOptions = this.ChatOptions?.Clone(), - ChatMessageStoreFactory = this.ChatMessageStoreFactory + ChatMessageStoreFactory = this.ChatMessageStoreFactory, + AIContextProviderFactory = this.AIContextProviderFactory, }; } diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs index 04cd8584c8..9906b62a21 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs @@ -235,7 +235,7 @@ public class AIAgentTests await MockAgent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken); - threadMock.Protected().Verify("OnNewMessagesAsync", Times.Once(), messages, cancellationToken); + threadMock.Protected().Verify("MessagesReceivedAsync", Times.Once(), messages, cancellationToken); } #region GetService Method Tests diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs new file mode 100644 index 0000000000..f5eea0d5ed --- /dev/null +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests; + +public class AIContextProviderTests +{ + [Fact] + public async Task InvokedAsync_ReturnsCompletedTaskAsync() + { + var provider = new TestAIContextProvider(); + var messages = new ReadOnlyCollection(new List()); + var task = provider.InvokedAsync(new(messages)); + Assert.Equal(default, task); + } + + [Fact] + public async Task SerializeAsync_ReturnsEmptyElementAsync() + { + var provider = new TestAIContextProvider(); + var actual = await provider.SerializeAsync(); + Assert.Equal(default, actual); + } + + [Fact] + public async Task DeserializeAsync_ReturnsCompletedTaskAsync() + { + var provider = new TestAIContextProvider(); + var element = default(JsonElement); + var actual = provider.DeserializeAsync(element); + Assert.Equal(default, actual); + } + + private sealed class TestAIContextProvider : AIContextProvider + { + public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + return default; + } + + public override async ValueTask SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + return await base.SerializeAsync(jsonSerializerOptions, cancellationToken); + } + + public override async ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + { + await base.DeserializeAsync(serializedState, jsonSerializerOptions, cancellationToken); + } + } +} diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs new file mode 100644 index 0000000000..2a85a525e5 --- /dev/null +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests; + +/// +/// Unit tests for . +/// +public class AIContextTests +{ + [Fact] + public void SetInstructionsRoundtrips() + { + var context = new AIContext + { + Instructions = "Test Instructions" + }; + + Assert.Equal("Test Instructions", context.Instructions); + } + + [Fact] + public void SetMessagesRoundtrips() + { + var context = new AIContext + { + Messages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi there!") + } + }; + + Assert.NotNull(context.Messages); + Assert.Equal(2, context.Messages.Count); + Assert.Equal("Hello", context.Messages[0].Text); + Assert.Equal("Hi there!", context.Messages[1].Text); + } + + [Fact] + public void SetAIFunctionsRoundtrips() + { + var context = new AIContext + { + Tools = new List + { + AIFunctionFactory.Create(() => "Function1", "Function1", "Description1"), + AIFunctionFactory.Create(() => "Function2", "Function2", "Description2"), + } + }; + + Assert.NotNull(context.Tools); + Assert.Equal(2, context.Tools.Count); + Assert.Equal("Function1", context.Tools[0].Name); + Assert.Equal("Function2", context.Tools[1].Name); + } +} diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs index 305fc1cd45..843ea3fb0f 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs @@ -8,6 +8,8 @@ using System.Threading; using System.Threading.Tasks; using Moq; +#pragma warning disable CA1861 // Avoid constant arrays as arguments + namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests; public class AgentThreadTests @@ -102,7 +104,7 @@ public class AgentThreadTests }; // Act - await thread.OnNewMessagesAsync(messages, CancellationToken.None); + await thread.MessagesReceivedAsync(messages, CancellationToken.None); Assert.Equal("thread-123", thread.ConversationId); Assert.Null(thread.MessageStore); } @@ -120,7 +122,7 @@ public class AgentThreadTests }; // Act - await thread.OnNewMessagesAsync(messages, CancellationToken.None); + await thread.MessagesReceivedAsync(messages, CancellationToken.None); // Assert Assert.Equal(2, store.Count); @@ -173,6 +175,26 @@ public class AgentThreadTests Assert.Null(thread.MessageStore); } + [Fact] + public async Task VerifyDeserializeWithAIContextProviderAsync() + { + // Arrange + var json = JsonSerializer.Deserialize(""" + { + "aiContextProviderState": ["CP1"] + } + """, TestJsonSerializerContext.Default.JsonElement); + Mock mockProvider = new(); + var thread = new AgentThread() { AIContextProvider = mockProvider.Object }; + + // Act + await thread.DeserializeAsync(json); + + // Assert + Assert.Null(thread.MessageStore); + mockProvider.Verify(m => m.DeserializeAsync(It.Is(e => e.ValueKind == JsonValueKind.Array && e.GetArrayLength() == 1), It.IsAny(), It.IsAny()), Times.Once); + } + [Fact] public async Task DeserializeWithInvalidJsonThrowsAsync() { @@ -245,6 +267,31 @@ public class AgentThreadTests Assert.Equal("TestContent", textContent.GetProperty("text").GetString()); } + [Fact] + public async Task VerifyThreadSerializationWithWithAIContextProviderAsync() + { + // Arrange + Mock mockProvider = new(); + var providerStateElement = JsonSerializer.SerializeToElement(new[] { "CP1" }, TestJsonSerializerContext.Default.StringArray); + mockProvider + .Setup(m => m.SerializeAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(providerStateElement); + + var thread = new AgentThread(); + thread.AIContextProvider = mockProvider.Object; + + // Act + var json = await thread.SerializeAsync(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("aiContextProviderState", out var providerStateProperty)); + Assert.Equal(JsonValueKind.Array, providerStateProperty.ValueKind); + Assert.Single(providerStateProperty.EnumerateArray()); + Assert.Equal("CP1", providerStateProperty.EnumerateArray().First().GetString()); + mockProvider.Verify(m => m.SerializeAsync(It.IsAny(), It.IsAny()), Times.Once); + } + /// /// Verify thread serialization to JSON with custom options. /// diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs index bd39c6638a..f256ad311c 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs @@ -17,4 +17,5 @@ namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests; [JsonSerializable(typeof(Animal))] [JsonSerializable(typeof(JsonElement))] [JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(string[]))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs index ead3946fc2..d8dd5d52fc 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs @@ -22,6 +22,7 @@ public class ChatClientAgentOptionsTests Assert.Null(options.Description); Assert.Null(options.ChatOptions); Assert.Null(options.ChatMessageStoreFactory); + Assert.Null(options.AIContextProviderFactory); } [Fact] @@ -39,6 +40,7 @@ public class ChatClientAgentOptionsTests Assert.Null(options.Instructions); Assert.Null(options.Description); Assert.Null(options.ChatOptions); + Assert.Null(options.AIContextProviderFactory); } [Fact] @@ -163,11 +165,13 @@ public class ChatClientAgentOptionsTests const string Description = "Test description"; var tools = new List { AIFunctionFactory.Create(() => "test") }; static IChatMessageStore ChatMessageStoreFactory() => new Mock().Object; + static AIContextProvider AIContextProviderFactory() => new Mock().Object; var original = new ChatClientAgentOptions(Instructions, Name, Description, tools) { Id = "test-id", - ChatMessageStoreFactory = ChatMessageStoreFactory + ChatMessageStoreFactory = ChatMessageStoreFactory, + AIContextProviderFactory = AIContextProviderFactory }; // Act @@ -180,6 +184,7 @@ public class ChatClientAgentOptionsTests Assert.Equal(original.Instructions, clone.Instructions); Assert.Equal(original.Description, clone.Description); Assert.Same(original.ChatMessageStoreFactory, clone.ChatMessageStoreFactory); + Assert.Same(original.AIContextProviderFactory, clone.AIContextProviderFactory); // ChatOptions should be cloned, not the same reference Assert.NotSame(original.ChatOptions, clone.ChatOptions); @@ -209,5 +214,7 @@ public class ChatClientAgentOptionsTests Assert.Equal(original.Instructions, clone.Instructions); Assert.Equal(original.Description, clone.Description); Assert.Null(clone.ChatOptions); + Assert.Null(clone.ChatMessageStoreFactory); + Assert.Null(clone.AIContextProviderFactory); } } diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs index 46afaace36..4ec63e04cd 100644 --- a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs @@ -10,6 +10,8 @@ namespace Microsoft.Extensions.AI.Agents.UnitTests.ChatCompletion; public class ChatClientAgentTests { + #region Constructor Tests + /// /// Verify the invocation and response of . /// @@ -38,6 +40,10 @@ public class ChatClientAgentTests Assert.Equal("AgentInvokedChatClient", agent.ChatClient.GetType().Name); } + #endregion + + #region RunAsync Tests + /// /// Verify the invocation and response of using . /// @@ -390,6 +396,176 @@ public class ChatClientAgentTests await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread)); } + /// + /// Verify that RunAsync sets the ConversationId on the thread when the service returns one. + /// + [Fact] + public async Task RunAsyncSetsConversationIdOnThreadWhenReturnedByChatClientAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); + ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" }); + AgentThread thread = new(); + + // Act + await agent.RunAsync([new(ChatRole.User, "test")], thread); + + // Assert + Assert.Equal("ConvId", thread.ConversationId); + } + + /// + /// Verify that RunAsync invokes any provided AIContextProvider and uses the result. + /// + [Fact] + public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() + { + // Arrange + ChatMessage[] requestMessages = [new(ChatRole.User, "user message")]; + ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")]; + Mock mockService = new(); + List capturedMessages = []; + string capturedInstructions = string.Empty; + List capturedTools = []; + mockService + .Setup(s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + { + capturedMessages.AddRange(msgs); + capturedInstructions = opts.Instructions ?? string.Empty; + if (opts.Tools != null) + { + capturedTools.AddRange(opts.Tools); + } + }) + .ReturnsAsync(new ChatResponse(responseMessages)); + + var mockProvider = new Mock(); + mockProvider + .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new AIContext + { + Messages = [new(ChatRole.System, "context provider message")], + Instructions = "context provider instructions", + Tools = [AIFunctionFactory.Create(() => { }, "context provider function")] + }); + mockProvider + .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Returns(new ValueTask()); + + ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + + // Act + await agent.RunAsync(requestMessages); + + // Assert + // Should contain: base instructions, context message, user message, base function, context function + Assert.Equal(2, capturedMessages.Count); + Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions); + Assert.Equal("context provider message", capturedMessages[0].Text); + Assert.Equal(ChatRole.System, capturedMessages[0].Role); + Assert.Equal("user message", capturedMessages[1].Text); + Assert.Equal(ChatRole.User, capturedMessages[1].Role); + Assert.Equal(2, capturedTools.Count); + Assert.Contains(capturedTools, t => t.Name == "base function"); + Assert.Contains(capturedTools, t => t.Name == "context provider function"); + mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); + mockProvider.Verify(p => p.InvokedAsync(It.Is(x => x.RequestMessages == requestMessages && x.ResponseMessages == responseMessages && x.InvokeException == null), It.IsAny()), Times.Once); + } + + /// + /// Verify that RunAsync invokes any provided AIContextProvider when the downstream GetResponse call fails. + /// + [Fact] + public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() + { + // Arrange + ChatMessage[] requestMessages = [new(ChatRole.User, "user message")]; + ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")]; + Mock mockService = new(); + mockService + .Setup(s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Throws(new InvalidOperationException("downstream failure")); + + var mockProvider = new Mock(); + mockProvider + .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new AIContext()); + mockProvider + .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) + .Returns(new ValueTask()); + + ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + + // Act + await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); + + // Assert + mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); + mockProvider.Verify(p => p.InvokedAsync(It.Is(x => x.RequestMessages == requestMessages && x.ResponseMessages == null && x.InvokeException is InvalidOperationException), It.IsAny()), Times.Once); + } + + /// + /// Verify that RunAsync invokes any provided AIContextProvider and succeeds even when the AIContext is empty. + /// + [Fact] + public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextAsync() + { + // Arrange + Mock mockService = new(); + List capturedMessages = []; + string capturedInstructions = string.Empty; + List capturedTools = []; + mockService + .Setup(s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + { + capturedMessages.AddRange(msgs); + capturedInstructions = opts.Instructions ?? string.Empty; + if (opts.Tools != null) + { + capturedTools.AddRange(opts.Tools); + } + }) + .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + var mockProvider = new Mock(); + mockProvider + .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new AIContext()); + + ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + + // Act + await agent.RunAsync([new(ChatRole.User, "user message")]); + + // Assert + // Should contain: base instructions, user message, base function + Assert.Single(capturedMessages); + Assert.Equal("base instructions", capturedInstructions); + Assert.Equal("user message", capturedMessages[0].Text); + Assert.Equal(ChatRole.User, capturedMessages[0].Role); + Assert.Single(capturedTools); + Assert.Contains(capturedTools, t => t.Name == "base function"); + mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + #endregion + #region Property Override Tests /// @@ -1524,6 +1700,61 @@ public class ChatClientAgentTests #endregion + #region GetNewThread Tests + + [Fact] + public void GetNewThreadUsesChatMessageStoreFactoryIfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockStore = new Mock(); + var factoryCalled = false; + + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + ChatMessageStoreFactory = () => + { + factoryCalled = true; + return mockStore.Object; + } + }); + + // Act + var thread = agent.GetNewThread(); + + // Assert + Assert.True(factoryCalled, "ChatMessageStoreFactory was not called."); + Assert.Same(mockStore.Object, thread.MessageStore); + } + + [Fact] + public void GetNewThreadUsesAIContextProviderFactoryIfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + AIContextProviderFactory = () => + { + factoryCalled = true; + return mockContextProvider.Object; + } + }); + + // Act + var thread = agent.GetNewThread(); + + // Assert + Assert.True(factoryCalled, "AIContextProviderFactory was not called."); + Assert.Same(mockContextProvider.Object, thread.AIContextProvider); + } + + #endregion + private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values) { await Task.Yield(); diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs new file mode 100644 index 0000000000..d3d1f00118 --- /dev/null +++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI.Agents.UnitTests; + +[JsonSourceGenerationOptions( + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + UseStringEnumConverter = true)] +[JsonSerializable(typeof(JsonElement))] +[JsonSerializable(typeof(string))] +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;