mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: [BREAKING] Provide agent and session to AIContextProvider & ChatHistoryProvider (#3695)
* Add a StateBag to AgentSession and pass Agent and AgentSession to AIContextProvider and ChatHistoryProviders * Remove statebag code from this branch, to get the refactoring out of the way first * Apply suggestion from @rogerbarreto Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com> * Apply suggestion from @westey-m * Apply suggestion from @westey-m --------- Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
2b66ca03b2
commit
ec82ed15d2
+4
-4
@@ -55,14 +55,14 @@ namespace SampleApp
|
||||
}
|
||||
|
||||
// Get existing messages from the store
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(messages);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages);
|
||||
var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken);
|
||||
|
||||
// Clone the input messages and turn them into response messages with upper case text.
|
||||
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.Name).ToList();
|
||||
|
||||
// Notify the session of the input and output messages.
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages)
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages)
|
||||
{
|
||||
ResponseMessages = responseMessages
|
||||
};
|
||||
@@ -87,14 +87,14 @@ namespace SampleApp
|
||||
}
|
||||
|
||||
// Get existing messages from the store
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(messages);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages);
|
||||
var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken);
|
||||
|
||||
// Clone the input messages and turn them into response messages with upper case text.
|
||||
List<ChatMessage> responseMessages = CloneAndToUpperCase(messages, this.Name).ToList();
|
||||
|
||||
// Notify the session of the input and output messages.
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages)
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages)
|
||||
{
|
||||
ResponseMessages = responseMessages
|
||||
};
|
||||
|
||||
@@ -129,13 +129,30 @@ public abstract class AIContextProvider
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InvokingContext"/> class with the specified request messages.
|
||||
/// </summary>
|
||||
/// <param name="agent">The agent being invoked.</param>
|
||||
/// <param name="session">The session associated with the agent invocation.</param>
|
||||
/// <param name="requestMessages">The messages to be used by the agent for this invocation.</param>
|
||||
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
|
||||
public InvokingContext(IEnumerable<ChatMessage> requestMessages)
|
||||
public InvokingContext(
|
||||
AIAgent agent,
|
||||
AgentSession? session,
|
||||
IEnumerable<ChatMessage> requestMessages)
|
||||
{
|
||||
this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
|
||||
this.Agent = Throw.IfNull(agent);
|
||||
this.Session = session;
|
||||
this.RequestMessages = Throw.IfNull(requestMessages);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent that is being invoked.
|
||||
/// </summary>
|
||||
public AIAgent Agent { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent session associated with the agent invocation.
|
||||
/// </summary>
|
||||
public AgentSession? Session { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the caller provided messages that will be used by the agent for this invocation.
|
||||
/// </summary>
|
||||
@@ -158,15 +175,33 @@ public abstract class AIContextProvider
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InvokedContext"/> class with the specified request messages.
|
||||
/// </summary>
|
||||
/// <param name="agent">The agent being invoked.</param>
|
||||
/// <param name="session">The session associated with the agent invocation.</param>
|
||||
/// <param name="requestMessages">The caller provided messages that were used by the agent for this invocation.</param>
|
||||
/// <param name="aiContextProviderMessages">The messages provided by the <see cref="AIContextProvider"/> for this invocation, if any.</param>
|
||||
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
|
||||
public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<ChatMessage>? aiContextProviderMessages)
|
||||
public InvokedContext(
|
||||
AIAgent agent,
|
||||
AgentSession? session,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
IEnumerable<ChatMessage>? aiContextProviderMessages)
|
||||
{
|
||||
this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
|
||||
this.Agent = Throw.IfNull(agent);
|
||||
this.Session = session;
|
||||
this.RequestMessages = Throw.IfNull(requestMessages);
|
||||
this.AIContextProviderMessages = aiContextProviderMessages;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent that is being invoked.
|
||||
/// </summary>
|
||||
public AIAgent Agent { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent session associated with the agent invocation.
|
||||
/// </summary>
|
||||
public AgentSession? Session { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the caller provided messages that were used by the agent for this invocation.
|
||||
/// </summary>
|
||||
|
||||
@@ -143,13 +143,30 @@ public abstract class ChatHistoryProvider
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InvokingContext"/> class with the specified request messages.
|
||||
/// </summary>
|
||||
/// <param name="agent">The agent being invoked.</param>
|
||||
/// <param name="session">The session associated with the agent invocation.</param>
|
||||
/// <param name="requestMessages">The new messages to be used by the agent for this invocation.</param>
|
||||
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
|
||||
public InvokingContext(IEnumerable<ChatMessage> requestMessages)
|
||||
public InvokingContext(
|
||||
AIAgent agent,
|
||||
AgentSession? session,
|
||||
IEnumerable<ChatMessage> requestMessages)
|
||||
{
|
||||
this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages));
|
||||
this.Agent = Throw.IfNull(agent);
|
||||
this.Session = session;
|
||||
this.RequestMessages = Throw.IfNull(requestMessages);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent that is being invoked.
|
||||
/// </summary>
|
||||
public AIAgent Agent { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent session associated with the agent invocation.
|
||||
/// </summary>
|
||||
public AgentSession? Session { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the caller provided messages that will be used by the agent for this invocation.
|
||||
/// </summary>
|
||||
@@ -172,15 +189,33 @@ public abstract class ChatHistoryProvider
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InvokedContext"/> class with the specified request messages.
|
||||
/// </summary>
|
||||
/// <param name="agent">The agent being invoked.</param>
|
||||
/// <param name="session">The session associated with the agent invocation.</param>
|
||||
/// <param name="requestMessages">The caller provided messages that were used by the agent for this invocation.</param>
|
||||
/// <param name="chatHistoryProviderMessages">The messages retrieved from the <see cref="ChatHistoryProvider"/> for this invocation.</param>
|
||||
/// <exception cref="ArgumentNullException"><paramref name="requestMessages"/> is <see langword="null"/>.</exception>
|
||||
public InvokedContext(IEnumerable<ChatMessage> requestMessages, IEnumerable<ChatMessage>? chatHistoryProviderMessages)
|
||||
public InvokedContext(
|
||||
AIAgent agent,
|
||||
AgentSession? session,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
IEnumerable<ChatMessage>? chatHistoryProviderMessages)
|
||||
{
|
||||
this.Agent = Throw.IfNull(agent);
|
||||
this.Session = session;
|
||||
this.RequestMessages = Throw.IfNull(requestMessages);
|
||||
this.ChatHistoryProviderMessages = chatHistoryProviderMessages;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent that is being invoked.
|
||||
/// </summary>
|
||||
public AIAgent Agent { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the agent session associated with the agent invocation.
|
||||
/// </summary>
|
||||
public AgentSession? Session { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the caller provided messages that were used by the agent for this invocation.
|
||||
/// </summary>
|
||||
|
||||
@@ -231,8 +231,8 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
@@ -246,8 +246,8 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
@@ -273,8 +273,8 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@@ -286,10 +286,10 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
// To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request.
|
||||
await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
// Notify the AIContextProvider of all new messages.
|
||||
await NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
@@ -455,8 +455,8 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await NotifyAIContextProviderOfFailureAsync(safeSession, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
|
||||
throw;
|
||||
}
|
||||
|
||||
@@ -473,10 +473,10 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
}
|
||||
|
||||
// Only notify the session of new messages if the chatResponse was successful to avoid inconsistent message state in the session.
|
||||
await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, inputMessages, chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
// Notify the AIContextProvider of all new messages.
|
||||
await NotifyAIContextProviderOfSuccessAsync(safeSession, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
await this.NotifyAIContextProviderOfSuccessAsync(safeSession, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
var agentResponse = agentResponseFactoryFunc(chatResponse);
|
||||
|
||||
@@ -488,7 +488,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
/// <summary>
|
||||
/// Notify the <see cref="AIContextProvider"/> when an agent run succeeded, if there is an <see cref="AIContextProvider"/>.
|
||||
/// </summary>
|
||||
private static async Task NotifyAIContextProviderOfSuccessAsync(
|
||||
private async Task NotifyAIContextProviderOfSuccessAsync(
|
||||
ChatClientAgentSession session,
|
||||
IEnumerable<ChatMessage> inputMessages,
|
||||
IList<ChatMessage>? aiContextProviderMessages,
|
||||
@@ -497,7 +497,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
{
|
||||
if (session.AIContextProvider is not null)
|
||||
{
|
||||
await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages },
|
||||
await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages },
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
@@ -505,7 +505,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
/// <summary>
|
||||
/// Notify the <see cref="AIContextProvider"/> of any failure during an agent run, if there is an <see cref="AIContextProvider"/>.
|
||||
/// </summary>
|
||||
private static async Task NotifyAIContextProviderOfFailureAsync(
|
||||
private async Task NotifyAIContextProviderOfFailureAsync(
|
||||
ChatClientAgentSession session,
|
||||
Exception ex,
|
||||
IEnumerable<ChatMessage> inputMessages,
|
||||
@@ -514,7 +514,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
{
|
||||
if (session.AIContextProvider is not null)
|
||||
{
|
||||
await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { InvokeException = ex },
|
||||
await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex },
|
||||
cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
@@ -726,7 +726,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
// Add any existing messages from the session to the messages to be sent to the chat client.
|
||||
if (chatHistoryProvider is not null)
|
||||
{
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(inputMessages);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessages);
|
||||
var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
|
||||
inputMessagesForChatClient.AddRange(providerMessages);
|
||||
chatHistoryProviderMessages = providerMessages as IList<ChatMessage> ?? providerMessages.ToList();
|
||||
@@ -739,7 +739,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
// messages and options with the additional context.
|
||||
if (typedSession.AIContextProvider is not null)
|
||||
{
|
||||
var invokingContext = new AIContextProvider.InvokingContext(inputMessages);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages);
|
||||
var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false);
|
||||
if (aiContext.Messages is { Count: > 0 })
|
||||
{
|
||||
@@ -812,7 +812,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
}
|
||||
}
|
||||
|
||||
private static Task NotifyChatHistoryProviderOfFailureAsync(
|
||||
private Task NotifyChatHistoryProviderOfFailureAsync(
|
||||
ChatClientAgentSession session,
|
||||
Exception ex,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
@@ -827,7 +827,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
// If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages.
|
||||
if (provider is not null)
|
||||
{
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!)
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!)
|
||||
{
|
||||
AIContextProviderMessages = aiContextProviderMessages,
|
||||
InvokeException = ex
|
||||
@@ -839,7 +839,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private static Task NotifyChatHistoryProviderOfNewMessagesAsync(
|
||||
private Task NotifyChatHistoryProviderOfNewMessagesAsync(
|
||||
ChatClientAgentSession session,
|
||||
IEnumerable<ChatMessage> requestMessages,
|
||||
IEnumerable<ChatMessage>? chatHistoryProviderMessages,
|
||||
@@ -854,7 +854,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
// If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages.
|
||||
if (provider is not null)
|
||||
{
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!)
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!)
|
||||
{
|
||||
AIContextProviderMessages = aiContextProviderMessages,
|
||||
ResponseMessages = responseMessages
|
||||
|
||||
@@ -15,7 +15,7 @@ public interface IAgentFixture : IAsyncLifetime
|
||||
{
|
||||
AIAgent Agent { get; }
|
||||
|
||||
Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session);
|
||||
Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session);
|
||||
|
||||
Task DeleteSessionAsync(AgentSession session);
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ public abstract class RunStreamingTests<TAgentFixture>(Func<TAgentFixture> creat
|
||||
Assert.Contains("Paris", response1Text);
|
||||
Assert.Contains("Vienna", response2Text);
|
||||
|
||||
var chatHistory = await this.Fixture.GetChatHistoryAsync(session);
|
||||
var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session);
|
||||
Assert.Equal(4, chatHistory.Count);
|
||||
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User));
|
||||
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant));
|
||||
|
||||
@@ -111,7 +111,7 @@ public abstract class RunTests<TAgentFixture>(Func<TAgentFixture> createAgentFix
|
||||
Assert.Contains("Paris", result1.Text);
|
||||
Assert.Contains("Vienna", result2.Text);
|
||||
|
||||
var chatHistory = await this.Fixture.GetChatHistoryAsync(session);
|
||||
var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session);
|
||||
Assert.Equal(4, chatHistory.Count);
|
||||
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User));
|
||||
Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant));
|
||||
|
||||
+2
-2
@@ -35,7 +35,7 @@ public class AnthropicChatCompletionFixture : IChatClientAgentFixture
|
||||
|
||||
public IChatClient ChatClient => this._agent.ChatClient;
|
||||
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
|
||||
{
|
||||
var typedSession = (ChatClientAgentSession)session;
|
||||
|
||||
@@ -44,7 +44,7 @@ public class AnthropicChatCompletionFixture : IChatClientAgentFixture
|
||||
return [];
|
||||
}
|
||||
|
||||
return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
|
||||
return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
|
||||
}
|
||||
|
||||
public Task<ChatClientAgent> CreateChatClientAgentAsync(
|
||||
|
||||
@@ -33,7 +33,7 @@ public class AIProjectClientFixture : IChatClientAgentFixture
|
||||
return response.Value.Id;
|
||||
}
|
||||
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
|
||||
{
|
||||
var chatClientSession = (ChatClientAgentSession)session;
|
||||
|
||||
@@ -53,7 +53,7 @@ public class AIProjectClientFixture : IChatClientAgentFixture
|
||||
return [];
|
||||
}
|
||||
|
||||
return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
|
||||
return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
|
||||
}
|
||||
|
||||
private async Task<List<ChatMessage>> GetChatHistoryFromResponsesChainAsync(string conversationId)
|
||||
|
||||
+1
-1
@@ -24,7 +24,7 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture
|
||||
|
||||
public AIAgent Agent => this._agent;
|
||||
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
|
||||
{
|
||||
List<ChatMessage> messages = [];
|
||||
var typedSession = (ChatClientAgentSession)session;
|
||||
|
||||
@@ -20,7 +20,7 @@ public class CopilotStudioFixture : IAgentFixture
|
||||
{
|
||||
public AIAgent Agent { get; private set; } = null!;
|
||||
|
||||
public Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session) =>
|
||||
public Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session) =>
|
||||
throw new NotSupportedException("CopilotStudio doesn't allow retrieval of chat history.");
|
||||
|
||||
public Task DeleteSessionAsync(AgentSession session) =>
|
||||
|
||||
+112
-10
@@ -6,17 +6,21 @@ using System.Collections.ObjectModel;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
public class AIContextProviderTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
[Fact]
|
||||
public async Task InvokedAsync_ReturnsCompletedTaskAsync()
|
||||
{
|
||||
var provider = new TestAIContextProvider();
|
||||
var messages = new ReadOnlyCollection<ChatMessage>([]);
|
||||
var task = provider.InvokedAsync(new(messages, aiContextProviderMessages: null));
|
||||
var task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null));
|
||||
Assert.Equal(default, task);
|
||||
}
|
||||
|
||||
@@ -31,13 +35,13 @@ public class AIContextProviderTests
|
||||
[Fact]
|
||||
public void InvokingContext_Constructor_ThrowsForNullMessages()
|
||||
{
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokingContext(null!));
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Constructor_ThrowsForNullMessages()
|
||||
{
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokedContext(null!, aiContextProviderMessages: null));
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!, aiContextProviderMessages: null));
|
||||
}
|
||||
|
||||
#region GetService Method Tests
|
||||
@@ -163,7 +167,7 @@ public class AIContextProviderTests
|
||||
{
|
||||
// Arrange
|
||||
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
var context = new AIContextProvider.InvokingContext(messages);
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
|
||||
@@ -175,7 +179,7 @@ public class AIContextProviderTests
|
||||
// Arrange
|
||||
var initialMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
|
||||
var context = new AIContextProvider.InvokingContext(initialMessages);
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages);
|
||||
|
||||
// Act
|
||||
context.RequestMessages = newMessages;
|
||||
@@ -184,6 +188,55 @@ public class AIContextProviderTests
|
||||
Assert.Same(newMessages, context.RequestMessages);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Agent_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockAgent, context.Agent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Session_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockSession, context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Session_CanBeNull()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, null, messages);
|
||||
|
||||
// Assert
|
||||
Assert.Null(context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Constructor_ThrowsForNullAgent()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokingContext(null!, s_mockSession, messages));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region InvokedContext Tests
|
||||
@@ -193,7 +246,7 @@ public class AIContextProviderTests
|
||||
{
|
||||
// Arrange
|
||||
var messages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
var context = new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null);
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
|
||||
@@ -205,7 +258,7 @@ public class AIContextProviderTests
|
||||
// Arrange
|
||||
var initialMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
|
||||
var context = new AIContextProvider.InvokedContext(initialMessages, aiContextProviderMessages: null);
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null);
|
||||
|
||||
// Act
|
||||
context.RequestMessages = newMessages;
|
||||
@@ -220,7 +273,7 @@ public class AIContextProviderTests
|
||||
// Arrange
|
||||
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
var aiContextMessages = new List<ChatMessage> { new(ChatRole.System, "AI context message") };
|
||||
var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null);
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
|
||||
|
||||
// Act
|
||||
context.AIContextProviderMessages = aiContextMessages;
|
||||
@@ -235,7 +288,7 @@ public class AIContextProviderTests
|
||||
// Arrange
|
||||
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response message") };
|
||||
var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null);
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
|
||||
|
||||
// Act
|
||||
context.ResponseMessages = responseMessages;
|
||||
@@ -250,7 +303,7 @@ public class AIContextProviderTests
|
||||
// Arrange
|
||||
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
var exception = new InvalidOperationException("Test exception");
|
||||
var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null);
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
|
||||
|
||||
// Act
|
||||
context.InvokeException = exception;
|
||||
@@ -259,6 +312,55 @@ public class AIContextProviderTests
|
||||
Assert.Same(exception, context.InvokeException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Agent_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockAgent, context.Agent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Session_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockSession, context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Session_CanBeNull()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act
|
||||
var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, aiContextProviderMessages: null);
|
||||
|
||||
// Assert
|
||||
Assert.Null(context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Constructor_ThrowsForNullAgent()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new ReadOnlyCollection<ChatMessage>([new(ChatRole.User, "Hello")]);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, aiContextProviderMessages: null));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private sealed class TestAIContextProvider : AIContextProvider
|
||||
|
||||
+6
-3
@@ -14,6 +14,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
/// </summary>
|
||||
public sealed class ChatHistoryProviderExtensionsTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
[Fact]
|
||||
public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter()
|
||||
{
|
||||
@@ -35,7 +38,7 @@ public sealed class ChatHistoryProviderExtensionsTests
|
||||
// Arrange
|
||||
Mock<ChatHistoryProvider> providerMock = new();
|
||||
List<ChatMessage> innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")];
|
||||
ChatHistoryProvider.InvokingContext context = new([new ChatMessage(ChatRole.User, "Test")]);
|
||||
ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
providerMock
|
||||
.Setup(p => p.InvokingAsync(context, It.IsAny<CancellationToken>()))
|
||||
@@ -59,7 +62,7 @@ public sealed class ChatHistoryProviderExtensionsTests
|
||||
Mock<ChatHistoryProvider> providerMock = new();
|
||||
List<ChatMessage> requestMessages = [new(ChatRole.User, "Hello")];
|
||||
List<ChatMessage> chatHistoryProviderMessages = [new(ChatRole.System, "System")];
|
||||
ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages)
|
||||
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages)
|
||||
{
|
||||
ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")]
|
||||
};
|
||||
@@ -106,7 +109,7 @@ public sealed class ChatHistoryProviderExtensionsTests
|
||||
List<ChatMessage> requestMessages = [new(ChatRole.User, "Hello")];
|
||||
List<ChatMessage> chatHistoryProviderMessages = [new(ChatRole.System, "System")];
|
||||
List<ChatMessage> aiContextProviderMessages = [new(ChatRole.System, "Context")];
|
||||
ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages)
|
||||
ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages)
|
||||
{
|
||||
AIContextProviderMessages = aiContextProviderMessages
|
||||
};
|
||||
|
||||
+8
-5
@@ -16,6 +16,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
/// </summary>
|
||||
public sealed class ChatHistoryProviderMessageFilterTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
[Fact]
|
||||
public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException()
|
||||
{
|
||||
@@ -59,7 +62,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
|
||||
new(ChatRole.User, "Hello"),
|
||||
new(ChatRole.Assistant, "Hi there!")
|
||||
};
|
||||
var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]);
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
innerProviderMock
|
||||
.Setup(s => s.InvokingAsync(context, It.IsAny<CancellationToken>()))
|
||||
@@ -88,7 +91,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
|
||||
new(ChatRole.Assistant, "Hi there!"),
|
||||
new(ChatRole.User, "How are you?")
|
||||
};
|
||||
var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]);
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
innerProviderMock
|
||||
.Setup(s => s.InvokingAsync(context, It.IsAny<CancellationToken>()))
|
||||
@@ -118,7 +121,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
|
||||
new(ChatRole.User, "Hello"),
|
||||
new(ChatRole.Assistant, "Hi there!")
|
||||
};
|
||||
var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]);
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]);
|
||||
|
||||
innerProviderMock
|
||||
.Setup(s => s.InvokingAsync(context, It.IsAny<CancellationToken>()))
|
||||
@@ -147,7 +150,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var chatHistoryProviderMessages = new List<ChatMessage> { new(ChatRole.System, "System") };
|
||||
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response") };
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages)
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages)
|
||||
{
|
||||
ResponseMessages = responseMessages
|
||||
};
|
||||
@@ -162,7 +165,7 @@ public sealed class ChatHistoryProviderMessageFilterTests
|
||||
ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx)
|
||||
{
|
||||
var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList();
|
||||
return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages, ctx.ChatHistoryProviderMessages)
|
||||
return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages, ctx.ChatHistoryProviderMessages)
|
||||
{
|
||||
ResponseMessages = ctx.ResponseMessages,
|
||||
AIContextProviderMessages = ctx.AIContextProviderMessages,
|
||||
|
||||
+112
-10
@@ -6,6 +6,7 @@ using System.Text.Json;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
@@ -14,6 +15,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
/// </summary>
|
||||
public class ChatHistoryProviderTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
#region GetService Method Tests
|
||||
|
||||
[Fact]
|
||||
@@ -82,7 +86,7 @@ public class ChatHistoryProviderTests
|
||||
public void InvokingContext_Constructor_ThrowsForNullMessages()
|
||||
{
|
||||
// Arrange & Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokingContext(null!));
|
||||
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
@@ -90,7 +94,7 @@ public class ChatHistoryProviderTests
|
||||
{
|
||||
// Arrange
|
||||
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var context = new ChatHistoryProvider.InvokingContext(messages);
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
|
||||
@@ -102,7 +106,7 @@ public class ChatHistoryProviderTests
|
||||
// Arrange
|
||||
var initialMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
|
||||
var context = new ChatHistoryProvider.InvokingContext(initialMessages);
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages);
|
||||
|
||||
// Act
|
||||
context.RequestMessages = newMessages;
|
||||
@@ -111,6 +115,55 @@ public class ChatHistoryProviderTests
|
||||
Assert.Same(newMessages, context.RequestMessages);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Agent_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockAgent, context.Agent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Session_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockSession, context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Session_CanBeNull()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, null, messages);
|
||||
|
||||
// Assert
|
||||
Assert.Null(context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokingContext_Constructor_ThrowsForNullAgent()
|
||||
{
|
||||
// Arrange
|
||||
var messages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokingContext(null!, s_mockSession, messages));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region InvokedContext Tests
|
||||
@@ -119,7 +172,7 @@ public class ChatHistoryProviderTests
|
||||
public void InvokedContext_Constructor_ThrowsForNullRequestMessages()
|
||||
{
|
||||
// Arrange & Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokedContext(null!, []));
|
||||
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, []));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
@@ -127,7 +180,7 @@ public class ChatHistoryProviderTests
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => context.RequestMessages = null!);
|
||||
@@ -139,7 +192,7 @@ public class ChatHistoryProviderTests
|
||||
// Arrange
|
||||
var initialMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var newMessages = new List<ChatMessage> { new(ChatRole.User, "New message") };
|
||||
var context = new ChatHistoryProvider.InvokedContext(initialMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, []);
|
||||
|
||||
// Act
|
||||
context.RequestMessages = newMessages;
|
||||
@@ -154,7 +207,7 @@ public class ChatHistoryProviderTests
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var newProviderMessages = new List<ChatMessage> { new(ChatRole.System, "System message") };
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
|
||||
|
||||
// Act
|
||||
context.ChatHistoryProviderMessages = newProviderMessages;
|
||||
@@ -169,7 +222,7 @@ public class ChatHistoryProviderTests
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var aiContextMessages = new List<ChatMessage> { new(ChatRole.System, "AI context message") };
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
|
||||
|
||||
// Act
|
||||
context.AIContextProviderMessages = aiContextMessages;
|
||||
@@ -184,7 +237,7 @@ public class ChatHistoryProviderTests
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var responseMessages = new List<ChatMessage> { new(ChatRole.Assistant, "Response message") };
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
|
||||
|
||||
// Act
|
||||
context.ResponseMessages = responseMessages;
|
||||
@@ -199,7 +252,7 @@ public class ChatHistoryProviderTests
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
var exception = new InvalidOperationException("Test exception");
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
|
||||
|
||||
// Act
|
||||
context.InvokeException = exception;
|
||||
@@ -208,6 +261,55 @@ public class ChatHistoryProviderTests
|
||||
Assert.Same(exception, context.InvokeException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Agent_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockAgent, context.Agent);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Session_ReturnsConstructorValue()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []);
|
||||
|
||||
// Assert
|
||||
Assert.Same(s_mockSession, context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Session_CanBeNull()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []);
|
||||
|
||||
// Assert
|
||||
Assert.Null(context.Session);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InvokedContext_Constructor_ThrowsForNullAgent()
|
||||
{
|
||||
// Arrange
|
||||
var requestMessages = new List<ChatMessage> { new(ChatRole.User, "Hello") };
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, []));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
private sealed class TestChatHistoryProvider : ChatHistoryProvider
|
||||
|
||||
+12
-9
@@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
/// </summary>
|
||||
public class InMemoryChatHistoryProviderTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
[Fact]
|
||||
public void Constructor_Throws_ForNullReducer() =>
|
||||
// Arrange & Act & Assert
|
||||
@@ -68,7 +71,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
|
||||
var provider = new InMemoryChatHistoryProvider();
|
||||
provider.Add(providerMessages[0]);
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, providerMessages)
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, providerMessages)
|
||||
{
|
||||
AIContextProviderMessages = aiContextProviderMessages,
|
||||
ResponseMessages = responseMessages
|
||||
@@ -87,7 +90,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
{
|
||||
var provider = new InMemoryChatHistoryProvider();
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext([], []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [], []);
|
||||
await provider.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
Assert.Empty(provider);
|
||||
@@ -102,7 +105,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
new ChatMessage(ChatRole.Assistant, "Test2")
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokingContext([]);
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList();
|
||||
|
||||
Assert.Equal(2, result.Count);
|
||||
@@ -183,7 +186,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
var provider = new InMemoryChatHistoryProvider();
|
||||
var messages = new List<ChatMessage>();
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(messages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
|
||||
await provider.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
Assert.Empty(provider);
|
||||
@@ -520,7 +523,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded);
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokedContext(originalMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []);
|
||||
await provider.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
@@ -556,7 +559,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
}
|
||||
|
||||
// Act
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty<ChatMessage>());
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty<ChatMessage>());
|
||||
var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList();
|
||||
|
||||
// Assert
|
||||
@@ -579,7 +582,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval);
|
||||
|
||||
// Act
|
||||
var context = new ChatHistoryProvider.InvokedContext(originalMessages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []);
|
||||
await provider.InvokedAsync(context, CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
@@ -605,7 +608,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
};
|
||||
|
||||
// Act
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty<ChatMessage>());
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty<ChatMessage>());
|
||||
var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList();
|
||||
|
||||
// Assert
|
||||
@@ -627,7 +630,7 @@ public class InMemoryChatHistoryProviderTests
|
||||
{
|
||||
new(ChatRole.Assistant, "Hi there!")
|
||||
};
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, [])
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, [])
|
||||
{
|
||||
ResponseMessages = responseMessages,
|
||||
InvokeException = new InvalidOperationException("Test exception")
|
||||
|
||||
+31
-28
@@ -41,6 +41,9 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests;
|
||||
[Collection("CosmosDB")]
|
||||
public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Moq.Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Moq.Mock<AgentSession>().Object;
|
||||
|
||||
// Cosmos DB Emulator connection settings
|
||||
private const string EmulatorEndpoint = "https://localhost:8081";
|
||||
private const string EmulatorKey = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==";
|
||||
@@ -214,7 +217,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId);
|
||||
var message = new ChatMessage(ChatRole.User, "Hello, world!");
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext([message], [])
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], [])
|
||||
{
|
||||
ResponseMessages = []
|
||||
};
|
||||
@@ -226,7 +229,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var messages = await provider.InvokingAsync(invokingContext);
|
||||
var messageList = messages.ToList();
|
||||
|
||||
@@ -293,7 +296,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
new ChatMessage(ChatRole.Assistant, "Response message")
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(requestMessages, [])
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, [])
|
||||
{
|
||||
AIContextProviderMessages = aiContextProviderMessages,
|
||||
ResponseMessages = responseMessages
|
||||
@@ -303,7 +306,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await provider.InvokedAsync(context);
|
||||
|
||||
// Assert
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var retrievedMessages = await provider.InvokingAsync(invokingContext);
|
||||
var messageList = retrievedMessages.ToList();
|
||||
Assert.Equal(5, messageList.Count);
|
||||
@@ -327,7 +330,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString());
|
||||
|
||||
// Act
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var messages = await provider.InvokingAsync(invokingContext);
|
||||
|
||||
// Assert
|
||||
@@ -346,15 +349,15 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1);
|
||||
using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2);
|
||||
|
||||
var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")], []);
|
||||
var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")], []);
|
||||
var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []);
|
||||
var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []);
|
||||
|
||||
await store1.InvokedAsync(context1);
|
||||
await store2.InvokedAsync(context2);
|
||||
|
||||
// Act
|
||||
var invokingContext1 = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext2 = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
|
||||
var messages1 = await store1.InvokingAsync(invokingContext1);
|
||||
var messages2 = await store2.InvokingAsync(invokingContext2);
|
||||
@@ -391,11 +394,11 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
};
|
||||
|
||||
// Act 1: Add messages
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(messages, []);
|
||||
var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
|
||||
await originalStore.InvokedAsync(invokedContext);
|
||||
|
||||
// Act 2: Verify messages were added
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var retrievedMessages = await originalStore.InvokingAsync(invokingContext);
|
||||
var retrievedList = retrievedMessages.ToList();
|
||||
Assert.Equal(5, retrievedList.Count);
|
||||
@@ -545,7 +548,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId);
|
||||
var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!");
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext([message], []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []);
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(context);
|
||||
@@ -554,7 +557,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var messages = await provider.InvokingAsync(invokingContext);
|
||||
var messageList = messages.ToList();
|
||||
|
||||
@@ -602,7 +605,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
new ChatMessage(ChatRole.User, "Third hierarchical message")
|
||||
};
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(messages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(context);
|
||||
@@ -611,7 +614,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var retrievedMessages = await provider.InvokingAsync(invokingContext);
|
||||
var messageList = retrievedMessages.ToList();
|
||||
|
||||
@@ -637,8 +640,8 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId);
|
||||
|
||||
// Add messages to both stores
|
||||
var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")], []);
|
||||
var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")], []);
|
||||
var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")], []);
|
||||
var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")], []);
|
||||
|
||||
await store1.InvokedAsync(context1);
|
||||
await store2.InvokedAsync(context2);
|
||||
@@ -647,8 +650,8 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await Task.Delay(100);
|
||||
|
||||
// Act & Assert
|
||||
var invokingContext1 = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext2 = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
|
||||
var messages1 = await store1.InvokingAsync(invokingContext1);
|
||||
var messageList1 = messages1.ToList();
|
||||
@@ -675,7 +678,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
|
||||
using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId);
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")], []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")], []);
|
||||
await originalStore.InvokedAsync(context);
|
||||
|
||||
// Act - Serialize the provider state
|
||||
@@ -693,7 +696,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await Task.Delay(100);
|
||||
|
||||
// Assert - The deserialized provider should have the same functionality
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var messages = await deserializedStore.InvokingAsync(invokingContext);
|
||||
var messageList = messages.ToList();
|
||||
|
||||
@@ -717,8 +720,8 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId);
|
||||
|
||||
// Add messages to both
|
||||
var simpleContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")], []);
|
||||
var hierarchicalContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []);
|
||||
var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []);
|
||||
var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []);
|
||||
|
||||
await simpleProvider.InvokedAsync(simpleContext);
|
||||
await hierarchicalProvider.InvokedAsync(hierarchicalContext);
|
||||
@@ -727,7 +730,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await Task.Delay(100);
|
||||
|
||||
// Act & Assert
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
|
||||
var simpleMessages = await simpleProvider.InvokingAsync(invokingContext);
|
||||
var simpleMessageList = simpleMessages.ToList();
|
||||
@@ -760,7 +763,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
await Task.Delay(10); // Small delay to ensure different timestamps
|
||||
}
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(messages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
|
||||
await provider.InvokedAsync(context);
|
||||
|
||||
// Wait for eventual consistency
|
||||
@@ -768,7 +771,7 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
|
||||
// Act - Set max to 5 and retrieve
|
||||
provider.MaxMessagesToRetrieve = 5;
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var retrievedMessages = await provider.InvokingAsync(invokingContext);
|
||||
var messageList = retrievedMessages.ToList();
|
||||
|
||||
@@ -798,14 +801,14 @@ public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable
|
||||
messages.Add(new ChatMessage(ChatRole.User, $"Message {i}"));
|
||||
}
|
||||
|
||||
var context = new ChatHistoryProvider.InvokedContext(messages, []);
|
||||
var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []);
|
||||
await provider.InvokedAsync(context);
|
||||
|
||||
// Wait for eventual consistency
|
||||
await Task.Delay(100);
|
||||
|
||||
// Act - No limit set (default null)
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext([]);
|
||||
var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []);
|
||||
var retrievedMessages = await provider.InvokingAsync(invokingContext);
|
||||
var messageList = retrievedMessages.ToList();
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
{
|
||||
private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable.
|
||||
|
||||
private static readonly AIAgent s_mockAgent = new Moq.Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Moq.Mock<AgentSession>().Object;
|
||||
|
||||
private readonly HttpClient _httpClient;
|
||||
|
||||
public Mem0ProviderTests()
|
||||
@@ -49,14 +52,14 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
var sut = new Mem0Provider(this._httpClient, storageScope);
|
||||
|
||||
await sut.ClearStoredMemoriesAsync();
|
||||
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
|
||||
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
|
||||
Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty);
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext([input], aiContextProviderMessages: null));
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input], aiContextProviderMessages: null));
|
||||
var ctxAfterAdding = await GetContextWithRetryAsync(sut, question);
|
||||
await sut.ClearStoredMemoriesAsync();
|
||||
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
|
||||
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty);
|
||||
@@ -73,14 +76,14 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
var sut = new Mem0Provider(this._httpClient, storageScope);
|
||||
|
||||
await sut.ClearStoredMemoriesAsync();
|
||||
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
|
||||
var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
|
||||
Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty);
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null));
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null));
|
||||
var ctxAfterAdding = await GetContextWithRetryAsync(sut, question);
|
||||
await sut.ClearStoredMemoriesAsync();
|
||||
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question]));
|
||||
var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
|
||||
|
||||
// Assert
|
||||
Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty);
|
||||
@@ -99,13 +102,13 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
await sut1.ClearStoredMemoriesAsync();
|
||||
await sut2.ClearStoredMemoriesAsync();
|
||||
|
||||
var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext([question]));
|
||||
var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext([question]));
|
||||
var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
|
||||
var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]));
|
||||
Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty);
|
||||
Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty);
|
||||
|
||||
// Act
|
||||
await sut1.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null));
|
||||
await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null));
|
||||
var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question);
|
||||
var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question);
|
||||
|
||||
@@ -123,7 +126,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
AIContext? ctx = null;
|
||||
for (int i = 0; i < attempts; i++)
|
||||
{
|
||||
ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext([question]), CancellationToken.None);
|
||||
ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None);
|
||||
var text = ctx.Messages?[0].Text;
|
||||
if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0)
|
||||
{
|
||||
|
||||
@@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests;
|
||||
/// </summary>
|
||||
public sealed class Mem0ProviderTests : IDisposable
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
private readonly Mock<ILogger<Mem0Provider>> _loggerMock;
|
||||
private readonly Mock<ILoggerFactory> _loggerFactoryMock;
|
||||
private readonly RecordingHandler _handler = new();
|
||||
@@ -96,7 +99,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
UserId = "user"
|
||||
};
|
||||
var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "What is my name?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]);
|
||||
|
||||
// Act
|
||||
var aiContext = await sut.InvokingAsync(invokingContext);
|
||||
@@ -161,7 +164,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData };
|
||||
|
||||
var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Who am I?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]);
|
||||
|
||||
// Act
|
||||
await sut.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -215,7 +218,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
};
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
|
||||
|
||||
// Assert
|
||||
var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList();
|
||||
@@ -242,7 +245,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
};
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") });
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") });
|
||||
|
||||
// Assert
|
||||
Assert.Empty(this._handler.Requests);
|
||||
@@ -268,7 +271,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
};
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
|
||||
|
||||
// Assert
|
||||
this._loggerMock.Verify(
|
||||
@@ -318,7 +321,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
};
|
||||
|
||||
// Act
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
|
||||
await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });
|
||||
|
||||
// Assert
|
||||
Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count);
|
||||
@@ -400,7 +403,7 @@ public sealed class Mem0ProviderTests : IDisposable
|
||||
// Arrange
|
||||
var storageScope = new Mem0ProviderScope { ApplicationId = "app" };
|
||||
var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act
|
||||
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
@@ -327,4 +327,9 @@ public class ChatClientAgentSessionTests
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
internal sealed class Animal
|
||||
{
|
||||
public string Name { get; set; } = string.Empty;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,9 @@ namespace Microsoft.Agents.AI.UnitTests.Data;
|
||||
/// </summary>
|
||||
public sealed class TextSearchProviderTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
private readonly Mock<ILogger<TextSearchProvider>> _loggerMock;
|
||||
private readonly Mock<ILoggerFactory> _loggerFactoryMock;
|
||||
|
||||
@@ -64,10 +67,12 @@ public sealed class TextSearchProviderTests
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null);
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "Sample user question?"),
|
||||
new ChatMessage(ChatRole.User, "Additional part")
|
||||
]);
|
||||
s_mockAgent,
|
||||
s_mockSession,
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "Sample user question?"),
|
||||
new ChatMessage(ChatRole.User, "Additional part")
|
||||
]);
|
||||
|
||||
// Act
|
||||
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -139,7 +144,7 @@ public sealed class TextSearchProviderTests
|
||||
FunctionToolDescription = overrideDescription
|
||||
};
|
||||
var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act
|
||||
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -158,7 +163,7 @@ public sealed class TextSearchProviderTests
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act
|
||||
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -251,7 +256,7 @@ public sealed class TextSearchProviderTests
|
||||
ContextFormatter = r => $"Custom formatted context with {r.Count} results."
|
||||
};
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act
|
||||
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -285,7 +290,7 @@ public sealed class TextSearchProviderTests
|
||||
ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id))
|
||||
};
|
||||
var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act
|
||||
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -302,7 +307,7 @@ public sealed class TextSearchProviderTests
|
||||
// Arrange
|
||||
var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke };
|
||||
var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options);
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]);
|
||||
|
||||
// Act
|
||||
var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -340,12 +345,14 @@ public sealed class TextSearchProviderTests
|
||||
new ChatMessage(ChatRole.User, "C"),
|
||||
new ChatMessage(ChatRole.Assistant, "D"),
|
||||
};
|
||||
await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") });
|
||||
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") });
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "E")
|
||||
]);
|
||||
s_mockAgent,
|
||||
s_mockSession,
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "E")
|
||||
]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -380,12 +387,14 @@ public sealed class TextSearchProviderTests
|
||||
new ChatMessage(ChatRole.User, "C"),
|
||||
new ChatMessage(ChatRole.Assistant, "D"),
|
||||
};
|
||||
await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null));
|
||||
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null));
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "E")
|
||||
]);
|
||||
s_mockAgent,
|
||||
s_mockSession,
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "E")
|
||||
]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -414,20 +423,24 @@ public sealed class TextSearchProviderTests
|
||||
|
||||
// First memory update (A,B)
|
||||
await provider.InvokedAsync(new(
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "A"),
|
||||
new ChatMessage(ChatRole.Assistant, "B"),
|
||||
], aiContextProviderMessages: null));
|
||||
s_mockAgent,
|
||||
s_mockSession,
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "A"),
|
||||
new ChatMessage(ChatRole.Assistant, "B"),
|
||||
], aiContextProviderMessages: null));
|
||||
|
||||
// Second memory update (C,D,E)
|
||||
await provider.InvokedAsync(new(
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "C"),
|
||||
new ChatMessage(ChatRole.Assistant, "D"),
|
||||
new ChatMessage(ChatRole.User, "E"),
|
||||
], aiContextProviderMessages: null));
|
||||
s_mockAgent,
|
||||
s_mockSession,
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "C"),
|
||||
new ChatMessage(ChatRole.Assistant, "D"),
|
||||
new ChatMessage(ChatRole.User, "E"),
|
||||
], aiContextProviderMessages: null));
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "F")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -462,12 +475,14 @@ public sealed class TextSearchProviderTests
|
||||
new ChatMessage(ChatRole.User, "U2"),
|
||||
new ChatMessage(ChatRole.Assistant, "A2"),
|
||||
};
|
||||
await provider.InvokedAsync(new(initialMessages, null));
|
||||
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, null));
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext(
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "Question?") // Current request message always appended.
|
||||
]);
|
||||
s_mockAgent,
|
||||
s_mockSession,
|
||||
[
|
||||
new ChatMessage(ChatRole.User, "Question?") // Current request message always appended.
|
||||
]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -518,7 +533,7 @@ public sealed class TextSearchProviderTests
|
||||
};
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); // Populate recent memory.
|
||||
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Populate recent memory.
|
||||
var state = provider.Serialize();
|
||||
|
||||
// Assert
|
||||
@@ -547,7 +562,7 @@ public sealed class TextSearchProviderTests
|
||||
new ChatMessage(ChatRole.User, "C"),
|
||||
new ChatMessage(ChatRole.Assistant, "D"),
|
||||
};
|
||||
await provider.InvokedAsync(new(messages, aiContextProviderMessages: null));
|
||||
await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null));
|
||||
|
||||
// Act
|
||||
var state = provider.Serialize();
|
||||
@@ -563,7 +578,7 @@ public sealed class TextSearchProviderTests
|
||||
RecentMessageMemoryLimit = 4
|
||||
});
|
||||
var emptyMessages = Array.Empty<ChatMessage>();
|
||||
await roundTrippedProvider.InvokingAsync(new(emptyMessages), CancellationToken.None); // Trigger search to read memory.
|
||||
await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory.
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedInput);
|
||||
@@ -588,7 +603,7 @@ public sealed class TextSearchProviderTests
|
||||
new ChatMessage(ChatRole.Assistant, "L4"),
|
||||
new ChatMessage(ChatRole.User, "L5"),
|
||||
};
|
||||
await initialProvider.InvokedAsync(new(messages, aiContextProviderMessages: null));
|
||||
await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null));
|
||||
var state = initialProvider.Serialize();
|
||||
|
||||
string? capturedInput = null;
|
||||
@@ -604,7 +619,7 @@ public sealed class TextSearchProviderTests
|
||||
SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke,
|
||||
RecentMessageMemoryLimit = 3 // Lower limit
|
||||
});
|
||||
await restoredProvider.InvokingAsync(new(Array.Empty<ChatMessage>()), CancellationToken.None);
|
||||
await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty<ChatMessage>()), CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedInput);
|
||||
@@ -631,7 +646,7 @@ public sealed class TextSearchProviderTests
|
||||
RecentMessageMemoryLimit = 3
|
||||
});
|
||||
var emptyMessages = Array.Empty<ChatMessage>();
|
||||
await provider.InvokingAsync(new(emptyMessages), CancellationToken.None);
|
||||
await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None);
|
||||
|
||||
// Assert
|
||||
Assert.NotNull(capturedInput);
|
||||
|
||||
+10
-7
@@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Memory.UnitTests;
|
||||
/// </summary>
|
||||
public class ChatHistoryMemoryProviderTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
private static readonly AgentSession s_mockSession = new Mock<AgentSession>().Object;
|
||||
|
||||
private readonly Mock<ILogger<ChatHistoryMemoryProvider>> _loggerMock;
|
||||
private readonly Mock<ILoggerFactory> _loggerFactoryMock;
|
||||
|
||||
@@ -116,7 +119,7 @@ public class ChatHistoryMemoryProviderTests
|
||||
var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls");
|
||||
var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" };
|
||||
|
||||
var invokedContext = new AIContextProvider.InvokedContext([requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null)
|
||||
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null)
|
||||
{
|
||||
ResponseMessages = [responseMsg]
|
||||
};
|
||||
@@ -174,7 +177,7 @@ public class ChatHistoryMemoryProviderTests
|
||||
1,
|
||||
new ChatHistoryMemoryProviderScope() { UserId = "UID" });
|
||||
var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" };
|
||||
var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null)
|
||||
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null)
|
||||
{
|
||||
InvokeException = new InvalidOperationException("Invoke failed")
|
||||
};
|
||||
@@ -203,7 +206,7 @@ public class ChatHistoryMemoryProviderTests
|
||||
new ChatHistoryMemoryProviderScope() { UserId = "UID" },
|
||||
loggerFactory: this._loggerFactoryMock.Object);
|
||||
var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" };
|
||||
var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null);
|
||||
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null);
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(invokedContext, CancellationToken.None);
|
||||
@@ -254,7 +257,7 @@ public class ChatHistoryMemoryProviderTests
|
||||
loggerFactory: this._loggerFactoryMock.Object);
|
||||
|
||||
var requestMsg = new ChatMessage(ChatRole.User, "request text");
|
||||
var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null);
|
||||
var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null);
|
||||
|
||||
// Act
|
||||
await provider.InvokedAsync(invokedContext, CancellationToken.None);
|
||||
@@ -327,7 +330,7 @@ public class ChatHistoryMemoryProviderTests
|
||||
options: providerOptions);
|
||||
|
||||
var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history");
|
||||
var invokingContext = new AIContextProvider.InvokingContext([requestMsg]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -378,7 +381,7 @@ public class ChatHistoryMemoryProviderTests
|
||||
var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope);
|
||||
|
||||
var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history");
|
||||
var invokingContext = new AIContextProvider.InvokingContext([requestMsg]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
@@ -442,7 +445,7 @@ public class ChatHistoryMemoryProviderTests
|
||||
options: options,
|
||||
loggerFactory: this._loggerFactoryMock.Object);
|
||||
|
||||
var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "requesting relevant history")]);
|
||||
var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]);
|
||||
|
||||
// Act
|
||||
await provider.InvokingAsync(invokingContext, CancellationToken.None);
|
||||
|
||||
@@ -14,4 +14,5 @@ namespace Microsoft.Agents.AI.UnitTests;
|
||||
[JsonSerializable(typeof(string))]
|
||||
[JsonSerializable(typeof(string[]))]
|
||||
[JsonSerializable(typeof(Dictionary<string, object?>))]
|
||||
[JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))]
|
||||
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
|
||||
|
||||
@@ -23,7 +23,7 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture
|
||||
|
||||
public IChatClient ChatClient => this._agent.ChatClient;
|
||||
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
|
||||
{
|
||||
var typedSession = (ChatClientAgentSession)session;
|
||||
List<ChatMessage> messages = [];
|
||||
|
||||
@@ -28,7 +28,7 @@ public class OpenAIChatCompletionFixture : IChatClientAgentFixture
|
||||
|
||||
public IChatClient ChatClient => this._agent.ChatClient;
|
||||
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
|
||||
{
|
||||
var typedSession = (ChatClientAgentSession)session;
|
||||
|
||||
@@ -37,7 +37,7 @@ public class OpenAIChatCompletionFixture : IChatClientAgentFixture
|
||||
return [];
|
||||
}
|
||||
|
||||
return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
|
||||
return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
|
||||
}
|
||||
|
||||
public Task<ChatClientAgent> CreateChatClientAgentAsync(
|
||||
|
||||
@@ -25,7 +25,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture
|
||||
|
||||
public IChatClient ChatClient => this._agent.ChatClient;
|
||||
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AgentSession session)
|
||||
public async Task<List<ChatMessage>> GetChatHistoryAsync(AIAgent agent, AgentSession session)
|
||||
{
|
||||
var typedSession = (ChatClientAgentSession)session;
|
||||
|
||||
@@ -55,7 +55,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture
|
||||
return [];
|
||||
}
|
||||
|
||||
return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList();
|
||||
return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList();
|
||||
}
|
||||
|
||||
private static ChatMessage ConvertToChatMessage(ResponseItem item)
|
||||
|
||||
Reference in New Issue
Block a user