.NET: feat: Improve Support for AIAgent-as-Executor (#432)

* feat: Support Executor-targeted messages

This adds support for only sending a message to a given executor. Messages will still only route through connected edges.

* feat: Support sending all valid input types after starting a run

* feat: Normalize AIAgent-as-Executor Message Protocol to use MEAI types
This commit is contained in:
Jacob Alber
2025-08-19 16:04:03 -04:00
committed by GitHub
Unverified
parent 953ed7560d
commit baaa9c0aee
21 changed files with 595 additions and 52 deletions
@@ -16,17 +16,23 @@ internal class DirectEdgeRunner(IRunnerContext runContext, DirectEdgeData edgeDa
.ConfigureAwait(false);
}
public async ValueTask<IEnumerable<object?>> ChaseAsync(object message)
public async ValueTask<IEnumerable<object?>> ChaseAsync(MessageEnvelope envelope)
{
if (envelope.TargetId != null && this.EdgeData.SinkId != envelope.TargetId)
{
return [];
}
object message = envelope.Message;
if (this.EdgeData.Condition != null && !this.EdgeData.Condition(message))
{
return [];
}
Executor target = await this.FindRouterAsync().ConfigureAwait(false);
if (target.CanHandle(message.GetType()))
if (target.CanHandle(envelope.MessageType))
{
return [await target.ExecuteAsync(message, this.WorkflowContext).ConfigureAwait(false)];
return [await target.ExecuteAsync(message, envelope.MessageType, this.WorkflowContext).ConfigureAwait(false)];
}
return [];
@@ -40,7 +40,7 @@ internal class EdgeMap
this._inputRunner = new InputEdgeRunner(runContext, startExecutorId);
}
public async ValueTask<IEnumerable<object?>> InvokeEdgeAsync(Edge edge, string sourceId, object message)
public async ValueTask<IEnumerable<object?>> InvokeEdgeAsync(Edge edge, string sourceId, MessageEnvelope message)
{
if (!this._edgeRunners.TryGetValue(edge, out object? edgeRunner))
{
@@ -87,9 +87,9 @@ internal class EdgeMap
}
// TODO: Should we promote Input to a true "FlowEdge" type?
public async ValueTask<IEnumerable<object?>> InvokeInputAsync(object inputMessage)
public async ValueTask<IEnumerable<object?>> InvokeInputAsync(MessageEnvelope envelope)
{
return [await this._inputRunner.ChaseAsync(inputMessage).ConfigureAwait(false)];
return [await this._inputRunner.ChaseAsync(envelope).ConfigureAwait(false)];
}
public async ValueTask<IEnumerable<object?>> InvokeResponseAsync(ExternalResponse response)
@@ -99,6 +99,6 @@ internal class EdgeMap
throw new InvalidOperationException($"Port {response.Port.Id} not found in the edge map.");
}
return [await portRunner.ChaseAsync(response).ConfigureAwait(false)];
return [await portRunner.ChaseAsync(new MessageEnvelope(response)).ConfigureAwait(false)];
}
}
@@ -12,8 +12,15 @@ internal class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData edgeData
public FanInEdgeState CreateState() => new(this.EdgeData);
public async ValueTask<object?> ChaseAsync(string sourceId, object message, FanInEdgeState state)
public async ValueTask<object?> ChaseAsync(string sourceId, MessageEnvelope envelope, FanInEdgeState state)
{
if (envelope.TargetId != null && this.EdgeData.SinkId != envelope.TargetId)
{
// This message is not for us.
return null;
}
object message = envelope.Message;
IEnumerable<object>? releasedMessages = state.ProcessMessage(sourceId, message);
if (releasedMessages is null)
{
@@ -26,7 +33,7 @@ internal class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData edgeData
if (target.CanHandle(message.GetType()))
{
return await target.ExecuteAsync(message, this.BoundContext)
return await target.ExecuteAsync(message, envelope.MessageType, this.BoundContext)
.ConfigureAwait(false);
}
return null;
@@ -14,14 +14,20 @@ internal class FanOutEdgeRunner(IRunnerContext runContext, FanOutEdgeData edgeDa
sinkId => sinkId,
sinkId => runContext.Bind(sinkId));
public async ValueTask<IEnumerable<object?>> ChaseAsync(object message)
public async ValueTask<IEnumerable<object?>> ChaseAsync(MessageEnvelope envelope)
{
object message = envelope.Message;
List<string> targets =
this.EdgeData.PartitionAssigner == null
? this.EdgeData.SinkIds
: this.EdgeData.PartitionAssigner(message, this.BoundContexts.Count).Select(i => this.EdgeData.SinkIds[i]).ToList();
: this.EdgeData.PartitionAssigner(message, this.BoundContexts.Count)
.Select(i => this.EdgeData.SinkIds[i]).ToList();
object?[] result = await Task.WhenAll(targets.Select(ProcessTargetAsync)).ConfigureAwait(false);
IEnumerable<string> filteredTargets = envelope.TargetId != null
? targets.Where(IsValidTarget)
: targets;
object?[] result = await Task.WhenAll(filteredTargets.Select(ProcessTargetAsync)).ConfigureAwait(false);
return result.Where(r => r is not null);
async Task<object?> ProcessTargetAsync(string targetId)
@@ -31,11 +37,16 @@ internal class FanOutEdgeRunner(IRunnerContext runContext, FanOutEdgeData edgeDa
if (executor.CanHandle(message.GetType()))
{
return await executor.ExecuteAsync(message, this.BoundContexts[targetId])
return await executor.ExecuteAsync(message, envelope.MessageType, this.BoundContexts[targetId])
.ConfigureAwait(false);
}
return null;
}
bool IsValidTarget(string targetId)
{
return envelope.TargetId == null || targetId == envelope.TargetId;
}
}
}
@@ -7,7 +7,7 @@ namespace Microsoft.Agents.Workflows.Execution;
internal interface IRunnerContext : IExternalRequestSink
{
ValueTask AddEventAsync(WorkflowEvent workflowEvent);
ValueTask SendMessageAsync(string executorId, object message);
ValueTask SendMessageAsync(string sourceId, object message, string? targetId = null);
// TODO: State Management
@@ -11,7 +11,8 @@ internal interface ISuperStepRunner
bool HasUnservicedRequests { get; }
bool HasUnprocessedMessages { get; }
ValueTask EnqueueMessageAsync(object message);
ValueTask EnqueueResponseAsync(ExternalResponse response);
ValueTask<bool> EnqueueMessageAsync<T>(T message);
event EventHandler<WorkflowEvent>? WorkflowEvent;
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Diagnostics;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
@@ -23,16 +24,17 @@ internal class InputEdgeRunner(IRunnerContext runContext, string sinkId)
return await this.RunContext.EnsureExecutorAsync(this.EdgeData).ConfigureAwait(false);
}
public async ValueTask<object?> ChaseAsync(object message)
public async ValueTask<object?> ChaseAsync(MessageEnvelope envelope)
{
Executor target = await this.FindExecutorAsync().ConfigureAwait(false);
if (target.CanHandle(message.GetType()))
if (target.CanHandle(envelope.MessageType))
{
return await target.ExecuteAsync(message, this.WorkflowContext)
return await target.ExecuteAsync(envelope.Message, envelope.MessageType, this.WorkflowContext)
.ConfigureAwait(false);
}
// TODO: Throw instead?
// TODO: Throw instead? / Log
Debug.WriteLine($"Executor {target.Id} cannot handle message of type {envelope.MessageType.FullName}. Dropping.");
return null;
}
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
namespace Microsoft.Agents.Workflows.Execution;
internal sealed class MessageEnvelope(object message, Type? declaredType = null, string? targetId = null)
{
public Type MessageType => declaredType ?? message.GetType();
public object Message => message;
public string? TargetId => targetId;
}
@@ -23,9 +23,11 @@ internal class MessageRouter
{
this._typedHandlers = Throw.IfNull(handlers);
this._hasCatchall = this._typedHandlers.ContainsKey(typeof(object));
this.IncomingTypes = [.. this._typedHandlers.Keys];
}
public HashSet<Type> IncomingTypes => [.. this._typedHandlers.Keys];
public HashSet<Type> IncomingTypes { get; }
public bool CanHandle(object message) => this.CanHandle(Throw.IfNull(message).GetType());
@@ -7,16 +7,15 @@ namespace Microsoft.Agents.Workflows.Execution;
internal class StepContext
{
public Dictionary<ExecutorIdentity, List<object>> QueuedMessages { get; } = new();
public Dictionary<ExecutorIdentity, List<MessageEnvelope>> QueuedMessages { get; } = new();
public bool HasMessages => this.QueuedMessages.Values.Any(messageList => messageList.Count > 0);
public List<object> MessagesFor(string? executorId)
public List<MessageEnvelope> MessagesFor(string? executorId)
{
if (!this.QueuedMessages.TryGetValue(executorId, out var messages))
{
messages = new List<object>();
this.QueuedMessages[executorId] = messages;
this.QueuedMessages[executorId] = messages = new();
}
return messages;
@@ -54,11 +54,13 @@ public abstract class Executor : IIdentified
/// Process an incoming message using the registered handlers.
/// </summary>
/// <param name="message">The message to be processed by the executor.</param>
/// <param name="messageType">The "declared" type of the message (captured when it was being sent). This is
/// used to enable routing messages as their base types, in absence of true polymorphic type routing.</param>
/// <param name="context">The workflow context in which the executor executes.</param>
/// <returns>A ValueTask representing the asynchronous operation, wrapping the output from the executor.</returns>
/// <exception cref="NotSupportedException">No handler found for the message type.</exception>
/// <exception cref="TargetInvocationException">An exception is generated while handling the message.</exception>
public async ValueTask<object?> ExecuteAsync(object message, IWorkflowContext context)
public async ValueTask<object?> ExecuteAsync(object message, Type messageType, IWorkflowContext context)
{
await context.AddEventAsync(new ExecutorInvokeEvent(this.Id, message)).ConfigureAwait(false);
@@ -21,8 +21,11 @@ public interface IWorkflowContext
/// Queues a message to be sent to connected executors. The message will be sent during the next SuperStep.
/// </summary>
/// <param name="message">The message to be sent.</param>
/// <param name="targetId">An optional identifier of the target executor. If null, the message is sent to all connected
/// executors. If the target executor is not connected from this executor via an edge, it will still not receive the
/// message.</param>
/// <returns>A <see cref="ValueTask"/> representing the asynchronous operation.</returns>
ValueTask SendMessageAsync(object message);
ValueTask SendMessageAsync(object message, string? targetId = null);
/// <summary>
/// Reads a state value from the workflow's state store. If no scope is provided, the executor's private
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@@ -29,9 +30,38 @@ internal class InProcessRunner<TInput> : ISuperStepRunner where TInput : notnull
this.EdgeMap = new EdgeMap(this.RunContext, this.Workflow.Edges, this.Workflow.Ports.Values, this.Workflow.StartExecutorId);
}
ValueTask ISuperStepRunner.EnqueueMessageAsync(object message)
public async ValueTask<bool> IsValidInputAsync<TMessage>(TMessage message)
{
return this.RunContext.AddExternalMessageAsync(message);
Throw.IfNull(message);
Type type = typeof(TMessage);
// Short circuit the logic if the type is the input type
if (type == typeof(TInput))
{
return true;
}
Executor startingExecutor = await this.RunContext.EnsureExecutorAsync(this.Workflow.StartExecutorId).ConfigureAwait(false);
return startingExecutor.CanHandle(type);
}
async ValueTask<bool> ISuperStepRunner.EnqueueMessageAsync<T>(T message)
{
// Check that the type of the incoming message is compatible with the starting executor's
// input type.
if (!await this.IsValidInputAsync(message).ConfigureAwait(false))
{
return false;
}
await this.RunContext.AddExternalMessageAsync<T>(message).ConfigureAwait(false);
return true;
}
ValueTask ISuperStepRunner.EnqueueResponseAsync(ExternalResponse response)
{
return this.RunContext.AddExternalMessageAsync(response);
}
private Dictionary<string, string> PendingCalls { get; } = new();
@@ -57,11 +87,14 @@ internal class InProcessRunner<TInput> : ISuperStepRunner where TInput : notnull
return message is ExternalResponse;
}
private ValueTask<IEnumerable<object?>> RouteExternalMessageAsync(object message)
private ValueTask<IEnumerable<object?>> RouteExternalMessageAsync(MessageEnvelope envelope)
{
Debug.Assert(envelope.TargetId == null, "External Messages cannot be targeted to a specific executor.");
object message = envelope.Message;
return message is ExternalResponse response
? this.CompleteExternalResponseAsync(response)
: this.EdgeMap.InvokeInputAsync(message);
: this.EdgeMap.InvokeInputAsync(envelope);
}
private ValueTask<IEnumerable<object?>> CompleteExternalResponseAsync(ExternalResponse response)
@@ -113,16 +146,16 @@ internal class InProcessRunner<TInput> : ISuperStepRunner where TInput : notnull
List<Task<IEnumerable<object?>>> edgeTasks = new();
foreach (ExecutorIdentity sender in currentStep.QueuedMessages.Keys)
{
IEnumerable<object> senderMessages = currentStep.QueuedMessages[sender];
IEnumerable<MessageEnvelope> senderMessages = currentStep.QueuedMessages[sender];
if (sender.Id is null)
{
edgeTasks.AddRange(senderMessages.Select(message => this.RouteExternalMessageAsync(message).AsTask()));
edgeTasks.AddRange(senderMessages.Select(envelope => this.RouteExternalMessageAsync(envelope).AsTask()));
}
else if (this.Workflow.Edges.TryGetValue(sender.Id!, out HashSet<Edge>? outgoingEdges))
{
foreach (Edge outgoingEdge in outgoingEdges)
{
edgeTasks.AddRange(senderMessages.Select(message => this.EdgeMap.InvokeEdgeAsync(outgoingEdge, sender.Id, message).AsTask()));
edgeTasks.AddRange(senderMessages.Select(envelope => this.EdgeMap.InvokeEdgeAsync(outgoingEdge, sender.Id, envelope).AsTask()));
}
}
}
@@ -2,7 +2,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Execution;
@@ -44,11 +43,19 @@ internal class InProcessRunnerContext<TExternalInput> : IRunnerContext
return executor;
}
public ValueTask AddExternalMessageAsync([NotNull] object message)
public ValueTask AddExternalMessageUntypedAsync(object message)
{
Throw.IfNull(message);
this._nextStep.MessagesFor(ExecutorIdentity.None).Add(message);
this._nextStep.MessagesFor(ExecutorIdentity.None).Add(new MessageEnvelope(message));
return default;
}
public ValueTask AddExternalMessageAsync<T>(T message)
{
Throw.IfNull(message);
this._nextStep.MessagesFor(ExecutorIdentity.None).Add(new MessageEnvelope(message, declaredType: typeof(T)));
return default;
}
@@ -66,9 +73,9 @@ internal class InProcessRunnerContext<TExternalInput> : IRunnerContext
return default;
}
public ValueTask SendMessageAsync(string executorId, object message)
public ValueTask SendMessageAsync(string sourceId, object message, string? targetId = null)
{
this._nextStep.MessagesFor(executorId).Add(message);
this._nextStep.MessagesFor(sourceId).Add(new MessageEnvelope(message, targetId: targetId));
return default;
}
@@ -92,7 +99,7 @@ internal class InProcessRunnerContext<TExternalInput> : IRunnerContext
private class BoundContext(InProcessRunnerContext<TExternalInput> RunnerContext, string ExecutorId) : IWorkflowContext
{
public ValueTask AddEventAsync(WorkflowEvent workflowEvent) => RunnerContext.AddEventAsync(workflowEvent);
public ValueTask SendMessageAsync(object message) => RunnerContext.SendMessageAsync(ExecutorId, message);
public ValueTask SendMessageAsync(object message, string? targetId = null) => RunnerContext.SendMessageAsync(ExecutorId, message, targetId);
public ValueTask QueueStateUpdateAsync<T>(string key, T? value, string? scopeName = null)
=> RunnerContext.StateManager.WriteStateAsync(ExecutorId, scopeName, key, value);
@@ -127,6 +127,28 @@ public class Run
return await this.RunToNextHaltAsync(cancellation).ConfigureAwait(false);
}
/// <summary>
/// Resume execution of the workflow with the provided external responses.
/// </summary>
/// <param name="cancellation">A <see cref="CancellationToken"/> that can be used to cancel the workflow execution.</param>
/// <param name="messages">An array of messages to send to the workflow. Messages will only be sent if they are valid
/// input types to the starting executor or a <see cref="ExternalResponse"/>.</param>
/// <returns><c>true</c> if the workflow had any output events, <c>false</c> otherwise.</returns>
public async ValueTask<bool> ResumeAsync<T>(CancellationToken cancellation = default, params T[] messages)
{
if (messages is ExternalResponse[] responses)
{
return await this.ResumeAsync(cancellation, responses).ConfigureAwait(false);
}
foreach (T message in messages)
{
await this._streamingRun.TrySendMessageAsync(message).ConfigureAwait(false);
}
return await this.RunToNextHaltAsync(cancellation).ConfigureAwait(false);
}
}
/// <summary>
@@ -3,30 +3,64 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Reflection;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.AI.Agents;
namespace Microsoft.Agents.Workflows.Specialized;
internal class AIAgentHostExecutor : ReflectingExecutor<AIAgentHostExecutor>, IMessageHandler<IList<ChatMessage>>
internal class AIAgentHostExecutor : Executor
{
private AIAgent Agent { get; set; }
private readonly AIAgent _agent;
private readonly List<ChatMessage> _pendingMessages = new();
private AgentThread? _thread = null;
public AIAgentHostExecutor(AIAgent agent)
public AIAgentHostExecutor(AIAgent agent) : base(id: agent.Id)
{
this.Agent = agent;
this._agent = agent;
}
public async ValueTask HandleAsync(IList<ChatMessage> message, IWorkflowContext context)
private AgentThread EnsureThread()
{
IReadOnlyCollection<ChatMessage> messageList = (message as List<ChatMessage> ?? message.ToList()).AsReadOnly();
if (this._thread != null)
{
return this._thread;
}
return this._thread = this._agent.GetNewThread();
}
protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)
{
return routeBuilder.AddHandler<ChatMessage>(this.QueueMessageAsync)
.AddHandler<List<ChatMessage>>(this.QueueMessagesAsync)
.AddHandler<TurnToken>(this.TakeTurnAsync);
}
public ValueTask QueueMessagesAsync(List<ChatMessage> messages, IWorkflowContext context)
{
this._pendingMessages.AddRange(messages);
return default;
}
public ValueTask QueueMessageAsync(ChatMessage message, IWorkflowContext context)
{
this._pendingMessages.Add(message);
return default;
}
public async ValueTask TakeTurnAsync(TurnToken token, IWorkflowContext context)
{
// TODO: Ideally we want to be able to split the Run across multiple super-steps so that we can stream out
// incremental updates from the chat model.
AgentRunResponse runResponse = await this.Agent.RunAsync(messageList).ConfigureAwait(false);
AgentRunResponse runResponse = await this._agent.RunAsync(this._pendingMessages, this.EnsureThread())
.ConfigureAwait(false);
await context.AddEventAsync(new AgentRunEvent(this.Id, runResponse)).ConfigureAwait(false);
await context.SendMessageAsync(runResponse).ConfigureAwait(false);
if (token.EmitEvents)
{
await context.AddEventAsync(new AgentRunEvent(this.Id, runResponse)).ConfigureAwait(false);
}
await context.SendMessageAsync(runResponse.Messages.ToList()).ConfigureAwait(false);
await context.SendMessageAsync(token).ConfigureAwait(false);
}
}
@@ -37,13 +37,35 @@ public class StreamingRun
/// </summary>
/// <remarks>The response will be queued for processing for the next superstep.</remarks>
/// <param name="response">The <see cref="ExternalResponse"/> to send. Must not be <c>null</c>.</param>
/// <returns>A <see cref="ValueTask"/> that represents the asynchronous send operation. The task completes when the response
/// has been enqueued for processing, but will not wait for processing to complete.</returns>
/// <returns>A <see cref="ValueTask"/> that represents the asynchronous send operation.</returns>
public ValueTask SendResponseAsync(ExternalResponse response)
{
this._waitForResponseSource?.TrySetResult(new());
return this._stepRunner.EnqueueMessageAsync(response);
return this._stepRunner.EnqueueResponseAsync(response);
}
/// <summary>
/// Attempts to send the specified message asynchronously and returns a value indicating whether the operation was
/// successful.
/// </summary>
/// <typeparam name="TMessage">The type of the message to send. Must be compatible with the expected message types for
/// the starting executor, or receiving port.</typeparam>
/// <param name="message">The message instance to send. Cannot be null.</param>
/// <returns>A <see cref="ValueTask{Boolean}"/> that represents the asynchronous send operation. It's
/// <see cref="ValueTask{Boolean}.Result"/> is <see langword="true"/> if the message was sent
/// successfully; otherwise, <see langword="false"/>.</returns>
public async ValueTask<bool> TrySendMessageAsync<TMessage>(TMessage message)
{
Throw.IfNull(message);
if (message is ExternalResponse response)
{
await this.SendResponseAsync(response).ConfigureAwait(false);
return true;
}
return await this._stepRunner.EnqueueMessageAsync(message).ConfigureAwait(false);
}
/// <summary>
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.Extensions.AI.Agents;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// Sent to an <see cref="AIAgent"/>-based executor to request
/// a response to accumulated <see cref="Microsoft.Extensions.AI.ChatMessage"/>.
/// </summary>
/// <param name="emitEvents">Whether to raise AgentRunEvents for this executor.</param>
public class TurnToken(bool emitEvents = false)
{
/// <summary>
/// Gets a value indicating whether events are emitted by the receiving executor.
/// </summary>
public bool EmitEvents => emitEvents;
}
@@ -15,6 +15,59 @@ namespace Microsoft.Agents.Workflows;
/// promote common patterns for chaining and aggregating workflow steps.</remarks>
public static class WorkflowBuilderExtensions
{
/// <summary>
/// Adds edges to the workflow that forward messages of the specified type from the source executor to
/// one or more target executors.
/// </summary>
/// <typeparam name="TMessage">The type of message to forward.</typeparam>
/// <param name="builder">The <see cref="WorkflowBuilder"/> to which the edges will be added.</param>
/// <param name="source">The source executor from which messages will be forwarded.</param>
/// <param name="executors">The target executors to which messages will be forwarded.</param>
/// <returns>The updated <see cref="WorkflowBuilder"/> instance.</returns>
public static WorkflowBuilder ForwardMessage<TMessage>(this WorkflowBuilder builder, ExecutorIsh source, params ExecutorIsh[] executors)
{
Throw.IfNullOrEmpty(executors);
if (executors.Length == 1)
{
return builder.AddEdge(source, executors[0], IsAllowedType);
}
return builder.AddSwitch(source,
(switch_) =>
{
switch_.AddCase(IsAllowedType, executors);
});
bool IsAllowedType(object? message) => message is TMessage;
}
/// <summary>
/// Adds edges from the specified source to the provided executors, excluding messages of a specified type.
/// </summary>
/// <typeparam name="TMessage">The type of messages to exclude from being forwarded to the executors.</typeparam>
/// <param name="builder">The <see cref="WorkflowBuilder"/> instance to which the edges will be added.</param>
/// <param name="source">The source executor from which messages will be forwarded.</param>
/// <param name="executors">The target executors to which messages, except those of type <typeparamref name="TMessage"/>, will be forwarded.</param>
/// <returns>The updated <see cref="WorkflowBuilder"/> instance with the added edges.</returns>
public static WorkflowBuilder ForwardExcept<TMessage>(this WorkflowBuilder builder, ExecutorIsh source, params ExecutorIsh[] executors)
{
Throw.IfNullOrEmpty(executors);
if (executors.Length == 1)
{
return builder.AddEdge(source, executors[0], IsAllowedType);
}
return builder.AddSwitch(source,
(switch_) =>
{
switch_.AddCase(IsAllowedType, executors);
});
bool IsAllowedType(object? message) => message is not TMessage;
}
/// <summary>
/// Adds a sequential chain of executors to the workflow, connecting each executor in order so that each is
/// executed after the previous one.
@@ -0,0 +1,292 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.AI.Agents;
namespace Microsoft.Agents.Workflows.UnitTests.Sample;
internal static class Step6EntryPoint
{
internal static int MaxSteps { get; set; }
public static async ValueTask RunAsync(TextWriter writer, int maxSteps = 2)
{
Step6EntryPoint.MaxSteps = maxSteps;
GroupChatBuilder builder = GroupChatBuilder.Create<RoundRobinGroupChatManager>()
.AddParticipant(new HelloAgent(), shouldEmitEvents: true)
.AddParticipant(new EchoAgent(), shouldEmitEvents: true);
Workflow<List<ChatMessage>> workflow = builder.ReduceToWorkflow();
StreamingRun run = await InProcessExecution.StreamAsync(workflow, [])
.ConfigureAwait(false);
await run.TrySendMessageAsync(new TurnToken(emitEvents: true));
await foreach (WorkflowEvent evt in run.WatchStreamAsync().ConfigureAwait(false))
{
if (evt is ExecutorCompleteEvent executorComplete)
{
Debug.WriteLine($"{executorComplete.ExecutorId}: {executorComplete.Data}");
}
else if (evt is AgentRunEvent agentRun && agentRun.Data is AgentRunResponse response)
{
foreach (ChatMessage message in response.Messages)
{
writer.WriteLine($"{agentRun.ExecutorId}: {message.Text}");
}
}
}
}
private sealed class RoundRobinGroupChatManager : GroupChatManager
{
public int TurnCount { get; private set; } = 0;
public int MaxTurns { get; init; } = Step6EntryPoint.MaxSteps;
public override int? GetNextTurnExecutor(GroupChatHistory history)
{
if (this.ParticipantIds.Length == 0)
{
throw new InvalidOperationException("No participants in the group chat.");
}
if (this.TurnCount >= this.MaxTurns)
{
return null;
}
return this.TurnCount++ % this.ParticipantIds.Length;
}
}
}
internal sealed class HelloAgent(string id = nameof(HelloAgent)) : AIAgent
{
public const string Greeting = "Hello World!";
public const string DefaultId = nameof(HelloAgent);
public override string Id => id;
public override Task<AgentRunResponse> RunAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
{
AgentRunResponse response = new(new ChatMessage(ChatRole.Assistant, "Hello World!"));
return Task.FromResult(response);
}
public override IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}
internal sealed class EchoAgent(string id = nameof(EchoAgent)) : AIAgent
{
public const string Prefix = "You said: ";
public const string DefaultId = nameof(EchoAgent);
public override string Id => id;
public override Task<AgentRunResponse> RunAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
{
if (messages.Count == 0)
{
throw new ArgumentException("No messages provided to echo.", nameof(messages));
}
StringBuilder collectedText = new(Prefix);
foreach (string messageText in messages.Select(message => message.Text)
.Where(text => !string.IsNullOrEmpty(text)))
{
collectedText.AppendLine(messageText);
}
AgentRunResponse result = new(new ChatMessage(ChatRole.Assistant, collectedText.ToString()));
return Task.FromResult(result);
}
public override IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}
internal sealed class GroupChatHistory
{
private readonly List<ChatMessage> _messages = new();
private int _bookmark = 0;
public void AddMessage(ChatMessage message)
{
this._messages.Add(message);
}
public void AddMessages(IEnumerable<ChatMessage> messages)
{
this._messages.AddRange(messages);
}
public void UpdateBookmark()
{
this._bookmark = this._messages.Count;
}
public IReadOnlyList<ChatMessage> FullHistory => this._messages.AsReadOnly();
public IEnumerable<ChatMessage> NewMessagesThisTurn => this._messages.Skip(this._bookmark);
}
internal abstract class GroupChatManager
{
public string[] ParticipantIds { get; internal init; } = [];
public abstract int? GetNextTurnExecutor(GroupChatHistory history);
}
internal sealed class GroupChatBuilder
{
private readonly List<ExecutorIsh> _participants = new();
private readonly List<bool> _shouldEmitEvents = new();
private readonly Func<string[], GroupChatManager> _managerFactory;
private GroupChatBuilder(Func<string[], GroupChatManager> managerFactory)
{
this._managerFactory = managerFactory;
}
public static GroupChatBuilder Create<TManager>() where TManager : GroupChatManager, new()
{
return new GroupChatBuilder(participantIds => new TManager() { ParticipantIds = participantIds });
}
public GroupChatBuilder AddParticipant(ExecutorIsh executor, bool shouldEmitEvents = false)
{
this._participants.Add(executor);
this._shouldEmitEvents.Add(shouldEmitEvents);
return this;
}
public GroupChatBuilder AddParticipants(params ExecutorIsh[] executors)
{
this._participants.AddRange(executors);
return this;
}
public Workflow<List<ChatMessage>> ReduceToWorkflow()
{
string[] participantIds = this._participants.Select(identified => identified.Id).ToArray();
GroupChatHost host = new(this._shouldEmitEvents.ToArray(), this._managerFactory(participantIds));
WorkflowBuilder builder = new WorkflowBuilder(host)
.AddFanOutEdge(host, targets: this._participants.ToArray());
foreach (ExecutorIsh participant in this._participants)
{
builder.AddEdge(participant, host);
}
return builder.Build<List<ChatMessage>>();
//bool IsMessageType(object? message) => message is ChatMessage || message is IEnumerable<ChatMessage>;
}
private sealed class TurnAssignedEvent(string executorId, string nextSpeakerId) : ExecutorEvent(executorId, data: nextSpeakerId);
private sealed class GroupChatHost : Executor
{
private readonly bool[] _shouldEmitEvents;
private readonly GroupChatManager _manager;
private readonly bool _autoStartConversation;
private readonly GroupChatHistory _history = new();
public GroupChatHost(bool[] shouldEmitEvents, GroupChatManager manager, bool autoStartConversation = false) : base(nameof(GroupChatHost))
{
this._shouldEmitEvents = shouldEmitEvents;
this._manager = manager ?? throw new ArgumentNullException(nameof(manager));
this._autoStartConversation = autoStartConversation;
}
protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)
{
return routeBuilder.AddHandler<List<ChatMessage>>(this.HandleChatMessagesAsync)
.AddHandler<ChatMessage>(this.HandleChatMessageAsync)
.AddHandler<TurnToken>(this.AssignNextTurnAsync);
}
private async Task TryAutoStartConversationAsync(IWorkflowContext context)
{
if (this._autoStartConversation && this.TryEnterConversation())
{
await this.AssignNextTurnAsync(new TurnToken(emitEvents: false), context).ConfigureAwait(false);
}
}
private async ValueTask HandleChatMessagesAsync(List<ChatMessage> initialMessages, IWorkflowContext context)
{
this._history.AddMessages(initialMessages);
await context.SendMessageAsync(initialMessages).ConfigureAwait(false);
await this.TryAutoStartConversationAsync(context).ConfigureAwait(false);
}
private async ValueTask HandleChatMessageAsync(ChatMessage message, IWorkflowContext context)
{
// First, add the message to the history, then forward to all executors
this._history.AddMessage(message);
await context.SendMessageAsync(message).ConfigureAwait(false);
await this.TryAutoStartConversationAsync(context).ConfigureAwait(false);
}
private int _inConversationFlag = 0;
/// <summary>
/// Atomically switches to "in conversation" state if not already in that state.
/// </summary>
/// <returns><see langword="true"/> if the state was changed, <see langword="false"/> otherwise.</returns>
private bool TryEnterConversation()
{
return Interlocked.CompareExchange(ref this._inConversationFlag, 1, 0) == 0;
}
private bool _shouldHostEmitEvents = false;
private async ValueTask AssignNextTurnAsync(TurnToken token, IWorkflowContext context)
{
if (this.TryEnterConversation())
{
// Capture the initial turn token's EmitEvents setting
this._shouldHostEmitEvents = token.EmitEvents;
}
int? nextSpeakerIndex = this._manager.GetNextTurnExecutor(this._history);
if (nextSpeakerIndex == null)
{
await context.AddEventAsync(new WorkflowCompletedEvent())
.ConfigureAwait(false);
return;
}
string nextSpeakerId = this._manager.ParticipantIds[nextSpeakerIndex.Value];
if (this._shouldHostEmitEvents)
{
await context.AddEventAsync(new TurnAssignedEvent(this.Id, nextSpeakerId))
.ConfigureAwait(false);
}
await context.SendMessageAsync(new TurnToken(this._shouldEmitEvents[nextSpeakerIndex.Value]), nextSpeakerId)
.ConfigureAwait(false);
}
}
}
@@ -5,6 +5,7 @@ using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Sample;
using Microsoft.Agents.Workflows.UnitTests.Sample;
namespace Microsoft.Agents.Workflows.UnitTests;
@@ -83,6 +84,22 @@ public class SampleSmokeTest
string guessResult = await Step5EntryPoint.RunAsync(writer, userGuessCallback: responder.InvokeNext);
Assert.Equal("You guessed correctly! You Win!", guessResult);
}
[Fact]
public async Task Test_RunSample_Step6Async()
{
using StringWriter writer = new();
await Step6EntryPoint.RunAsync(writer);
string result = writer.ToString();
string[] lines = result.Split([Environment.NewLine], StringSplitOptions.RemoveEmptyEntries);
Assert.Collection(lines,
line => Assert.Contains($"{HelloAgent.DefaultId}: {HelloAgent.Greeting}", line),
line => Assert.Contains($"{EchoAgent.DefaultId}: {EchoAgent.Prefix}{HelloAgent.Greeting}", line)
);
}
}
internal sealed class VerifyingPlaybackResponder<TInput, TResponse>