mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: fix: preserve AG-UI session history (#5904)
* fix: preserve AG-UI session history * refactor: use static AG-UI provider check
This commit is contained in:
committed by
GitHub
Unverified
parent
e89e745bc0
commit
d222079df9
@@ -38,6 +38,8 @@ namespace Microsoft.Agents.AI;
|
||||
/// </remarks>
|
||||
public sealed partial class ChatClientAgent : AIAgent
|
||||
{
|
||||
private const string AGUIProviderName = "ag-ui";
|
||||
|
||||
private readonly ChatClientAgentOptions? _agentOptions;
|
||||
private readonly HashSet<string> _aiContextProviderStateKeys;
|
||||
private readonly AIAgentMetadata _agentMetadata;
|
||||
@@ -815,7 +817,7 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(responseConversationId))
|
||||
{
|
||||
if (this._agentOptions?.ChatHistoryProvider is not null)
|
||||
if (!IsAGUIProviderName(this._agentMetadata.ProviderName) && this._agentOptions?.ChatHistoryProvider is not null)
|
||||
{
|
||||
// The agent has a ChatHistoryProvider configured, but the service returned a conversation id,
|
||||
// meaning the service manages chat history server-side. Both cannot be used simultaneously.
|
||||
@@ -929,6 +931,9 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
}
|
||||
}
|
||||
|
||||
private static bool IsAGUIProviderName(string? providerName) =>
|
||||
string.Equals(providerName, AGUIProviderName, StringComparison.Ordinal);
|
||||
|
||||
/// <summary>
|
||||
/// Ensures that <see cref="AIAgent.CurrentRunContext"/> contains the resolved session.
|
||||
/// </summary>
|
||||
@@ -976,12 +981,17 @@ public sealed partial class ChatClientAgent : AIAgent
|
||||
|
||||
private ChatHistoryProvider? ResolveChatHistoryProvider(ChatOptions? chatOptions)
|
||||
{
|
||||
ChatHistoryProvider? provider = chatOptions?.ConversationId is null ? this.ChatHistoryProvider : null;
|
||||
ChatHistoryProvider? provider =
|
||||
chatOptions?.ConversationId is null || IsAGUIProviderName(this._agentMetadata.ProviderName)
|
||||
? this.ChatHistoryProvider
|
||||
: null;
|
||||
|
||||
// If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead.
|
||||
if (chatOptions?.AdditionalProperties?.TryGetValue(out ChatHistoryProvider? overrideProvider) is true)
|
||||
{
|
||||
if (this._agentOptions?.ThrowOnChatHistoryProviderConflict is true && string.IsNullOrWhiteSpace(chatOptions?.ConversationId) is false)
|
||||
if (!IsAGUIProviderName(this._agentMetadata.ProviderName) &&
|
||||
this._agentOptions?.ThrowOnChatHistoryProviderConflict is true &&
|
||||
string.IsNullOrWhiteSpace(chatOptions?.ConversationId) is false)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
$"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} was provided via {nameof(AgentRunOptions.AdditionalProperties)}.");
|
||||
|
||||
@@ -243,6 +243,46 @@ public sealed class AGUIAgentTests
|
||||
Assert.Contains(updates, u => u.Text == "Hello");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RunStreamingAsync_WithSession_SendsFullHistoryAfterThreadIdIsSetAsync()
|
||||
{
|
||||
// Arrange
|
||||
var captureHandler = new StateCapturingTestDelegatingHandler();
|
||||
captureHandler.AddResponse(
|
||||
[
|
||||
new RunStartedEvent { ThreadId = "thread1", RunId = "run1" },
|
||||
new TextMessageStartEvent { MessageId = "msg1", Role = AGUIRoles.Assistant },
|
||||
new TextMessageContentEvent { MessageId = "msg1", Delta = "First response" },
|
||||
new TextMessageEndEvent { MessageId = "msg1" },
|
||||
new RunFinishedEvent { ThreadId = "thread1", RunId = "run1" }
|
||||
]);
|
||||
captureHandler.AddResponse(
|
||||
[
|
||||
new RunStartedEvent { ThreadId = "thread1", RunId = "run2" },
|
||||
new TextMessageStartEvent { MessageId = "msg2", Role = AGUIRoles.Assistant },
|
||||
new TextMessageContentEvent { MessageId = "msg2", Delta = "Second response" },
|
||||
new TextMessageEndEvent { MessageId = "msg2" },
|
||||
new RunFinishedEvent { ThreadId = "thread1", RunId = "run2" }
|
||||
]);
|
||||
using HttpClient httpClient = new(captureHandler);
|
||||
|
||||
var chatClient = new AGUIChatClient(httpClient, "http://localhost/agent", null, AGUIJsonSerializerContext.Default.Options);
|
||||
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "agent1", description: "Test agent", tools: []);
|
||||
AgentSession session = await agent.CreateSessionAsync();
|
||||
|
||||
// Act
|
||||
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "First")], session))
|
||||
{
|
||||
}
|
||||
|
||||
await foreach (var _ in agent.RunStreamingAsync([new ChatMessage(ChatRole.User, "Second")], session))
|
||||
{
|
||||
}
|
||||
|
||||
// Assert
|
||||
Assert.Equal([1, 3], captureHandler.CapturedMessageCounts);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task DeserializeSession_WithValidState_ReturnsChatClientAgentSessionAsync()
|
||||
{
|
||||
@@ -1686,10 +1726,12 @@ internal sealed class CapturingTestDelegatingHandler : DelegatingHandler
|
||||
internal sealed class StateCapturingTestDelegatingHandler : DelegatingHandler
|
||||
{
|
||||
private readonly Queue<Func<HttpRequestMessage, Task<HttpResponseMessage>>> _responseFactories = new();
|
||||
private readonly List<int> _capturedMessageCounts = [];
|
||||
|
||||
public bool RequestWasMade { get; private set; }
|
||||
public JsonElement? CapturedState { get; private set; }
|
||||
public int CapturedMessageCount { get; private set; }
|
||||
public IReadOnlyList<int> CapturedMessageCounts => this._capturedMessageCounts;
|
||||
|
||||
public void AddResponse(BaseEvent[] events)
|
||||
{
|
||||
@@ -1714,6 +1756,7 @@ internal sealed class StateCapturingTestDelegatingHandler : DelegatingHandler
|
||||
this.CapturedState = input.State;
|
||||
}
|
||||
this.CapturedMessageCount = input.Messages.Count();
|
||||
this._capturedMessageCounts.Add(this.CapturedMessageCount);
|
||||
}
|
||||
|
||||
if (this._responseFactories.Count == 0)
|
||||
|
||||
Reference in New Issue
Block a user