mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
fix: State Updates are not published without CheckpointManager (#501)
This commit is contained in:
committed by
GitHub
Unverified
parent
fa88641263
commit
8544cd9b03
@@ -214,6 +214,8 @@ internal class InProcessRunner<TInput> : ISuperStepRunner, ICheckpointingRunner
|
||||
{
|
||||
if (this.CheckpointManager == null)
|
||||
{
|
||||
// Always publish the state updates, even in the absence of a CheckpointManager.
|
||||
await this.RunContext.StateManager.PublishUpdatesAsync(this.StepTracer).ConfigureAwait(false);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
namespace Microsoft.Agents.Workflows.UnitTests;
|
||||
|
||||
internal sealed class ForwardMessageExecutor<TMessage> : Executor where TMessage : notnull
|
||||
{
|
||||
protected override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)
|
||||
{
|
||||
return routeBuilder.AddHandler<TMessage>((message, ctx) => ctx.SendMessageAsync(message));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using FluentAssertions;
|
||||
|
||||
namespace Microsoft.Agents.Workflows.UnitTests;
|
||||
|
||||
public class InProcessStateTests
|
||||
{
|
||||
private sealed class TurnToken
|
||||
{
|
||||
public int Count { get; }
|
||||
|
||||
public TurnToken() : this(0)
|
||||
{ }
|
||||
|
||||
private TurnToken(int count)
|
||||
{
|
||||
this.Count = count;
|
||||
}
|
||||
|
||||
public TurnToken Next => new(this.Count + 1);
|
||||
}
|
||||
|
||||
private sealed class StateTestExecutor<TState> : TestingExecutor<TurnToken, TurnToken>
|
||||
{
|
||||
private static Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>>[] WrapActions(ScopeKey stateKey, Func<TState?, TState?>[] stateActions)
|
||||
{
|
||||
Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>>[] result
|
||||
= new Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>>[stateActions.Length];
|
||||
|
||||
for (int i = 0; i < stateActions.Length; i++)
|
||||
{
|
||||
result[i] = CreateWrapperAsync(stateActions[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
Func<TurnToken, IWorkflowContext, CancellationToken, ValueTask<TurnToken>> CreateWrapperAsync(Func<TState?, TState?> action)
|
||||
{
|
||||
return
|
||||
async (turn, context, cancellation) =>
|
||||
{
|
||||
TState? state = await context.ReadStateAsync<TState>(stateKey.Key, stateKey.ScopeId.ScopeName)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
state = action(state);
|
||||
|
||||
await context.QueueStateUpdateAsync(stateKey.Key, state, stateKey.ScopeId.ScopeName);
|
||||
|
||||
return turn.Next;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public ScopeKey StateKey { get; }
|
||||
|
||||
public StateTestExecutor(ScopeKey stateKey, bool loop = false, params Func<TState?, TState?>[] stateActions)
|
||||
: base(stateKey.ScopeId.ExecutorId, loop, WrapActions(stateKey, stateActions))
|
||||
{
|
||||
this.StateKey = stateKey;
|
||||
}
|
||||
}
|
||||
|
||||
private static Func<int?, int?> CreateOrIncrement(int defaultValue = default)
|
||||
=> currState => currState.HasValue ? currState + 1 : defaultValue;
|
||||
|
||||
private static Func<int?, int?> ValidateState(int expectedValue, string? because = null, params object[] becauseArgs)
|
||||
=> (int? currState) =>
|
||||
{
|
||||
currState.Should().Be(expectedValue, because, becauseArgs);
|
||||
|
||||
return currState;
|
||||
};
|
||||
|
||||
private static Func<object?, bool> MaxTurns(int maxTurns)
|
||||
=> (object? maybeTurn) => maybeTurn is not TurnToken turn || turn.Count < maxTurns;
|
||||
|
||||
[Fact]
|
||||
public async Task InProcessRun_StateShouldPersist_NotCheckpointedAsync()
|
||||
{
|
||||
StateTestExecutor<int?> writer = new(
|
||||
new ScopeKey("Writer", "TestScope", "TestKey"),
|
||||
loop: false,
|
||||
CreateOrIncrement(),
|
||||
CreateOrIncrement()
|
||||
);
|
||||
|
||||
StateTestExecutor<int?> validator = new(
|
||||
new ScopeKey("Validator", "TestScope", "TestKey"),
|
||||
loop: false,
|
||||
ValidateState(0),
|
||||
ValidateState(1)
|
||||
);
|
||||
|
||||
Workflow<TurnToken> workflow =
|
||||
new WorkflowBuilder(writer)
|
||||
.AddEdge(writer, validator, MaxTurns(4))
|
||||
.AddEdge(validator, writer, MaxTurns(4)).Build<TurnToken>();
|
||||
|
||||
Run run = await InProcessExecution.RunAsync(workflow, new());
|
||||
|
||||
run.Status.Should().Be(RunStatus.Idle);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InProcessRun_StateShouldPersist_CheckpointedAsync()
|
||||
{
|
||||
StateTestExecutor<int?> writer = new(
|
||||
new ScopeKey("Writer", "TestScope", "TestKey"),
|
||||
loop: false,
|
||||
CreateOrIncrement(),
|
||||
CreateOrIncrement()
|
||||
);
|
||||
|
||||
StateTestExecutor<int?> validator = new(
|
||||
new ScopeKey("Validator", "TestScope", "TestKey"),
|
||||
loop: false,
|
||||
ValidateState(0),
|
||||
ValidateState(1)
|
||||
);
|
||||
|
||||
Workflow<TurnToken> workflow =
|
||||
new WorkflowBuilder(writer)
|
||||
.AddEdge(writer, validator, MaxTurns(4))
|
||||
.AddEdge(validator, writer, MaxTurns(4)).Build<TurnToken>();
|
||||
|
||||
Checkpointed<Run> checkpointed = await InProcessExecution.RunAsync(workflow, new(), new CheckpointManager());
|
||||
|
||||
checkpointed.Checkpoints.Should().HaveCount(6);
|
||||
checkpointed.Run.Status.Should().Be(RunStatus.Idle);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InProcessRun_StateShouldError_TwoExecutorsAsync()
|
||||
{
|
||||
ForwardMessageExecutor<TurnToken> forward = new();
|
||||
using StateTestExecutor<int?> testExecutor = new(
|
||||
new ScopeKey("StateTestExecutor", "TestScope", "TestKey"),
|
||||
loop: false,
|
||||
CreateOrIncrement()
|
||||
);
|
||||
|
||||
using StateTestExecutor<int?> testExecutor2 = new(
|
||||
new ScopeKey("StateTestExecutor2", "TestScope", "TestKey"),
|
||||
loop: false,
|
||||
CreateOrIncrement()
|
||||
);
|
||||
|
||||
Workflow<TurnToken> workflow =
|
||||
new WorkflowBuilder(forward)
|
||||
.AddFanOutEdge(forward, targets: [testExecutor, testExecutor2])
|
||||
.Build<TurnToken>();
|
||||
|
||||
var act = async () => await InProcessExecution.RunAsync(workflow, new());
|
||||
|
||||
var result = await act.Should()
|
||||
.ThrowAsync("multiple writers to the same shared scope key");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.Agents.Workflows.UnitTests;
|
||||
|
||||
internal class TestingExecutor<TIn, TOut> : Executor, IDisposable
|
||||
{
|
||||
private readonly bool _loop;
|
||||
private readonly Func<TIn, IWorkflowContext, CancellationToken, ValueTask<TOut>>[] _actions;
|
||||
private readonly HashSet<CancellationToken> _linkedTokens = new();
|
||||
private CancellationTokenSource _internalCts = new();
|
||||
|
||||
public TestingExecutor(string? id = null, bool loop = false, params Func<TIn, IWorkflowContext, CancellationToken, ValueTask<TOut>>[] actions) : base(id)
|
||||
{
|
||||
this._loop = loop;
|
||||
this._actions = actions;
|
||||
}
|
||||
|
||||
public void UnlinkCancellation(CancellationToken token)
|
||||
{
|
||||
this._linkedTokens.Remove(token);
|
||||
}
|
||||
|
||||
public void LinkCancellation(CancellationToken token)
|
||||
{
|
||||
this._linkedTokens.Add(token);
|
||||
CancellationTokenSource tokenSource = CancellationTokenSource.CreateLinkedTokenSource(this._linkedTokens.ToArray());
|
||||
tokenSource = Interlocked.Exchange(ref this._internalCts, tokenSource);
|
||||
tokenSource.Dispose();
|
||||
}
|
||||
|
||||
public void SetCancel()
|
||||
{
|
||||
Volatile.Read(ref this._internalCts).Cancel();
|
||||
}
|
||||
|
||||
protected sealed override RouteBuilder ConfigureRoutes(RouteBuilder routeBuilder)
|
||||
{
|
||||
return routeBuilder.AddHandler<TIn, TOut>(this.RouteToActions);
|
||||
}
|
||||
|
||||
private int _nextActionIndex = 0;
|
||||
private ValueTask<TOut> RouteToActions(TIn message, IWorkflowContext context)
|
||||
{
|
||||
if (this._nextActionIndex >= this._actions.Length)
|
||||
{
|
||||
if (this._loop)
|
||||
{
|
||||
this._nextActionIndex = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("No more actions to execute and looping is disabled.");
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
Func<TIn, IWorkflowContext, CancellationToken, ValueTask<TOut>> action = this._actions[this._nextActionIndex];
|
||||
return action(message, context, Volatile.Read(ref this._internalCts).Token);
|
||||
}
|
||||
finally
|
||||
{
|
||||
this._nextActionIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
~TestingExecutor()
|
||||
{
|
||||
this.Dispose(false);
|
||||
}
|
||||
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
this._internalCts.Dispose();
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
this.Dispose(true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user