diff --git a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs index d1291ff982..a6d7081c1c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs @@ -214,6 +214,8 @@ internal class InProcessRunner : ISuperStepRunner, ICheckpointingRunner { if (this.CheckpointManager == null) { + // Always publish the state updates, even in the absence of a CheckpointManager. + await this.RunContext.StateManager.PublishUpdatesAsync(this.StepTracer).ConfigureAwait(false); return; } diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/ForwardMessageExecutor.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/ForwardMessageExecutor.cs new file mode 100644 index 0000000000..7724f73630 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/ForwardMessageExecutor.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Agents.Workflows.UnitTests; + +internal sealed class ForwardMessageExecutor : Executor where TMessage : notnull +{ + protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + { + return routeBuilder.AddHandler((message, ctx) => ctx.SendMessageAsync(message)); + } +} diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InProcessStateTests.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InProcessStateTests.cs new file mode 100644 index 0000000000..63bdf4455d --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InProcessStateTests.cs @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; + +namespace Microsoft.Agents.Workflows.UnitTests; + +public class InProcessStateTests +{ + private sealed class TurnToken + { + public int Count { get; } + + public TurnToken() : this(0) + { } + + private TurnToken(int count) + { + this.Count = count; + } + + public TurnToken Next => new(this.Count + 1); + } + + private sealed class StateTestExecutor : TestingExecutor + { + private static Func>[] WrapActions(ScopeKey stateKey, Func[] stateActions) + { + Func>[] result + = new Func>[stateActions.Length]; + + for (int i = 0; i < stateActions.Length; i++) + { + result[i] = CreateWrapperAsync(stateActions[i]); + } + + return result; + + Func> CreateWrapperAsync(Func action) + { + return + async (turn, context, cancellation) => + { + TState? state = await context.ReadStateAsync(stateKey.Key, stateKey.ScopeId.ScopeName) + .ConfigureAwait(false); + + state = action(state); + + await context.QueueStateUpdateAsync(stateKey.Key, state, stateKey.ScopeId.ScopeName); + + return turn.Next; + }; + } + } + + public ScopeKey StateKey { get; } + + public StateTestExecutor(ScopeKey stateKey, bool loop = false, params Func[] stateActions) + : base(stateKey.ScopeId.ExecutorId, loop, WrapActions(stateKey, stateActions)) + { + this.StateKey = stateKey; + } + } + + private static Func CreateOrIncrement(int defaultValue = default) + => currState => currState.HasValue ? currState + 1 : defaultValue; + + private static Func ValidateState(int expectedValue, string? because = null, params object[] becauseArgs) + => (int? currState) => + { + currState.Should().Be(expectedValue, because, becauseArgs); + + return currState; + }; + + private static Func MaxTurns(int maxTurns) + => (object? maybeTurn) => maybeTurn is not TurnToken turn || turn.Count < maxTurns; + + [Fact] + public async Task InProcessRun_StateShouldPersist_NotCheckpointedAsync() + { + StateTestExecutor writer = new( + new ScopeKey("Writer", "TestScope", "TestKey"), + loop: false, + CreateOrIncrement(), + CreateOrIncrement() + ); + + StateTestExecutor validator = new( + new ScopeKey("Validator", "TestScope", "TestKey"), + loop: false, + ValidateState(0), + ValidateState(1) + ); + + Workflow workflow = + new WorkflowBuilder(writer) + .AddEdge(writer, validator, MaxTurns(4)) + .AddEdge(validator, writer, MaxTurns(4)).Build(); + + Run run = await InProcessExecution.RunAsync(workflow, new()); + + run.Status.Should().Be(RunStatus.Idle); + } + + [Fact] + public async Task InProcessRun_StateShouldPersist_CheckpointedAsync() + { + StateTestExecutor writer = new( + new ScopeKey("Writer", "TestScope", "TestKey"), + loop: false, + CreateOrIncrement(), + CreateOrIncrement() + ); + + StateTestExecutor validator = new( + new ScopeKey("Validator", "TestScope", "TestKey"), + loop: false, + ValidateState(0), + ValidateState(1) + ); + + Workflow workflow = + new WorkflowBuilder(writer) + .AddEdge(writer, validator, MaxTurns(4)) + .AddEdge(validator, writer, MaxTurns(4)).Build(); + + Checkpointed checkpointed = await InProcessExecution.RunAsync(workflow, new(), new CheckpointManager()); + + checkpointed.Checkpoints.Should().HaveCount(6); + checkpointed.Run.Status.Should().Be(RunStatus.Idle); + } + + [Fact] + public async Task InProcessRun_StateShouldError_TwoExecutorsAsync() + { + ForwardMessageExecutor forward = new(); + using StateTestExecutor testExecutor = new( + new ScopeKey("StateTestExecutor", "TestScope", "TestKey"), + loop: false, + CreateOrIncrement() + ); + + using StateTestExecutor testExecutor2 = new( + new ScopeKey("StateTestExecutor2", "TestScope", "TestKey"), + loop: false, + CreateOrIncrement() + ); + + Workflow workflow = + new WorkflowBuilder(forward) + .AddFanOutEdge(forward, targets: [testExecutor, testExecutor2]) + .Build(); + + var act = async () => await InProcessExecution.RunAsync(workflow, new()); + + var result = await act.Should() + .ThrowAsync("multiple writers to the same shared scope key"); + } +} diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestingExecutor.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestingExecutor.cs new file mode 100644 index 0000000000..c90e88416c --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestingExecutor.cs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Agents.Workflows.UnitTests; + +internal class TestingExecutor : Executor, IDisposable +{ + private readonly bool _loop; + private readonly Func>[] _actions; + private readonly HashSet _linkedTokens = new(); + private CancellationTokenSource _internalCts = new(); + + public TestingExecutor(string? id = null, bool loop = false, params Func>[] actions) : base(id) + { + this._loop = loop; + this._actions = actions; + } + + public void UnlinkCancellation(CancellationToken token) + { + this._linkedTokens.Remove(token); + } + + public void LinkCancellation(CancellationToken token) + { + this._linkedTokens.Add(token); + CancellationTokenSource tokenSource = CancellationTokenSource.CreateLinkedTokenSource(this._linkedTokens.ToArray()); + tokenSource = Interlocked.Exchange(ref this._internalCts, tokenSource); + tokenSource.Dispose(); + } + + public void SetCancel() + { + Volatile.Read(ref this._internalCts).Cancel(); + } + + protected sealed override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder) + { + return routeBuilder.AddHandler(this.RouteToActions); + } + + private int _nextActionIndex = 0; + private ValueTask RouteToActions(TIn message, IWorkflowContext context) + { + if (this._nextActionIndex >= this._actions.Length) + { + if (this._loop) + { + this._nextActionIndex = 0; + } + else + { + throw new InvalidOperationException("No more actions to execute and looping is disabled."); + } + } + + try + { + Func> action = this._actions[this._nextActionIndex]; + return action(message, context, Volatile.Read(ref this._internalCts).Token); + } + finally + { + this._nextActionIndex++; + } + } + + ~TestingExecutor() + { + this.Dispose(false); + } + + protected virtual void Dispose(bool disposing) + { + this._internalCts.Dispose(); + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } +}