mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: [BREAKING] Add ChatClient decorator for calling AIContextProviders (#4097)
* Add ChatClient decorator for calling AIContextProviders * Format new files * Address PR comments * Revert problematic change * Rename Use to UseAIContextProvider
This commit is contained in:
committed by
GitHub
Unverified
parent
6e4562e354
commit
e45e58108b
@@ -105,12 +105,27 @@ Console.WriteLine("\n\n=== Example 5: MessageAIContextProvider middleware ===");
|
||||
|
||||
var contextProviderAgent = originalAgent
|
||||
.AsBuilder()
|
||||
.Use([new DateTimeContextProvider()])
|
||||
.UseAIContextProviders(new DateTimeContextProvider())
|
||||
.Build();
|
||||
|
||||
var contextResponse = await contextProviderAgent.RunAsync("Is it almost time for lunch?");
|
||||
Console.WriteLine($"Context-enriched response: {contextResponse}");
|
||||
|
||||
// AIContextProvider at the chat client level. Unlike the agent-level MessageAIContextProvider,
|
||||
// this operates within the IChatClient pipeline and can also enrich tools and instructions.
|
||||
// It must be used within the context of a running AIAgent (uses AIAgent.CurrentRunContext).
|
||||
// In this case we are attaching an AIContextProvider that only adds messages.
|
||||
Console.WriteLine("\n\n=== Example 6: AIContextProvider on chat client pipeline ===");
|
||||
|
||||
var chatClientProviderAgent = azureOpenAIClient.AsIChatClient()
|
||||
.AsBuilder()
|
||||
.UseAIContextProviders(new DateTimeContextProvider())
|
||||
.BuildAIAgent(
|
||||
instructions: "You are an AI assistant that helps people find information.");
|
||||
|
||||
var chatClientContextResponse = await chatClientProviderAgent.RunAsync("Is it almost time for lunch?");
|
||||
Console.WriteLine($"Chat client context-enriched response: {chatClientContextResponse}");
|
||||
|
||||
// Function invocation middleware that logs before and after function calls.
|
||||
async ValueTask<object?> FunctionCallMiddleware(AIAgent agent, FunctionInvocationContext context, Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> next, CancellationToken cancellationToken)
|
||||
{
|
||||
@@ -278,7 +293,7 @@ async Task<ChatResponse> PerRequestChatClientMiddleware(IEnumerable<ChatMessage>
|
||||
/// <summary>
|
||||
/// A <see cref="MessageAIContextProvider"/> that injects the current date and time into the agent's context.
|
||||
/// This is a simple example of how to use a MessageAIContextProvider to enrich agent messages
|
||||
/// via the <see cref="AIAgentBuilder.Use(MessageAIContextProvider[])"/> extension method.
|
||||
/// via the <see cref="AIAgentBuilder.UseAIContextProviders(MessageAIContextProvider[])"/> extension method.
|
||||
/// </summary>
|
||||
internal sealed class DateTimeContextProvider : MessageAIContextProvider
|
||||
{
|
||||
|
||||
@@ -15,6 +15,7 @@ This sample demonstrates how to add middleware to intercept:
|
||||
6. Per‑request function pipeline with approval
|
||||
7. Combining agent‑level and per‑request middleware
|
||||
8. MessageAIContextProvider middleware via `AIAgentBuilder.Use(...)` for injecting additional context messages
|
||||
9. AIContextProvider middleware via `ChatClientBuilder.Use(...)` for enriching messages, tools, and instructions at the chat client level
|
||||
|
||||
## Function Invocation Middleware
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ public sealed class AIAgentBuilder
|
||||
/// context enrichment, not just agents that natively support <see cref="AIContextProvider"/> instances.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public AIAgentBuilder Use(MessageAIContextProvider[] providers)
|
||||
public AIAgentBuilder UseAIContextProviders(params MessageAIContextProvider[] providers)
|
||||
{
|
||||
return this.Use((innerAgent, _) => new MessageAIContextProviderAgent(innerAgent, providers));
|
||||
}
|
||||
|
||||
+215
@@ -0,0 +1,215 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// A delegating chat client that enriches input messages, tools, and instructions by invoking a pipeline of
|
||||
/// <see cref="AIContextProvider"/> instances before delegating to the inner chat client, and notifies those
|
||||
/// providers after the inner client completes.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// This chat client must be used within the context of a running <see cref="AIAgent"/>. It retrieves the current
|
||||
/// agent and session from <see cref="AIAgent.CurrentRunContext"/>, which is set automatically when an agent's
|
||||
/// <see cref="AIAgent.RunAsync(IEnumerable{ChatMessage}, AgentSession?, AgentRunOptions?, CancellationToken)"/> or
|
||||
/// <see cref="AIAgent.RunStreamingAsync(IEnumerable{ChatMessage}, AgentSession?, AgentRunOptions?, CancellationToken)"/> method is called.
|
||||
/// An <see cref="InvalidOperationException"/> is thrown if no run context is available.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
internal sealed class AIContextProviderChatClient : DelegatingChatClient
|
||||
{
|
||||
private readonly IReadOnlyList<AIContextProvider> _providers;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="AIContextProviderChatClient"/> class.
|
||||
/// </summary>
|
||||
/// <param name="innerClient">The underlying chat client that will handle the core operations.</param>
|
||||
/// <param name="providers">The AI context providers to invoke before and after the inner chat client.</param>
|
||||
public AIContextProviderChatClient(IChatClient innerClient, IReadOnlyList<AIContextProvider> providers)
|
||||
: base(innerClient)
|
||||
{
|
||||
Throw.IfNull(providers);
|
||||
|
||||
if (providers.Count == 0)
|
||||
{
|
||||
Throw.ArgumentException(nameof(providers), "At least one AIContextProvider must be provided.");
|
||||
}
|
||||
|
||||
this._providers = providers;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override async Task<ChatResponse> GetResponseAsync(
|
||||
IEnumerable<ChatMessage> messages,
|
||||
ChatOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var runContext = GetRequiredRunContext();
|
||||
var (enrichedMessages, enrichedOptions) = await this.InvokeProvidersAsync(runContext, messages, options, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
ChatResponse response;
|
||||
try
|
||||
{
|
||||
response = await base.GetResponseAsync(enrichedMessages, enrichedOptions, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(runContext, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
await this.NotifyProvidersOfSuccessAsync(runContext, enrichedMessages, response.Messages, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
|
||||
IEnumerable<ChatMessage> messages,
|
||||
ChatOptions? options = null,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var runContext = GetRequiredRunContext();
|
||||
var (enrichedMessages, enrichedOptions) = await this.InvokeProvidersAsync(runContext, messages, options, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
List<ChatResponseUpdate> responseUpdates = [];
|
||||
|
||||
IAsyncEnumerator<ChatResponseUpdate> enumerator;
|
||||
try
|
||||
{
|
||||
enumerator = base.GetStreamingResponseAsync(enrichedMessages, enrichedOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(runContext, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
bool hasUpdates;
|
||||
try
|
||||
{
|
||||
hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(runContext, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
while (hasUpdates)
|
||||
{
|
||||
var update = enumerator.Current;
|
||||
responseUpdates.Add(update);
|
||||
yield return update;
|
||||
|
||||
try
|
||||
{
|
||||
hasUpdates = await enumerator.MoveNextAsync().ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await this.NotifyProvidersOfFailureAsync(runContext, enrichedMessages, ex, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
var chatResponse = responseUpdates.ToChatResponse();
|
||||
await this.NotifyProvidersOfSuccessAsync(runContext, enrichedMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the current <see cref="AgentRunContext"/>, throwing if not available.
|
||||
/// </summary>
|
||||
private static AgentRunContext GetRequiredRunContext()
|
||||
{
|
||||
return AIAgent.CurrentRunContext
|
||||
?? throw new InvalidOperationException(
|
||||
$"{nameof(AIContextProviderChatClient)} can only be used within the context of a running AIAgent. " +
|
||||
"Ensure that the chat client is being invoked as part of an AIAgent.RunAsync or AIAgent.RunStreamingAsync call.");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes each provider's <see cref="AIContextProvider.InvokingAsync"/> in sequence,
|
||||
/// accumulating context (messages, tools, instructions) from each.
|
||||
/// </summary>
|
||||
private async Task<(IEnumerable<ChatMessage> Messages, ChatOptions? Options)> InvokeProvidersAsync(
|
||||
AgentRunContext runContext,
|
||||
IEnumerable<ChatMessage> messages,
|
||||
ChatOptions? options,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var aiContext = new AIContext
|
||||
{
|
||||
Instructions = options?.Instructions,
|
||||
Messages = messages,
|
||||
Tools = options?.Tools
|
||||
};
|
||||
|
||||
foreach (var provider in this._providers)
|
||||
{
|
||||
var invokingContext = new AIContextProvider.InvokingContext(runContext.Agent, runContext.Session, aiContext);
|
||||
aiContext = await provider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
// Materialize the accumulated context back into messages and options.
|
||||
var enrichedMessages = aiContext.Messages ?? [];
|
||||
|
||||
var tools = aiContext.Tools as IList<AITool> ?? aiContext.Tools?.ToList();
|
||||
if (options?.Tools is { Count: > 0 } || tools is { Count: > 0 })
|
||||
{
|
||||
options ??= new();
|
||||
options.Tools = tools;
|
||||
}
|
||||
|
||||
if (options?.Instructions is not null || aiContext.Instructions is not null)
|
||||
{
|
||||
options ??= new();
|
||||
options.Instructions = aiContext.Instructions;
|
||||
}
|
||||
|
||||
return (enrichedMessages, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Notifies each provider of a successful invocation.
|
||||
/// </summary>
|
||||
private async Task NotifyProvidersOfSuccessAsync(
|
||||
AgentRunContext runContext,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
IEnumerable<ChatMessage> responseMessages,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var invokedContext = new AIContextProvider.InvokedContext(runContext.Agent, runContext.Session, requestMessages, responseMessages);
|
||||
|
||||
foreach (var provider in this._providers)
|
||||
{
|
||||
await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Notifies each provider of a failed invocation.
|
||||
/// </summary>
|
||||
private async Task NotifyProvidersOfFailureAsync(
|
||||
AgentRunContext runContext,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
Exception exception,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var invokedContext = new AIContextProvider.InvokedContext(runContext.Agent, runContext.Session, requestMessages, exception);
|
||||
|
||||
foreach (var provider in this._providers)
|
||||
{
|
||||
await provider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using Microsoft.Agents.AI;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
|
||||
namespace Microsoft.Extensions.AI;
|
||||
|
||||
/// <summary>
|
||||
/// Provides extension methods for adding <see cref="AIContextProvider"/> support to <see cref="ChatClientBuilder"/> instances.
|
||||
/// </summary>
|
||||
public static class AIContextProviderChatClientBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds one or more <see cref="AIContextProvider"/> instances to the chat client pipeline, enabling context enrichment
|
||||
/// (messages, tools, and instructions) for any <see cref="IChatClient"/>.
|
||||
/// </summary>
|
||||
/// <param name="builder">The <see cref="ChatClientBuilder"/> to which the providers will be added.</param>
|
||||
/// <param name="providers">
|
||||
/// The <see cref="AIContextProvider"/> instances to invoke before and after each chat client call.
|
||||
/// Providers are called in sequence, with each receiving the accumulated context from the previous provider.
|
||||
/// </param>
|
||||
/// <returns>The <see cref="ChatClientBuilder"/> with the providers added, enabling method chaining.</returns>
|
||||
/// <exception cref="System.ArgumentNullException"><paramref name="builder"/> or <paramref name="providers"/> is <see langword="null"/>.</exception>
|
||||
/// <exception cref="System.ArgumentException"><paramref name="providers"/> is empty.</exception>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// This method wraps the inner chat client with a decorator that calls each provider's
|
||||
/// <see cref="AIContextProvider.InvokingAsync"/> in sequence before the inner client is called,
|
||||
/// and calls <see cref="AIContextProvider.InvokedAsync"/> on each provider after the inner client completes.
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// The chat client must be used within the context of a running <see cref="AIAgent"/>. The agent and session
|
||||
/// are retrieved from <see cref="AIAgent.CurrentRunContext"/>. An <see cref="System.InvalidOperationException"/>
|
||||
/// is thrown at invocation time if no run context is available.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public static ChatClientBuilder UseAIContextProviders(this ChatClientBuilder builder, params AIContextProvider[] providers)
|
||||
{
|
||||
_ = Throw.IfNull(builder);
|
||||
|
||||
return builder.Use(innerClient => new AIContextProviderChatClient(innerClient, providers));
|
||||
}
|
||||
}
|
||||
+430
@@ -0,0 +1,430 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
|
||||
namespace Microsoft.Agents.AI.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Unit tests for the <see cref="AIContextProviderChatClient"/> class and
|
||||
/// the <see cref="AIContextProviderChatClientBuilderExtensions.UseAIContextProviders(ChatClientBuilder, AIContextProvider[])"/> builder extension.
|
||||
/// </summary>
|
||||
public class AIContextProviderChatClientTests
|
||||
{
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
#region Constructor Tests
|
||||
|
||||
[Fact]
|
||||
public void Constructor_NullInnerClient_ThrowsArgumentNullException()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider("key1");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProviderChatClient(null!, [provider]));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_NullProviders_ThrowsArgumentNullException()
|
||||
{
|
||||
// Arrange
|
||||
var innerClient = new Mock<IChatClient>().Object;
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProviderChatClient(innerClient, null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Constructor_EmptyProviders_ThrowsArgumentException()
|
||||
{
|
||||
// Arrange
|
||||
var innerClient = new Mock<IChatClient>().Object;
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentException>(() => new AIContextProviderChatClient(innerClient, []));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region GetResponseAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task GetResponseAsync_NoRunContext_ThrowsInvalidOperationExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerClient = new Mock<IChatClient>();
|
||||
var provider = new TestAIContextProvider("key1");
|
||||
var chatClient = new AIContextProviderChatClient(innerClient.Object, [provider]);
|
||||
|
||||
// Act & Assert — no AIAgent.CurrentRunContext is set
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(
|
||||
() => chatClient.GetResponseAsync([new ChatMessage(ChatRole.User, "Hello")]));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetResponseAsync_SingleProvider_EnrichesMessagesAsync()
|
||||
{
|
||||
// Arrange
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerClient = CreateMockChatClient(
|
||||
onGetResponse: (messages, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var provider = new TestAIContextProvider("key1", provideMessages: [new ChatMessage(ChatRole.System, "Extra context")]);
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider]);
|
||||
|
||||
// Act — run through an agent so CurrentRunContext is set
|
||||
await RunWithAgentContextAsync(chatClient);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(2, messageList.Count);
|
||||
Assert.Equal("Hello", messageList[0].Text);
|
||||
Assert.Contains("Extra context", messageList[1].Text);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetResponseAsync_MultipleProviders_CalledInSequenceAsync()
|
||||
{
|
||||
// Arrange
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerClient = CreateMockChatClient(
|
||||
onGetResponse: (messages, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var provider1 = new TestAIContextProvider("key1", provideMessages: [new ChatMessage(ChatRole.System, "From P1")]);
|
||||
var provider2 = new TestAIContextProvider("key2", provideMessages: [new ChatMessage(ChatRole.System, "From P2")]);
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider1, provider2]);
|
||||
|
||||
// Act
|
||||
await RunWithAgentContextAsync(chatClient);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(3, messageList.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetResponseAsync_Provider_EnrichesToolsAndInstructionsAsync()
|
||||
{
|
||||
// Arrange
|
||||
ChatOptions? capturedOptions = null;
|
||||
var innerClient = CreateMockChatClient(
|
||||
onGetResponse: (_, options, _) =>
|
||||
{
|
||||
capturedOptions = options;
|
||||
return Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var provider = new TestAIContextProvider("key1", provideInstructions: "Extra instructions", provideTools: [new TestAITool()]);
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider]);
|
||||
|
||||
// Act
|
||||
await RunWithAgentContextAsync(chatClient);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedOptions);
|
||||
Assert.Equal("Extra instructions", capturedOptions!.Instructions);
|
||||
Assert.Single(capturedOptions.Tools!);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetResponseAsync_OnSuccess_InvokedAsyncCalledAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerClient = CreateMockChatClient(
|
||||
onGetResponse: (_, _, _) => Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Response")])));
|
||||
|
||||
var provider = new TestAIContextProvider("key1");
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider]);
|
||||
|
||||
// Act
|
||||
await RunWithAgentContextAsync(chatClient);
|
||||
|
||||
// Assert
|
||||
Assert.True(provider.InvokedAsyncCalled);
|
||||
Assert.Null(provider.LastInvokedContext!.InvokeException);
|
||||
Assert.NotNull(provider.LastInvokedContext.ResponseMessages);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetResponseAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var expectedException = new InvalidOperationException("Chat failed");
|
||||
var innerClient = CreateMockChatClient(
|
||||
onGetResponse: (_, _, _) => throw expectedException);
|
||||
|
||||
var provider = new TestAIContextProvider("key1");
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider]);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(() => RunWithAgentContextAsync(chatClient));
|
||||
|
||||
Assert.True(provider.InvokedAsyncCalled);
|
||||
Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region GetStreamingResponseAsync Tests
|
||||
|
||||
[Fact]
|
||||
public async Task GetStreamingResponseAsync_SingleProvider_EnrichesAndStreamsAsync()
|
||||
{
|
||||
// Arrange
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerClient = CreateMockStreamingChatClient(
|
||||
onGetStreamingResponse: (messages, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return ToAsyncEnumerableAsync(
|
||||
new ChatResponseUpdate(ChatRole.Assistant, "Part1"),
|
||||
new ChatResponseUpdate(ChatRole.Assistant, "Part2"));
|
||||
});
|
||||
|
||||
var provider = new TestAIContextProvider("key1", provideMessages: [new ChatMessage(ChatRole.System, "Extra context")]);
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider]);
|
||||
|
||||
// Act
|
||||
var updates = new List<ChatResponseUpdate>();
|
||||
await RunStreamingWithAgentContextAsync(chatClient, updates);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(2, updates.Count);
|
||||
Assert.NotNull(capturedMessages);
|
||||
Assert.Equal(2, capturedMessages!.ToList().Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetStreamingResponseAsync_OnSuccess_InvokedAsyncCalledAsync()
|
||||
{
|
||||
// Arrange
|
||||
var innerClient = CreateMockStreamingChatClient(
|
||||
onGetStreamingResponse: (_, _, _) => ToAsyncEnumerableAsync(
|
||||
new ChatResponseUpdate(ChatRole.Assistant, "Response")));
|
||||
|
||||
var provider = new TestAIContextProvider("key1");
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider]);
|
||||
|
||||
// Act
|
||||
await RunStreamingWithAgentContextAsync(chatClient, []);
|
||||
|
||||
// Assert
|
||||
Assert.True(provider.InvokedAsyncCalled);
|
||||
Assert.Null(provider.LastInvokedContext!.InvokeException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task GetStreamingResponseAsync_OnFailure_InvokedAsyncCalledWithExceptionAsync()
|
||||
{
|
||||
// Arrange
|
||||
var expectedException = new InvalidOperationException("Stream failed");
|
||||
var innerClient = CreateMockStreamingChatClient(
|
||||
onGetStreamingResponse: (_, _, _) => throw expectedException);
|
||||
|
||||
var provider = new TestAIContextProvider("key1");
|
||||
var chatClient = new AIContextProviderChatClient(innerClient, [provider]);
|
||||
|
||||
// Act & Assert
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(
|
||||
() => RunStreamingWithAgentContextAsync(chatClient, []));
|
||||
|
||||
Assert.True(provider.InvokedAsyncCalled);
|
||||
Assert.Same(expectedException, provider.LastInvokedContext!.InvokeException);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Builder Extension Tests
|
||||
|
||||
[Fact]
|
||||
public void UseExtension_NullBuilder_ThrowsArgumentNullException()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider("key1");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() =>
|
||||
AIContextProviderChatClientBuilderExtensions.UseAIContextProviders(null!, provider));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UseExtension_CreatesWorkingPipelineAsync()
|
||||
{
|
||||
// Arrange
|
||||
IEnumerable<ChatMessage>? capturedMessages = null;
|
||||
var innerClient = CreateMockChatClient(
|
||||
onGetResponse: (messages, _, _) =>
|
||||
{
|
||||
capturedMessages = messages;
|
||||
return Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Response")]));
|
||||
});
|
||||
|
||||
var provider = new TestAIContextProvider("key1", provideMessages: [new ChatMessage(ChatRole.System, "Pipeline context")]);
|
||||
|
||||
var pipeline = new ChatClientBuilder(innerClient)
|
||||
.UseAIContextProviders(provider)
|
||||
.Build();
|
||||
|
||||
// Act — wrap in an agent to set CurrentRunContext
|
||||
var agent = new TestAIAgent
|
||||
{
|
||||
RunAsyncFunc = async (messages, session, options, ct) =>
|
||||
{
|
||||
var response = await pipeline.GetResponseAsync(messages, cancellationToken: ct);
|
||||
return new AgentResponse(response);
|
||||
}
|
||||
};
|
||||
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedMessages);
|
||||
var messageList = capturedMessages!.ToList();
|
||||
Assert.Equal(2, messageList.Count);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helpers
|
||||
|
||||
/// <summary>
|
||||
/// Runs a chat client within an agent context so that AIAgent.CurrentRunContext is set.
|
||||
/// </summary>
|
||||
private static async Task RunWithAgentContextAsync(AIContextProviderChatClient chatClient)
|
||||
{
|
||||
var agent = new TestAIAgent
|
||||
{
|
||||
RunAsyncFunc = async (messages, session, options, ct) =>
|
||||
{
|
||||
var response = await chatClient.GetResponseAsync(messages, cancellationToken: ct);
|
||||
return new AgentResponse(response);
|
||||
}
|
||||
};
|
||||
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Runs a streaming chat client within an agent context so that AIAgent.CurrentRunContext is set.
|
||||
/// </summary>
|
||||
private static async Task RunStreamingWithAgentContextAsync(AIContextProviderChatClient chatClient, List<ChatResponseUpdate> updates)
|
||||
{
|
||||
var agent = new TestAIAgent
|
||||
{
|
||||
RunAsyncFunc = async (messages, session, options, ct) =>
|
||||
{
|
||||
await foreach (var update in chatClient.GetStreamingResponseAsync(messages, cancellationToken: ct))
|
||||
{
|
||||
updates.Add(update);
|
||||
}
|
||||
|
||||
return new AgentResponse([new ChatMessage(ChatRole.Assistant, "done")]);
|
||||
}
|
||||
};
|
||||
|
||||
await agent.RunAsync([new ChatMessage(ChatRole.User, "Hello")], s_mockSession);
|
||||
}
|
||||
|
||||
private static IChatClient CreateMockChatClient(
|
||||
Func<IEnumerable<ChatMessage>, ChatOptions?, CancellationToken, Task<ChatResponse>> onGetResponse)
|
||||
{
|
||||
var mock = new Mock<IChatClient>();
|
||||
mock.Setup(c => c.GetResponseAsync(
|
||||
It.IsAny<IEnumerable<ChatMessage>>(),
|
||||
It.IsAny<ChatOptions?>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Returns((IEnumerable<ChatMessage> m, ChatOptions? o, CancellationToken ct) => onGetResponse(m, o, ct));
|
||||
return mock.Object;
|
||||
}
|
||||
|
||||
private static IChatClient CreateMockStreamingChatClient(
|
||||
Func<IEnumerable<ChatMessage>, ChatOptions?, CancellationToken, IAsyncEnumerable<ChatResponseUpdate>> onGetStreamingResponse)
|
||||
{
|
||||
var mock = new Mock<IChatClient>();
|
||||
mock.Setup(c => c.GetStreamingResponseAsync(
|
||||
It.IsAny<IEnumerable<ChatMessage>>(),
|
||||
It.IsAny<ChatOptions?>(),
|
||||
It.IsAny<CancellationToken>()))
|
||||
.Returns((IEnumerable<ChatMessage> m, ChatOptions? o, CancellationToken ct) => onGetStreamingResponse(m, o, ct));
|
||||
return mock.Object;
|
||||
}
|
||||
|
||||
private static async IAsyncEnumerable<ChatResponseUpdate> ToAsyncEnumerableAsync(params ChatResponseUpdate[] updates)
|
||||
{
|
||||
foreach (var update in updates)
|
||||
{
|
||||
yield return update;
|
||||
}
|
||||
|
||||
await Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A test AIContextProvider that provides configurable messages, tools, and instructions.
|
||||
/// </summary>
|
||||
private sealed class TestAIContextProvider : AIContextProvider
|
||||
{
|
||||
private readonly string _stateKey;
|
||||
private readonly IEnumerable<ChatMessage> _provideMessages;
|
||||
private readonly string? _provideInstructions;
|
||||
private readonly IEnumerable<AITool>? _provideTools;
|
||||
|
||||
public bool InvokedAsyncCalled { get; private set; }
|
||||
|
||||
public InvokedContext? LastInvokedContext { get; private set; }
|
||||
|
||||
public override string StateKey => this._stateKey;
|
||||
|
||||
public TestAIContextProvider(
|
||||
string stateKey,
|
||||
IEnumerable<ChatMessage>? provideMessages = null,
|
||||
string? provideInstructions = null,
|
||||
IEnumerable<AITool>? provideTools = null)
|
||||
{
|
||||
this._stateKey = stateKey;
|
||||
this._provideMessages = provideMessages ?? [];
|
||||
this._provideInstructions = provideInstructions;
|
||||
this._provideTools = provideTools;
|
||||
}
|
||||
|
||||
protected override ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return new ValueTask<AIContext>(new AIContext
|
||||
{
|
||||
Messages = this._provideMessages,
|
||||
Instructions = this._provideInstructions,
|
||||
Tools = this._provideTools,
|
||||
});
|
||||
}
|
||||
|
||||
protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
this.InvokedAsyncCalled = true;
|
||||
this.LastInvokedContext = context;
|
||||
return default;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A minimal AITool for testing.
|
||||
/// </summary>
|
||||
private sealed class TestAITool : AITool;
|
||||
|
||||
#endregion
|
||||
}
|
||||
+3
-3
@@ -12,7 +12,7 @@ namespace Microsoft.Agents.AI.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Unit tests for the <see cref="MessageAIContextProviderAgent"/> class and
|
||||
/// the <see cref="AIAgentBuilder.Use(MessageAIContextProvider[])"/> builder extension.
|
||||
/// the <see cref="AIAgentBuilder.UseAIContextProviders(MessageAIContextProvider[])"/> builder extension.
|
||||
/// </summary>
|
||||
public class MessageAIContextProviderAgentTests
|
||||
{
|
||||
@@ -355,7 +355,7 @@ public class MessageAIContextProviderAgentTests
|
||||
});
|
||||
|
||||
var pipeline = new AIAgentBuilder(innerAgent)
|
||||
.Use([provider])
|
||||
.UseAIContextProviders([provider])
|
||||
.Build();
|
||||
|
||||
// Act
|
||||
@@ -385,7 +385,7 @@ public class MessageAIContextProviderAgentTests
|
||||
});
|
||||
|
||||
var pipeline = new AIAgentBuilder(innerAgent)
|
||||
.Use([provider1, provider2])
|
||||
.UseAIContextProviders([provider1, provider2])
|
||||
.Build();
|
||||
|
||||
// Act
|
||||
Reference in New Issue
Block a user