mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: Add AIContextProvider support (#691)
* Add AIContextProvider support * Address feedback. * Address PR comments. * Switch to valuetask and remove parallel calls for AIContextProvider * Remove Model from ModelInvokingAsync method name * Remove agent thread id again and remove it from context provider interface * Add AIContextProvider serialization support to AgentThread and update sample to show this feature * Address PR comments * Improve memory sample * Update sample comment. * Remove AggregateAIContextProvider for now since it makes too many assumptions. We can include it later as a sample if needed. * Update AIContextProviders to have an Invoked method instead of MessagesAddingAsync. * Remove unused using. * Address PR comments. * Address PR comment. * Update comment. * Update comment * Address PR comments.
This commit is contained in:
committed by
GitHub
Unverified
parent
89cb94b5c2
commit
66fe1c957c
@@ -50,6 +50,7 @@
|
||||
<Project Path="samples/GettingStarted/Agents/Agent_Step10_AsMcpTool/Agent_Step10_AsMcpTool.csproj" />
|
||||
<Project Path="samples/GettingStarted/Agents/Agent_Step11_UsingImages/Agent_Step11_UsingImages.csproj" />
|
||||
<Project Path="samples/GettingStarted/Agents/Agent_Step12_AsFunctionTool/Agent_Step12_AsFunctionTool.csproj" />
|
||||
<Project Path="samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj" />
|
||||
</Folder>
|
||||
<Folder Name="/Samples/GettingStarted/AgentWithOpenAI/">
|
||||
<File Path="samples/GettingStarted/AgentWithOpenAI/README.md" />
|
||||
@@ -190,6 +191,10 @@
|
||||
<File Path="src/LegacySupport/CallerAttributes/CallerArgumentExpressionAttribute.cs" />
|
||||
<File Path="src/LegacySupport/CallerAttributes/README.md" />
|
||||
</Folder>
|
||||
<Folder Name="/Solution Items/src/LegacySupport/CompilerFeatureRequiredAttribute/">
|
||||
<File Path="src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs" />
|
||||
<File Path="src/LegacySupport/CompilerFeatureRequiredAttribute/README.md" />
|
||||
</Folder>
|
||||
<Folder Name="/Solution Items/src/LegacySupport/DiagnosticAttributes/">
|
||||
<File Path="src/LegacySupport/DiagnosticAttributes/NullableAttributes.cs" />
|
||||
<File Path="src/LegacySupport/DiagnosticAttributes/README.md" />
|
||||
@@ -206,6 +211,10 @@
|
||||
<File Path="src/LegacySupport/IsExternalInit/IsExternalInit.cs" />
|
||||
<File Path="src/LegacySupport/IsExternalInit/README.md" />
|
||||
</Folder>
|
||||
<Folder Name="/Solution Items/src/LegacySupport/RequiredMemberAttribute/">
|
||||
<File Path="src/LegacySupport/RequiredMemberAttribute/README.md" />
|
||||
<File Path="src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs" />
|
||||
</Folder>
|
||||
<Folder Name="/Solution Items/src/LegacySupport/TrimAttributes/">
|
||||
<File Path="src/LegacySupport/TrimAttributes/DynamicallyAccessedMembersAttribute.cs" />
|
||||
<File Path="src/LegacySupport/TrimAttributes/DynamicallyAccessedMemberTypes.cs" />
|
||||
|
||||
@@ -22,4 +22,12 @@
|
||||
<ItemGroup Condition="'$(InjectTrimAttributesOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
|
||||
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\TrimAttributes\*.cs" LinkBase="LegacySupport\TrimAttributes" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup Condition="'$(InjectRequiredMemberOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
|
||||
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\RequiredMemberAttribute\*.cs" LinkBase="LegacySupport\RequiredMemberAttribute" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup Condition="'$(InjectCompilerFeatureRequiredOnLegacy)' == 'true' AND !$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
|
||||
<Compile Include="$(MSBuildThisFileDirectory)\..\..\src\LegacySupport\CompilerFeatureRequiredAttribute\*.cs" LinkBase="LegacySupport\CompilerFeatureRequiredAttribute" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
+2
-4
@@ -39,8 +39,7 @@ namespace SampleApp
|
||||
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();
|
||||
|
||||
// Notify the thread of the input and output messages.
|
||||
await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
|
||||
await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken);
|
||||
await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
|
||||
|
||||
return new AgentRunResponse
|
||||
{
|
||||
@@ -59,8 +58,7 @@ namespace SampleApp
|
||||
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();
|
||||
|
||||
// Notify the thread of the input and output messages.
|
||||
await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
|
||||
await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken);
|
||||
await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
|
||||
|
||||
foreach (var message in responseMessages)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>net9.0</TargetFramework>
|
||||
|
||||
<Nullable>enable</Nullable>
|
||||
<ImplicitUsings>disable</ImplicitUsings>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Azure.AI.OpenAI" />
|
||||
<PackageReference Include="Azure.Identity" />
|
||||
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\..\..\src\Microsoft.Extensions.AI.Agents.OpenAI\Microsoft.Extensions.AI.Agents.OpenAI.csproj" />
|
||||
<ProjectReference Include="..\..\..\..\src\Microsoft.Extensions.AI.Agents\Microsoft.Extensions.AI.Agents.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
@@ -0,0 +1,152 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
// This sample shows how to add a basic custom memory component to an agent.
|
||||
// The memory component subscribes to all messages added to the conversation and
|
||||
// extracts the user's name and age if provided.
|
||||
// The component adds a prompt to ask for this information if it is not already known
|
||||
// and provides it to the model before each invocation if known.
|
||||
|
||||
using System;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Azure.AI.OpenAI;
|
||||
using Azure.Identity;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Extensions.AI.Agents;
|
||||
using OpenAI;
|
||||
using OpenAI.Chat;
|
||||
using SampleApp;
|
||||
|
||||
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new InvalidOperationException("AZURE_OPENAI_ENDPOINT is not set.");
|
||||
var deploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOYMENT_NAME") ?? "gpt-4o-mini";
|
||||
|
||||
ChatClient chatClient = new AzureOpenAIClient(
|
||||
new Uri(endpoint),
|
||||
new AzureCliCredential())
|
||||
.GetChatClient(deploymentName);
|
||||
|
||||
// Create the agent and provide a factory to add our custom memory component to
|
||||
// all threads created by the agent. Here each new memory component will have its own
|
||||
// user info object, so each thread will have its own memory.
|
||||
// In real world applications/services, where the user info would be persisted in a database,
|
||||
// and preferably shared between multiple threads used by the same user, ensure that the
|
||||
// factory reads the user id from the current context and scopes the memory component
|
||||
// and its storage to that user id.
|
||||
AIAgent agent = chatClient.CreateAIAgent(new ChatClientAgentOptions()
|
||||
{
|
||||
Instructions = "You are a friendly assistant. Always address the user by their name.",
|
||||
AIContextProviderFactory = () => new SampleApp.UserInfoMemory(chatClient.AsIChatClient())
|
||||
});
|
||||
|
||||
// Create a new thread for the conversation.
|
||||
AgentThread thread = agent.GetNewThread();
|
||||
|
||||
Console.WriteLine(">> Use thread with blank memory\n");
|
||||
|
||||
// Invoke the agent and output the text result.
|
||||
Console.WriteLine(await agent.RunAsync("Hello, what is the square root of 9?", thread));
|
||||
Console.WriteLine(await agent.RunAsync("My name is Ruaidhrí", thread));
|
||||
Console.WriteLine(await agent.RunAsync("I am 20 years old", thread));
|
||||
|
||||
// We can serialize the thread. The serialized state will include the state of the memory component.
|
||||
var threadElement = await thread.SerializeAsync();
|
||||
|
||||
Console.WriteLine("\n>> Use deserialized thread with previously created memories\n");
|
||||
|
||||
// Later we can deserialize the thread and continue the conversation with the previous memory component state.
|
||||
var deserializedThread = await agent.DeserializeThreadAsync(threadElement);
|
||||
Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedThread));
|
||||
|
||||
Console.WriteLine("\n>> Read memories from memory component\n");
|
||||
|
||||
// It's possible to access the memory component via the thread's AIContextProvider property.
|
||||
var userInfo = ((UserInfoMemory)deserializedThread.AIContextProvider!).UserInfo;
|
||||
|
||||
// Output the user info that was captured by the memory component.
|
||||
Console.WriteLine($"MEMORY - User Name: {userInfo.UserName}");
|
||||
Console.WriteLine($"MEMORY - User Age: {userInfo.UserAge}");
|
||||
|
||||
Console.WriteLine("\n>> Use new thread with previously created memories\n");
|
||||
|
||||
// Create a new thread.
|
||||
thread = agent.GetNewThread();
|
||||
|
||||
// It is also possible to add the memory component to an individual thread only instead of all
|
||||
// threads via the factory above.
|
||||
// In this case we will also use the same user info object, so this thread will share the same
|
||||
// memories as the previous thread.
|
||||
thread.AIContextProvider = new UserInfoMemory(chatClient.AsIChatClient(), userInfo);
|
||||
|
||||
// Invoke the agent and output the text result.
|
||||
// This time the agent should remember the user's name and use it in the response.
|
||||
Console.WriteLine(await agent.RunAsync("What is my name and age?", thread));
|
||||
|
||||
namespace SampleApp
|
||||
{
|
||||
/// <summary>
|
||||
/// Sample memory component that can remember a user's name and age.
|
||||
/// </summary>
|
||||
internal sealed class UserInfoMemory(IChatClient chatClient, UserInfo? userInfo = null) : AIContextProvider
|
||||
{
|
||||
public UserInfo UserInfo { get; set; } = userInfo ?? new();
|
||||
|
||||
public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Try and extract the user name and age from the message if we don't have it already and it's a user message.
|
||||
if ((this.UserInfo.UserName == null || this.UserInfo.UserAge == null) && context.RequestMessages.Any(x => x.Role == ChatRole.User))
|
||||
{
|
||||
var result = await chatClient.GetResponseAsync<UserInfo>(
|
||||
context.RequestMessages,
|
||||
new ChatOptions()
|
||||
{
|
||||
Instructions = "Extract the user's name and age from the message if present. If not present return nulls."
|
||||
},
|
||||
cancellationToken: cancellationToken);
|
||||
|
||||
this.UserInfo.UserName ??= result.Result.UserName;
|
||||
this.UserInfo.UserAge ??= result.Result.UserAge;
|
||||
}
|
||||
}
|
||||
|
||||
public override ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
StringBuilder instructions = new();
|
||||
|
||||
// If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context.
|
||||
instructions.AppendLine(
|
||||
this.UserInfo.UserName == null ?
|
||||
"Ask the user for their name and politely decline to answer any questions until they provide it." :
|
||||
$"The user's name is {this.UserInfo.UserName}.");
|
||||
|
||||
instructions.AppendLine(
|
||||
this.UserInfo.UserAge == null ?
|
||||
"Ask the user for their age and politely decline to answer any questions until they provide it." :
|
||||
$"The user's age is {this.UserInfo.UserAge}.");
|
||||
|
||||
return new ValueTask<AIContext>(new AIContext
|
||||
{
|
||||
Instructions = instructions.ToString()
|
||||
});
|
||||
}
|
||||
|
||||
public override ValueTask<JsonElement?> SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return new ValueTask<JsonElement?>(JsonSerializer.SerializeToElement(this.UserInfo, jsonSerializerOptions));
|
||||
}
|
||||
|
||||
public override ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
this.UserInfo = JsonSerializer.Deserialize<UserInfo>(serializedState, jsonSerializerOptions) ?? new UserInfo();
|
||||
return default;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class UserInfo
|
||||
{
|
||||
public string? UserName { get; set; }
|
||||
public int? UserAge { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,8 @@ Before you begin, ensure you have the following prerequisites:
|
||||
|[Dependency injection with a simple agent](./Agent_Step09_DependencyInjection/)|This sample demonstrates how to add and resolve an agent with a dependency injection container|
|
||||
|[Exposing a simple agent as MCP tool](./Agent_Step10_AsMcpTool/)|This sample demonstrates how to expose an agent as an MCP tool|
|
||||
|[Using images with a simple agent](./Agent_Step11_UsingImages/)|This sample demonstrates how to use image multi-modality with an AI agent|
|
||||
|[Exposing a simple agent a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool|
|
||||
|[Exposing a simple agent as a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool|
|
||||
|[Using memory with an agent](./Agent_Step12_Memory/)|This sample demonstrates how to create a simple memory component and use it with an agent|
|
||||
|
||||
## Running the samples from the console
|
||||
|
||||
|
||||
+37
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
#pragma warning disable SA1623 // Property summary documentation should match accessors
|
||||
|
||||
namespace System.Runtime.CompilerServices;
|
||||
|
||||
/// <summary>
|
||||
/// Indicates that compiler support for a particular feature is required for the location where this attribute is applied.
|
||||
/// </summary>
|
||||
[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)]
|
||||
internal sealed class CompilerFeatureRequiredAttribute : Attribute
|
||||
{
|
||||
public CompilerFeatureRequiredAttribute(string featureName)
|
||||
{
|
||||
FeatureName = featureName;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The name of the compiler feature.
|
||||
/// </summary>
|
||||
public string FeatureName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand <see cref="FeatureName"/>.
|
||||
/// </summary>
|
||||
public bool IsOptional { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// The <see cref="FeatureName"/> used for the ref structs C# feature.
|
||||
/// </summary>
|
||||
public const string RefStructs = nameof(RefStructs);
|
||||
|
||||
/// <summary>
|
||||
/// The <see cref="FeatureName"/> used for the required members C# feature.
|
||||
/// </summary>
|
||||
public const string RequiredMembers = nameof(RequiredMembers);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
Enables use of C# required members on older frameworks.
|
||||
|
||||
To use this source in your project, add the following to your `.csproj` file:
|
||||
|
||||
```xml
|
||||
<PropertyGroup>
|
||||
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>
|
||||
</PropertyGroup>
|
||||
```
|
||||
@@ -0,0 +1,9 @@
|
||||
Enables use of C# required members on older frameworks.
|
||||
|
||||
To use this source in your project, add the following to your `.csproj` file:
|
||||
|
||||
```xml
|
||||
<PropertyGroup>
|
||||
<InjectRequiredMemberOnLegacy>true</InjectRequiredMemberOnLegacy>
|
||||
</PropertyGroup>
|
||||
```
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System.ComponentModel;
|
||||
|
||||
namespace System.Runtime.CompilerServices;
|
||||
|
||||
/// <summary>Specifies that a type has required members or that a member is required.</summary>
|
||||
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)]
|
||||
[EditorBrowsable(EditorBrowsableState.Never)]
|
||||
internal sealed class RequiredMemberAttribute : Attribute;
|
||||
@@ -290,6 +290,6 @@ public abstract class AIAgent
|
||||
_ = Throw.IfNull(thread);
|
||||
_ = Throw.IfNull(messages);
|
||||
|
||||
await thread.OnNewMessagesAsync(messages, cancellationToken).ConfigureAwait(false);
|
||||
await thread.MessagesReceivedAsync(messages, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.Extensions.AI.Agents;
|
||||
|
||||
/// <summary>
|
||||
/// A class containing any context that should be provided to the AI model
|
||||
/// as supplied by an <see cref="AIContextProvider"/>.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Each <see cref="AIContextProvider"/> has the ability to provide its own context for each invocation.
|
||||
/// The <see cref="AIContext"/> class contains the additional context supplied by the <see cref="AIContextProvider"/>.
|
||||
/// This context will be combined with context supplied by other providers before being passed to the AI model.
|
||||
/// </remarks>
|
||||
public sealed class AIContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets any instructions to pass to the AI model in addition to any other prompts
|
||||
/// that it may already have (in the case of an agent), or chat history that may
|
||||
/// already exist.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// These instructions will be transient and only apply to the current invocation.
|
||||
/// </remarks>
|
||||
public string? Instructions { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a list of messages to add to the chat history.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// These messages will permanently be added to the chat history.
|
||||
/// </remarks>
|
||||
public IList<ChatMessage>? Messages { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a list of functions/tools to make available to the AI model for the current invocation.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// These functions/tools will be transient and only apply to the current invocation.
|
||||
/// </remarks>
|
||||
public IList<AITool>? Tools { get; set; }
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.Extensions.AI.Agents;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for all AI context providers.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// An AI context provider is a component that can be used to enhance the AI's context management.
|
||||
/// It can listen to changes in the conversation, provide additional context to
|
||||
/// the Model/Agent/etc. just before invocation and supply additional function tools.
|
||||
/// </remarks>
|
||||
public abstract class AIContextProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Called just before the Model/Agent/etc. is invoked
|
||||
/// Implementers can load any additional context required at this time,
|
||||
/// and they should return any context that should be passed to the Model/Agent/etc.
|
||||
/// </summary>
|
||||
/// <param name="context">Contains the event context.</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>A task that completes when the context has been rendered and returned.</returns>
|
||||
public abstract ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default);
|
||||
|
||||
/// <summary>
|
||||
/// Called just before the Model/Agent/etc. is invoked
|
||||
/// Implementers can load any additional context required at this time,
|
||||
/// and they should return any context that should be passed to the Model/Agent/etc.
|
||||
/// </summary>
|
||||
/// <param name="context">Contains the event context.</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>A task that completes when the context has been rendered and returned.</returns>
|
||||
public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return default;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Serializes the current object's state to a <see cref="JsonElement"/> using the specified serialization options.
|
||||
/// </summary>
|
||||
/// <param name="jsonSerializerOptions">The JSON serialization options to use.</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>A <see cref="JsonElement"/> representation of the object's state.</returns>
|
||||
public virtual ValueTask<JsonElement?> SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return default;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes the state contained in the provided <see cref="JsonElement"/> into the properties on this object.
|
||||
/// </summary>
|
||||
/// <param name="serializedState">A <see cref="JsonElement"/> representing the state of the object.</param>
|
||||
/// <param name="jsonSerializerOptions">Optional settings for customizing the JSON deserialization process.</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>A <see cref="ValueTask"/> that completes when the state has been deserialized.</returns>
|
||||
public virtual ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return default;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Contains the event context provided to <see cref="AIContextProvider.InvokingAsync(InvokingContext, CancellationToken)"/>.
|
||||
/// </summary>
|
||||
public class InvokingContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InvokingContext"/> class.
|
||||
/// </summary>
|
||||
/// <param name="requestMessages">The messages to be sent to the Model/Agent/etc. for this invocation.</param>
|
||||
/// <exception cref="ArgumentNullException">Thrown if <paramref name="requestMessages"/> is <see langword="null"/>.</exception>
|
||||
public InvokingContext(IEnumerable<ChatMessage> requestMessages)
|
||||
{
|
||||
RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the messages that will be sent to the Model/Agent/etc. for this invocation.
|
||||
/// </summary>
|
||||
public IEnumerable<ChatMessage> RequestMessages { get; private set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Contains the event conext provided to <see cref="AIContextProvider.InvokedAsync(InvokedContext, CancellationToken)"/>.
|
||||
/// </summary>
|
||||
public class InvokedContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InvokedContext"/> class with the specified request messages.
|
||||
/// </summary>
|
||||
/// <param name="requestMessages">The messages that were sent to the Model/Agent/etc. for this invocation.</param>
|
||||
/// <exception cref="ArgumentNullException">Thrown if <paramref name="requestMessages"/> is <see langword="null"/>.</exception>
|
||||
public InvokedContext(IEnumerable<ChatMessage> requestMessages)
|
||||
{
|
||||
RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the messages that were sent to the Model/Agent/etc. for this invocation.
|
||||
/// </summary>
|
||||
public IEnumerable<ChatMessage> RequestMessages { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the collection of response messages generated by Model/Agent/etc. if the invocation succeeded.
|
||||
/// </summary>
|
||||
public IEnumerable<ChatMessage>? ResponseMessages { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the <see cref="Exception"/> that was thrown during the invocation, if the invocation failed.
|
||||
/// </summary>
|
||||
public Exception? InvokeException { get; init; }
|
||||
}
|
||||
}
|
||||
@@ -27,7 +27,7 @@ public class AgentThread
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the id of the current thread to support cases where the thread is owned by the agent service.
|
||||
/// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
@@ -108,6 +108,11 @@ public class AgentThread
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the <see cref="AIContextProvider"/> used by this thread to provide additional context to the AI model before each invocation.
|
||||
/// </summary>
|
||||
public AIContextProvider? AIContextProvider { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Serializes the current object's state to a <see cref="JsonElement"/> using the specified serialization options.
|
||||
/// </summary>
|
||||
@@ -120,10 +125,15 @@ public class AgentThread
|
||||
null :
|
||||
await this._messageStore.SerializeStateAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
var aiContextProviderState = this.AIContextProvider is null ?
|
||||
null :
|
||||
await this.AIContextProvider.SerializeAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
var state = new ThreadState
|
||||
{
|
||||
ConversationId = this.ConversationId,
|
||||
StoreState = storeState
|
||||
StoreState = storeState,
|
||||
AIContextProviderState = aiContextProviderState
|
||||
};
|
||||
|
||||
return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState)));
|
||||
@@ -139,7 +149,7 @@ public class AgentThread
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>A task that completes when the context has been updated.</returns>
|
||||
/// <exception cref="InvalidOperationException">The thread has been deleted.</exception>
|
||||
protected internal virtual async Task OnNewMessagesAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
|
||||
protected internal virtual async Task MessagesReceivedAsync(IEnumerable<ChatMessage> newMessages, CancellationToken cancellationToken = default)
|
||||
{
|
||||
switch (this)
|
||||
{
|
||||
@@ -186,6 +196,11 @@ public class AgentThread
|
||||
return;
|
||||
}
|
||||
|
||||
if (state?.AIContextProviderState.HasValue is true && this.AIContextProvider is not null)
|
||||
{
|
||||
await this.AIContextProvider.DeserializeAsync(state.AIContextProviderState.Value, jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
// If we don't have any IChatMessageStore state return here.
|
||||
if (state?.StoreState is null || state?.StoreState.Value.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null)
|
||||
{
|
||||
@@ -206,5 +221,7 @@ public class AgentThread
|
||||
public string? ConversationId { get; set; }
|
||||
|
||||
public JsonElement? StoreState { get; set; }
|
||||
|
||||
public JsonElement? AIContextProviderState { get; set; }
|
||||
}
|
||||
}
|
||||
|
||||
+3
@@ -12,6 +12,9 @@
|
||||
<InjectSharedThrow>true</InjectSharedThrow>
|
||||
<InjectDiagnosticClassesOnLegacy>true</InjectDiagnosticClassesOnLegacy>
|
||||
<InjectTrimAttributesOnLegacy>true</InjectTrimAttributesOnLegacy>
|
||||
<InjectIsExternalInitOnLegacy>true</InjectIsExternalInitOnLegacy>
|
||||
<InjectRequiredMemberOnLegacy>true</InjectRequiredMemberOnLegacy>
|
||||
<InjectCompilerFeatureRequiredOnLegacy>true</InjectCompilerFeatureRequiredOnLegacy>
|
||||
</PropertyGroup>
|
||||
|
||||
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
|
||||
|
||||
@@ -114,7 +114,17 @@ public sealed class ChatClientAgent : AIAgent
|
||||
|
||||
this._logger.LogAgentChatClientInvokingAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType);
|
||||
|
||||
ChatResponse chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
// Call the IChatClient and notify the AIContextProvider of any failures.
|
||||
ChatResponse chatResponse;
|
||||
try
|
||||
{
|
||||
chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType, inputMessages.Count);
|
||||
|
||||
@@ -122,19 +132,17 @@ public sealed class ChatClientAgent : AIAgent
|
||||
// so let's update it and set the conversation id for the service thread case.
|
||||
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
|
||||
|
||||
// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread.
|
||||
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
// Ensure that the author name is set for each message in the response.
|
||||
foreach (ChatMessage chatResponseMessage in chatResponse.Messages)
|
||||
{
|
||||
chatResponseMessage.AuthorName ??= agentName;
|
||||
}
|
||||
|
||||
// Convert the chat response messages to a valid IReadOnlyCollection for notification signatures below.
|
||||
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? [.. chatResponse.Messages];
|
||||
// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent message state in the thread.
|
||||
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
|
||||
// Notify the AIContextProvider of all new messages.
|
||||
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return new(chatResponse) { AgentId = this.Id };
|
||||
}
|
||||
@@ -156,15 +164,34 @@ public sealed class ChatClientAgent : AIAgent
|
||||
|
||||
this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType);
|
||||
|
||||
// Using the enumerator to ensure we consider the case where no updates are returned for notification.
|
||||
var responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);
|
||||
List<ChatResponseUpdate> responseUpdates = [];
|
||||
|
||||
IAsyncEnumerator<ChatResponseUpdate> responseUpdatesEnumerator;
|
||||
|
||||
try
|
||||
{
|
||||
// Using the enumerator to ensure we consider the case where no updates are returned for notification.
|
||||
responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
this._logger.LogAgentChatClientInvokedStreamingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType);
|
||||
|
||||
List<ChatResponseUpdate> responseUpdates = [];
|
||||
|
||||
// Ensure we start the streaming request
|
||||
var hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
bool hasUpdates;
|
||||
try
|
||||
{
|
||||
// Ensure we start the streaming request
|
||||
hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
while (hasUpdates)
|
||||
{
|
||||
@@ -176,20 +203,28 @@ public sealed class ChatClientAgent : AIAgent
|
||||
yield return new(update) { AgentId = this.Id };
|
||||
}
|
||||
|
||||
hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
var chatResponse = responseUpdates.ToChatResponse();
|
||||
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? [.. chatResponse.Messages];
|
||||
|
||||
// We can derive the type of supported thread from whether we have a conversation id,
|
||||
// so let's update it and set the conversation id for the service thread case.
|
||||
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
|
||||
|
||||
// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
|
||||
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false);
|
||||
await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
|
||||
// Notify the AIContextProvider of all new messages.
|
||||
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
@@ -204,12 +239,40 @@ public sealed class ChatClientAgent : AIAgent
|
||||
/// <inheritdoc/>
|
||||
public override AgentThread GetNewThread()
|
||||
{
|
||||
var thread = new AgentThread { MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke() };
|
||||
var thread = new AgentThread
|
||||
{
|
||||
MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(),
|
||||
AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke()
|
||||
};
|
||||
return thread;
|
||||
}
|
||||
|
||||
#region Private
|
||||
|
||||
/// <summary>
|
||||
/// Notify the <see cref="AIContextProvider"/> when an agent run succeeded, if there is an <see cref="AIContextProvider"/>.
|
||||
/// </summary>
|
||||
private static async Task NotifyAIContextProviderOfSuccessAsync(AgentThread thread, IEnumerable<ChatMessage> inputMessages, IEnumerable<ChatMessage> responseMessages, CancellationToken cancellationToken)
|
||||
{
|
||||
if (thread.AIContextProvider is not null)
|
||||
{
|
||||
await thread.AIContextProvider.InvokedAsync(new(inputMessages) { ResponseMessages = responseMessages },
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Notify the <see cref="AIContextProvider"/> of any failure during an agent run, if there is an <see cref="AIContextProvider"/>.
|
||||
/// </summary>
|
||||
private static async Task NotifyAIContextProviderOfFailureAsync(AgentThread thread, Exception ex, IEnumerable<ChatMessage> inputMessages, CancellationToken cancellationToken)
|
||||
{
|
||||
if (thread.AIContextProvider is not null)
|
||||
{
|
||||
await thread.AIContextProvider.InvokedAsync(new(inputMessages) { InvokeException = ex },
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Configures and returns chat options by merging the provided run options with the agent's default chat options.
|
||||
/// </summary>
|
||||
@@ -350,6 +413,34 @@ public sealed class ChatClientAgent : AIAgent
|
||||
threadMessages.AddRange(await thread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false));
|
||||
}
|
||||
|
||||
// If we have an AIContextProvider, we should get context from it, and update our
|
||||
// messages and options with the additional context.
|
||||
if (thread.AIContextProvider is not null)
|
||||
{
|
||||
var invokingContext = new AIContextProvider.InvokingContext(inputMessages);
|
||||
var aiContext = await thread.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
|
||||
if (aiContext.Messages is { Count: > 0 })
|
||||
{
|
||||
threadMessages.AddRange(aiContext.Messages);
|
||||
}
|
||||
|
||||
if (aiContext.Tools is { Count: > 0 })
|
||||
{
|
||||
chatOptions ??= new();
|
||||
chatOptions.Tools ??= [];
|
||||
foreach (AITool tool in aiContext.Tools)
|
||||
{
|
||||
chatOptions.Tools.Add(tool);
|
||||
}
|
||||
}
|
||||
|
||||
if (aiContext.Instructions is not null)
|
||||
{
|
||||
chatOptions ??= new();
|
||||
chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}";
|
||||
}
|
||||
}
|
||||
|
||||
// Add the input messages to the end of thread messages.
|
||||
threadMessages.AddRange(inputMessages);
|
||||
|
||||
|
||||
@@ -80,6 +80,13 @@ public class ChatClientAgentOptions
|
||||
/// </summary>
|
||||
public Func<IChatMessageStore>? ChatMessageStoreFactory { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a factory function to create an instance of <see cref="AIContextProvider"/>
|
||||
/// which will be used to create a context provider for each new thread, and can then
|
||||
/// provide additional context for each agent run.
|
||||
/// </summary>
|
||||
public Func<AIContextProvider>? AIContextProviderFactory { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether to use the provided <see cref="IChatClient"/> instance as is,
|
||||
/// without applying any default decorators.
|
||||
@@ -105,6 +112,7 @@ public class ChatClientAgentOptions
|
||||
Instructions = this.Instructions,
|
||||
Description = this.Description,
|
||||
ChatOptions = this.ChatOptions?.Clone(),
|
||||
ChatMessageStoreFactory = this.ChatMessageStoreFactory
|
||||
ChatMessageStoreFactory = this.ChatMessageStoreFactory,
|
||||
AIContextProviderFactory = this.AIContextProviderFactory,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -235,7 +235,7 @@ public class AIAgentTests
|
||||
|
||||
await MockAgent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken);
|
||||
|
||||
threadMock.Protected().Verify("OnNewMessagesAsync", Times.Once(), messages, cancellationToken);
|
||||
threadMock.Protected().Verify("MessagesReceivedAsync", Times.Once(), messages, cancellationToken);
|
||||
}
|
||||
|
||||
#region GetService Method Tests
|
||||
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.ObjectModel;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
|
||||
|
||||
public class AIContextProviderTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task InvokedAsync_ReturnsCompletedTaskAsync()
|
||||
{
|
||||
var provider = new TestAIContextProvider();
|
||||
var messages = new ReadOnlyCollection<ChatMessage>(new List<ChatMessage>());
|
||||
var task = provider.InvokedAsync(new(messages));
|
||||
Assert.Equal(default, task);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SerializeAsync_ReturnsEmptyElementAsync()
|
||||
{
|
||||
var provider = new TestAIContextProvider();
|
||||
var actual = await provider.SerializeAsync();
|
||||
Assert.Equal(default, actual);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DeserializeAsync_ReturnsCompletedTaskAsync()
|
||||
{
|
||||
var provider = new TestAIContextProvider();
|
||||
var element = default(JsonElement);
|
||||
var actual = provider.DeserializeAsync(element);
|
||||
Assert.Equal(default, actual);
|
||||
}
|
||||
|
||||
private sealed class TestAIContextProvider : AIContextProvider
|
||||
{
|
||||
public override ValueTask<AIContext> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return default;
|
||||
}
|
||||
|
||||
public override async ValueTask<JsonElement?> SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return await base.SerializeAsync(jsonSerializerOptions, cancellationToken);
|
||||
}
|
||||
|
||||
public override async ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
await base.DeserializeAsync(serializedState, jsonSerializerOptions, cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Unit tests for <see cref="AIContext"/>.
|
||||
/// </summary>
|
||||
public class AIContextTests
|
||||
{
|
||||
[Fact]
|
||||
public void SetInstructionsRoundtrips()
|
||||
{
|
||||
var context = new AIContext
|
||||
{
|
||||
Instructions = "Test Instructions"
|
||||
};
|
||||
|
||||
Assert.Equal("Test Instructions", context.Instructions);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SetMessagesRoundtrips()
|
||||
{
|
||||
var context = new AIContext
|
||||
{
|
||||
Messages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "Hello"),
|
||||
new(ChatRole.Assistant, "Hi there!")
|
||||
}
|
||||
};
|
||||
|
||||
Assert.NotNull(context.Messages);
|
||||
Assert.Equal(2, context.Messages.Count);
|
||||
Assert.Equal("Hello", context.Messages[0].Text);
|
||||
Assert.Equal("Hi there!", context.Messages[1].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SetAIFunctionsRoundtrips()
|
||||
{
|
||||
var context = new AIContext
|
||||
{
|
||||
Tools = new List<AITool>
|
||||
{
|
||||
AIFunctionFactory.Create(() => "Function1", "Function1", "Description1"),
|
||||
AIFunctionFactory.Create(() => "Function2", "Function2", "Description2"),
|
||||
}
|
||||
};
|
||||
|
||||
Assert.NotNull(context.Tools);
|
||||
Assert.Equal(2, context.Tools.Count);
|
||||
Assert.Equal("Function1", context.Tools[0].Name);
|
||||
Assert.Equal("Function2", context.Tools[1].Name);
|
||||
}
|
||||
}
|
||||
+49
-2
@@ -8,6 +8,8 @@ using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Moq;
|
||||
|
||||
#pragma warning disable CA1861 // Avoid constant arrays as arguments
|
||||
|
||||
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
|
||||
|
||||
public class AgentThreadTests
|
||||
@@ -102,7 +104,7 @@ public class AgentThreadTests
|
||||
};
|
||||
|
||||
// Act
|
||||
await thread.OnNewMessagesAsync(messages, CancellationToken.None);
|
||||
await thread.MessagesReceivedAsync(messages, CancellationToken.None);
|
||||
Assert.Equal("thread-123", thread.ConversationId);
|
||||
Assert.Null(thread.MessageStore);
|
||||
}
|
||||
@@ -120,7 +122,7 @@ public class AgentThreadTests
|
||||
};
|
||||
|
||||
// Act
|
||||
await thread.OnNewMessagesAsync(messages, CancellationToken.None);
|
||||
await thread.MessagesReceivedAsync(messages, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(2, store.Count);
|
||||
@@ -173,6 +175,26 @@ public class AgentThreadTests
|
||||
Assert.Null(thread.MessageStore);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task VerifyDeserializeWithAIContextProviderAsync()
|
||||
{
|
||||
// Arrange
|
||||
var json = JsonSerializer.Deserialize("""
|
||||
{
|
||||
"aiContextProviderState": ["CP1"]
|
||||
}
|
||||
""", TestJsonSerializerContext.Default.JsonElement);
|
||||
Mock<AIContextProvider> mockProvider = new();
|
||||
var thread = new AgentThread() { AIContextProvider = mockProvider.Object };
|
||||
|
||||
// Act
|
||||
await thread.DeserializeAsync(json);
|
||||
|
||||
// Assert
|
||||
Assert.Null(thread.MessageStore);
|
||||
mockProvider.Verify(m => m.DeserializeAsync(It.Is<JsonElement>(e => e.ValueKind == JsonValueKind.Array && e.GetArrayLength() == 1), It.IsAny<JsonSerializerOptions?>(), It.IsAny<CancellationToken>()), Times.Once);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DeserializeWithInvalidJsonThrowsAsync()
|
||||
{
|
||||
@@ -245,6 +267,31 @@ public class AgentThreadTests
|
||||
Assert.Equal("TestContent", textContent.GetProperty("text").GetString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task VerifyThreadSerializationWithWithAIContextProviderAsync()
|
||||
{
|
||||
// Arrange
|
||||
Mock<AIContextProvider> mockProvider = new();
|
||||
var providerStateElement = JsonSerializer.SerializeToElement(new[] { "CP1" }, TestJsonSerializerContext.Default.StringArray);
|
||||
mockProvider
|
||||
.Setup(m => m.SerializeAsync(It.IsAny<JsonSerializerOptions?>(), It.IsAny<CancellationToken>()))
|
||||
.ReturnsAsync(providerStateElement);
|
||||
|
||||
var thread = new AgentThread();
|
||||
thread.AIContextProvider = mockProvider.Object;
|
||||
|
||||
// Act
|
||||
var json = await thread.SerializeAsync();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(JsonValueKind.Object, json.ValueKind);
|
||||
Assert.True(json.TryGetProperty("aiContextProviderState", out var providerStateProperty));
|
||||
Assert.Equal(JsonValueKind.Array, providerStateProperty.ValueKind);
|
||||
Assert.Single(providerStateProperty.EnumerateArray());
|
||||
Assert.Equal("CP1", providerStateProperty.EnumerateArray().First().GetString());
|
||||
mockProvider.Verify(m => m.SerializeAsync(It.IsAny<JsonSerializerOptions?>(), It.IsAny<CancellationToken>()), Times.Once);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify thread serialization to JSON with custom options.
|
||||
/// </summary>
|
||||
|
||||
+1
@@ -17,4 +17,5 @@ namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
|
||||
[JsonSerializable(typeof(Animal))]
|
||||
[JsonSerializable(typeof(JsonElement))]
|
||||
[JsonSerializable(typeof(Dictionary<string, object?>))]
|
||||
[JsonSerializable(typeof(string[]))]
|
||||
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
|
||||
|
||||
+8
-1
@@ -22,6 +22,7 @@ public class ChatClientAgentOptionsTests
|
||||
Assert.Null(options.Description);
|
||||
Assert.Null(options.ChatOptions);
|
||||
Assert.Null(options.ChatMessageStoreFactory);
|
||||
Assert.Null(options.AIContextProviderFactory);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
@@ -39,6 +40,7 @@ public class ChatClientAgentOptionsTests
|
||||
Assert.Null(options.Instructions);
|
||||
Assert.Null(options.Description);
|
||||
Assert.Null(options.ChatOptions);
|
||||
Assert.Null(options.AIContextProviderFactory);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
@@ -163,11 +165,13 @@ public class ChatClientAgentOptionsTests
|
||||
const string Description = "Test description";
|
||||
var tools = new List<AITool> { AIFunctionFactory.Create(() => "test") };
|
||||
static IChatMessageStore ChatMessageStoreFactory() => new Mock<IChatMessageStore>().Object;
|
||||
static AIContextProvider AIContextProviderFactory() => new Mock<AIContextProvider>().Object;
|
||||
|
||||
var original = new ChatClientAgentOptions(Instructions, Name, Description, tools)
|
||||
{
|
||||
Id = "test-id",
|
||||
ChatMessageStoreFactory = ChatMessageStoreFactory
|
||||
ChatMessageStoreFactory = ChatMessageStoreFactory,
|
||||
AIContextProviderFactory = AIContextProviderFactory
|
||||
};
|
||||
|
||||
// Act
|
||||
@@ -180,6 +184,7 @@ public class ChatClientAgentOptionsTests
|
||||
Assert.Equal(original.Instructions, clone.Instructions);
|
||||
Assert.Equal(original.Description, clone.Description);
|
||||
Assert.Same(original.ChatMessageStoreFactory, clone.ChatMessageStoreFactory);
|
||||
Assert.Same(original.AIContextProviderFactory, clone.AIContextProviderFactory);
|
||||
|
||||
// ChatOptions should be cloned, not the same reference
|
||||
Assert.NotSame(original.ChatOptions, clone.ChatOptions);
|
||||
@@ -209,5 +214,7 @@ public class ChatClientAgentOptionsTests
|
||||
Assert.Equal(original.Instructions, clone.Instructions);
|
||||
Assert.Equal(original.Description, clone.Description);
|
||||
Assert.Null(clone.ChatOptions);
|
||||
Assert.Null(clone.ChatMessageStoreFactory);
|
||||
Assert.Null(clone.AIContextProviderFactory);
|
||||
}
|
||||
}
|
||||
|
||||
+231
@@ -10,6 +10,8 @@ namespace Microsoft.Extensions.AI.Agents.UnitTests.ChatCompletion;
|
||||
|
||||
public class ChatClientAgentTests
|
||||
{
|
||||
#region Constructor Tests
|
||||
|
||||
/// <summary>
|
||||
/// Verify the invocation and response of <see cref="ChatClientAgent"/>.
|
||||
/// </summary>
|
||||
@@ -38,6 +40,10 @@ public class ChatClientAgentTests
|
||||
Assert.Equal("AgentInvokedChatClient", agent.ChatClient.GetType().Name);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region RunAsync Tests
|
||||
|
||||
/// <summary>
|
||||
/// Verify the invocation and response of <see cref="ChatClientAgent"/> using <see cref="IChatClient"/>.
|
||||
/// </summary>
|
||||
@@ -390,6 +396,176 @@ public class ChatClientAgentTests
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync([new(ChatRole.User, "test")], thread));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify that RunAsync sets the ConversationId on the thread when the service returns one.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task RunAsyncSetsConversationIdOnThreadWhenReturnedByChatClientAsync()
|
||||
{
|
||||
// Arrange
|
||||
Mock<IChatClient> mockService = new();
|
||||
mockService.Setup(
|
||||
s => s.GetResponseAsync(
|
||||
It.IsAny<IEnumerable<ChatMessage>>(),
|
||||
It.IsAny<ChatOptions>(),
|
||||
It.IsAny<CancellationToken>())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" });
|
||||
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" });
|
||||
AgentThread thread = new();
|
||||
|
||||
// Act
|
||||
await agent.RunAsync([new(ChatRole.User, "test")], thread);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("ConvId", thread.ConversationId);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify that RunAsync invokes any provided AIContextProvider and uses the result.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync()
|
||||
{
|
||||
// Arrange
|
||||
ChatMessage[] requestMessages = [new(ChatRole.User, "user message")];
|
||||
ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")];
|
||||
Mock<IChatClient> mockService = new();
|
||||
List<ChatMessage> capturedMessages = [];
|
||||
string capturedInstructions = string.Empty;
|
||||
List<AITool> capturedTools = [];
|
||||
mockService
|
||||
.Setup(s => s.GetResponseAsync(
|
||||
It.IsAny<IEnumerable<ChatMessage>>(),
|
||||
It.IsAny<ChatOptions>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Callback<IEnumerable<ChatMessage>, ChatOptions, CancellationToken>((msgs, opts, ct) =>
|
||||
{
|
||||
capturedMessages.AddRange(msgs);
|
||||
capturedInstructions = opts.Instructions ?? string.Empty;
|
||||
if (opts.Tools != null)
|
||||
{
|
||||
capturedTools.AddRange(opts.Tools);
|
||||
}
|
||||
})
|
||||
.ReturnsAsync(new ChatResponse(responseMessages));
|
||||
|
||||
var mockProvider = new Mock<AIContextProvider>();
|
||||
mockProvider
|
||||
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
|
||||
.ReturnsAsync(new AIContext
|
||||
{
|
||||
Messages = [new(ChatRole.System, "context provider message")],
|
||||
Instructions = "context provider instructions",
|
||||
Tools = [AIFunctionFactory.Create(() => { }, "context provider function")]
|
||||
});
|
||||
mockProvider
|
||||
.Setup(p => p.InvokedAsync(It.IsAny<AIContextProvider.InvokedContext>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(new ValueTask());
|
||||
|
||||
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
|
||||
|
||||
// Act
|
||||
await agent.RunAsync(requestMessages);
|
||||
|
||||
// Assert
|
||||
// Should contain: base instructions, context message, user message, base function, context function
|
||||
Assert.Equal(2, capturedMessages.Count);
|
||||
Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions);
|
||||
Assert.Equal("context provider message", capturedMessages[0].Text);
|
||||
Assert.Equal(ChatRole.System, capturedMessages[0].Role);
|
||||
Assert.Equal("user message", capturedMessages[1].Text);
|
||||
Assert.Equal(ChatRole.User, capturedMessages[1].Role);
|
||||
Assert.Equal(2, capturedTools.Count);
|
||||
Assert.Contains(capturedTools, t => t.Name == "base function");
|
||||
Assert.Contains(capturedTools, t => t.Name == "context provider function");
|
||||
mockProvider.Verify(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()), Times.Once);
|
||||
mockProvider.Verify(p => p.InvokedAsync(It.Is<AIContextProvider.InvokedContext>(x => x.RequestMessages == requestMessages && x.ResponseMessages == responseMessages && x.InvokeException == null), It.IsAny<CancellationToken>()), Times.Once);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify that RunAsync invokes any provided AIContextProvider when the downstream GetResponse call fails.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync()
|
||||
{
|
||||
// Arrange
|
||||
ChatMessage[] requestMessages = [new(ChatRole.User, "user message")];
|
||||
ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")];
|
||||
Mock<IChatClient> mockService = new();
|
||||
mockService
|
||||
.Setup(s => s.GetResponseAsync(
|
||||
It.IsAny<IEnumerable<ChatMessage>>(),
|
||||
It.IsAny<ChatOptions>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Throws(new InvalidOperationException("downstream failure"));
|
||||
|
||||
var mockProvider = new Mock<AIContextProvider>();
|
||||
mockProvider
|
||||
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
|
||||
.ReturnsAsync(new AIContext());
|
||||
mockProvider
|
||||
.Setup(p => p.InvokedAsync(It.IsAny<AIContextProvider.InvokedContext>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(new ValueTask());
|
||||
|
||||
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
|
||||
|
||||
// Act
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() => agent.RunAsync(requestMessages));
|
||||
|
||||
// Assert
|
||||
mockProvider.Verify(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()), Times.Once);
|
||||
mockProvider.Verify(p => p.InvokedAsync(It.Is<AIContextProvider.InvokedContext>(x => x.RequestMessages == requestMessages && x.ResponseMessages == null && x.InvokeException is InvalidOperationException), It.IsAny<CancellationToken>()), Times.Once);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify that RunAsync invokes any provided AIContextProvider and succeeds even when the AIContext is empty.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextAsync()
|
||||
{
|
||||
// Arrange
|
||||
Mock<IChatClient> mockService = new();
|
||||
List<ChatMessage> capturedMessages = [];
|
||||
string capturedInstructions = string.Empty;
|
||||
List<AITool> capturedTools = [];
|
||||
mockService
|
||||
.Setup(s => s.GetResponseAsync(
|
||||
It.IsAny<IEnumerable<ChatMessage>>(),
|
||||
It.IsAny<ChatOptions>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Callback<IEnumerable<ChatMessage>, ChatOptions, CancellationToken>((msgs, opts, ct) =>
|
||||
{
|
||||
capturedMessages.AddRange(msgs);
|
||||
capturedInstructions = opts.Instructions ?? string.Empty;
|
||||
if (opts.Tools != null)
|
||||
{
|
||||
capturedTools.AddRange(opts.Tools);
|
||||
}
|
||||
})
|
||||
.ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]));
|
||||
|
||||
var mockProvider = new Mock<AIContextProvider>();
|
||||
mockProvider
|
||||
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
|
||||
.ReturnsAsync(new AIContext());
|
||||
|
||||
ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
|
||||
|
||||
// Act
|
||||
await agent.RunAsync([new(ChatRole.User, "user message")]);
|
||||
|
||||
// Assert
|
||||
// Should contain: base instructions, user message, base function
|
||||
Assert.Single(capturedMessages);
|
||||
Assert.Equal("base instructions", capturedInstructions);
|
||||
Assert.Equal("user message", capturedMessages[0].Text);
|
||||
Assert.Equal(ChatRole.User, capturedMessages[0].Role);
|
||||
Assert.Single(capturedTools);
|
||||
Assert.Contains(capturedTools, t => t.Name == "base function");
|
||||
mockProvider.Verify(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()), Times.Once);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Property Override Tests
|
||||
|
||||
/// <summary>
|
||||
@@ -1524,6 +1700,61 @@ public class ChatClientAgentTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region GetNewThread Tests
|
||||
|
||||
[Fact]
|
||||
public void GetNewThreadUsesChatMessageStoreFactoryIfProvided()
|
||||
{
|
||||
// Arrange
|
||||
var mockChatClient = new Mock<IChatClient>();
|
||||
var mockStore = new Mock<IChatMessageStore>();
|
||||
var factoryCalled = false;
|
||||
|
||||
var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions
|
||||
{
|
||||
Instructions = "Test instructions",
|
||||
ChatMessageStoreFactory = () =>
|
||||
{
|
||||
factoryCalled = true;
|
||||
return mockStore.Object;
|
||||
}
|
||||
});
|
||||
|
||||
// Act
|
||||
var thread = agent.GetNewThread();
|
||||
|
||||
// Assert
|
||||
Assert.True(factoryCalled, "ChatMessageStoreFactory was not called.");
|
||||
Assert.Same(mockStore.Object, thread.MessageStore);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetNewThreadUsesAIContextProviderFactoryIfProvided()
|
||||
{
|
||||
// Arrange
|
||||
var mockChatClient = new Mock<IChatClient>();
|
||||
var mockContextProvider = new Mock<AIContextProvider>();
|
||||
var factoryCalled = false;
|
||||
var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions
|
||||
{
|
||||
Instructions = "Test instructions",
|
||||
AIContextProviderFactory = () =>
|
||||
{
|
||||
factoryCalled = true;
|
||||
return mockContextProvider.Object;
|
||||
}
|
||||
});
|
||||
|
||||
// Act
|
||||
var thread = agent.GetNewThread();
|
||||
|
||||
// Assert
|
||||
Assert.True(factoryCalled, "AIContextProviderFactory was not called.");
|
||||
Assert.Same(mockContextProvider.Object, thread.AIContextProvider);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private static async IAsyncEnumerable<T> ToAsyncEnumerableAsync<T>(IEnumerable<T> values)
|
||||
{
|
||||
await Task.Yield();
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace Microsoft.Extensions.AI.Agents.UnitTests;
|
||||
|
||||
[JsonSourceGenerationOptions(
|
||||
PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
|
||||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
|
||||
UseStringEnumConverter = true)]
|
||||
[JsonSerializable(typeof(JsonElement))]
|
||||
[JsonSerializable(typeof(string))]
|
||||
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
|
||||
Reference in New Issue
Block a user