.NET: Support a message only AIContextProvider as an AIAgent Decorator (#4009)

* Support a message only AIContextProvider as an AIAgent Decorator

* Fix formatting

* Address PR comments.
This commit is contained in:
westey
2026-02-19 19:03:56 +00:00
committed by GitHub
Unverified
parent 5ee06853a1
commit 40d3a0655c
17 changed files with 1692 additions and 50 deletions
@@ -2,8 +2,9 @@
// This sample shows multiple middleware layers working together with Azure OpenAI:
// chat client (global/per-request), agent run (PII filtering and guardrails),
// function invocation (logging and result overrides), and human-in-the-loop
// approval workflows for sensitive function calls.
// function invocation (logging and result overrides), human-in-the-loop
// approval workflows for sensitive function calls, and MessageAIContextProvider
// middleware for injecting additional context messages into the agent pipeline.
using System.ComponentModel;
using System.Text.RegularExpressions;
@@ -96,6 +97,20 @@ var response = await originalAgent // Using per-request middleware pipeline with
Console.WriteLine($"Per-request middleware response: {response}");
// MessageAIContextProvider middleware that injects additional messages into the agent request.
// This allows any AIAgent (not just ChatClientAgent) to benefit from MessageAIContextProvider-based
// context enrichment. Multiple providers can be passed to Use and they are called in sequence,
// each receiving the output of the previous one.
Console.WriteLine("\n\n=== Example 5: MessageAIContextProvider middleware ===");
var contextProviderAgent = originalAgent
.AsBuilder()
.Use([new DateTimeContextProvider()])
.Build();
var contextResponse = await contextProviderAgent.RunAsync("Is it almost time for lunch?");
Console.WriteLine($"Context-enriched response: {contextResponse}");
// Function invocation middleware that logs before and after function calls.
async ValueTask<object?> FunctionCallMiddleware(AIAgent agent, FunctionInvocationContext context, Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> next, CancellationToken cancellationToken)
{
@@ -259,3 +274,23 @@ async Task<ChatResponse> PerRequestChatClientMiddleware(IEnumerable<ChatMessage>
return response;
}
/// <summary>
/// A <see cref="MessageAIContextProvider"/> that injects the current date and time into the agent's context.
/// This is a simple example of how to use a MessageAIContextProvider to enrich agent messages
/// via the <see cref="AIAgentBuilder.Use(MessageAIContextProvider[])"/> extension method.
/// </summary>
internal sealed class DateTimeContextProvider : MessageAIContextProvider
{
protected override ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(
InvokingContext context,
CancellationToken cancellationToken = default)
{
Console.WriteLine("DateTimeContextProvider - Injecting current date/time context");
return new ValueTask<IEnumerable<ChatMessage>>(
[
new ChatMessage(ChatRole.User, $"For reference, the current date and time is: {DateTimeOffset.Now}")
]);
}
}
@@ -14,6 +14,7 @@ This sample demonstrates how to add middleware to intercept:
5. Perrequest chat client middleware
6. Perrequest function pipeline with approval
7. Combining agentlevel and perrequest middleware
8. MessageAIContextProvider middleware via `AIAgentBuilder.Use(...)` for injecting additional context messages
## Function Invocation Middleware
@@ -146,11 +146,11 @@ namespace SampleApp
}
/// <summary>
/// An <see cref="AIContextProvider"/> which searches for upcoming calendar events and adds them to the AI context.
/// A <see cref="MessageAIContextProvider"/> which searches for upcoming calendar events and adds them to the AI context.
/// </summary>
internal sealed class CalendarSearchAIContextProvider(Func<Task<string[]>> loadNextThreeCalendarEvents) : AIContextProvider
internal sealed class CalendarSearchAIContextProvider(Func<Task<string[]>> loadNextThreeCalendarEvents) : MessageAIContextProvider
{
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<IEnumerable<MEAI.ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var events = await loadNextThreeCalendarEvents();
@@ -161,10 +161,7 @@ namespace SampleApp
outputMessageBuilder.AppendLine($" - {calendarEvent}");
}
return new AIContext
{
Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())]
};
return [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())];
}
}
}
@@ -34,9 +34,6 @@ public abstract class AIContextProvider
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _provideInputMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storeInputMessageFilter;
/// <summary>
/// Initializes a new instance of the <see cref="AIContextProvider"/> class.
/// </summary>
@@ -46,10 +43,20 @@ public abstract class AIContextProvider
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
{
this._provideInputMessageFilter = provideInputMessageFilter ?? DefaultExternalOnlyFilter;
this._storeInputMessageFilter = storeInputMessageFilter ?? DefaultExternalOnlyFilter;
this.ProvideInputMessageFilter = provideInputMessageFilter ?? DefaultExternalOnlyFilter;
this.StoreInputMessageFilter = storeInputMessageFilter ?? DefaultExternalOnlyFilter;
}
/// <summary>
/// Gets the filter function to apply to input messages before providing context via <see cref="ProvideAIContextAsync"/>.
/// </summary>
protected Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> ProvideInputMessageFilter { get; }
/// <summary>
/// Gets the filter function to apply to request messages before storing context via <see cref="StoreAIContextAsync"/>.
/// </summary>
protected Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> StoreInputMessageFilter { get; }
/// <summary>
/// Gets the key used to store the provider state in the <see cref="AgentSession.StateBag"/>.
/// </summary>
@@ -120,7 +127,7 @@ public abstract class AIContextProvider
new AIContext
{
Instructions = inputContext.Instructions,
Messages = inputContext.Messages is not null ? this._provideInputMessageFilter(inputContext.Messages) : null,
Messages = inputContext.Messages is not null ? this.ProvideInputMessageFilter(inputContext.Messages) : null,
Tools = inputContext.Tools
});
@@ -254,7 +261,7 @@ public abstract class AIContextProvider
return default;
}
var subContext = new InvokedContext(context.Agent, context.Session, this._storeInputMessageFilter(context.RequestMessages), context.ResponseMessages!);
var subContext = new InvokedContext(context.Agent, context.Session, this.StoreInputMessageFilter(context.RequestMessages), context.ResponseMessages!);
return this.StoreAIContextAsync(subContext, cancellationToken);
}
@@ -0,0 +1,203 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.AI;
/// <summary>
/// Provides an abstract base class for components that enhance AI context during agent invocations by supplying additional chat messages.
/// </summary>
/// <remarks>
/// <para>
/// A message AI context provider is a component that participates in the agent invocation lifecycle by:
/// <list type="bullet">
/// <item><description>Listening to changes in conversations</description></item>
/// <item><description>Providing additional messages to agents during invocation</description></item>
/// <item><description>Processing invocation results for state management or learning</description></item>
/// </list>
/// </para>
/// <para>
/// Context providers operate through a two-phase lifecycle: they are called at the start of invocation via
/// <see cref="AIContextProvider.InvokingAsync"/> to provide context, and optionally called at the end of invocation via
/// <see cref="AIContextProvider.InvokedAsync"/> to process results.
/// </para>
/// </remarks>
public abstract class MessageAIContextProvider : AIContextProvider
{
/// <summary>
/// Initializes a new instance of the <see cref="MessageAIContextProvider"/> class.
/// </summary>
/// <param name="provideInputMessageFilter">An optional filter function to apply to input messages before providing messages via <see cref="ProvideMessagesAsync"/>. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
/// <param name="storeInputMessageFilter">An optional filter function to apply to request messages before storing messages via <see cref="AIContextProvider.StoreAIContextAsync"/>. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
protected MessageAIContextProvider(
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
: base(provideInputMessageFilter, storeInputMessageFilter)
{
}
/// <inheritdoc/>
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
{
// Call ProvideMessagesAsync directly to return only additional messages.
// The base AIContextProvider.InvokingCoreAsync handles merging with the original input and stamping.
return new AIContext
{
Messages = await this.ProvideMessagesAsync(
new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []),
cancellationToken).ConfigureAwait(false)
};
}
/// <summary>
/// Called at the start of agent invocation to provide additional messages.
/// </summary>
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that represents the asynchronous operation. The task result contains the <see cref="IEnumerable{ChatMessage}"/> to be used by the agent during this invocation.</returns>
/// <remarks>
/// <para>
/// Implementers can load any additional messages required at this time, such as:
/// <list type="bullet">
/// <item><description>Retrieving relevant information from knowledge bases</description></item>
/// <item><description>Adding system instructions or prompts</description></item>
/// <item><description>Injecting contextual messages from conversation history</description></item>
/// </list>
/// </para>
/// </remarks>
public ValueTask<IEnumerable<ChatMessage>> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
=> this.InvokingCoreAsync(Throw.IfNull(context), cancellationToken);
/// <summary>
/// Called at the start of agent invocation to provide additional messages.
/// </summary>
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A task that represents the asynchronous operation. The task result contains the <see cref="IEnumerable{ChatMessage}"/> to be used by the agent during this invocation.</returns>
/// <remarks>
/// <para>
/// Implementers can load any additional messages required at this time, such as:
/// <list type="bullet">
/// <item><description>Retrieving relevant information from knowledge bases</description></item>
/// <item><description>Adding system instructions or prompts</description></item>
/// <item><description>Injecting contextual messages from conversation history</description></item>
/// </list>
/// </para>
/// <para>
/// The default implementation of this method filters the input messages using the configured provide-input message filter
/// (which defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages),
/// then calls <see cref="ProvideMessagesAsync"/> to get additional messages,
/// stamps any messages with <see cref="AgentRequestMessageSourceType.AIContextProvider"/> source attribution,
/// and merges the returned messages with the original (unfiltered) input messages.
/// For most scenarios, overriding <see cref="ProvideMessagesAsync"/> is sufficient to provide additional messages,
/// while still benefiting from the default filtering, merging and source stamping behavior.
/// However, for scenarios that require more control over message filtering, merging or source stamping, overriding this method
/// allows you to directly control the full <see cref="IEnumerable{ChatMessage}"/> returned for the invocation.
/// </para>
/// </remarks>
protected virtual async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
var inputMessages = context.RequestMessages;
// Create a filtered context for ProvideMessagesAsync, filtering input messages
// to exclude non-external messages (e.g. chat history, other AI context provider messages).
var filteredContext = new InvokingContext(
context.Agent,
context.Session,
this.ProvideInputMessageFilter(inputMessages));
var providedMessages = await this.ProvideMessagesAsync(filteredContext, cancellationToken).ConfigureAwait(false);
// Stamp and merge provided messages.
providedMessages = providedMessages.Select(m => m.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!));
return inputMessages.Concat(providedMessages);
}
/// <summary>
/// When overridden in a derived class, provides additional messages to be merged with the input messages for the current invocation.
/// </summary>
/// <remarks>
/// <para>
/// This method is called from <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/>.
/// Note that <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/> can be overridden to directly control messages merging and source stamping, in which case
/// it is up to the implementer to call this method as needed to retrieve the additional messages.
/// </para>
/// <para>
/// In contrast with <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/>, this method only returns additional messages to be merged with the input,
/// while <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/> is responsible for returning the full merged <see cref="IEnumerable{ChatMessage}"/> for the invocation.
/// </para>
/// </remarks>
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>
/// A task that represents the asynchronous operation. The task result contains an <see cref="IEnumerable{ChatMessage}"/>
/// with additional messages to be merged with the input messages.
/// </returns>
protected virtual ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
return new ValueTask<IEnumerable<ChatMessage>>([]);
}
/// <summary>
/// Contains the context information provided to <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/>.
/// </summary>
/// <remarks>
/// This class provides context about the invocation before the underlying AI model is invoked, including the messages
/// that will be used. Message AI Context providers can use this information to determine what additional messages
/// should be provided for the invocation.
/// </remarks>
public new sealed class InvokingContext
{
/// <summary>
/// Initializes a new instance of the <see cref="InvokingContext"/> class with the specified request messages.
/// </summary>
/// <param name="agent">The agent being invoked.</param>
/// <param name="session">The session associated with the agent invocation.</param>
/// <param name="requestMessages">The messages to be used by the agent for this invocation.</param>
/// <exception cref="ArgumentNullException"><paramref name="agent"/> or <paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokingContext(
AIAgent agent,
AgentSession? session,
IEnumerable<ChatMessage> requestMessages)
{
this.Agent = Throw.IfNull(agent);
this.Session = session;
this.RequestMessages = Throw.IfNull(requestMessages);
}
/// <summary>
/// Gets the agent that is being invoked.
/// </summary>
public AIAgent Agent { get; }
/// <summary>
/// Gets the agent session associated with the agent invocation.
/// </summary>
public AgentSession? Session { get; }
/// <summary>
/// Gets the messages that will be used by the agent for this invocation. <see cref="MessageAIContextProvider"/> instances can modify
/// and return or return a new message list to add additional messages for the invocation.
/// </summary>
/// <value>
/// A collection of <see cref="ChatMessage"/> instances representing the messages that will be used by the agent for this invocation.
/// </value>
/// <remarks>
/// <para>
/// If multiple <see cref="MessageAIContextProvider"/> instances are used in the same invocation, each <see cref="MessageAIContextProvider"/>
/// will receive the messages returned by the previous <see cref="MessageAIContextProvider"/> allowing them to build on top of each other's context.
/// </para>
/// <para>
/// The first <see cref="MessageAIContextProvider"/> in the invocation pipeline will receive the
/// caller provided messages.
/// </para>
/// </remarks>
public IEnumerable<ChatMessage> RequestMessages { get; set { field = Throw.IfNull(value); } }
}
}
@@ -2,6 +2,7 @@
using System;
using System.Text.Json;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.AI;
@@ -39,8 +40,8 @@ public class ProviderSessionState<TState>
string stateKey,
JsonSerializerOptions? jsonSerializerOptions = null)
{
this._stateInitializer = stateInitializer;
this.StateKey = stateKey;
this._stateInitializer = Throw.IfNull(stateInitializer);
this.StateKey = Throw.IfNullOrWhitespace(stateKey);
this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions;
}
@@ -14,7 +14,7 @@ using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.AI.Mem0;
/// <summary>
/// Provides a Mem0 backed <see cref="AIContextProvider"/> that persists conversation messages as memories
/// Provides a Mem0 backed <see cref="MessageAIContextProvider"/> that persists conversation messages as memories
/// and retrieves related memories to augment the agent invocation context.
/// </summary>
/// <remarks>
@@ -22,7 +22,7 @@ namespace Microsoft.Agents.AI.Mem0;
/// for new invocations using a semantic search endpoint. Retrieved memories are injected as user messages
/// to the model, prefixed by a configurable context prompt.
/// </remarks>
public sealed class Mem0Provider : AIContextProvider
public sealed class Mem0Provider : MessageAIContextProvider
{
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
@@ -92,7 +92,7 @@ public sealed class Mem0Provider : AIContextProvider
};
/// <inheritdoc />
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
Throw.IfNull(context);
@@ -101,7 +101,7 @@ public sealed class Mem0Provider : AIContextProvider
string queryText = string.Join(
Environment.NewLine,
(context.AIContext.Messages ?? [])
context.RequestMessages
.Where(m => !string.IsNullOrWhiteSpace(m.Text))
.Select(m => m.Text));
@@ -142,12 +142,9 @@ public sealed class Mem0Provider : AIContextProvider
}
}
return new AIContext
{
Messages = outputMessageText is not null
? [new ChatMessage(ChatRole.User, outputMessageText)]
: null
};
return outputMessageText is not null
? [new ChatMessage(ChatRole.User, outputMessageText)]
: [];
}
catch (ArgumentException)
{
@@ -166,7 +163,7 @@ public sealed class Mem0Provider : AIContextProvider
this.SanitizeLogData(searchScope.UserId));
}
return new AIContext();
return [];
}
}
@@ -151,6 +151,32 @@ public sealed class AIAgentBuilder
return this.Use((innerAgent, _) => new AnonymousDelegatingAIAgent(innerAgent, runFunc, runStreamingFunc));
}
/// <summary>
/// Adds one or more <see cref="MessageAIContextProvider"/> instances to the agent pipeline, enabling message enrichment
/// for any <see cref="AIAgent"/>.
/// </summary>
/// <param name="providers">
/// The <see cref="MessageAIContextProvider"/> instances to invoke before and after each agent invocation.
/// Providers are called in sequence, with each receiving the output of the previous provider.
/// </param>
/// <returns>The <see cref="AIAgentBuilder"/> with the providers added, enabling method chaining.</returns>
/// <exception cref="ArgumentException"><paramref name="providers"/> is empty.</exception>
/// <remarks>
/// <para>
/// This method wraps the inner agent with a <see cref="DelegatingAIAgent"/> that calls each provider's
/// <see cref="MessageAIContextProvider.InvokingAsync"/> in sequence before the inner agent runs,
/// and calls <see cref="AIContextProvider.InvokedAsync"/> on each provider after the inner agent completes.
/// </para>
/// <para>
/// This allows any <see cref="AIAgent"/> to benefit from <see cref="MessageAIContextProvider"/>-based
/// context enrichment, not just agents that natively support <see cref="AIContextProvider"/> instances.
/// </para>
/// </remarks>
public AIAgentBuilder Use(MessageAIContextProvider[] providers)
{
return this.Use((innerAgent, _) => new MessageAIContextProviderAgent(innerAgent, providers));
}
/// <summary>
/// Provides an empty <see cref="IServiceProvider"/> implementation.
/// </summary>
@@ -34,7 +34,7 @@ namespace Microsoft.Agents.AI;
/// injecting them automatically on each invocation.
/// </para>
/// </remarks>
public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
public sealed class ChatHistoryMemoryProvider : MessageAIContextProvider, IDisposable
{
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
private const int DefaultMaxResults = 3;
@@ -119,7 +119,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
public override string StateKey => this._sessionState.StateKey;
/// <inheritdoc />
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(context);
@@ -147,17 +147,46 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
};
}
return new AIContext
{
Messages = await this.ProvideMessagesAsync(
new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []),
cancellationToken).ConfigureAwait(false)
};
}
/// <inheritdoc />
protected override ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
// This code path is invoked using InvokingAsync on MessageAIContextProvider, which does not support tools and instructions,
// and OnDemandFunctionCalling requires tools.
if (this._searchTime != ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke)
{
throw new InvalidOperationException($"Using the {nameof(ChatHistoryMemoryProvider)} as a {nameof(MessageAIContextProvider)} is not supported when {nameof(ChatHistoryMemoryProviderOptions.SearchTime)} is set to {ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling}.");
}
return base.InvokingCoreAsync(context, cancellationToken);
}
/// <inheritdoc />
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(context);
var state = this._sessionState.GetOrInitializeState(context.Session);
var searchScope = state.SearchScope;
try
{
// Get the text from the current request messages
var requestText = string.Join("\n",
(context.AIContext.Messages ?? [])
(context.RequestMessages ?? [])
.Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text))
.Select(m => m.Text));
if (string.IsNullOrWhiteSpace(requestText))
{
return new AIContext();
return [];
}
// Search for relevant chat history
@@ -165,13 +194,10 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
if (string.IsNullOrWhiteSpace(contextText))
{
return new AIContext();
return [];
}
return new AIContext
{
Messages = [new ChatMessage(ChatRole.User, contextText)]
};
return [new ChatMessage(ChatRole.User, contextText)];
}
catch (Exception ex)
{
@@ -186,7 +212,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
this.SanitizeLogData(searchScope.UserId));
}
return new AIContext();
return [];
}
}
@@ -0,0 +1,167 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.AI;
/// <summary>
/// A delegating AI agent that enriches input messages by invoking a pipeline of <see cref="MessageAIContextProvider"/> instances
/// before delegating to the inner agent, and notifies those providers after the inner agent completes.
/// </summary>
internal sealed class MessageAIContextProviderAgent : DelegatingAIAgent
{
private readonly IReadOnlyList<MessageAIContextProvider> _providers;
/// <summary>
/// Initializes a new instance of the <see cref="MessageAIContextProviderAgent"/> class.
/// </summary>
/// <param name="innerAgent">The underlying agent instance that will handle the core operations.</param>
/// <param name="providers">The message AI context providers to invoke before and after the inner agent.</param>
public MessageAIContextProviderAgent(AIAgent innerAgent, IReadOnlyList<MessageAIContextProvider> providers)
: base(innerAgent)
{
Throw.IfNull(providers);
Throw.IfLessThanOrEqual(providers.Count, 0, nameof(providers));
this._providers = providers;
}
/// <inheritdoc/>
protected override async Task<AgentResponse> RunCoreAsync(
IEnumerable<ChatMessage> messages,
AgentSession? session = null,
AgentRunOptions? options = null,
CancellationToken cancellationToken = default)
{
var enrichedMessages = await this.InvokeProvidersAsync(messages, session, cancellationToken).ConfigureAwait(false);
AgentResponse response;
try
{
response = await this.InnerAgent.RunAsync(enrichedMessages, session, options, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
throw;
}
await this.NotifyProvidersOfSuccessAsync(session, enrichedMessages, response.Messages, cancellationToken).ConfigureAwait(false);
return response;
}
/// <inheritdoc/>
protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingAsync(
IEnumerable<ChatMessage> messages,
AgentSession? session = null,
AgentRunOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var enrichedMessages = await this.InvokeProvidersAsync(messages, session, cancellationToken).ConfigureAwait(false);
List<AgentResponseUpdate> responseUpdates = [];
IAsyncEnumerator<AgentResponseUpdate> enumerator;
try
{
enumerator = this.InnerAgent.RunStreamingAsync(enrichedMessages, session, options, cancellationToken).GetAsyncEnumerator(cancellationToken);
}
catch (Exception ex)
{
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
throw;
}
bool hasUpdates;
try
{
hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
throw;
}
while (hasUpdates)
{
var update = enumerator.Current;
responseUpdates.Add(update);
yield return update;
try
{
hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
throw;
}
}
var agentResponse = responseUpdates.ToAgentResponse();
await this.NotifyProvidersOfSuccessAsync(session, enrichedMessages, agentResponse.Messages, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Invokes each provider's <see cref="MessageAIContextProvider.InvokingAsync"/> in sequence,
/// passing the output of each as input to the next.
/// </summary>
private async Task<IEnumerable<ChatMessage>> InvokeProvidersAsync(
IEnumerable<ChatMessage> messages,
AgentSession? session,
CancellationToken cancellationToken)
{
var currentMessages = messages;
foreach (var provider in this._providers)
{
var context = new MessageAIContextProvider.InvokingContext(this, session, currentMessages);
currentMessages = await provider.InvokingAsync(context, cancellationToken).ConfigureAwait(false);
}
return currentMessages;
}
/// <summary>
/// Notifies each provider of a successful invocation.
/// </summary>
private async Task NotifyProvidersOfSuccessAsync(
AgentSession? session,
IEnumerable<ChatMessage> requestMessages,
IEnumerable<ChatMessage> responseMessages,
CancellationToken cancellationToken)
{
var invokedContext = new AIContextProvider.InvokedContext(this, session, requestMessages, responseMessages);
foreach (var provider in this._providers)
{
await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false);
}
}
/// <summary>
/// Notifies each provider of a failed invocation.
/// </summary>
private async Task NotifyProvidersOfFailureAsync(
AgentSession? session,
IEnumerable<ChatMessage> requestMessages,
Exception exception,
CancellationToken cancellationToken)
{
var invokedContext = new AIContextProvider.InvokedContext(this, session, requestMessages, exception);
foreach (var provider in this._providers)
{
await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false);
}
}
}
@@ -32,7 +32,7 @@ namespace Microsoft.Agents.AI;
/// multi-turn context to the retrieval layer without permanently altering the conversation history.
/// </para>
/// </remarks>
public sealed class TextSearchProvider : AIContextProvider
public sealed class TextSearchProvider : MessageAIContextProvider
{
private const string DefaultPluginSearchFunctionName = "Search";
private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question.";
@@ -91,7 +91,7 @@ public sealed class TextSearchProvider : AIContextProvider
public override string StateKey => this._sessionState.StateKey;
/// <inheritdoc />
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
{
if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke)
{
@@ -102,6 +102,30 @@ public sealed class TextSearchProvider : AIContextProvider
};
}
return new AIContext
{
Messages = await this.ProvideMessagesAsync(
new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []),
cancellationToken).ConfigureAwait(false)
};
}
/// <inheritdoc />
protected override ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
// This code path is invoked using InvokingAsync on MessageAIContextProvider, which does not support tools and instructions,
// and OnDemandFunctionCalling requires tools.
if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke)
{
throw new InvalidOperationException($"Using the {nameof(TextSearchProvider)} as a {nameof(MessageAIContextProvider)} is not supported when {nameof(TextSearchProviderOptions.SearchTime)} is set to {TextSearchProviderOptions.TextSearchBehavior.OnDemandFunctionCalling}.");
}
return base.InvokingCoreAsync(context, cancellationToken);
}
/// <inheritdoc />
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
// Retrieve recent messages from the session state.
var recentMessagesText = this._sessionState.GetOrInitializeState(context.Session).RecentMessagesText
?? [];
@@ -109,7 +133,7 @@ public sealed class TextSearchProvider : AIContextProvider
// Aggregate text from memory + current request messages.
var sbInput = new StringBuilder();
var requestMessagesText =
(context.AIContext.Messages ?? [])
(context.RequestMessages ?? [])
.Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text);
foreach (var messageText in recentMessagesText.Concat(requestMessagesText))
{
@@ -135,7 +159,7 @@ public sealed class TextSearchProvider : AIContextProvider
if (materialized.Count == 0)
{
return new AIContext();
return [];
}
// Format search results
@@ -146,15 +170,12 @@ public sealed class TextSearchProvider : AIContextProvider
this._logger.LogTrace("TextSearchProvider: Search Results\nInput:{Input}\nOutput:{MessageText}", input, formatted);
}
return new AIContext
{
Messages = [new ChatMessage(ChatRole.User, formatted)]
};
return [new ChatMessage(ChatRole.User, formatted)];
}
catch (Exception ex)
{
this._logger?.LogError(ex, "TextSearchProvider: Failed to search for data due to error");
return new AIContext();
return [];
}
}
@@ -0,0 +1,323 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// <summary>
/// Contains tests for the <see cref="MessageAIContextProvider"/> class.
/// </summary>
public class MessageAIContextProviderTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
#region InvokingAsync Tests
[Fact]
public async Task InvokingAsync_NullContext_ThrowsArgumentNullExceptionAsync()
{
// Arrange
var provider = new TestMessageProvider();
// Act & Assert
await Assert.ThrowsAsync<ArgumentNullException>(() => provider.InvokingAsync(null!).AsTask());
}
[Fact]
public async Task InvokingAsync_ReturnsInputAndProvidedMessagesAsync()
{
// Arrange
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context message") };
var provider = new TestMessageProvider(provideMessages: providedMessages);
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "User input")]);
// Act
var result = (await provider.InvokingAsync(context)).ToList();
// Assert - input messages + provided messages merged
Assert.Equal(2, result.Count);
Assert.Equal("User input", result[0].Text);
Assert.Equal("Context message", result[1].Text);
}
[Fact]
public async Task InvokingAsync_ReturnsOnlyInputMessages_WhenNoMessagesProvidedAsync()
{
// Arrange
var provider = new DefaultMessageProvider();
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hello")]);
// Act
var result = (await provider.InvokingAsync(context)).ToList();
// Assert
Assert.Single(result);
Assert.Equal("Hello", result[0].Text);
}
[Fact]
public async Task InvokingAsync_StampsProvidedMessagesWithAIContextProviderSourceAsync()
{
// Arrange
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Provided") };
var provider = new TestMessageProvider(provideMessages: providedMessages);
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []);
// Act
var result = (await provider.InvokingAsync(context)).ToList();
// Assert
Assert.Single(result);
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, result[0].GetAgentRequestMessageSourceType());
}
[Fact]
public async Task InvokingAsync_FiltersInputToExternalOnlyByDefaultAsync()
{
// Arrange
var provider = new TestMessageProvider(captureFilteredContext: true);
var externalMsg = new ChatMessage(ChatRole.User, "External");
var chatHistoryMsg = new ChatMessage(ChatRole.User, "History")
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
var contextProviderMsg = new ChatMessage(ChatRole.User, "ContextProvider")
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, "src");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg, contextProviderMsg]);
// Act
await provider.InvokingAsync(context);
// Assert - ProvideMessagesAsync received only External messages
Assert.NotNull(provider.LastFilteredContext);
var filteredMessages = provider.LastFilteredContext!.RequestMessages.ToList();
Assert.Single(filteredMessages);
Assert.Equal("External", filteredMessages[0].Text);
}
[Fact]
public async Task InvokingAsync_UsesCustomProvideInputFilterAsync()
{
// Arrange - filter that keeps all messages (not just External)
var provider = new TestMessageProvider(
captureFilteredContext: true,
provideInputMessageFilter: msgs => msgs);
var externalMsg = new ChatMessage(ChatRole.User, "External");
var chatHistoryMsg = new ChatMessage(ChatRole.User, "History")
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg]);
// Act
await provider.InvokingAsync(context);
// Assert - ProvideMessagesAsync received ALL messages (custom filter keeps everything)
Assert.NotNull(provider.LastFilteredContext);
var filteredMessages = provider.LastFilteredContext!.RequestMessages.ToList();
Assert.Equal(2, filteredMessages.Count);
}
[Fact]
public async Task InvokingAsync_MergesWithOriginalUnfilteredMessagesAsync()
{
// Arrange - default filter is External-only, but the MERGED result should include
// the original unfiltered input messages plus the provided messages
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Provided") };
var provider = new TestMessageProvider(provideMessages: providedMessages);
var externalMsg = new ChatMessage(ChatRole.User, "External");
var chatHistoryMsg = new ChatMessage(ChatRole.User, "History")
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg]);
// Act
var result = (await provider.InvokingAsync(context)).ToList();
// Assert - original 2 input messages + 1 provided message
Assert.Equal(3, result.Count);
Assert.Equal("External", result[0].Text);
Assert.Equal("History", result[1].Text);
Assert.Equal("Provided", result[2].Text);
}
#endregion
#region ProvideAIContextAsync Tests
[Fact]
public async Task ProvideAIContextAsync_PreservesInstructionsAndToolsAsync()
{
// Arrange
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context") };
var provider = new TestMessageProvider(provideMessages: providedMessages);
var inputTool = AIFunctionFactory.Create(() => "a", "inputTool");
var inputContext = new AIContext
{
Messages = [new ChatMessage(ChatRole.User, "Hello")],
Instructions = "Be helpful",
Tools = [inputTool]
};
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext);
// Act
var result = await provider.InvokingAsync(context);
// Assert - instructions and tools are preserved
Assert.Equal("Be helpful", result.Instructions);
Assert.NotNull(result.Tools);
Assert.Single(result.Tools!);
Assert.Equal("inputTool", result.Tools!.First().Name);
// Messages include original input + provided messages (with stamping)
var messages = result.Messages!.ToList();
Assert.Equal(2, messages.Count);
Assert.Equal("Hello", messages[0].Text);
Assert.Equal("Context", messages[1].Text);
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
}
[Fact]
public async Task ProvideAIContextAsync_PreservesNullInstructionsAndToolsAsync()
{
// Arrange
var provider = new DefaultMessageProvider();
var inputContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] };
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext);
// Act
var result = await provider.InvokingAsync(context);
// Assert
Assert.Null(result.Instructions);
Assert.Null(result.Tools);
var messages = result.Messages!.ToList();
Assert.Single(messages);
Assert.Equal("Hello", messages[0].Text);
}
#endregion
#region InvokingContext Tests
[Fact]
public void InvokingContext_Constructor_ThrowsForNullAgent()
{
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProvider.InvokingContext(null!, s_mockSession, []));
}
[Fact]
public void InvokingContext_Constructor_ThrowsForNullRequestMessages()
{
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!));
}
[Fact]
public void InvokingContext_Constructor_AllowsNullSession()
{
// Act
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, null, []);
// Assert
Assert.Null(context.Session);
}
[Fact]
public void InvokingContext_Properties_Roundtrip()
{
// Arrange
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
// Assert
Assert.Same(s_mockAgent, context.Agent);
Assert.Same(s_mockSession, context.Session);
Assert.Same(messages, context.RequestMessages);
}
[Fact]
public void InvokingContext_RequestMessages_SetterThrowsForNull()
{
// Arrange
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []);
// Act & Assert
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
}
[Fact]
public void InvokingContext_RequestMessages_SetterAcceptsValidValue()
{
// Arrange
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var newMessages = new List<ChatMessage> { new(ChatRole.User, "Updated") };
// Act
context.RequestMessages = newMessages;
// Assert
Assert.Same(newMessages, context.RequestMessages);
}
#endregion
#region GetService Tests
[Fact]
public void GetService_ReturnsProviderForMessageAIContextProviderType()
{
// Arrange
var provider = new TestMessageProvider();
// Act & Assert
Assert.Same(provider, provider.GetService(typeof(MessageAIContextProvider)));
Assert.Same(provider, provider.GetService(typeof(AIContextProvider)));
Assert.Same(provider, provider.GetService(typeof(TestMessageProvider)));
}
#endregion
#region Test helpers
private sealed class TestMessageProvider : MessageAIContextProvider
{
private readonly IEnumerable<ChatMessage>? _provideMessages;
private readonly bool _captureFilteredContext;
public InvokingContext? LastFilteredContext { get; private set; }
public TestMessageProvider(
IEnumerable<ChatMessage>? provideMessages = null,
bool captureFilteredContext = false,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
: base(provideInputMessageFilter, storeInputMessageFilter)
{
this._provideMessages = provideMessages;
this._captureFilteredContext = captureFilteredContext;
}
protected override ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
if (this._captureFilteredContext)
{
this.LastFilteredContext = context;
}
return new(this._provideMessages ?? []);
}
}
/// <summary>
/// A provider that uses only base class defaults (no overrides of ProvideMessagesAsync).
/// </summary>
private sealed class DefaultMessageProvider : MessageAIContextProvider;
#endregion
}
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// <summary>
@@ -7,6 +9,56 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// </summary>
public class ProviderSessionStateTests
{
#region Constructor Tests
[Fact]
public void Constructor_ThrowsForNullStateInitializer()
{
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new ProviderSessionState<TestState>(null!, "test-key"));
}
[Fact]
public void Constructor_ThrowsForNullStateKey()
{
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new ProviderSessionState<TestState>(_ => new TestState(), null!));
}
[Theory]
[InlineData("")]
[InlineData(" ")]
public void Constructor_ThrowsForEmptyOrWhitespaceStateKey(string stateKey)
{
// Act & Assert
Assert.Throws<ArgumentException>(() => new ProviderSessionState<TestState>(_ => new TestState(), stateKey));
}
[Fact]
public void Constructor_AcceptsNullJsonSerializerOptions()
{
// Act - should not throw
var sessionState = new ProviderSessionState<TestState>(_ => new TestState(), "test-key", jsonSerializerOptions: null);
// Assert - instance is created and functional
Assert.Equal("test-key", sessionState.StateKey);
}
[Fact]
public void Constructor_AcceptsCustomJsonSerializerOptions()
{
// Arrange
var customOptions = new System.Text.Json.JsonSerializerOptions();
// Act - should not throw
var sessionState = new ProviderSessionState<TestState>(_ => new TestState(), "test-key", customOptions);
// Assert - instance is created and functional
Assert.Equal("test-key", sessionState.StateKey);
}
#endregion
#region GetOrInitializeState Tests
[Fact]
@@ -547,6 +547,87 @@ public sealed class Mem0ProviderTests : IDisposable
Assert.Equal(2, memoryPosts.Count);
}
#region MessageAIContextProvider.InvokingAsync Tests
[Fact]
public async Task MessageInvokingAsync_SearchesAndReturnsMergedMessagesAsync()
{
// Arrange
this._handler.EnqueueJsonResponse("[ { \"id\": \"1\", \"memory\": \"Name is Caoimhe\", \"hash\": \"h\", \"metadata\": null, \"score\": 0.9, \"created_at\": \"2023-01-01T00:00:00Z\", \"updated_at\": null, \"user_id\": \"u\", \"app_id\": null, \"agent_id\": \"agent\", \"thread_id\": \"session\" } ]");
var storageScope = new Mem0ProviderScope
{
ApplicationId = "app",
AgentId = "agent",
ThreadId = "session",
UserId = "user"
};
var mockSession = new TestAgentSession();
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
var inputMsg = new ChatMessage(ChatRole.User, "What is my name?");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [inputMsg]);
// Act
var messages = (await sut.InvokingAsync(context)).ToList();
// Assert - input message + memory message, with stamping
Assert.Equal(2, messages.Count);
Assert.Equal("What is my name?", messages[0].Text);
Assert.Contains("Name is Caoimhe", messages[1].Text);
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
}
[Fact]
public async Task MessageInvokingAsync_NoMemories_ReturnsOnlyInputMessagesAsync()
{
// Arrange
this._handler.EnqueueJsonResponse("[]");
var storageScope = new Mem0ProviderScope
{
UserId = "user"
};
var mockSession = new TestAgentSession();
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
var inputMsg = new ChatMessage(ChatRole.User, "Hello");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [inputMsg]);
// Act
var messages = (await sut.InvokingAsync(context)).ToList();
// Assert
Assert.Single(messages);
Assert.Equal("Hello", messages[0].Text);
}
[Fact]
public async Task MessageInvokingAsync_DefaultFilter_ExcludesNonExternalMessagesAsync()
{
// Arrange
this._handler.EnqueueJsonResponse("[]");
var storageScope = new Mem0ProviderScope
{
UserId = "user"
};
var mockSession = new TestAgentSession();
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
var externalMsg = new ChatMessage(ChatRole.User, "External question");
var historyMsg = new ChatMessage(ChatRole.User, "History message")
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [externalMsg, historyMsg]);
// Act
await sut.InvokingAsync(context);
// Assert - Only External message used for search query
var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post && ContainsOrdinal(r.RequestMessage.RequestUri!.AbsoluteUri, "/v1/memories/search/"));
using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody);
Assert.Equal("External question", doc.RootElement.GetProperty("query").GetString());
}
#endregion
private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0;
public void Dispose()
@@ -743,7 +743,7 @@ public sealed class TextSearchProviderTests
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke,
RecentMessageMemoryLimit = 4
});
await newProvider.InvokingAsync(new(s_mockAgent, restoredSession, new AIContext()), CancellationToken.None); // Trigger search to read memory.
await newProvider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, restoredSession, new AIContext()), CancellationToken.None); // Trigger search to read memory.
// Assert
Assert.NotNull(capturedInput);
@@ -769,7 +769,7 @@ public sealed class TextSearchProviderTests
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke,
RecentMessageMemoryLimit = 3
});
await provider.InvokingAsync(new(s_mockAgent, session, new AIContext()), CancellationToken.None);
await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext()), CancellationToken.None);
// Assert
Assert.NotNull(capturedInput);
@@ -778,6 +778,101 @@ public sealed class TextSearchProviderTests
#endregion
#region MessageAIContextProvider.InvokingAsync Tests
[Fact]
public async Task MessageInvokingAsync_BeforeAIInvoke_SearchesAndReturnsMergedMessagesAsync()
{
// Arrange
List<TextSearchProvider.TextSearchResult> results =
[
new() { SourceName = "Doc1", Text = "Content of Doc1" }
];
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
=> Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>(results);
var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions
{
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke
});
var inputMsg = new ChatMessage(ChatRole.User, "Question?");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
// Act
var messages = (await provider.InvokingAsync(context)).ToList();
// Assert - input message + search result message, with stamping
Assert.Equal(2, messages.Count);
Assert.Equal("Question?", messages[0].Text);
Assert.Contains("Content of Doc1", messages[1].Text);
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
}
[Fact]
public async Task MessageInvokingAsync_OnDemand_ThrowsInvalidOperationExceptionAsync()
{
// Arrange
var provider = new TextSearchProvider(this.NoResultSearchAsync, new TextSearchProviderOptions
{
SearchTime = TextSearchProviderOptions.TextSearchBehavior.OnDemandFunctionCalling,
});
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]);
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(() => provider.InvokingAsync(context).AsTask());
}
[Fact]
public async Task MessageInvokingAsync_BeforeAIInvoke_NoResults_ReturnsOnlyInputMessagesAsync()
{
// Arrange
var provider = new TextSearchProvider(this.NoResultSearchAsync, new TextSearchProviderOptions
{
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke
});
var inputMsg = new ChatMessage(ChatRole.User, "Hello");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
// Act
var messages = (await provider.InvokingAsync(context)).ToList();
// Assert
Assert.Single(messages);
Assert.Equal("Hello", messages[0].Text);
}
[Fact]
public async Task MessageInvokingAsync_BeforeAIInvoke_DefaultFilter_ExcludesNonExternalMessagesAsync()
{
// Arrange
string? capturedInput = null;
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
{
capturedInput = input;
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
}
var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions
{
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke
});
var externalMsg = new ChatMessage(ChatRole.User, "External message");
var historyMsg = new ChatMessage(ChatRole.System, "From history")
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [externalMsg, historyMsg]);
// Act
await provider.InvokingAsync(context);
// Assert - Only External message used for search query
Assert.Equal("External message", capturedInput);
}
#endregion
private Task<IEnumerable<TextSearchProvider.TextSearchResult>> NoResultSearchAsync(string input, CancellationToken ct)
{
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
@@ -710,6 +710,147 @@ public class ChatHistoryMemoryProviderTests
#endregion
#region MessageAIContextProvider.InvokingAsync Tests
[Fact]
public async Task MessageInvokingAsync_BeforeAIInvoke_SearchesAndReturnsMergedMessagesAsync()
{
// Arrange
var storedItems = new List<VectorSearchResult<Dictionary<string, object?>>>
{
new(
new Dictionary<string, object?>
{
["MessageId"] = "msg-1",
["Content"] = "Previous message",
["Role"] = ChatRole.User.ToString(),
["CreatedAt"] = "2023-01-01T00:00:00.0000000+00:00"
},
0.9f)
};
this._vectorStoreCollectionMock
.Setup(c => c.SearchAsync(
It.IsAny<string>(),
It.IsAny<int>(),
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
It.IsAny<CancellationToken>()))
.Returns(ToAsyncEnumerableAsync(storedItems));
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
options: new ChatHistoryMemoryProviderOptions
{
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke
});
var inputMsg = new ChatMessage(ChatRole.User, "What was discussed?");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
// Act
var messages = (await provider.InvokingAsync(context)).ToList();
// Assert - input message + search result message, with stamping
Assert.Equal(2, messages.Count);
Assert.Equal("What was discussed?", messages[0].Text);
Assert.Contains("Previous message", messages[1].Text);
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
}
[Fact]
public async Task MessageInvokingAsync_OnDemand_ThrowsInvalidOperationExceptionAsync()
{
// Arrange
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
options: new ChatHistoryMemoryProviderOptions
{
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling
});
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]);
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(() => provider.InvokingAsync(context).AsTask());
}
[Fact]
public async Task MessageInvokingAsync_BeforeAIInvoke_NoResults_ReturnsOnlyInputMessagesAsync()
{
// Arrange
this._vectorStoreCollectionMock
.Setup(c => c.SearchAsync(
It.IsAny<string>(),
It.IsAny<int>(),
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
It.IsAny<CancellationToken>()))
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
options: new ChatHistoryMemoryProviderOptions
{
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke
});
var inputMsg = new ChatMessage(ChatRole.User, "Hello");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
// Act
var messages = (await provider.InvokingAsync(context)).ToList();
// Assert
Assert.Single(messages);
Assert.Equal("Hello", messages[0].Text);
}
[Fact]
public async Task MessageInvokingAsync_BeforeAIInvoke_DefaultFilter_ExcludesNonExternalMessagesAsync()
{
// Arrange
string? capturedQuery = null;
this._vectorStoreCollectionMock
.Setup(c => c.SearchAsync(
It.IsAny<string>(),
It.IsAny<int>(),
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
It.IsAny<CancellationToken>()))
.Callback<string, int, VectorSearchOptions<Dictionary<string, object?>>, CancellationToken>((query, _, _, _) => capturedQuery = query)
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
var provider = new ChatHistoryMemoryProvider(
this._vectorStoreMock.Object,
TestCollectionName,
1,
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
options: new ChatHistoryMemoryProviderOptions
{
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke
});
var externalMsg = new ChatMessage(ChatRole.User, "External message");
var historyMsg = new ChatMessage(ChatRole.System, "From history")
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [externalMsg, historyMsg]);
// Act
await provider.InvokingAsync(context);
// Assert - Only External message used for search query
Assert.Equal("External message", capturedQuery);
}
#endregion
private static async IAsyncEnumerable<T> ToAsyncEnumerableAsync<T>(IEnumerable<T> values)
{
await Task.Yield();
@@ -0,0 +1,469 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
namespace Microsoft.Agents.AI.UnitTests;
/// <summary>
/// Unit tests for the <see cref="MessageAIContextProviderAgent"/> class and
/// the <see cref="AIAgentBuilder.Use(MessageAIContextProvider[])"/> builder extension.
/// </summary>
public class MessageAIContextProviderAgentTests
{
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
#region Constructor Tests
[Fact]
public void Constructor_NullInnerAgent_ThrowsArgumentNullException()
{
// Arrange
var provider = new TestProvider();
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProviderAgent(null!, [provider]));
}
[Fact]
public void Constructor_NullProviders_ThrowsArgumentNullException()
{
// Arrange
var agent = CreateTestAgent();
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProviderAgent(agent, null!));
}
[Fact]
public void Constructor_EmptyProviders_ThrowsArgumentOutOfRangeException()
{
// Arrange
var agent = CreateTestAgent();
// Act & Assert
Assert.Throws<ArgumentOutOfRangeException>(() => new MessageAIContextProviderAgent(agent, []));
}
#endregion
#region RunAsync Tests
[Fact]
public async Task RunAsync_SingleProvider_EnrichesMessagesAndDelegatesToInnerAgentAsync()
{
// Arrange
var contextMessage = new ChatMessage(ChatRole.System, "Extra context");
var provider = new TestProvider(provideMessages: [contextMessage]);
IEnumerable<ChatMessage>? capturedMessages = null;
var innerAgent = CreateTestAgent(
runFunc: (messages, _, _, _) =>
{
capturedMessages = messages;
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
});
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
// Act
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
// Assert - inner agent received enriched messages (input + provider's message)
Assert.NotNull(capturedMessages);
var messageList = capturedMessages!.ToList();
Assert.Equal(2, messageList.Count);
Assert.Equal("Hello", messageList[0].Text);
Assert.Contains("Extra context", messageList[1].Text);
}
[Fact]
public async Task RunAsync_MultipleProviders_CalledInSequenceAsync()
{
// Arrange
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]);
var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")]);
IEnumerable<ChatMessage>? capturedMessages = null;
var innerAgent = CreateTestAgent(
runFunc: (messages, _, _, _) =>
{
capturedMessages = messages;
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
});
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
// Act
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
// Assert - inner agent received messages from both providers in sequence
Assert.NotNull(capturedMessages);
var messageList = capturedMessages!.ToList();
Assert.Equal(3, messageList.Count);
Assert.Equal("Hello", messageList[0].Text);
Assert.Contains("From provider 1", messageList[1].Text);
Assert.Contains("From provider 2", messageList[2].Text);
}
[Fact]
public async Task RunAsync_SequentialProviders_EachReceivesPreviousOutputAsync()
{
// Arrange - provider 2 captures the filtered messages it receives in ProvideMessagesAsync.
// The default filter only includes External messages, so provider 1's stamped messages
// (marked as AIContextProvider) are filtered out before reaching provider 2's ProvideMessagesAsync.
// However, the full unfiltered output from provider 1 is passed to provider 2's InvokingAsync,
// and the inner agent receives the full merged output from both providers.
IEnumerable<ChatMessage>? provider2ReceivedMessages = null;
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]);
var provider2 = new TestProvider(
provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")],
onInvoking: messages => provider2ReceivedMessages = messages.ToList());
var innerAgent = CreateTestAgent(
runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])));
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
// Act
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
// Assert - provider 2's ProvideMessagesAsync received only External messages (filtered)
Assert.NotNull(provider2ReceivedMessages);
var received = provider2ReceivedMessages!.ToList();
Assert.Single(received);
Assert.Equal("Hello", received[0].Text);
}
[Fact]
public async Task RunAsync_OnSuccess_InvokedAsyncCalledOnAllProvidersAsync()
{
// Arrange
var provider1 = new TestProvider();
var provider2 = new TestProvider();
var innerAgent = CreateTestAgent(
runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])));
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
// Act
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
// Assert
Assert.True(provider1.InvokedAsyncCalled);
Assert.True(provider2.InvokedAsyncCalled);
Assert.Null(provider1.LastInvokedContext!.InvokeException);
Assert.Null(provider2.LastInvokedContext!.InvokeException);
}
[Fact]
public async Task RunAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync()
{
// Arrange
var provider = new TestProvider();
var expectedException = new InvalidOperationException("Agent failed");
var innerAgent = CreateTestAgent(
runFunc: (_, _, _, _) => throw expectedException);
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(() =>
agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession));
Assert.True(provider.InvokedAsyncCalled);
Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException);
}
[Fact]
public async Task RunAsync_OnSuccess_InvokedContextContainsResponseMessagesAsync()
{
// Arrange
var provider = new TestProvider();
var responseMessage = new ChatMessage(ChatRole.Assistant, "Response text");
var innerAgent = CreateTestAgent(
runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([responseMessage])));
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
// Act
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
// Assert
Assert.NotNull(provider.LastInvokedContext?.ResponseMessages);
Assert.Contains(provider.LastInvokedContext!.ResponseMessages!, m => m.Text == "Response text");
}
#endregion
#region RunStreamingAsync Tests
[Fact]
public async Task RunStreamingAsync_SingleProvider_EnrichesMessagesAndStreamsAsync()
{
// Arrange
var contextMessage = new ChatMessage(ChatRole.System, "Extra context");
var provider = new TestProvider(provideMessages: [contextMessage]);
IEnumerable<ChatMessage>? capturedMessages = null;
var innerAgent = CreateTestAgent(
runStreamingFunc: (messages, _, _, _) =>
{
capturedMessages = messages;
return ToAsyncEnumerableAsync(
new AgentResponseUpdate(ChatRole.Assistant, "Part1"),
new AgentResponseUpdate(ChatRole.Assistant, "Part2"));
});
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
// Act
var updates = new List<AgentResponseUpdate>();
await foreach (var update in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
{
updates.Add(update);
}
// Assert - streaming updates received
Assert.Equal(2, updates.Count);
// Assert - inner agent received enriched messages
Assert.NotNull(capturedMessages);
var messageList = capturedMessages!.ToList();
Assert.Equal(2, messageList.Count);
}
[Fact]
public async Task RunStreamingAsync_OnSuccess_InvokedAsyncCalledAfterAllUpdatesAsync()
{
// Arrange
var provider = new TestProvider();
var innerAgent = CreateTestAgent(
runStreamingFunc: (_, _, _, _) => ToAsyncEnumerableAsync(
new AgentResponseUpdate(ChatRole.Assistant, "Response")));
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
// Act - consume all updates
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
{
}
// Assert
Assert.True(provider.InvokedAsyncCalled);
Assert.Null(provider.LastInvokedContext!.InvokeException);
}
[Fact]
public async Task RunStreamingAsync_OnSuccess_InvokedContextContainsAccumulatedResponseAsync()
{
// Arrange
var provider = new TestProvider();
var innerAgent = CreateTestAgent(
runStreamingFunc: (_, _, _, _) => ToAsyncEnumerableAsync(
new AgentResponseUpdate(ChatRole.Assistant, "Hello "),
new AgentResponseUpdate(ChatRole.Assistant, "World")));
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
// Act - consume all updates
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
{
}
// Assert - InvokedAsync received the accumulated response messages
Assert.NotNull(provider.LastInvokedContext?.ResponseMessages);
var responseMessages = provider.LastInvokedContext!.ResponseMessages!.ToList();
Assert.True(responseMessages.Count > 0);
}
[Fact]
public async Task RunStreamingAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync()
{
// Arrange
var provider = new TestProvider();
var expectedException = new InvalidOperationException("Stream failed");
var innerAgent = CreateTestAgent(
runStreamingFunc: (_, _, _, _) => throw expectedException);
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
// Act & Assert
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
{
}
});
Assert.True(provider.InvokedAsyncCalled);
Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException);
}
[Fact]
public async Task RunStreamingAsync_MultipleProviders_CalledInSequenceAsync()
{
// Arrange
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]);
var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")]);
IEnumerable<ChatMessage>? capturedMessages = null;
var innerAgent = CreateTestAgent(
runStreamingFunc: (messages, _, _, _) =>
{
capturedMessages = messages;
return ToAsyncEnumerableAsync(new AgentResponseUpdate(ChatRole.Assistant, "Response"));
});
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
// Act
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
{
}
// Assert
Assert.NotNull(capturedMessages);
var messageList = capturedMessages!.ToList();
Assert.Equal(3, messageList.Count);
Assert.Equal("Hello", messageList[0].Text);
Assert.Contains("From provider 1", messageList[1].Text);
Assert.Contains("From provider 2", messageList[2].Text);
}
#endregion
#region Builder Extension Tests
[Fact]
public async Task UseExtension_CreatesWorkingPipelineAsync()
{
// Arrange
var contextMessage = new ChatMessage(ChatRole.System, "Pipeline context");
var provider = new TestProvider(provideMessages: [contextMessage]);
IEnumerable<ChatMessage>? capturedMessages = null;
var innerAgent = CreateTestAgent(
runFunc: (messages, _, _, _) =>
{
capturedMessages = messages;
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
});
var pipeline = new AIAgentBuilder(innerAgent)
.Use([provider])
.Build();
// Act
await pipeline.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
// Assert
Assert.NotNull(capturedMessages);
var messageList = capturedMessages!.ToList();
Assert.Equal(2, messageList.Count);
Assert.Equal("Hello", messageList[0].Text);
Assert.Contains("Pipeline context", messageList[1].Text);
}
[Fact]
public async Task UseExtension_MultipleProviders_AllAppliedAsync()
{
// Arrange
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "P1")]);
var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "P2")]);
IEnumerable<ChatMessage>? capturedMessages = null;
var innerAgent = CreateTestAgent(
runFunc: (messages, _, _, _) =>
{
capturedMessages = messages;
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
});
var pipeline = new AIAgentBuilder(innerAgent)
.Use([provider1, provider2])
.Build();
// Act
await pipeline.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
// Assert
Assert.NotNull(capturedMessages);
var messageList = capturedMessages!.ToList();
Assert.Equal(3, messageList.Count);
}
#endregion
#region Helpers
private static TestAIAgent CreateTestAgent(
Func<IEnumerable<ChatMessage>, AgentSession?, AgentRunOptions?, CancellationToken, Task<AgentResponse>>? runFunc = null,
Func<IEnumerable<ChatMessage>, AgentSession?, AgentRunOptions?, CancellationToken, IAsyncEnumerable<AgentResponseUpdate>>? runStreamingFunc = null)
{
var agent = new TestAIAgent();
if (runFunc is not null)
{
agent.RunAsyncFunc = runFunc;
}
if (runStreamingFunc is not null)
{
agent.RunStreamingAsyncFunc = runStreamingFunc;
}
return agent;
}
private static async IAsyncEnumerable<AgentResponseUpdate> ToAsyncEnumerableAsync(params AgentResponseUpdate[] updates)
{
foreach (var update in updates)
{
yield return update;
}
await Task.CompletedTask;
}
/// <summary>
/// A test implementation of <see cref="MessageAIContextProvider"/> that records invocation calls.
/// </summary>
private sealed class TestProvider : MessageAIContextProvider
{
private readonly IEnumerable<ChatMessage> _provideMessages;
private readonly Action<IEnumerable<ChatMessage>>? _onInvoking;
public bool InvokedAsyncCalled { get; private set; }
public InvokedContext? LastInvokedContext { get; private set; }
public TestProvider(
IEnumerable<ChatMessage>? provideMessages = null,
Action<IEnumerable<ChatMessage>>? onInvoking = null)
{
this._provideMessages = provideMessages ?? [];
this._onInvoking = onInvoking;
}
protected override ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(
InvokingContext context,
CancellationToken cancellationToken = default)
{
this._onInvoking?.Invoke(context.RequestMessages);
return new ValueTask<IEnumerable<ChatMessage>>(this._provideMessages);
}
protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
this.InvokedAsyncCalled = true;
this.LastInvokedContext = context;
return default;
}
}
#endregion
}