Files
agent-framework/dotnet/src/Microsoft.Agents.AI.DurableTask/AgentEntity.cs
T
westey 6fdf6111e6 .NET: [BREAKING] Rename GetNewSession to CreateSession (#3501)
* Rename GetNewSession to CreateSession

* Address copilot feedback

* Suppress warning

* Suppress warning

* Fix further warnings.
2026-02-03 11:33:50 +00:00

234 lines
10 KiB
C#

// Copyright (c) Microsoft. All rights reserved.
using Microsoft.Agents.AI.DurableTask.State;
using Microsoft.DurableTask.Client;
using Microsoft.DurableTask.Entities;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
namespace Microsoft.Agents.AI.DurableTask;
internal class AgentEntity(IServiceProvider services, CancellationToken cancellationToken = default) : TaskEntity<DurableAgentState>
{
private readonly IServiceProvider _services = services;
private readonly DurableTaskClient _client = services.GetRequiredService<DurableTaskClient>();
private readonly ILoggerFactory _loggerFactory = services.GetRequiredService<ILoggerFactory>();
private readonly IAgentResponseHandler? _messageHandler = services.GetService<IAgentResponseHandler>();
private readonly DurableAgentsOptions _options = services.GetRequiredService<DurableAgentsOptions>();
private readonly CancellationToken _cancellationToken = cancellationToken != default
? cancellationToken
: services.GetService<IHostApplicationLifetime>()?.ApplicationStopping ?? CancellationToken.None;
public Task<AgentResponse> RunAgentAsync(RunRequest request)
{
return this.Run(request);
}
// IDE1006 and VSTHRD200 disabled to allow method name to match the common cross-platform entity operation name.
#pragma warning disable IDE1006
#pragma warning disable VSTHRD200
public async Task<AgentResponse> Run(RunRequest request)
#pragma warning restore VSTHRD200
#pragma warning restore IDE1006
{
AgentSessionId sessionId = this.Context.Id;
AIAgent agent = this.GetAgent(sessionId);
EntityAgentWrapper agentWrapper = new(agent, this.Context, request, this._services);
// Logger category is Microsoft.DurableTask.Agents.{agentName}.{sessionId}
ILogger logger = this.GetLogger(agent.Name!, sessionId.Key);
if (request.Messages.Count == 0)
{
logger.LogInformation("Ignoring empty request");
return new AgentResponse();
}
this.State.Data.ConversationHistory.Add(DurableAgentStateRequest.FromRunRequest(request));
foreach (ChatMessage msg in request.Messages)
{
logger.LogAgentRequest(sessionId, msg.Role, msg.Text);
}
// Set the current agent context for the duration of the agent run. This will be exposed
// to any tools that are invoked by the agent.
DurableAgentContext agentContext = new(
entityContext: this.Context,
client: this._client,
lifetime: this._services.GetRequiredService<IHostApplicationLifetime>(),
services: this._services);
DurableAgentContext.SetCurrent(agentContext);
try
{
// Start the agent response stream
IAsyncEnumerable<AgentResponseUpdate> responseStream = agentWrapper.RunStreamingAsync(
this.State.Data.ConversationHistory.SelectMany(e => e.Messages).Select(m => m.ToChatMessage()),
await agentWrapper.CreateSessionAsync(cancellationToken).ConfigureAwait(false),
options: null,
this._cancellationToken);
AgentResponse response;
if (this._messageHandler is null)
{
// If no message handler is provided, we can just get the full response at once.
// This is expected to be the common case for non-interactive agents.
response = await responseStream.ToAgentResponseAsync(this._cancellationToken);
}
else
{
List<AgentResponseUpdate> responseUpdates = [];
// To support interactive chat agents, we need to stream the responses to an IAgentMessageHandler.
// The user-provided message handler can be implemented to send the responses to the user.
// We assume that only non-empty text updates are useful for the user.
async IAsyncEnumerable<AgentResponseUpdate> StreamResultsAsync()
{
await foreach (AgentResponseUpdate update in responseStream)
{
// We need the full response further down, so we piece it together as we go.
responseUpdates.Add(update);
// Yield the update to the message handler.
yield return update;
}
}
await this._messageHandler.OnStreamingResponseUpdateAsync(StreamResultsAsync(), this._cancellationToken);
response = responseUpdates.ToAgentResponse();
}
// Persist the agent response to the entity state for client polling
this.State.Data.ConversationHistory.Add(
DurableAgentStateResponse.FromResponse(request.CorrelationId, response));
string responseText = response.Text;
if (!string.IsNullOrEmpty(responseText))
{
logger.LogAgentResponse(
sessionId,
response.Messages.FirstOrDefault()?.Role ?? ChatRole.Assistant,
responseText,
response.Usage?.InputTokenCount,
response.Usage?.OutputTokenCount,
response.Usage?.TotalTokenCount);
}
// Update TTL expiration time. Only schedule deletion check on first interaction.
// Subsequent interactions just update the expiration time; CheckAndDeleteIfExpiredAsync
// will reschedule the deletion check when it runs.
TimeSpan? timeToLive = this._options.GetTimeToLive(sessionId.Name);
if (timeToLive.HasValue)
{
DateTime newExpirationTime = DateTime.UtcNow.Add(timeToLive.Value);
bool isFirstInteraction = this.State.Data.ExpirationTimeUtc is null;
this.State.Data.ExpirationTimeUtc = newExpirationTime;
logger.LogTTLExpirationTimeUpdated(sessionId, newExpirationTime);
// Only schedule deletion check on the first interaction when entity is created.
// On subsequent interactions, we just update the expiration time. The scheduled
// CheckAndDeleteIfExpiredAsync will reschedule itself if the entity hasn't expired.
if (isFirstInteraction)
{
this.ScheduleDeletionCheck(sessionId, logger, timeToLive.Value);
}
}
else
{
// TTL is disabled. Clear the expiration time if it was previously set.
if (this.State.Data.ExpirationTimeUtc.HasValue)
{
logger.LogTTLExpirationTimeCleared(sessionId);
this.State.Data.ExpirationTimeUtc = null;
}
}
return response;
}
finally
{
// Clear the current agent context
DurableAgentContext.ClearCurrent();
}
}
/// <summary>
/// Checks if the entity has expired and deletes it if so, otherwise reschedules the deletion check.
/// </summary>
/// <remarks>
/// This method is called by the durable task runtime when a <c>CheckAndDeleteIfExpired</c> signal is received.
/// </remarks>
public void CheckAndDeleteIfExpired()
{
AgentSessionId sessionId = this.Context.Id;
AIAgent agent = this.GetAgent(sessionId);
ILogger logger = this.GetLogger(agent.Name!, sessionId.Key);
DateTime currentTime = DateTime.UtcNow;
DateTime? expirationTime = this.State.Data.ExpirationTimeUtc;
logger.LogTTLDeletionCheck(sessionId, expirationTime, currentTime);
if (expirationTime.HasValue)
{
if (currentTime >= expirationTime.Value)
{
// Entity has expired, delete it
logger.LogTTLEntityExpired(sessionId, expirationTime.Value);
this.State = null!;
}
else
{
// Entity hasn't expired yet, reschedule the deletion check
TimeSpan? timeToLive = this._options.GetTimeToLive(sessionId.Name);
if (timeToLive.HasValue)
{
this.ScheduleDeletionCheck(sessionId, logger, timeToLive.Value);
}
}
}
}
private void ScheduleDeletionCheck(AgentSessionId sessionId, ILogger logger, TimeSpan timeToLive)
{
DateTime currentTime = DateTime.UtcNow;
DateTime expirationTime = this.State.Data.ExpirationTimeUtc ?? currentTime.Add(timeToLive);
TimeSpan minimumDelay = this._options.MinimumTimeToLiveSignalDelay;
// To avoid excessive scheduling, we schedule the deletion check for no less than the minimum delay.
DateTime scheduledTime = expirationTime > currentTime.Add(minimumDelay)
? expirationTime
: currentTime.Add(minimumDelay);
logger.LogTTLDeletionScheduled(sessionId, scheduledTime);
// Schedule a signal to self to check for expiration
this.Context.SignalEntity(
this.Context.Id,
nameof(CheckAndDeleteIfExpired), // self-signal
options: new SignalEntityOptions { SignalTime = scheduledTime });
}
private AIAgent GetAgent(AgentSessionId sessionId)
{
IReadOnlyDictionary<string, Func<IServiceProvider, AIAgent>> agents =
this._services.GetRequiredService<IReadOnlyDictionary<string, Func<IServiceProvider, AIAgent>>>();
if (!agents.TryGetValue(sessionId.Name, out Func<IServiceProvider, AIAgent>? agentFactory))
{
throw new InvalidOperationException($"Agent '{sessionId.Name}' not found");
}
return agentFactory(this._services);
}
private ILogger GetLogger(string agentName, string sessionKey)
{
return this._loggerFactory.CreateLogger($"Microsoft.DurableTask.Agents.{agentName}.{sessionKey}");
}
}