mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: Support a message only AIContextProvider as an AIAgent Decorator (#4009)
* Support a message only AIContextProvider as an AIAgent Decorator * Fix formatting * Address PR comments.
This commit is contained in:
committed by
GitHub
Unverified
parent
5ee06853a1
commit
40d3a0655c
@@ -2,8 +2,9 @@
|
||||
|
||||
// This sample shows multiple middleware layers working together with Azure OpenAI:
|
||||
// chat client (global/per-request), agent run (PII filtering and guardrails),
|
||||
// function invocation (logging and result overrides), and human-in-the-loop
|
||||
// approval workflows for sensitive function calls.
|
||||
// function invocation (logging and result overrides), human-in-the-loop
|
||||
// approval workflows for sensitive function calls, and MessageAIContextProvider
|
||||
// middleware for injecting additional context messages into the agent pipeline.
|
||||
|
||||
using System.ComponentModel;
|
||||
using System.Text.RegularExpressions;
|
||||
@@ -96,6 +97,20 @@ var response = await originalAgent // Using per-request middleware pipeline with
|
||||
|
||||
Console.WriteLine($"Per-request middleware response: {response}");
|
||||
|
||||
// MessageAIContextProvider middleware that injects additional messages into the agent request.
|
||||
// This allows any AIAgent (not just ChatClientAgent) to benefit from MessageAIContextProvider-based
|
||||
// context enrichment. Multiple providers can be passed to Use and they are called in sequence,
|
||||
// each receiving the output of the previous one.
|
||||
Console.WriteLine("\n\n=== Example 5: MessageAIContextProvider middleware ===");
|
||||
|
||||
var contextProviderAgent = originalAgent
|
||||
.AsBuilder()
|
||||
.Use([new DateTimeContextProvider()])
|
||||
.Build();
|
||||
|
||||
var contextResponse = await contextProviderAgent.RunAsync("Is it almost time for lunch?");
|
||||
Console.WriteLine($"Context-enriched response: {contextResponse}");
|
||||
|
||||
// Function invocation middleware that logs before and after function calls.
|
||||
async ValueTask<object?> FunctionCallMiddleware(AIAgent agent, FunctionInvocationContext context, Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> next, CancellationToken cancellationToken)
|
||||
{
|
||||
@@ -259,3 +274,23 @@ async Task<ChatResponse> PerRequestChatClientMiddleware(IEnumerable<ChatMessage>
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A <see cref="MessageAIContextProvider"/> that injects the current date and time into the agent's context.
|
||||
/// This is a simple example of how to use a MessageAIContextProvider to enrich agent messages
|
||||
/// via the <see cref="AIAgentBuilder.Use(MessageAIContextProvider[])"/> extension method.
|
||||
/// </summary>
|
||||
internal sealed class DateTimeContextProvider : MessageAIContextProvider
|
||||
{
|
||||
protected override ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(
|
||||
InvokingContext context,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
Console.WriteLine("DateTimeContextProvider - Injecting current date/time context");
|
||||
|
||||
return new ValueTask<IEnumerable<ChatMessage>>(
|
||||
[
|
||||
new ChatMessage(ChatRole.User, $"For reference, the current date and time is: {DateTimeOffset.Now}")
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ This sample demonstrates how to add middleware to intercept:
|
||||
5. Per‑request chat client middleware
|
||||
6. Per‑request function pipeline with approval
|
||||
7. Combining agent‑level and per‑request middleware
|
||||
8. MessageAIContextProvider middleware via `AIAgentBuilder.Use(...)` for injecting additional context messages
|
||||
|
||||
## Function Invocation Middleware
|
||||
|
||||
|
||||
@@ -146,11 +146,11 @@ namespace SampleApp
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// An <see cref="AIContextProvider"/> which searches for upcoming calendar events and adds them to the AI context.
|
||||
/// A <see cref="MessageAIContextProvider"/> which searches for upcoming calendar events and adds them to the AI context.
|
||||
/// </summary>
|
||||
internal sealed class CalendarSearchAIContextProvider(Func<Task<string[]>> loadNextThreeCalendarEvents) : AIContextProvider
|
||||
internal sealed class CalendarSearchAIContextProvider(Func<Task<string[]>> loadNextThreeCalendarEvents) : MessageAIContextProvider
|
||||
{
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
protected override async ValueTask<IEnumerable<MEAI.ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var events = await loadNextThreeCalendarEvents();
|
||||
|
||||
@@ -161,10 +161,7 @@ namespace SampleApp
|
||||
outputMessageBuilder.AppendLine($" - {calendarEvent}");
|
||||
}
|
||||
|
||||
return new AIContext
|
||||
{
|
||||
Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())]
|
||||
};
|
||||
return [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,9 +34,6 @@ public abstract class AIContextProvider
|
||||
private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<ChatMessage> messages)
|
||||
=> messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External);
|
||||
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _provideInputMessageFilter;
|
||||
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storeInputMessageFilter;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="AIContextProvider"/> class.
|
||||
/// </summary>
|
||||
@@ -46,10 +43,20 @@ public abstract class AIContextProvider
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
|
||||
{
|
||||
this._provideInputMessageFilter = provideInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
this._storeInputMessageFilter = storeInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
this.ProvideInputMessageFilter = provideInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
this.StoreInputMessageFilter = storeInputMessageFilter ?? DefaultExternalOnlyFilter;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the filter function to apply to input messages before providing context via <see cref="ProvideAIContextAsync"/>.
|
||||
/// </summary>
|
||||
protected Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> ProvideInputMessageFilter { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the filter function to apply to request messages before storing context via <see cref="StoreAIContextAsync"/>.
|
||||
/// </summary>
|
||||
protected Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> StoreInputMessageFilter { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the key used to store the provider state in the <see cref="AgentSession.StateBag"/>.
|
||||
/// </summary>
|
||||
@@ -120,7 +127,7 @@ public abstract class AIContextProvider
|
||||
new AIContext
|
||||
{
|
||||
Instructions = inputContext.Instructions,
|
||||
Messages = inputContext.Messages is not null ? this._provideInputMessageFilter(inputContext.Messages) : null,
|
||||
Messages = inputContext.Messages is not null ? this.ProvideInputMessageFilter(inputContext.Messages) : null,
|
||||
Tools = inputContext.Tools
|
||||
});
|
||||
|
||||
@@ -254,7 +261,7 @@ public abstract class AIContextProvider
|
||||
return default;
|
||||
}
|
||||
|
||||
var subContext = new InvokedContext(context.Agent, context.Session, this._storeInputMessageFilter(context.RequestMessages), context.ResponseMessages!);
|
||||
var subContext = new InvokedContext(context.Agent, context.Session, this.StoreInputMessageFilter(context.RequestMessages), context.ResponseMessages!);
|
||||
return this.StoreAIContextAsync(subContext, cancellationToken);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// Provides an abstract base class for components that enhance AI context during agent invocations by supplying additional chat messages.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// A message AI context provider is a component that participates in the agent invocation lifecycle by:
|
||||
/// <list type="bullet">
|
||||
/// <item><description>Listening to changes in conversations</description></item>
|
||||
/// <item><description>Providing additional messages to agents during invocation</description></item>
|
||||
/// <item><description>Processing invocation results for state management or learning</description></item>
|
||||
/// </list>
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// Context providers operate through a two-phase lifecycle: they are called at the start of invocation via
|
||||
/// <see cref="AIContextProvider.InvokingAsync"/> to provide context, and optionally called at the end of invocation via
|
||||
/// <see cref="AIContextProvider.InvokedAsync"/> to process results.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public abstract class MessageAIContextProvider : AIContextProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="MessageAIContextProvider"/> class.
|
||||
/// </summary>
|
||||
/// <param name="provideInputMessageFilter">An optional filter function to apply to input messages before providing messages via <see cref="ProvideMessagesAsync"/>. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
|
||||
/// <param name="storeInputMessageFilter">An optional filter function to apply to request messages before storing messages via <see cref="AIContextProvider.StoreAIContextAsync"/>. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
|
||||
protected MessageAIContextProvider(
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
|
||||
: base(provideInputMessageFilter, storeInputMessageFilter)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Call ProvideMessagesAsync directly to return only additional messages.
|
||||
// The base AIContextProvider.InvokingCoreAsync handles merging with the original input and stamping.
|
||||
return new AIContext
|
||||
{
|
||||
Messages = await this.ProvideMessagesAsync(
|
||||
new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []),
|
||||
cancellationToken).ConfigureAwait(false)
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Called at the start of agent invocation to provide additional messages.
|
||||
/// </summary>
|
||||
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>A task that represents the asynchronous operation. The task result contains the <see cref="IEnumerable{ChatMessage}"/> to be used by the agent during this invocation.</returns>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// Implementers can load any additional messages required at this time, such as:
|
||||
/// <list type="bullet">
|
||||
/// <item><description>Retrieving relevant information from knowledge bases</description></item>
|
||||
/// <item><description>Adding system instructions or prompts</description></item>
|
||||
/// <item><description>Injecting contextual messages from conversation history</description></item>
|
||||
/// </list>
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public ValueTask<IEnumerable<ChatMessage>> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
=> this.InvokingCoreAsync(Throw.IfNull(context), cancellationToken);
|
||||
|
||||
/// <summary>
|
||||
/// Called at the start of agent invocation to provide additional messages.
|
||||
/// </summary>
|
||||
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>A task that represents the asynchronous operation. The task result contains the <see cref="IEnumerable{ChatMessage}"/> to be used by the agent during this invocation.</returns>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// Implementers can load any additional messages required at this time, such as:
|
||||
/// <list type="bullet">
|
||||
/// <item><description>Retrieving relevant information from knowledge bases</description></item>
|
||||
/// <item><description>Adding system instructions or prompts</description></item>
|
||||
/// <item><description>Injecting contextual messages from conversation history</description></item>
|
||||
/// </list>
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// The default implementation of this method filters the input messages using the configured provide-input message filter
|
||||
/// (which defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages),
|
||||
/// then calls <see cref="ProvideMessagesAsync"/> to get additional messages,
|
||||
/// stamps any messages with <see cref="AgentRequestMessageSourceType.AIContextProvider"/> source attribution,
|
||||
/// and merges the returned messages with the original (unfiltered) input messages.
|
||||
/// For most scenarios, overriding <see cref="ProvideMessagesAsync"/> is sufficient to provide additional messages,
|
||||
/// while still benefiting from the default filtering, merging and source stamping behavior.
|
||||
/// However, for scenarios that require more control over message filtering, merging or source stamping, overriding this method
|
||||
/// allows you to directly control the full <see cref="IEnumerable{ChatMessage}"/> returned for the invocation.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
protected virtual async ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var inputMessages = context.RequestMessages;
|
||||
|
||||
// Create a filtered context for ProvideMessagesAsync, filtering input messages
|
||||
// to exclude non-external messages (e.g. chat history, other AI context provider messages).
|
||||
var filteredContext = new InvokingContext(
|
||||
context.Agent,
|
||||
context.Session,
|
||||
this.ProvideInputMessageFilter(inputMessages));
|
||||
|
||||
var providedMessages = await this.ProvideMessagesAsync(filteredContext, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
// Stamp and merge provided messages.
|
||||
providedMessages = providedMessages.Select(m => m.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!));
|
||||
return inputMessages.Concat(providedMessages);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// When overridden in a derived class, provides additional messages to be merged with the input messages for the current invocation.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// This method is called from <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/>.
|
||||
/// Note that <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/> can be overridden to directly control messages merging and source stamping, in which case
|
||||
/// it is up to the implementer to call this method as needed to retrieve the additional messages.
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// In contrast with <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/>, this method only returns additional messages to be merged with the input,
|
||||
/// while <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/> is responsible for returning the full merged <see cref="IEnumerable{ChatMessage}"/> for the invocation.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
/// <param name="context">Contains the request context including the caller provided messages that will be used by the agent for this invocation.</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>
|
||||
/// A task that represents the asynchronous operation. The task result contains an <see cref="IEnumerable{ChatMessage}"/>
|
||||
/// with additional messages to be merged with the input messages.
|
||||
/// </returns>
|
||||
protected virtual ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return new ValueTask<IEnumerable<ChatMessage>>([]);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Contains the context information provided to <see cref="InvokingCoreAsync(InvokingContext, CancellationToken)"/>.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This class provides context about the invocation before the underlying AI model is invoked, including the messages
|
||||
/// that will be used. Message AI Context providers can use this information to determine what additional messages
|
||||
/// should be provided for the invocation.
|
||||
/// </remarks>
|
||||
public new sealed class InvokingContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InvokingContext"/> class with the specified request messages.
|
||||
/// </summary>
|
||||
/// <param name="agent">The agent being invoked.</param>
|
||||
/// <param name="session">The session associated with the agent invocation.</param>
|
||||
/// <param name="requestMessages">The messages to be used by the agent for this invocation.</param>
|
||||
/// <exception cref="ArgumentNullException"><paramref name="agent"/> or <paramref name="requestMessages"/> is <see langword="null"/>.</exception>
|
||||
public InvokingContext(
|
||||
AIAgent agent,
|
||||
AgentSession? session,
|
||||
IEnumerable<ChatMessage> requestMessages)
|
||||
{
|
||||
this.Agent = Throw.IfNull(agent);
|
||||
this.Session = session;
|
||||
this.RequestMessages = Throw.IfNull(requestMessages);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent that is being invoked.
|
||||
/// </summary>
|
||||
public AIAgent Agent { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent session associated with the agent invocation.
|
||||
/// </summary>
|
||||
public AgentSession? Session { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the messages that will be used by the agent for this invocation. <see cref="MessageAIContextProvider"/> instances can modify
|
||||
/// and return or return a new message list to add additional messages for the invocation.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// A collection of <see cref="ChatMessage"/> instances representing the messages that will be used by the agent for this invocation.
|
||||
/// </value>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// If multiple <see cref="MessageAIContextProvider"/> instances are used in the same invocation, each <see cref="MessageAIContextProvider"/>
|
||||
/// will receive the messages returned by the previous <see cref="MessageAIContextProvider"/> allowing them to build on top of each other's context.
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// The first <see cref="MessageAIContextProvider"/> in the invocation pipeline will receive the
|
||||
/// caller provided messages.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public IEnumerable<ChatMessage> RequestMessages { get; set { field = Throw.IfNull(value); } }
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
using System;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
@@ -39,8 +40,8 @@ public class ProviderSessionState<TState>
|
||||
string stateKey,
|
||||
JsonSerializerOptions? jsonSerializerOptions = null)
|
||||
{
|
||||
this._stateInitializer = stateInitializer;
|
||||
this.StateKey = stateKey;
|
||||
this._stateInitializer = Throw.IfNull(stateInitializer);
|
||||
this.StateKey = Throw.IfNullOrWhitespace(stateKey);
|
||||
this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ using Microsoft.Shared.Diagnostics;
|
||||
namespace Microsoft.Agents.AI.Mem0;
|
||||
|
||||
/// <summary>
|
||||
/// Provides a Mem0 backed <see cref="AIContextProvider"/> that persists conversation messages as memories
|
||||
/// Provides a Mem0 backed <see cref="MessageAIContextProvider"/> that persists conversation messages as memories
|
||||
/// and retrieves related memories to augment the agent invocation context.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
@@ -22,7 +22,7 @@ namespace Microsoft.Agents.AI.Mem0;
|
||||
/// for new invocations using a semantic search endpoint. Retrieved memories are injected as user messages
|
||||
/// to the model, prefixed by a configurable context prompt.
|
||||
/// </remarks>
|
||||
public sealed class Mem0Provider : AIContextProvider
|
||||
public sealed class Mem0Provider : MessageAIContextProvider
|
||||
{
|
||||
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
|
||||
|
||||
@@ -92,7 +92,7 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
};
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
Throw.IfNull(context);
|
||||
|
||||
@@ -101,7 +101,7 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
|
||||
string queryText = string.Join(
|
||||
Environment.NewLine,
|
||||
(context.AIContext.Messages ?? [])
|
||||
context.RequestMessages
|
||||
.Where(m => !string.IsNullOrWhiteSpace(m.Text))
|
||||
.Select(m => m.Text));
|
||||
|
||||
@@ -142,12 +142,9 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
}
|
||||
}
|
||||
|
||||
return new AIContext
|
||||
{
|
||||
Messages = outputMessageText is not null
|
||||
? [new ChatMessage(ChatRole.User, outputMessageText)]
|
||||
: null
|
||||
};
|
||||
return outputMessageText is not null
|
||||
? [new ChatMessage(ChatRole.User, outputMessageText)]
|
||||
: [];
|
||||
}
|
||||
catch (ArgumentException)
|
||||
{
|
||||
@@ -166,7 +163,7 @@ public sealed class Mem0Provider : AIContextProvider
|
||||
this.SanitizeLogData(searchScope.UserId));
|
||||
}
|
||||
|
||||
return new AIContext();
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -151,6 +151,32 @@ public sealed class AIAgentBuilder
|
||||
return this.Use((innerAgent, _) => new AnonymousDelegatingAIAgent(innerAgent, runFunc, runStreamingFunc));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds one or more <see cref="MessageAIContextProvider"/> instances to the agent pipeline, enabling message enrichment
|
||||
/// for any <see cref="AIAgent"/>.
|
||||
/// </summary>
|
||||
/// <param name="providers">
|
||||
/// The <see cref="MessageAIContextProvider"/> instances to invoke before and after each agent invocation.
|
||||
/// Providers are called in sequence, with each receiving the output of the previous provider.
|
||||
/// </param>
|
||||
/// <returns>The <see cref="AIAgentBuilder"/> with the providers added, enabling method chaining.</returns>
|
||||
/// <exception cref="ArgumentException"><paramref name="providers"/> is empty.</exception>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// This method wraps the inner agent with a <see cref="DelegatingAIAgent"/> that calls each provider's
|
||||
/// <see cref="MessageAIContextProvider.InvokingAsync"/> in sequence before the inner agent runs,
|
||||
/// and calls <see cref="AIContextProvider.InvokedAsync"/> on each provider after the inner agent completes.
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// This allows any <see cref="AIAgent"/> to benefit from <see cref="MessageAIContextProvider"/>-based
|
||||
/// context enrichment, not just agents that natively support <see cref="AIContextProvider"/> instances.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public AIAgentBuilder Use(MessageAIContextProvider[] providers)
|
||||
{
|
||||
return this.Use((innerAgent, _) => new MessageAIContextProviderAgent(innerAgent, providers));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Provides an empty <see cref="IServiceProvider"/> implementation.
|
||||
/// </summary>
|
||||
|
||||
@@ -34,7 +34,7 @@ namespace Microsoft.Agents.AI;
|
||||
/// injecting them automatically on each invocation.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
public sealed class ChatHistoryMemoryProvider : MessageAIContextProvider, IDisposable
|
||||
{
|
||||
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
|
||||
private const int DefaultMaxResults = 3;
|
||||
@@ -119,7 +119,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
_ = Throw.IfNull(context);
|
||||
|
||||
@@ -147,17 +147,46 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
};
|
||||
}
|
||||
|
||||
return new AIContext
|
||||
{
|
||||
Messages = await this.ProvideMessagesAsync(
|
||||
new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []),
|
||||
cancellationToken).ConfigureAwait(false)
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// This code path is invoked using InvokingAsync on MessageAIContextProvider, which does not support tools and instructions,
|
||||
// and OnDemandFunctionCalling requires tools.
|
||||
if (this._searchTime != ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke)
|
||||
{
|
||||
throw new InvalidOperationException($"Using the {nameof(ChatHistoryMemoryProvider)} as a {nameof(MessageAIContextProvider)} is not supported when {nameof(ChatHistoryMemoryProviderOptions.SearchTime)} is set to {ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling}.");
|
||||
}
|
||||
|
||||
return base.InvokingCoreAsync(context, cancellationToken);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
_ = Throw.IfNull(context);
|
||||
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var searchScope = state.SearchScope;
|
||||
|
||||
try
|
||||
{
|
||||
// Get the text from the current request messages
|
||||
var requestText = string.Join("\n",
|
||||
(context.AIContext.Messages ?? [])
|
||||
(context.RequestMessages ?? [])
|
||||
.Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text))
|
||||
.Select(m => m.Text));
|
||||
|
||||
if (string.IsNullOrWhiteSpace(requestText))
|
||||
{
|
||||
return new AIContext();
|
||||
return [];
|
||||
}
|
||||
|
||||
// Search for relevant chat history
|
||||
@@ -165,13 +194,10 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
|
||||
if (string.IsNullOrWhiteSpace(contextText))
|
||||
{
|
||||
return new AIContext();
|
||||
return [];
|
||||
}
|
||||
|
||||
return new AIContext
|
||||
{
|
||||
Messages = [new ChatMessage(ChatRole.User, contextText)]
|
||||
};
|
||||
return [new ChatMessage(ChatRole.User, contextText)];
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -186,7 +212,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
this.SanitizeLogData(searchScope.UserId));
|
||||
}
|
||||
|
||||
return new AIContext();
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// A delegating AI agent that enriches input messages by invoking a pipeline of <see cref="MessageAIContextProvider"/> instances
|
||||
/// before delegating to the inner agent, and notifies those providers after the inner agent completes.
|
||||
/// </summary>
|
||||
internal sealed class MessageAIContextProviderAgent : DelegatingAIAgent
|
||||
{
|
||||
private readonly IReadOnlyList<MessageAIContextProvider> _providers;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="MessageAIContextProviderAgent"/> class.
|
||||
/// </summary>
|
||||
/// <param name="innerAgent">The underlying agent instance that will handle the core operations.</param>
|
||||
/// <param name="providers">The message AI context providers to invoke before and after the inner agent.</param>
|
||||
public MessageAIContextProviderAgent(AIAgent innerAgent, IReadOnlyList<MessageAIContextProvider> providers)
|
||||
: base(innerAgent)
|
||||
{
|
||||
Throw.IfNull(providers);
|
||||
Throw.IfLessThanOrEqual(providers.Count, 0, nameof(providers));
|
||||
|
||||
this._providers = providers;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override async Task<AgentResponse> RunCoreAsync(
|
||||
IEnumerable<ChatMessage> messages,
|
||||
AgentSession? session = null,
|
||||
AgentRunOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var enrichedMessages = await this.InvokeProvidersAsync(messages, session, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
AgentResponse response;
|
||||
try
|
||||
{
|
||||
response = await this.InnerAgent.RunAsync(enrichedMessages, session, options, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
await this.NotifyProvidersOfSuccessAsync(session, enrichedMessages, response.Messages, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingAsync(
|
||||
IEnumerable<ChatMessage> messages,
|
||||
AgentSession? session = null,
|
||||
AgentRunOptions? options = null,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var enrichedMessages = await this.InvokeProvidersAsync(messages, session, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
List<AgentResponseUpdate> responseUpdates = [];
|
||||
|
||||
IAsyncEnumerator<AgentResponseUpdate> enumerator;
|
||||
try
|
||||
{
|
||||
enumerator = this.InnerAgent.RunStreamingAsync(enrichedMessages, session, options, cancellationToken).GetAsyncEnumerator(cancellationToken);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
bool hasUpdates;
|
||||
try
|
||||
{
|
||||
hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
while (hasUpdates)
|
||||
{
|
||||
var update = enumerator.Current;
|
||||
responseUpdates.Add(update);
|
||||
yield return update;
|
||||
|
||||
try
|
||||
{
|
||||
hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(session, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
var agentResponse = responseUpdates.ToAgentResponse();
|
||||
await this.NotifyProvidersOfSuccessAsync(session, enrichedMessages, agentResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes each provider's <see cref="MessageAIContextProvider.InvokingAsync"/> in sequence,
|
||||
/// passing the output of each as input to the next.
|
||||
/// </summary>
|
||||
private async Task<IEnumerable<ChatMessage>> InvokeProvidersAsync(
|
||||
IEnumerable<ChatMessage> messages,
|
||||
AgentSession? session,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var currentMessages = messages;
|
||||
|
||||
foreach (var provider in this._providers)
|
||||
{
|
||||
var context = new MessageAIContextProvider.InvokingContext(this, session, currentMessages);
|
||||
currentMessages = await provider.InvokingAsync(context, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
return currentMessages;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Notifies each provider of a successful invocation.
|
||||
/// </summary>
|
||||
private async Task NotifyProvidersOfSuccessAsync(
|
||||
AgentSession? session,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
IEnumerable<ChatMessage> responseMessages,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var invokedContext = new AIContextProvider.InvokedContext(this, session, requestMessages, responseMessages);
|
||||
|
||||
foreach (var provider in this._providers)
|
||||
{
|
||||
await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Notifies each provider of a failed invocation.
|
||||
/// </summary>
|
||||
private async Task NotifyProvidersOfFailureAsync(
|
||||
AgentSession? session,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
Exception exception,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var invokedContext = new AIContextProvider.InvokedContext(this, session, requestMessages, exception);
|
||||
|
||||
foreach (var provider in this._providers)
|
||||
{
|
||||
await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,7 @@ namespace Microsoft.Agents.AI;
|
||||
/// multi-turn context to the retrieval layer without permanently altering the conversation history.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public sealed class TextSearchProvider : AIContextProvider
|
||||
public sealed class TextSearchProvider : MessageAIContextProvider
|
||||
{
|
||||
private const string DefaultPluginSearchFunctionName = "Search";
|
||||
private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question.";
|
||||
@@ -91,7 +91,7 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke)
|
||||
{
|
||||
@@ -102,6 +102,30 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
};
|
||||
}
|
||||
|
||||
return new AIContext
|
||||
{
|
||||
Messages = await this.ProvideMessagesAsync(
|
||||
new InvokingContext(context.Agent, context.Session, context.AIContext.Messages ?? []),
|
||||
cancellationToken).ConfigureAwait(false)
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// This code path is invoked using InvokingAsync on MessageAIContextProvider, which does not support tools and instructions,
|
||||
// and OnDemandFunctionCalling requires tools.
|
||||
if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke)
|
||||
{
|
||||
throw new InvalidOperationException($"Using the {nameof(TextSearchProvider)} as a {nameof(MessageAIContextProvider)} is not supported when {nameof(TextSearchProviderOptions.SearchTime)} is set to {TextSearchProviderOptions.TextSearchBehavior.OnDemandFunctionCalling}.");
|
||||
}
|
||||
|
||||
return base.InvokingCoreAsync(context, cancellationToken);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Retrieve recent messages from the session state.
|
||||
var recentMessagesText = this._sessionState.GetOrInitializeState(context.Session).RecentMessagesText
|
||||
?? [];
|
||||
@@ -109,7 +133,7 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
// Aggregate text from memory + current request messages.
|
||||
var sbInput = new StringBuilder();
|
||||
var requestMessagesText =
|
||||
(context.AIContext.Messages ?? [])
|
||||
(context.RequestMessages ?? [])
|
||||
.Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text);
|
||||
foreach (var messageText in recentMessagesText.Concat(requestMessagesText))
|
||||
{
|
||||
@@ -135,7 +159,7 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
|
||||
if (materialized.Count == 0)
|
||||
{
|
||||
return new AIContext();
|
||||
return [];
|
||||
}
|
||||
|
||||
// Format search results
|
||||
@@ -146,15 +170,12 @@ public sealed class TextSearchProvider : AIContextProvider
|
||||
this._logger.LogTrace("TextSearchProvider: Search Results\nInput:{Input}\nOutput:{MessageText}", input, formatted);
|
||||
}
|
||||
|
||||
return new AIContext
|
||||
{
|
||||
Messages = [new ChatMessage(ChatRole.User, formatted)]
|
||||
};
|
||||
return [new ChatMessage(ChatRole.User, formatted)];
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
this._logger?.LogError(ex, "TextSearchProvider: Failed to search for data due to error");
|
||||
return new AIContext();
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+323
@@ -0,0 +1,323 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Contains tests for the <see cref="MessageAIContextProvider"/> class.
|
||||
/// </summary>
|
||||
public class MessageAIContextProviderTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
#region InvokingAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_NullContext_ThrowsArgumentNullExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestMessageProvider();
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<ArgumentNullException>(() => provider.InvokingAsync(null!).AsTask());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_ReturnsInputAndProvidedMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context message") };
|
||||
var provider = new TestMessageProvider(provideMessages: providedMessages);
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "User input")]);
|
||||
|
||||
// Act
|
||||
var result = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert - input messages + provided messages merged
|
||||
Assert.Equal(2, result.Count);
|
||||
Assert.Equal("User input", result[0].Text);
|
||||
Assert.Equal("Context message", result[1].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_ReturnsOnlyInputMessages_WhenNoMessagesProvidedAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new DefaultMessageProvider();
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act
|
||||
var result = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Single(result);
|
||||
Assert.Equal("Hello", result[0].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_StampsProvidedMessagesWithAIContextProviderSourceAsync()
|
||||
{
|
||||
// Arrange
|
||||
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Provided") };
|
||||
var provider = new TestMessageProvider(provideMessages: providedMessages);
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
|
||||
// Act
|
||||
var result = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Single(result);
|
||||
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, result[0].GetAgentRequestMessageSourceType());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_FiltersInputToExternalOnlyByDefaultAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestMessageProvider(captureFilteredContext: true);
|
||||
var externalMsg = new ChatMessage(ChatRole.User, "External");
|
||||
var chatHistoryMsg = new ChatMessage(ChatRole.User, "History")
|
||||
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
|
||||
var contextProviderMsg = new ChatMessage(ChatRole.User, "ContextProvider")
|
||||
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, "src");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg, contextProviderMsg]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(context);
|
||||
|
||||
// Assert - ProvideMessagesAsync received only External messages
|
||||
Assert.NotNull(provider.LastFilteredContext);
|
||||
var filteredMessages = provider.LastFilteredContext!.RequestMessages.ToList();
|
||||
Assert.Single(filteredMessages);
|
||||
Assert.Equal("External", filteredMessages[0].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_UsesCustomProvideInputFilterAsync()
|
||||
{
|
||||
// Arrange - filter that keeps all messages (not just External)
|
||||
var provider = new TestMessageProvider(
|
||||
captureFilteredContext: true,
|
||||
provideInputMessageFilter: msgs => msgs);
|
||||
var externalMsg = new ChatMessage(ChatRole.User, "External");
|
||||
var chatHistoryMsg = new ChatMessage(ChatRole.User, "History")
|
||||
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(context);
|
||||
|
||||
// Assert - ProvideMessagesAsync received ALL messages (custom filter keeps everything)
|
||||
Assert.NotNull(provider.LastFilteredContext);
|
||||
var filteredMessages = provider.LastFilteredContext!.RequestMessages.ToList();
|
||||
Assert.Equal(2, filteredMessages.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingAsync_MergesWithOriginalUnfilteredMessagesAsync()
|
||||
{
|
||||
// Arrange - default filter is External-only, but the MERGED result should include
|
||||
// the original unfiltered input messages plus the provided messages
|
||||
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Provided") };
|
||||
var provider = new TestMessageProvider(provideMessages: providedMessages);
|
||||
var externalMsg = new ChatMessage(ChatRole.User, "External");
|
||||
var chatHistoryMsg = new ChatMessage(ChatRole.User, "History")
|
||||
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [externalMsg, chatHistoryMsg]);
|
||||
|
||||
// Act
|
||||
var result = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert - original 2 input messages + 1 provided message
|
||||
Assert.Equal(3, result.Count);
|
||||
Assert.Equal("External", result[0].Text);
|
||||
Assert.Equal("History", result[1].Text);
|
||||
Assert.Equal("Provided", result[2].Text);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region ProvideAIContextAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task ProvideAIContextAsync_PreservesInstructionsAndToolsAsync()
|
||||
{
|
||||
// Arrange
|
||||
var providedMessages = new[] { new ChatMessage(ChatRole.System, "Context") };
|
||||
var provider = new TestMessageProvider(provideMessages: providedMessages);
|
||||
var inputTool = AIFunctionFactory.Create(() => "a", "inputTool");
|
||||
var inputContext = new AIContext
|
||||
{
|
||||
Messages = [new ChatMessage(ChatRole.User, "Hello")],
|
||||
Instructions = "Be helpful",
|
||||
Tools = [inputTool]
|
||||
};
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext);
|
||||
|
||||
// Act
|
||||
var result = await provider.InvokingAsync(context);
|
||||
|
||||
// Assert - instructions and tools are preserved
|
||||
Assert.Equal("Be helpful", result.Instructions);
|
||||
Assert.NotNull(result.Tools);
|
||||
Assert.Single(result.Tools!);
|
||||
Assert.Equal("inputTool", result.Tools!.First().Name);
|
||||
|
||||
// Messages include original input + provided messages (with stamping)
|
||||
var messages = result.Messages!.ToList();
|
||||
Assert.Equal(2, messages.Count);
|
||||
Assert.Equal("Hello", messages[0].Text);
|
||||
Assert.Equal("Context", messages[1].Text);
|
||||
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ProvideAIContextAsync_PreservesNullInstructionsAndToolsAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new DefaultMessageProvider();
|
||||
var inputContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] };
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, inputContext);
|
||||
|
||||
// Act
|
||||
var result = await provider.InvokingAsync(context);
|
||||
|
||||
// Assert
|
||||
Assert.Null(result.Instructions);
|
||||
Assert.Null(result.Tools);
|
||||
var messages = result.Messages!.ToList();
|
||||
Assert.Single(messages);
|
||||
Assert.Equal("Hello", messages[0].Text);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region InvokingContext Tests
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Constructor_ThrowsForNullAgent()
|
||||
{
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProvider.InvokingContext(null!, s_mockSession, []));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Constructor_ThrowsForNullRequestMessages()
|
||||
{
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Constructor_AllowsNullSession()
|
||||
{
|
||||
// Act
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, null, []);
|
||||
|
||||
// Assert
|
||||
Assert.Null(context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Properties_Roundtrip()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockAgent, context.Agent);
|
||||
Assert.Same(s_mockSession, context.Session);
|
||||
Assert.Same(messages, context.RequestMessages);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_RequestMessages_SetterThrowsForNull()
|
||||
{
|
||||
// Arrange
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_RequestMessages_SetterAcceptsValidValue()
|
||||
{
|
||||
// Arrange
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var newMessages = new List<ChatMessage> { new(ChatRole.User, "Updated") };
|
||||
|
||||
// Act
|
||||
context.RequestMessages = newMessages;
|
||||
|
||||
// Assert
|
||||
Assert.Same(newMessages, context.RequestMessages);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region GetService Tests
|
||||
|
||||
[Fact]
|
||||
public void GetService_ReturnsProviderForMessageAIContextProviderType()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestMessageProvider();
|
||||
|
||||
// Act & Assert
|
||||
Assert.Same(provider, provider.GetService(typeof(MessageAIContextProvider)));
|
||||
Assert.Same(provider, provider.GetService(typeof(AIContextProvider)));
|
||||
Assert.Same(provider, provider.GetService(typeof(TestMessageProvider)));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Test helpers
|
||||
|
||||
private sealed class TestMessageProvider : MessageAIContextProvider
|
||||
{
|
||||
private readonly IEnumerable<ChatMessage>? _provideMessages;
|
||||
private readonly bool _captureFilteredContext;
|
||||
|
||||
public InvokingContext? LastFilteredContext { get; private set; }
|
||||
|
||||
public TestMessageProvider(
|
||||
IEnumerable<ChatMessage>? provideMessages = null,
|
||||
bool captureFilteredContext = false,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter = null,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
|
||||
: base(provideInputMessageFilter, storeInputMessageFilter)
|
||||
{
|
||||
this._provideMessages = provideMessages;
|
||||
this._captureFilteredContext = captureFilteredContext;
|
||||
}
|
||||
|
||||
protected override ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (this._captureFilteredContext)
|
||||
{
|
||||
this.LastFilteredContext = context;
|
||||
}
|
||||
|
||||
return new(this._provideMessages ?? []);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A provider that uses only base class defaults (no overrides of ProvideMessagesAsync).
|
||||
/// </summary>
|
||||
private sealed class DefaultMessageProvider : MessageAIContextProvider;
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
@@ -7,6 +9,56 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
/// </summary>
|
||||
public class ProviderSessionStateTests
|
||||
{
|
||||
#region Constructor Tests
|
||||
|
||||
[Fact]
|
||||
public void Constructor_ThrowsForNullStateInitializer()
|
||||
{
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new ProviderSessionState<TestState>(null!, "test-key"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_ThrowsForNullStateKey()
|
||||
{
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new ProviderSessionState<TestState>(_ => new TestState(), null!));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData("")]
|
||||
[InlineData(" ")]
|
||||
public void Constructor_ThrowsForEmptyOrWhitespaceStateKey(string stateKey)
|
||||
{
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentException>(() => new ProviderSessionState<TestState>(_ => new TestState(), stateKey));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_AcceptsNullJsonSerializerOptions()
|
||||
{
|
||||
// Act - should not throw
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => new TestState(), "test-key", jsonSerializerOptions: null);
|
||||
|
||||
// Assert - instance is created and functional
|
||||
Assert.Equal("test-key", sessionState.StateKey);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_AcceptsCustomJsonSerializerOptions()
|
||||
{
|
||||
// Arrange
|
||||
var customOptions = new System.Text.Json.JsonSerializerOptions();
|
||||
|
||||
// Act - should not throw
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => new TestState(), "test-key", customOptions);
|
||||
|
||||
// Assert - instance is created and functional
|
||||
Assert.Equal("test-key", sessionState.StateKey);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region GetOrInitializeState Tests
|
||||
|
||||
[Fact]
|
||||
|
||||
@@ -547,6 +547,87 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
Assert.Equal(2, memoryPosts.Count);
|
||||
}
|
||||
|
||||
#region MessageAIContextProvider.InvokingAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_SearchesAndReturnsMergedMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._handler.EnqueueJsonResponse("[ { \"id\": \"1\", \"memory\": \"Name is Caoimhe\", \"hash\": \"h\", \"metadata\": null, \"score\": 0.9, \"created_at\": \"2023-01-01T00:00:00Z\", \"updated_at\": null, \"user_id\": \"u\", \"app_id\": null, \"agent_id\": \"agent\", \"thread_id\": \"session\" } ]");
|
||||
var storageScope = new Mem0ProviderScope
|
||||
{
|
||||
ApplicationId = "app",
|
||||
AgentId = "agent",
|
||||
ThreadId = "session",
|
||||
UserId = "user"
|
||||
};
|
||||
var mockSession = new TestAgentSession();
|
||||
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
|
||||
|
||||
var inputMsg = new ChatMessage(ChatRole.User, "What is my name?");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [inputMsg]);
|
||||
|
||||
// Act
|
||||
var messages = (await sut.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert - input message + memory message, with stamping
|
||||
Assert.Equal(2, messages.Count);
|
||||
Assert.Equal("What is my name?", messages[0].Text);
|
||||
Assert.Contains("Name is Caoimhe", messages[1].Text);
|
||||
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_NoMemories_ReturnsOnlyInputMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._handler.EnqueueJsonResponse("[]");
|
||||
var storageScope = new Mem0ProviderScope
|
||||
{
|
||||
UserId = "user"
|
||||
};
|
||||
var mockSession = new TestAgentSession();
|
||||
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
|
||||
|
||||
var inputMsg = new ChatMessage(ChatRole.User, "Hello");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [inputMsg]);
|
||||
|
||||
// Act
|
||||
var messages = (await sut.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Single(messages);
|
||||
Assert.Equal("Hello", messages[0].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_DefaultFilter_ExcludesNonExternalMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._handler.EnqueueJsonResponse("[]");
|
||||
var storageScope = new Mem0ProviderScope
|
||||
{
|
||||
UserId = "user"
|
||||
};
|
||||
var mockSession = new TestAgentSession();
|
||||
var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope));
|
||||
|
||||
var externalMsg = new ChatMessage(ChatRole.User, "External question");
|
||||
var historyMsg = new ChatMessage(ChatRole.User, "History message")
|
||||
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, mockSession, [externalMsg, historyMsg]);
|
||||
|
||||
// Act
|
||||
await sut.InvokingAsync(context);
|
||||
|
||||
// Assert - Only External message used for search query
|
||||
var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post && ContainsOrdinal(r.RequestMessage.RequestUri!.AbsoluteUri, "/v1/memories/search/"));
|
||||
using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody);
|
||||
Assert.Equal("External question", doc.RootElement.GetProperty("query").GetString());
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0;
|
||||
|
||||
public void Dispose()
|
||||
|
||||
@@ -743,7 +743,7 @@ public sealed class TextSearchProviderTests
|
||||
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke,
|
||||
RecentMessageMemoryLimit = 4
|
||||
});
|
||||
await newProvider.InvokingAsync(new(s_mockAgent, restoredSession, new AIContext()), CancellationToken.None); // Trigger search to read memory.
|
||||
await newProvider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, restoredSession, new AIContext()), CancellationToken.None); // Trigger search to read memory.
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedInput);
|
||||
@@ -769,7 +769,7 @@ public sealed class TextSearchProviderTests
|
||||
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke,
|
||||
RecentMessageMemoryLimit = 3
|
||||
});
|
||||
await provider.InvokingAsync(new(s_mockAgent, session, new AIContext()), CancellationToken.None);
|
||||
await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext()), CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedInput);
|
||||
@@ -778,6 +778,101 @@ public sealed class TextSearchProviderTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region MessageAIContextProvider.InvokingAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_BeforeAIInvoke_SearchesAndReturnsMergedMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
List<TextSearchProvider.TextSearchResult> results =
|
||||
[
|
||||
new() { SourceName = "Doc1", Text = "Content of Doc1" }
|
||||
];
|
||||
|
||||
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
|
||||
=> Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>(results);
|
||||
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions
|
||||
{
|
||||
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke
|
||||
});
|
||||
|
||||
var inputMsg = new ChatMessage(ChatRole.User, "Question?");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
|
||||
|
||||
// Act
|
||||
var messages = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert - input message + search result message, with stamping
|
||||
Assert.Equal(2, messages.Count);
|
||||
Assert.Equal("Question?", messages[0].Text);
|
||||
Assert.Contains("Content of Doc1", messages[1].Text);
|
||||
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_OnDemand_ThrowsInvalidOperationExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TextSearchProvider(this.NoResultSearchAsync, new TextSearchProviderOptions
|
||||
{
|
||||
SearchTime = TextSearchProviderOptions.TextSearchBehavior.OnDemandFunctionCalling,
|
||||
});
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() => provider.InvokingAsync(context).AsTask());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_BeforeAIInvoke_NoResults_ReturnsOnlyInputMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TextSearchProvider(this.NoResultSearchAsync, new TextSearchProviderOptions
|
||||
{
|
||||
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke
|
||||
});
|
||||
var inputMsg = new ChatMessage(ChatRole.User, "Hello");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
|
||||
|
||||
// Act
|
||||
var messages = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Single(messages);
|
||||
Assert.Equal("Hello", messages[0].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_BeforeAIInvoke_DefaultFilter_ExcludesNonExternalMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
string? capturedInput = null;
|
||||
Task<IEnumerable<TextSearchProvider.TextSearchResult>> SearchDelegateAsync(string input, CancellationToken ct)
|
||||
{
|
||||
capturedInput = input;
|
||||
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
|
||||
}
|
||||
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions
|
||||
{
|
||||
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke
|
||||
});
|
||||
|
||||
var externalMsg = new ChatMessage(ChatRole.User, "External message");
|
||||
var historyMsg = new ChatMessage(ChatRole.System, "From history")
|
||||
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [externalMsg, historyMsg]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(context);
|
||||
|
||||
// Assert - Only External message used for search query
|
||||
Assert.Equal("External message", capturedInput);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private Task<IEnumerable<TextSearchProvider.TextSearchResult>> NoResultSearchAsync(string input, CancellationToken ct)
|
||||
{
|
||||
return Task.FromResult<IEnumerable<TextSearchProvider.TextSearchResult>>([]);
|
||||
|
||||
@@ -710,6 +710,147 @@ public class ChatHistoryMemoryProviderTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region MessageAIContextProvider.InvokingAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_BeforeAIInvoke_SearchesAndReturnsMergedMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var storedItems = new List<VectorSearchResult<Dictionary<string, object?>>>
|
||||
{
|
||||
new(
|
||||
new Dictionary<string, object?>
|
||||
{
|
||||
["MessageId"] = "msg-1",
|
||||
["Content"] = "Previous message",
|
||||
["Role"] = ChatRole.User.ToString(),
|
||||
["CreatedAt"] = "2023-01-01T00:00:00.0000000+00:00"
|
||||
},
|
||||
0.9f)
|
||||
};
|
||||
|
||||
this._vectorStoreCollectionMock
|
||||
.Setup(c => c.SearchAsync(
|
||||
It.IsAny<string>(),
|
||||
It.IsAny<int>(),
|
||||
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Returns(ToAsyncEnumerableAsync(storedItems));
|
||||
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
|
||||
options: new ChatHistoryMemoryProviderOptions
|
||||
{
|
||||
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke
|
||||
});
|
||||
|
||||
var inputMsg = new ChatMessage(ChatRole.User, "What was discussed?");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
|
||||
|
||||
// Act
|
||||
var messages = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert - input message + search result message, with stamping
|
||||
Assert.Equal(2, messages.Count);
|
||||
Assert.Equal("What was discussed?", messages[0].Text);
|
||||
Assert.Contains("Previous message", messages[1].Text);
|
||||
Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_OnDemand_ThrowsInvalidOperationExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
|
||||
options: new ChatHistoryMemoryProviderOptions
|
||||
{
|
||||
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling
|
||||
});
|
||||
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() => provider.InvokingAsync(context).AsTask());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_BeforeAIInvoke_NoResults_ReturnsOnlyInputMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
this._vectorStoreCollectionMock
|
||||
.Setup(c => c.SearchAsync(
|
||||
It.IsAny<string>(),
|
||||
It.IsAny<int>(),
|
||||
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
|
||||
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
|
||||
options: new ChatHistoryMemoryProviderOptions
|
||||
{
|
||||
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke
|
||||
});
|
||||
|
||||
var inputMsg = new ChatMessage(ChatRole.User, "Hello");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [inputMsg]);
|
||||
|
||||
// Act
|
||||
var messages = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert
|
||||
Assert.Single(messages);
|
||||
Assert.Equal("Hello", messages[0].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MessageInvokingAsync_BeforeAIInvoke_DefaultFilter_ExcludesNonExternalMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
string? capturedQuery = null;
|
||||
this._vectorStoreCollectionMock
|
||||
.Setup(c => c.SearchAsync(
|
||||
It.IsAny<string>(),
|
||||
It.IsAny<int>(),
|
||||
It.IsAny<VectorSearchOptions<Dictionary<string, object?>>>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Callback<string, int, VectorSearchOptions<Dictionary<string, object?>>, CancellationToken>((query, _, _, _) => capturedQuery = query)
|
||||
.Returns(ToAsyncEnumerableAsync(new List<VectorSearchResult<Dictionary<string, object?>>>()));
|
||||
|
||||
var provider = new ChatHistoryMemoryProvider(
|
||||
this._vectorStoreMock.Object,
|
||||
TestCollectionName,
|
||||
1,
|
||||
_ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }),
|
||||
options: new ChatHistoryMemoryProviderOptions
|
||||
{
|
||||
SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke
|
||||
});
|
||||
|
||||
var externalMsg = new ChatMessage(ChatRole.User, "External message");
|
||||
var historyMsg = new ChatMessage(ChatRole.System, "From history")
|
||||
.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "src");
|
||||
var context = new MessageAIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [externalMsg, historyMsg]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(context);
|
||||
|
||||
// Assert - Only External message used for search query
|
||||
Assert.Equal("External message", capturedQuery);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private static async IAsyncEnumerable<T> ToAsyncEnumerableAsync<T>(IEnumerable<T> values)
|
||||
{
|
||||
await Task.Yield();
|
||||
|
||||
@@ -0,0 +1,469 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
|
||||
namespace Microsoft.Agents.AI.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Unit tests for the <see cref="MessageAIContextProviderAgent"/> class and
|
||||
/// the <see cref="AIAgentBuilder.Use(MessageAIContextProvider[])"/> builder extension.
|
||||
/// </summary>
|
||||
public class MessageAIContextProviderAgentTests
|
||||
{
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
#region Constructor Tests
|
||||
|
||||
[Fact]
|
||||
public void Constructor_NullInnerAgent_ThrowsArgumentNullException()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestProvider();
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProviderAgent(null!, [provider]));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_NullProviders_ThrowsArgumentNullException()
|
||||
{
|
||||
// Arrange
|
||||
var agent = CreateTestAgent();
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new MessageAIContextProviderAgent(agent, null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_EmptyProviders_ThrowsArgumentOutOfRangeException()
|
||||
{
|
||||
// Arrange
|
||||
var agent = CreateTestAgent();
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentOutOfRangeException>(() => new MessageAIContextProviderAgent(agent, []));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region RunAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task RunAsync_SingleProvider_EnrichesMessagesAndDelegatesToInnerAgentAsync()
|
||||
{
|
||||
// Arrange
|
||||
var contextMessage = new ChatMessage(ChatRole.System, "Extra context");
|
||||
var provider = new TestProvider(provideMessages: [contextMessage]);
|
||||
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (messages, _, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
|
||||
|
||||
// Act
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert - inner agent received enriched messages (input + provider's message)
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(2, messageList.Count);
|
||||
Assert.Equal("Hello", messageList[0].Text);
|
||||
Assert.Contains("Extra context", messageList[1].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunAsync_MultipleProviders_CalledInSequenceAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]);
|
||||
var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")]);
|
||||
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (messages, _, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
|
||||
|
||||
// Act
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert - inner agent received messages from both providers in sequence
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(3, messageList.Count);
|
||||
Assert.Equal("Hello", messageList[0].Text);
|
||||
Assert.Contains("From provider 1", messageList[1].Text);
|
||||
Assert.Contains("From provider 2", messageList[2].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunAsync_SequentialProviders_EachReceivesPreviousOutputAsync()
|
||||
{
|
||||
// Arrange - provider 2 captures the filtered messages it receives in ProvideMessagesAsync.
|
||||
// The default filter only includes External messages, so provider 1's stamped messages
|
||||
// (marked as AIContextProvider) are filtered out before reaching provider 2's ProvideMessagesAsync.
|
||||
// However, the full unfiltered output from provider 1 is passed to provider 2's InvokingAsync,
|
||||
// and the inner agent receives the full merged output from both providers.
|
||||
IEnumerable<ChatMessage>? provider2ReceivedMessages = null;
|
||||
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]);
|
||||
var provider2 = new TestProvider(
|
||||
provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")],
|
||||
onInvoking: messages => provider2ReceivedMessages = messages.ToList());
|
||||
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])));
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
|
||||
|
||||
// Act
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert - provider 2's ProvideMessagesAsync received only External messages (filtered)
|
||||
Assert.NotNull(provider2ReceivedMessages);
|
||||
var received = provider2ReceivedMessages!.ToList();
|
||||
Assert.Single(received);
|
||||
Assert.Equal("Hello", received[0].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunAsync_OnSuccess_InvokedAsyncCalledOnAllProvidersAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider1 = new TestProvider();
|
||||
var provider2 = new TestProvider();
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")])));
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
|
||||
|
||||
// Act
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert
|
||||
Assert.True(provider1.InvokedAsyncCalled);
|
||||
Assert.True(provider2.InvokedAsyncCalled);
|
||||
Assert.Null(provider1.LastInvokedContext!.InvokeException);
|
||||
Assert.Null(provider2.LastInvokedContext!.InvokeException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestProvider();
|
||||
var expectedException = new InvalidOperationException("Agent failed");
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (_, _, _, _) => throw expectedException);
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() =>
|
||||
agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession));
|
||||
|
||||
Assert.True(provider.InvokedAsyncCalled);
|
||||
Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunAsync_OnSuccess_InvokedContextContainsResponseMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestProvider();
|
||||
var responseMessage = new ChatMessage(ChatRole.Assistant, "Response text");
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (_, _, _, _) => Task.FromResult(new AgentResponse([responseMessage])));
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
|
||||
|
||||
// Act
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(provider.LastInvokedContext?.ResponseMessages);
|
||||
Assert.Contains(provider.LastInvokedContext!.ResponseMessages!, m => m.Text == "Response text");
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region RunStreamingAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task RunStreamingAsync_SingleProvider_EnrichesMessagesAndStreamsAsync()
|
||||
{
|
||||
// Arrange
|
||||
var contextMessage = new ChatMessage(ChatRole.System, "Extra context");
|
||||
var provider = new TestProvider(provideMessages: [contextMessage]);
|
||||
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerAgent = CreateTestAgent(
|
||||
runStreamingFunc: (messages, _, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return ToAsyncEnumerableAsync(
|
||||
new AgentResponseUpdate(ChatRole.Assistant, "Part1"),
|
||||
new AgentResponseUpdate(ChatRole.Assistant, "Part2"));
|
||||
});
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
|
||||
|
||||
// Act
|
||||
var updates = new List<AgentResponseUpdate>();
|
||||
await foreach (var update in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
|
||||
{
|
||||
updates.Add(update);
|
||||
}
|
||||
|
||||
// Assert - streaming updates received
|
||||
Assert.Equal(2, updates.Count);
|
||||
// Assert - inner agent received enriched messages
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(2, messageList.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunStreamingAsync_OnSuccess_InvokedAsyncCalledAfterAllUpdatesAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestProvider();
|
||||
var innerAgent = CreateTestAgent(
|
||||
runStreamingFunc: (_, _, _, _) => ToAsyncEnumerableAsync(
|
||||
new AgentResponseUpdate(ChatRole.Assistant, "Response")));
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
|
||||
|
||||
// Act - consume all updates
|
||||
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
|
||||
{
|
||||
}
|
||||
|
||||
// Assert
|
||||
Assert.True(provider.InvokedAsyncCalled);
|
||||
Assert.Null(provider.LastInvokedContext!.InvokeException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunStreamingAsync_OnSuccess_InvokedContextContainsAccumulatedResponseAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestProvider();
|
||||
var innerAgent = CreateTestAgent(
|
||||
runStreamingFunc: (_, _, _, _) => ToAsyncEnumerableAsync(
|
||||
new AgentResponseUpdate(ChatRole.Assistant, "Hello "),
|
||||
new AgentResponseUpdate(ChatRole.Assistant, "World")));
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
|
||||
|
||||
// Act - consume all updates
|
||||
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
|
||||
{
|
||||
}
|
||||
|
||||
// Assert - InvokedAsync received the accumulated response messages
|
||||
Assert.NotNull(provider.LastInvokedContext?.ResponseMessages);
|
||||
var responseMessages = provider.LastInvokedContext!.ResponseMessages!.ToList();
|
||||
Assert.True(responseMessages.Count > 0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunStreamingAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestProvider();
|
||||
var expectedException = new InvalidOperationException("Stream failed");
|
||||
var innerAgent = CreateTestAgent(
|
||||
runStreamingFunc: (_, _, _, _) => throw expectedException);
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider]);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
|
||||
{
|
||||
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
|
||||
{
|
||||
}
|
||||
});
|
||||
|
||||
Assert.True(provider.InvokedAsyncCalled);
|
||||
Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunStreamingAsync_MultipleProviders_CalledInSequenceAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 1")]);
|
||||
var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "From provider 2")]);
|
||||
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerAgent = CreateTestAgent(
|
||||
runStreamingFunc: (messages, _, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return ToAsyncEnumerableAsync(new AgentResponseUpdate(ChatRole.Assistant, "Response"));
|
||||
});
|
||||
|
||||
var agent = new MessageAIContextProviderAgent(innerAgent, [provider1, provider2]);
|
||||
|
||||
// Act
|
||||
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession))
|
||||
{
|
||||
}
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(3, messageList.Count);
|
||||
Assert.Equal("Hello", messageList[0].Text);
|
||||
Assert.Contains("From provider 1", messageList[1].Text);
|
||||
Assert.Contains("From provider 2", messageList[2].Text);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Builder Extension Tests
|
||||
|
||||
[Fact]
|
||||
public async Task UseExtension_CreatesWorkingPipelineAsync()
|
||||
{
|
||||
// Arrange
|
||||
var contextMessage = new ChatMessage(ChatRole.System, "Pipeline context");
|
||||
var provider = new TestProvider(provideMessages: [contextMessage]);
|
||||
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (messages, _, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var pipeline = new AIAgentBuilder(innerAgent)
|
||||
.Use([provider])
|
||||
.Build();
|
||||
|
||||
// Act
|
||||
await pipeline.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(2, messageList.Count);
|
||||
Assert.Equal("Hello", messageList[0].Text);
|
||||
Assert.Contains("Pipeline context", messageList[1].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UseExtension_MultipleProviders_AllAppliedAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider1 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "P1")]);
|
||||
var provider2 = new TestProvider(provideMessages: [new ChatMessage(ChatRole.System, "P2")]);
|
||||
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerAgent = CreateTestAgent(
|
||||
runFunc: (messages, _, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return Task.FromResult(new AgentResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var pipeline = new AIAgentBuilder(innerAgent)
|
||||
.Use([provider1, provider2])
|
||||
.Build();
|
||||
|
||||
// Act
|
||||
await pipeline.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(3, messageList.Count);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helpers
|
||||
|
||||
private static TestAIAgent CreateTestAgent(
|
||||
Func<IEnumerable<ChatMessage>, AgentSession?, AgentRunOptions?, CancellationToken, Task<AgentResponse>>? runFunc = null,
|
||||
Func<IEnumerable<ChatMessage>, AgentSession?, AgentRunOptions?, CancellationToken, IAsyncEnumerable<AgentResponseUpdate>>? runStreamingFunc = null)
|
||||
{
|
||||
var agent = new TestAIAgent();
|
||||
if (runFunc is not null)
|
||||
{
|
||||
agent.RunAsyncFunc = runFunc;
|
||||
}
|
||||
|
||||
if (runStreamingFunc is not null)
|
||||
{
|
||||
agent.RunStreamingAsyncFunc = runStreamingFunc;
|
||||
}
|
||||
|
||||
return agent;
|
||||
}
|
||||
|
||||
private static async IAsyncEnumerable<AgentResponseUpdate> ToAsyncEnumerableAsync(params AgentResponseUpdate[] updates)
|
||||
{
|
||||
foreach (var update in updates)
|
||||
{
|
||||
yield return update;
|
||||
}
|
||||
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A test implementation of <see cref="MessageAIContextProvider"/> that records invocation calls.
|
||||
/// </summary>
|
||||
private sealed class TestProvider : MessageAIContextProvider
|
||||
{
|
||||
private readonly IEnumerable<ChatMessage> _provideMessages;
|
||||
private readonly Action<IEnumerable<ChatMessage>>? _onInvoking;
|
||||
|
||||
public bool InvokedAsyncCalled { get; private set; }
|
||||
|
||||
public InvokedContext? LastInvokedContext { get; private set; }
|
||||
|
||||
public TestProvider(
|
||||
IEnumerable<ChatMessage>? provideMessages = null,
|
||||
Action<IEnumerable<ChatMessage>>? onInvoking = null)
|
||||
{
|
||||
this._provideMessages = provideMessages ?? [];
|
||||
this._onInvoking = onInvoking;
|
||||
}
|
||||
|
||||
protected override ValueTask<IEnumerable<ChatMessage>> ProvideMessagesAsync(
|
||||
InvokingContext context,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
this._onInvoking?.Invoke(context.RequestMessages);
|
||||
return new ValueTask<IEnumerable<ChatMessage>>(this._provideMessages);
|
||||
}
|
||||
|
||||
protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
this.InvokedAsyncCalled = true;
|
||||
this.LastInvokedContext = context;
|
||||
return default;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
Reference in New Issue
Block a user