.NET: [BREAKING] Provide agent and session to AIContextProvider & ChatHistoryProvider (#3695)

* Add a StateBag to AgentSession and pass Agent and AgentSession to AIContextProvider and ChatHistoryProviders

* Remove statebag code from this branch, to get the refactoring out of the way first

* Apply suggestion from @rogerbarreto

Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com>

* Apply suggestion from @westey-m

* Apply suggestion from @westey-m

---------

Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com>
This commit is contained in:
westey
2026-02-05 15:58:41 +00:00
committed by GitHub
Unverified
parent 2b66ca03b2
commit ec82ed15d2
26 changed files with 489 additions and 173 deletions
@@ -55,14 +55,14 @@ namespace SampleApp
}
// Get existing messages from the store
var invokingContext = new ChatHistoryProvider.InvokingContext(messages);
var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages);
var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken);
// Clone the input messages and turn them into response messages with upper case text.
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.Name).ToList();
// Notify the session of the input and output messages.
var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages)
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages)
{
ResponseMessages = responseMessages
};
@@ -87,14 +87,14 @@ namespace SampleApp
}
// Get existing messages from the store
var invokingContext = new ChatHistoryProvider.InvokingContext(messages);
var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages);
var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken);
// Clone the input messages and turn them into response messages with upper case text.
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.Name).ToList();
// Notify the session of the input and output messages.
var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages)
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages)
{
ResponseMessages = responseMessages
};
@@ -129,13 +129,30 @@ public abstract class AIContextProvider
/// <summary>
/// Initializes a new instance of the <see cref="InvokingContext"/> class with the specified request messages.
/// </summary>
/// <param name="agent">The agent being invoked.</param>
/// <param name="session">The session associated with the agent invocation.</param>
/// <param name="requestMessages">The messages to be used by the agent for this invocation.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokingContext(IEnumerable<ChatMessage> requestMessages)
public InvokingContext(
AIAgent agent,
AgentSession? session,
IEnumerable<ChatMessage> requestMessages)
{
this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
this.Agent = Throw.IfNull(agent);
this.Session = session;
this.RequestMessages = Throw.IfNull(requestMessages);
}
/// <summary>
/// Gets the agent that is being invoked.
/// </summary>
public AIAgent Agent { get; }
/// <summary>
/// Gets the agent session associated with the agent invocation.
/// </summary>
public AgentSession? Session { get; }
/// <summary>
/// Gets the caller provided messages that will be used by the agent for this invocation.
/// </summary>
@@ -158,15 +175,33 @@ public abstract class AIContextProvider
/// <summary>
/// Initializes a new instance of the <see cref="InvokedContext"/> class with the specified request messages.
/// </summary>
/// <param name="agent">The agent being invoked.</param>
/// <param name="session">The session associated with the agent invocation.</param>
/// <param name="requestMessages">The caller provided messages that were used by the agent for this invocation.</param>
/// <param name="aiContextProviderMessages">The messages provided by the <see cref="AIContextProvider"/> for this invocation, if any.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<ChatMessage>? aiContextProviderMessages)
public InvokedContext(
AIAgent agent,
AgentSession? session,
IEnumerable<ChatMessage> requestMessages,
IEnumerable<ChatMessage>? aiContextProviderMessages)
{
this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
this.Agent = Throw.IfNull(agent);
this.Session = session;
this.RequestMessages = Throw.IfNull(requestMessages);
this.AIContextProviderMessages = aiContextProviderMessages;
}
/// <summary>
/// Gets the agent that is being invoked.
/// </summary>
public AIAgent Agent { get; }
/// <summary>
/// Gets the agent session associated with the agent invocation.
/// </summary>
public AgentSession? Session { get; }
/// <summary>
/// Gets the caller provided messages that were used by the agent for this invocation.
/// </summary>
@@ -143,13 +143,30 @@ public abstract class ChatHistoryProvider
/// <summary>
/// Initializes a new instance of the <see cref="InvokingContext"/> class with the specified request messages.
/// </summary>
/// <param name="agent">The agent being invoked.</param>
/// <param name="session">The session associated with the agent invocation.</param>
/// <param name="requestMessages">The new messages to be used by the agent for this invocation.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokingContext(IEnumerable<ChatMessage> requestMessages)
public InvokingContext(
AIAgent agent,
AgentSession? session,
IEnumerable<ChatMessage> requestMessages)
{
this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
this.Agent = Throw.IfNull(agent);
this.Session = session;
this.RequestMessages = Throw.IfNull(requestMessages);
}
/// <summary>
/// Gets the agent that is being invoked.
/// </summary>
public AIAgent Agent { get; }
/// <summary>
/// Gets the agent session associated with the agent invocation.
/// </summary>
public AgentSession? Session { get; }
/// <summary>
/// Gets the caller provided messages that will be used by the agent for this invocation.
/// </summary>
@@ -172,15 +189,33 @@ public abstract class ChatHistoryProvider
/// <summary>
/// Initializes a new instance of the <see cref="InvokedContext"/> class with the specified request messages.
/// </summary>
/// <param name="agent">The agent being invoked.</param>
/// <param name="session">The session associated with the agent invocation.</param>
/// <param name="requestMessages">The caller provided messages that were used by the agent for this invocation.</param>
/// <param name="chatHistoryProviderMessages">The messages retrieved from the <see cref="ChatHistoryProvider"/> for this invocation.</param>
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<ChatMessage>? chatHistoryProviderMessages)
public InvokedContext(
AIAgent agent,
AgentSession? session,
IEnumerable<ChatMessage> requestMessages,
IEnumerable<ChatMessage>? chatHistoryProviderMessages)
{
this.Agent = Throw.IfNull(agent);
this.Session = session;
this.RequestMessages = Throw.IfNull(requestMessages);
this.ChatHistoryProviderMessages = chatHistoryProviderMessages;
}
/// <summary>
/// Gets the agent that is being invoked.
/// </summary>
public AIAgent Agent { get; }
/// <summary>
/// Gets the agent session associated with the agent invocation.
/// </summary>
public AgentSession? Session { get; }
/// <summary>
/// Gets the caller provided messages that were used by the agent for this invocation.
/// </summary>
@@ -231,8 +231,8 @@ public sealed partial class ChatClientAgent : AIAgent
}
catch (Exception ex)
{
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
@@ -246,8 +246,8 @@ public sealed partial class ChatClientAgent : AIAgent
}
catch (Exception ex)
{
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
@@ -273,8 +273,8 @@ public sealed partial class ChatClientAgent : AIAgent
}
catch (Exception ex)
{
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
}
@@ -286,10 +286,10 @@ public sealed partial class ChatClientAgent : AIAgent
await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false);
// To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request.
await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
@@ -455,8 +455,8 @@ public sealed partial class ChatClientAgent : AIAgent
}
catch (Exception ex)
{
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
@@ -473,10 +473,10 @@ public sealed partial class ChatClientAgent : AIAgent
}
// Only notify the session of new messages if the chatResponse was successful to avoid inconsistent message state in the session.
await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeSession, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
await this.NotifyAIContextProviderOfSuccessAsync(safeSession, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
var agentResponse = agentResponseFactoryFunc(chatResponse);
@@ -488,7 +488,7 @@ public sealed partial class ChatClientAgent : AIAgent
/// <summary>
/// Notify the <see cref="AIContextProvider"/> when an agent run succeeded, if there is an <see cref="AIContextProvider"/>.
/// </summary>
private static async Task NotifyAIContextProviderOfSuccessAsync(
private async Task NotifyAIContextProviderOfSuccessAsync(
ChatClientAgentSession session,
IEnumerable<ChatMessage> inputMessages,
IList<ChatMessage>? aiContextProviderMessages,
@@ -497,7 +497,7 @@ public sealed partial class ChatClientAgent : AIAgent
{
if (session.AIContextProvider is not null)
{
await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages },
await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages },
cancellationToken).ConfigureAwait(false);
}
}
@@ -505,7 +505,7 @@ public sealed partial class ChatClientAgent : AIAgent
/// <summary>
/// Notify the <see cref="AIContextProvider"/> of any failure during an agent run, if there is an <see cref="AIContextProvider"/>.
/// </summary>
private static async Task NotifyAIContextProviderOfFailureAsync(
private async Task NotifyAIContextProviderOfFailureAsync(
ChatClientAgentSession session,
Exception ex,
IEnumerable<ChatMessage> inputMessages,
@@ -514,7 +514,7 @@ public sealed partial class ChatClientAgent : AIAgent
{
if (session.AIContextProvider is not null)
{
await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { InvokeException = ex },
await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex },
cancellationToken).ConfigureAwait(false);
}
}
@@ -726,7 +726,7 @@ public sealed partial class ChatClientAgent : AIAgent
// Add any existing messages from the session to the messages to be sent to the chat client.
if (chatHistoryProvider is not null)
{
var invokingContext = new ChatHistoryProvider.InvokingContext(inputMessages);
var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessages);
var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
inputMessagesForChatClient.AddRange(providerMessages);
chatHistoryProviderMessages = providerMessages as IList<ChatMessage> ?? providerMessages.ToList();
@@ -739,7 +739,7 @@ public sealed partial class ChatClientAgent : AIAgent
// messages and options with the additional context.
if (typedSession.AIContextProvider is not null)
{
var invokingContext = new AIContextProvider.InvokingContext(inputMessages);
var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages);
var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
if (aiContext.Messages is { Count: > 0 })
{
@@ -812,7 +812,7 @@ public sealed partial class ChatClientAgent : AIAgent
}
}
private static Task NotifyChatHistoryProviderOfFailureAsync(
private Task NotifyChatHistoryProviderOfFailureAsync(
ChatClientAgentSession session,
Exception ex,
IEnumerable<ChatMessage> requestMessages,
@@ -827,7 +827,7 @@ public sealed partial class ChatClientAgent : AIAgent
// If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages.
if (provider is not null)
{
var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!)
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!)
{
AIContextProviderMessages = aiContextProviderMessages,
InvokeException = ex
@@ -839,7 +839,7 @@ public sealed partial class ChatClientAgent : AIAgent
return Task.CompletedTask;
}
private static Task NotifyChatHistoryProviderOfNewMessagesAsync(
private Task NotifyChatHistoryProviderOfNewMessagesAsync(
ChatClientAgentSession session,
IEnumerable<ChatMessage> requestMessages,
IEnumerable<ChatMessage>? chatHistoryProviderMessages,
@@ -854,7 +854,7 @@ public sealed partial class ChatClientAgent : AIAgent
// If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages.
if (provider is not null)
{
var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!)
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!)
{
AIContextProviderMessages = aiContextProviderMessages,
ResponseMessages = responseMessages
@@ -15,7 +15,7 @@ public interface IAgentFixture : IAsyncLifetime
{
AIAgent Agent { get; }
Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session);
Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session);
Task DeleteSessionAsync(AgentSession session);
}
@@ -106,7 +106,7 @@ public abstract class RunStreamingTests<TAgentFixture>(Func<TAgentFixture> creat
Assert.Contains("Paris", response1Text);
Assert.Contains("Vienna", response2Text);
var chatHistory = await this.Fixture.GetChatHistoryAsync(session);
var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session);
Assert.Equal(4, chatHistory.Count);
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User));
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant));
@@ -111,7 +111,7 @@ public abstract class RunTests<TAgentFixture>(Func<TAgentFixture> createAgentFix
Assert.Contains("Paris", result1.Text);
Assert.Contains("Vienna", result2.Text);
var chatHistory = await this.Fixture.GetChatHistoryAsync(session);
var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session);
Assert.Equal(4, chatHistory.Count);
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User));
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant));
@@ -35,7 +35,7 @@ public class AnthropicChatCompletionFixture : IChatClientAgentFixture
public IChatClient ChatClient => this._agent.ChatClient;
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
{
var typedSession = (ChatClientAgentSession)session;
@@ -44,7 +44,7 @@ public class AnthropicChatCompletionFixture : IChatClientAgentFixture
return [];
}
return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
}
public Task<ChatClientAgent> CreateChatClientAgentAsync(
@@ -33,7 +33,7 @@ public class AIProjectClientFixture : IChatClientAgentFixture
return response.Value.Id;
}
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
{
var chatClientSession = (ChatClientAgentSession)session;
@@ -53,7 +53,7 @@ public class AIProjectClientFixture : IChatClientAgentFixture
return [];
}
return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
}
private async Task<List<ChatMessage>> GetChatHistoryFromResponsesChainAsync(string conversationId)
@@ -24,7 +24,7 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture
public AIAgent Agent => this._agent;
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
{
List<ChatMessage> messages = [];
var typedSession = (ChatClientAgentSession)session;
@@ -20,7 +20,7 @@ public class CopilotStudioFixture : IAgentFixture
{
public AIAgent Agent { get; private set; } = null!;
public Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session) =>
public Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session) =>
throw new NotSupportedException("CopilotStudio doesn't allow retrieval of chat history.");
public Task DeleteSessionAsync(AgentSession session) =>
@@ -6,17 +6,21 @@ using System.Collections.ObjectModel;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
public class AIContextProviderTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
[Fact]
public async Task InvokedAsync_ReturnsCompletedTaskAsync()
{
var provider = new TestAIContextProvider();
var messages = new ReadOnlyCollection<ChatMessage>([]);
var task = provider.InvokedAsync(new(messages, aiContextProviderMessages: null));
var task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null));
Assert.Equal(default, task);
}
@@ -31,13 +35,13 @@ public class AIContextProviderTests
[Fact]
public void InvokingContext_Constructor_ThrowsForNullMessages()
{
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokingContext(null!));
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!));
}
[Fact]
public void InvokedContext_Constructor_ThrowsForNullMessages()
{
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokedContext(null!, aiContextProviderMessages: null));
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!, aiContextProviderMessages: null));
}
#region GetService Method Tests
@@ -163,7 +167,7 @@ public class AIContextProviderTests
{
// Arrange
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
var context = new AIContextProvider.InvokingContext(messages);
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
// Act & Assert
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
@@ -175,7 +179,7 @@ public class AIContextProviderTests
// Arrange
var initialMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
var context = new AIContextProvider.InvokingContext(initialMessages);
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages);
// Act
context.RequestMessages = newMessages;
@@ -184,6 +188,55 @@ public class AIContextProviderTests
Assert.Same(newMessages, context.RequestMessages);
}
[Fact]
public void InvokingContext_Agent_ReturnsConstructorValue()
{
// Arrange
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
// Assert
Assert.Same(s_mockAgent, context.Agent);
}
[Fact]
public void InvokingContext_Session_ReturnsConstructorValue()
{
// Arrange
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
// Assert
Assert.Same(s_mockSession, context.Session);
}
[Fact]
public void InvokingContext_Session_CanBeNull()
{
// Arrange
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act
var context = new AIContextProvider.InvokingContext(s_mockAgent, null, messages);
// Assert
Assert.Null(context.Session);
}
[Fact]
public void InvokingContext_Constructor_ThrowsForNullAgent()
{
// Arrange
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokingContext(null!, s_mockSession, messages));
}
#endregion
#region InvokedContext Tests
@@ -193,7 +246,7 @@ public class AIContextProviderTests
{
// Arrange
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
var context = new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null);
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null);
// Act & Assert
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
@@ -205,7 +258,7 @@ public class AIContextProviderTests
// Arrange
var initialMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
var context = new AIContextProvider.InvokedContext(initialMessages, aiContextProviderMessages: null);
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null);
// Act
context.RequestMessages = newMessages;
@@ -220,7 +273,7 @@ public class AIContextProviderTests
// Arrange
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
var aiContextMessages = new List<ChatMessage> { new(ChatRole.System, "AI context message") };
var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null);
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
// Act
context.AIContextProviderMessages = aiContextMessages;
@@ -235,7 +288,7 @@ public class AIContextProviderTests
// Arrange
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response message") };
var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null);
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
// Act
context.ResponseMessages = responseMessages;
@@ -250,7 +303,7 @@ public class AIContextProviderTests
// Arrange
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
var exception = new InvalidOperationException("Test exception");
var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null);
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
// Act
context.InvokeException = exception;
@@ -259,6 +312,55 @@ public class AIContextProviderTests
Assert.Same(exception, context.InvokeException);
}
[Fact]
public void InvokedContext_Agent_ReturnsConstructorValue()
{
// Arrange
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
// Assert
Assert.Same(s_mockAgent, context.Agent);
}
[Fact]
public void InvokedContext_Session_ReturnsConstructorValue()
{
// Arrange
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
// Assert
Assert.Same(s_mockSession, context.Session);
}
[Fact]
public void InvokedContext_Session_CanBeNull()
{
// Arrange
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act
var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, aiContextProviderMessages: null);
// Assert
Assert.Null(context.Session);
}
[Fact]
public void InvokedContext_Constructor_ThrowsForNullAgent()
{
// Arrange
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, aiContextProviderMessages: null));
}
#endregion
private sealed class TestAIContextProvider : AIContextProvider
@@ -14,6 +14,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// </summary>
public sealed class ChatHistoryProviderExtensionsTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
[Fact]
public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter()
{
@@ -35,7 +38,7 @@ public sealed class ChatHistoryProviderExtensionsTests
// Arrange
Mock<ChatHistoryProvider> providerMock = new();
List<ChatMessage> innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")];
ChatHistoryProvider.InvokingContext context = new([new ChatMessage(ChatRole.User, "Test")]);
ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
providerMock
.Setup(p => p.InvokingAsync(context, It.IsAny<CancellationToken>()))
@@ -59,7 +62,7 @@ public sealed class ChatHistoryProviderExtensionsTests
Mock<ChatHistoryProvider> providerMock = new();
List<ChatMessage> requestMessages = [new(ChatRole.User, "Hello")];
List<ChatMessage> chatHistoryProviderMessages = [new(ChatRole.System, "System")];
ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages)
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages)
{
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
};
@@ -106,7 +109,7 @@ public sealed class ChatHistoryProviderExtensionsTests
List<ChatMessage> requestMessages = [new(ChatRole.User, "Hello")];
List<ChatMessage> chatHistoryProviderMessages = [new(ChatRole.System, "System")];
List<ChatMessage> aiContextProviderMessages = [new(ChatRole.System, "Context")];
ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages)
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages)
{
AIContextProviderMessages = aiContextProviderMessages
};
@@ -16,6 +16,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// </summary>
public sealed class ChatHistoryProviderMessageFilterTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
[Fact]
public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException()
{
@@ -59,7 +62,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]);
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
innerProviderMock
.Setup(s => s.InvokingAsync(context, It.IsAny<CancellationToken>()))
@@ -88,7 +91,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
new(ChatRole.Assistant, "Hi there!"),
new(ChatRole.User, "How are you?")
};
var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]);
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
innerProviderMock
.Setup(s => s.InvokingAsync(context, It.IsAny<CancellationToken>()))
@@ -118,7 +121,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
new(ChatRole.User, "Hello"),
new(ChatRole.Assistant, "Hi there!")
};
var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]);
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
innerProviderMock
.Setup(s => s.InvokingAsync(context, It.IsAny<CancellationToken>()))
@@ -147,7 +150,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var chatHistoryProviderMessages = new List<ChatMessage> { new(ChatRole.System, "System") };
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response") };
var context = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages)
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages)
{
ResponseMessages = responseMessages
};
@@ -162,7 +165,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx)
{
var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList();
return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages, ctx.ChatHistoryProviderMessages)
return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages, ctx.ChatHistoryProviderMessages)
{
ResponseMessages = ctx.ResponseMessages,
AIContextProviderMessages = ctx.AIContextProviderMessages,
@@ -6,6 +6,7 @@ using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Moq;
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
@@ -14,6 +15,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// </summary>
public class ChatHistoryProviderTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
#region GetService Method Tests
[Fact]
@@ -82,7 +86,7 @@ public class ChatHistoryProviderTests
public void InvokingContext_Constructor_ThrowsForNullMessages()
{
// Arrange & Act & Assert
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokingContext(null!));
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, null!));
}
[Fact]
@@ -90,7 +94,7 @@ public class ChatHistoryProviderTests
{
// Arrange
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var context = new ChatHistoryProvider.InvokingContext(messages);
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
// Act & Assert
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
@@ -102,7 +106,7 @@ public class ChatHistoryProviderTests
// Arrange
var initialMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
var context = new ChatHistoryProvider.InvokingContext(initialMessages);
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages);
// Act
context.RequestMessages = newMessages;
@@ -111,6 +115,55 @@ public class ChatHistoryProviderTests
Assert.Same(newMessages, context.RequestMessages);
}
[Fact]
public void InvokingContext_Agent_ReturnsConstructorValue()
{
// Arrange
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
// Assert
Assert.Same(s_mockAgent, context.Agent);
}
[Fact]
public void InvokingContext_Session_ReturnsConstructorValue()
{
// Arrange
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
// Assert
Assert.Same(s_mockSession, context.Session);
}
[Fact]
public void InvokingContext_Session_CanBeNull()
{
// Arrange
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, null, messages);
// Assert
Assert.Null(context.Session);
}
[Fact]
public void InvokingContext_Constructor_ThrowsForNullAgent()
{
// Arrange
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokingContext(null!, s_mockSession, messages));
}
#endregion
#region InvokedContext Tests
@@ -119,7 +172,7 @@ public class ChatHistoryProviderTests
public void InvokedContext_Constructor_ThrowsForNullRequestMessages()
{
// Arrange & Act & Assert
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokedContext(null!, []));
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, []));
}
[Fact]
@@ -127,7 +180,7 @@ public class ChatHistoryProviderTests
{
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
// Act & Assert
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
@@ -139,7 +192,7 @@ public class ChatHistoryProviderTests
// Arrange
var initialMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
var context = new ChatHistoryProvider.InvokedContext(initialMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, []);
// Act
context.RequestMessages = newMessages;
@@ -154,7 +207,7 @@ public class ChatHistoryProviderTests
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var newProviderMessages = new List<ChatMessage> { new(ChatRole.System, "System message") };
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
// Act
context.ChatHistoryProviderMessages = newProviderMessages;
@@ -169,7 +222,7 @@ public class ChatHistoryProviderTests
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var aiContextMessages = new List<ChatMessage> { new(ChatRole.System, "AI context message") };
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
// Act
context.AIContextProviderMessages = aiContextMessages;
@@ -184,7 +237,7 @@ public class ChatHistoryProviderTests
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response message") };
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
// Act
context.ResponseMessages = responseMessages;
@@ -199,7 +252,7 @@ public class ChatHistoryProviderTests
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var exception = new InvalidOperationException("Test exception");
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
// Act
context.InvokeException = exception;
@@ -208,6 +261,55 @@ public class ChatHistoryProviderTests
Assert.Same(exception, context.InvokeException);
}
[Fact]
public void InvokedContext_Agent_ReturnsConstructorValue()
{
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
// Assert
Assert.Same(s_mockAgent, context.Agent);
}
[Fact]
public void InvokedContext_Session_ReturnsConstructorValue()
{
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
// Assert
Assert.Same(s_mockSession, context.Session);
}
[Fact]
public void InvokedContext_Session_CanBeNull()
{
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []);
// Assert
Assert.Null(context.Session);
}
[Fact]
public void InvokedContext_Constructor_ThrowsForNullAgent()
{
// Arrange
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, []));
}
#endregion
private sealed class TestChatHistoryProvider : ChatHistoryProvider
@@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
/// </summary>
public class InMemoryChatHistoryProviderTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
[Fact]
public void Constructor_Throws_ForNullReducer() =>
// Arrange & Act & Assert
@@ -68,7 +71,7 @@ public class InMemoryChatHistoryProviderTests
var provider = new InMemoryChatHistoryProvider();
provider.Add(providerMessages[0]);
var context = new ChatHistoryProvider.InvokedContext(requestMessages, providerMessages)
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, providerMessages)
{
AIContextProviderMessages = aiContextProviderMessages,
ResponseMessages = responseMessages
@@ -87,7 +90,7 @@ public class InMemoryChatHistoryProviderTests
{
var provider = new InMemoryChatHistoryProvider();
var context = new ChatHistoryProvider.InvokedContext([], []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [], []);
await provider.InvokedAsync(context, CancellationToken.None);
Assert.Empty(provider);
@@ -102,7 +105,7 @@ public class InMemoryChatHistoryProviderTests
new ChatMessage(ChatRole.Assistant, "Test2")
};
var context = new ChatHistoryProvider.InvokingContext([]);
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList();
Assert.Equal(2, result.Count);
@@ -183,7 +186,7 @@ public class InMemoryChatHistoryProviderTests
var provider = new InMemoryChatHistoryProvider();
var messages = new List<ChatMessage>();
var context = new ChatHistoryProvider.InvokedContext(messages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
await provider.InvokedAsync(context, CancellationToken.None);
Assert.Empty(provider);
@@ -520,7 +523,7 @@ public class InMemoryChatHistoryProviderTests
var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded);
// Act
var context = new ChatHistoryProvider.InvokedContext(originalMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []);
await provider.InvokedAsync(context, CancellationToken.None);
// Assert
@@ -556,7 +559,7 @@ public class InMemoryChatHistoryProviderTests
}
// Act
var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty<ChatMessage>());
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty<ChatMessage>());
var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList();
// Assert
@@ -579,7 +582,7 @@ public class InMemoryChatHistoryProviderTests
var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval);
// Act
var context = new ChatHistoryProvider.InvokedContext(originalMessages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []);
await provider.InvokedAsync(context, CancellationToken.None);
// Assert
@@ -605,7 +608,7 @@ public class InMemoryChatHistoryProviderTests
};
// Act
var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty<ChatMessage>());
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty<ChatMessage>());
var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList();
// Assert
@@ -627,7 +630,7 @@ public class InMemoryChatHistoryProviderTests
{
new(ChatRole.Assistant, "Hi there!")
};
var context = new ChatHistoryProvider.InvokedContext(requestMessages, [])
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, [])
{
ResponseMessages = responseMessages,
InvokeException = new InvalidOperationException("Test exception")
@@ -41,6 +41,9 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests;
[Collection("CosmosDB")]
public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
{
private static readonly AIAgent s_mockAgent = new Moq.Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Moq.Mock<AgentSession>().Object;
// Cosmos DB Emulator connection settings
private const string EmulatorEndpoint = "https://localhost:8081";
private const string EmulatorKey = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==";
@@ -214,7 +217,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId);
var message = new ChatMessage(ChatRole.User, "Hello, world!");
var context = new ChatHistoryProvider.InvokedContext([message], [])
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], [])
{
ResponseMessages = []
};
@@ -226,7 +229,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await Task.Delay(100);
// Assert
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var messages = await provider.InvokingAsync(invokingContext);
var messageList = messages.ToList();
@@ -293,7 +296,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
new ChatMessage(ChatRole.Assistant, "Response message")
};
var context = new ChatHistoryProvider.InvokedContext(requestMessages, [])
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, [])
{
AIContextProviderMessages = aiContextProviderMessages,
ResponseMessages = responseMessages
@@ -303,7 +306,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await provider.InvokedAsync(context);
// Assert
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var retrievedMessages = await provider.InvokingAsync(invokingContext);
var messageList = retrievedMessages.ToList();
Assert.Equal(5, messageList.Count);
@@ -327,7 +330,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString());
// Act
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var messages = await provider.InvokingAsync(invokingContext);
// Assert
@@ -346,15 +349,15 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1);
using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2);
var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")], []);
var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")], []);
var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []);
var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []);
await store1.InvokedAsync(context1);
await store2.InvokedAsync(context2);
// Act
var invokingContext1 = new ChatHistoryProvider.InvokingContext([]);
var invokingContext2 = new ChatHistoryProvider.InvokingContext([]);
var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var messages1 = await store1.InvokingAsync(invokingContext1);
var messages2 = await store2.InvokingAsync(invokingContext2);
@@ -391,11 +394,11 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
};
// Act 1: Add messages
var invokedContext = new ChatHistoryProvider.InvokedContext(messages, []);
var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
await originalStore.InvokedAsync(invokedContext);
// Act 2: Verify messages were added
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var retrievedMessages = await originalStore.InvokingAsync(invokingContext);
var retrievedList = retrievedMessages.ToList();
Assert.Equal(5, retrievedList.Count);
@@ -545,7 +548,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId);
var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!");
var context = new ChatHistoryProvider.InvokedContext([message], []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []);
// Act
await provider.InvokedAsync(context);
@@ -554,7 +557,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await Task.Delay(100);
// Assert
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var messages = await provider.InvokingAsync(invokingContext);
var messageList = messages.ToList();
@@ -602,7 +605,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
new ChatMessage(ChatRole.User, "Third hierarchical message")
};
var context = new ChatHistoryProvider.InvokedContext(messages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
// Act
await provider.InvokedAsync(context);
@@ -611,7 +614,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await Task.Delay(100);
// Assert
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var retrievedMessages = await provider.InvokingAsync(invokingContext);
var messageList = retrievedMessages.ToList();
@@ -637,8 +640,8 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId);
// Add messages to both stores
var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")], []);
var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")], []);
var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")], []);
var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")], []);
await store1.InvokedAsync(context1);
await store2.InvokedAsync(context2);
@@ -647,8 +650,8 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await Task.Delay(100);
// Act & Assert
var invokingContext1 = new ChatHistoryProvider.InvokingContext([]);
var invokingContext2 = new ChatHistoryProvider.InvokingContext([]);
var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var messages1 = await store1.InvokingAsync(invokingContext1);
var messageList1 = messages1.ToList();
@@ -675,7 +678,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId);
var context = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")], []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")], []);
await originalStore.InvokedAsync(context);
// Act - Serialize the provider state
@@ -693,7 +696,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await Task.Delay(100);
// Assert - The deserialized provider should have the same functionality
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var messages = await deserializedStore.InvokingAsync(invokingContext);
var messageList = messages.ToList();
@@ -717,8 +720,8 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId);
// Add messages to both
var simpleContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")], []);
var hierarchicalContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []);
var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []);
var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []);
await simpleProvider.InvokedAsync(simpleContext);
await hierarchicalProvider.InvokedAsync(hierarchicalContext);
@@ -727,7 +730,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await Task.Delay(100);
// Act & Assert
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var simpleMessages = await simpleProvider.InvokingAsync(invokingContext);
var simpleMessageList = simpleMessages.ToList();
@@ -760,7 +763,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
await Task.Delay(10); // Small delay to ensure different timestamps
}
var context = new ChatHistoryProvider.InvokedContext(messages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
await provider.InvokedAsync(context);
// Wait for eventual consistency
@@ -768,7 +771,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
// Act - Set max to 5 and retrieve
provider.MaxMessagesToRetrieve = 5;
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var retrievedMessages = await provider.InvokingAsync(invokingContext);
var messageList = retrievedMessages.ToList();
@@ -798,14 +801,14 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
messages.Add(new ChatMessage(ChatRole.User, $"Message {i}"));
}
var context = new ChatHistoryProvider.InvokedContext(messages, []);
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
await provider.InvokedAsync(context);
// Wait for eventual consistency
await Task.Delay(100);
// Act - No limit set (default null)
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
var retrievedMessages = await provider.InvokingAsync(invokingContext);
var messageList = retrievedMessages.ToList();
@@ -18,6 +18,9 @@ public sealed class Mem0ProviderTests : IDisposable
{
private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable.
private static readonly AIAgent s_mockAgent = new Moq.Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Moq.Mock<AgentSession>().Object;
private readonly HttpClient _httpClient;
public Mem0ProviderTests()
@@ -49,14 +52,14 @@ public sealed class Mem0ProviderTests : IDisposable
var sut = new Mem0Provider(this._httpClient, storageScope);
await sut.ClearStoredMemoriesAsync();
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty);
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext([input], aiContextProviderMessages: null));
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input], aiContextProviderMessages: null));
var ctxAfterAdding = await GetContextWithRetryAsync(sut, question);
await sut.ClearStoredMemoriesAsync();
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
// Assert
Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty);
@@ -73,14 +76,14 @@ public sealed class Mem0ProviderTests : IDisposable
var sut = new Mem0Provider(this._httpClient, storageScope);
await sut.ClearStoredMemoriesAsync();
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty);
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null));
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null));
var ctxAfterAdding = await GetContextWithRetryAsync(sut, question);
await sut.ClearStoredMemoriesAsync();
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
// Assert
Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty);
@@ -99,13 +102,13 @@ public sealed class Mem0ProviderTests : IDisposable
await sut1.ClearStoredMemoriesAsync();
await sut2.ClearStoredMemoriesAsync();
var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext([question]));
var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext([question]));
var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty);
Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty);
// Act
await sut1.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null));
await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null));
var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question);
var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question);
@@ -123,7 +126,7 @@ public sealed class Mem0ProviderTests : IDisposable
AIContext? ctx = null;
for (int i = 0; i < attempts; i++)
{
ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext([question]), CancellationToken.None);
ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None);
var text = ctx.Messages?[0].Text;
if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0)
{
@@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests;
/// </summary>
public sealed class Mem0ProviderTests : IDisposable
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
private readonly Mock<ILogger<Mem0Provider>> _loggerMock;
private readonly Mock<ILoggerFactory> _loggerFactoryMock;
private readonly RecordingHandler _handler = new();
@@ -96,7 +99,7 @@ public sealed class Mem0ProviderTests : IDisposable
UserId = "user"
};
var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "What is my name?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]);
// Act
var aiContext = await sut.InvokingAsync(invokingContext);
@@ -161,7 +164,7 @@ public sealed class Mem0ProviderTests : IDisposable
var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData };
var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Who am I?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]);
// Act
await sut.InvokingAsync(invokingContext, CancellationToken.None);
@@ -215,7 +218,7 @@ public sealed class Mem0ProviderTests : IDisposable
};
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
// Assert
var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList();
@@ -242,7 +245,7 @@ public sealed class Mem0ProviderTests : IDisposable
};
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") });
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") });
// Assert
Assert.Empty(this._handler.Requests);
@@ -268,7 +271,7 @@ public sealed class Mem0ProviderTests : IDisposable
};
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
// Assert
this._loggerMock.Verify(
@@ -318,7 +321,7 @@ public sealed class Mem0ProviderTests : IDisposable
};
// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
// Assert
Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count);
@@ -400,7 +403,7 @@ public sealed class Mem0ProviderTests : IDisposable
// Arrange
var storageScope = new Mem0ProviderScope { ApplicationId = "app" };
var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
// Act
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -327,4 +327,9 @@ public class ChatClientAgentSessionTests
}
#endregion
internal sealed class Animal
{
public string Name { get; set; } = string.Empty;
}
}
@@ -17,6 +17,9 @@ namespace Microsoft.Agents.AI.UnitTests.Data;
/// </summary>
public sealed class TextSearchProviderTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
private readonly Mock<ILogger<TextSearchProvider>> _loggerMock;
private readonly Mock<ILoggerFactory> _loggerFactoryMock;
@@ -64,10 +67,12 @@ public sealed class TextSearchProviderTests
var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null);
var invokingContext = new AIContextProvider.InvokingContext(
[
new ChatMessage(ChatRole.User, "Sample user question?"),
new ChatMessage(ChatRole.User, "Additional part")
]);
s_mockAgent,
s_mockSession,
[
new ChatMessage(ChatRole.User, "Sample user question?"),
new ChatMessage(ChatRole.User, "Additional part")
]);
// Act
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -139,7 +144,7 @@ public sealed class TextSearchProviderTests
FunctionToolDescription = overrideDescription
};
var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
// Act
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -158,7 +163,7 @@ public sealed class TextSearchProviderTests
{
// Arrange
var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
// Act
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -251,7 +256,7 @@ public sealed class TextSearchProviderTests
ContextFormatter = r => $"Custom formatted context with {r.Count} results."
};
var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
// Act
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -285,7 +290,7 @@ public sealed class TextSearchProviderTests
ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id))
};
var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
// Act
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -302,7 +307,7 @@ public sealed class TextSearchProviderTests
// Arrange
var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke };
var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
// Act
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -340,12 +345,14 @@ public sealed class TextSearchProviderTests
new ChatMessage(ChatRole.User, "C"),
new ChatMessage(ChatRole.Assistant, "D"),
};
await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") });
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") });
var invokingContext = new AIContextProvider.InvokingContext(
[
new ChatMessage(ChatRole.User, "E")
]);
s_mockAgent,
s_mockSession,
[
new ChatMessage(ChatRole.User, "E")
]);
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -380,12 +387,14 @@ public sealed class TextSearchProviderTests
new ChatMessage(ChatRole.User, "C"),
new ChatMessage(ChatRole.Assistant, "D"),
};
await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null));
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null));
var invokingContext = new AIContextProvider.InvokingContext(
[
new ChatMessage(ChatRole.User, "E")
]);
s_mockAgent,
s_mockSession,
[
new ChatMessage(ChatRole.User, "E")
]);
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -414,20 +423,24 @@ public sealed class TextSearchProviderTests
// First memory update (A,B)
await provider.InvokedAsync(new(
[
new ChatMessage(ChatRole.User, "A"),
new ChatMessage(ChatRole.Assistant, "B"),
], aiContextProviderMessages: null));
s_mockAgent,
s_mockSession,
[
new ChatMessage(ChatRole.User, "A"),
new ChatMessage(ChatRole.Assistant, "B"),
], aiContextProviderMessages: null));
// Second memory update (C,D,E)
await provider.InvokedAsync(new(
[
new ChatMessage(ChatRole.User, "C"),
new ChatMessage(ChatRole.Assistant, "D"),
new ChatMessage(ChatRole.User, "E"),
], aiContextProviderMessages: null));
s_mockAgent,
s_mockSession,
[
new ChatMessage(ChatRole.User, "C"),
new ChatMessage(ChatRole.Assistant, "D"),
new ChatMessage(ChatRole.User, "E"),
], aiContextProviderMessages: null));
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "F")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]);
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -462,12 +475,14 @@ public sealed class TextSearchProviderTests
new ChatMessage(ChatRole.User, "U2"),
new ChatMessage(ChatRole.Assistant, "A2"),
};
await provider.InvokedAsync(new(initialMessages, null));
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, null));
var invokingContext = new AIContextProvider.InvokingContext(
[
new ChatMessage(ChatRole.User, "Question?") // Current request message always appended.
]);
s_mockAgent,
s_mockSession,
[
new ChatMessage(ChatRole.User, "Question?") // Current request message always appended.
]);
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -518,7 +533,7 @@ public sealed class TextSearchProviderTests
};
// Act
await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); // Populate recent memory.
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Populate recent memory.
var state = provider.Serialize();
// Assert
@@ -547,7 +562,7 @@ public sealed class TextSearchProviderTests
new ChatMessage(ChatRole.User, "C"),
new ChatMessage(ChatRole.Assistant, "D"),
};
await provider.InvokedAsync(new(messages, aiContextProviderMessages: null));
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null));
// Act
var state = provider.Serialize();
@@ -563,7 +578,7 @@ public sealed class TextSearchProviderTests
RecentMessageMemoryLimit = 4
});
var emptyMessages = Array.Empty<ChatMessage>();
await roundTrippedProvider.InvokingAsync(new(emptyMessages), CancellationToken.None); // Trigger search to read memory.
await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory.
// Assert
Assert.NotNull(capturedInput);
@@ -588,7 +603,7 @@ public sealed class TextSearchProviderTests
new ChatMessage(ChatRole.Assistant, "L4"),
new ChatMessage(ChatRole.User, "L5"),
};
await initialProvider.InvokedAsync(new(messages, aiContextProviderMessages: null));
await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null));
var state = initialProvider.Serialize();
string? capturedInput = null;
@@ -604,7 +619,7 @@ public sealed class TextSearchProviderTests
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke,
RecentMessageMemoryLimit = 3 // Lower limit
});
await restoredProvider.InvokingAsync(new(Array.Empty<ChatMessage>()), CancellationToken.None);
await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty<ChatMessage>()), CancellationToken.None);
// Assert
Assert.NotNull(capturedInput);
@@ -631,7 +646,7 @@ public sealed class TextSearchProviderTests
RecentMessageMemoryLimit = 3
});
var emptyMessages = Array.Empty<ChatMessage>();
await provider.InvokingAsync(new(emptyMessages), CancellationToken.None);
await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None);
// Assert
Assert.NotNull(capturedInput);
@@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Memory.UnitTests;
/// </summary>
public class ChatHistoryMemoryProviderTests
{
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
private readonly Mock<ILogger<ChatHistoryMemoryProvider>> _loggerMock;
private readonly Mock<ILoggerFactory> _loggerFactoryMock;
@@ -116,7 +119,7 @@ public class ChatHistoryMemoryProviderTests
var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls");
var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" };
var invokedContext = new AIContextProvider.InvokedContext([requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null)
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null)
{
ResponseMessages = [responseMsg]
};
@@ -174,7 +177,7 @@ public class ChatHistoryMemoryProviderTests
1,
new ChatHistoryMemoryProviderScope() { UserId = "UID" });
var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" };
var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null)
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null)
{
InvokeException = new InvalidOperationException("Invoke failed")
};
@@ -203,7 +206,7 @@ public class ChatHistoryMemoryProviderTests
new ChatHistoryMemoryProviderScope() { UserId = "UID" },
loggerFactory: this._loggerFactoryMock.Object);
var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" };
var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null);
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null);
// Act
await provider.InvokedAsync(invokedContext, CancellationToken.None);
@@ -254,7 +257,7 @@ public class ChatHistoryMemoryProviderTests
loggerFactory: this._loggerFactoryMock.Object);
var requestMsg = new ChatMessage(ChatRole.User, "request text");
var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null);
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null);
// Act
await provider.InvokedAsync(invokedContext, CancellationToken.None);
@@ -327,7 +330,7 @@ public class ChatHistoryMemoryProviderTests
options: providerOptions);
var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history");
var invokingContext = new AIContextProvider.InvokingContext([requestMsg]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]);
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -378,7 +381,7 @@ public class ChatHistoryMemoryProviderTests
var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope);
var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history");
var invokingContext = new AIContextProvider.InvokingContext([requestMsg]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]);
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -442,7 +445,7 @@ public class ChatHistoryMemoryProviderTests
options: options,
loggerFactory: this._loggerFactoryMock.Object);
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "requesting relevant history")]);
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]);
// Act
await provider.InvokingAsync(invokingContext, CancellationToken.None);
@@ -14,4 +14,5 @@ namespace Microsoft.Agents.AI.UnitTests;
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(string[]))]
[JsonSerializable(typeof(Dictionary<string, object?>))]
[JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
@@ -23,7 +23,7 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture
public IChatClient ChatClient => this._agent.ChatClient;
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
{
var typedSession = (ChatClientAgentSession)session;
List<ChatMessage> messages = [];
@@ -28,7 +28,7 @@ public class OpenAIChatCompletionFixture : IChatClientAgentFixture
public IChatClient ChatClient => this._agent.ChatClient;
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
{
var typedSession = (ChatClientAgentSession)session;
@@ -37,7 +37,7 @@ public class OpenAIChatCompletionFixture : IChatClientAgentFixture
return [];
}
return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
}
public Task<ChatClientAgent> CreateChatClientAgentAsync(
@@ -25,7 +25,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture
public IChatClient ChatClient => this._agent.ChatClient;
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
{
var typedSession = (ChatClientAgentSession)session;
@@ -55,7 +55,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture
return [];
}
return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
}
private static ChatMessage ConvertToChatMessage(ResponseItem item)