From baaa9c0aee7276ea07d16286abc9801e935fb608 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Tue, 19 Aug 2025 16:04:03 -0400 Subject: [PATCH] .NET: feat: Improve Support for AIAgent-as-Executor (#432) * feat: Support Executor-targeted messages This adds support for only sending a message to a given executor. Messages will still only route through connected edges. * feat: Support sending all valid input types after starting a run * feat: Normalize AIAgent-as-Executor Message Protocol to use MEAI types --- .../Execution/DirectEdgeRunner.cs | 12 +- .../Execution/EdgeMap.cs | 8 +- .../Execution/FanInEdgeRunner.cs | 11 +- .../Execution/FanOutEdgeRunner.cs | 19 +- .../Execution/IRunnerContext.cs | 2 +- .../Execution/ISuperStepRunner.cs | 3 +- .../Execution/InputEdgeRunner.cs | 10 +- .../Execution/MessageEnvelope.cs | 12 + .../Execution/MessageRouter.cs | 4 +- .../Execution/StepContext.cs | 7 +- .../Microsoft.Agents.Workflows/Executor.cs | 4 +- .../IWorkflowContext.cs | 5 +- .../InProc/InProcessRunner.cs | 47 ++- .../InProc/InProcessRunnerContext.cs | 19 +- dotnet/src/Microsoft.Agents.Workflows/Run.cs | 22 ++ .../Specialized/AIAgentHostExecutor.cs | 54 +++- .../StreamingRun.cs | 28 +- .../Microsoft.Agents.Workflows/TurnToken.cs | 18 ++ .../WorkflowBuilderExtensions.cs | 53 ++++ .../Sample/06_GroupChat_Workflow.cs | 292 ++++++++++++++++++ .../SampleSmokeTest.cs | 17 + 21 files changed, 595 insertions(+), 52 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/TurnToken.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/DirectEdgeRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/DirectEdgeRunner.cs index 5b9efc1fa3..30936d116d 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/DirectEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/DirectEdgeRunner.cs @@ -16,17 +16,23 @@ internal class DirectEdgeRunner(IRunnerContext runContext, DirectEdgeData edgeDa .ConfigureAwait(false); } - public async ValueTask> ChaseAsync(object message) + public async ValueTask> ChaseAsync(MessageEnvelope envelope) { + if (envelope.TargetId != null && this.EdgeData.SinkId != envelope.TargetId) + { + return []; + } + + object message = envelope.Message; if (this.EdgeData.Condition != null && !this.EdgeData.Condition(message)) { return []; } Executor target = await this.FindRouterAsync().ConfigureAwait(false); - if (target.CanHandle(message.GetType())) + if (target.CanHandle(envelope.MessageType)) { - return [await target.ExecuteAsync(message, this.WorkflowContext).ConfigureAwait(false)]; + return [await target.ExecuteAsync(message, envelope.MessageType, this.WorkflowContext).ConfigureAwait(false)]; } return []; diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs index 046382631d..7e6f961182 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs @@ -40,7 +40,7 @@ internal class EdgeMap this._inputRunner = new InputEdgeRunner(runContext, startExecutorId); } - public async ValueTask> InvokeEdgeAsync(Edge edge, string sourceId, object message) + public async ValueTask> InvokeEdgeAsync(Edge edge, string sourceId, MessageEnvelope message) { if (!this._edgeRunners.TryGetValue(edge, out object? edgeRunner)) { @@ -87,9 +87,9 @@ internal class EdgeMap } // TODO: Should we promote Input to a true "FlowEdge" type? - public async ValueTask> InvokeInputAsync(object inputMessage) + public async ValueTask> InvokeInputAsync(MessageEnvelope envelope) { - return [await this._inputRunner.ChaseAsync(inputMessage).ConfigureAwait(false)]; + return [await this._inputRunner.ChaseAsync(envelope).ConfigureAwait(false)]; } public async ValueTask> InvokeResponseAsync(ExternalResponse response) @@ -99,6 +99,6 @@ internal class EdgeMap throw new InvalidOperationException($"Port {response.Port.Id} not found in the edge map."); } - return [await portRunner.ChaseAsync(response).ConfigureAwait(false)]; + return [await portRunner.ChaseAsync(new MessageEnvelope(response)).ConfigureAwait(false)]; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs index 1ac7a09a6f..dfce8c7984 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs @@ -12,8 +12,15 @@ internal class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData edgeData public FanInEdgeState CreateState() => new(this.EdgeData); - public async ValueTask ChaseAsync(string sourceId, object message, FanInEdgeState state) + public async ValueTask ChaseAsync(string sourceId, MessageEnvelope envelope, FanInEdgeState state) { + if (envelope.TargetId != null && this.EdgeData.SinkId != envelope.TargetId) + { + // This message is not for us. + return null; + } + + object message = envelope.Message; IEnumerable? releasedMessages = state.ProcessMessage(sourceId, message); if (releasedMessages is null) { @@ -26,7 +33,7 @@ internal class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData edgeData if (target.CanHandle(message.GetType())) { - return await target.ExecuteAsync(message, this.BoundContext) + return await target.ExecuteAsync(message, envelope.MessageType, this.BoundContext) .ConfigureAwait(false); } return null; diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanOutEdgeRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanOutEdgeRunner.cs index 5d6d6f4780..32a91b1f45 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanOutEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanOutEdgeRunner.cs @@ -14,14 +14,20 @@ internal class FanOutEdgeRunner(IRunnerContext runContext, FanOutEdgeData edgeDa sinkId => sinkId, sinkId => runContext.Bind(sinkId)); - public async ValueTask> ChaseAsync(object message) + public async ValueTask> ChaseAsync(MessageEnvelope envelope) { + object message = envelope.Message; List targets = this.EdgeData.PartitionAssigner == null ? this.EdgeData.SinkIds - : this.EdgeData.PartitionAssigner(message, this.BoundContexts.Count).Select(i => this.EdgeData.SinkIds[i]).ToList(); + : this.EdgeData.PartitionAssigner(message, this.BoundContexts.Count) + .Select(i => this.EdgeData.SinkIds[i]).ToList(); - object?[] result = await Task.WhenAll(targets.Select(ProcessTargetAsync)).ConfigureAwait(false); + IEnumerable filteredTargets = envelope.TargetId != null + ? targets.Where(IsValidTarget) + : targets; + + object?[] result = await Task.WhenAll(filteredTargets.Select(ProcessTargetAsync)).ConfigureAwait(false); return result.Where(r => r is not null); async Task ProcessTargetAsync(string targetId) @@ -31,11 +37,16 @@ internal class FanOutEdgeRunner(IRunnerContext runContext, FanOutEdgeData edgeDa if (executor.CanHandle(message.GetType())) { - return await executor.ExecuteAsync(message, this.BoundContexts[targetId]) + return await executor.ExecuteAsync(message, envelope.MessageType, this.BoundContexts[targetId]) .ConfigureAwait(false); } return null; } + + bool IsValidTarget(string targetId) + { + return envelope.TargetId == null || targetId == envelope.TargetId; + } } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/IRunnerContext.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/IRunnerContext.cs index d9cfc09d74..8ee732b02c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/IRunnerContext.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/IRunnerContext.cs @@ -7,7 +7,7 @@ namespace Microsoft.Agents.Workflows.Execution; internal interface IRunnerContext : IExternalRequestSink { ValueTask AddEventAsync(WorkflowEvent workflowEvent); - ValueTask SendMessageAsync(string executorId, object message); + ValueTask SendMessageAsync(string sourceId, object message, string? targetId = null); // TODO: State Management diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/ISuperStepRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/ISuperStepRunner.cs index 13f7a47346..4a1da0f3be 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/ISuperStepRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/ISuperStepRunner.cs @@ -11,7 +11,8 @@ internal interface ISuperStepRunner bool HasUnservicedRequests { get; } bool HasUnprocessedMessages { get; } - ValueTask EnqueueMessageAsync(object message); + ValueTask EnqueueResponseAsync(ExternalResponse response); + ValueTask EnqueueMessageAsync(T message); event EventHandler? WorkflowEvent; diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs index aaff94df5d..33d9b0e0f0 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -23,16 +24,17 @@ internal class InputEdgeRunner(IRunnerContext runContext, string sinkId) return await this.RunContext.EnsureExecutorAsync(this.EdgeData).ConfigureAwait(false); } - public async ValueTask ChaseAsync(object message) + public async ValueTask ChaseAsync(MessageEnvelope envelope) { Executor target = await this.FindExecutorAsync().ConfigureAwait(false); - if (target.CanHandle(message.GetType())) + if (target.CanHandle(envelope.MessageType)) { - return await target.ExecuteAsync(message, this.WorkflowContext) + return await target.ExecuteAsync(envelope.Message, envelope.MessageType, this.WorkflowContext) .ConfigureAwait(false); } - // TODO: Throw instead? + // TODO: Throw instead? / Log + Debug.WriteLine($"Executor {target.Id} cannot handle message of type {envelope.MessageType.FullName}. Dropping."); return null; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs new file mode 100644 index 0000000000..829a91625a --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Agents.Workflows.Execution; + +internal sealed class MessageEnvelope(object message, Type? declaredType = null, string? targetId = null) +{ + public Type MessageType => declaredType ?? message.GetType(); + public object Message => message; + public string? TargetId => targetId; +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs index fcf3f20412..2731f02b52 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs @@ -23,9 +23,11 @@ internal class MessageRouter { this._typedHandlers = Throw.IfNull(handlers); this._hasCatchall = this._typedHandlers.ContainsKey(typeof(object)); + + this.IncomingTypes = [.. this._typedHandlers.Keys]; } - public HashSet IncomingTypes => [.. this._typedHandlers.Keys]; + public HashSet IncomingTypes { get; } public bool CanHandle(object message) => this.CanHandle(Throw.IfNull(message).GetType()); diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs index 07d30267ed..8b07d6339d 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs @@ -7,16 +7,15 @@ namespace Microsoft.Agents.Workflows.Execution; internal class StepContext { - public Dictionary> QueuedMessages { get; } = new(); + public Dictionary> QueuedMessages { get; } = new(); public bool HasMessages => this.QueuedMessages.Values.Any(messageList => messageList.Count > 0); - public List MessagesFor(string? executorId) + public List MessagesFor(string? executorId) { if (!this.QueuedMessages.TryGetValue(executorId, out var messages)) { - messages = new List(); - this.QueuedMessages[executorId] = messages; + this.QueuedMessages[executorId] = messages = new(); } return messages; diff --git a/dotnet/src/Microsoft.Agents.Workflows/Executor.cs b/dotnet/src/Microsoft.Agents.Workflows/Executor.cs index 4dc81b18b1..194632b56a 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Executor.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Executor.cs @@ -54,11 +54,13 @@ public abstract class Executor : IIdentified /// Process an incoming message using the registered handlers. /// /// The message to be processed by the executor. + /// The "declared" type of the message (captured when it was being sent). This is + /// used to enable routing messages as their base types, in absence of true polymorphic type routing. /// The workflow context in which the executor executes. /// A ValueTask representing the asynchronous operation, wrapping the output from the executor. /// No handler found for the message type. /// An exception is generated while handling the message. - public async ValueTask ExecuteAsync(object message, IWorkflowContext context) + public async ValueTask ExecuteAsync(object message, Type messageType, IWorkflowContext context) { await context.AddEventAsync(new ExecutorInvokeEvent(this.Id, message)).ConfigureAwait(false); diff --git a/dotnet/src/Microsoft.Agents.Workflows/IWorkflowContext.cs b/dotnet/src/Microsoft.Agents.Workflows/IWorkflowContext.cs index 1c83d81108..b3199d0c32 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/IWorkflowContext.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/IWorkflowContext.cs @@ -21,8 +21,11 @@ public interface IWorkflowContext /// Queues a message to be sent to connected executors. The message will be sent during the next SuperStep. /// /// The message to be sent. + /// An optional identifier of the target executor. If null, the message is sent to all connected + /// executors. If the target executor is not connected from this executor via an edge, it will still not receive the + /// message. /// A representing the asynchronous operation. - ValueTask SendMessageAsync(object message); + ValueTask SendMessageAsync(object message, string? targetId = null); /// /// Reads a state value from the workflow's state store. If no scope is provided, the executor's private diff --git a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs index 8a2879de37..29217a099c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -29,9 +30,38 @@ internal class InProcessRunner : ISuperStepRunner where TInput : notnull this.EdgeMap = new EdgeMap(this.RunContext, this.Workflow.Edges, this.Workflow.Ports.Values, this.Workflow.StartExecutorId); } - ValueTask ISuperStepRunner.EnqueueMessageAsync(object message) + public async ValueTask IsValidInputAsync(TMessage message) { - return this.RunContext.AddExternalMessageAsync(message); + Throw.IfNull(message); + + Type type = typeof(TMessage); + + // Short circuit the logic if the type is the input type + if (type == typeof(TInput)) + { + return true; + } + + Executor startingExecutor = await this.RunContext.EnsureExecutorAsync(this.Workflow.StartExecutorId).ConfigureAwait(false); + return startingExecutor.CanHandle(type); + } + + async ValueTask ISuperStepRunner.EnqueueMessageAsync(T message) + { + // Check that the type of the incoming message is compatible with the starting executor's + // input type. + if (!await this.IsValidInputAsync(message).ConfigureAwait(false)) + { + return false; + } + + await this.RunContext.AddExternalMessageAsync(message).ConfigureAwait(false); + return true; + } + + ValueTask ISuperStepRunner.EnqueueResponseAsync(ExternalResponse response) + { + return this.RunContext.AddExternalMessageAsync(response); } private Dictionary PendingCalls { get; } = new(); @@ -57,11 +87,14 @@ internal class InProcessRunner : ISuperStepRunner where TInput : notnull return message is ExternalResponse; } - private ValueTask> RouteExternalMessageAsync(object message) + private ValueTask> RouteExternalMessageAsync(MessageEnvelope envelope) { + Debug.Assert(envelope.TargetId == null, "External Messages cannot be targeted to a specific executor."); + + object message = envelope.Message; return message is ExternalResponse response ? this.CompleteExternalResponseAsync(response) - : this.EdgeMap.InvokeInputAsync(message); + : this.EdgeMap.InvokeInputAsync(envelope); } private ValueTask> CompleteExternalResponseAsync(ExternalResponse response) @@ -113,16 +146,16 @@ internal class InProcessRunner : ISuperStepRunner where TInput : notnull List>> edgeTasks = new(); foreach (ExecutorIdentity sender in currentStep.QueuedMessages.Keys) { - IEnumerable senderMessages = currentStep.QueuedMessages[sender]; + IEnumerable senderMessages = currentStep.QueuedMessages[sender]; if (sender.Id is null) { - edgeTasks.AddRange(senderMessages.Select(message => this.RouteExternalMessageAsync(message).AsTask())); + edgeTasks.AddRange(senderMessages.Select(envelope => this.RouteExternalMessageAsync(envelope).AsTask())); } else if (this.Workflow.Edges.TryGetValue(sender.Id!, out HashSet? outgoingEdges)) { foreach (Edge outgoingEdge in outgoingEdges) { - edgeTasks.AddRange(senderMessages.Select(message => this.EdgeMap.InvokeEdgeAsync(outgoingEdge, sender.Id, message).AsTask())); + edgeTasks.AddRange(senderMessages.Select(envelope => this.EdgeMap.InvokeEdgeAsync(outgoingEdge, sender.Id, envelope).AsTask())); } } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs index ff25293cb6..a797832524 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.Workflows.Execution; @@ -44,11 +43,19 @@ internal class InProcessRunnerContext : IRunnerContext return executor; } - public ValueTask AddExternalMessageAsync([NotNull] object message) + public ValueTask AddExternalMessageUntypedAsync(object message) { Throw.IfNull(message); - this._nextStep.MessagesFor(ExecutorIdentity.None).Add(message); + this._nextStep.MessagesFor(ExecutorIdentity.None).Add(new MessageEnvelope(message)); + return default; + } + + public ValueTask AddExternalMessageAsync(T message) + { + Throw.IfNull(message); + + this._nextStep.MessagesFor(ExecutorIdentity.None).Add(new MessageEnvelope(message, declaredType: typeof(T))); return default; } @@ -66,9 +73,9 @@ internal class InProcessRunnerContext : IRunnerContext return default; } - public ValueTask SendMessageAsync(string executorId, object message) + public ValueTask SendMessageAsync(string sourceId, object message, string? targetId = null) { - this._nextStep.MessagesFor(executorId).Add(message); + this._nextStep.MessagesFor(sourceId).Add(new MessageEnvelope(message, targetId: targetId)); return default; } @@ -92,7 +99,7 @@ internal class InProcessRunnerContext : IRunnerContext private class BoundContext(InProcessRunnerContext RunnerContext, string ExecutorId) : IWorkflowContext { public ValueTask AddEventAsync(WorkflowEvent workflowEvent) => RunnerContext.AddEventAsync(workflowEvent); - public ValueTask SendMessageAsync(object message) => RunnerContext.SendMessageAsync(ExecutorId, message); + public ValueTask SendMessageAsync(object message, string? targetId = null) => RunnerContext.SendMessageAsync(ExecutorId, message, targetId); public ValueTask QueueStateUpdateAsync(string key, T? value, string? scopeName = null) => RunnerContext.StateManager.WriteStateAsync(ExecutorId, scopeName, key, value); diff --git a/dotnet/src/Microsoft.Agents.Workflows/Run.cs b/dotnet/src/Microsoft.Agents.Workflows/Run.cs index fa520e092e..93e031c10f 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Run.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Run.cs @@ -127,6 +127,28 @@ public class Run return await this.RunToNextHaltAsync(cancellation).ConfigureAwait(false); } + + /// + /// Resume execution of the workflow with the provided external responses. + /// + /// A that can be used to cancel the workflow execution. + /// An array of messages to send to the workflow. Messages will only be sent if they are valid + /// input types to the starting executor or a . + /// true if the workflow had any output events, false otherwise. + public async ValueTask ResumeAsync(CancellationToken cancellation = default, params T[] messages) + { + if (messages is ExternalResponse[] responses) + { + return await this.ResumeAsync(cancellation, responses).ConfigureAwait(false); + } + + foreach (T message in messages) + { + await this._streamingRun.TrySendMessageAsync(message).ConfigureAwait(false); + } + + return await this.RunToNextHaltAsync(cancellation).ConfigureAwait(false); + } } /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/Specialized/AIAgentHostExecutor.cs b/dotnet/src/Microsoft.Agents.Workflows/Specialized/AIAgentHostExecutor.cs index deb7216b73..fbb97d8802 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Specialized/AIAgentHostExecutor.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Specialized/AIAgentHostExecutor.cs @@ -3,30 +3,64 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using Microsoft.Agents.Workflows.Reflection; using Microsoft.Extensions.AI; using Microsoft.Extensions.AI.Agents; namespace Microsoft.Agents.Workflows.Specialized; -internal class AIAgentHostExecutor : ReflectingExecutor, IMessageHandler> +internal class AIAgentHostExecutor : Executor { - private AIAgent Agent { get; set; } + private readonly AIAgent _agent; + private readonly List _pendingMessages = new(); + private AgentThread? _thread = null; - public AIAgentHostExecutor(AIAgent agent) + public AIAgentHostExecutor(AIAgent agent) : base(id: agent.Id) { - this.Agent = agent; + this._agent = agent; } - public async ValueTask HandleAsync(IList message, IWorkflowContext context) + private AgentThread EnsureThread() { - IReadOnlyCollection messageList = (message as List ?? message.ToList()).AsReadOnly(); + if (this._thread != null) + { + return this._thread; + } + return this._thread = this._agent.GetNewThread(); + } + + protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + { + return routeBuilder.AddHandler(this.QueueMessageAsync) + .AddHandler>(this.QueueMessagesAsync) + .AddHandler(this.TakeTurnAsync); + } + + public ValueTask QueueMessagesAsync(List messages, IWorkflowContext context) + { + this._pendingMessages.AddRange(messages); + return default; + } + + public ValueTask QueueMessageAsync(ChatMessage message, IWorkflowContext context) + { + this._pendingMessages.Add(message); + return default; + } + + public async ValueTask TakeTurnAsync(TurnToken token, IWorkflowContext context) + { // TODO: Ideally we want to be able to split the Run across multiple super-steps so that we can stream out // incremental updates from the chat model. - AgentRunResponse runResponse = await this.Agent.RunAsync(messageList).ConfigureAwait(false); + AgentRunResponse runResponse = await this._agent.RunAsync(this._pendingMessages, this.EnsureThread()) + .ConfigureAwait(false); - await context.AddEventAsync(new AgentRunEvent(this.Id, runResponse)).ConfigureAwait(false); - await context.SendMessageAsync(runResponse).ConfigureAwait(false); + if (token.EmitEvents) + { + await context.AddEventAsync(new AgentRunEvent(this.Id, runResponse)).ConfigureAwait(false); + } + + await context.SendMessageAsync(runResponse.Messages.ToList()).ConfigureAwait(false); + await context.SendMessageAsync(token).ConfigureAwait(false); } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/StreamingRun.cs b/dotnet/src/Microsoft.Agents.Workflows/StreamingRun.cs index 3a6c51d85b..f5c6943148 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/StreamingRun.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/StreamingRun.cs @@ -37,13 +37,35 @@ public class StreamingRun /// /// The response will be queued for processing for the next superstep. /// The to send. Must not be null. - /// A that represents the asynchronous send operation. The task completes when the response - /// has been enqueued for processing, but will not wait for processing to complete. + /// A that represents the asynchronous send operation. public ValueTask SendResponseAsync(ExternalResponse response) { this._waitForResponseSource?.TrySetResult(new()); - return this._stepRunner.EnqueueMessageAsync(response); + return this._stepRunner.EnqueueResponseAsync(response); + } + + /// + /// Attempts to send the specified message asynchronously and returns a value indicating whether the operation was + /// successful. + /// + /// The type of the message to send. Must be compatible with the expected message types for + /// the starting executor, or receiving port. + /// The message instance to send. Cannot be null. + /// A that represents the asynchronous send operation. It's + /// is if the message was sent + /// successfully; otherwise, . + public async ValueTask TrySendMessageAsync(TMessage message) + { + Throw.IfNull(message); + + if (message is ExternalResponse response) + { + await this.SendResponseAsync(response).ConfigureAwait(false); + return true; + } + + return await this._stepRunner.EnqueueMessageAsync(message).ConfigureAwait(false); } /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/TurnToken.cs b/dotnet/src/Microsoft.Agents.Workflows/TurnToken.cs new file mode 100644 index 0000000000..a561607799 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/TurnToken.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI.Agents; + +namespace Microsoft.Agents.Workflows; + +/// +/// Sent to an -based executor to request +/// a response to accumulated . +/// +/// Whether to raise AgentRunEvents for this executor. +public class TurnToken(bool emitEvents = false) +{ + /// + /// Gets a value indicating whether events are emitted by the receiving executor. + /// + public bool EmitEvents => emitEvents; +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs index cbae3e1c45..fcfebaebd4 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs @@ -15,6 +15,59 @@ namespace Microsoft.Agents.Workflows; /// promote common patterns for chaining and aggregating workflow steps. public static class WorkflowBuilderExtensions { + /// + /// Adds edges to the workflow that forward messages of the specified type from the source executor to + /// one or more target executors. + /// + /// The type of message to forward. + /// The to which the edges will be added. + /// The source executor from which messages will be forwarded. + /// The target executors to which messages will be forwarded. + /// The updated instance. + public static WorkflowBuilder ForwardMessage(this WorkflowBuilder builder, ExecutorIsh source, params ExecutorIsh[] executors) + { + Throw.IfNullOrEmpty(executors); + + if (executors.Length == 1) + { + return builder.AddEdge(source, executors[0], IsAllowedType); + } + + return builder.AddSwitch(source, + (switch_) => + { + switch_.AddCase(IsAllowedType, executors); + }); + + bool IsAllowedType(object? message) => message is TMessage; + } + + /// + /// Adds edges from the specified source to the provided executors, excluding messages of a specified type. + /// + /// The type of messages to exclude from being forwarded to the executors. + /// The instance to which the edges will be added. + /// The source executor from which messages will be forwarded. + /// The target executors to which messages, except those of type , will be forwarded. + /// The updated instance with the added edges. + public static WorkflowBuilder ForwardExcept(this WorkflowBuilder builder, ExecutorIsh source, params ExecutorIsh[] executors) + { + Throw.IfNullOrEmpty(executors); + + if (executors.Length == 1) + { + return builder.AddEdge(source, executors[0], IsAllowedType); + } + + return builder.AddSwitch(source, + (switch_) => + { + switch_.AddCase(IsAllowedType, executors); + }); + + bool IsAllowedType(object? message) => message is not TMessage; + } + /// /// Adds a sequential chain of executors to the workflow, connecting each executor in order so that each is /// executed after the previous one. diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs new file mode 100644 index 0000000000..02ebc76ab0 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -0,0 +1,292 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI.Agents; + +namespace Microsoft.Agents.Workflows.UnitTests.Sample; + +internal static class Step6EntryPoint +{ + internal static int MaxSteps { get; set; } + + public static async ValueTask RunAsync(TextWriter writer, int maxSteps = 2) + { + Step6EntryPoint.MaxSteps = maxSteps; + + GroupChatBuilder builder = GroupChatBuilder.Create() + .AddParticipant(new HelloAgent(), shouldEmitEvents: true) + .AddParticipant(new EchoAgent(), shouldEmitEvents: true); + + Workflow> workflow = builder.ReduceToWorkflow(); + + StreamingRun run = await InProcessExecution.StreamAsync(workflow, []) + .ConfigureAwait(false); + await run.TrySendMessageAsync(new TurnToken(emitEvents: true)); + + await foreach (WorkflowEvent evt in run.WatchStreamAsync().ConfigureAwait(false)) + { + if (evt is ExecutorCompleteEvent executorComplete) + { + Debug.WriteLine($"{executorComplete.ExecutorId}: {executorComplete.Data}"); + } + else if (evt is AgentRunEvent agentRun && agentRun.Data is AgentRunResponse response) + { + foreach (ChatMessage message in response.Messages) + { + writer.WriteLine($"{agentRun.ExecutorId}: {message.Text}"); + } + } + } + } + + private sealed class RoundRobinGroupChatManager : GroupChatManager + { + public int TurnCount { get; private set; } = 0; + public int MaxTurns { get; init; } = Step6EntryPoint.MaxSteps; + + public override int? GetNextTurnExecutor(GroupChatHistory history) + { + if (this.ParticipantIds.Length == 0) + { + throw new InvalidOperationException("No participants in the group chat."); + } + + if (this.TurnCount >= this.MaxTurns) + { + return null; + } + + return this.TurnCount++ % this.ParticipantIds.Length; + } + } +} + +internal sealed class HelloAgent(string id = nameof(HelloAgent)) : AIAgent +{ + public const string Greeting = "Hello World!"; + public const string DefaultId = nameof(HelloAgent); + + public override string Id => id; + + public override Task RunAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + { + AgentRunResponse response = new(new ChatMessage(ChatRole.Assistant, "Hello World!")); + + return Task.FromResult(response); + } + + public override IAsyncEnumerable RunStreamingAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } +} + +internal sealed class EchoAgent(string id = nameof(EchoAgent)) : AIAgent +{ + public const string Prefix = "You said: "; + public const string DefaultId = nameof(EchoAgent); + + public override string Id => id; + + public override Task RunAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + { + if (messages.Count == 0) + { + throw new ArgumentException("No messages provided to echo.", nameof(messages)); + } + + StringBuilder collectedText = new(Prefix); + foreach (string messageText in messages.Select(message => message.Text) + .Where(text => !string.IsNullOrEmpty(text))) + { + collectedText.AppendLine(messageText); + } + + AgentRunResponse result = new(new ChatMessage(ChatRole.Assistant, collectedText.ToString())); + return Task.FromResult(result); + } + + public override IAsyncEnumerable RunStreamingAsync(IReadOnlyCollection messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } +} + +internal sealed class GroupChatHistory +{ + private readonly List _messages = new(); + private int _bookmark = 0; + + public void AddMessage(ChatMessage message) + { + this._messages.Add(message); + } + + public void AddMessages(IEnumerable messages) + { + this._messages.AddRange(messages); + } + + public void UpdateBookmark() + { + this._bookmark = this._messages.Count; + } + + public IReadOnlyList FullHistory => this._messages.AsReadOnly(); + public IEnumerable NewMessagesThisTurn => this._messages.Skip(this._bookmark); +} + +internal abstract class GroupChatManager +{ + public string[] ParticipantIds { get; internal init; } = []; + + public abstract int? GetNextTurnExecutor(GroupChatHistory history); +} + +internal sealed class GroupChatBuilder +{ + private readonly List _participants = new(); + private readonly List _shouldEmitEvents = new(); + private readonly Func _managerFactory; + + private GroupChatBuilder(Func managerFactory) + { + this._managerFactory = managerFactory; + } + + public static GroupChatBuilder Create() where TManager : GroupChatManager, new() + { + return new GroupChatBuilder(participantIds => new TManager() { ParticipantIds = participantIds }); + } + + public GroupChatBuilder AddParticipant(ExecutorIsh executor, bool shouldEmitEvents = false) + { + this._participants.Add(executor); + this._shouldEmitEvents.Add(shouldEmitEvents); + + return this; + } + + public GroupChatBuilder AddParticipants(params ExecutorIsh[] executors) + { + this._participants.AddRange(executors); + return this; + } + + public Workflow> ReduceToWorkflow() + { + string[] participantIds = this._participants.Select(identified => identified.Id).ToArray(); + GroupChatHost host = new(this._shouldEmitEvents.ToArray(), this._managerFactory(participantIds)); + + WorkflowBuilder builder = new WorkflowBuilder(host) + .AddFanOutEdge(host, targets: this._participants.ToArray()); + + foreach (ExecutorIsh participant in this._participants) + { + builder.AddEdge(participant, host); + } + + return builder.Build>(); + + //bool IsMessageType(object? message) => message is ChatMessage || message is IEnumerable; + } + + private sealed class TurnAssignedEvent(string executorId, string nextSpeakerId) : ExecutorEvent(executorId, data: nextSpeakerId); + + private sealed class GroupChatHost : Executor + { + private readonly bool[] _shouldEmitEvents; + private readonly GroupChatManager _manager; + private readonly bool _autoStartConversation; + + private readonly GroupChatHistory _history = new(); + + public GroupChatHost(bool[] shouldEmitEvents, GroupChatManager manager, bool autoStartConversation = false) : base(nameof(GroupChatHost)) + { + this._shouldEmitEvents = shouldEmitEvents; + this._manager = manager ?? throw new ArgumentNullException(nameof(manager)); + this._autoStartConversation = autoStartConversation; + } + + protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + { + return routeBuilder.AddHandler>(this.HandleChatMessagesAsync) + .AddHandler(this.HandleChatMessageAsync) + .AddHandler(this.AssignNextTurnAsync); + } + + private async Task TryAutoStartConversationAsync(IWorkflowContext context) + { + if (this._autoStartConversation && this.TryEnterConversation()) + { + await this.AssignNextTurnAsync(new TurnToken(emitEvents: false), context).ConfigureAwait(false); + } + } + + private async ValueTask HandleChatMessagesAsync(List initialMessages, IWorkflowContext context) + { + this._history.AddMessages(initialMessages); + + await context.SendMessageAsync(initialMessages).ConfigureAwait(false); + await this.TryAutoStartConversationAsync(context).ConfigureAwait(false); + } + + private async ValueTask HandleChatMessageAsync(ChatMessage message, IWorkflowContext context) + { + // First, add the message to the history, then forward to all executors + this._history.AddMessage(message); + + await context.SendMessageAsync(message).ConfigureAwait(false); + await this.TryAutoStartConversationAsync(context).ConfigureAwait(false); + } + + private int _inConversationFlag = 0; + + /// + /// Atomically switches to "in conversation" state if not already in that state. + /// + /// if the state was changed, otherwise. + private bool TryEnterConversation() + { + return Interlocked.CompareExchange(ref this._inConversationFlag, 1, 0) == 0; + } + + private bool _shouldHostEmitEvents = false; + private async ValueTask AssignNextTurnAsync(TurnToken token, IWorkflowContext context) + { + if (this.TryEnterConversation()) + { + // Capture the initial turn token's EmitEvents setting + this._shouldHostEmitEvents = token.EmitEvents; + } + + int? nextSpeakerIndex = this._manager.GetNextTurnExecutor(this._history); + if (nextSpeakerIndex == null) + { + await context.AddEventAsync(new WorkflowCompletedEvent()) + .ConfigureAwait(false); + + return; + } + + string nextSpeakerId = this._manager.ParticipantIds[nextSpeakerIndex.Value]; + + if (this._shouldHostEmitEvents) + { + await context.AddEventAsync(new TurnAssignedEvent(this.Id, nextSpeakerId)) + .ConfigureAwait(false); + } + + await context.SendMessageAsync(new TurnToken(this._shouldEmitEvents[nextSpeakerIndex.Value]), nextSpeakerId) + .ConfigureAwait(false); + } + } +} diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs index dc08df598e..7eca842077 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Threading.Tasks; using Microsoft.Agents.Workflows.Sample; +using Microsoft.Agents.Workflows.UnitTests.Sample; namespace Microsoft.Agents.Workflows.UnitTests; @@ -83,6 +84,22 @@ public class SampleSmokeTest string guessResult = await Step5EntryPoint.RunAsync(writer, userGuessCallback: responder.InvokeNext); Assert.Equal("You guessed correctly! You Win!", guessResult); } + + [Fact] + public async Task Test_RunSample_Step6Async() + { + using StringWriter writer = new(); + + await Step6EntryPoint.RunAsync(writer); + + string result = writer.ToString(); + string[] lines = result.Split([Environment.NewLine], StringSplitOptions.RemoveEmptyEntries); + + Assert.Collection(lines, + line => Assert.Contains($"{HelloAgent.DefaultId}: {HelloAgent.Greeting}", line), + line => Assert.Contains($"{EchoAgent.DefaultId}: {EchoAgent.Prefix}{HelloAgent.Greeting}", line) + ); + } } internal sealed class VerifyingPlaybackResponder