mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Replace Typed Base Providers with Composition (#3988)
This commit is contained in:
committed by
GitHub
Unverified
parent
8015e00f56
commit
cd4e36ebf7
+13
-7
@@ -86,25 +86,31 @@ namespace SampleApp
|
||||
/// <summary>
|
||||
/// Sample memory component that can remember a user's name and age.
|
||||
/// </summary>
|
||||
internal sealed class UserInfoMemory : AIContextProvider<UserInfo>
|
||||
internal sealed class UserInfoMemory : AIContextProvider
|
||||
{
|
||||
private readonly ProviderSessionState<UserInfo> _sessionState;
|
||||
private readonly IChatClient _chatClient;
|
||||
|
||||
public UserInfoMemory(IChatClient chatClient, Func<AgentSession?, UserInfo>? stateInitializer = null)
|
||||
: base(stateInitializer ?? (_ => new UserInfo()), null, null, null, null)
|
||||
: base(null, null)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<UserInfo>(
|
||||
stateInitializer ?? (_ => new UserInfo()),
|
||||
this.GetType().Name);
|
||||
this._chatClient = chatClient;
|
||||
}
|
||||
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
public UserInfo GetUserInfo(AgentSession session)
|
||||
=> this.GetOrInitializeState(session);
|
||||
=> this._sessionState.GetOrInitializeState(session);
|
||||
|
||||
public void SetUserInfo(AgentSession session, UserInfo userInfo)
|
||||
=> this.SaveState(session, userInfo);
|
||||
=> this._sessionState.SaveState(session, userInfo);
|
||||
|
||||
protected override async ValueTask StoreAIContextAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var userInfo = this.GetOrInitializeState(context.Session);
|
||||
var userInfo = this._sessionState.GetOrInitializeState(context.Session);
|
||||
|
||||
// Try and extract the user name and age from the message if we don't have it already and it's a user message.
|
||||
if ((userInfo.UserName is null || userInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User))
|
||||
@@ -121,12 +127,12 @@ namespace SampleApp
|
||||
userInfo.UserAge ??= result.Result.UserAge;
|
||||
}
|
||||
|
||||
this.SaveState(context.Session, userInfo);
|
||||
this._sessionState.SaveState(context.Session, userInfo);
|
||||
}
|
||||
|
||||
protected override ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var userInfo = this.GetOrInitializeState(context.Session);
|
||||
var userInfo = this._sessionState.GetOrInitializeState(context.Session);
|
||||
|
||||
StringBuilder instructions = new();
|
||||
|
||||
|
||||
+11
-5
@@ -76,25 +76,31 @@ namespace SampleApp
|
||||
/// State (the session DB key) is stored in the <see cref="AgentSession.StateBag"/> so it roundtrips
|
||||
/// automatically with session serialization.
|
||||
/// </summary>
|
||||
internal sealed class VectorChatHistoryProvider : ChatHistoryProvider<VectorChatHistoryProvider.State>
|
||||
internal sealed class VectorChatHistoryProvider : ChatHistoryProvider
|
||||
{
|
||||
private readonly ProviderSessionState<State> _sessionState;
|
||||
private readonly VectorStore _vectorStore;
|
||||
|
||||
public VectorChatHistoryProvider(
|
||||
VectorStore vectorStore,
|
||||
Func<AgentSession?, State>? stateInitializer = null,
|
||||
string? stateKey = null)
|
||||
: base(stateInitializer: stateInitializer ?? (_ => new State(Guid.NewGuid().ToString("N"))), stateKey: stateKey, jsonSerializerOptions: null, provideOutputMessageFilter: null, storeInputMessageFilter: null)
|
||||
: base(provideOutputMessageFilter: null, storeInputMessageFilter: null)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<State>(
|
||||
stateInitializer ?? (_ => new State(Guid.NewGuid().ToString("N"))),
|
||||
stateKey ?? this.GetType().Name);
|
||||
this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore));
|
||||
}
|
||||
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
public string GetSessionDbKey(AgentSession session)
|
||||
=> this.GetOrInitializeState(session).SessionDbKey;
|
||||
=> this._sessionState.GetOrInitializeState(session).SessionDbKey;
|
||||
|
||||
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideChatHistoryAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
|
||||
await collection.EnsureCollectionExistsAsync(cancellationToken);
|
||||
|
||||
@@ -112,7 +118,7 @@ namespace SampleApp
|
||||
|
||||
protected override async ValueTask StoreChatHistoryAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
|
||||
var collection = this._vectorStore.GetCollection<string, ChatHistoryItem>("ChatHistory");
|
||||
await collection.EnsureCollectionExistsAsync(cancellationToken);
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// Provides an abstract base class for components that enhance AI context during agent invocations with support for maintaining provider state of type <typeparamref name="TState"/>.
|
||||
/// </summary>
|
||||
/// <typeparam name="TState">The type of the state to be maintained by the context provider. Must be a reference type.</typeparam>
|
||||
/// <remarks>
|
||||
/// This class extends <see cref="AIContextProvider"/> by introducing a strongly-typed state management mechanism, allowing derived classes to maintain and persist custom state information across invocations.
|
||||
/// The state is stored in the session's StateBag using a configurable key and JSON serialization options, enabling seamless integration with the agent session lifecycle.
|
||||
/// </remarks>
|
||||
public abstract class AIContextProvider<TState> : AIContextProvider
|
||||
where TState : class
|
||||
{
|
||||
private readonly Func<AgentSession?, TState> _stateInitializer;
|
||||
private readonly string _stateKey;
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="AIContextProvider{TState}"/> class.
|
||||
/// </summary>
|
||||
/// <param name="stateInitializer">A function to initialize the state for the context provider.</param>
|
||||
/// <param name="stateKey">The key used to store the state in the session's StateBag.</param>
|
||||
/// <param name="jsonSerializerOptions">Options for JSON serialization and deserialization of the state.</param>
|
||||
/// <param name="provideInputMessageFilter">An optional filter function to apply to input messages before providing context. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
|
||||
/// <param name="storeInputMessageFilter">An optional filter function to apply to request messages before storing context. If not set, defaults to including only <see cref="AgentRequestMessageSourceType.External"/> messages.</param>
|
||||
protected AIContextProvider(
|
||||
Func<AgentSession?, TState> stateInitializer,
|
||||
string? stateKey,
|
||||
JsonSerializerOptions? jsonSerializerOptions,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideInputMessageFilter,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter)
|
||||
: base(provideInputMessageFilter, storeInputMessageFilter)
|
||||
{
|
||||
this._stateInitializer = stateInitializer;
|
||||
this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions;
|
||||
this._stateKey = stateKey ?? this.GetType().Name;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._stateKey;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the state from the session's StateBag, or initializes it using the state initializer if not present.
|
||||
/// </summary>
|
||||
/// <param name="session">The agent session containing the StateBag.</param>
|
||||
/// <returns>The provider state.</returns>
|
||||
protected virtual TState GetOrInitializeState(AgentSession? session)
|
||||
{
|
||||
if (session?.StateBag.TryGetValue<TState>(this._stateKey, out var state, this._jsonSerializerOptions) is true && state is not null)
|
||||
{
|
||||
return state;
|
||||
}
|
||||
|
||||
state = this._stateInitializer(session);
|
||||
if (session is not null)
|
||||
{
|
||||
session.StateBag.SetValue(this._stateKey, state, this._jsonSerializerOptions);
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the specified state to the session's StateBag using the configured state key and JSON serializer options.
|
||||
/// If the session is null, this method does nothing.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This method provides a convenient way for derived classes to persist state changes back to the session after processing.
|
||||
/// It abstracts away the details of how state is stored in the session, allowing derived classes to focus on their specific logic.
|
||||
/// </remarks>
|
||||
/// <param name="session">The agent session containing the StateBag.</param>
|
||||
/// <param name="state">The state to be saved.</param>
|
||||
protected virtual void SaveState(AgentSession? session, TState state)
|
||||
{
|
||||
if (session is not null)
|
||||
{
|
||||
session.StateBag.SetValue(this._stateKey, state, this._jsonSerializerOptions);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// Provides an abstract base class for fetching chat messages from, and adding chat messages to, chat history for the purposes of agent execution with support for maintaining provider state of type <typeparamref name="TState"/>.
|
||||
/// </summary>
|
||||
/// <typeparam name="TState">The type of the state to be maintained by the chat history provider. Must be a reference type.</typeparam>
|
||||
/// <remarks>
|
||||
/// This class extends <see cref="ChatHistoryProvider"/> by introducing a strongly-typed state management mechanism, allowing derived classes to maintain and persist custom state information across invocations.
|
||||
/// The state is stored in the session's StateBag using a configurable key and JSON serialization options, enabling seamless integration with the agent session lifecycle.
|
||||
/// </remarks>
|
||||
public abstract class ChatHistoryProvider<TState> : ChatHistoryProvider
|
||||
where TState : class
|
||||
{
|
||||
private readonly Func<AgentSession?, TState> _stateInitializer;
|
||||
private readonly string _stateKey;
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ChatHistoryProvider{TState}"/> class.
|
||||
/// </summary>
|
||||
/// <param name="stateInitializer">A function to initialize the state for the chat history provider.</param>
|
||||
/// <param name="stateKey">The key used to store the state in the session's StateBag.</param>
|
||||
/// <param name="jsonSerializerOptions">Options for JSON serialization and deserialization of the state.</param>
|
||||
/// <param name="provideOutputMessageFilter">A filter function to apply to messages when retrieving them from the chat history.</param>
|
||||
/// <param name="storeInputMessageFilter">A filter function to apply to messages before storing them in the chat history.</param>
|
||||
protected ChatHistoryProvider(
|
||||
Func<AgentSession?, TState> stateInitializer,
|
||||
string? stateKey,
|
||||
JsonSerializerOptions? jsonSerializerOptions,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideOutputMessageFilter,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter)
|
||||
: base(provideOutputMessageFilter, storeInputMessageFilter)
|
||||
{
|
||||
this._stateInitializer = stateInitializer;
|
||||
this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions;
|
||||
this._stateKey = stateKey ?? this.GetType().Name;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._stateKey;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the state from the session's StateBag, or initializes it using the state initializer if not present.
|
||||
/// </summary>
|
||||
/// <param name="session">The agent session containing the StateBag.</param>
|
||||
/// <returns>The provider state, or null if no session is available.</returns>
|
||||
protected virtual TState GetOrInitializeState(AgentSession? session)
|
||||
{
|
||||
if (session?.StateBag.TryGetValue<TState>(this._stateKey, out var state, this._jsonSerializerOptions) is true && state is not null)
|
||||
{
|
||||
return state;
|
||||
}
|
||||
|
||||
state = this._stateInitializer(session);
|
||||
if (session is not null)
|
||||
{
|
||||
session.StateBag.SetValue(this._stateKey, state, this._jsonSerializerOptions);
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the specified state to the session's StateBag using the configured state key and JSON serializer options.
|
||||
/// If the session is null, this method does nothing.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This method provides a convenient way for derived classes to persist state changes back to the session after processing.
|
||||
/// It abstracts away the details of how state is stored in the session, allowing derived classes to focus on their specific logic.
|
||||
/// </remarks>
|
||||
/// <param name="session">The agent session containing the StateBag.</param>
|
||||
/// <param name="state">The state to be saved.</param>
|
||||
protected virtual void SaveState(AgentSession? session, TState state)
|
||||
{
|
||||
if (session is not null)
|
||||
{
|
||||
session.StateBag.SetValue(this._stateKey, state, this._jsonSerializerOptions);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,8 +24,10 @@ namespace Microsoft.Agents.AI;
|
||||
/// message reduction strategies or alternative storage implementations.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider<InMemoryChatHistoryProvider.State>
|
||||
public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
|
||||
{
|
||||
private readonly ProviderSessionState<State> _sessionState;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="InMemoryChatHistoryProvider"/> class.
|
||||
/// </summary>
|
||||
@@ -35,16 +37,20 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider<InMemoryCh
|
||||
/// </param>
|
||||
public InMemoryChatHistoryProvider(InMemoryChatHistoryProviderOptions? options = null)
|
||||
: base(
|
||||
options?.StateInitializer ?? (_ => new State()),
|
||||
options?.StateKey,
|
||||
options?.JsonSerializerOptions,
|
||||
options?.ProvideOutputMessageFilter,
|
||||
options?.StorageInputMessageFilter)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<State>(
|
||||
options?.StateInitializer ?? (_ => new State()),
|
||||
options?.StateKey ?? this.GetType().Name,
|
||||
options?.JsonSerializerOptions);
|
||||
this.ChatReducer = options?.ChatReducer;
|
||||
this.ReducerTriggerEvent = options?.ReducerTriggerEvent ?? InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the chat reducer used to process or reduce chat messages. If null, no reduction logic will be applied.
|
||||
/// </summary>
|
||||
@@ -61,7 +67,7 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider<InMemoryCh
|
||||
/// <param name="session">The agent session containing the state.</param>
|
||||
/// <returns>A list of chat messages, or an empty list if no state is found.</returns>
|
||||
public List<ChatMessage> GetMessages(AgentSession? session)
|
||||
=> this.GetOrInitializeState(session).Messages;
|
||||
=> this._sessionState.GetOrInitializeState(session).Messages;
|
||||
|
||||
/// <summary>
|
||||
/// Sets the chat messages for the specified session.
|
||||
@@ -73,14 +79,14 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider<InMemoryCh
|
||||
{
|
||||
_ = Throw.IfNull(messages);
|
||||
|
||||
var state = this.GetOrInitializeState(session);
|
||||
var state = this._sessionState.GetOrInitializeState(session);
|
||||
state.Messages = messages;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<IEnumerable<ChatMessage>> ProvideChatHistoryAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
|
||||
if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null)
|
||||
{
|
||||
@@ -93,7 +99,7 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider<InMemoryCh
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask StoreChatHistoryAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
|
||||
// Add request and response messages to the provider
|
||||
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Text.Json;
|
||||
|
||||
namespace Microsoft.Agents.AI;
|
||||
|
||||
/// <summary>
|
||||
/// Provides strongly-typed state management for providers, enabling reading and writing of provider-specific state
|
||||
/// to and from an <see cref="AgentSession"/>'s <see cref="AgentSessionStateBag"/>.
|
||||
/// </summary>
|
||||
/// <typeparam name="TState">The type of the state to be maintained. Must be a reference type.</typeparam>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// This class encapsulates the logic for initializing, retrieving, and persisting provider state in the session's StateBag
|
||||
/// using a configurable key and JSON serialization options. It is intended to be used as a composed field within provider
|
||||
/// implementations (e.g., <see cref="AIContextProvider"/> or <see cref="ChatHistoryProvider"/> subclasses) to avoid
|
||||
/// duplicating state management logic across provider type hierarchies.
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// State is stored in the <see cref="AgentSession.StateBag"/> using the <see cref="StateKey"/> property as the key,
|
||||
/// enabling multiple providers to maintain independent state within the same session.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public class ProviderSessionState<TState>
|
||||
where TState : class
|
||||
{
|
||||
private readonly Func<AgentSession?, TState> _stateInitializer;
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ProviderSessionState{TState}"/> class.
|
||||
/// </summary>
|
||||
/// <param name="stateInitializer">A function to initialize the state when it is not yet present in the session's StateBag.</param>
|
||||
/// <param name="stateKey">The key used to store the state in the session's StateBag.</param>
|
||||
/// <param name="jsonSerializerOptions">Options for JSON serialization and deserialization of the state.</param>
|
||||
public ProviderSessionState(
|
||||
Func<AgentSession?, TState> stateInitializer,
|
||||
string stateKey,
|
||||
JsonSerializerOptions? jsonSerializerOptions = null)
|
||||
{
|
||||
this._stateInitializer = stateInitializer;
|
||||
this.StateKey = stateKey;
|
||||
this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the key used to store the provider state in the <see cref="AgentSession.StateBag"/>.
|
||||
/// </summary>
|
||||
public string StateKey { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the state from the session's StateBag, or initializes it using the state initializer if not present.
|
||||
/// </summary>
|
||||
/// <param name="session">The agent session containing the StateBag.</param>
|
||||
/// <returns>The provider state.</returns>
|
||||
public TState GetOrInitializeState(AgentSession? session)
|
||||
{
|
||||
if (session?.StateBag.TryGetValue<TState>(this.StateKey, out var state, this._jsonSerializerOptions) is true && state is not null)
|
||||
{
|
||||
return state;
|
||||
}
|
||||
|
||||
state = this._stateInitializer(session);
|
||||
if (session is not null)
|
||||
{
|
||||
session.StateBag.SetValue(this.StateKey, state, this._jsonSerializerOptions);
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the specified state to the session's StateBag using the configured state key and JSON serializer options.
|
||||
/// If the session is null, this method does nothing.
|
||||
/// </summary>
|
||||
/// <param name="session">The agent session containing the StateBag.</param>
|
||||
/// <param name="state">The state to be saved.</param>
|
||||
public void SaveState(AgentSession? session, TState state)
|
||||
{
|
||||
if (session is not null)
|
||||
{
|
||||
session.StateBag.SetValue(this.StateKey, state, this._jsonSerializerOptions);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,8 +19,9 @@ namespace Microsoft.Agents.AI;
|
||||
/// </summary>
|
||||
[RequiresUnreferencedCode("The CosmosChatHistoryProvider uses JSON serialization which is incompatible with trimming.")]
|
||||
[RequiresDynamicCode("The CosmosChatHistoryProvider uses JSON serialization which is incompatible with NativeAOT.")]
|
||||
public sealed class CosmosChatHistoryProvider : ChatHistoryProvider<CosmosChatHistoryProvider.State>, IDisposable
|
||||
public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
|
||||
{
|
||||
private readonly ProviderSessionState<State> _sessionState;
|
||||
private readonly CosmosClient _cosmosClient;
|
||||
private readonly Container _container;
|
||||
private readonly bool _ownsClient;
|
||||
@@ -98,8 +99,11 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider<CosmosChatHi
|
||||
string? stateKey = null,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? provideOutputMessageFilter = null,
|
||||
Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? storeInputMessageFilter = null)
|
||||
: base(Throw.IfNull(stateInitializer), stateKey, null, provideOutputMessageFilter, storeInputMessageFilter)
|
||||
: base(provideOutputMessageFilter, storeInputMessageFilter)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<State>(
|
||||
Throw.IfNull(stateInitializer),
|
||||
stateKey ?? this.GetType().Name);
|
||||
this._cosmosClient = Throw.IfNull(cosmosClient);
|
||||
this.DatabaseId = Throw.IfNullOrWhitespace(databaseId);
|
||||
this.ContainerId = Throw.IfNullOrWhitespace(containerId);
|
||||
@@ -107,6 +111,9 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider<CosmosChatHi
|
||||
this._ownsClient = ownsClient;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CosmosChatHistoryProvider"/> class using a connection string.
|
||||
/// </summary>
|
||||
@@ -190,7 +197,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider<CosmosChatHi
|
||||
}
|
||||
#pragma warning restore CA1513
|
||||
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var partitionKey = BuildPartitionKey(state);
|
||||
|
||||
// Fetch most recent messages in descending order when limit is set, then reverse to ascending
|
||||
@@ -253,7 +260,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider<CosmosChatHi
|
||||
}
|
||||
#pragma warning restore CA1513
|
||||
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var messageList = context.RequestMessages.Concat(context.ResponseMessages ?? []).ToList();
|
||||
if (messageList.Count == 0)
|
||||
{
|
||||
@@ -424,7 +431,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider<CosmosChatHi
|
||||
}
|
||||
#pragma warning restore CA1513
|
||||
|
||||
var state = this.GetOrInitializeState(session);
|
||||
var state = this._sessionState.GetOrInitializeState(session);
|
||||
var partitionKey = BuildPartitionKey(state);
|
||||
|
||||
// Efficient count query
|
||||
@@ -458,7 +465,7 @@ public sealed class CosmosChatHistoryProvider : ChatHistoryProvider<CosmosChatHi
|
||||
}
|
||||
#pragma warning restore CA1513
|
||||
|
||||
var state = this.GetOrInitializeState(session);
|
||||
var state = this._sessionState.GetOrInitializeState(session);
|
||||
var partitionKey = BuildPartitionKey(state);
|
||||
|
||||
// Batch delete for efficiency
|
||||
|
||||
@@ -22,10 +22,11 @@ namespace Microsoft.Agents.AI.Mem0;
|
||||
/// for new invocations using a semantic search endpoint. Retrieved memories are injected as user messages
|
||||
/// to the model, prefixed by a configurable context prompt.
|
||||
/// </remarks>
|
||||
public sealed class Mem0Provider : AIContextProvider<Mem0Provider.State>
|
||||
public sealed class Mem0Provider : AIContextProvider
|
||||
{
|
||||
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
|
||||
|
||||
private readonly ProviderSessionState<State> _sessionState;
|
||||
private readonly string _contextPrompt;
|
||||
private readonly bool _enableSensitiveTelemetryData;
|
||||
|
||||
@@ -51,8 +52,12 @@ public sealed class Mem0Provider : AIContextProvider<Mem0Provider.State>
|
||||
/// </code>
|
||||
/// </remarks>
|
||||
public Mem0Provider(HttpClient httpClient, Func<AgentSession?, State> stateInitializer, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null)
|
||||
: base(ValidateStateInitializer(Throw.IfNull(stateInitializer)), options?.StateKey, Mem0JsonUtilities.DefaultOptions, options?.SearchInputMessageFilter, options?.StorageInputMessageFilter)
|
||||
: base(options?.SearchInputMessageFilter, options?.StorageInputMessageFilter)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<State>(
|
||||
ValidateStateInitializer(Throw.IfNull(stateInitializer)),
|
||||
options?.StateKey ?? this.GetType().Name,
|
||||
Mem0JsonUtilities.DefaultOptions);
|
||||
Throw.IfNull(httpClient);
|
||||
if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsoluteUri))
|
||||
{
|
||||
@@ -66,6 +71,9 @@ public sealed class Mem0Provider : AIContextProvider<Mem0Provider.State>
|
||||
this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
private static Func<AgentSession?, State> ValidateStateInitializer(Func<AgentSession?, State> stateInitializer) =>
|
||||
session =>
|
||||
{
|
||||
@@ -88,7 +96,7 @@ public sealed class Mem0Provider : AIContextProvider<Mem0Provider.State>
|
||||
{
|
||||
Throw.IfNull(context);
|
||||
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var searchScope = state.SearchScope;
|
||||
|
||||
string queryText = string.Join(
|
||||
@@ -165,7 +173,7 @@ public sealed class Mem0Provider : AIContextProvider<Mem0Provider.State>
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask StoreAIContextAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var storageScope = state.StorageScope;
|
||||
|
||||
try
|
||||
@@ -200,7 +208,7 @@ public sealed class Mem0Provider : AIContextProvider<Mem0Provider.State>
|
||||
public Task ClearStoredMemoriesAsync(AgentSession session, CancellationToken cancellationToken = default)
|
||||
{
|
||||
Throw.IfNull(session);
|
||||
var state = this.GetOrInitializeState(session);
|
||||
var state = this._sessionState.GetOrInitializeState(session);
|
||||
var storageScope = state.StorageScope;
|
||||
|
||||
return this._client.ClearMemoryAsync(
|
||||
|
||||
@@ -9,8 +9,10 @@ using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Agents.AI.Workflows;
|
||||
|
||||
internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider<WorkflowChatHistoryProvider.StoreState>
|
||||
internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider
|
||||
{
|
||||
private readonly ProviderSessionState<StoreState> _sessionState;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="WorkflowChatHistoryProvider"/> class.
|
||||
/// </summary>
|
||||
@@ -20,10 +22,17 @@ internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider<Workflow
|
||||
/// and source generated serializers are required, or Native AOT / Trimming is required.
|
||||
/// </param>
|
||||
public WorkflowChatHistoryProvider(JsonSerializerOptions? jsonSerializerOptions = null)
|
||||
: base(stateInitializer: _ => new StoreState(), stateKey: null, jsonSerializerOptions: jsonSerializerOptions, provideOutputMessageFilter: null, storeInputMessageFilter: null)
|
||||
: base(provideOutputMessageFilter: null, storeInputMessageFilter: null)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<StoreState>(
|
||||
_ => new StoreState(),
|
||||
this.GetType().Name,
|
||||
jsonSerializerOptions);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
internal sealed class StoreState
|
||||
{
|
||||
public int Bookmark { get; set; }
|
||||
@@ -31,21 +40,21 @@ internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider<Workflow
|
||||
}
|
||||
|
||||
internal void AddMessages(AgentSession session, params IEnumerable<ChatMessage> messages)
|
||||
=> this.GetOrInitializeState(session).Messages.AddRange(messages);
|
||||
=> this._sessionState.GetOrInitializeState(session).Messages.AddRange(messages);
|
||||
|
||||
protected override ValueTask<IEnumerable<ChatMessage>> ProvideChatHistoryAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
=> new(this.GetOrInitializeState(context.Session).Messages);
|
||||
=> new(this._sessionState.GetOrInitializeState(context.Session).Messages);
|
||||
|
||||
protected override ValueTask StoreChatHistoryAsync(InvokedContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []);
|
||||
this.GetOrInitializeState(context.Session).Messages.AddRange(allNewMessages);
|
||||
this._sessionState.GetOrInitializeState(context.Session).Messages.AddRange(allNewMessages);
|
||||
return default;
|
||||
}
|
||||
|
||||
public IEnumerable<ChatMessage> GetFromBookmark(AgentSession session)
|
||||
{
|
||||
var state = this.GetOrInitializeState(session);
|
||||
var state = this._sessionState.GetOrInitializeState(session);
|
||||
|
||||
for (int i = state.Bookmark; i < state.Messages.Count; i++)
|
||||
{
|
||||
@@ -55,7 +64,7 @@ internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider<Workflow
|
||||
|
||||
public void UpdateBookmark(AgentSession session)
|
||||
{
|
||||
var state = this.GetOrInitializeState(session);
|
||||
var state = this._sessionState.GetOrInitializeState(session);
|
||||
state.Bookmark = state.Messages.Count;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,13 +34,15 @@ namespace Microsoft.Agents.AI;
|
||||
/// injecting them automatically on each invocation.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public sealed class ChatHistoryMemoryProvider : AIContextProvider<ChatHistoryMemoryProvider.State>, IDisposable
|
||||
public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
|
||||
{
|
||||
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";
|
||||
private const int DefaultMaxResults = 3;
|
||||
private const string DefaultFunctionToolName = "Search";
|
||||
private const string DefaultFunctionToolDescription = "Allows searching for related previous chat history to help answer the user question.";
|
||||
|
||||
private readonly ProviderSessionState<State> _sessionState;
|
||||
|
||||
#pragma warning disable CA2213 // VectorStore is not owned by this class - caller is responsible for disposal
|
||||
private readonly VectorStore _vectorStore;
|
||||
#pragma warning restore CA2213
|
||||
@@ -74,8 +76,12 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider<ChatHistoryMem
|
||||
Func<AgentSession?, State> stateInitializer,
|
||||
ChatHistoryMemoryProviderOptions? options = null,
|
||||
ILoggerFactory? loggerFactory = null)
|
||||
: base(Throw.IfNull(stateInitializer), options?.StateKey, AgentJsonUtilities.DefaultOptions, options?.SearchInputMessageFilter, options?.StorageInputMessageFilter)
|
||||
: base(options?.SearchInputMessageFilter, options?.StorageInputMessageFilter)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<State>(
|
||||
Throw.IfNull(stateInitializer),
|
||||
options?.StateKey ?? this.GetType().Name,
|
||||
AgentJsonUtilities.DefaultOptions);
|
||||
this._vectorStore = Throw.IfNull(vectorStore);
|
||||
|
||||
options ??= new ChatHistoryMemoryProviderOptions();
|
||||
@@ -109,12 +115,15 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider<ChatHistoryMem
|
||||
this._collection = this._vectorStore.GetDynamicCollection(Throw.IfNullOrWhitespace(collectionName), definition);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
_ = Throw.IfNull(context);
|
||||
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var searchScope = state.SearchScope;
|
||||
|
||||
if (this._searchTime == ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling)
|
||||
@@ -186,7 +195,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider<ChatHistoryMem
|
||||
{
|
||||
_ = Throw.IfNull(context);
|
||||
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
var state = this._sessionState.GetOrInitializeState(context.Session);
|
||||
var storageScope = state.StorageScope;
|
||||
|
||||
try
|
||||
|
||||
@@ -32,13 +32,14 @@ namespace Microsoft.Agents.AI;
|
||||
/// multi-turn context to the retrieval layer without permanently altering the conversation history.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public sealed class TextSearchProvider : AIContextProvider<TextSearchProvider.TextSearchProviderState>
|
||||
public sealed class TextSearchProvider : AIContextProvider
|
||||
{
|
||||
private const string DefaultPluginSearchFunctionName = "Search";
|
||||
private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question.";
|
||||
private const string DefaultContextPrompt = "## Additional Context\nConsider the following information from source documents when responding to the user:";
|
||||
private const string DefaultCitationsPrompt = "Include citations to the source document with document name and link if document name and link is available.";
|
||||
|
||||
private readonly ProviderSessionState<TextSearchProviderState> _sessionState;
|
||||
private readonly Func<string, CancellationToken, Task<IEnumerable<TextSearchResult>>> _searchAsync;
|
||||
private readonly ILogger<TextSearchProvider>? _logger;
|
||||
private readonly AITool[] _tools;
|
||||
@@ -60,8 +61,12 @@ public sealed class TextSearchProvider : AIContextProvider<TextSearchProvider.Te
|
||||
Func<string, CancellationToken, Task<IEnumerable<TextSearchResult>>> searchAsync,
|
||||
TextSearchProviderOptions? options = null,
|
||||
ILoggerFactory? loggerFactory = null)
|
||||
: base(_ => new TextSearchProviderState(), options?.StateKey, AgentJsonUtilities.DefaultOptions, options?.SearchInputMessageFilter, options?.StorageInputMessageFilter)
|
||||
: base(options?.SearchInputMessageFilter, options?.StorageInputMessageFilter)
|
||||
{
|
||||
this._sessionState = new ProviderSessionState<TextSearchProviderState>(
|
||||
_ => new TextSearchProviderState(),
|
||||
options?.StateKey ?? this.GetType().Name,
|
||||
AgentJsonUtilities.DefaultOptions);
|
||||
// Validate and assign parameters
|
||||
this._searchAsync = Throw.IfNull(searchAsync);
|
||||
this._logger = loggerFactory?.CreateLogger<TextSearchProvider>();
|
||||
@@ -82,6 +87,9 @@ public sealed class TextSearchProvider : AIContextProvider<TextSearchProvider.Te
|
||||
];
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string StateKey => this._sessionState.StateKey;
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override async ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
@@ -95,7 +103,7 @@ public sealed class TextSearchProvider : AIContextProvider<TextSearchProvider.Te
|
||||
}
|
||||
|
||||
// Retrieve recent messages from the session state.
|
||||
var recentMessagesText = this.GetOrInitializeState(context.Session).RecentMessagesText
|
||||
var recentMessagesText = this._sessionState.GetOrInitializeState(context.Session).RecentMessagesText
|
||||
?? [];
|
||||
|
||||
// Aggregate text from memory + current request messages.
|
||||
@@ -165,7 +173,7 @@ public sealed class TextSearchProvider : AIContextProvider<TextSearchProvider.Te
|
||||
}
|
||||
|
||||
// Retrieve existing recent messages from the session state.
|
||||
var recentMessagesText = this.GetOrInitializeState(context.Session).RecentMessagesText
|
||||
var recentMessagesText = this._sessionState.GetOrInitializeState(context.Session).RecentMessagesText
|
||||
?? [];
|
||||
|
||||
var newMessagesText = context.RequestMessages
|
||||
@@ -182,7 +190,7 @@ public sealed class TextSearchProvider : AIContextProvider<TextSearchProvider.Te
|
||||
: allMessages;
|
||||
|
||||
// Store updated state back to the session.
|
||||
this.SaveState(
|
||||
this._sessionState.SaveState(
|
||||
context.Session,
|
||||
new TextSearchProviderState { RecentMessagesText = updatedMessages });
|
||||
|
||||
|
||||
-199
@@ -1,199 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Contains tests for the <see cref="AIContextProvider{TState}"/> class.
|
||||
/// </summary>
|
||||
public class AIContextProviderTStateTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
|
||||
#region GetOrInitializeState Tests
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_InitializesFromStateInitializerOnFirstCall()
|
||||
{
|
||||
// Arrange
|
||||
var expectedState = new TestState { Value = "initialized" };
|
||||
var provider = new TestAIContextProvider(_ => expectedState);
|
||||
var session = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state = provider.GetState(session);
|
||||
|
||||
// Assert
|
||||
Assert.Same(expectedState, state);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_ReturnsCachedStateFromStateBagOnSecondCall()
|
||||
{
|
||||
// Arrange
|
||||
var callCount = 0;
|
||||
var provider = new TestAIContextProvider(_ =>
|
||||
{
|
||||
callCount++;
|
||||
return new TestState { Value = $"init-{callCount}" };
|
||||
});
|
||||
var session = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state1 = provider.GetState(session);
|
||||
var state2 = provider.GetState(session);
|
||||
|
||||
// Assert - initializer called only once; second call reads from StateBag
|
||||
Assert.Equal(1, callCount);
|
||||
Assert.Equal("init-1", state1.Value);
|
||||
Assert.Equal("init-1", state2.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_WorksWhenSessionIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider(_ => new TestState { Value = "no-session" });
|
||||
|
||||
// Act
|
||||
var state = provider.GetState(null);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("no-session", state.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_ReInitializesWhenSessionIsNull()
|
||||
{
|
||||
// Arrange - without a session, state can't be cached in StateBag
|
||||
var callCount = 0;
|
||||
var provider = new TestAIContextProvider(_ =>
|
||||
{
|
||||
callCount++;
|
||||
return new TestState { Value = $"init-{callCount}" };
|
||||
});
|
||||
|
||||
// Act
|
||||
provider.GetState(null);
|
||||
provider.GetState(null);
|
||||
|
||||
// Assert - initializer called each time since there's no session to cache in
|
||||
Assert.Equal(2, callCount);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region SaveState Tests
|
||||
|
||||
[Fact]
|
||||
public void SaveState_SavesToStateBag()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider(_ => new TestState());
|
||||
var session = new TestAgentSession();
|
||||
var state = new TestState { Value = "saved" };
|
||||
|
||||
// Act
|
||||
provider.DoSaveState(session, state);
|
||||
var retrieved = provider.GetState(session);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("saved", retrieved.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SaveState_NoOpWhenSessionIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider(_ => new TestState { Value = "default" });
|
||||
|
||||
// Act - should not throw
|
||||
provider.DoSaveState(null, new TestState { Value = "saved" });
|
||||
|
||||
// Assert - no exception; can't verify further without a session
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region StateKey Tests
|
||||
|
||||
[Fact]
|
||||
public void StateKey_DefaultsToTypeName()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider(_ => new TestState());
|
||||
|
||||
// Act & Assert
|
||||
Assert.Equal(nameof(TestAIContextProvider), provider.StateKey);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StateKey_UsesCustomKeyWhenProvided()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider(_ => new TestState(), stateKey: "custom-key");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Equal("custom-key", provider.StateKey);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Integration Tests
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingCoreAsync_CanUseStateInProvideAIContextAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestAIContextProvider(_ => new TestState { Value = "state-value" });
|
||||
var session = new TestAgentSession();
|
||||
var inputContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hi")] };
|
||||
var context = new AIContextProvider.InvokingContext(s_mockAgent, session, inputContext);
|
||||
|
||||
// Act
|
||||
var result = await provider.InvokingAsync(context);
|
||||
|
||||
// Assert - the provider uses state to produce context messages
|
||||
var messages = result.Messages!.ToList();
|
||||
Assert.Equal(2, messages.Count);
|
||||
Assert.Contains("state-value", messages[1].Text);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
public sealed class TestState
|
||||
{
|
||||
public string Value { get; set; } = string.Empty;
|
||||
}
|
||||
|
||||
private sealed class TestAIContextProvider : AIContextProvider<TestState>
|
||||
{
|
||||
public TestAIContextProvider(
|
||||
Func<AgentSession?, TestState> stateInitializer,
|
||||
string? stateKey = null)
|
||||
: base(stateInitializer, stateKey, null, null, null)
|
||||
{
|
||||
}
|
||||
|
||||
public TestState GetState(AgentSession? session) => this.GetOrInitializeState(session);
|
||||
|
||||
public void DoSaveState(AgentSession? session, TestState state) => this.SaveState(session, state);
|
||||
|
||||
protected override ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
return new(new AIContext
|
||||
{
|
||||
Messages = [new ChatMessage(ChatRole.System, $"Context from state: {state.Value}")]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class TestAgentSession : AgentSession;
|
||||
}
|
||||
-195
@@ -1,195 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Moq;
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Contains tests for the <see cref="ChatHistoryProvider{TState}"/> class.
|
||||
/// </summary>
|
||||
public class ChatHistoryProviderTStateTests
|
||||
{
|
||||
private static readonly AIAgent s_mockAgent = new Mock<AIAgent>().Object;
|
||||
|
||||
#region GetOrInitializeState Tests
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_InitializesFromStateInitializerOnFirstCall()
|
||||
{
|
||||
// Arrange
|
||||
var expectedState = new TestState { Value = "initialized" };
|
||||
var provider = new TestChatHistoryProvider(_ => expectedState);
|
||||
var session = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state = provider.GetState(session);
|
||||
|
||||
// Assert
|
||||
Assert.Same(expectedState, state);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_ReturnsCachedStateFromStateBagOnSecondCall()
|
||||
{
|
||||
// Arrange
|
||||
var callCount = 0;
|
||||
var provider = new TestChatHistoryProvider(_ =>
|
||||
{
|
||||
callCount++;
|
||||
return new TestState { Value = $"init-{callCount}" };
|
||||
});
|
||||
var session = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state1 = provider.GetState(session);
|
||||
var state2 = provider.GetState(session);
|
||||
|
||||
// Assert - initializer called only once; second call reads from StateBag
|
||||
Assert.Equal(1, callCount);
|
||||
Assert.Equal("init-1", state1.Value);
|
||||
Assert.Equal("init-1", state2.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_WorksWhenSessionIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestChatHistoryProvider(_ => new TestState { Value = "no-session" });
|
||||
|
||||
// Act
|
||||
var state = provider.GetState(null);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("no-session", state.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_ReInitializesWhenSessionIsNull()
|
||||
{
|
||||
// Arrange - without a session, state can't be cached in StateBag
|
||||
var callCount = 0;
|
||||
var provider = new TestChatHistoryProvider(_ =>
|
||||
{
|
||||
callCount++;
|
||||
return new TestState { Value = $"init-{callCount}" };
|
||||
});
|
||||
|
||||
// Act
|
||||
_ = provider.GetState(null);
|
||||
provider.GetState(null);
|
||||
|
||||
// Assert - initializer called each time since there's no session to cache in
|
||||
Assert.Equal(2, callCount);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region SaveState Tests
|
||||
|
||||
[Fact]
|
||||
public void SaveState_SavesToStateBag()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestChatHistoryProvider(_ => new TestState());
|
||||
var session = new TestAgentSession();
|
||||
var state = new TestState { Value = "saved" };
|
||||
|
||||
// Act
|
||||
provider.DoSaveState(session, state);
|
||||
var retrieved = provider.GetState(session);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("saved", retrieved.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SaveState_NoOpWhenSessionIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestChatHistoryProvider(_ => new TestState { Value = "default" });
|
||||
|
||||
// Act - should not throw
|
||||
provider.DoSaveState(null, new TestState { Value = "saved" });
|
||||
|
||||
// Assert - no exception; can't verify further without a session
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region StateKey Tests
|
||||
|
||||
[Fact]
|
||||
public void StateKey_DefaultsToTypeName()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestChatHistoryProvider(_ => new TestState());
|
||||
|
||||
// Act & Assert
|
||||
Assert.Equal(nameof(TestChatHistoryProvider), provider.StateKey);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StateKey_UsesCustomKeyWhenProvided()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestChatHistoryProvider(_ => new TestState(), stateKey: "custom-key");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Equal("custom-key", provider.StateKey);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Integration Tests
|
||||
|
||||
[Fact]
|
||||
public async Task InvokingCoreAsync_CanUseStateInProvideChatHistoryAsync()
|
||||
{
|
||||
// Arrange
|
||||
var provider = new TestChatHistoryProvider(_ => new TestState { Value = "state-value" });
|
||||
var session = new TestAgentSession();
|
||||
var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Hi")]);
|
||||
|
||||
// Act
|
||||
var result = (await provider.InvokingAsync(context)).ToList();
|
||||
|
||||
// Assert - the provider uses state to produce history messages
|
||||
Assert.Equal(2, result.Count);
|
||||
Assert.Contains("state-value", result[0].Text);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
public sealed class TestState
|
||||
{
|
||||
public string Value { get; set; } = string.Empty;
|
||||
}
|
||||
|
||||
private sealed class TestChatHistoryProvider : ChatHistoryProvider<TestState>
|
||||
{
|
||||
public TestChatHistoryProvider(
|
||||
Func<AgentSession?, TestState> stateInitializer,
|
||||
string? stateKey = null)
|
||||
: base(stateInitializer, stateKey, null, null, null)
|
||||
{
|
||||
}
|
||||
|
||||
public TestState GetState(AgentSession? session) => this.GetOrInitializeState(session);
|
||||
|
||||
public void DoSaveState(AgentSession? session, TestState state) => this.SaveState(session, state);
|
||||
|
||||
protected override ValueTask<IEnumerable<ChatMessage>> ProvideChatHistoryAsync(InvokingContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var state = this.GetOrInitializeState(context.Session);
|
||||
return new(new[] { new ChatMessage(ChatRole.System, $"History from state: {state.Value}") });
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class TestAgentSession : AgentSession;
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
namespace Microsoft.Agents.AI.Abstractions.UnitTests;
|
||||
|
||||
/// <summary>
|
||||
/// Contains tests for the <see cref="ProviderSessionState{TState}"/> class.
|
||||
/// </summary>
|
||||
public class ProviderSessionStateTests
|
||||
{
|
||||
#region GetOrInitializeState Tests
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_InitializesFromStateInitializerOnFirstCall()
|
||||
{
|
||||
// Arrange
|
||||
var expectedState = new TestState { Value = "initialized" };
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => expectedState, "test-key");
|
||||
var session = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state = sessionState.GetOrInitializeState(session);
|
||||
|
||||
// Assert
|
||||
Assert.Same(expectedState, state);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_ReturnsCachedStateFromStateBagOnSecondCall()
|
||||
{
|
||||
// Arrange
|
||||
var callCount = 0;
|
||||
var sessionState = new ProviderSessionState<TestState>(_ =>
|
||||
{
|
||||
callCount++;
|
||||
return new TestState { Value = $"init-{callCount}" };
|
||||
}, "test-key");
|
||||
var session = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state1 = sessionState.GetOrInitializeState(session);
|
||||
var state2 = sessionState.GetOrInitializeState(session);
|
||||
|
||||
// Assert - initializer called only once; second call reads from StateBag
|
||||
Assert.Equal(1, callCount);
|
||||
Assert.Equal("init-1", state1.Value);
|
||||
Assert.Equal("init-1", state2.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_WorksWhenSessionIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => new TestState { Value = "no-session" }, "test-key");
|
||||
|
||||
// Act
|
||||
var state = sessionState.GetOrInitializeState(null);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("no-session", state.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_ReInitializesWhenSessionIsNull()
|
||||
{
|
||||
// Arrange - without a session, state can't be cached in StateBag
|
||||
var callCount = 0;
|
||||
var sessionState = new ProviderSessionState<TestState>(_ =>
|
||||
{
|
||||
callCount++;
|
||||
return new TestState { Value = $"init-{callCount}" };
|
||||
}, "test-key");
|
||||
|
||||
// Act
|
||||
sessionState.GetOrInitializeState(null);
|
||||
sessionState.GetOrInitializeState(null);
|
||||
|
||||
// Assert - initializer called each time since there's no session to cache in
|
||||
Assert.Equal(2, callCount);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region SaveState Tests
|
||||
|
||||
[Fact]
|
||||
public void SaveState_SavesToStateBag()
|
||||
{
|
||||
// Arrange
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => new TestState(), "test-key");
|
||||
var session = new TestAgentSession();
|
||||
var state = new TestState { Value = "saved" };
|
||||
|
||||
// Act
|
||||
sessionState.SaveState(session, state);
|
||||
var retrieved = sessionState.GetOrInitializeState(session);
|
||||
|
||||
// Assert
|
||||
Assert.Equal("saved", retrieved.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SaveState_NoOpWhenSessionIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => new TestState { Value = "default" }, "test-key");
|
||||
|
||||
// Act - should not throw
|
||||
sessionState.SaveState(null, new TestState { Value = "saved" });
|
||||
|
||||
// Assert - no exception; can't verify further without a session
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region StateKey Tests
|
||||
|
||||
[Fact]
|
||||
public void StateKey_UsesProvidedKey()
|
||||
{
|
||||
// Arrange
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => new TestState(), "my-provider-key");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Equal("my-provider-key", sessionState.StateKey);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StateKey_UsesCustomKeyWhenProvided()
|
||||
{
|
||||
// Arrange
|
||||
var sessionState = new ProviderSessionState<TestState>(_ => new TestState(), "custom-key");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Equal("custom-key", sessionState.StateKey);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Isolation Tests
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_IsolatesStateBetweenDifferentKeys()
|
||||
{
|
||||
// Arrange
|
||||
var sessionState1 = new ProviderSessionState<TestState>(_ => new TestState { Value = "state-1" }, "key-1");
|
||||
var sessionState2 = new ProviderSessionState<TestState>(_ => new TestState { Value = "state-2" }, "key-2");
|
||||
var session = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state1 = sessionState1.GetOrInitializeState(session);
|
||||
var state2 = sessionState2.GetOrInitializeState(session);
|
||||
|
||||
// Assert - each key maintains independent state
|
||||
Assert.Equal("state-1", state1.Value);
|
||||
Assert.Equal("state-2", state2.Value);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOrInitializeState_IsolatesStateBetweenDifferentSessions()
|
||||
{
|
||||
// Arrange
|
||||
var callCount = 0;
|
||||
var sessionState = new ProviderSessionState<TestState>(_ =>
|
||||
{
|
||||
callCount++;
|
||||
return new TestState { Value = $"init-{callCount}" };
|
||||
}, "test-key");
|
||||
var session1 = new TestAgentSession();
|
||||
var session2 = new TestAgentSession();
|
||||
|
||||
// Act
|
||||
var state1 = sessionState.GetOrInitializeState(session1);
|
||||
var state2 = sessionState.GetOrInitializeState(session2);
|
||||
|
||||
// Assert - each session gets its own state
|
||||
Assert.Equal(2, callCount);
|
||||
Assert.Equal("init-1", state1.Value);
|
||||
Assert.Equal("init-2", state2.Value);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
public sealed class TestState
|
||||
{
|
||||
public string Value { get; set; } = string.Empty;
|
||||
}
|
||||
|
||||
private sealed class TestAgentSession : AgentSession;
|
||||
}
|
||||
Reference in New Issue
Block a user