.NET: Add AIContextProvider support (#691)

* Add AIContextProvider support

* Address feedback.

* Address PR comments.

* Switch to valuetask and remove parallel calls for AIContextProvider

* Remove Model from ModelInvokingAsync method name

* Remove agent thread id again and remove it from context provider interface

* Add AIContextProvider serialization support to AgentThread and update sample to show this feature

* Address PR comments

* Improve memory sample

* Update sample comment.

* Remove AggregateAIContextProvider for now since it makes too many assumptions.  We can include it later as a sample if needed.

* Update AIContextProviders to have an Invoked method instead of MessagesAddingAsync.

* Remove unused using.

* Address PR comments.

* Address PR comment.

* Update comment.

* Update comment

* Address PR comments.
This commit is contained in:
westey
2025-09-16 10:54:18 +01:00
committed by GitHub
Unverified
parent 89cb94b5c2
commit 66fe1c957c
25 changed files with 981 additions and 32 deletions
+9
View File
@@ -50,6 +50,7 @@
<Project Path="samples/GettingStarted/Agents/Agent_Step10_AsMcpTool/Agent_Step10_AsMcpTool.csproj" />
<Project Path="samples/GettingStarted/Agents/Agent_Step11_UsingImages/Agent_Step11_UsingImages.csproj" />
<Project Path="samples/GettingStarted/Agents/Agent_Step12_AsFunctionTool/Agent_Step12_AsFunctionTool.csproj" />
<Project Path="samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj" />
</Folder>
<Folder Name="/Samples/GettingStarted/AgentWithOpenAI/">
<File Path="samples/GettingStarted/AgentWithOpenAI/README.md" />
@@ -190,6 +191,10 @@
<File Path="src/LegacySupport/CallerAttributes/CallerArgumentExpressionAttribute.cs" />
<File Path="src/LegacySupport/CallerAttributes/README.md" />
</Folder>
<Folder Name="/Solution Items/src/LegacySupport/CompilerFeatureRequiredAttribute/">
<File Path="src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs" />
<File Path="src/LegacySupport/CompilerFeatureRequiredAttribute/README.md" />
</Folder>
<Folder Name="/Solution Items/src/LegacySupport/DiagnosticAttributes/">
<File Path="src/LegacySupport/DiagnosticAttributes/NullableAttributes.cs" />
<File Path="src/LegacySupport/DiagnosticAttributes/README.md" />
@@ -206,6 +211,10 @@
<File Path="src/LegacySupport/IsExternalInit/IsExternalInit.cs" />
<File Path="src/LegacySupport/IsExternalInit/README.md" />
</Folder>
<Folder Name="/Solution Items/src/LegacySupport/RequiredMemberAttribute/">
<File Path="src/LegacySupport/RequiredMemberAttribute/README.md" />
<File Path="src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs" />
</Folder>
<Folder Name="/Solution Items/src/LegacySupport/TrimAttributes/">
<File Path="src/LegacySupport/TrimAttributes/DynamicallyAccessedMembersAttribute.cs" />
<File Path="src/LegacySupport/TrimAttributes/DynamicallyAccessedMemberTypes.cs" />
+8
View File
@@ -22,4 +22,12 @@
<ItemGroup Condition="'$(InjectTrimAttributesOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\TrimAttributes\*.cs" LinkBase="LegacySupport\TrimAttributes" />
</ItemGroup>
<ItemGroup Condition="'$(InjectRequiredMemberOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\RequiredMemberAttribute\*.cs" LinkBase="LegacySupport\RequiredMemberAttribute" />
</ItemGroup>
<ItemGroup Condition="'$(InjectCompilerFeatureRequiredOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\CompilerFeatureRequiredAttribute\*.cs" LinkBase="LegacySupport\CompilerFeatureRequiredAttribute" />
</ItemGroup>
</Project>
@@ -39,8 +39,7 @@ namespace SampleApp
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();
// Notify the thread of the input and output messages.
await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken);
await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
return new AgentRunResponse
{
@@ -59,8 +58,7 @@ namespace SampleApp
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();
// Notify the thread of the input and output messages.
await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken);
await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
foreach (var message in responseMessages)
{
@@ -0,0 +1,22 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net9.0</TargetFramework>
<Nullable>enable</Nullable>
<ImplicitUsings>disable</ImplicitUsings>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\..\..\src\Microsoft.Extensions.AI.Agents.OpenAI\Microsoft.Extensions.AI.Agents.OpenAI.csproj" />
<ProjectReference Include="..\..\..\..\src\Microsoft.Extensions.AI.Agents\Microsoft.Extensions.AI.Agents.csproj" />
</ItemGroup>
</Project>
@@ -0,0 +1,152 @@
// Copyright (c) Microsoft. All rights reserved.
// This sample shows how to add a basic custom memory component to an agent.
// The memory component subscribes to all messages added to the conversation and
// extracts the user's name and age if provided.
// The component adds a prompt to ask for this information if it is not already known
// and provides it to the model before each invocation if known.
using System;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Azure.Identity;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.AI.Agents;
using OpenAI;
using OpenAI.Chat;
using SampleApp;
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new InvalidOperationException("AZURE_OPENAI_ENDPOINT is not set.");
var deploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOYMENT_NAME") ?? "gpt-4o-mini";
ChatClient chatClient = new AzureOpenAIClient(
new Uri(endpoint),
new AzureCliCredential())
.GetChatClient(deploymentName);
// Create the agent and provide a factory to add our custom memory component to
// all threads created by the agent. Here each new memory component will have its own
// user info object, so each thread will have its own memory.
// In real world applications/services, where the user info would be persisted in a database,
// and preferably shared between multiple threads used by the same user, ensure that the
// factory reads the user id from the current context and scopes the memory component
// and its storage to that user id.
AIAgent agent = chatClient.CreateAIAgent(new ChatClientAgentOptions()
{
Instructions = "You are a friendly assistant. Always address the user by their name.",
AIContextProviderFactory = () => new SampleApp.UserInfoMemory(chatClient.AsIChatClient())
});
// Create a new thread for the conversation.
AgentThread thread = agent.GetNewThread();
Console.WriteLine(">> Use thread with blank memory\n");
// Invoke the agent and output the text result.
Console.WriteLine(await agent.RunAsync("Hello, what is the square root of 9?", thread));
Console.WriteLine(await agent.RunAsync("My name is Ruaidhrí", thread));
Console.WriteLine(await agent.RunAsync("I am 20 years old", thread));
// We can serialize the thread. The serialized state will include the state of the memory component.
var threadElement = await thread.SerializeAsync();
Console.WriteLine("\n>> Use deserialized thread with previously created memories\n");
// Later we can deserialize the thread and continue the conversation with the previous memory component state.
var deserializedThread = await agent.DeserializeThreadAsync(threadElement);
Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedThread));
Console.WriteLine("\n>> Read memories from memory component\n");
// It's possible to access the memory component via the thread's AIContextProvider property.
var userInfo = ((UserInfoMemory)deserializedThread.AIContextProvider!).UserInfo;
// Output the user info that was captured by the memory component.
Console.WriteLine($"MEMORY - User Name: {userInfo.UserName}");
Console.WriteLine($"MEMORY - User Age: {userInfo.UserAge}");
Console.WriteLine("\n>> Use new thread with previously created memories\n");
// Create a new thread.
thread = agent.GetNewThread();
// It is also possible to add the memory component to an individual thread only instead of all
// threads via the factory above.
// In this case we will also use the same user info object, so this thread will share the same
// memories as the previous thread.
thread.AIContextProvider = new UserInfoMemory(chatClient.AsIChatClient(), userInfo);
// Invoke the agent and output the text result.
// This time the agent should remember the user's name and use it in the response.
Console.WriteLine(await agent.RunAsync("What is my name and age?", thread));
namespace SampleApp
{
/// <summary>
/// Sample memory component that can remember a user's name and age.
/// </summary>
internal sealed class UserInfoMemory(IChatClient chatClient, UserInfo? userInfo = null) : AIContextProvider
{
public UserInfo UserInfo { get; set; } = userInfo ?? new();
public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
// Try and extract the user name and age from the message if we don't have it already and it's a user message.
if ((this.UserInfo.UserName == null || this.UserInfo.UserAge == null) && context.RequestMessages.Any(x => x.Role == ChatRole.User))
{
var result = await chatClient.GetResponseAsync<UserInfo>(
context.RequestMessages,
new ChatOptions()
{
Instructions = "Extract the user's name and age from the message if present. If not present return nulls."
},
cancellationToken: cancellationToken);
this.UserInfo.UserName ??= result.Result.UserName;
this.UserInfo.UserAge ??= result.Result.UserAge;
}
}
public override ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
StringBuilder instructions = new();
// If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context.
instructions.AppendLine(
this.UserInfo.UserName == null ?
"Ask the user for their name and politely decline to answer any questions until they provide it." :
$"The user's name is {this.UserInfo.UserName}.");
instructions.AppendLine(
this.UserInfo.UserAge == null ?
"Ask the user for their age and politely decline to answer any questions until they provide it." :
$"The user's age is {this.UserInfo.UserAge}.");
return new ValueTask<AIContext>(new AIContext
{
Instructions = instructions.ToString()
});
}
public override ValueTask<JsonElement?> SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
return new ValueTask<JsonElement?>(JsonSerializer.SerializeToElement(this.UserInfo, jsonSerializerOptions));
}
public override ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
this.UserInfo = JsonSerializer.Deserialize<UserInfo>(serializedState, jsonSerializerOptions) ?? new UserInfo();
return default;
}
}
internal sealed class UserInfo
{
public string? UserName { get; set; }
public int? UserAge { get; set; }
}
}
@@ -37,7 +37,8 @@ Before you begin, ensure you have the following prerequisites:
|[Dependency injection with a simple agent](./Agent_Step09_DependencyInjection/)|This sample demonstrates how to add and resolve an agent with a dependency injection container|
|[Exposing a simple agent as MCP tool](./Agent_Step10_AsMcpTool/)|This sample demonstrates how to expose an agent as an MCP tool|
|[Using images with a simple agent](./Agent_Step11_UsingImages/)|This sample demonstrates how to use image multi-modality with an AI agent|
|[Exposing a simple agent a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool|
|[Exposing a simple agent as a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool|
|[Using memory with an agent](./Agent_Step12_Memory/)|This sample demonstrates how to create a simple memory component and use it with an agent|
## Running the samples from the console
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft. All rights reserved.
#pragma warning disable SA1623 // Property summary documentation should match accessors
namespace System.Runtime.CompilerServices;
/// <summary>
/// Indicates that compiler support for a particular feature is required for the location where this attribute is applied.
/// </summary>
[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)]
internal sealed class CompilerFeatureRequiredAttribute : Attribute
{
public CompilerFeatureRequiredAttribute(string featureName)
{
FeatureName = featureName;
}
/// <summary>
/// The name of the compiler feature.
/// </summary>
public string FeatureName { get; }
/// <summary>
/// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand <see cref="FeatureName"/>.
/// </summary>
public bool IsOptional { get; init; }
/// <summary>
/// The <see cref="FeatureName"/> used for the ref structs C# feature.
/// </summary>
public const string RefStructs = nameof(RefStructs);
/// <summary>
/// The <see cref="FeatureName"/> used for the required members C# feature.
/// </summary>
public const string RequiredMembers = nameof(RequiredMembers);
}
@@ -0,0 +1,9 @@
Enables use of C# required members on older frameworks.
To use this source in your project, add the following to your `.csproj` file:
```xml
<PropertyGroup>
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>
</PropertyGroup>
```
@@ -0,0 +1,9 @@
Enables use of C# required members on older frameworks.
To use this source in your project, add the following to your `.csproj` file:
```xml
<PropertyGroup>
<InjectRequiredMemberOnLegacy>true</InjectRequiredMemberOnLegacy>
</PropertyGroup>
```
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
using System.ComponentModel;
namespace System.Runtime.CompilerServices;
/// <summary>Specifies that a type has required members or that a member is required.</summary>
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)]
[EditorBrowsable(EditorBrowsableState.Never)]
internal sealed class RequiredMemberAttribute : Attribute;
@@ -290,6 +290,6 @@ public abstract class AIAgent
_ = Throw.IfNull(thread);
_ = Throw.IfNull(messages);
await thread.OnNewMessagesAsync(messages, cancellationToken).ConfigureAwait(false);
await thread.MessagesReceivedAsync(messages, cancellationToken).ConfigureAwait(false);
}
}
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// A class containing any context that should be provided to the AI model
/// as supplied by an <see cref="AIContextProvider"/>.
/// </summary>
/// <remarks>
/// Each <see cref="AIContextProvider"/> has the ability to provide its own context for each invocation.
/// The <see cref="AIContext"/> class contains the additional context supplied by the <see cref="AIContextProvider"/>.
/// This context will be combined with context supplied by other providers before being passed to the AI model.
/// </remarks>
public sealed class AIContext
{
/// <summary>
/// Gets or sets any instructions to pass to the AI model in addition to any other prompts
/// that it may already have (in the case of an agent), or chat history that may
/// already exist.
/// </summary>
/// <remarks>
/// These instructions will be transient and only apply to the current invocation.
/// </remarks>
public string? Instructions { get; set; }
/// <summary>
/// Gets or sets a list of messages to add to the chat history.
/// </summary>
/// <remarks>
/// These messages will permanently be added to the chat history.
/// </remarks>
public IList<ChatMessage>? Messages { get; set; }
/// <summary>
/// Gets or sets a list of functions/tools to make available to the AI model for the current invocation.
/// </summary>
/// <remarks>
/// These functions/tools will be transient and only apply to the current invocation.
/// </remarks>
public IList<AITool>? Tools { get; set; }
}
@@ -0,0 +1,118 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Extensions.AI.Agents;
/// <summary>
/// Base class for all AI context providers.
/// </summary>
/// <remarks>
/// An AI context provider is a component that can be used to enhance the AI's context management.
/// It can listen to changes in the conversation, provide additional context to
/// the Model/Agent/etc. just before invocation and supply additional function tools.
/// </remarks>
public abstract class AIContextProvider
{
/// <summary>
/// Called just before the Model/Agent/etc. is invoked
/// Implementers can load any additional context required at this time,
/// and they should return any context that should be passed to the Model/Agent/etc.
/// </summary>
/// <param name="context">Contains the event context.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the context has been rendered and returned.</returns>
public abstract ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default);
/// <summary>
/// Called just before the Model/Agent/etc. is invoked
/// Implementers can load any additional context required at this time,
/// and they should return any context that should be passed to the Model/Agent/etc.
/// </summary>
/// <param name="context">Contains the event context.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the context has been rendered and returned.</returns>
public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
return default;
}
/// <summary>
/// Serializes the current object's state to a <see cref="JsonElement"/> using the specified serialization options.
/// </summary>
/// <param name="jsonSerializerOptions">The JSON serialization options to use.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A <see cref="JsonElement"/> representation of the object's state.</returns>
public virtual ValueTask<JsonElement?> SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
return default;
}
/// <summary>
/// Deserializes the state contained in the provided <see cref="JsonElement"/> into the properties on this object.
/// </summary>
/// <param name="serializedState">A <see cref="JsonElement"/> representing the state of the object.</param>
/// <param name="jsonSerializerOptions">Optional settings for customizing the JSON deserialization process.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A <see cref="ValueTask"/> that completes when the state has been deserialized.</returns>
public virtual ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
return default;
}
/// <summary>
/// Contains the event context provided to <see cref="AIContextProvider.InvokingAsync(InvokingContext, CancellationToken)"/>.
/// </summary>
public class InvokingContext
{
/// <summary>
/// Initializes a new instance of the <see cref="InvokingContext"/> class.
/// </summary>
/// <param name="requestMessages">The messages to be sent to the Model/Agent/etc. for this invocation.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokingContext(IEnumerable<ChatMessage> requestMessages)
{
RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
}
/// <summary>
/// Gets the messages that will be sent to the Model/Agent/etc. for this invocation.
/// </summary>
public IEnumerable<ChatMessage> RequestMessages { get; private set; }
}
/// <summary>
/// Contains the event conext provided to <see cref="AIContextProvider.InvokedAsync(InvokedContext, CancellationToken)"/>.
/// </summary>
public class InvokedContext
{
/// <summary>
/// Initializes a new instance of the <see cref="InvokedContext"/> class with the specified request messages.
/// </summary>
/// <param name="requestMessages">The messages that were sent to the Model/Agent/etc. for this invocation.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokedContext(IEnumerable<ChatMessage> requestMessages)
{
RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
}
/// <summary>
/// Gets the messages that were sent to the Model/Agent/etc. for this invocation.
/// </summary>
public IEnumerable<ChatMessage> RequestMessages { get; private set; }
/// <summary>
/// Gets the collection of response messages generated by Model/Agent/etc. if the invocation succeeded.
/// </summary>
public IEnumerable<ChatMessage>? ResponseMessages { get; init; }
/// <summary>
/// Gets the <see cref="Exception"/> that was thrown during the invocation, if the invocation failed.
/// </summary>
public Exception? InvokeException { get; init; }
}
}
@@ -27,7 +27,7 @@ public class AgentThread
}
/// <summary>
/// Gets or sets the id of the current thread to support cases where the thread is owned by the agent service.
/// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service.
/// </summary>
/// <remarks>
/// <para>
@@ -108,6 +108,11 @@ public class AgentThread
}
}
/// <summary>
/// Gets or sets the <see cref="AIContextProvider"/> used by this thread to provide additional context to the AI model before each invocation.
/// </summary>
public AIContextProvider? AIContextProvider { get; set; }
/// <summary>
/// Serializes the current object's state to a <see cref="JsonElement"/> using the specified serialization options.
/// </summary>
@@ -120,10 +125,15 @@ public class AgentThread
null :
await this._messageStore.SerializeStateAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
var aiContextProviderState = this.AIContextProvider is null ?
null :
await this.AIContextProvider.SerializeAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
var state = new ThreadState
{
ConversationId = this.ConversationId,
StoreState = storeState
StoreState = storeState,
AIContextProviderState = aiContextProviderState
};
return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState)));
@@ -139,7 +149,7 @@ public class AgentThread
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that completes when the context has been updated.</returns>
/// <exception cref="InvalidOperationException">The thread has been deleted.</exception>
protected internal virtual async Task OnNewMessagesAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
protected internal virtual async Task MessagesReceivedAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
{
switch (this)
{
@@ -186,6 +196,11 @@ public class AgentThread
return;
}
if (state?.AIContextProviderState.HasValue is true && this.AIContextProvider is not null)
{
await this.AIContextProvider.DeserializeAsync(state.AIContextProviderState.Value, jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
}
// If we don't have any IChatMessageStore state return here.
if (state?.StoreState is null || state?.StoreState.Value.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null)
{
@@ -206,5 +221,7 @@ public class AgentThread
public string? ConversationId { get; set; }
public JsonElement? StoreState { get; set; }
public JsonElement? AIContextProviderState { get; set; }
}
}
@@ -12,6 +12,9 @@
<InjectSharedThrow>true</InjectSharedThrow>
<InjectDiagnosticClassesOnLegacy>true</InjectDiagnosticClassesOnLegacy>
<InjectTrimAttributesOnLegacy>true</InjectTrimAttributesOnLegacy>
<InjectIsExternalInitOnLegacy>true</InjectIsExternalInitOnLegacy>
<InjectRequiredMemberOnLegacy>true</InjectRequiredMemberOnLegacy>
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
@@ -114,7 +114,17 @@ public sealed class ChatClientAgent : AIAgent
this._logger.LogAgentChatClientInvokingAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType);
ChatResponse chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false);
// Call the IChatClient and notify the AIContextProvider of any failures.
ChatResponse chatResponse;
try
{
chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
throw;
}
this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType, inputMessages.Count);
@@ -122,19 +132,17 @@ public sealed class ChatClientAgent : AIAgent
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread.
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false);
// Ensure that the author name is set for each message in the response.
foreach (ChatMessage chatResponseMessage in chatResponse.Messages)
{
chatResponseMessage.AuthorName ??= agentName;
}
// Convert the chat response messages to a valid IReadOnlyCollection for notification signatures below.
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? [.. chatResponse.Messages];
// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent message state in the thread.
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
return new(chatResponse) { AgentId = this.Id };
}
@@ -156,15 +164,34 @@ public sealed class ChatClientAgent : AIAgent
this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType);
// Using the enumerator to ensure we consider the case where no updates are returned for notification.
var responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);
List<ChatResponseUpdate> responseUpdates = [];
IAsyncEnumerator<ChatResponseUpdate> responseUpdatesEnumerator;
try
{
// Using the enumerator to ensure we consider the case where no updates are returned for notification.
responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);
}
catch (Exception ex)
{
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
throw;
}
this._logger.LogAgentChatClientInvokedStreamingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType);
List<ChatResponseUpdate> responseUpdates = [];
// Ensure we start the streaming request
var hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
bool hasUpdates;
try
{
// Ensure we start the streaming request
hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
throw;
}
while (hasUpdates)
{
@@ -176,20 +203,28 @@ public sealed class ChatClientAgent : AIAgent
yield return new(update) { AgentId = this.Id };
}
hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
try
{
hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
throw;
}
}
var chatResponse = responseUpdates.ToChatResponse();
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? [.. chatResponse.Messages];
// We can derive the type of supported thread from whether we have a conversation id,
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false);
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
@@ -204,12 +239,40 @@ public sealed class ChatClientAgent : AIAgent
/// <inheritdoc/>
public override AgentThread GetNewThread()
{
var thread = new AgentThread { MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke() };
var thread = new AgentThread
{
MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(),
AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke()
};
return thread;
}
#region Private
/// <summary>
/// Notify the <see cref="AIContextProvider"/> when an agent run succeeded, if there is an <see cref="AIContextProvider"/>.
/// </summary>
private static async Task NotifyAIContextProviderOfSuccessAsync(AgentThread thread, IEnumerable<ChatMessage> inputMessages, IEnumerable<ChatMessage> responseMessages, CancellationToken cancellationToken)
{
if (thread.AIContextProvider is not null)
{
await thread.AIContextProvider.InvokedAsync(new(inputMessages) { ResponseMessages = responseMessages },
cancellationToken).ConfigureAwait(false);
}
}
/// <summary>
/// Notify the <see cref="AIContextProvider"/> of any failure during an agent run, if there is an <see cref="AIContextProvider"/>.
/// </summary>
private static async Task NotifyAIContextProviderOfFailureAsync(AgentThread thread, Exception ex, IEnumerable<ChatMessage> inputMessages, CancellationToken cancellationToken)
{
if (thread.AIContextProvider is not null)
{
await thread.AIContextProvider.InvokedAsync(new(inputMessages) { InvokeException = ex },
cancellationToken).ConfigureAwait(false);
}
}
/// <summary>
/// Configures and returns chat options by merging the provided run options with the agent's default chat options.
/// </summary>
@@ -350,6 +413,34 @@ public sealed class ChatClientAgent : AIAgent
threadMessages.AddRange(await thread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false));
}
// If we have an AIContextProvider, we should get context from it, and update our
// messages and options with the additional context.
if (thread.AIContextProvider is not null)
{
var invokingContext = new AIContextProvider.InvokingContext(inputMessages);
var aiContext = await thread.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
if (aiContext.Messages is { Count: > 0 })
{
threadMessages.AddRange(aiContext.Messages);
}
if (aiContext.Tools is { Count: > 0 })
{
chatOptions ??= new();
chatOptions.Tools ??= [];
foreach (AITool tool in aiContext.Tools)
{
chatOptions.Tools.Add(tool);
}
}
if (aiContext.Instructions is not null)
{
chatOptions ??= new();
chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}";
}
}
// Add the input messages to the end of thread messages.
threadMessages.AddRange(inputMessages);
@@ -80,6 +80,13 @@ public class ChatClientAgentOptions
/// </summary>
public Func<IChatMessageStore>? ChatMessageStoreFactory { get; set; }
/// <summary>
/// Gets or sets a factory function to create an instance of <see cref="AIContextProvider"/>
/// which will be used to create a context provider for each new thread, and can then
/// provide additional context for each agent run.
/// </summary>
public Func<AIContextProvider>? AIContextProviderFactory { get; set; }
/// <summary>
/// Gets or sets a value indicating whether to use the provided <see cref="IChatClient"/> instance as is,
/// without applying any default decorators.
@@ -105,6 +112,7 @@ public class ChatClientAgentOptions
Instructions = this.Instructions,
Description = this.Description,
ChatOptions = this.ChatOptions?.Clone(),
ChatMessageStoreFactory = this.ChatMessageStoreFactory
ChatMessageStoreFactory = this.ChatMessageStoreFactory,
AIContextProviderFactory = this.AIContextProviderFactory,
};
}
@@ -235,7 +235,7 @@ public class AIAgentTests
await MockAgent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken);
threadMock.Protected().Verify("OnNewMessagesAsync", Times.Once(), messages, cancellationToken);
threadMock.Protected().Verify("MessagesReceivedAsync", Times.Once(), messages, cancellationToken);
}
#region GetService Method Tests
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
public class AIContextProviderTests
{
[Fact]
public async Task InvokedAsync_ReturnsCompletedTaskAsync()
{
var provider = new TestAIContextProvider();
var messages = new ReadOnlyCollection<ChatMessage>(new List<ChatMessage>());
var task = provider.InvokedAsync(new(messages));
Assert.Equal(default, task);
}
[Fact]
public async Task SerializeAsync_ReturnsEmptyElementAsync()
{
var provider = new TestAIContextProvider();
var actual = await provider.SerializeAsync();
Assert.Equal(default, actual);
}
[Fact]
public async Task DeserializeAsync_ReturnsCompletedTaskAsync()
{
var provider = new TestAIContextProvider();
var element = default(JsonElement);
var actual = provider.DeserializeAsync(element);
Assert.Equal(default, actual);
}
private sealed class TestAIContextProvider : AIContextProvider
{
public override ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
return default;
}
public override async ValueTask<JsonElement?> SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
return await base.SerializeAsync(jsonSerializerOptions, cancellationToken);
}
public override async ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
await base.DeserializeAsync(serializedState, jsonSerializerOptions, cancellationToken);
}
}
}
@@ -0,0 +1,58 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
/// <summary>
/// Unit tests for <see cref="AIContext"/>.
/// </summary>
public class AIContextTests
{
[Fact]
public void SetInstructionsRoundtrips()
{
var context = new AIContext
{
Instructions = "Test Instructions"
};
Assert.Equal("Test Instructions", context.Instructions);
}
[Fact]
public void SetMessagesRoundtrips()
{
var context = new AIContext
{
Messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
}
};
Assert.NotNull(context.Messages);
Assert.Equal(2, context.Messages.Count);
Assert.Equal("Hello", context.Messages[0].Text);
Assert.Equal("Hi there!", context.Messages[1].Text);
}
[Fact]
public void SetAIFunctionsRoundtrips()
{
var context = new AIContext
{
Tools = new List<AITool>
{
AIFunctionFactory.Create(() => "Function1", "Function1", "Description1"),
AIFunctionFactory.Create(() => "Function2", "Function2", "Description2"),
}
};
Assert.NotNull(context.Tools);
Assert.Equal(2, context.Tools.Count);
Assert.Equal("Function1", context.Tools[0].Name);
Assert.Equal("Function2", context.Tools[1].Name);
}
}
@@ -8,6 +8,8 @@ using System.Threading;
using System.Threading.Tasks;
using Moq;
#pragma warning disable CA1861 // Avoid constant arrays as arguments
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
public class AgentThreadTests
@@ -102,7 +104,7 @@ public class AgentThreadTests
};
// Act
await thread.OnNewMessagesAsync(messages, CancellationToken.None);
await thread.MessagesReceivedAsync(messages, CancellationToken.None);
Assert.Equal("thread-123", thread.ConversationId);
Assert.Null(thread.MessageStore);
}
@@ -120,7 +122,7 @@ public class AgentThreadTests
};
// Act
await thread.OnNewMessagesAsync(messages, CancellationToken.None);
await thread.MessagesReceivedAsync(messages, CancellationToken.None);
// Assert
Assert.Equal(2, store.Count);
@@ -173,6 +175,26 @@ public class AgentThreadTests
Assert.Null(thread.MessageStore);
}
[Fact]
public async Task VerifyDeserializeWithAIContextProviderAsync()
{
// Arrange
var json = JsonSerializer.Deserialize("""
{
"aiContextProviderState": ["CP1"]
}
""", TestJsonSerializerContext.Default.JsonElement);
Mock<AIContextProvider> mockProvider = new();
var thread = new AgentThread() { AIContextProvider = mockProvider.Object };
// Act
await thread.DeserializeAsync(json);
// Assert
Assert.Null(thread.MessageStore);
mockProvider.Verify(m => m.DeserializeAsync(It.Is<JsonElement>(e => e.ValueKind == JsonValueKind.Array && e.GetArrayLength() == 1), It.IsAny<JsonSerializerOptions?>(), It.IsAny<CancellationToken>()), Times.Once);
}
[Fact]
public async Task DeserializeWithInvalidJsonThrowsAsync()
{
@@ -245,6 +267,31 @@ public class AgentThreadTests
Assert.Equal("TestContent", textContent.GetProperty("text").GetString());
}
[Fact]
public async Task VerifyThreadSerializationWithWithAIContextProviderAsync()
{
// Arrange
Mock<AIContextProvider> mockProvider = new();
var providerStateElement = JsonSerializer.SerializeToElement(new[] { "CP1" }, TestJsonSerializerContext.Default.StringArray);
mockProvider
.Setup(m => m.SerializeAsync(It.IsAny<JsonSerializerOptions?>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(providerStateElement);
var thread = new AgentThread();
thread.AIContextProvider = mockProvider.Object;
// Act
var json = await thread.SerializeAsync();
// Assert
Assert.Equal(JsonValueKind.Object, json.ValueKind);
Assert.True(json.TryGetProperty("aiContextProviderState", out var providerStateProperty));
Assert.Equal(JsonValueKind.Array, providerStateProperty.ValueKind);
Assert.Single(providerStateProperty.EnumerateArray());
Assert.Equal("CP1", providerStateProperty.EnumerateArray().First().GetString());
mockProvider.Verify(m => m.SerializeAsync(It.IsAny<JsonSerializerOptions?>(), It.IsAny<CancellationToken>()), Times.Once);
}
/// <summary>
/// Verify thread serialization to JSON with custom options.
/// </summary>
@@ -17,4 +17,5 @@ namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
[JsonSerializable(typeof(Animal))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(Dictionary<string, object?>))]
[JsonSerializable(typeof(string[]))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
@@ -22,6 +22,7 @@ public class ChatClientAgentOptionsTests
Assert.Null(options.Description);
Assert.Null(options.ChatOptions);
Assert.Null(options.ChatMessageStoreFactory);
Assert.Null(options.AIContextProviderFactory);
}
[Fact]
@@ -39,6 +40,7 @@ public class ChatClientAgentOptionsTests
Assert.Null(options.Instructions);
Assert.Null(options.Description);
Assert.Null(options.ChatOptions);
Assert.Null(options.AIContextProviderFactory);
}
[Fact]
@@ -163,11 +165,13 @@ public class ChatClientAgentOptionsTests
const string Description = "Test description";
var tools = new List<AITool> { AIFunctionFactory.Create(() => "test") };
static IChatMessageStore ChatMessageStoreFactory() => new Mock<IChatMessageStore>().Object;
static AIContextProvider AIContextProviderFactory() => new Mock<AIContextProvider>().Object;
var original = new ChatClientAgentOptions(Instructions, Name, Description, tools)
{
Id = "test-id",
ChatMessageStoreFactory = ChatMessageStoreFactory
ChatMessageStoreFactory = ChatMessageStoreFactory,
AIContextProviderFactory = AIContextProviderFactory
};
// Act
@@ -180,6 +184,7 @@ public class ChatClientAgentOptionsTests
Assert.Equal(original.Instructions, clone.Instructions);
Assert.Equal(original.Description, clone.Description);
Assert.Same(original.ChatMessageStoreFactory, clone.ChatMessageStoreFactory);
Assert.Same(original.AIContextProviderFactory, clone.AIContextProviderFactory);
// ChatOptions should be cloned, not the same reference
Assert.NotSame(original.ChatOptions, clone.ChatOptions);
@@ -209,5 +214,7 @@ public class ChatClientAgentOptionsTests
Assert.Equal(original.Instructions, clone.Instructions);
Assert.Equal(original.Description, clone.Description);
Assert.Null(clone.ChatOptions);
Assert.Null(clone.ChatMessageStoreFactory);
Assert.Null(clone.AIContextProviderFactory);
}
}
@@ -10,6 +10,8 @@ namespace Microsoft.Extensions.AI.Agents.UnitTests.ChatCompletion;
public class ChatClientAgentTests
{
#region Constructor Tests
/// <summary>
/// Verify the invocation and response of <see cref="ChatClientAgent"/>.
/// </summary>
@@ -38,6 +40,10 @@ public class ChatClientAgentTests
Assert.Equal("AgentInvokedChatClient", agent.ChatClient.GetType().Name);
}
#endregion
#region RunAsync Tests
/// <summary>
/// Verify the invocation and response of <see cref="ChatClientAgent"/> using <see cref="IChatClient"/>.
/// </summary>
@@ -390,6 +396,176 @@ public class ChatClientAgentTests
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync([new(ChatRole.User, "test")], thread));
}
/// <summary>
/// Verify that RunAsync sets the ConversationId on the thread when the service returns one.
/// </summary>
[Fact]
public async Task RunAsyncSetsConversationIdOnThreadWhenReturnedByChatClientAsync()
{
// Arrange
Mock<IChatClient> mockService = new();
mockService.Setup(
s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" });
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" });
AgentThread thread = new();
// Act
await agent.RunAsync([new(ChatRole.User, "test")], thread);
// Assert
Assert.Equal("ConvId", thread.ConversationId);
}
/// <summary>
/// Verify that RunAsync invokes any provided AIContextProvider and uses the result.
/// </summary>
[Fact]
public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync()
{
// Arrange
ChatMessage[] requestMessages = [new(ChatRole.User, "user message")];
ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")];
Mock<IChatClient> mockService = new();
List<ChatMessage> capturedMessages = [];
string capturedInstructions = string.Empty;
List<AITool> capturedTools = [];
mockService
.Setup(s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Callback<IEnumerable<ChatMessage>, ChatOptions, CancellationToken>((msgs, opts, ct) =>
{
capturedMessages.AddRange(msgs);
capturedInstructions = opts.Instructions ?? string.Empty;
if (opts.Tools != null)
{
capturedTools.AddRange(opts.Tools);
}
})
.ReturnsAsync(new ChatResponse(responseMessages));
var mockProvider = new Mock<AIContextProvider>();
mockProvider
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new AIContext
{
Messages = [new(ChatRole.System, "context provider message")],
Instructions = "context provider instructions",
Tools = [AIFunctionFactory.Create(() => { }, "context provider function")]
});
mockProvider
.Setup(p => p.InvokedAsync(It.IsAny<AIContextProvider.InvokedContext>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask());
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
// Act
await agent.RunAsync(requestMessages);
// Assert
// Should contain: base instructions, context message, user message, base function, context function
Assert.Equal(2, capturedMessages.Count);
Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions);
Assert.Equal("context provider message", capturedMessages[0].Text);
Assert.Equal(ChatRole.System, capturedMessages[0].Role);
Assert.Equal("user message", capturedMessages[1].Text);
Assert.Equal(ChatRole.User, capturedMessages[1].Role);
Assert.Equal(2, capturedTools.Count);
Assert.Contains(capturedTools, t => t.Name == "base function");
Assert.Contains(capturedTools, t => t.Name == "context provider function");
mockProvider.Verify(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()), Times.Once);
mockProvider.Verify(p => p.InvokedAsync(It.Is<AIContextProvider.InvokedContext>(x => x.RequestMessages == requestMessages && x.ResponseMessages == responseMessages && x.InvokeException == null), It.IsAny<CancellationToken>()), Times.Once);
}
/// <summary>
/// Verify that RunAsync invokes any provided AIContextProvider when the downstream GetResponse call fails.
/// </summary>
[Fact]
public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync()
{
// Arrange
ChatMessage[] requestMessages = [new(ChatRole.User, "user message")];
ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")];
Mock<IChatClient> mockService = new();
mockService
.Setup(s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Throws(new InvalidOperationException("downstream failure"));
var mockProvider = new Mock<AIContextProvider>();
mockProvider
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new AIContext());
mockProvider
.Setup(p => p.InvokedAsync(It.IsAny<AIContextProvider.InvokedContext>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask());
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
// Act
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync(requestMessages));
// Assert
mockProvider.Verify(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()), Times.Once);
mockProvider.Verify(p => p.InvokedAsync(It.Is<AIContextProvider.InvokedContext>(x => x.RequestMessages == requestMessages && x.ResponseMessages == null && x.InvokeException is InvalidOperationException), It.IsAny<CancellationToken>()), Times.Once);
}
/// <summary>
/// Verify that RunAsync invokes any provided AIContextProvider and succeeds even when the AIContext is empty.
/// </summary>
[Fact]
public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextAsync()
{
// Arrange
Mock<IChatClient> mockService = new();
List<ChatMessage> capturedMessages = [];
string capturedInstructions = string.Empty;
List<AITool> capturedTools = [];
mockService
.Setup(s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Callback<IEnumerable<ChatMessage>, ChatOptions, CancellationToken>((msgs, opts, ct) =>
{
capturedMessages.AddRange(msgs);
capturedInstructions = opts.Instructions ?? string.Empty;
if (opts.Tools != null)
{
capturedTools.AddRange(opts.Tools);
}
})
.ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]));
var mockProvider = new Mock<AIContextProvider>();
mockProvider
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new AIContext());
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
// Act
await agent.RunAsync([new(ChatRole.User, "user message")]);
// Assert
// Should contain: base instructions, user message, base function
Assert.Single(capturedMessages);
Assert.Equal("base instructions", capturedInstructions);
Assert.Equal("user message", capturedMessages[0].Text);
Assert.Equal(ChatRole.User, capturedMessages[0].Role);
Assert.Single(capturedTools);
Assert.Contains(capturedTools, t => t.Name == "base function");
mockProvider.Verify(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()), Times.Once);
}
#endregion
#region Property Override Tests
/// <summary>
@@ -1524,6 +1700,61 @@ public class ChatClientAgentTests
#endregion
#region GetNewThread Tests
[Fact]
public void GetNewThreadUsesChatMessageStoreFactoryIfProvided()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var mockStore = new Mock<IChatMessageStore>();
var factoryCalled = false;
var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions
{
Instructions = "Test instructions",
ChatMessageStoreFactory = () =>
{
factoryCalled = true;
return mockStore.Object;
}
});
// Act
var thread = agent.GetNewThread();
// Assert
Assert.True(factoryCalled, "ChatMessageStoreFactory was not called.");
Assert.Same(mockStore.Object, thread.MessageStore);
}
[Fact]
public void GetNewThreadUsesAIContextProviderFactoryIfProvided()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var mockContextProvider = new Mock<AIContextProvider>();
var factoryCalled = false;
var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions
{
Instructions = "Test instructions",
AIContextProviderFactory = () =>
{
factoryCalled = true;
return mockContextProvider.Object;
}
});
// Act
var thread = agent.GetNewThread();
// Assert
Assert.True(factoryCalled, "AIContextProviderFactory was not called.");
Assert.Same(mockContextProvider.Object, thread.AIContextProvider);
}
#endregion
private static async IAsyncEnumerable<T> ToAsyncEnumerableAsync<T>(IEnumerable<T> values)
{
await Task.Yield();
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.Extensions.AI.Agents.UnitTests;
[JsonSourceGenerationOptions(
PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
UseStringEnumConverter = true)]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(string))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;