.NET: [BREAKING] Add consistent message filtering to all providers. (#3851)

* Add consistent message filtering to all providers.

* Remove old chat history filtering classes

* Fix merge issues

* Fix unit test

* Enforce non-nullable property

* Fix merging bug and make troubleshooting source info easier by adding tostring implementation
This commit is contained in:
westey
2026-02-12 10:50:13 +00:00
committed by GitHub
Unverified
parent c99df98547
commit de82ffd40a
26 changed files with 918 additions and 502 deletions
@@ -63,11 +63,15 @@ AIAgent agent = azureOpenAIClient
{
ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." },
AIContextProvider = new TextSearchProvider(SearchAdapter, textSearchOptions),
// Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy
// Since we are using ChatCompletion which stores chat history locally, we can also add a message filter
// that removes messages produced by the TextSearchProvider before they are added to the chat history, so that
// we don't bloat chat history with all the search result messages.
ChatHistoryProvider = new InMemoryChatHistoryProvider()
.WithAIContextProviderMessageRemoval(),
// By default the chat history provider will store all messages, except for those that came from chat history in the first place.
// We also want to maintain that exclusion here.
ChatHistoryProvider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
{
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider && m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
}),
});
AgentSession session = await agent.CreateSessionAsync();
@@ -144,9 +144,11 @@ namespace SampleApp
var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
await collection.EnsureCollectionExistsAsync(cancellationToken);
// Add both request and response messages to the store
// Add both request and response messages to the store, excluding messages that came from chat history.
// Optionally messages produced by the AIContextProvider can also be persisted (not shown).
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
var allNewMessages = context.RequestMessages
.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
.Concat(context.ResponseMessages ?? []);
await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem()
{
@@ -44,10 +44,14 @@ AIAgent agent = new AzureOpenAIClient(
You manage a TODO list for the user. When the user has completed one of the tasks it can be removed from the TODO list. Only provide the list of TODO items if asked.
You remind users of upcoming calendar events when the user interacts with you.
""" },
ChatHistoryProvider = new InMemoryChatHistoryProvider()
// Use WithAIContextProviderMessageRemoval, so that we don't store the messages from the AI context provider in the chat history.
ChatHistoryProvider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
{
// Use StorageInputMessageFilter to provide a custom filter for messages stored in chat history.
// By default the chat history provider will store all messages, except for those that came from chat history in the first place.
// In this case, we want to also exclude messages that came from AI context providers.
// You may want to store these messages, depending on their content and your requirements.
.WithAIContextProviderMessageRemoval(),
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider && m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
}),
// Add an AI context provider that maintains a todo list for the agent and one that provides upcoming calendar entries.
// Wrap these in an AI context provider that aggregates the other two.
AIContextProvider = new AggregatingAIContextProvider([
@@ -63,6 +63,17 @@ public readonly struct AgentRequestMessageSourceAttribution : IEquatable<AgentRe
return obj is AgentRequestMessageSourceAttribution other && this.Equals(other);
}
/// <summary>
/// Returns a string representation of the current instance.
/// </summary>
/// <returns>A string containing the source type and source identifier.</returns>
public override string ToString()
{
return this.SourceId is null
? $"{this.SourceType}"
: $"{this.SourceType}:{this.SourceId}";
}
/// <summary>
/// Returns a hash code for the current instance.
/// </summary>
@@ -58,6 +58,12 @@ public readonly struct AgentRequestMessageSourceType : IEquatable<AgentRequestMe
/// <returns><see langword="true"/> if <paramref name="obj"/> is a <see cref="AgentRequestMessageSourceType"/> and its value is the same as this instance; otherwise, <see langword="false"/>.</returns>
public override bool Equals(object? obj) => obj is AgentRequestMessageSourceType other && this.Equals(other);
/// <summary>
/// Returns the string representation of this instance.
/// </summary>
/// <returns>The string value representing the source of the agent request message.</returns>
public override string ToString() => this.Value;
/// <summary>
/// Returns the hash code for this instance.
/// </summary>
@@ -1,52 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.AI;
/// <summary>
/// Contains extension methods for the <see cref="ChatHistoryProvider"/> class.
/// </summary>
public static class ChatHistoryProviderExtensions
{
/// <summary>
/// Adds message filtering to an existing <see cref="ChatHistoryProvider"/>, so that messages passed to the <see cref="ChatHistoryProvider"/> and messages
/// provided by the <see cref="ChatHistoryProvider"/> can be filtered, updated or replaced.
/// </summary>
/// <param name="provider">The <see cref="ChatHistoryProvider"/> to add the message filter to.</param>
/// <param name="invokingMessagesFilter">An optional filter function to apply to messages produced by the <see cref="ChatHistoryProvider"/>. If null, no filter is applied at this
/// stage.</param>
/// <param name="invokedMessagesFilter">An optional filter function to apply to the invoked context messages before they are passed to the <see cref="ChatHistoryProvider"/>. If null, no
/// filter is applied at this stage.</param>
/// <returns>The <see cref="ChatHistoryProvider"/> with filtering applied.</returns>
public static ChatHistoryProvider WithMessageFilters(
this ChatHistoryProvider provider,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? invokingMessagesFilter = null,
Func<ChatHistoryProvider.InvokedContext, ChatHistoryProvider.InvokedContext>? invokedMessagesFilter = null)
{
return new ChatHistoryProviderMessageFilter(
innerProvider: provider,
invokingMessagesFilter: invokingMessagesFilter,
invokedMessagesFilter: invokedMessagesFilter);
}
/// <summary>
/// Decorates the provided <see cref="ChatHistoryProvider"/> so that it does not add
/// messages with <see cref="AgentRequestMessageSourceType.AIContextProvider"/> to chat history.
/// </summary>
/// <param name="provider">The <see cref="ChatHistoryProvider"/> to add the message filter to.</param>
/// <returns>A new <see cref="ChatHistoryProvider"/> instance that filters out <see cref="AIContextProvider"/> messages so they do not get added.</returns>
public static ChatHistoryProvider WithAIContextProviderMessageRemoval(this ChatHistoryProvider provider)
{
return new ChatHistoryProviderMessageFilter(
innerProvider: provider,
invokedMessagesFilter: (ctx) =>
{
ctx.RequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider);
return ctx;
});
}
}
@@ -1,70 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.AI;
/// <summary>
/// A <see cref="ChatHistoryProvider"/> decorator that allows filtering the messages
/// passed into and out of an inner <see cref="ChatHistoryProvider"/>.
/// </summary>
public sealed class ChatHistoryProviderMessageFilter : ChatHistoryProvider
{
private readonly ChatHistoryProvider _innerProvider;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? _invokingMessagesFilter;
private readonly Func<InvokedContext, InvokedContext>? _invokedMessagesFilter;
/// <summary>
/// Initializes a new instance of the <see cref="ChatHistoryProviderMessageFilter"/> class.
/// </summary>
/// <remarks>Use this constructor to customize how messages are filtered before and after invocation by
/// providing appropriate filter functions. If no filters are provided, the <see cref="ChatHistoryProvider"/> operates without
/// additional filtering.</remarks>
/// <param name="innerProvider">The underlying <see cref="ChatHistoryProvider"/> to be wrapped. Cannot be null.</param>
/// <param name="invokingMessagesFilter">An optional filter function to apply to messages provided by the <see cref="ChatHistoryProvider"/>
/// before they are used by the agent. If null, no filter is applied at this stage.</param>
/// <param name="invokedMessagesFilter">An optional filter function to apply to the invocation context after messages have been produced. If null, no
/// filter is applied at this stage.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="innerProvider"/> is null.</exception>
public ChatHistoryProviderMessageFilter(
ChatHistoryProvider innerProvider,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? invokingMessagesFilter = null,
Func<InvokedContext, InvokedContext>? invokedMessagesFilter = null)
{
this._innerProvider = Throw.IfNull(innerProvider);
if (invokingMessagesFilter == null && invokedMessagesFilter == null)
{
throw new ArgumentException("At least one filter function, invokingMessagesFilter or invokedMessagesFilter, must be provided.");
}
this._invokingMessagesFilter = invokingMessagesFilter;
this._invokedMessagesFilter = invokedMessagesFilter;
}
/// <inheritdoc />
public override string StateKey => this._innerProvider.StateKey;
/// <inheritdoc />
protected override async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var messages = await this._innerProvider.InvokingAsync(context, cancellationToken).ConfigureAwait(false);
return this._invokingMessagesFilter != null ? this._invokingMessagesFilter(messages) : messages;
}
/// <inheritdoc />
protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
if (this._invokedMessagesFilter != null)
{
context = this._invokedMessagesFilter(context);
}
return this._innerProvider.InvokedAsync(context, cancellationToken);
}
}
@@ -27,9 +27,14 @@ namespace Microsoft.Agents.AI;
/// </remarks>
public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
{
private static IEnumerable<ChatMessage> DefaultExcludeChatHistoryFilter(IEnumerable<ChatMessage> messages)
=> messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory);
private readonly string _stateKey;
private readonly Func<AgentSession?, State> _stateInitializer;
private readonly JsonSerializerOptions _jsonSerializerOptions;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? _retrievalOutputMessageFilter;
/// <summary>
/// Initializes a new instance of the <see cref="InMemoryChatHistoryProvider"/> class.
@@ -45,6 +50,8 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
this.ReducerTriggerEvent = options?.ReducerTriggerEvent ?? InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval;
this._stateKey = options?.StateKey ?? base.StateKey;
this._jsonSerializerOptions = options?.JsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions;
this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExcludeChatHistoryFilter;
this._retrievalOutputMessageFilter = options?.RetrievalOutputMessageFilter;
}
/// <inheritdoc />
@@ -115,7 +122,12 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList();
}
return state.Messages
IEnumerable<ChatMessage> output = state.Messages;
if (this._retrievalOutputMessageFilter is not null)
{
output = this._retrievalOutputMessageFilter(output);
}
return output
.Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!))
.Concat(context.RequestMessages);
}
@@ -133,7 +145,7 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
var state = this.GetOrInitializeState(context.Session);
// Add request and response messages to the provider
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
var allNewMessages = this._storageInputMessageFilter(context.RequestMessages).Concat(context.ResponseMessages ?? []);
state.Messages.AddRange(allNewMessages);
if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null)
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Text.Json;
using Microsoft.Extensions.AI;
@@ -47,6 +48,31 @@ public sealed class InMemoryChatHistoryProviderOptions
/// </summary>
public JsonSerializerOptions? JsonSerializerOptions { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to request messages before they are added to storage
/// during <see cref="ChatHistoryProvider.InvokedAsync"/>.
/// </summary>
/// <value>
/// When <see langword="null"/>, the provider defaults to excluding messages with
/// <see cref="AgentRequestMessageSourceType.ChatHistory"/> source type to avoid
/// storing messages that came from chat history in the first place.
/// Depending on your requirements, you could provide a different filter, that also excludes
/// messages from e.g. AI context providers.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to messages produced by this provider
/// during <see cref="ChatHistoryProvider.InvokingAsync"/>.
/// </summary>
/// <remarks>
/// This filter is only applied to the messages that the provider itself produces (from its internal storage).
/// </remarks>
/// <value>
/// When <see langword="null"/>, no filtering is applied to the output messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? RetrievalOutputMessageFilter { get; set; }
/// <summary>
/// Defines the events that can trigger a reducer in the <see cref="InMemoryChatHistoryProvider"/>.
/// </summary>
@@ -21,6 +21,9 @@ namespace Microsoft.Agents.AI;
[RequiresDynamicCode("The CosmosChatHistoryProvider uses JSON serialization which is incompatible with NativeAOT.")]
public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
{
private static IEnumerable<ChatMessage> DefaultExcludeChatHistoryFilter(IEnumerable<ChatMessage> messages)
=> messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory);
private readonly CosmosClient _cosmosClient;
private readonly Container _container;
private readonly bool _ownsClient;
@@ -81,6 +84,25 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
/// </summary>
public string ContainerId { get; init; }
/// <summary>
/// A filter function applied to request messages before they are stored
/// during <see cref="ChatHistoryProvider.InvokedAsync"/>. The default filter excludes messages with the
/// <see cref="AgentRequestMessageSourceType.ChatHistory"/> source type.
/// </summary>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> StorageInputMessageFilter { get; set { field = Throw.IfNull(value); } } = DefaultExcludeChatHistoryFilter;
/// <summary>
/// Gets or sets an optional filter function applied to messages produced by this provider
/// during <see cref="ChatHistoryProvider.InvokingAsync"/>.
/// </summary>
/// <remarks>
/// This filter is only applied to the messages that the provider itself produces (from its internal storage).
/// </remarks>
/// <value>
/// When <see langword="null"/>, no filtering is applied to the output messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? RetrievalOutputMessageFilter { get; set; }
/// <summary>
/// Initializes a new instance of the <see cref="CosmosChatHistoryProvider"/> class.
/// </summary>
@@ -257,7 +279,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
messages.Reverse();
}
return messages
return (this.RetrievalOutputMessageFilter is not null ? this.RetrievalOutputMessageFilter(messages) : messages)
.Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!))
.Concat(context.RequestMessages);
}
@@ -281,7 +303,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
#pragma warning restore CA1513
var state = this.GetOrInitializeState(context.Session);
var messageList = context.RequestMessages.Concat(context.ResponseMessages ?? []).ToList();
var messageList = this.StorageInputMessageFilter(context.RequestMessages).Concat(context.ResponseMessages ?? []).ToList();
if (messageList.Count == 0)
{
return;
@@ -26,10 +26,15 @@ public sealed class Mem0Provider : AIContextProvider
{
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
private readonly string _contextPrompt;
private readonly bool _enableSensitiveTelemetryData;
private readonly string _stateKey;
private readonly Func<AgentSession?, State> _stateInitializer;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _searchInputMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
private readonly Mem0Client _client;
private readonly ILogger<Mem0Provider>? _logger;
@@ -67,6 +72,8 @@ public sealed class Mem0Provider : AIContextProvider
this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt;
this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false;
this._stateKey = options?.StateKey ?? base.StateKey;
this._searchInputMessageFilter = options?.SearchInputMessageFilter ?? DefaultExternalOnlyFilter;
this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExternalOnlyFilter;
}
/// <inheritdoc />
@@ -114,8 +121,7 @@ public sealed class Mem0Provider : AIContextProvider
string queryText = string.Join(
Environment.NewLine,
(inputContext.Messages ?? [])
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
this._searchInputMessageFilter(inputContext.Messages ?? [])
.Where(m => !string.IsNullOrWhiteSpace(m.Text))
.Select(m => m.Text));
@@ -205,8 +211,7 @@ public sealed class Mem0Provider : AIContextProvider
// Persist request and response messages after invocation.
await this.PersistMessagesAsync(
storageScope,
context.RequestMessages
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
this._storageInputMessageFilter(context.RequestMessages)
.Concat(context.ResponseMessages ?? []),
cancellationToken).ConfigureAwait(false);
}
@@ -1,5 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.AI.Mem0;
/// <summary>
@@ -24,4 +28,24 @@ public sealed class Mem0ProviderOptions
/// </summary>
/// <value>Defaults to the provider's type name.</value>
public string? StateKey { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to request messages when building the search text to use when
/// searching for relevant memories during <see cref="AIContextProvider.InvokingAsync"/>.
/// </summary>
/// <value>
/// When <see langword="null"/>, the provider defaults to including only
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? SearchInputMessageFilter { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to request messages when determining which messages to
/// extract memories from during <see cref="AIContextProvider.InvokedAsync"/>.
/// </summary>
/// <value>
/// When <see langword="null"/>, the provider defaults to including only
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
}
@@ -64,7 +64,9 @@ internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider
return default;
}
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
var allNewMessages = context.RequestMessages
.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
.Concat(context.ResponseMessages ?? []);
this.GetOrInitializeState(context.Session).Messages.AddRange(allNewMessages);
return default;
@@ -41,6 +41,9 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
private const string DefaultFunctionToolName = "Search";
private const string DefaultFunctionToolDescription = "Allows searching for related previous chat history to help answer the user question.";
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
#pragma warning disable CA2213 // VectorStore is not owned by this class - caller is responsible for disposal
private readonly VectorStore _vectorStore;
#pragma warning restore CA2213
@@ -54,6 +57,8 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
private readonly ILogger<ChatHistoryMemoryProvider>? _logger;
private readonly string _stateKey;
private readonly Func<AgentSession?, State> _stateInitializer;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _searchInputMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
private bool _collectionInitialized;
private readonly SemaphoreSlim _initializationLock = new(1, 1);
@@ -89,6 +94,8 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
this._logger = loggerFactory?.CreateLogger<ChatHistoryMemoryProvider>();
this._toolName = options.FunctionToolName ?? DefaultFunctionToolName;
this._toolDescription = options.FunctionToolDescription ?? DefaultFunctionToolDescription;
this._searchInputMessageFilter = options.SearchInputMessageFilter ?? DefaultExternalOnlyFilter;
this._storageInputMessageFilter = options.StorageInputMessageFilter ?? DefaultExternalOnlyFilter;
// Create a definition so that we can use the dimensions provided at runtime.
var definition = new VectorStoreCollectionDefinition
@@ -171,8 +178,8 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
try
{
// Get the text from the current request messages
var requestText = string.Join("\n", (inputContext.Messages ?? [])
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
var requestText = string.Join("\n",
this._searchInputMessageFilter(inputContext.Messages ?? [])
.Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text))
.Select(m => m.Text));
@@ -238,8 +245,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
// Ensure the collection is initialized
var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false);
List<Dictionary<string, object?>> itemsToStore = context.RequestMessages
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
List<Dictionary<string, object?>> itemsToStore = this._storageInputMessageFilter(context.RequestMessages)
.Concat(context.ResponseMessages ?? [])
.Select(message => new Dictionary<string, object?>
{
@@ -1,5 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.AI;
/// <summary>
@@ -53,6 +57,26 @@ public sealed class ChatHistoryMemoryProviderOptions
/// </value>
public string? StateKey { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to request messages when constructing the search text to use
/// to search for relevant chat history during <see cref="AIContextProvider.InvokingAsync"/>.
/// </summary>
/// <value>
/// When <see langword="null"/>, the provider defaults to including only
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? SearchInputMessageFilter { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to request messages when storing recent chat history
/// during <see cref="AIContextProvider.InvokedAsync"/>.
/// </summary>
/// <value>
/// When <see langword="null"/>, the provider defaults to including only
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
/// <summary>
/// Behavior choices for the provider.
/// </summary>
@@ -39,6 +39,9 @@ public sealed class TextSearchProvider : AIContextProvider
private const string DefaultContextPrompt = "## Additional Context\nConsider the following information from source documents when responding to the user:";
private const string DefaultCitationsPrompt = "Include citations to the source document with document name and link if document name and link is available.";
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
private readonly Func<string, CancellationToken, Task<IEnumerable<TextSearchResult>>> _searchAsync;
private readonly ILogger<TextSearchProvider>? _logger;
private readonly AITool[] _tools;
@@ -49,6 +52,8 @@ public sealed class TextSearchProvider : AIContextProvider
private readonly string _citationsPrompt;
private readonly string _stateKey;
private readonly Func<IList<TextSearchResult>, string>? _contextFormatter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _searchInputMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
/// <summary>
/// Initializes a new instance of the <see cref="TextSearchProvider"/> class.
@@ -72,6 +77,8 @@ public sealed class TextSearchProvider : AIContextProvider
this._citationsPrompt = options?.CitationsPrompt ?? DefaultCitationsPrompt;
this._stateKey = options?.StateKey ?? base.StateKey;
this._contextFormatter = options?.ContextFormatter;
this._searchInputMessageFilter = options?.SearchInputMessageFilter ?? DefaultExternalOnlyFilter;
this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExternalOnlyFilter;
// Create the on-demand search tool (only used if behavior is OnDemandFunctionCalling)
this._tools =
@@ -108,8 +115,8 @@ public sealed class TextSearchProvider : AIContextProvider
// Aggregate text from memory + current request messages.
var sbInput = new StringBuilder();
var requestMessagesText = (inputContext.Messages ?? [])
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
var requestMessagesText =
this._searchInputMessageFilter(inputContext.Messages ?? [])
.Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text);
foreach (var messageText in recentMessagesText.Concat(requestMessagesText))
{
@@ -189,8 +196,7 @@ public sealed class TextSearchProvider : AIContextProvider
var recentMessagesText = context.Session.StateBag.GetValue<TextSearchProviderState>(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText
?? [];
var newMessagesText = context.RequestMessages
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
var newMessagesText = this._storageInputMessageFilter(context.RequestMessages)
.Concat(context.ResponseMessages ?? [])
.Where(m =>
this._recentMessageRolesIncluded.Contains(m.Role) &&
@@ -68,6 +68,26 @@ public sealed class TextSearchProviderOptions
/// </value>
public string? StateKey { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to request messages when constructing the search input
/// text during <see cref="AIContextProvider.InvokingAsync"/>.
/// </summary>
/// <value>
/// When <see langword="null"/>, the provider defaults to including only
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? SearchInputMessageFilter { get; set; }
/// <summary>
/// Gets or sets an optional filter function applied to request messages when updating the recent message
/// memory during <see cref="AIContextProvider.InvokedAsync"/>.
/// </summary>
/// <value>
/// When <see langword="null"/>, the provider defaults to including only
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
/// </value>
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
/// <summary>
/// Gets or sets the list of <see cref="ChatRole"/> types to filter recent messages to
/// when deciding which recent messages to include when constructing the search input.
@@ -390,6 +390,49 @@ public sealed class AgentRequestMessageSourceAttributionTests
#endregion
#region ToString Tests
[Fact]
public void ToString_WithSourceId_ReturnsTypeColonId()
{
// Arrange
AgentRequestMessageSourceAttribution attribution = new(AgentRequestMessageSourceType.AIContextProvider, "MyProvider");
// Act
string result = attribution.ToString();
// Assert
Assert.Equal("AIContextProvider:MyProvider", result);
}
[Fact]
public void ToString_WithNullSourceId_ReturnsTypeOnly()
{
// Arrange
AgentRequestMessageSourceAttribution attribution = new(AgentRequestMessageSourceType.ChatHistory, null);
// Act
string result = attribution.ToString();
// Assert
Assert.Equal("ChatHistory", result);
}
[Fact]
public void ToString_Default_ReturnsExternalOnly()
{
// Arrange
AgentRequestMessageSourceAttribution attribution = default;
// Act
string result = attribution.ToString();
// Assert
Assert.Equal("External", result);
}
#endregion
#region Inequality Operator Tests
[Fact]
@@ -414,6 +414,46 @@ public sealed class AgentRequestMessageSourceTypeTests
#endregion
#region ToString Tests
[Fact]
public void ToString_ReturnsValue()
{
// Arrange
AgentRequestMessageSourceType source = new("CustomSource");
// Act
string result = source.ToString();
// Assert
Assert.Equal("CustomSource", result);
}
[Fact]
public void ToString_StaticExternal_ReturnsExternal()
{
// Arrange & Act
string result = AgentRequestMessageSourceType.External.ToString();
// Assert
Assert.Equal("External", result);
}
[Fact]
public void ToString_Default_ReturnsExternal()
{
// Arrange
AgentRequestMessageSourceType source = default;
// Act
string result = source.ToString();
// Assert
Assert.Equal("External", result);
}
#endregion
#region IEquatable Tests
[Fact]
@@ -1,141 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
using Moq.Protected;
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// <summary>
/// Contains tests for the <see cref="ChatHistoryProviderExtensions"/> class.
/// </summary>
public sealed class ChatHistoryProviderExtensionsTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
[Fact]
public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter()
{
// Arrange
Mock<ChatHistoryProvider> providerMock = new();
// Act
ChatHistoryProvider result = providerMock.Object.WithMessageFilters(
invokingMessagesFilter: msgs => msgs,
invokedMessagesFilter: ctx => ctx);
// Assert
Assert.IsType<ChatHistoryProviderMessageFilter>(result);
}
[Fact]
public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync()
{
// Arrange
Mock<ChatHistoryProvider> providerMock = new();
List<ChatMessage> innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")];
ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
providerMock
.Protected()
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
.ReturnsAsync(innerMessages);
ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters(
invokingMessagesFilter: msgs => msgs.Where(m => m.Role == ChatRole.User));
// Act
List<ChatMessage> result = (await filtered.InvokingAsync(context, CancellationToken.None)).ToList();
// Assert
Assert.Single(result);
Assert.Equal(ChatRole.User, result[0].Role);
}
[Fact]
public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync()
{
// Arrange
Mock<ChatHistoryProvider> providerMock = new();
List<ChatMessage> requestMessages =
[
new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } },
new(ChatRole.User, "Hello")
];
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
ChatHistoryProvider.InvokedContext? capturedContext = null;
providerMock
.Protected()
.Setup<ValueTask>("InvokedCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>())
.Callback<ChatHistoryProvider.InvokedContext, CancellationToken>((ctx, _) => capturedContext = ctx)
.Returns(default(ValueTask));
ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters(
invokedMessagesFilter: ctx =>
{
ctx.ResponseMessages = null;
return ctx;
});
// Act
await filtered.InvokedAsync(context, CancellationToken.None);
// Assert
Assert.NotNull(capturedContext);
Assert.Null(capturedContext.ResponseMessages);
}
[Fact]
public void WithAIContextProviderMessageRemoval_ReturnsChatHistoryProviderMessageFilter()
{
// Arrange
Mock<ChatHistoryProvider> providerMock = new();
// Act
ChatHistoryProvider result = providerMock.Object.WithAIContextProviderMessageRemoval();
// Assert
Assert.IsType<ChatHistoryProviderMessageFilter>(result);
}
[Fact]
public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMessagesAsync()
{
// Arrange
Mock<ChatHistoryProvider> providerMock = new();
List<ChatMessage> requestMessages =
[
new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } },
new(ChatRole.User, "Hello"),
new(ChatRole.System, "Context") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "TestContextSource") } } }
];
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages);
ChatHistoryProvider.InvokedContext? capturedContext = null;
providerMock
.Protected()
.Setup<ValueTask>("InvokedCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>())
.Callback<ChatHistoryProvider.InvokedContext, CancellationToken>((ctx, _) => capturedContext = ctx)
.Returns(default(ValueTask));
ChatHistoryProvider filtered = providerMock.Object.WithAIContextProviderMessageRemoval();
// Act
await filtered.InvokedAsync(context, CancellationToken.None);
// Assert
Assert.NotNull(capturedContext);
Assert.Equal(2, capturedContext.RequestMessages.Count());
Assert.Contains("System", capturedContext.RequestMessages.Select(x => x.Text));
Assert.Contains("Hello", capturedContext.RequestMessages.Select(x => x.Text));
}
}
@@ -1,213 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
using Moq.Protected;
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// <summary>
/// Contains tests for the <see cref="ChatHistoryProviderMessageFilter"/> class.
/// </summary>
public sealed class ChatHistoryProviderMessageFilterTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
[Fact]
public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException()
{
// Arrange, Act & Assert
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProviderMessageFilter(null!));
}
[Fact]
public void Constructor_WithOnlyInnerProvider_Throws()
{
// Arrange
var innerProviderMock = new Mock<ChatHistoryProvider>();
// Act & Assert
Assert.Throws<ArgumentException>(() => new ChatHistoryProviderMessageFilter(innerProviderMock.Object));
}
[Fact]
public void Constructor_WithAllParameters_CreatesInstance()
{
// Arrange
var innerProviderMock = new Mock<ChatHistoryProvider>();
IEnumerable<ChatMessage> InvokingFilter(IEnumerable<ChatMessage> msgs) => msgs;
ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) => ctx;
// Act
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter, InvokedFilter);
// Assert
Assert.NotNull(filter);
}
[Fact]
public void StateKey_DelegatesToInnerProvider()
{
// Arrange
var innerProviderMock = new Mock<ChatHistoryProvider>();
innerProviderMock.Setup(p => p.StateKey).Returns("inner-state-key");
// Act
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x);
// Assert
Assert.Equal("inner-state-key", filter.StateKey);
}
[Fact]
public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsync()
{
// Arrange
var innerProviderMock = new Mock<ChatHistoryProvider>();
var expectedMessages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
innerProviderMock
.Protected()
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
.ReturnsAsync(expectedMessages);
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x);
// Act
var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList();
// Assert
Assert.Equal(2, result.Count);
Assert.Equal("Hello", result[0].Text);
Assert.Equal("Hi there!", result[1].Text);
innerProviderMock
.Protected()
.Verify<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>());
}
[Fact]
public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync()
{
// Arrange
var innerProviderMock = new Mock<ChatHistoryProvider>();
var innerMessages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!"),
new(ChatRole.User, "How are you?")
};
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
innerProviderMock
.Protected()
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
.ReturnsAsync(innerMessages);
// Filter to only user messages
IEnumerable<ChatMessage> InvokingFilter(IEnumerable<ChatMessage> msgs) => msgs.Where(m => m.Role == ChatRole.User);
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter);
// Act
var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList();
// Assert
Assert.Equal(2, result.Count);
Assert.All(result, msg => Assert.Equal(ChatRole.User, msg.Role));
innerProviderMock
.Protected()
.Verify<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>());
}
[Fact]
public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync()
{
// Arrange
var innerProviderMock = new Mock<ChatHistoryProvider>();
var innerMessages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
innerProviderMock
.Protected()
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
.ReturnsAsync(innerMessages);
// Filter that transforms messages
IEnumerable<ChatMessage> InvokingFilter(IEnumerable<ChatMessage> msgs) =>
msgs.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}"));
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter);
// Act
var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList();
// Assert
Assert.Equal(2, result.Count);
Assert.Equal("[FILTERED] Hello", result[0].Text);
Assert.Equal("[FILTERED] Hi there!", result[1].Text);
}
[Fact]
public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync()
{
// Arrange
var innerProviderMock = new Mock<ChatHistoryProvider>();
List<ChatMessage> requestMessages =
[
new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } },
new(ChatRole.User, "Hello"),
];
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response") };
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages)
{
ResponseMessages = responseMessages
};
ChatHistoryProvider.InvokedContext? capturedContext = null;
innerProviderMock
.Protected()
.Setup<ValueTask>("InvokedCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>())
.Callback<ChatHistoryProvider.InvokedContext, CancellationToken>((ctx, ct) => capturedContext = ctx)
.Returns(default(ValueTask));
// Filter that modifies the context
ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx)
{
var modifiedRequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External).Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList();
return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages)
{
ResponseMessages = ctx.ResponseMessages,
InvokeException = ctx.InvokeException
};
}
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, invokedMessagesFilter: InvokedFilter);
// Act
await filter.InvokedAsync(context, CancellationToken.None);
// Assert
Assert.NotNull(capturedContext);
Assert.Single(capturedContext.RequestMessages);
Assert.Equal("[FILTERED] Hello", capturedContext.RequestMessages.First().Text);
innerProviderMock
.Protected()
.Verify<ValueTask>("InvokedCoreAsync", Times.Once(), ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>());
}
}
@@ -83,7 +83,7 @@ public class InMemoryChatHistoryProviderTests
};
var provider = new InMemoryChatHistoryProvider();
provider.SetMessages(session, [providerMessages[0]]);
provider.SetMessages(session, providerMessages);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
{
ResponseMessages = responseMessages
@@ -397,6 +397,91 @@ public class InMemoryChatHistoryProviderTests
await Assert.ThrowsAsync<ArgumentNullException>(() => provider.InvokingAsync(null!, CancellationToken.None).AsTask());
}
[Fact]
public async Task InvokedAsync_DefaultFilter_ExcludesChatHistoryMessagesAsync()
{
// Arrange
var session = CreateMockSession();
var provider = new InMemoryChatHistoryProvider();
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
// Act
await provider.InvokedAsync(context, CancellationToken.None);
// Assert - ChatHistory message excluded, AIContextProvider message included
var messages = provider.GetMessages(session);
Assert.Equal(3, messages.Count);
Assert.Equal("External message", messages[0].Text);
Assert.Equal("From context provider", messages[1].Text);
Assert.Equal("Response", messages[2].Text);
}
[Fact]
public async Task InvokedAsync_CustomFilter_OverridesDefaultAsync()
{
// Arrange
var session = CreateMockSession();
var provider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
{
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
});
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
// Act
await provider.InvokedAsync(context, CancellationToken.None);
// Assert - Custom filter keeps only External messages (both ChatHistory and AIContextProvider excluded)
var messages = provider.GetMessages(session);
Assert.Equal(2, messages.Count);
Assert.Equal("External message", messages[0].Text);
Assert.Equal("Response", messages[1].Text);
}
[Fact]
public async Task InvokingAsync_OutputFilter_FiltersOutputMessagesAsync()
{
// Arrange
var session = CreateMockSession();
var provider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
{
RetrievalOutputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.User)
});
provider.SetMessages(session,
[
new ChatMessage(ChatRole.User, "User message"),
new ChatMessage(ChatRole.Assistant, "Assistant message"),
new ChatMessage(ChatRole.System, "System message")
]);
// Act
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList();
// Assert - Only user messages pass through the output filter
Assert.Single(result);
Assert.Equal("User message", result[0].Text);
}
public class TestAIContent(string testData) : AIContent
{
public string TestData => testData;
@@ -841,4 +841,128 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
}
#endregion
#region Message Filter Tests
[SkippableFact]
[Trait("Category", "CosmosDB")]
public async Task InvokedAsync_DefaultFilter_ExcludesChatHistoryMessagesFromStorageAsync()
{
// Arrange
this.SkipIfEmulatorNotAvailable();
var session = CreateMockSession();
var conversationId = Guid.NewGuid().ToString();
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId,
_ => new CosmosChatHistoryProvider.State(conversationId));
var requestMessages = new[]
{
new ChatMessage(ChatRole.User, "External message"),
new ChatMessage(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new ChatMessage(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
// Act
await provider.InvokedAsync(context);
// Wait for eventual consistency
await Task.Delay(100);
// Assert - ChatHistory message excluded, External + AIContextProvider + Response stored
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
var messages = (await provider.InvokingAsync(invokingContext)).ToList();
Assert.Equal(3, messages.Count);
Assert.Equal("External message", messages[0].Text);
Assert.Equal("From context provider", messages[1].Text);
Assert.Equal("Response", messages[2].Text);
}
[SkippableFact]
[Trait("Category", "CosmosDB")]
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
{
// Arrange
this.SkipIfEmulatorNotAvailable();
var session = CreateMockSession();
var conversationId = Guid.NewGuid().ToString();
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId,
_ => new CosmosChatHistoryProvider.State(conversationId))
{
// Custom filter: only store External messages (also exclude AIContextProvider)
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
};
var requestMessages = new[]
{
new ChatMessage(ChatRole.User, "External message"),
new ChatMessage(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new ChatMessage(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
// Act
await provider.InvokedAsync(context);
// Wait for eventual consistency
await Task.Delay(100);
// Assert - Custom filter: only External + Response stored (both ChatHistory and AIContextProvider excluded)
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
var messages = (await provider.InvokingAsync(invokingContext)).ToList();
Assert.Equal(2, messages.Count);
Assert.Equal("External message", messages[0].Text);
Assert.Equal("Response", messages[1].Text);
}
[SkippableFact]
[Trait("Category", "CosmosDB")]
public async Task InvokingAsync_RetrievalOutputFilter_FiltersRetrievedMessagesAsync()
{
// Arrange
this.SkipIfEmulatorNotAvailable();
var session = CreateMockSession();
var conversationId = Guid.NewGuid().ToString();
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId,
_ => new CosmosChatHistoryProvider.State(conversationId))
{
// Only return User messages when retrieving
RetrievalOutputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.User)
};
var requestMessages = new[]
{
new ChatMessage(ChatRole.User, "User message"),
new ChatMessage(ChatRole.System, "System message"),
};
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Assistant response")]
};
await provider.InvokedAsync(context);
// Wait for eventual consistency
await Task.Delay(100);
// Act
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
var messages = (await provider.InvokingAsync(invokingContext)).ToList();
// Assert - Only User messages returned (System and Assistant filtered by RetrievalOutputMessageFilter)
Assert.Single(messages);
Assert.Equal("User message", messages[0].Text);
Assert.Equal(ChatRole.User, messages[0].Role);
}
#endregion
}
@@ -436,6 +436,116 @@ public sealed class Mem0ProviderTests : IDisposable
Assert.NotNull(state);
}
[Fact]
public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchAsync()
{
// Arrange
this._handler.EnqueueJsonResponse("[]"); // Empty search results
var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" };
var mockSession = new TestAgentSession();
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = requestMessages });
// Act
await sut.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Search query should only contain the External message
var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post);
using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody);
Assert.Equal("External message", doc.RootElement.GetProperty("query").GetString());
}
[Fact]
public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync()
{
// Arrange
this._handler.EnqueueJsonResponse("[]"); // Empty search results
var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" };
var mockSession = new TestAgentSession();
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new Mem0ProviderOptions
{
SearchInputMessageFilter = messages => messages // No filtering
});
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
};
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = requestMessages });
// Act
await sut.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Search query should contain all messages (custom identity filter)
var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post);
using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody);
var queryText = doc.RootElement.GetProperty("query").GetString();
Assert.Contains("External message", queryText);
Assert.Contains("From history", queryText);
}
[Fact]
public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync()
{
// Arrange
this._handler.EnqueueEmptyOk(); // For the one message that should be stored
var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" };
var mockSession = new TestAgentSession();
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
};
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages));
// Assert - Only the External message should be persisted
var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList();
Assert.Single(memoryPosts);
Assert.Contains("External message", memoryPosts[0].RequestBody);
Assert.DoesNotContain(memoryPosts, r => ContainsOrdinal(r.RequestBody, "From history"));
}
[Fact]
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
{
// Arrange
this._handler.EnqueueEmptyOk(); // For first CreateMemory
this._handler.EnqueueEmptyOk(); // For second CreateMemory
var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" };
var mockSession = new TestAgentSession();
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new Mem0ProviderOptions
{
StorageInputMessageFilter = messages => messages // No filtering - store everything
});
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
};
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages));
// Assert - Both messages should be persisted (identity filter overrides default)
var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList();
Assert.Equal(2, memoryPosts.Count);
}
private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0;
public void Dispose()
@@ -356,6 +356,140 @@ public sealed class TextSearchProviderTests
Assert.Null(aiContext.Tools);
}
#region Message Filter Tests
[Fact]
public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchInputAsync()
{
// Arrange
string? capturedInput = null;
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
{
capturedInput = input;
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
}
var provider = new TextSearchProvider(SearchDelegateAsync);
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Only external messages should be used for search input
Assert.Equal("External message", capturedInput);
}
[Fact]
public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync()
{
// Arrange
string? capturedInput = null;
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
{
capturedInput = input;
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
}
var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions
{
SearchInputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.System)
});
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "User message"),
new(ChatRole.System, "System message"),
};
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Custom filter keeps only System messages
Assert.Equal("System message", capturedInput);
}
[Fact]
public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync()
{
// Arrange
var options = new TextSearchProviderOptions
{
RecentMessageMemoryLimit = 10,
RecentMessageRolesIncluded = [ChatRole.User, ChatRole.System]
};
string? capturedInput = null;
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
{
capturedInput = input;
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
}
var provider = new TextSearchProvider(SearchDelegateAsync, options);
var session = new TestAgentSession();
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
// Store messages via InvokedAsync
await provider.InvokedAsync(new(s_mockAgent, session, requestMessages));
// Now invoke to read stored memory
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = [new ChatMessage(ChatRole.User, "Next")] });
await provider.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Only "External message" was stored in memory, so search input = "External message" + "Next"
Assert.Equal("External message\nNext", capturedInput);
}
[Fact]
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
{
// Arrange
var options = new TextSearchProviderOptions
{
RecentMessageMemoryLimit = 10,
RecentMessageRolesIncluded = [ChatRole.User, ChatRole.System],
StorageInputMessageFilter = messages => messages // No filtering - store everything
};
string? capturedInput = null;
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
{
capturedInput = input;
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
}
var provider = new TextSearchProvider(SearchDelegateAsync, options);
var session = new TestAgentSession();
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
};
// Store messages via InvokedAsync
await provider.InvokedAsync(new(s_mockAgent, session, requestMessages));
// Now invoke to read stored memory
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = [new ChatMessage(ChatRole.User, "Next")] });
await provider.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Both messages stored (identity filter), so search input includes all + current
Assert.Equal("External message\nFrom history\nNext", capturedInput);
}
#endregion
#region Recent Message Memory Tests
[Fact]
@@ -539,6 +539,188 @@ public class ChatHistoryMemoryProviderTests
#endregion
#region Message Filter Tests
[Fact]
public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchAsync()
{
// Arrange
var providerOptions = new ChatHistoryMemoryProviderOptions
{
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke,
};
string? capturedQuery = null;
this._vectorStoreCollectionMock
.Setup(c => c.SearchAsync(
It.IsAny<string>(),
It.IsAny<int>(),
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
It.IsAny<CancellationToken>()))
.Callback<string, int, VectorSearchOptions<Dictionary<string, object?>>, CancellationToken>((query, _, _, _) => capturedQuery = query)
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
options: providerOptions);
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Only External message used for search query
Assert.Equal("External message", capturedQuery);
}
[Fact]
public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync()
{
// Arrange
var providerOptions = new ChatHistoryMemoryProviderOptions
{
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke,
SearchInputMessageFilter = messages => messages // No filtering
};
string? capturedQuery = null;
this._vectorStoreCollectionMock
.Setup(c => c.SearchAsync(
It.IsAny<string>(),
It.IsAny<int>(),
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
It.IsAny<CancellationToken>()))
.Callback<string, int, VectorSearchOptions<Dictionary<string, object?>>, CancellationToken>((query, _, _, _) => capturedQuery = query)
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
options: providerOptions);
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
};
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
// Assert - Both messages should be included in search query (identity filter)
Assert.NotNull(capturedQuery);
Assert.Contains("External message", capturedQuery);
Assert.Contains("From history", capturedQuery);
}
[Fact]
public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync()
{
// Arrange
var stored = new List<Dictionary<string, object?>>();
this._vectorStoreCollectionMock
.Setup(c => c.UpsertAsync(It.IsAny<IEnumerable<Dictionary<string, object?>>>(), It.IsAny<CancellationToken>()))
.Callback<IEnumerable<Dictionary<string, object?>>, CancellationToken>((items, ct) =>
{
if (items != null)
{
stored.AddRange(items);
}
})
.Returns(Task.CompletedTask);
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }));
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
};
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
// Act
await provider.InvokedAsync(invokedContext, CancellationToken.None);
// Assert - Only External message + response stored (ChatHistory and AIContextProvider excluded by default)
Assert.Equal(2, stored.Count);
Assert.Equal("External message", stored[0]["Content"]);
Assert.Equal("Response", stored[1]["Content"]);
}
[Fact]
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
{
// Arrange
var stored = new List<Dictionary<string, object?>>();
this._vectorStoreCollectionMock
.Setup(c => c.UpsertAsync(It.IsAny<IEnumerable<Dictionary<string, object?>>>(), It.IsAny<CancellationToken>()))
.Callback<IEnumerable<Dictionary<string, object?>>, CancellationToken>((items, ct) =>
{
if (items != null)
{
stored.AddRange(items);
}
})
.Returns(Task.CompletedTask);
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
options: new ChatHistoryMemoryProviderOptions
{
StorageInputMessageFilter = messages => messages // No filtering - store everything
});
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "External message"),
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
};
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), requestMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
// Act
await provider.InvokedAsync(invokedContext, CancellationToken.None);
// Assert - All messages stored (identity filter overrides default)
Assert.Equal(3, stored.Count);
Assert.Equal("External message", stored[0]["Content"]);
Assert.Equal("From history", stored[1]["Content"]);
Assert.Equal("Response", stored[2]["Content"]);
}
#endregion
private static async IAsyncEnumerable<T> ToAsyncEnumerableAsync<T>(IEnumerable<T> values)
{
await Task.Yield();