mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: [BREAKING] Add consistent message filtering to all providers. (#3851)
* Add consistent message filtering to all providers. * Remove old chat history filtering classes * Fix merge issues * Fix unit test * Enforce non-nullable property * Fix merging bug and make troubleshooting source info easier by adding tostring implementation
This commit is contained in:
committed by
GitHub
Unverified
parent
c99df98547
commit
de82ffd40a
+7
-3
@@ -63,11 +63,15 @@ AIAgent agent = azureOpenAIClient
|
||||
{
|
||||
ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." },
|
||||
AIContextProvider = new TextSearchProvider(SearchAdapter, textSearchOptions),
|
||||
// Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy
|
||||
// Since we are using ChatCompletion which stores chat history locally, we can also add a message filter
|
||||
// that removes messages produced by the TextSearchProvider before they are added to the chat history, so that
|
||||
// we don't bloat chat history with all the search result messages.
|
||||
ChatHistoryProvider = new InMemoryChatHistoryProvider()
|
||||
.WithAIContextProviderMessageRemoval(),
|
||||
// By default the chat history provider will store all messages, except for those that came from chat history in the first place.
|
||||
// We also want to maintain that exclusion here.
|
||||
ChatHistoryProvider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
|
||||
{
|
||||
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider && m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
|
||||
}),
|
||||
});
|
||||
|
||||
AgentSession session = await agent.CreateSessionAsync();
|
||||
|
||||
+4
-2
@@ -144,9 +144,11 @@ namespace SampleApp
|
||||
var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
|
||||
await collection.EnsureCollectionExistsAsync(cancellationToken);
|
||||
|
||||
// Add both request and response messages to the store
|
||||
// Add both request and response messages to the store, excluding messages that came from chat history.
|
||||
// Optionally messages produced by the AIContextProvider can also be persisted (not shown).
|
||||
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
|
||||
var allNewMessages = context.RequestMessages
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
|
||||
.Concat(context.ResponseMessages ?? []);
|
||||
|
||||
await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem()
|
||||
{
|
||||
|
||||
@@ -44,10 +44,14 @@ AIAgent agent = new AzureOpenAIClient(
|
||||
You manage a TODO list for the user. When the user has completed one of the tasks it can be removed from the TODO list. Only provide the list of TODO items if asked.
|
||||
You remind users of upcoming calendar events when the user interacts with you.
|
||||
""" },
|
||||
ChatHistoryProvider = new InMemoryChatHistoryProvider()
|
||||
// Use WithAIContextProviderMessageRemoval, so that we don't store the messages from the AI context provider in the chat history.
|
||||
ChatHistoryProvider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
|
||||
{
|
||||
// Use StorageInputMessageFilter to provide a custom filter for messages stored in chat history.
|
||||
// By default the chat history provider will store all messages, except for those that came from chat history in the first place.
|
||||
// In this case, we want to also exclude messages that came from AI context providers.
|
||||
// You may want to store these messages, depending on their content and your requirements.
|
||||
.WithAIContextProviderMessageRemoval(),
|
||||
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider && m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
|
||||
}),
|
||||
// Add an AI context provider that maintains a todo list for the agent and one that provides upcoming calendar entries.
|
||||
// Wrap these in an AI context provider that aggregates the other two.
|
||||
AIContextProvider = new AggregatingAIContextProvider([
|
||||
|
||||
@@ -63,6 +63,17 @@ public readonly struct AgentRequestMessageSourceAttribution : IEquatable<AgentRe
|
||||
return obj is AgentRequestMessageSourceAttribution other && this.Equals(other);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a string representation of the current instance.
|
||||
/// </summary>
|
||||
/// <returns>A string containing the source type and source identifier.</returns>
|
||||
public override string ToString()
|
||||
{
|
||||
return this.SourceId is null
|
||||
? $"{this.SourceType}"
|
||||
: $"{this.SourceType}:{this.SourceId}";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a hash code for the current instance.
|
||||
/// </summary>
|
||||
|
||||
@@ -58,6 +58,12 @@ public readonly struct AgentRequestMessageSourceType : IEquatable<AgentRequestMe
|
||||
/// <returns><see langword="true"/> if <paramref name="obj"/> is a <see cref="AgentRequestMessageSourceType"/> and its value is the same as this instance; otherwise, <see langword="false"/>.</returns>
|
||||
public override bool Equals(object? obj) => obj is AgentRequestMessageSourceType other && this.Equals(other);
|
||||
|
||||
/// <summary>
|
||||
/// Returns the string representation of this instance.
|
||||
/// </summary>
|
||||
/// <returns>The string value representing the source of the agent request message.</returns>
|
||||
public override string ToString() => this.Value;
|
||||
|
||||
/// <summary>
|
||||
/// Returns the hash code for this instance.
|
||||
/// </summary>
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// Contains extension methods for the <see cref="ChatHistoryProvider"/> class.
|
||||
/// </summary>
|
||||
public static class ChatHistoryProviderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds message filtering to an existing <see cref="ChatHistoryProvider"/>, so that messages passed to the <see cref="ChatHistoryProvider"/> and messages
|
||||
/// provided by the <see cref="ChatHistoryProvider"/> can be filtered, updated or replaced.
|
||||
/// </summary>
|
||||
/// <param name="provider">The <see cref="ChatHistoryProvider"/> to add the message filter to.</param>
|
||||
/// <param name="invokingMessagesFilter">An optional filter function to apply to messages produced by the <see cref="ChatHistoryProvider"/>. If null, no filter is applied at this
|
||||
/// stage.</param>
|
||||
/// <param name="invokedMessagesFilter">An optional filter function to apply to the invoked context messages before they are passed to the <see cref="ChatHistoryProvider"/>. If null, no
|
||||
/// filter is applied at this stage.</param>
|
||||
/// <returns>The <see cref="ChatHistoryProvider"/> with filtering applied.</returns>
|
||||
public static ChatHistoryProvider WithMessageFilters(
|
||||
this ChatHistoryProvider provider,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? invokingMessagesFilter = null,
|
||||
Func<ChatHistoryProvider.InvokedContext, ChatHistoryProvider.InvokedContext>? invokedMessagesFilter = null)
|
||||
{
|
||||
return new ChatHistoryProviderMessageFilter(
|
||||
innerProvider: provider,
|
||||
invokingMessagesFilter: invokingMessagesFilter,
|
||||
invokedMessagesFilter: invokedMessagesFilter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Decorates the provided <see cref="ChatHistoryProvider"/> so that it does not add
|
||||
/// messages with <see cref="AgentRequestMessageSourceType.AIContextProvider"/> to chat history.
|
||||
/// </summary>
|
||||
/// <param name="provider">The <see cref="ChatHistoryProvider"/> to add the message filter to.</param>
|
||||
/// <returns>A new <see cref="ChatHistoryProvider"/> instance that filters out <see cref="AIContextProvider"/> messages so they do not get added.</returns>
|
||||
public static ChatHistoryProvider WithAIContextProviderMessageRemoval(this ChatHistoryProvider provider)
|
||||
{
|
||||
return new ChatHistoryProviderMessageFilter(
|
||||
innerProvider: provider,
|
||||
invokedMessagesFilter: (ctx) =>
|
||||
{
|
||||
ctx.RequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider);
|
||||
return ctx;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// A <see cref="ChatHistoryProvider"/> decorator that allows filtering the messages
|
||||
/// passed into and out of an inner <see cref="ChatHistoryProvider"/>.
|
||||
/// </summary>
|
||||
public sealed class ChatHistoryProviderMessageFilter : ChatHistoryProvider
|
||||
{
|
||||
private readonly ChatHistoryProvider _innerProvider;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? _invokingMessagesFilter;
|
||||
private readonly Func<InvokedContext, InvokedContext>? _invokedMessagesFilter;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ChatHistoryProviderMessageFilter"/> class.
|
||||
/// </summary>
|
||||
/// <remarks>Use this constructor to customize how messages are filtered before and after invocation by
|
||||
/// providing appropriate filter functions. If no filters are provided, the <see cref="ChatHistoryProvider"/> operates without
|
||||
/// additional filtering.</remarks>
|
||||
/// <param name="innerProvider">The underlying <see cref="ChatHistoryProvider"/> to be wrapped. Cannot be null.</param>
|
||||
/// <param name="invokingMessagesFilter">An optional filter function to apply to messages provided by the <see cref="ChatHistoryProvider"/>
|
||||
/// before they are used by the agent. If null, no filter is applied at this stage.</param>
|
||||
/// <param name="invokedMessagesFilter">An optional filter function to apply to the invocation context after messages have been produced. If null, no
|
||||
/// filter is applied at this stage.</param>
|
||||
/// <exception cref="ArgumentNullException">Thrown if <paramref name="innerProvider"/> is null.</exception>
|
||||
public ChatHistoryProviderMessageFilter(
|
||||
ChatHistoryProvider innerProvider,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? invokingMessagesFilter = null,
|
||||
Func<InvokedContext, InvokedContext>? invokedMessagesFilter = null)
|
||||
{
|
||||
this._innerProvider = Throw.IfNull(innerProvider);
|
||||
|
||||
if (invokingMessagesFilter == null && invokedMessagesFilter == null)
|
||||
{
|
||||
throw new ArgumentException("At least one filter function, invokingMessagesFilter or invokedMessagesFilter, must be provided.");
|
||||
}
|
||||
|
||||
this._invokingMessagesFilter = invokingMessagesFilter;
|
||||
this._invokedMessagesFilter = invokedMessagesFilter;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._innerProvider.StateKey;
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var messages = await this._innerProvider.InvokingAsync(context, cancellationToken).ConfigureAwait(false);
|
||||
return this._invokingMessagesFilter != null ? this._invokingMessagesFilter(messages) : messages;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (this._invokedMessagesFilter != null)
|
||||
{
|
||||
context = this._invokedMessagesFilter(context);
|
||||
}
|
||||
|
||||
return this._innerProvider.InvokedAsync(context, cancellationToken);
|
||||
}
|
||||
}
|
||||
@@ -27,9 +27,14 @@ namespace Microsoft.Agents.AI;
|
||||
/// </remarks>
|
||||
public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
|
||||
{
|
||||
private static IEnumerable<ChatMessage> DefaultExcludeChatHistoryFilter(IEnumerable<ChatMessage> messages)
|
||||
=> messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory);
|
||||
|
||||
private readonly string _stateKey;
|
||||
private readonly Func<AgentSession?, State> _stateInitializer;
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? _retrievalOutputMessageFilter;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InMemoryChatHistoryProvider"/> class.
|
||||
@@ -45,6 +50,8 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
|
||||
this.ReducerTriggerEvent = options?.ReducerTriggerEvent ?? InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval;
|
||||
this._stateKey = options?.StateKey ?? base.StateKey;
|
||||
this._jsonSerializerOptions = options?.JsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions;
|
||||
this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExcludeChatHistoryFilter;
|
||||
this._retrievalOutputMessageFilter = options?.RetrievalOutputMessageFilter;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
@@ -115,7 +122,12 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
|
||||
state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList();
|
||||
}
|
||||
|
||||
return state.Messages
|
||||
IEnumerable<ChatMessage> output = state.Messages;
|
||||
if (this._retrievalOutputMessageFilter is not null)
|
||||
{
|
||||
output = this._retrievalOutputMessageFilter(output);
|
||||
}
|
||||
return output
|
||||
.Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!))
|
||||
.Concat(context.RequestMessages);
|
||||
}
|
||||
@@ -133,7 +145,7 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
|
||||
// Add request and response messages to the provider
|
||||
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
|
||||
var allNewMessages = this._storageInputMessageFilter(context.RequestMessages).Concat(context.ResponseMessages ?? []);
|
||||
state.Messages.AddRange(allNewMessages);
|
||||
|
||||
if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
@@ -47,6 +48,31 @@ public sealed class InMemoryChatHistoryProviderOptions
|
||||
/// </summary>
|
||||
public JsonSerializerOptions? JsonSerializerOptions { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to request messages before they are added to storage
|
||||
/// during <see cref="ChatHistoryProvider.InvokedAsync"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, the provider defaults to excluding messages with
|
||||
/// <see cref="AgentRequestMessageSourceType.ChatHistory"/> source type to avoid
|
||||
/// storing messages that came from chat history in the first place.
|
||||
/// Depending on your requirements, you could provide a different filter, that also excludes
|
||||
/// messages from e.g. AI context providers.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to messages produced by this provider
|
||||
/// during <see cref="ChatHistoryProvider.InvokingAsync"/>.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This filter is only applied to the messages that the provider itself produces (from its internal storage).
|
||||
/// </remarks>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, no filtering is applied to the output messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? RetrievalOutputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Defines the events that can trigger a reducer in the <see cref="InMemoryChatHistoryProvider"/>.
|
||||
/// </summary>
|
||||
|
||||
@@ -21,6 +21,9 @@ namespace Microsoft.Agents.AI;
|
||||
[RequiresDynamicCode("The CosmosChatHistoryProvider uses JSON serialization which is incompatible with NativeAOT.")]
|
||||
public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
|
||||
{
|
||||
private static IEnumerable<ChatMessage> DefaultExcludeChatHistoryFilter(IEnumerable<ChatMessage> messages)
|
||||
=> messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory);
|
||||
|
||||
private readonly CosmosClient _cosmosClient;
|
||||
private readonly Container _container;
|
||||
private readonly bool _ownsClient;
|
||||
@@ -81,6 +84,25 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
|
||||
/// </summary>
|
||||
public string ContainerId { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// A filter function applied to request messages before they are stored
|
||||
/// during <see cref="ChatHistoryProvider.InvokedAsync"/>. The default filter excludes messages with the
|
||||
/// <see cref="AgentRequestMessageSourceType.ChatHistory"/> source type.
|
||||
/// </summary>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> StorageInputMessageFilter { get; set { field = Throw.IfNull(value); } } = DefaultExcludeChatHistoryFilter;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to messages produced by this provider
|
||||
/// during <see cref="ChatHistoryProvider.InvokingAsync"/>.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This filter is only applied to the messages that the provider itself produces (from its internal storage).
|
||||
/// </remarks>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, no filtering is applied to the output messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? RetrievalOutputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CosmosChatHistoryProvider"/> class.
|
||||
/// </summary>
|
||||
@@ -257,7 +279,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
|
||||
messages.Reverse();
|
||||
}
|
||||
|
||||
return messages
|
||||
return (this.RetrievalOutputMessageFilter is not null ? this.RetrievalOutputMessageFilter(messages) : messages)
|
||||
.Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!))
|
||||
.Concat(context.RequestMessages);
|
||||
}
|
||||
@@ -281,7 +303,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
|
||||
#pragma warning restore CA1513
|
||||
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var messageList = context.RequestMessages.Concat(context.ResponseMessages ?? []).ToList();
|
||||
var messageList = this.StorageInputMessageFilter(context.RequestMessages).Concat(context.ResponseMessages ?? []).ToList();
|
||||
if (messageList.Count == 0)
|
||||
{
|
||||
return;
|
||||
|
||||
@@ -26,10 +26,15 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
{
|
||||
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
|
||||
|
||||
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
|
||||
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
|
||||
|
||||
private readonly string _contextPrompt;
|
||||
private readonly bool _enableSensitiveTelemetryData;
|
||||
private readonly string _stateKey;
|
||||
private readonly Func<AgentSession?, State> _stateInitializer;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _searchInputMessageFilter;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
|
||||
|
||||
private readonly Mem0Client _client;
|
||||
private readonly ILogger<Mem0Provider>? _logger;
|
||||
@@ -67,6 +72,8 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt;
|
||||
this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false;
|
||||
this._stateKey = options?.StateKey ?? base.StateKey;
|
||||
this._searchInputMessageFilter = options?.SearchInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
@@ -114,8 +121,7 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
|
||||
string queryText = string.Join(
|
||||
Environment.NewLine,
|
||||
(inputContext.Messages ?? [])
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
this._searchInputMessageFilter(inputContext.Messages ?? [])
|
||||
.Where(m => !string.IsNullOrWhiteSpace(m.Text))
|
||||
.Select(m => m.Text));
|
||||
|
||||
@@ -205,8 +211,7 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
// Persist request and response messages after invocation.
|
||||
await this.PersistMessagesAsync(
|
||||
storageScope,
|
||||
context.RequestMessages
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
this._storageInputMessageFilter(context.RequestMessages)
|
||||
.Concat(context.ResponseMessages ?? []),
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Agents.AI.Mem0;
|
||||
|
||||
/// <summary>
|
||||
@@ -24,4 +28,24 @@ public sealed class Mem0ProviderOptions
|
||||
/// </summary>
|
||||
/// <value>Defaults to the provider's type name.</value>
|
||||
public string? StateKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to request messages when building the search text to use when
|
||||
/// searching for relevant memories during <see cref="AIContextProvider.InvokingAsync"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, the provider defaults to including only
|
||||
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? SearchInputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to request messages when determining which messages to
|
||||
/// extract memories from during <see cref="AIContextProvider.InvokedAsync"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, the provider defaults to including only
|
||||
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
|
||||
}
|
||||
|
||||
@@ -64,7 +64,9 @@ internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider
|
||||
return default;
|
||||
}
|
||||
|
||||
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
|
||||
var allNewMessages = context.RequestMessages
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory)
|
||||
.Concat(context.ResponseMessages ?? []);
|
||||
this.GetOrInitializeState(context.Session).Messages.AddRange(allNewMessages);
|
||||
|
||||
return default;
|
||||
|
||||
@@ -41,6 +41,9 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
private const string DefaultFunctionToolName = "Search";
|
||||
private const string DefaultFunctionToolDescription = "Allows searching for related previous chat history to help answer the user question.";
|
||||
|
||||
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
|
||||
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
|
||||
|
||||
#pragma warning disable CA2213 // VectorStore is not owned by this class - caller is responsible for disposal
|
||||
private readonly VectorStore _vectorStore;
|
||||
#pragma warning restore CA2213
|
||||
@@ -54,6 +57,8 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
private readonly ILogger<ChatHistoryMemoryProvider>? _logger;
|
||||
private readonly string _stateKey;
|
||||
private readonly Func<AgentSession?, State> _stateInitializer;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _searchInputMessageFilter;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
|
||||
|
||||
private bool _collectionInitialized;
|
||||
private readonly SemaphoreSlim _initializationLock = new(1, 1);
|
||||
@@ -89,6 +94,8 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
this._logger = loggerFactory?.CreateLogger<ChatHistoryMemoryProvider>();
|
||||
this._toolName = options.FunctionToolName ?? DefaultFunctionToolName;
|
||||
this._toolDescription = options.FunctionToolDescription ?? DefaultFunctionToolDescription;
|
||||
this._searchInputMessageFilter = options.SearchInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
this._storageInputMessageFilter = options.StorageInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
|
||||
// Create a definition so that we can use the dimensions provided at runtime.
|
||||
var definition = new VectorStoreCollectionDefinition
|
||||
@@ -171,8 +178,8 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
try
|
||||
{
|
||||
// Get the text from the current request messages
|
||||
var requestText = string.Join("\n", (inputContext.Messages ?? [])
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
var requestText = string.Join("\n",
|
||||
this._searchInputMessageFilter(inputContext.Messages ?? [])
|
||||
.Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text))
|
||||
.Select(m => m.Text));
|
||||
|
||||
@@ -238,8 +245,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
// Ensure the collection is initialized
|
||||
var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
List<Dictionary<string, object?>> itemsToStore = context.RequestMessages
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
List<Dictionary<string, object?>> itemsToStore = this._storageInputMessageFilter(context.RequestMessages)
|
||||
.Concat(context.ResponseMessages ?? [])
|
||||
.Select(message => new Dictionary<string, object?>
|
||||
{
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
@@ -53,6 +57,26 @@ public sealed class ChatHistoryMemoryProviderOptions
|
||||
/// </value>
|
||||
public string? StateKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to request messages when constructing the search text to use
|
||||
/// to search for relevant chat history during <see cref="AIContextProvider.InvokingAsync"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, the provider defaults to including only
|
||||
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? SearchInputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to request messages when storing recent chat history
|
||||
/// during <see cref="AIContextProvider.InvokedAsync"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, the provider defaults to including only
|
||||
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Behavior choices for the provider.
|
||||
/// </summary>
|
||||
|
||||
@@ -39,6 +39,9 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
private const string DefaultContextPrompt = "## Additional Context\nConsider the following information from source documents when responding to the user:";
|
||||
private const string DefaultCitationsPrompt = "Include citations to the source document with document name and link if document name and link is available.";
|
||||
|
||||
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
|
||||
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
|
||||
|
||||
private readonly Func<string, CancellationToken, Task<IEnumerable<TextSearchResult>>> _searchAsync;
|
||||
private readonly ILogger<TextSearchProvider>? _logger;
|
||||
private readonly AITool[] _tools;
|
||||
@@ -49,6 +52,8 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
private readonly string _citationsPrompt;
|
||||
private readonly string _stateKey;
|
||||
private readonly Func<IList<TextSearchResult>, string>? _contextFormatter;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _searchInputMessageFilter;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storageInputMessageFilter;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TextSearchProvider"/> class.
|
||||
@@ -72,6 +77,8 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
this._citationsPrompt = options?.CitationsPrompt ?? DefaultCitationsPrompt;
|
||||
this._stateKey = options?.StateKey ?? base.StateKey;
|
||||
this._contextFormatter = options?.ContextFormatter;
|
||||
this._searchInputMessageFilter = options?.SearchInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
|
||||
// Create the on-demand search tool (only used if behavior is OnDemandFunctionCalling)
|
||||
this._tools =
|
||||
@@ -108,8 +115,8 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
|
||||
// Aggregate text from memory + current request messages.
|
||||
var sbInput = new StringBuilder();
|
||||
var requestMessagesText = (inputContext.Messages ?? [])
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
var requestMessagesText =
|
||||
this._searchInputMessageFilter(inputContext.Messages ?? [])
|
||||
.Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text);
|
||||
foreach (var messageText in recentMessagesText.Concat(requestMessagesText))
|
||||
{
|
||||
@@ -189,8 +196,7 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
var recentMessagesText = context.Session.StateBag.GetValue<TextSearchProviderState>(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText
|
||||
?? [];
|
||||
|
||||
var newMessagesText = context.RequestMessages
|
||||
.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
var newMessagesText = this._storageInputMessageFilter(context.RequestMessages)
|
||||
.Concat(context.ResponseMessages ?? [])
|
||||
.Where(m =>
|
||||
this._recentMessageRolesIncluded.Contains(m.Role) &&
|
||||
|
||||
@@ -68,6 +68,26 @@ public sealed class TextSearchProviderOptions
|
||||
/// </value>
|
||||
public string? StateKey { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to request messages when constructing the search input
|
||||
/// text during <see cref="AIContextProvider.InvokingAsync"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, the provider defaults to including only
|
||||
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? SearchInputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets an optional filter function applied to request messages when updating the recent message
|
||||
/// memory during <see cref="AIContextProvider.InvokedAsync"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// When <see langword="null"/>, the provider defaults to including only
|
||||
/// <see cref="AgentRequestMessageSourceType.External"/> messages.
|
||||
/// </value>
|
||||
public Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? StorageInputMessageFilter { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the list of <see cref="ChatRole"/> types to filter recent messages to
|
||||
/// when deciding which recent messages to include when constructing the search input.
|
||||
|
||||
+43
@@ -390,6 +390,49 @@ public sealed class AgentRequestMessageSourceAttributionTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region ToString Tests
|
||||
|
||||
[Fact]
|
||||
public void ToString_WithSourceId_ReturnsTypeColonId()
|
||||
{
|
||||
// Arrange
|
||||
AgentRequestMessageSourceAttribution attribution = new(AgentRequestMessageSourceType.AIContextProvider, "MyProvider");
|
||||
|
||||
// Act
|
||||
string result = attribution.ToString();
|
||||
|
||||
// Assert
|
||||
Assert.Equal("AIContextProvider:MyProvider", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ToString_WithNullSourceId_ReturnsTypeOnly()
|
||||
{
|
||||
// Arrange
|
||||
AgentRequestMessageSourceAttribution attribution = new(AgentRequestMessageSourceType.ChatHistory, null);
|
||||
|
||||
// Act
|
||||
string result = attribution.ToString();
|
||||
|
||||
// Assert
|
||||
Assert.Equal("ChatHistory", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ToString_Default_ReturnsExternalOnly()
|
||||
{
|
||||
// Arrange
|
||||
AgentRequestMessageSourceAttribution attribution = default;
|
||||
|
||||
// Act
|
||||
string result = attribution.ToString();
|
||||
|
||||
// Assert
|
||||
Assert.Equal("External", result);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Inequality Operator Tests
|
||||
|
||||
[Fact]
|
||||
|
||||
+40
@@ -414,6 +414,46 @@ public sealed class AgentRequestMessageSourceTypeTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region ToString Tests
|
||||
|
||||
[Fact]
|
||||
public void ToString_ReturnsValue()
|
||||
{
|
||||
// Arrange
|
||||
AgentRequestMessageSourceType source = new("CustomSource");
|
||||
|
||||
// Act
|
||||
string result = source.ToString();
|
||||
|
||||
// Assert
|
||||
Assert.Equal("CustomSource", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ToString_StaticExternal_ReturnsExternal()
|
||||
{
|
||||
// Arrange & Act
|
||||
string result = AgentRequestMessageSourceType.External.ToString();
|
||||
|
||||
// Assert
|
||||
Assert.Equal("External", result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ToString_Default_ReturnsExternal()
|
||||
{
|
||||
// Arrange
|
||||
AgentRequestMessageSourceType source = default;
|
||||
|
||||
// Act
|
||||
string result = source.ToString();
|
||||
|
||||
// Assert
|
||||
Assert.Equal("External", result);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region IEquatable Tests
|
||||
|
||||
[Fact]
|
||||
|
||||
-141
@@ -1,141 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
using Moq.Protected;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Contains tests for the <see cref="ChatHistoryProviderExtensions"/> class.
|
||||
/// </summary>
|
||||
public sealed class ChatHistoryProviderExtensionsTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
[Fact]
|
||||
public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter()
|
||||
{
|
||||
// Arrange
|
||||
Mock<ChatHistoryProvider> providerMock = new();
|
||||
|
||||
// Act
|
||||
ChatHistoryProvider result = providerMock.Object.WithMessageFilters(
|
||||
invokingMessagesFilter: msgs => msgs,
|
||||
invokedMessagesFilter: ctx => ctx);
|
||||
|
||||
// Assert
|
||||
Assert.IsType<ChatHistoryProviderMessageFilter>(result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync()
|
||||
{
|
||||
// Arrange
|
||||
Mock<ChatHistoryProvider> providerMock = new();
|
||||
List<ChatMessage> innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")];
|
||||
ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
providerMock
|
||||
.Protected()
|
||||
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
|
||||
.ReturnsAsync(innerMessages);
|
||||
|
||||
ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters(
|
||||
invokingMessagesFilter: msgs => msgs.Where(m => m.Role == ChatRole.User));
|
||||
|
||||
// Act
|
||||
List<ChatMessage> result = (await filtered.InvokingAsync(context, CancellationToken.None)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Single(result);
|
||||
Assert.Equal(ChatRole.User, result[0].Role);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync()
|
||||
{
|
||||
// Arrange
|
||||
Mock<ChatHistoryProvider> providerMock = new();
|
||||
List<ChatMessage> requestMessages =
|
||||
[
|
||||
new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } },
|
||||
new(ChatRole.User, "Hello")
|
||||
];
|
||||
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
|
||||
ChatHistoryProvider.InvokedContext? capturedContext = null;
|
||||
providerMock
|
||||
.Protected()
|
||||
.Setup<ValueTask>("InvokedCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>())
|
||||
.Callback<ChatHistoryProvider.InvokedContext, CancellationToken>((ctx, _) => capturedContext = ctx)
|
||||
.Returns(default(ValueTask));
|
||||
|
||||
ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters(
|
||||
invokedMessagesFilter: ctx =>
|
||||
{
|
||||
ctx.ResponseMessages = null;
|
||||
return ctx;
|
||||
});
|
||||
|
||||
// Act
|
||||
await filtered.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedContext);
|
||||
Assert.Null(capturedContext.ResponseMessages);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIContextProviderMessageRemoval_ReturnsChatHistoryProviderMessageFilter()
|
||||
{
|
||||
// Arrange
|
||||
Mock<ChatHistoryProvider> providerMock = new();
|
||||
|
||||
// Act
|
||||
ChatHistoryProvider result = providerMock.Object.WithAIContextProviderMessageRemoval();
|
||||
|
||||
// Assert
|
||||
Assert.IsType<ChatHistoryProviderMessageFilter>(result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
Mock<ChatHistoryProvider> providerMock = new();
|
||||
List<ChatMessage> requestMessages =
|
||||
[
|
||||
new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } },
|
||||
new(ChatRole.User, "Hello"),
|
||||
new(ChatRole.System, "Context") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "TestContextSource") } } }
|
||||
];
|
||||
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages);
|
||||
|
||||
ChatHistoryProvider.InvokedContext? capturedContext = null;
|
||||
providerMock
|
||||
.Protected()
|
||||
.Setup<ValueTask>("InvokedCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>())
|
||||
.Callback<ChatHistoryProvider.InvokedContext, CancellationToken>((ctx, _) => capturedContext = ctx)
|
||||
.Returns(default(ValueTask));
|
||||
|
||||
ChatHistoryProvider filtered = providerMock.Object.WithAIContextProviderMessageRemoval();
|
||||
|
||||
// Act
|
||||
await filtered.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedContext);
|
||||
Assert.Equal(2, capturedContext.RequestMessages.Count());
|
||||
Assert.Contains("System", capturedContext.RequestMessages.Select(x => x.Text));
|
||||
Assert.Contains("Hello", capturedContext.RequestMessages.Select(x => x.Text));
|
||||
}
|
||||
}
|
||||
-213
@@ -1,213 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
using Moq.Protected;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Contains tests for the <see cref="ChatHistoryProviderMessageFilter"/> class.
|
||||
/// </summary>
|
||||
public sealed class ChatHistoryProviderMessageFilterTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
[Fact]
|
||||
public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException()
|
||||
{
|
||||
// Arrange, Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProviderMessageFilter(null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_WithOnlyInnerProvider_Throws()
|
||||
{
|
||||
// Arrange
|
||||
var innerProviderMock = new Mock<ChatHistoryProvider>();
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentException>(() => new ChatHistoryProviderMessageFilter(innerProviderMock.Object));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_WithAllParameters_CreatesInstance()
|
||||
{
|
||||
// Arrange
|
||||
var innerProviderMock = new Mock<ChatHistoryProvider>();
|
||||
|
||||
IEnumerable<ChatMessage> InvokingFilter(IEnumerable<ChatMessage> msgs) => msgs;
|
||||
ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) => ctx;
|
||||
|
||||
// Act
|
||||
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter, InvokedFilter);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(filter);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StateKey_DelegatesToInnerProvider()
|
||||
{
|
||||
// Arrange
|
||||
var innerProviderMock = new Mock<ChatHistoryProvider>();
|
||||
innerProviderMock.Setup(p => p.StateKey).Returns("inner-state-key");
|
||||
|
||||
// Act
|
||||
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("inner-state-key", filter.StateKey);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerProviderMock = new Mock<ChatHistoryProvider>();
|
||||
var expectedMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "Hello"),
|
||||
new(ChatRole.Assistant, "Hi there!")
|
||||
};
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
innerProviderMock
|
||||
.Protected()
|
||||
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
|
||||
.ReturnsAsync(expectedMessages);
|
||||
|
||||
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x);
|
||||
|
||||
// Act
|
||||
var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(2, result.Count);
|
||||
Assert.Equal("Hello", result[0].Text);
|
||||
Assert.Equal("Hi there!", result[1].Text);
|
||||
innerProviderMock
|
||||
.Protected()
|
||||
.Verify<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerProviderMock = new Mock<ChatHistoryProvider>();
|
||||
var innerMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "Hello"),
|
||||
new(ChatRole.Assistant, "Hi there!"),
|
||||
new(ChatRole.User, "How are you?")
|
||||
};
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
innerProviderMock
|
||||
.Protected()
|
||||
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
|
||||
.ReturnsAsync(innerMessages);
|
||||
|
||||
// Filter to only user messages
|
||||
IEnumerable<ChatMessage> InvokingFilter(IEnumerable<ChatMessage> msgs) => msgs.Where(m => m.Role == ChatRole.User);
|
||||
|
||||
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter);
|
||||
|
||||
// Act
|
||||
var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(2, result.Count);
|
||||
Assert.All(result, msg => Assert.Equal(ChatRole.User, msg.Role));
|
||||
innerProviderMock
|
||||
.Protected()
|
||||
.Verify<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerProviderMock = new Mock<ChatHistoryProvider>();
|
||||
var innerMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "Hello"),
|
||||
new(ChatRole.Assistant, "Hi there!")
|
||||
};
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
innerProviderMock
|
||||
.Protected()
|
||||
.Setup<ValueTask<IEnumerable<ChatMessage>>>("InvokingCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokingContext>(), ItExpr.IsAny<CancellationToken>())
|
||||
.ReturnsAsync(innerMessages);
|
||||
|
||||
// Filter that transforms messages
|
||||
IEnumerable<ChatMessage> InvokingFilter(IEnumerable<ChatMessage> msgs) =>
|
||||
msgs.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}"));
|
||||
|
||||
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter);
|
||||
|
||||
// Act
|
||||
var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(2, result.Count);
|
||||
Assert.Equal("[FILTERED] Hello", result[0].Text);
|
||||
Assert.Equal("[FILTERED] Hi there!", result[1].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerProviderMock = new Mock<ChatHistoryProvider>();
|
||||
List<ChatMessage> requestMessages =
|
||||
[
|
||||
new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } },
|
||||
new(ChatRole.User, "Hello"),
|
||||
];
|
||||
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response") };
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages)
|
||||
{
|
||||
ResponseMessages = responseMessages
|
||||
};
|
||||
|
||||
ChatHistoryProvider.InvokedContext? capturedContext = null;
|
||||
innerProviderMock
|
||||
.Protected()
|
||||
.Setup<ValueTask>("InvokedCoreAsync", ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>())
|
||||
.Callback<ChatHistoryProvider.InvokedContext, CancellationToken>((ctx, ct) => capturedContext = ctx)
|
||||
.Returns(default(ValueTask));
|
||||
|
||||
// Filter that modifies the context
|
||||
ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx)
|
||||
{
|
||||
var modifiedRequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External).Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList();
|
||||
return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages)
|
||||
{
|
||||
ResponseMessages = ctx.ResponseMessages,
|
||||
InvokeException = ctx.InvokeException
|
||||
};
|
||||
}
|
||||
|
||||
var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, invokedMessagesFilter: InvokedFilter);
|
||||
|
||||
// Act
|
||||
await filter.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedContext);
|
||||
Assert.Single(capturedContext.RequestMessages);
|
||||
Assert.Equal("[FILTERED] Hello", capturedContext.RequestMessages.First().Text);
|
||||
innerProviderMock
|
||||
.Protected()
|
||||
.Verify<ValueTask>("InvokedCoreAsync", Times.Once(), ItExpr.IsAny<ChatHistoryProvider.InvokedContext>(), ItExpr.IsAny<CancellationToken>());
|
||||
}
|
||||
}
|
||||
+86
-1
@@ -83,7 +83,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
};
|
||||
|
||||
var provider = new InMemoryChatHistoryProvider();
|
||||
provider.SetMessages(session, [providerMessages[0]]);
|
||||
provider.SetMessages(session, providerMessages);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
|
||||
{
|
||||
ResponseMessages = responseMessages
|
||||
@@ -397,6 +397,91 @@ public class InMemoryChatHistoryProviderTests
|
||||
await Assert.ThrowsAsync<ArgumentNullException>(() => provider.InvokingAsync(null!, CancellationToken.None).AsTask());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_DefaultFilter_ExcludesChatHistoryMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var session = CreateMockSession();
|
||||
var provider = new InMemoryChatHistoryProvider();
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert - ChatHistory message excluded, AIContextProvider message included
|
||||
var messages = provider.GetMessages(session);
|
||||
Assert.Equal(3, messages.Count);
|
||||
Assert.Equal("External message", messages[0].Text);
|
||||
Assert.Equal("From context provider", messages[1].Text);
|
||||
Assert.Equal("Response", messages[2].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_CustomFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
var session = CreateMockSession();
|
||||
var provider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
|
||||
{
|
||||
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
});
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert - Custom filter keeps only External messages (both ChatHistory and AIContextProvider excluded)
|
||||
var messages = provider.GetMessages(session);
|
||||
Assert.Equal(2, messages.Count);
|
||||
Assert.Equal("External message", messages[0].Text);
|
||||
Assert.Equal("Response", messages[1].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_OutputFilter_FiltersOutputMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var session = CreateMockSession();
|
||||
var provider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions
|
||||
{
|
||||
RetrievalOutputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.User)
|
||||
});
|
||||
provider.SetMessages(session,
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "User message"),
|
||||
new ChatMessage(ChatRole.Assistant, "Assistant message"),
|
||||
new ChatMessage(ChatRole.System, "System message")
|
||||
]);
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
|
||||
var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList();
|
||||
|
||||
// Assert - Only user messages pass through the output filter
|
||||
Assert.Single(result);
|
||||
Assert.Equal("User message", result[0].Text);
|
||||
}
|
||||
|
||||
public class TestAIContent(string testData) : AIContent
|
||||
{
|
||||
public string TestData => testData;
|
||||
|
||||
+124
@@ -841,4 +841,128 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Message Filter Tests
|
||||
|
||||
[SkippableFact]
|
||||
[Trait("Category", "CosmosDB")]
|
||||
public async Task InvokedAsync_DefaultFilter_ExcludesChatHistoryMessagesFromStorageAsync()
|
||||
{
|
||||
// Arrange
|
||||
this.SkipIfEmulatorNotAvailable();
|
||||
var session = CreateMockSession();
|
||||
var conversationId = Guid.NewGuid().ToString();
|
||||
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId,
|
||||
_ => new CosmosChatHistoryProvider.State(conversationId));
|
||||
|
||||
var requestMessages = new[]
|
||||
{
|
||||
new ChatMessage(ChatRole.User, "External message"),
|
||||
new ChatMessage(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new ChatMessage(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(context);
|
||||
|
||||
// Wait for eventual consistency
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert - ChatHistory message excluded, External + AIContextProvider + Response stored
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
|
||||
var messages = (await provider.InvokingAsync(invokingContext)).ToList();
|
||||
Assert.Equal(3, messages.Count);
|
||||
Assert.Equal("External message", messages[0].Text);
|
||||
Assert.Equal("From context provider", messages[1].Text);
|
||||
Assert.Equal("Response", messages[2].Text);
|
||||
}
|
||||
|
||||
[SkippableFact]
|
||||
[Trait("Category", "CosmosDB")]
|
||||
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
this.SkipIfEmulatorNotAvailable();
|
||||
var session = CreateMockSession();
|
||||
var conversationId = Guid.NewGuid().ToString();
|
||||
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId,
|
||||
_ => new CosmosChatHistoryProvider.State(conversationId))
|
||||
{
|
||||
// Custom filter: only store External messages (also exclude AIContextProvider)
|
||||
StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External)
|
||||
};
|
||||
|
||||
var requestMessages = new[]
|
||||
{
|
||||
new ChatMessage(ChatRole.User, "External message"),
|
||||
new ChatMessage(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new ChatMessage(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(context);
|
||||
|
||||
// Wait for eventual consistency
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert - Custom filter: only External + Response stored (both ChatHistory and AIContextProvider excluded)
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
|
||||
var messages = (await provider.InvokingAsync(invokingContext)).ToList();
|
||||
Assert.Equal(2, messages.Count);
|
||||
Assert.Equal("External message", messages[0].Text);
|
||||
Assert.Equal("Response", messages[1].Text);
|
||||
}
|
||||
|
||||
[SkippableFact]
|
||||
[Trait("Category", "CosmosDB")]
|
||||
public async Task InvokingAsync_RetrievalOutputFilter_FiltersRetrievedMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
this.SkipIfEmulatorNotAvailable();
|
||||
var session = CreateMockSession();
|
||||
var conversationId = Guid.NewGuid().ToString();
|
||||
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId,
|
||||
_ => new CosmosChatHistoryProvider.State(conversationId))
|
||||
{
|
||||
// Only return User messages when retrieving
|
||||
RetrievalOutputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.User)
|
||||
};
|
||||
|
||||
var requestMessages = new[]
|
||||
{
|
||||
new ChatMessage(ChatRole.User, "User message"),
|
||||
new ChatMessage(ChatRole.System, "System message"),
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Assistant response")]
|
||||
};
|
||||
|
||||
await provider.InvokedAsync(context);
|
||||
|
||||
// Wait for eventual consistency
|
||||
await Task.Delay(100);
|
||||
|
||||
// Act
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []);
|
||||
var messages = (await provider.InvokingAsync(invokingContext)).ToList();
|
||||
|
||||
// Assert - Only User messages returned (System and Assistant filtered by RetrievalOutputMessageFilter)
|
||||
Assert.Single(messages);
|
||||
Assert.Equal("User message", messages[0].Text);
|
||||
Assert.Equal(ChatRole.User, messages[0].Role);
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
|
||||
@@ -436,6 +436,116 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
Assert.NotNull(state);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._handler.EnqueueJsonResponse("[]"); // Empty search results
|
||||
var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" };
|
||||
var mockSession = new TestAgentSession();
|
||||
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = requestMessages });
|
||||
|
||||
// Act
|
||||
await sut.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Search query should only contain the External message
|
||||
var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post);
|
||||
using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody);
|
||||
Assert.Equal("External message", doc.RootElement.GetProperty("query").GetString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._handler.EnqueueJsonResponse("[]"); // Empty search results
|
||||
var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" };
|
||||
var mockSession = new TestAgentSession();
|
||||
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new Mem0ProviderOptions
|
||||
{
|
||||
SearchInputMessageFilter = messages => messages // No filtering
|
||||
});
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
};
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = requestMessages });
|
||||
|
||||
// Act
|
||||
await sut.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Search query should contain all messages (custom identity filter)
|
||||
var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post);
|
||||
using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody);
|
||||
var queryText = doc.RootElement.GetProperty("query").GetString();
|
||||
Assert.Contains("External message", queryText);
|
||||
Assert.Contains("From history", queryText);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._handler.EnqueueEmptyOk(); // For the one message that should be stored
|
||||
var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" };
|
||||
var mockSession = new TestAgentSession();
|
||||
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
};
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages));
|
||||
|
||||
// Assert - Only the External message should be persisted
|
||||
var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList();
|
||||
Assert.Single(memoryPosts);
|
||||
Assert.Contains("External message", memoryPosts[0].RequestBody);
|
||||
Assert.DoesNotContain(memoryPosts, r => ContainsOrdinal(r.RequestBody, "From history"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._handler.EnqueueEmptyOk(); // For first CreateMemory
|
||||
this._handler.EnqueueEmptyOk(); // For second CreateMemory
|
||||
var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" };
|
||||
var mockSession = new TestAgentSession();
|
||||
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new Mem0ProviderOptions
|
||||
{
|
||||
StorageInputMessageFilter = messages => messages // No filtering - store everything
|
||||
});
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
};
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages));
|
||||
|
||||
// Assert - Both messages should be persisted (identity filter overrides default)
|
||||
var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList();
|
||||
Assert.Equal(2, memoryPosts.Count);
|
||||
}
|
||||
|
||||
private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0;
|
||||
|
||||
public void Dispose()
|
||||
|
||||
@@ -356,6 +356,140 @@ public sealed class TextSearchProviderTests
|
||||
Assert.Null(aiContext.Tools);
|
||||
}
|
||||
|
||||
#region Message Filter Tests
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchInputAsync()
|
||||
{
|
||||
// Arrange
|
||||
string? capturedInput = null;
|
||||
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
|
||||
{
|
||||
capturedInput = input;
|
||||
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
|
||||
}
|
||||
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync);
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Only external messages should be used for search input
|
||||
Assert.Equal("External message", capturedInput);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
string? capturedInput = null;
|
||||
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
|
||||
{
|
||||
capturedInput = input;
|
||||
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
|
||||
}
|
||||
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions
|
||||
{
|
||||
SearchInputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.System)
|
||||
});
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "User message"),
|
||||
new(ChatRole.System, "System message"),
|
||||
};
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Custom filter keeps only System messages
|
||||
Assert.Equal("System message", capturedInput);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync()
|
||||
{
|
||||
// Arrange
|
||||
var options = new TextSearchProviderOptions
|
||||
{
|
||||
RecentMessageMemoryLimit = 10,
|
||||
RecentMessageRolesIncluded = [ChatRole.User, ChatRole.System]
|
||||
};
|
||||
string? capturedInput = null;
|
||||
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
|
||||
{
|
||||
capturedInput = input;
|
||||
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
|
||||
}
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, options);
|
||||
var session = new TestAgentSession();
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
// Store messages via InvokedAsync
|
||||
await provider.InvokedAsync(new(s_mockAgent, session, requestMessages));
|
||||
|
||||
// Now invoke to read stored memory
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = [new ChatMessage(ChatRole.User, "Next")] });
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Only "External message" was stored in memory, so search input = "External message" + "Next"
|
||||
Assert.Equal("External message\nNext", capturedInput);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
var options = new TextSearchProviderOptions
|
||||
{
|
||||
RecentMessageMemoryLimit = 10,
|
||||
RecentMessageRolesIncluded = [ChatRole.User, ChatRole.System],
|
||||
StorageInputMessageFilter = messages => messages // No filtering - store everything
|
||||
};
|
||||
string? capturedInput = null;
|
||||
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
|
||||
{
|
||||
capturedInput = input;
|
||||
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
|
||||
}
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, options);
|
||||
var session = new TestAgentSession();
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
};
|
||||
|
||||
// Store messages via InvokedAsync
|
||||
await provider.InvokedAsync(new(s_mockAgent, session, requestMessages));
|
||||
|
||||
// Now invoke to read stored memory
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = [new ChatMessage(ChatRole.User, "Next")] });
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Both messages stored (identity filter), so search input includes all + current
|
||||
Assert.Equal("External message\nFrom history\nNext", capturedInput);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Recent Message Memory Tests
|
||||
|
||||
[Fact]
|
||||
|
||||
@@ -539,6 +539,188 @@ public class ChatHistoryMemoryProviderTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region Message Filter Tests
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchAsync()
|
||||
{
|
||||
// Arrange
|
||||
var providerOptions = new ChatHistoryMemoryProviderOptions
|
||||
{
|
||||
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke,
|
||||
};
|
||||
|
||||
string? capturedQuery = null;
|
||||
this._vectorStoreCollectionMock
|
||||
.Setup(c => c.SearchAsync(
|
||||
It.IsAny<string>(),
|
||||
It.IsAny<int>(),
|
||||
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Callback<string, int, VectorSearchOptions<Dictionary<string, object?>>, CancellationToken>((query, _, _, _) => capturedQuery = query)
|
||||
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
|
||||
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
|
||||
options: providerOptions);
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Only External message used for search query
|
||||
Assert.Equal("External message", capturedQuery);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
var providerOptions = new ChatHistoryMemoryProviderOptions
|
||||
{
|
||||
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke,
|
||||
SearchInputMessageFilter = messages => messages // No filtering
|
||||
};
|
||||
|
||||
string? capturedQuery = null;
|
||||
this._vectorStoreCollectionMock
|
||||
.Setup(c => c.SearchAsync(
|
||||
It.IsAny<string>(),
|
||||
It.IsAny<int>(),
|
||||
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Callback<string, int, VectorSearchOptions<Dictionary<string, object?>>, CancellationToken>((query, _, _, _) => capturedQuery = query)
|
||||
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
|
||||
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
|
||||
options: providerOptions);
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
};
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages });
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
// Assert - Both messages should be included in search query (identity filter)
|
||||
Assert.NotNull(capturedQuery);
|
||||
Assert.Contains("External message", capturedQuery);
|
||||
Assert.Contains("From history", capturedQuery);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync()
|
||||
{
|
||||
// Arrange
|
||||
var stored = new List<Dictionary<string, object?>>();
|
||||
|
||||
this._vectorStoreCollectionMock
|
||||
.Setup(c => c.UpsertAsync(It.IsAny<IEnumerable<Dictionary<string, object?>>>(), It.IsAny<CancellationToken>()))
|
||||
.Callback<IEnumerable<Dictionary<string, object?>>, CancellationToken>((items, ct) =>
|
||||
{
|
||||
if (items != null)
|
||||
{
|
||||
stored.AddRange(items);
|
||||
}
|
||||
})
|
||||
.Returns(Task.CompletedTask);
|
||||
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }));
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } },
|
||||
};
|
||||
|
||||
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(invokedContext, CancellationToken.None);
|
||||
|
||||
// Assert - Only External message + response stored (ChatHistory and AIContextProvider excluded by default)
|
||||
Assert.Equal(2, stored.Count);
|
||||
Assert.Equal("External message", stored[0]["Content"]);
|
||||
Assert.Equal("Response", stored[1]["Content"]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
var stored = new List<Dictionary<string, object?>>();
|
||||
|
||||
this._vectorStoreCollectionMock
|
||||
.Setup(c => c.UpsertAsync(It.IsAny<IEnumerable<Dictionary<string, object?>>>(), It.IsAny<CancellationToken>()))
|
||||
.Callback<IEnumerable<Dictionary<string, object?>>, CancellationToken>((items, ct) =>
|
||||
{
|
||||
if (items != null)
|
||||
{
|
||||
stored.AddRange(items);
|
||||
}
|
||||
})
|
||||
.Returns(Task.CompletedTask);
|
||||
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
|
||||
options: new ChatHistoryMemoryProviderOptions
|
||||
{
|
||||
StorageInputMessageFilter = messages => messages // No filtering - store everything
|
||||
});
|
||||
|
||||
var requestMessages = new List<ChatMessage>
|
||||
{
|
||||
new(ChatRole.User, "External message"),
|
||||
new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } },
|
||||
};
|
||||
|
||||
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), requestMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(invokedContext, CancellationToken.None);
|
||||
|
||||
// Assert - All messages stored (identity filter overrides default)
|
||||
Assert.Equal(3, stored.Count);
|
||||
Assert.Equal("External message", stored[0]["Content"]);
|
||||
Assert.Equal("From history", stored[1]["Content"]);
|
||||
Assert.Equal("Response", stored[2]["Content"]);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private static async IAsyncEnumerable<T> ToAsyncEnumerableAsync<T>(IEnumerable<T> values)
|
||||
{
|
||||
await Task.Yield();
|
||||
|
||||
Reference in New Issue
Block a user