// Copyright (c) Microsoft. All rights reserved. using System.Runtime.CompilerServices; using System.Text.Json; using Microsoft.Agents.AI; using Microsoft.Extensions.AI; namespace RecipeClient; /// /// A delegating agent that manages client-side state and automatically attaches it to requests. /// /// The state type. internal sealed class StatefulAgent : DelegatingAIAgent where TState : class, new() { private readonly JsonSerializerOptions _jsonSerializerOptions; /// /// Gets or sets the current state. /// public TState State { get; set; } /// /// Initializes a new instance of the class. /// /// The underlying agent to delegate to. /// The JSON serializer options for state serialization. /// The initial state. If null, a new instance will be created. public StatefulAgent(AIAgent innerAgent, JsonSerializerOptions jsonSerializerOptions, TState? initialState = null) : base(innerAgent) { this._jsonSerializerOptions = jsonSerializerOptions; this.State = initialState ?? new TState(); } /// protected override Task RunCoreAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { return this.RunCoreStreamingAsync(messages, thread, options, cancellationToken) .ToAgentResponseAsync(cancellationToken); } /// protected override async IAsyncEnumerable RunCoreStreamingAsync( IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Add state to messages List messagesWithState = [.. messages]; // Serialize the state using AgentState wrapper byte[] stateBytes = JsonSerializer.SerializeToUtf8Bytes( this.State, this._jsonSerializerOptions.GetTypeInfo(typeof(TState))); DataContent stateContent = new(stateBytes, "application/json"); ChatMessage stateMessage = new(ChatRole.System, [stateContent]); messagesWithState.Add(stateMessage); // Stream the response and update state when received await foreach (AgentResponseUpdate update in this.InnerAgent.RunStreamingAsync(messagesWithState, thread, options, cancellationToken)) { // Check if this update contains a state snapshot foreach (AIContent content in update.Contents) { if (content is DataContent dataContent && dataContent.MediaType == "application/json") { // Deserialize the state TState? newState = JsonSerializer.Deserialize( dataContent.Data.Span, this._jsonSerializerOptions.GetTypeInfo(typeof(TState))) as TState; if (newState != null) { this.State = newState; } } } yield return update; } } }