.NET: Add A2A input-request content for human-in-the-loop scenarios (#5743)

* .NET: Add A2A input-request content for human-in-the-loop scenarios

Adds first-class support for handling user input requests from A2A agents
when they return an `input-required` task state.

- Add `A2AInputRequestContent` (wraps the requested `AIContent`) and
  `A2AInputResponseContent` (wraps the user's `AIContent` reply), with
  `CreateResponse` helper overloads on the request type.
- Surface input requests on `AgentResponse` / `AgentResponseUpdate` via
  `AgentTask` and `TaskStatusUpdateEvent` mappings.
- Link follow-up messages containing `A2AInputResponseContent` to the
  existing task via `TaskId` instead of `ReferenceTaskIds`.
- Add `A2AAgent_HumanInTheLoop` sample and register it in the solution
  and parent README.
- Add unit tests for the new types, extensions, and `A2AAgent` paths.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Remove unnecessary using directive flagged by CI format check

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* address feedback

* Guard against null TaskId when sending A2AInputResponseContent

Throw InvalidOperationException if TaskId is missing when the message
contains A2AInputResponseContent, preventing silent no-op responses.
Also adds tests for both RunAsync and RunStreamingAsync paths.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Leave Contents null for non-InputRequired status updates

Remove unnecessary '?? []' fallback so Contents stays null when there
are no input requests, matching the other update mapping patterns.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Use consistent GUID format for request IDs

Use ToString("N") to match message ID format used elsewhere in
the A2A component.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Remove Debug build exclusion for the HumanInTheLoop sample so it                                                                                                                                                                                                               participates in normal solution validation.

* Add missing using Microsoft.Extensions.AI to A2AAgent_HumanInTheLoop

The sample uses ChatMessage, TextContent, and ChatRole types from
Microsoft.Extensions.AI but was missing the using directive, causing
CS0246 build errors on all CI jobs.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* change the way user input requests are handled based on pr review comments

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
SergeyMenshykh
2026-05-12 14:10:18 +01:00
committed by GitHub
Unverified
parent 939d4d0153
commit dfc3079d68
9 changed files with 395 additions and 21 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
<Solution>
<Solution>
<Configurations>
<BuildType Name="Debug" />
<BuildType Name="Publish" />
+29 -14
View File
@@ -93,26 +93,26 @@ public sealed class A2AAgent : AIAgent
/// <inheritdoc/>
protected override async Task<AgentResponse> RunCoreAsync(IEnumerable<ChatMessage> messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(messages);
var inputMessages = Throw.IfNull(messages) as IReadOnlyCollection<ChatMessage> ?? messages.ToList();
A2AAgentSession typedSession = await this.GetA2ASessionAsync(session, options, cancellationToken).ConfigureAwait(false);
this._logger.LogA2AAgentInvokingAgent(nameof(RunAsync), this.Id, this.Name);
if (GetContinuationToken(messages, options) is { } token)
if (GetContinuationToken(inputMessages, options) is { } token)
{
AgentTask agentTask = await this._a2aClient.GetTaskAsync(new GetTaskRequest { Id = token.TaskId }, cancellationToken).ConfigureAwait(false);
this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, this.Name);
UpdateSession(typedSession, agentTask.ContextId, agentTask.Id);
UpdateSession(typedSession, agentTask.ContextId, agentTask.Id, agentTask.Status.State);
return this.ConvertToAgentResponse(agentTask);
}
SendMessageRequest sendParams = new()
{
Message = CreateA2AMessage(typedSession, messages),
Message = CreateA2AMessage(typedSession, inputMessages),
Metadata = options?.AdditionalProperties?.ToA2AMetadata(),
Configuration = new SendMessageConfiguration { ReturnImmediately = options?.AllowBackgroundResponses is true }
};
@@ -134,7 +134,7 @@ public sealed class A2AAgent : AIAgent
{
var agentTask = a2aResponse.Task!;
UpdateSession(typedSession, agentTask.ContextId, agentTask.Id);
UpdateSession(typedSession, agentTask.ContextId, agentTask.Id, agentTask.Status.State);
return this.ConvertToAgentResponse(agentTask);
}
@@ -145,7 +145,7 @@ public sealed class A2AAgent : AIAgent
/// <inheritdoc/>
protected override async IAsyncEnumerable<AgentResponseUpdate> RunCoreStreamingAsync(IEnumerable<ChatMessage> messages, AgentSession? session = null, AgentRunOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(messages);
var inputMessages = Throw.IfNull(messages) as IReadOnlyCollection<ChatMessage> ?? messages.ToList();
A2AAgentSession typedSession = await this.GetA2ASessionAsync(session, options, cancellationToken).ConfigureAwait(false);
@@ -153,7 +153,7 @@ public sealed class A2AAgent : AIAgent
ConfiguredCancelableAsyncEnumerable<StreamResponse> streamEvents;
if (GetContinuationToken(messages, options) is { } token)
if (GetContinuationToken(inputMessages, options) is { } token)
{
streamEvents = this.SubscribeToTaskWithFallbackAsync(token.TaskId, cancellationToken).ConfigureAwait(false);
}
@@ -161,7 +161,7 @@ public sealed class A2AAgent : AIAgent
{
SendMessageRequest sendParams = new()
{
Message = CreateA2AMessage(typedSession, messages),
Message = CreateA2AMessage(typedSession, inputMessages),
Metadata = options?.AdditionalProperties?.ToA2AMetadata()
};
@@ -172,6 +172,7 @@ public sealed class A2AAgent : AIAgent
string? contextId = null;
string? taskId = null;
TaskState? taskState = null;
await foreach (var streamResponse in streamEvents)
{
@@ -187,6 +188,7 @@ public sealed class A2AAgent : AIAgent
var task = streamResponse.Task!;
contextId = task.ContextId;
taskId = task.Id;
taskState = task.Status.State;
yield return this.ConvertToAgentResponseUpdate(task);
break;
@@ -194,6 +196,7 @@ public sealed class A2AAgent : AIAgent
var statusUpdate = streamResponse.StatusUpdate!;
contextId = statusUpdate.ContextId;
taskId = statusUpdate.TaskId;
taskState = statusUpdate.Status.State;
yield return this.ConvertToAgentResponseUpdate(statusUpdate);
break;
@@ -209,7 +212,7 @@ public sealed class A2AAgent : AIAgent
}
}
UpdateSession(typedSession, contextId, taskId);
UpdateSession(typedSession, contextId, taskId, taskState);
}
/// <inheritdoc/>
@@ -317,7 +320,7 @@ public sealed class A2AAgent : AIAgent
}
}
private static void UpdateSession(A2AAgentSession? session, string? contextId, string? taskId = null)
private static void UpdateSession(A2AAgentSession? session, string? contextId, string? taskId = null, TaskState? taskState = null)
{
if (session is null)
{
@@ -335,9 +338,10 @@ public sealed class A2AAgent : AIAgent
// Assign a server-generated context Id to the session if it's not already set.
session.ContextId ??= contextId;
session.TaskId = taskId;
session.TaskState = taskState;
}
private static Message CreateA2AMessage(A2AAgentSession typedSession, IEnumerable<ChatMessage> messages)
private static Message CreateA2AMessage(A2AAgentSession typedSession, IReadOnlyCollection<ChatMessage> messages)
{
var a2aMessage = messages.ToA2AMessage();
@@ -345,9 +349,19 @@ public sealed class A2AAgent : AIAgent
// See: https://github.com/a2aproject/A2A/blob/main/docs/topics/life-of-a-task.md#group-related-interactions
a2aMessage.ContextId = typedSession.ContextId;
// Link the message as a follow-up to an existing task, if any.
// See: https://github.com/a2aproject/A2A/blob/main/docs/topics/life-of-a-task.md#task-refinements
a2aMessage.ReferenceTaskIds = typedSession.TaskId is null ? null : [typedSession.TaskId];
if (typedSession.TaskState == TaskState.InputRequired)
{
// If the session indicates the task is waiting for user input,
// link the response to the existing task so it is treated as input
// for that task.
a2aMessage.TaskId = typedSession.TaskId;
}
else
{
// Link the message as a follow-up to an existing task, if any.
// See: https://github.com/a2aproject/A2A/blob/main/docs/topics/life-of-a-task.md#task-refinements
a2aMessage.ReferenceTaskIds = typedSession.TaskId is not null ? [typedSession.TaskId] : null;
}
return a2aMessage;
}
@@ -444,6 +458,7 @@ public sealed class A2AAgent : AIAgent
Role = ChatRole.Assistant,
FinishReason = MapTaskStateToFinishReason(statusUpdateEvent.Status.State),
AdditionalProperties = statusUpdateEvent.Metadata?.ToAdditionalProperties() ?? [],
Contents = statusUpdateEvent.Status.GetUserInputRequests(),
};
}
@@ -5,6 +5,8 @@ using System.Diagnostics;
using System.Text.Json;
using System.Text.Json.Serialization;
using TaskState = A2A.TaskState;
namespace Microsoft.Agents.AI.A2A;
/// <summary>
@@ -18,10 +20,11 @@ public sealed class A2AAgentSession : AgentSession
}
[JsonConstructor]
internal A2AAgentSession(string? contextId, string? taskId, AgentSessionStateBag? stateBag) : base(stateBag ?? new())
internal A2AAgentSession(string? contextId, string? taskId, TaskState? taskState, AgentSessionStateBag? stateBag) : base(stateBag ?? new())
{
this.ContextId = contextId;
this.TaskId = taskId;
this.TaskState = taskState;
}
/// <summary>
@@ -36,6 +39,12 @@ public sealed class A2AAgentSession : AgentSession
[JsonPropertyName("taskId")]
public string? TaskId { get; internal set; }
/// <summary>
/// Gets the state of the task the agent is currently working on.
/// </summary>
[JsonPropertyName("taskState")]
public TaskState? TaskState { get; internal set; }
/// <inheritdoc/>
internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null)
{
@@ -57,5 +66,5 @@ public sealed class A2AAgentSession : AgentSession
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private string DebuggerDisplay =>
$"ContextId = {this.ContextId}, TaskId = {this.TaskId}, StateBag Count = {this.StateBag.Count}";
$"ContextId = {this.ContextId}, TaskId = {this.TaskId}, TaskState = {this.TaskState}, StateBag Count = {this.StateBag.Count}";
}
@@ -13,7 +13,7 @@ internal static class A2AAIContentExtensions
/// <summary>
/// Converts a collection of <see cref="AIContent"/> to a list of <see cref="Part"/> objects.
/// </summary>
/// <param name="contents">The collection of AI contents to convert.</param>"
/// <param name="contents">The collection of AI contents to convert.</param>
/// <returns>The list of A2A <see cref="Part"/> objects.</returns>
internal static List<Part>? ToParts(this IEnumerable<AIContent> contents)
{
@@ -21,8 +21,7 @@ internal static class A2AAIContentExtensions
foreach (var content in contents)
{
var part = content.ToPart();
if (part is not null)
if (content.ToPart() is { } part)
{
(parts ??= []).Add(part);
}
@@ -17,7 +17,7 @@ internal static class A2AAgentTaskExtensions
List<ChatMessage>? messages = null;
if (agentTask?.Artifacts is { Count: > 0 })
if (agentTask.Artifacts is { Count: > 0 })
{
foreach (var artifact in agentTask.Artifacts)
{
@@ -25,6 +25,14 @@ internal static class A2AAgentTaskExtensions
}
}
if (agentTask.Status?.GetUserInputRequests() is { } userInputRequests)
{
(messages ??= []).Add(new(ChatRole.Assistant, userInputRequests)
{
RawRepresentation = agentTask.Status,
});
}
return messages;
}
@@ -42,6 +50,11 @@ internal static class A2AAgentTaskExtensions
}
}
if (agentTask.Status?.GetUserInputRequests() is { } userInputRequests)
{
(aiContents ??= []).AddRange(userInputRequests);
}
return aiContents;
}
}
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;
namespace A2A;
/// <summary>
/// Extension methods for the <see cref="TaskStatus"/> class.
/// </summary>
internal static class AgentTaskStatusExtensions
{
internal static IList<AIContent>? GetUserInputRequests(this TaskStatus status)
{
_ = Throw.IfNull(status);
List<AIContent>? contents = null;
if (status.Message is null || status.State is not TaskState.InputRequired)
{
return contents;
}
foreach (var part in status.Message.Parts)
{
var aiContent = part.ToAIContent();
aiContent.RawRepresentation = part;
aiContent.AdditionalProperties = part.Metadata.ToAdditionalProperties();
(contents ??= []).Add(aiContent);
}
return contents;
}
}
@@ -493,6 +493,35 @@ public sealed class A2AAgentTests : IDisposable
Assert.Contains("task-123", message.ReferenceTaskIds);
}
[Fact]
public async Task RunAsync_WithInputRequiredTaskState_SetsTaskIdOnMessageAsync()
{
// Arrange
this._handler.ResponseToReturn = new SendMessageResponse
{
Message = new Message
{
MessageId = "response-456",
Role = Role.Agent,
Parts = [new Part { Text = "Booking confirmed" }]
}
};
var session = (A2AAgentSession)await this._agent.CreateSessionAsync();
session.TaskId = "task-123";
session.TaskState = TaskState.InputRequired;
var inputMessage = new ChatMessage(ChatRole.User, [new TextContent("New York to London")]);
// Act
await this._agent.RunAsync(inputMessage, session);
// Assert
var message = this._handler.CapturedSendMessageRequest?.Message;
Assert.Equal("task-123", message?.TaskId);
Assert.Null(message?.ReferenceTaskIds);
}
[Fact]
public async Task RunAsync_WithAgentTask_UpdatesSessionTaskIdAsync()
{
@@ -573,6 +602,7 @@ public sealed class A2AAgentTests : IDisposable
[InlineData(TaskState.Completed)]
[InlineData(TaskState.Failed)]
[InlineData(TaskState.Canceled)]
[InlineData(TaskState.InputRequired)]
public async Task RunAsync_WithVariousTaskStates_ReturnsCorrectTokenAsync(TaskState taskState)
{
// Arrange
@@ -842,6 +872,38 @@ public sealed class A2AAgentTests : IDisposable
Assert.Contains("task-123", message.ReferenceTaskIds);
}
[Fact]
public async Task RunStreamingAsync_WithInputRequiredTaskState_SetsTaskIdOnMessageAsync()
{
// Arrange
this._handler.StreamingResponseToReturn = new StreamResponse
{
Message = new Message
{
MessageId = "response-456",
Role = Role.Agent,
Parts = [new Part { Text = "Booking confirmed" }]
}
};
var session = (A2AAgentSession)await this._agent.CreateSessionAsync();
session.TaskId = "task-123";
session.TaskState = TaskState.InputRequired;
var inputMessage = new ChatMessage(ChatRole.User, [new TextContent("New York to London")]);
// Act
await foreach (var _ in this._agent.RunStreamingAsync([inputMessage], session))
{
// Just iterate through to trigger the logic
}
// Assert
var message = this._handler.CapturedSendMessageRequest?.Message;
Assert.Equal("task-123", message?.TaskId);
Assert.Null(message?.ReferenceTaskIds);
}
[Fact]
public async Task RunStreamingAsync_WithAgentTask_UpdatesSessionTaskIdAsync()
{
@@ -1004,6 +1066,50 @@ public sealed class A2AAgentTests : IDisposable
Assert.Equal(TaskId, a2aSession.TaskId);
}
[Fact]
public async Task RunStreamingAsync_WithInputRequiredStatusUpdate_YieldsStatusContentsAsync()
{
// Arrange
const string TaskId = "task-input-123";
const string ContextId = "ctx-input-456";
this._handler.StreamingResponseToReturn = new StreamResponse
{
StatusUpdate = new TaskStatusUpdateEvent
{
TaskId = TaskId,
ContextId = ContextId,
Status = new()
{
State = TaskState.InputRequired,
Message = new Message
{
Parts = [Part.FromText("Where would you like to fly?")]
}
}
}
};
var session = await this._agent.CreateSessionAsync();
// Act
var updates = new List<AgentResponseUpdate>();
await foreach (var update in this._agent.RunStreamingAsync("I'd like to book a flight.", session))
{
updates.Add(update);
}
// Assert
Assert.Single(updates);
var update0 = updates[0];
Assert.Equal(TaskId, update0.ResponseId);
Assert.Null(update0.FinishReason);
var textContent = Assert.Single(update0.Contents.OfType<TextContent>());
Assert.Equal("Where would you like to fly?", textContent.Text);
}
[Fact]
public async Task RunStreamingAsync_WithTaskArtifactUpdateEvent_YieldsResponseUpdateAsync()
{
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using A2A;
using Microsoft.Extensions.AI;
@@ -166,4 +167,79 @@ public sealed class A2AAgentTaskExtensionsTests
Assert.Equal("content2", result[1].ToString());
Assert.Equal("content3", result[2].ToString());
}
[Fact]
public void ToChatMessages_WithInputRequiredStatus_IncludesStatusContents()
{
// Arrange
var agentTask = new AgentTask
{
Id = "task1",
Artifacts = null,
Status = new TaskStatus
{
State = TaskState.InputRequired,
Message = new Message { Parts = [Part.FromText("What is your destination?")] },
},
};
// Act
IList<ChatMessage>? result = agentTask.ToChatMessages();
// Assert
Assert.NotNull(result);
Assert.Single(result);
Assert.Equal(ChatRole.Assistant, result[0].Role);
var textContent = Assert.Single(result[0].Contents.OfType<TextContent>());
Assert.Equal("What is your destination?", textContent.Text);
}
[Fact]
public void ToAIContents_WithInputRequiredStatus_IncludesStatusContents()
{
// Arrange
var agentTask = new AgentTask
{
Id = "task1",
Artifacts = null,
Status = new TaskStatus
{
State = TaskState.InputRequired,
Message = new Message { Parts = [Part.FromText("What is your destination?")] },
},
};
// Act
IList<AIContent>? result = agentTask.ToAIContents();
// Assert
Assert.NotNull(result);
var textContent = Assert.Single(result.OfType<TextContent>());
Assert.Equal("What is your destination?", textContent.Text);
}
[Fact]
public void ToChatMessages_WithArtifactsAndInputRequired_IncludesBoth()
{
// Arrange
var agentTask = new AgentTask
{
Id = "task1",
Artifacts = [new Artifact { Parts = [Part.FromText("partial result")] }],
Status = new TaskStatus
{
State = TaskState.InputRequired,
Message = new Message { Parts = [Part.FromText("Need more info")] },
},
};
// Act
IList<ChatMessage>? result = agentTask.ToChatMessages();
// Assert
Assert.NotNull(result);
Assert.Equal(2, result.Count);
Assert.Equal("partial result", result[0].Text);
Assert.Single(result[1].Contents.OfType<TextContent>());
}
}
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using A2A;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.AI.A2A.UnitTests;
/// <summary>
/// Unit tests for the <see cref="AgentTaskStatusExtensions"/> class.
/// </summary>
public sealed class AgentTaskStatusExtensionsTests
{
[Fact]
public void GetUserInputRequests_WithNullMessage_ReturnsNull()
{
// Arrange
var status = new TaskStatus
{
State = TaskState.InputRequired,
Message = null,
};
// Act
IList<AIContent>? result = status.GetUserInputRequests();
// Assert
Assert.Null(result);
}
[Fact]
public void GetUserInputRequests_WithNotInputRequiredState_ReturnsNull()
{
// Arrange
var status = new TaskStatus
{
State = TaskState.Completed,
Message = new Message { Parts = [Part.FromText("Some text")] },
};
// Act
IList<AIContent>? result = status.GetUserInputRequests();
// Assert
Assert.Null(result);
}
[Fact]
public void GetUserInputRequests_WithInputRequiredStateAndMultipleRequests_ReturnsAIContentList()
{
// Arrange
var status = new TaskStatus
{
State = TaskState.InputRequired,
Message = new Message
{
Parts =
[
Part.FromText("First request"),
Part.FromText("Second request"),
Part.FromText("Third request")
],
},
};
// Act
IList<AIContent>? result = status.GetUserInputRequests();
// Assert
Assert.NotNull(result);
Assert.Equal(3, result.Count);
Assert.Equal("First request", Assert.IsType<TextContent>(result[0]).Text);
Assert.Equal("Second request", Assert.IsType<TextContent>(result[1]).Text);
Assert.Equal("Third request", Assert.IsType<TextContent>(result[2]).Text);
}
[Fact]
public void GetUserInputRequests_WithTextParts_SetsRawRepresentationAndAdditionalPropertiesCorrectly()
{
// Arrange
var textPart = Part.FromText("Input request");
textPart.Metadata = new Dictionary<string, System.Text.Json.JsonElement>
{
{ "key1", System.Text.Json.JsonSerializer.SerializeToElement("value1") },
{ "key2", System.Text.Json.JsonSerializer.SerializeToElement("value2") }
};
var status = new TaskStatus
{
State = TaskState.InputRequired,
Message = new Message { Parts = [textPart] },
};
// Act
IList<AIContent>? result = status.GetUserInputRequests();
// Assert
Assert.NotNull(result);
var content = Assert.IsType<TextContent>(result[0]);
Assert.Equal(textPart, content.RawRepresentation);
Assert.NotNull(content.AdditionalProperties);
Assert.True(content.AdditionalProperties.ContainsKey("key1"));
Assert.True(content.AdditionalProperties.ContainsKey("key2"));
}
[Fact]
public void GetUserInputRequests_WithEmptyMessageParts_ReturnsNull()
{
// Arrange
var status = new TaskStatus
{
State = TaskState.InputRequired,
Message = new Message { Parts = [] },
};
// Act
IList<AIContent>? result = status.GetUserInputRequests();
// Assert
Assert.Null(result);
}
}