From 66fe1c957ca297b87c8f2425deac3e2a2e700640 Mon Sep 17 00:00:00 2001
From: westey <164392973+westey-m@users.noreply.github.com>
Date: Tue, 16 Sep 2025 10:54:18 +0100
Subject: [PATCH] .NET: Add AIContextProvider support (#691)
* Add AIContextProvider support
* Address feedback.
* Address PR comments.
* Switch to valuetask and remove parallel calls for AIContextProvider
* Remove Model from ModelInvokingAsync method name
* Remove agent thread id again and remove it from context provider interface
* Add AIContextProvider serialization support to AgentThread and update sample to show this feature
* Address PR comments
* Improve memory sample
* Update sample comment.
* Remove AggregateAIContextProvider for now since it makes too many assumptions. We can include it later as a sample if needed.
* Update AIContextProviders to have an Invoked method instead of MessagesAddingAsync.
* Remove unused using.
* Address PR comments.
* Address PR comment.
* Update comment.
* Update comment
* Address PR comments.
---
dotnet/agent-framework-dotnet.slnx | 9 +
dotnet/eng/MSBuild/LegacySupport.props | 8 +
.../Program.cs | 6 +-
.../Agent_Step13_Memory.csproj | 22 ++
.../Agents/Agent_Step13_Memory/Program.cs | 152 ++++++++++++
.../samples/GettingStarted/Agents/README.md | 3 +-
.../CompilerFeatureRequiredAttribute.cs | 37 +++
.../README.md | 9 +
.../RequiredMemberAttribute/README.md | 9 +
.../RequiredMemberAttribute.cs | 10 +
.../AIAgent.cs | 2 +-
.../AIContext.cs | 43 ++++
.../AIContextProvider.cs | 118 +++++++++
.../AgentThread.cs | 23 +-
...t.Extensions.AI.Agents.Abstractions.csproj | 3 +
.../ChatCompletion/ChatClientAgent.cs | 127 ++++++++--
.../ChatCompletion/ChatClientAgentOptions.cs | 10 +-
.../AIAgentTests.cs | 2 +-
.../AIContextProviderTests.cs | 56 +++++
.../AIContextTests.cs | 58 +++++
.../AgentThreadTests.cs | 51 +++-
.../TestJsonSerializerContext.cs | 1 +
.../ChatClientAgentOptionsTests.cs | 9 +-
.../ChatCompletion/ChatClientAgentTests.cs | 231 ++++++++++++++++++
.../TestJsonSerializerContext.cs | 14 ++
25 files changed, 981 insertions(+), 32 deletions(-)
create mode 100644 dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj
create mode 100644 dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs
create mode 100644 dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs
create mode 100644 dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md
create mode 100644 dotnet/src/LegacySupport/RequiredMemberAttribute/README.md
create mode 100644 dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs
create mode 100644 dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs
create mode 100644 dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs
create mode 100644 dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs
create mode 100644 dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs
create mode 100644 dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs
diff --git a/dotnet/agent-framework-dotnet.slnx b/dotnet/agent-framework-dotnet.slnx
index 2b07e3346d..e662856548 100644
--- a/dotnet/agent-framework-dotnet.slnx
+++ b/dotnet/agent-framework-dotnet.slnx
@@ -50,6 +50,7 @@
+
@@ -190,6 +191,10 @@
+
+
+
+
@@ -206,6 +211,10 @@
+
+
+
+
diff --git a/dotnet/eng/MSBuild/LegacySupport.props b/dotnet/eng/MSBuild/LegacySupport.props
index c921bc59a2..54d65288bd 100644
--- a/dotnet/eng/MSBuild/LegacySupport.props
+++ b/dotnet/eng/MSBuild/LegacySupport.props
@@ -22,4 +22,12 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs
index a2713c80d2..04f8cc8b24 100644
--- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs
+++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs
@@ -39,8 +39,7 @@ namespace SampleApp
List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();
// Notify the thread of the input and output messages.
- await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
- await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken);
+ await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
return new AgentRunResponse
{
@@ -59,8 +58,7 @@ namespace SampleApp
List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList();
// Notify the thread of the input and output messages.
- await NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken);
- await NotifyThreadOfNewMessagesAsync(thread, responseMessages, cancellationToken);
+ await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken);
foreach (var message in responseMessages)
{
diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj
new file mode 100644
index 0000000000..68b0f8d36b
--- /dev/null
+++ b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Agent_Step13_Memory.csproj
@@ -0,0 +1,22 @@
+
+
+
+ Exe
+ net9.0
+
+ enable
+ disable
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs
new file mode 100644
index 0000000000..c91c81bd4b
--- /dev/null
+++ b/dotnet/samples/GettingStarted/Agents/Agent_Step13_Memory/Program.cs
@@ -0,0 +1,152 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+// This sample shows how to add a basic custom memory component to an agent.
+// The memory component subscribes to all messages added to the conversation and
+// extracts the user's name and age if provided.
+// The component adds a prompt to ask for this information if it is not already known
+// and provides it to the model before each invocation if known.
+
+using System;
+using System.Linq;
+using System.Text;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure.AI.OpenAI;
+using Azure.Identity;
+using Microsoft.Extensions.AI;
+using Microsoft.Extensions.AI.Agents;
+using OpenAI;
+using OpenAI.Chat;
+using SampleApp;
+
+var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new InvalidOperationException("AZURE_OPENAI_ENDPOINT is not set.");
+var deploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOYMENT_NAME") ?? "gpt-4o-mini";
+
+ChatClient chatClient = new AzureOpenAIClient(
+ new Uri(endpoint),
+ new AzureCliCredential())
+ .GetChatClient(deploymentName);
+
+// Create the agent and provide a factory to add our custom memory component to
+// all threads created by the agent. Here each new memory component will have its own
+// user info object, so each thread will have its own memory.
+// In real world applications/services, where the user info would be persisted in a database,
+// and preferably shared between multiple threads used by the same user, ensure that the
+// factory reads the user id from the current context and scopes the memory component
+// and its storage to that user id.
+AIAgent agent = chatClient.CreateAIAgent(new ChatClientAgentOptions()
+{
+ Instructions = "You are a friendly assistant. Always address the user by their name.",
+ AIContextProviderFactory = () => new SampleApp.UserInfoMemory(chatClient.AsIChatClient())
+});
+
+// Create a new thread for the conversation.
+AgentThread thread = agent.GetNewThread();
+
+Console.WriteLine(">> Use thread with blank memory\n");
+
+// Invoke the agent and output the text result.
+Console.WriteLine(await agent.RunAsync("Hello, what is the square root of 9?", thread));
+Console.WriteLine(await agent.RunAsync("My name is Ruaidhrí", thread));
+Console.WriteLine(await agent.RunAsync("I am 20 years old", thread));
+
+// We can serialize the thread. The serialized state will include the state of the memory component.
+var threadElement = await thread.SerializeAsync();
+
+Console.WriteLine("\n>> Use deserialized thread with previously created memories\n");
+
+// Later we can deserialize the thread and continue the conversation with the previous memory component state.
+var deserializedThread = await agent.DeserializeThreadAsync(threadElement);
+Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedThread));
+
+Console.WriteLine("\n>> Read memories from memory component\n");
+
+// It's possible to access the memory component via the thread's AIContextProvider property.
+var userInfo = ((UserInfoMemory)deserializedThread.AIContextProvider!).UserInfo;
+
+// Output the user info that was captured by the memory component.
+Console.WriteLine($"MEMORY - User Name: {userInfo.UserName}");
+Console.WriteLine($"MEMORY - User Age: {userInfo.UserAge}");
+
+Console.WriteLine("\n>> Use new thread with previously created memories\n");
+
+// Create a new thread.
+thread = agent.GetNewThread();
+
+// It is also possible to add the memory component to an individual thread only instead of all
+// threads via the factory above.
+// In this case we will also use the same user info object, so this thread will share the same
+// memories as the previous thread.
+thread.AIContextProvider = new UserInfoMemory(chatClient.AsIChatClient(), userInfo);
+
+// Invoke the agent and output the text result.
+// This time the agent should remember the user's name and use it in the response.
+Console.WriteLine(await agent.RunAsync("What is my name and age?", thread));
+
+namespace SampleApp
+{
+ ///
+ /// Sample memory component that can remember a user's name and age.
+ ///
+ internal sealed class UserInfoMemory(IChatClient chatClient, UserInfo? userInfo = null) : AIContextProvider
+ {
+ public UserInfo UserInfo { get; set; } = userInfo ?? new();
+
+ public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
+ {
+ // Try and extract the user name and age from the message if we don't have it already and it's a user message.
+ if ((this.UserInfo.UserName == null || this.UserInfo.UserAge == null) && context.RequestMessages.Any(x => x.Role == ChatRole.User))
+ {
+ var result = await chatClient.GetResponseAsync(
+ context.RequestMessages,
+ new ChatOptions()
+ {
+ Instructions = "Extract the user's name and age from the message if present. If not present return nulls."
+ },
+ cancellationToken: cancellationToken);
+
+ this.UserInfo.UserName ??= result.Result.UserName;
+ this.UserInfo.UserAge ??= result.Result.UserAge;
+ }
+ }
+
+ public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
+ {
+ StringBuilder instructions = new();
+
+ // If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context.
+ instructions.AppendLine(
+ this.UserInfo.UserName == null ?
+ "Ask the user for their name and politely decline to answer any questions until they provide it." :
+ $"The user's name is {this.UserInfo.UserName}.");
+
+ instructions.AppendLine(
+ this.UserInfo.UserAge == null ?
+ "Ask the user for their age and politely decline to answer any questions until they provide it." :
+ $"The user's age is {this.UserInfo.UserAge}.");
+
+ return new ValueTask(new AIContext
+ {
+ Instructions = instructions.ToString()
+ });
+ }
+
+ public override ValueTask SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
+ {
+ return new ValueTask(JsonSerializer.SerializeToElement(this.UserInfo, jsonSerializerOptions));
+ }
+
+ public override ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
+ {
+ this.UserInfo = JsonSerializer.Deserialize(serializedState, jsonSerializerOptions) ?? new UserInfo();
+ return default;
+ }
+ }
+
+ internal sealed class UserInfo
+ {
+ public string? UserName { get; set; }
+ public int? UserAge { get; set; }
+ }
+}
diff --git a/dotnet/samples/GettingStarted/Agents/README.md b/dotnet/samples/GettingStarted/Agents/README.md
index 43da38e0ff..d831fab364 100644
--- a/dotnet/samples/GettingStarted/Agents/README.md
+++ b/dotnet/samples/GettingStarted/Agents/README.md
@@ -37,7 +37,8 @@ Before you begin, ensure you have the following prerequisites:
|[Dependency injection with a simple agent](./Agent_Step09_DependencyInjection/)|This sample demonstrates how to add and resolve an agent with a dependency injection container|
|[Exposing a simple agent as MCP tool](./Agent_Step10_AsMcpTool/)|This sample demonstrates how to expose an agent as an MCP tool|
|[Using images with a simple agent](./Agent_Step11_UsingImages/)|This sample demonstrates how to use image multi-modality with an AI agent|
-|[Exposing a simple agent a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool|
+|[Exposing a simple agent as a function tool](./Agent_Step12_AsFunctionTool/)|This sample demonstrates how to expose an agent as a function tool|
+|[Using memory with an agent](./Agent_Step12_Memory/)|This sample demonstrates how to create a simple memory component and use it with an agent|
## Running the samples from the console
diff --git a/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs
new file mode 100644
index 0000000000..74a67d7c6e
--- /dev/null
+++ b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+#pragma warning disable SA1623 // Property summary documentation should match accessors
+
+namespace System.Runtime.CompilerServices;
+
+///
+/// Indicates that compiler support for a particular feature is required for the location where this attribute is applied.
+///
+[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)]
+internal sealed class CompilerFeatureRequiredAttribute : Attribute
+{
+ public CompilerFeatureRequiredAttribute(string featureName)
+ {
+ FeatureName = featureName;
+ }
+
+ ///
+ /// The name of the compiler feature.
+ ///
+ public string FeatureName { get; }
+
+ ///
+ /// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand .
+ ///
+ public bool IsOptional { get; init; }
+
+ ///
+ /// The used for the ref structs C# feature.
+ ///
+ public const string RefStructs = nameof(RefStructs);
+
+ ///
+ /// The used for the required members C# feature.
+ ///
+ public const string RequiredMembers = nameof(RequiredMembers);
+}
diff --git a/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md
new file mode 100644
index 0000000000..c30799eef0
--- /dev/null
+++ b/dotnet/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md
@@ -0,0 +1,9 @@
+Enables use of C# required members on older frameworks.
+
+To use this source in your project, add the following to your `.csproj` file:
+
+```xml
+
+ true
+
+```
diff --git a/dotnet/src/LegacySupport/RequiredMemberAttribute/README.md b/dotnet/src/LegacySupport/RequiredMemberAttribute/README.md
new file mode 100644
index 0000000000..da8c9bc98c
--- /dev/null
+++ b/dotnet/src/LegacySupport/RequiredMemberAttribute/README.md
@@ -0,0 +1,9 @@
+Enables use of C# required members on older frameworks.
+
+To use this source in your project, add the following to your `.csproj` file:
+
+```xml
+
+ true
+
+```
diff --git a/dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs b/dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs
new file mode 100644
index 0000000000..1a82954140
--- /dev/null
+++ b/dotnet/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs
@@ -0,0 +1,10 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.ComponentModel;
+
+namespace System.Runtime.CompilerServices;
+
+/// Specifies that a type has required members or that a member is required.
+[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)]
+[EditorBrowsable(EditorBrowsableState.Never)]
+internal sealed class RequiredMemberAttribute : Attribute;
diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs
index 395f075029..b744a17892 100644
--- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs
+++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIAgent.cs
@@ -290,6 +290,6 @@ public abstract class AIAgent
_ = Throw.IfNull(thread);
_ = Throw.IfNull(messages);
- await thread.OnNewMessagesAsync(messages, cancellationToken).ConfigureAwait(false);
+ await thread.MessagesReceivedAsync(messages, cancellationToken).ConfigureAwait(false);
}
}
diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs
new file mode 100644
index 0000000000..f0c8e64bff
--- /dev/null
+++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContext.cs
@@ -0,0 +1,43 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Collections.Generic;
+
+namespace Microsoft.Extensions.AI.Agents;
+
+///
+/// A class containing any context that should be provided to the AI model
+/// as supplied by an .
+///
+///
+/// Each has the ability to provide its own context for each invocation.
+/// The class contains the additional context supplied by the .
+/// This context will be combined with context supplied by other providers before being passed to the AI model.
+///
+public sealed class AIContext
+{
+ ///
+ /// Gets or sets any instructions to pass to the AI model in addition to any other prompts
+ /// that it may already have (in the case of an agent), or chat history that may
+ /// already exist.
+ ///
+ ///
+ /// These instructions will be transient and only apply to the current invocation.
+ ///
+ public string? Instructions { get; set; }
+
+ ///
+ /// Gets or sets a list of messages to add to the chat history.
+ ///
+ ///
+ /// These messages will permanently be added to the chat history.
+ ///
+ public IList? Messages { get; set; }
+
+ ///
+ /// Gets or sets a list of functions/tools to make available to the AI model for the current invocation.
+ ///
+ ///
+ /// These functions/tools will be transient and only apply to the current invocation.
+ ///
+ public IList? Tools { get; set; }
+}
diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs
new file mode 100644
index 0000000000..fc6d244051
--- /dev/null
+++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AIContextProvider.cs
@@ -0,0 +1,118 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Collections.Generic;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Extensions.AI.Agents;
+
+///
+/// Base class for all AI context providers.
+///
+///
+/// An AI context provider is a component that can be used to enhance the AI's context management.
+/// It can listen to changes in the conversation, provide additional context to
+/// the Model/Agent/etc. just before invocation and supply additional function tools.
+///
+public abstract class AIContextProvider
+{
+ ///
+ /// Called just before the Model/Agent/etc. is invoked
+ /// Implementers can load any additional context required at this time,
+ /// and they should return any context that should be passed to the Model/Agent/etc.
+ ///
+ /// Contains the event context.
+ /// The to monitor for cancellation requests. The default is .
+ /// A task that completes when the context has been rendered and returned.
+ public abstract ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default);
+
+ ///
+ /// Called just before the Model/Agent/etc. is invoked
+ /// Implementers can load any additional context required at this time,
+ /// and they should return any context that should be passed to the Model/Agent/etc.
+ ///
+ /// Contains the event context.
+ /// The to monitor for cancellation requests. The default is .
+ /// A task that completes when the context has been rendered and returned.
+ public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default)
+ {
+ return default;
+ }
+
+ ///
+ /// Serializes the current object's state to a using the specified serialization options.
+ ///
+ /// The JSON serialization options to use.
+ /// The to monitor for cancellation requests. The default is .
+ /// A representation of the object's state.
+ public virtual ValueTask SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
+ {
+ return default;
+ }
+
+ ///
+ /// Deserializes the state contained in the provided into the properties on this object.
+ ///
+ /// A representing the state of the object.
+ /// Optional settings for customizing the JSON deserialization process.
+ /// The to monitor for cancellation requests. The default is .
+ /// A that completes when the state has been deserialized.
+ public virtual ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
+ {
+ return default;
+ }
+
+ ///
+ /// Contains the event context provided to .
+ ///
+ public class InvokingContext
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The messages to be sent to the Model/Agent/etc. for this invocation.
+ /// Thrown if is .
+ public InvokingContext(IEnumerable requestMessages)
+ {
+ RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
+ }
+
+ ///
+ /// Gets the messages that will be sent to the Model/Agent/etc. for this invocation.
+ ///
+ public IEnumerable RequestMessages { get; private set; }
+ }
+
+ ///
+ /// Contains the event conext provided to .
+ ///
+ public class InvokedContext
+ {
+ ///
+ /// Initializes a new instance of the class with the specified request messages.
+ ///
+ /// The messages that were sent to the Model/Agent/etc. for this invocation.
+ /// Thrown if is .
+ public InvokedContext(IEnumerable requestMessages)
+ {
+ RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
+ }
+
+ ///
+ /// Gets the messages that were sent to the Model/Agent/etc. for this invocation.
+ ///
+ public IEnumerable RequestMessages { get; private set; }
+
+ ///
+ /// Gets the collection of response messages generated by Model/Agent/etc. if the invocation succeeded.
+ ///
+ public IEnumerable? ResponseMessages { get; init; }
+
+ ///
+ /// Gets the that was thrown during the invocation, if the invocation failed.
+ ///
+ public Exception? InvokeException { get; init; }
+ }
+}
diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs
index 74a5c8108e..9b1fbcdb5b 100644
--- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs
+++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/AgentThread.cs
@@ -27,7 +27,7 @@ public class AgentThread
}
///
- /// Gets or sets the id of the current thread to support cases where the thread is owned by the agent service.
+ /// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service.
///
///
///
@@ -108,6 +108,11 @@ public class AgentThread
}
}
+ ///
+ /// Gets or sets the used by this thread to provide additional context to the AI model before each invocation.
+ ///
+ public AIContextProvider? AIContextProvider { get; set; }
+
///
/// Serializes the current object's state to a using the specified serialization options.
///
@@ -120,10 +125,15 @@ public class AgentThread
null :
await this._messageStore.SerializeStateAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
+ var aiContextProviderState = this.AIContextProvider is null ?
+ null :
+ await this.AIContextProvider.SerializeAsync(jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
+
var state = new ThreadState
{
ConversationId = this.ConversationId,
- StoreState = storeState
+ StoreState = storeState,
+ AIContextProviderState = aiContextProviderState
};
return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ThreadState)));
@@ -139,7 +149,7 @@ public class AgentThread
/// The to monitor for cancellation requests. The default is .
/// A task that completes when the context has been updated.
/// The thread has been deleted.
- protected internal virtual async Task OnNewMessagesAsync(IEnumerable newMessages, CancellationToken cancellationToken = default)
+ protected internal virtual async Task MessagesReceivedAsync(IEnumerable newMessages, CancellationToken cancellationToken = default)
{
switch (this)
{
@@ -186,6 +196,11 @@ public class AgentThread
return;
}
+ if (state?.AIContextProviderState.HasValue is true && this.AIContextProvider is not null)
+ {
+ await this.AIContextProvider.DeserializeAsync(state.AIContextProviderState.Value, jsonSerializerOptions, cancellationToken).ConfigureAwait(false);
+ }
+
// If we don't have any IChatMessageStore state return here.
if (state?.StoreState is null || state?.StoreState.Value.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null)
{
@@ -206,5 +221,7 @@ public class AgentThread
public string? ConversationId { get; set; }
public JsonElement? StoreState { get; set; }
+
+ public JsonElement? AIContextProviderState { get; set; }
}
}
diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj
index 0bdc6a5b49..c4f910dbec 100644
--- a/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj
+++ b/dotnet/src/Microsoft.Extensions.AI.Agents.Abstractions/Microsoft.Extensions.AI.Agents.Abstractions.csproj
@@ -12,6 +12,9 @@
true
true
true
+ true
+ true
+ true
diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs
index 3deafcdc6f..5878cc71c9 100644
--- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs
+++ b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgent.cs
@@ -114,7 +114,17 @@ public sealed class ChatClientAgent : AIAgent
this._logger.LogAgentChatClientInvokingAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType);
- ChatResponse chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false);
+ // Call the IChatClient and notify the AIContextProvider of any failures.
+ ChatResponse chatResponse;
+ try
+ {
+ chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
+ throw;
+ }
this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType, inputMessages.Count);
@@ -122,19 +132,17 @@ public sealed class ChatClientAgent : AIAgent
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
- // Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread.
- await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false);
-
// Ensure that the author name is set for each message in the response.
foreach (ChatMessage chatResponseMessage in chatResponse.Messages)
{
chatResponseMessage.AuthorName ??= agentName;
}
- // Convert the chat response messages to a valid IReadOnlyCollection for notification signatures below.
- var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection ?? [.. chatResponse.Messages];
+ // Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent message state in the thread.
+ await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
- await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
+ // Notify the AIContextProvider of all new messages.
+ await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
return new(chatResponse) { AgentId = this.Id };
}
@@ -156,15 +164,34 @@ public sealed class ChatClientAgent : AIAgent
this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType);
- // Using the enumerator to ensure we consider the case where no updates are returned for notification.
- var responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);
+ List responseUpdates = [];
+
+ IAsyncEnumerator responseUpdatesEnumerator;
+
+ try
+ {
+ // Using the enumerator to ensure we consider the case where no updates are returned for notification.
+ responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);
+ }
+ catch (Exception ex)
+ {
+ await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
+ throw;
+ }
this._logger.LogAgentChatClientInvokedStreamingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType);
- List responseUpdates = [];
-
- // Ensure we start the streaming request
- var hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
+ bool hasUpdates;
+ try
+ {
+ // Ensure we start the streaming request
+ hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
+ throw;
+ }
while (hasUpdates)
{
@@ -176,20 +203,28 @@ public sealed class ChatClientAgent : AIAgent
yield return new(update) { AgentId = this.Id };
}
- hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
+ try
+ {
+ hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
+ }
+ catch (Exception ex)
+ {
+ await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, cancellationToken).ConfigureAwait(false);
+ throw;
+ }
}
var chatResponse = responseUpdates.ToChatResponse();
- var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection ?? [.. chatResponse.Messages];
// We can derive the type of supported thread from whether we have a conversation id,
// so let's update it and set the conversation id for the service thread case.
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);
// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
- await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages, cancellationToken).ConfigureAwait(false);
+ await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
- await NotifyThreadOfNewMessagesAsync(safeThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
+ // Notify the AIContextProvider of all new messages.
+ await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
}
///
@@ -204,12 +239,40 @@ public sealed class ChatClientAgent : AIAgent
///
public override AgentThread GetNewThread()
{
- var thread = new AgentThread { MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke() };
+ var thread = new AgentThread
+ {
+ MessageStore = this._agentOptions?.ChatMessageStoreFactory?.Invoke(),
+ AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke()
+ };
return thread;
}
#region Private
+ ///
+ /// Notify the when an agent run succeeded, if there is an .
+ ///
+ private static async Task NotifyAIContextProviderOfSuccessAsync(AgentThread thread, IEnumerable inputMessages, IEnumerable responseMessages, CancellationToken cancellationToken)
+ {
+ if (thread.AIContextProvider is not null)
+ {
+ await thread.AIContextProvider.InvokedAsync(new(inputMessages) { ResponseMessages = responseMessages },
+ cancellationToken).ConfigureAwait(false);
+ }
+ }
+
+ ///
+ /// Notify the of any failure during an agent run, if there is an .
+ ///
+ private static async Task NotifyAIContextProviderOfFailureAsync(AgentThread thread, Exception ex, IEnumerable inputMessages, CancellationToken cancellationToken)
+ {
+ if (thread.AIContextProvider is not null)
+ {
+ await thread.AIContextProvider.InvokedAsync(new(inputMessages) { InvokeException = ex },
+ cancellationToken).ConfigureAwait(false);
+ }
+ }
+
///
/// Configures and returns chat options by merging the provided run options with the agent's default chat options.
///
@@ -350,6 +413,34 @@ public sealed class ChatClientAgent : AIAgent
threadMessages.AddRange(await thread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false));
}
+ // If we have an AIContextProvider, we should get context from it, and update our
+ // messages and options with the additional context.
+ if (thread.AIContextProvider is not null)
+ {
+ var invokingContext = new AIContextProvider.InvokingContext(inputMessages);
+ var aiContext = await thread.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
+ if (aiContext.Messages is { Count: > 0 })
+ {
+ threadMessages.AddRange(aiContext.Messages);
+ }
+
+ if (aiContext.Tools is { Count: > 0 })
+ {
+ chatOptions ??= new();
+ chatOptions.Tools ??= [];
+ foreach (AITool tool in aiContext.Tools)
+ {
+ chatOptions.Tools.Add(tool);
+ }
+ }
+
+ if (aiContext.Instructions is not null)
+ {
+ chatOptions ??= new();
+ chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}";
+ }
+ }
+
// Add the input messages to the end of thread messages.
threadMessages.AddRange(inputMessages);
diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs
index 2759523cc2..461ed01aa3 100644
--- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs
+++ b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientAgentOptions.cs
@@ -80,6 +80,13 @@ public class ChatClientAgentOptions
///
public Func? ChatMessageStoreFactory { get; set; }
+ ///
+ /// Gets or sets a factory function to create an instance of
+ /// which will be used to create a context provider for each new thread, and can then
+ /// provide additional context for each agent run.
+ ///
+ public Func? AIContextProviderFactory { get; set; }
+
///
/// Gets or sets a value indicating whether to use the provided instance as is,
/// without applying any default decorators.
@@ -105,6 +112,7 @@ public class ChatClientAgentOptions
Instructions = this.Instructions,
Description = this.Description,
ChatOptions = this.ChatOptions?.Clone(),
- ChatMessageStoreFactory = this.ChatMessageStoreFactory
+ ChatMessageStoreFactory = this.ChatMessageStoreFactory,
+ AIContextProviderFactory = this.AIContextProviderFactory,
};
}
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs
index 04cd8584c8..9906b62a21 100644
--- a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIAgentTests.cs
@@ -235,7 +235,7 @@ public class AIAgentTests
await MockAgent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken);
- threadMock.Protected().Verify("OnNewMessagesAsync", Times.Once(), messages, cancellationToken);
+ threadMock.Protected().Verify("MessagesReceivedAsync", Times.Once(), messages, cancellationToken);
}
#region GetService Method Tests
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs
new file mode 100644
index 0000000000..f5eea0d5ed
--- /dev/null
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextProviderTests.cs
@@ -0,0 +1,56 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
+
+public class AIContextProviderTests
+{
+ [Fact]
+ public async Task InvokedAsync_ReturnsCompletedTaskAsync()
+ {
+ var provider = new TestAIContextProvider();
+ var messages = new ReadOnlyCollection(new List());
+ var task = provider.InvokedAsync(new(messages));
+ Assert.Equal(default, task);
+ }
+
+ [Fact]
+ public async Task SerializeAsync_ReturnsEmptyElementAsync()
+ {
+ var provider = new TestAIContextProvider();
+ var actual = await provider.SerializeAsync();
+ Assert.Equal(default, actual);
+ }
+
+ [Fact]
+ public async Task DeserializeAsync_ReturnsCompletedTaskAsync()
+ {
+ var provider = new TestAIContextProvider();
+ var element = default(JsonElement);
+ var actual = provider.DeserializeAsync(element);
+ Assert.Equal(default, actual);
+ }
+
+ private sealed class TestAIContextProvider : AIContextProvider
+ {
+ public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
+ {
+ return default;
+ }
+
+ public override async ValueTask SerializeAsync(JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
+ {
+ return await base.SerializeAsync(jsonSerializerOptions, cancellationToken);
+ }
+
+ public override async ValueTask DeserializeAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
+ {
+ await base.DeserializeAsync(serializedState, jsonSerializerOptions, cancellationToken);
+ }
+ }
+}
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs
new file mode 100644
index 0000000000..2a85a525e5
--- /dev/null
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AIContextTests.cs
@@ -0,0 +1,58 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Collections.Generic;
+
+namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
+
+///
+/// Unit tests for .
+///
+public class AIContextTests
+{
+ [Fact]
+ public void SetInstructionsRoundtrips()
+ {
+ var context = new AIContext
+ {
+ Instructions = "Test Instructions"
+ };
+
+ Assert.Equal("Test Instructions", context.Instructions);
+ }
+
+ [Fact]
+ public void SetMessagesRoundtrips()
+ {
+ var context = new AIContext
+ {
+ Messages = new List
+ {
+ new(ChatRole.User, "Hello"),
+ new(ChatRole.Assistant, "Hi there!")
+ }
+ };
+
+ Assert.NotNull(context.Messages);
+ Assert.Equal(2, context.Messages.Count);
+ Assert.Equal("Hello", context.Messages[0].Text);
+ Assert.Equal("Hi there!", context.Messages[1].Text);
+ }
+
+ [Fact]
+ public void SetAIFunctionsRoundtrips()
+ {
+ var context = new AIContext
+ {
+ Tools = new List
+ {
+ AIFunctionFactory.Create(() => "Function1", "Function1", "Description1"),
+ AIFunctionFactory.Create(() => "Function2", "Function2", "Description2"),
+ }
+ };
+
+ Assert.NotNull(context.Tools);
+ Assert.Equal(2, context.Tools.Count);
+ Assert.Equal("Function1", context.Tools[0].Name);
+ Assert.Equal("Function2", context.Tools[1].Name);
+ }
+}
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs
index 305fc1cd45..843ea3fb0f 100644
--- a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/AgentThreadTests.cs
@@ -8,6 +8,8 @@ using System.Threading;
using System.Threading.Tasks;
using Moq;
+#pragma warning disable CA1861 // Avoid constant arrays as arguments
+
namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
public class AgentThreadTests
@@ -102,7 +104,7 @@ public class AgentThreadTests
};
// Act
- await thread.OnNewMessagesAsync(messages, CancellationToken.None);
+ await thread.MessagesReceivedAsync(messages, CancellationToken.None);
Assert.Equal("thread-123", thread.ConversationId);
Assert.Null(thread.MessageStore);
}
@@ -120,7 +122,7 @@ public class AgentThreadTests
};
// Act
- await thread.OnNewMessagesAsync(messages, CancellationToken.None);
+ await thread.MessagesReceivedAsync(messages, CancellationToken.None);
// Assert
Assert.Equal(2, store.Count);
@@ -173,6 +175,26 @@ public class AgentThreadTests
Assert.Null(thread.MessageStore);
}
+ [Fact]
+ public async Task VerifyDeserializeWithAIContextProviderAsync()
+ {
+ // Arrange
+ var json = JsonSerializer.Deserialize("""
+ {
+ "aiContextProviderState": ["CP1"]
+ }
+ """, TestJsonSerializerContext.Default.JsonElement);
+ Mock mockProvider = new();
+ var thread = new AgentThread() { AIContextProvider = mockProvider.Object };
+
+ // Act
+ await thread.DeserializeAsync(json);
+
+ // Assert
+ Assert.Null(thread.MessageStore);
+ mockProvider.Verify(m => m.DeserializeAsync(It.Is(e => e.ValueKind == JsonValueKind.Array && e.GetArrayLength() == 1), It.IsAny(), It.IsAny()), Times.Once);
+ }
+
[Fact]
public async Task DeserializeWithInvalidJsonThrowsAsync()
{
@@ -245,6 +267,31 @@ public class AgentThreadTests
Assert.Equal("TestContent", textContent.GetProperty("text").GetString());
}
+ [Fact]
+ public async Task VerifyThreadSerializationWithWithAIContextProviderAsync()
+ {
+ // Arrange
+ Mock mockProvider = new();
+ var providerStateElement = JsonSerializer.SerializeToElement(new[] { "CP1" }, TestJsonSerializerContext.Default.StringArray);
+ mockProvider
+ .Setup(m => m.SerializeAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(providerStateElement);
+
+ var thread = new AgentThread();
+ thread.AIContextProvider = mockProvider.Object;
+
+ // Act
+ var json = await thread.SerializeAsync();
+
+ // Assert
+ Assert.Equal(JsonValueKind.Object, json.ValueKind);
+ Assert.True(json.TryGetProperty("aiContextProviderState", out var providerStateProperty));
+ Assert.Equal(JsonValueKind.Array, providerStateProperty.ValueKind);
+ Assert.Single(providerStateProperty.EnumerateArray());
+ Assert.Equal("CP1", providerStateProperty.EnumerateArray().First().GetString());
+ mockProvider.Verify(m => m.SerializeAsync(It.IsAny(), It.IsAny()), Times.Once);
+ }
+
///
/// Verify thread serialization to JSON with custom options.
///
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs
index bd39c6638a..f256ad311c 100644
--- a/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.Abstractions.UnitTests/TestJsonSerializerContext.cs
@@ -17,4 +17,5 @@ namespace Microsoft.Extensions.AI.Agents.Abstractions.UnitTests;
[JsonSerializable(typeof(Animal))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(Dictionary))]
+[JsonSerializable(typeof(string[]))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs
index ead3946fc2..d8dd5d52fc 100644
--- a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentOptionsTests.cs
@@ -22,6 +22,7 @@ public class ChatClientAgentOptionsTests
Assert.Null(options.Description);
Assert.Null(options.ChatOptions);
Assert.Null(options.ChatMessageStoreFactory);
+ Assert.Null(options.AIContextProviderFactory);
}
[Fact]
@@ -39,6 +40,7 @@ public class ChatClientAgentOptionsTests
Assert.Null(options.Instructions);
Assert.Null(options.Description);
Assert.Null(options.ChatOptions);
+ Assert.Null(options.AIContextProviderFactory);
}
[Fact]
@@ -163,11 +165,13 @@ public class ChatClientAgentOptionsTests
const string Description = "Test description";
var tools = new List { AIFunctionFactory.Create(() => "test") };
static IChatMessageStore ChatMessageStoreFactory() => new Mock().Object;
+ static AIContextProvider AIContextProviderFactory() => new Mock().Object;
var original = new ChatClientAgentOptions(Instructions, Name, Description, tools)
{
Id = "test-id",
- ChatMessageStoreFactory = ChatMessageStoreFactory
+ ChatMessageStoreFactory = ChatMessageStoreFactory,
+ AIContextProviderFactory = AIContextProviderFactory
};
// Act
@@ -180,6 +184,7 @@ public class ChatClientAgentOptionsTests
Assert.Equal(original.Instructions, clone.Instructions);
Assert.Equal(original.Description, clone.Description);
Assert.Same(original.ChatMessageStoreFactory, clone.ChatMessageStoreFactory);
+ Assert.Same(original.AIContextProviderFactory, clone.AIContextProviderFactory);
// ChatOptions should be cloned, not the same reference
Assert.NotSame(original.ChatOptions, clone.ChatOptions);
@@ -209,5 +214,7 @@ public class ChatClientAgentOptionsTests
Assert.Equal(original.Instructions, clone.Instructions);
Assert.Equal(original.Description, clone.Description);
Assert.Null(clone.ChatOptions);
+ Assert.Null(clone.ChatMessageStoreFactory);
+ Assert.Null(clone.AIContextProviderFactory);
}
}
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs
index 46afaace36..4ec63e04cd 100644
--- a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/ChatCompletion/ChatClientAgentTests.cs
@@ -10,6 +10,8 @@ namespace Microsoft.Extensions.AI.Agents.UnitTests.ChatCompletion;
public class ChatClientAgentTests
{
+ #region Constructor Tests
+
///
/// Verify the invocation and response of .
///
@@ -38,6 +40,10 @@ public class ChatClientAgentTests
Assert.Equal("AgentInvokedChatClient", agent.ChatClient.GetType().Name);
}
+ #endregion
+
+ #region RunAsync Tests
+
///
/// Verify the invocation and response of using .
///
@@ -390,6 +396,176 @@ public class ChatClientAgentTests
await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread));
}
+ ///
+ /// Verify that RunAsync sets the ConversationId on the thread when the service returns one.
+ ///
+ [Fact]
+ public async Task RunAsyncSetsConversationIdOnThreadWhenReturnedByChatClientAsync()
+ {
+ // Arrange
+ Mock mockService = new();
+ mockService.Setup(
+ s => s.GetResponseAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" });
+ ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions" });
+ AgentThread thread = new();
+
+ // Act
+ await agent.RunAsync([new(ChatRole.User, "test")], thread);
+
+ // Assert
+ Assert.Equal("ConvId", thread.ConversationId);
+ }
+
+ ///
+ /// Verify that RunAsync invokes any provided AIContextProvider and uses the result.
+ ///
+ [Fact]
+ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync()
+ {
+ // Arrange
+ ChatMessage[] requestMessages = [new(ChatRole.User, "user message")];
+ ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")];
+ Mock mockService = new();
+ List capturedMessages = [];
+ string capturedInstructions = string.Empty;
+ List capturedTools = [];
+ mockService
+ .Setup(s => s.GetResponseAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) =>
+ {
+ capturedMessages.AddRange(msgs);
+ capturedInstructions = opts.Instructions ?? string.Empty;
+ if (opts.Tools != null)
+ {
+ capturedTools.AddRange(opts.Tools);
+ }
+ })
+ .ReturnsAsync(new ChatResponse(responseMessages));
+
+ var mockProvider = new Mock();
+ mockProvider
+ .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(new AIContext
+ {
+ Messages = [new(ChatRole.System, "context provider message")],
+ Instructions = "context provider instructions",
+ Tools = [AIFunctionFactory.Create(() => { }, "context provider function")]
+ });
+ mockProvider
+ .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny()))
+ .Returns(new ValueTask());
+
+ ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
+
+ // Act
+ await agent.RunAsync(requestMessages);
+
+ // Assert
+ // Should contain: base instructions, context message, user message, base function, context function
+ Assert.Equal(2, capturedMessages.Count);
+ Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions);
+ Assert.Equal("context provider message", capturedMessages[0].Text);
+ Assert.Equal(ChatRole.System, capturedMessages[0].Role);
+ Assert.Equal("user message", capturedMessages[1].Text);
+ Assert.Equal(ChatRole.User, capturedMessages[1].Role);
+ Assert.Equal(2, capturedTools.Count);
+ Assert.Contains(capturedTools, t => t.Name == "base function");
+ Assert.Contains(capturedTools, t => t.Name == "context provider function");
+ mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once);
+ mockProvider.Verify(p => p.InvokedAsync(It.Is(x => x.RequestMessages == requestMessages && x.ResponseMessages == responseMessages && x.InvokeException == null), It.IsAny()), Times.Once);
+ }
+
+ ///
+ /// Verify that RunAsync invokes any provided AIContextProvider when the downstream GetResponse call fails.
+ ///
+ [Fact]
+ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync()
+ {
+ // Arrange
+ ChatMessage[] requestMessages = [new(ChatRole.User, "user message")];
+ ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")];
+ Mock mockService = new();
+ mockService
+ .Setup(s => s.GetResponseAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .Throws(new InvalidOperationException("downstream failure"));
+
+ var mockProvider = new Mock();
+ mockProvider
+ .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(new AIContext());
+ mockProvider
+ .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny()))
+ .Returns(new ValueTask());
+
+ ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
+
+ // Act
+ await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages));
+
+ // Assert
+ mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once);
+ mockProvider.Verify(p => p.InvokedAsync(It.Is(x => x.RequestMessages == requestMessages && x.ResponseMessages == null && x.InvokeException is InvalidOperationException), It.IsAny()), Times.Once);
+ }
+
+ ///
+ /// Verify that RunAsync invokes any provided AIContextProvider and succeeds even when the AIContext is empty.
+ ///
+ [Fact]
+ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextAsync()
+ {
+ // Arrange
+ Mock mockService = new();
+ List capturedMessages = [];
+ string capturedInstructions = string.Empty;
+ List capturedTools = [];
+ mockService
+ .Setup(s => s.GetResponseAsync(
+ It.IsAny>(),
+ It.IsAny(),
+ It.IsAny()))
+ .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) =>
+ {
+ capturedMessages.AddRange(msgs);
+ capturedInstructions = opts.Instructions ?? string.Empty;
+ if (opts.Tools != null)
+ {
+ capturedTools.AddRange(opts.Tools);
+ }
+ })
+ .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]));
+
+ var mockProvider = new Mock();
+ mockProvider
+ .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny()))
+ .ReturnsAsync(new AIContext());
+
+ ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "base instructions", AIContextProviderFactory = () => mockProvider.Object, ChatOptions = new() { Tools = [AIFunctionFactory.Create(() => { }, "base function")] } });
+
+ // Act
+ await agent.RunAsync([new(ChatRole.User, "user message")]);
+
+ // Assert
+ // Should contain: base instructions, user message, base function
+ Assert.Single(capturedMessages);
+ Assert.Equal("base instructions", capturedInstructions);
+ Assert.Equal("user message", capturedMessages[0].Text);
+ Assert.Equal(ChatRole.User, capturedMessages[0].Role);
+ Assert.Single(capturedTools);
+ Assert.Contains(capturedTools, t => t.Name == "base function");
+ mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once);
+ }
+
+ #endregion
+
#region Property Override Tests
///
@@ -1524,6 +1700,61 @@ public class ChatClientAgentTests
#endregion
+ #region GetNewThread Tests
+
+ [Fact]
+ public void GetNewThreadUsesChatMessageStoreFactoryIfProvided()
+ {
+ // Arrange
+ var mockChatClient = new Mock();
+ var mockStore = new Mock();
+ var factoryCalled = false;
+
+ var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions
+ {
+ Instructions = "Test instructions",
+ ChatMessageStoreFactory = () =>
+ {
+ factoryCalled = true;
+ return mockStore.Object;
+ }
+ });
+
+ // Act
+ var thread = agent.GetNewThread();
+
+ // Assert
+ Assert.True(factoryCalled, "ChatMessageStoreFactory was not called.");
+ Assert.Same(mockStore.Object, thread.MessageStore);
+ }
+
+ [Fact]
+ public void GetNewThreadUsesAIContextProviderFactoryIfProvided()
+ {
+ // Arrange
+ var mockChatClient = new Mock();
+ var mockContextProvider = new Mock();
+ var factoryCalled = false;
+ var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions
+ {
+ Instructions = "Test instructions",
+ AIContextProviderFactory = () =>
+ {
+ factoryCalled = true;
+ return mockContextProvider.Object;
+ }
+ });
+
+ // Act
+ var thread = agent.GetNewThread();
+
+ // Assert
+ Assert.True(factoryCalled, "AIContextProviderFactory was not called.");
+ Assert.Same(mockContextProvider.Object, thread.AIContextProvider);
+ }
+
+ #endregion
+
private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values)
{
await Task.Yield();
diff --git a/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs
new file mode 100644
index 0000000000..d3d1f00118
--- /dev/null
+++ b/dotnet/tests/Microsoft.Extensions.AI.Agents.UnitTests/TestJsonSerializerContext.cs
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace Microsoft.Extensions.AI.Agents.UnitTests;
+
+[JsonSourceGenerationOptions(
+ PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
+ UseStringEnumConverter = true)]
+[JsonSerializable(typeof(JsonElement))]
+[JsonSerializable(typeof(string))]
+internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;