fix: State Updates are not published without CheckpointManager (#501)

This commit is contained in:
Jacob Alber
2025-08-26 13:41:48 -04:00
committed by GitHub
Unverified
parent fa88641263
commit 8544cd9b03
4 changed files with 263 additions and 0 deletions
@@ -214,6 +214,8 @@ internal class InProcessRunner<TInput> : ISuperStepRunner, ICheckpointingRunner
{
if (this.CheckpointManager == null)
{
// Always publish the state updates, even in the absence of a CheckpointManager.
await this.RunContext.StateManager.PublishUpdatesAsync(this.StepTracer).ConfigureAwait(false);
return;
}
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Agents.Workflows.UnitTests;
internal sealed class ForwardMessageExecutor<TMessage> : Executor where TMessage : notnull
{
protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)
{
return routeBuilder.AddHandler<TMessage>((message, ctx) => ctx.SendMessageAsync(message));
}
}
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
namespace Microsoft.Agents.Workflows.UnitTests;
public class InProcessStateTests
{
private sealed class TurnToken
{
public int Count { get; }
public TurnToken() : this(0)
{ }
private TurnToken(int count)
{
this.Count = count;
}
public TurnToken Next => new(this.Count + 1);
}
private sealed class StateTestExecutor<TState> : TestingExecutor<TurnToken, TurnToken>
{
private static Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>>[] WrapActions(ScopeKey stateKey, Func<TState?, TState?>[] stateActions)
{
Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>>[] result
= new Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>>[stateActions.Length];
for (int i = 0; i < stateActions.Length; i++)
{
result[i] = CreateWrapperAsync(stateActions[i]);
}
return result;
Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>> CreateWrapperAsync(Func<TState?, TState?> action)
{
return
async (turn, context, cancellation) =>
{
TState? state = await context.ReadStateAsync<TState>(stateKey.Key, stateKey.ScopeId.ScopeName)
.ConfigureAwait(false);
state = action(state);
await context.QueueStateUpdateAsync(stateKey.Key, state, stateKey.ScopeId.ScopeName);
return turn.Next;
};
}
}
public ScopeKey StateKey { get; }
public StateTestExecutor(ScopeKey stateKey, bool loop = false, params Func<TState?, TState?>[] stateActions)
: base(stateKey.ScopeId.ExecutorId, loop, WrapActions(stateKey, stateActions))
{
this.StateKey = stateKey;
}
}
private static Func<int?, int?> CreateOrIncrement(int defaultValue = default)
=> currState => currState.HasValue ? currState + 1 : defaultValue;
private static Func<int?, int?> ValidateState(int expectedValue, string? because = null, params object[] becauseArgs)
=> (int? currState) =>
{
currState.Should().Be(expectedValue, because, becauseArgs);
return currState;
};
private static Func<object?, bool> MaxTurns(int maxTurns)
=> (object? maybeTurn) => maybeTurn is not TurnToken turn || turn.Count < maxTurns;
[Fact]
public async Task InProcessRun_StateShouldPersist_NotCheckpointedAsync()
{
StateTestExecutor<int?> writer = new(
new ScopeKey("Writer", "TestScope", "TestKey"),
loop: false,
CreateOrIncrement(),
CreateOrIncrement()
);
StateTestExecutor<int?> validator = new(
new ScopeKey("Validator", "TestScope", "TestKey"),
loop: false,
ValidateState(0),
ValidateState(1)
);
Workflow<TurnToken> workflow =
new WorkflowBuilder(writer)
.AddEdge(writer, validator, MaxTurns(4))
.AddEdge(validator, writer, MaxTurns(4)).Build<TurnToken>();
Run run = await InProcessExecution.RunAsync(workflow, new());
run.Status.Should().Be(RunStatus.Idle);
}
[Fact]
public async Task InProcessRun_StateShouldPersist_CheckpointedAsync()
{
StateTestExecutor<int?> writer = new(
new ScopeKey("Writer", "TestScope", "TestKey"),
loop: false,
CreateOrIncrement(),
CreateOrIncrement()
);
StateTestExecutor<int?> validator = new(
new ScopeKey("Validator", "TestScope", "TestKey"),
loop: false,
ValidateState(0),
ValidateState(1)
);
Workflow<TurnToken> workflow =
new WorkflowBuilder(writer)
.AddEdge(writer, validator, MaxTurns(4))
.AddEdge(validator, writer, MaxTurns(4)).Build<TurnToken>();
Checkpointed<Run> checkpointed = await InProcessExecution.RunAsync(workflow, new(), new CheckpointManager());
checkpointed.Checkpoints.Should().HaveCount(6);
checkpointed.Run.Status.Should().Be(RunStatus.Idle);
}
[Fact]
public async Task InProcessRun_StateShouldError_TwoExecutorsAsync()
{
ForwardMessageExecutor<TurnToken> forward = new();
using StateTestExecutor<int?> testExecutor = new(
new ScopeKey("StateTestExecutor", "TestScope", "TestKey"),
loop: false,
CreateOrIncrement()
);
using StateTestExecutor<int?> testExecutor2 = new(
new ScopeKey("StateTestExecutor2", "TestScope", "TestKey"),
loop: false,
CreateOrIncrement()
);
Workflow<TurnToken> workflow =
new WorkflowBuilder(forward)
.AddFanOutEdge(forward, targets: [testExecutor, testExecutor2])
.Build<TurnToken>();
var act = async () => await InProcessExecution.RunAsync(workflow, new());
var result = await act.Should()
.ThrowAsync("multiple writers to the same shared scope key");
}
}
@@ -0,0 +1,88 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Agents.Workflows.UnitTests;
internal class TestingExecutor<TIn, TOut> : Executor, IDisposable
{
private readonly bool _loop;
private readonly Func<TIn, IWorkflowContext, CancellationToken, ValueTask<TOut>>[] _actions;
private readonly HashSet<CancellationToken> _linkedTokens = new();
private CancellationTokenSource _internalCts = new();
public TestingExecutor(string? id = null, bool loop = false, params Func<TIn, IWorkflowContext, CancellationToken, ValueTask<TOut>>[] actions) : base(id)
{
this._loop = loop;
this._actions = actions;
}
public void UnlinkCancellation(CancellationToken token)
{
this._linkedTokens.Remove(token);
}
public void LinkCancellation(CancellationToken token)
{
this._linkedTokens.Add(token);
CancellationTokenSource tokenSource = CancellationTokenSource.CreateLinkedTokenSource(this._linkedTokens.ToArray());
tokenSource = Interlocked.Exchange(ref this._internalCts, tokenSource);
tokenSource.Dispose();
}
public void SetCancel()
{
Volatile.Read(ref this._internalCts).Cancel();
}
protected sealed override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)
{
return routeBuilder.AddHandler<TIn, TOut>(this.RouteToActions);
}
private int _nextActionIndex = 0;
private ValueTask<TOut> RouteToActions(TIn message, IWorkflowContext context)
{
if (this._nextActionIndex >= this._actions.Length)
{
if (this._loop)
{
this._nextActionIndex = 0;
}
else
{
throw new InvalidOperationException("No more actions to execute and looping is disabled.");
}
}
try
{
Func<TIn, IWorkflowContext, CancellationToken, ValueTask<TOut>> action = this._actions[this._nextActionIndex];
return action(message, context, Volatile.Read(ref this._internalCts).Token);
}
finally
{
this._nextActionIndex++;
}
}
~TestingExecutor()
{
this.Dispose(false);
}
protected virtual void Dispose(bool disposing)
{
this._internalCts.Dispose();
}
public void Dispose()
{
this.Dispose(true);
GC.SuppressFinalize(this);
}
}