Files
westey 08275f657b .NET: Improve session cast error message quality and consistency (#3973)
* Improve session cast error messge consistency

* Update changelog
2026-02-17 15:51:30 +00:00

451 lines
20 KiB
C#

// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Agents.AI.AGUI;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests;
public sealed class SharedStateTests : IAsyncDisposable
{
private WebApplication? _app;
private HttpClient? _client;
[Fact]
public async Task StateSnapshot_IsReturnedAsDataContent_WithCorrectMediaTypeAsync()
{
// Arrange
var initialState = new { counter = 42, status = "active" };
var fakeAgent = new FakeStateAgent();
await this.SetupTestServerAsync(fakeAgent);
var chatClient = new AGUIChatClient(this._client!, "", null);
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []);
ChatClientAgentSession? session = (ChatClientAgentSession)await agent.CreateSessionAsync();
string stateJson = JsonSerializer.Serialize(initialState);
byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson);
DataContent stateContent = new(stateBytes, "application/json");
ChatMessage stateMessage = new(ChatRole.System, [stateContent]);
ChatMessage userMessage = new(ChatRole.User, "update state");
List<AgentResponseUpdate> updates = [];
// Act
await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], session, new AgentRunOptions(), CancellationToken.None))
{
updates.Add(update);
}
// Assert
updates.Should().NotBeEmpty();
// Should receive state snapshot as DataContent with application/json media type
AgentResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
stateUpdate.Should().NotBeNull("should receive state snapshot update");
DataContent? dataContent = stateUpdate!.Contents.OfType<DataContent>().FirstOrDefault(dc => dc.MediaType == "application/json");
dataContent.Should().NotBeNull();
// Verify the state content
string receivedJson = System.Text.Encoding.UTF8.GetString(dataContent!.Data.ToArray());
JsonElement receivedState = JsonElement.Parse(receivedJson);
receivedState.GetProperty("counter").GetInt32().Should().Be(43, "state should be incremented");
receivedState.GetProperty("status").GetString().Should().Be("active");
}
[Fact]
public async Task StateSnapshot_HasCorrectAdditionalPropertiesAsync()
{
// Arrange
var initialState = new { step = 1 };
var fakeAgent = new FakeStateAgent();
await this.SetupTestServerAsync(fakeAgent);
var chatClient = new AGUIChatClient(this._client!, "", null);
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []);
ChatClientAgentSession? session = (ChatClientAgentSession)await agent.CreateSessionAsync();
string stateJson = JsonSerializer.Serialize(initialState);
byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson);
DataContent stateContent = new(stateBytes, "application/json");
ChatMessage stateMessage = new(ChatRole.System, [stateContent]);
ChatMessage userMessage = new(ChatRole.User, "process");
List<AgentResponseUpdate> updates = [];
// Act
await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], session, new AgentRunOptions(), CancellationToken.None))
{
updates.Add(update);
}
// Assert
AgentResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
stateUpdate.Should().NotBeNull();
ChatResponseUpdate chatUpdate = stateUpdate!.AsChatResponseUpdate();
chatUpdate.AdditionalProperties.Should().NotBeNull();
chatUpdate.AdditionalProperties.Should().ContainKey("is_state_snapshot");
((bool)chatUpdate.AdditionalProperties!["is_state_snapshot"]!).Should().BeTrue();
}
[Fact]
public async Task ComplexState_WithNestedObjectsAndArrays_RoundTripsCorrectlyAsync()
{
// Arrange
var complexState = new
{
sessionId = "test-123",
nested = new { value = "test", count = 10 },
array = new[] { 1, 2, 3 },
tags = new[] { "tag1", "tag2" }
};
var fakeAgent = new FakeStateAgent();
await this.SetupTestServerAsync(fakeAgent);
var chatClient = new AGUIChatClient(this._client!, "", null);
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []);
ChatClientAgentSession? session = (ChatClientAgentSession)await agent.CreateSessionAsync();
string stateJson = JsonSerializer.Serialize(complexState);
byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson);
DataContent stateContent = new(stateBytes, "application/json");
ChatMessage stateMessage = new(ChatRole.System, [stateContent]);
ChatMessage userMessage = new(ChatRole.User, "process complex state");
List<AgentResponseUpdate> updates = [];
// Act
await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], session, new AgentRunOptions(), CancellationToken.None))
{
updates.Add(update);
}
// Assert
AgentResponseUpdate? stateUpdate = updates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
stateUpdate.Should().NotBeNull();
DataContent? dataContent = stateUpdate!.Contents.OfType<DataContent>().FirstOrDefault(dc => dc.MediaType == "application/json");
string receivedJson = System.Text.Encoding.UTF8.GetString(dataContent!.Data.ToArray());
JsonElement receivedState = JsonElement.Parse(receivedJson);
receivedState.GetProperty("sessionId").GetString().Should().Be("test-123");
receivedState.GetProperty("nested").GetProperty("count").GetInt32().Should().Be(10);
receivedState.GetProperty("array").GetArrayLength().Should().Be(3);
receivedState.GetProperty("tags").GetArrayLength().Should().Be(2);
}
[Fact]
public async Task StateSnapshot_CanBeUsedInSubsequentRequest_ForStateRoundTripAsync()
{
// Arrange
var initialState = new { counter = 1, sessionId = "round-trip-test" };
var fakeAgent = new FakeStateAgent();
await this.SetupTestServerAsync(fakeAgent);
var chatClient = new AGUIChatClient(this._client!, "", null);
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []);
ChatClientAgentSession? session = (ChatClientAgentSession)await agent.CreateSessionAsync();
string stateJson = JsonSerializer.Serialize(initialState);
byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson);
DataContent stateContent = new(stateBytes, "application/json");
ChatMessage stateMessage = new(ChatRole.System, [stateContent]);
ChatMessage userMessage = new(ChatRole.User, "increment");
List<AgentResponseUpdate> firstRoundUpdates = [];
// Act - First round
await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], session, new AgentRunOptions(), CancellationToken.None))
{
firstRoundUpdates.Add(update);
}
// Extract state snapshot from first round
AgentResponseUpdate? firstStateUpdate = firstRoundUpdates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
firstStateUpdate.Should().NotBeNull();
DataContent? firstStateContent = firstStateUpdate!.Contents.OfType<DataContent>().FirstOrDefault(dc => dc.MediaType == "application/json");
// Second round - use returned state
ChatMessage secondStateMessage = new(ChatRole.System, [firstStateContent!]);
ChatMessage secondUserMessage = new(ChatRole.User, "increment again");
List<AgentResponseUpdate> secondRoundUpdates = [];
await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([secondUserMessage, secondStateMessage], session, new AgentRunOptions(), CancellationToken.None))
{
secondRoundUpdates.Add(update);
}
// Assert - Second round should have incremented counter again
AgentResponseUpdate? secondStateUpdate = secondRoundUpdates.FirstOrDefault(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
secondStateUpdate.Should().NotBeNull();
DataContent? secondStateContent = secondStateUpdate!.Contents.OfType<DataContent>().FirstOrDefault(dc => dc.MediaType == "application/json");
string secondStateJson = System.Text.Encoding.UTF8.GetString(secondStateContent!.Data.ToArray());
JsonElement secondState = JsonElement.Parse(secondStateJson);
secondState.GetProperty("counter").GetInt32().Should().Be(3, "counter should be incremented twice: 1 -> 2 -> 3");
}
[Fact]
public async Task WithoutState_AgentBehavesNormally_NoStateSnapshotReturnedAsync()
{
// Arrange
var fakeAgent = new FakeStateAgent();
await this.SetupTestServerAsync(fakeAgent);
var chatClient = new AGUIChatClient(this._client!, "", null);
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []);
ChatClientAgentSession? session = (ChatClientAgentSession)await agent.CreateSessionAsync();
ChatMessage userMessage = new(ChatRole.User, "hello");
List<AgentResponseUpdate> updates = [];
// Act
await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage], session, new AgentRunOptions(), CancellationToken.None))
{
updates.Add(update);
}
// Assert
updates.Should().NotBeEmpty();
// Should NOT have state snapshot when no state is sent
bool hasStateSnapshot = updates.Any(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
hasStateSnapshot.Should().BeFalse("should not return state snapshot when no state is provided");
// Should have normal text response
updates.Should().Contain(u => u.Contents.Any(c => c is TextContent));
}
[Fact]
public async Task EmptyState_DoesNotTriggerStateHandlingAsync()
{
// Arrange
var emptyState = new { };
var fakeAgent = new FakeStateAgent();
await this.SetupTestServerAsync(fakeAgent);
var chatClient = new AGUIChatClient(this._client!, "", null);
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []);
ChatClientAgentSession? session = (ChatClientAgentSession)await agent.CreateSessionAsync();
string stateJson = JsonSerializer.Serialize(emptyState);
byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson);
DataContent stateContent = new(stateBytes, "application/json");
ChatMessage stateMessage = new(ChatRole.System, [stateContent]);
ChatMessage userMessage = new(ChatRole.User, "hello");
List<AgentResponseUpdate> updates = [];
// Act
await foreach (AgentResponseUpdate update in agent.RunStreamingAsync([userMessage, stateMessage], session, new AgentRunOptions(), CancellationToken.None))
{
updates.Add(update);
}
// Assert
updates.Should().NotBeEmpty();
// Empty state {} should not trigger state snapshot mechanism
bool hasEmptyStateSnapshot = updates.Any(u => u.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
hasEmptyStateSnapshot.Should().BeFalse("empty state should be treated as no state");
// Should have normal response
updates.Should().Contain(u => u.Contents.Any(c => c is TextContent));
}
[Fact]
public async Task NonStreamingRunAsync_WithState_ReturnsStateInResponseAsync()
{
// Arrange
var initialState = new { counter = 5 };
var fakeAgent = new FakeStateAgent();
await this.SetupTestServerAsync(fakeAgent);
var chatClient = new AGUIChatClient(this._client!, "", null);
AIAgent agent = chatClient.AsAIAgent(instructions: null, name: "assistant", description: "Sample assistant", tools: []);
ChatClientAgentSession? session = (ChatClientAgentSession)await agent.CreateSessionAsync();
string stateJson = JsonSerializer.Serialize(initialState);
byte[] stateBytes = System.Text.Encoding.UTF8.GetBytes(stateJson);
DataContent stateContent = new(stateBytes, "application/json");
ChatMessage stateMessage = new(ChatRole.System, [stateContent]);
ChatMessage userMessage = new(ChatRole.User, "process");
// Act
AgentResponse response = await agent.RunAsync([userMessage, stateMessage], session, new AgentRunOptions(), CancellationToken.None);
// Assert
response.Should().NotBeNull();
response.Messages.Should().NotBeEmpty();
// Should have message with DataContent containing state
bool hasStateMessage = response.Messages.Any(m => m.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
hasStateMessage.Should().BeTrue("response should contain state message");
ChatMessage? stateResponseMessage = response.Messages.FirstOrDefault(m => m.Contents.Any(c => c is DataContent dc && dc.MediaType == "application/json"));
stateResponseMessage.Should().NotBeNull();
DataContent? dataContent = stateResponseMessage!.Contents.OfType<DataContent>().FirstOrDefault(dc => dc.MediaType == "application/json");
string receivedJson = System.Text.Encoding.UTF8.GetString(dataContent!.Data.ToArray());
JsonElement receivedState = JsonElement.Parse(receivedJson);
receivedState.GetProperty("counter").GetInt32().Should().Be(6);
}
private async Task SetupTestServerAsync(FakeStateAgent fakeAgent)
{
WebApplicationBuilder builder = WebApplication.CreateBuilder();
builder.Services.AddAGUI();
builder.WebHost.UseTestServer();
this._app = builder.Build();
this._app.MapAGUI("/agent", fakeAgent);
await this._app.StartAsync();
TestServer testServer = this._app.Services.GetRequiredService<IServer>() as TestServer
?? throw new InvalidOperationException("TestServer not found");
this._client = testServer.CreateClient();
this._client.BaseAddress = new Uri("http://localhost/agent");
}
public async ValueTask DisposeAsync()
{
this._client?.Dispose();
if (this._app != null)
{
await this._app.DisposeAsync();
}
}
}
[SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated in tests")]
internal sealed class FakeStateAgent : AIAgent
{
public override string? Description => "Agent for state testing";
protected override Task<AgentResponse> RunCoreAsync(IEnumerable<ChatMessage> messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
{
return this.RunCoreStreamingAsync(messages, session, options, cancellationToken).ToAgentResponseAsync(cancellationToken);
}
protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingAsync(
IEnumerable<ChatMessage> messages,
AgentSession? session = null,
AgentRunOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Check for state in ChatOptions.AdditionalProperties (set by AG-UI hosting layer)
if (options is ChatClientAgentRunOptions { ChatOptions.AdditionalProperties: { } properties } &&
properties.TryGetValue("ag_ui_state", out object? stateObj) &&
stateObj is JsonElement state &&
state.ValueKind == JsonValueKind.Object)
{
// Check if state object has properties (not empty {})
bool hasProperties = false;
foreach (JsonProperty _ in state.EnumerateObject())
{
hasProperties = true;
break;
}
if (hasProperties)
{
// State is present and non-empty - modify it and return as DataContent
Dictionary<string, object?> modifiedState = [];
foreach (JsonProperty prop in state.EnumerateObject())
{
if (prop.Name == "counter" && prop.Value.ValueKind == JsonValueKind.Number)
{
modifiedState[prop.Name] = prop.Value.GetInt32() + 1;
}
else if (prop.Value.ValueKind == JsonValueKind.Number)
{
modifiedState[prop.Name] = prop.Value.GetInt32();
}
else if (prop.Value.ValueKind == JsonValueKind.String)
{
modifiedState[prop.Name] = prop.Value.GetString();
}
else if (prop.Value.ValueKind is JsonValueKind.Object or JsonValueKind.Array)
{
modifiedState[prop.Name] = prop.Value;
}
}
// Return modified state as DataContent
string modifiedStateJson = JsonSerializer.Serialize(modifiedState);
byte[] modifiedStateBytes = System.Text.Encoding.UTF8.GetBytes(modifiedStateJson);
DataContent modifiedStateContent = new(modifiedStateBytes, "application/json");
yield return new AgentResponseUpdate
{
MessageId = Guid.NewGuid().ToString("N"),
Role = ChatRole.Assistant,
Contents = [modifiedStateContent]
};
}
}
// Always return a text response
string messageId = Guid.NewGuid().ToString("N");
yield return new AgentResponseUpdate
{
MessageId = messageId,
Role = ChatRole.Assistant,
Contents = [new TextContent("State processed")]
};
await Task.CompletedTask;
}
protected override ValueTask<AgentSession> CreateSessionCoreAsync(CancellationToken cancellationToken = default) =>
new(new FakeAgentSession());
protected override ValueTask<AgentSession> DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) =>
new(serializedState.Deserialize<FakeAgentSession>(jsonSerializerOptions)!);
protected override ValueTask<JsonElement> SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
if (session is not FakeAgentSession fakeSession)
{
throw new InvalidOperationException($"The provided session type '{session.GetType().Name}' is not compatible with this agent. Only sessions of type '{nameof(FakeAgentSession)}' can be serialized by this agent.");
}
return new(JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions));
}
private sealed class FakeAgentSession : AgentSession
{
public FakeAgentSession()
{
}
[JsonConstructor]
public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag)
{
}
}
public override object? GetService(Type serviceType, object? serviceKey = null) => null;
}