.NET: [BREAKING] Support Checkpoint Serialization (#735)

* feat: Support Checkpoint Serialization

* Implements serialization roundtripping for checkpoints.
* Adds support for JSON serialization
* Adds FileSystem-based checkpoint persistence

* fix: Executor State does not deserialize correctly

The StateManager was not properly handling delay-deserialized values.

* Fix PortableValue handling in StateManager (this makes it delegate to PortableValue the uwnrapping)
* Fix UnitTest to actually test checkpoint serialization
* Additional review comment fixes

---------

Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
This commit is contained in:
Jacob Alber
2025-09-17 15:57:19 -04:00
committed by GitHub
Unverified
parent 2a04c4197e
commit 2015f0dc09
90 changed files with 3011 additions and 341 deletions
+3
View File
@@ -436,3 +436,6 @@ resharper_redundant_using_directive_highlighting = warning # Resharper's "Redund
resharper_inconsistent_naming_highlighting = warning # Resharper's "Inconsistent naming" highlighting
resharper_redundant_this_qualifier_highlighting = warning # Resharper's "Redundant 'this' qualifier" highlighting
resharper_arrange_this_qualifier_highlighting = warning # Resharper's "Arrange 'this' qualifier" highlighting
csharp_style_prefer_primary_constructors = true:suggestion
csharp_prefer_system_threading_lock = true:suggestion
csharp_style_prefer_simple_property_accessors = true:suggestion
@@ -32,7 +32,7 @@ public static class Program
var workflow = WorkflowHelper.GetWorkflow();
// Create checkpoint manager
var checkpointManager = new CheckpointManager();
var checkpointManager = CheckpointManager.Default;
var checkpoints = new List<CheckpointInfo>();
// Execute the workflow and save checkpoints
@@ -31,7 +31,7 @@ public static class Program
var workflow = WorkflowHelper.GetWorkflow();
// Create checkpoint manager
var checkpointManager = new CheckpointManager();
var checkpointManager = CheckpointManager.Default;
var checkpoints = new List<CheckpointInfo>();
// Execute the workflow and save checkpoints
@@ -34,7 +34,7 @@ public static class Program
var workflow = WorkflowHelper.GetWorkflow();
// Create checkpoint manager
var checkpointManager = new CheckpointManager();
var checkpointManager = CheckpointManager.Default;
var checkpoints = new List<CheckpointInfo>();
// Execute the workflow and save checkpoints
@@ -102,9 +102,9 @@ public static class Program
private static ExternalResponse HandleExternalRequest(ExternalRequest request)
{
if (request.Port.Request == typeof(SignalWithNumber))
var signal = request.DataAs<SignalWithNumber>();
if (signal is not null)
{
var signal = (SignalWithNumber)request.Data;
switch (signal.Signal)
{
case NumberSignal.Init:
@@ -119,7 +119,7 @@ public static class Program
}
}
throw new NotSupportedException($"Request {request.Port.Request} is not supported");
throw new NotSupportedException($"Request {request.PortInfo.RequestType} is not supported");
}
private static int ReadIntegerFromConsole(string prompt)
@@ -75,10 +75,10 @@ public static class Program
// After the email assistant writes a response, it will be sent to the send email executor
.AddEdge(emailAssistantExecutor, sendEmailExecutor)
// Save the analysis result to the database if summary is not needed
.AddEdge(
.AddEdge<AnalysisResult>(
emailAnalysisExecutor,
databaseAccessExecutor,
condition: analysisResult => analysisResult is AnalysisResult result && result.EmailLength <= LongEmailThreshold)
condition: analysisResult => analysisResult is not null && analysisResult.EmailLength <= LongEmailThreshold)
// Save the analysis result to the database with summary
.AddEdge(emailSummaryExecutor, databaseAccessExecutor);
var workflow = builder.Build<ChatMessage>();
@@ -107,21 +107,21 @@ public static class Program
/// Creates a partitioner for routing messages based on the analysis result.
/// </summary>
/// <returns>A function that takes an analysis result and returns the target partitions.</returns>
private static Func<object?, int, IEnumerable<int>> GetPartitioner()
private static Func<AnalysisResult?, int, IEnumerable<int>> GetPartitioner()
{
return (analysisResult, targetCount) =>
{
if (analysisResult is AnalysisResult result)
if (analysisResult is not null)
{
if (result.spamDecision == SpamDecision.Spam)
if (analysisResult.spamDecision == SpamDecision.Spam)
{
return [0]; // Route to spam handler
}
else if (result.spamDecision == SpamDecision.NotSpam)
else if (analysisResult.spamDecision == SpamDecision.NotSpam)
{
List<int> targets = [1]; // Route to the email assistant
if (result.EmailLength > LongEmailThreshold)
if (analysisResult.EmailLength > LongEmailThreshold)
{
targets.Add(2); // Route to the email summarizer too
}
@@ -53,7 +53,7 @@ internal sealed class Program
// Run the workflow, just like any other workflow
string input = this.GetWorkflowInput();
CheckpointManager checkpointManager = new();
CheckpointManager checkpointManager = CheckpointManager.Default;
Checkpointed<StreamingRun> run = await InProcessExecution.StreamAsync(workflow, input, checkpointManager);
bool isComplete = false;
@@ -151,7 +151,7 @@ internal sealed class Program
Debug.WriteLine($"ACTION EXIT #{actionComplete.ActionId} [{actionComplete.ActionType}]");
break;
case ExecutorFailureEvent executorFailure:
case ExecutorFailedEvent executorFailure:
Debug.WriteLine($"STEP ERROR #{executorFailure.ExecutorId}: {executorFailure.Data?.Message ?? "Unknown"}");
break;
@@ -256,7 +256,7 @@ internal sealed class Program
}
private static InputResponse HandleExternalRequest(ExternalRequest request)
{
InputRequest? message = request.Data as InputRequest;
InputRequest? message = request.Data.As<InputRequest>();
string? userInput = null;
do
{
@@ -50,9 +50,9 @@ public static class Program
private static ExternalResponse HandleExternalRequest(ExternalRequest request)
{
if (request.Port.Request == typeof(NumberSignal))
if (request.DataIs<NumberSignal>())
{
var signal = (NumberSignal)request.Data;
var signal = request.DataAs<NumberSignal>();
switch (signal)
{
case NumberSignal.Init:
@@ -67,7 +67,7 @@ public static class Program
}
}
throw new NotSupportedException($"Request {request.Port.Request} is not supported");
throw new NotSupportedException($"Request {request.PortInfo.RequestType} is not supported");
}
private static int ReadIntegerFromConsole(string prompt)
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows;
@@ -10,14 +12,29 @@ namespace Microsoft.Agents.Workflows;
public class CheckpointInfo : IEquatable<CheckpointInfo>
{
/// <summary>
/// The unique identifier for the checkpoint.
/// Gets the unique identifier for the current run.
/// </summary>
public string CheckpointId { get; } = Guid.NewGuid().ToString("N");
public string RunId { get; }
/// <summary>
/// The date and time when the object was created, in Coordinated Universal Time (UTC).
/// The unique identifier for the checkpoint.
/// </summary>
public DateTimeOffset CreatedAt { get; } = DateTimeOffset.UtcNow;
public string CheckpointId { get; }
/// <summary>
/// Initializes a new instance of the <see cref="CheckpointInfo"/> class with a unique identifier and the current
/// UTC timestamp.
/// </summary>
/// <remarks>This constructor generates a new unique identifier using a GUID in a 32-character, lowercase,
/// hexadecimal format and sets the timestamp to the current UTC time.</remarks>
internal CheckpointInfo(string runId) : this(runId, Guid.NewGuid().ToString("N")) { }
[JsonConstructor]
internal CheckpointInfo(string runId, string checkpointId)
{
this.RunId = Throw.IfNullOrEmpty(runId);
this.CheckpointId = Throw.IfNullOrEmpty(checkpointId);
}
/// <inheritdoc/>
public bool Equals(CheckpointInfo? other)
@@ -27,8 +44,7 @@ public class CheckpointInfo : IEquatable<CheckpointInfo>
return false;
}
return this.CheckpointId == other.CheckpointId &&
this.CreatedAt == other.CreatedAt;
return this.RunId == other.RunId && this.CheckpointId == other.CheckpointId;
}
/// <inheritdoc/>
@@ -40,9 +56,9 @@ public class CheckpointInfo : IEquatable<CheckpointInfo>
/// <inheritdoc/>
public override int GetHashCode()
{
return HashCode.Combine(this.CheckpointId, this.CreatedAt);
return HashCode.Combine(this.RunId, this.CheckpointId);
}
/// <inheritdoc/>
public override string ToString() => $"CheckpointId: {this.CheckpointId}, CreatedAt: {this.CreatedAt:O}";
public override string ToString() => $"CheckpointInfo(RunId: {this.RunId}, CheckpointId: {this.CheckpointId})";
}
@@ -1,36 +1,57 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// An in-memory implementation of <see cref="ICheckpointManager"/> that stores checkpoints in a dictionary.
/// A manager for storing and retrieving workflow execution checkpoints.
/// </summary>
public sealed class CheckpointManager : ICheckpointManager
{
private readonly Dictionary<CheckpointInfo, Checkpoint> _checkpoints = new();
private readonly ICheckpointManager _impl;
ValueTask<CheckpointInfo> ICheckpointManager.CommitCheckpointAsync(Checkpoint checkpoint)
private static CheckpointManagerImpl<TStoreObject> CreateImpl<TStoreObject>(
IWireMarshaller<TStoreObject> marshaller,
ICheckpointStore<TStoreObject> store)
{
Throw.IfNull(checkpoint);
this._checkpoints[checkpoint] = checkpoint;
return new(checkpoint);
return new CheckpointManagerImpl<TStoreObject>(marshaller, store);
}
ValueTask<Checkpoint> ICheckpointManager.LookupCheckpointAsync(CheckpointInfo checkpointInfo)
private CheckpointManager(ICheckpointManager impl)
{
Throw.IfNull(checkpointInfo);
if (!this._checkpoints.TryGetValue(checkpointInfo, out Checkpoint? checkpoint))
{
throw new KeyNotFoundException($"Checkpoint not found: {checkpointInfo}");
}
return new ValueTask<Checkpoint>(checkpoint);
this._impl = impl;
}
/// <summary>
/// Creates a new instance of <see cref="ICheckpointManager"/> that uses the specified marshaller and store.
/// </summary>
/// <returns></returns>
public static CheckpointManager CreateInMemory() => new(new InMemoryCheckpointManager());
/// <summary>
/// Gets the default in-memory checkpoint manager instance.
/// </summary>
public static CheckpointManager Default { get; } = CreateInMemory();
/// <summary>
/// Creates a new instance of the CheckpointManager that uses JSON serialization for checkpoint data.
/// </summary>
/// <param name="store">The checkpoint store to use for persisting and retrieving checkpoint data as JSON elements. Cannot be null.</param>
/// <param name="customOptions">Optional custom JSON serializer options to use for serialization and deserialization. Must be provided if
/// using custom types in messages or state.</param>
/// <returns>A CheckpointManager instance configured to serialize checkpoint data as JSON.</returns>
public static CheckpointManager CreateJson(ICheckpointStore<JsonElement> store, JsonSerializerOptions? customOptions = null)
{
JsonMarshaller marshaller = new(customOptions);
return new(CreateImpl(marshaller, store));
}
ValueTask<CheckpointInfo> ICheckpointManager.CommitCheckpointAsync(string runId, Checkpoint checkpoint)
=> this._impl.CommitCheckpointAsync(runId, checkpoint);
ValueTask<Checkpoint> ICheckpointManager.LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo)
=> this._impl.LookupCheckpointAsync(runId, checkpointInfo);
}
@@ -1,33 +1,40 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Execution;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class Checkpoint : CheckpointInfo
internal class Checkpoint
{
[JsonConstructor]
internal Checkpoint(
int stepNumber,
WorkflowInfo workflow,
RunnerStateData runnerData,
Dictionary<ScopeKey, ExportedState> stateData,
Dictionary<EdgeConnection, ExportedState> edgeStateData)
Dictionary<ScopeKey, PortableValue> stateData,
Dictionary<EdgeId, PortableValue> edgeStateData,
CheckpointInfo? parent = null)
{
this.StepNumber = Throw.IfLessThan(stepNumber, -1); // -1 is a special flag indicating the initial checkpoint.
this.Workflow = Throw.IfNull(workflow);
this.RunnerData = Throw.IfNull(runnerData);
this.State = Throw.IfNull(stateData);
this.EdgeState = Throw.IfNull(edgeStateData);
this.StateData = Throw.IfNull(stateData);
this.EdgeStateData = Throw.IfNull(edgeStateData);
this.Parent = parent;
}
[JsonIgnore]
public bool IsInitial => this.StepNumber == -1;
public int StepNumber { get; }
public WorkflowInfo Workflow { get; }
public RunnerStateData RunnerData { get; }
public readonly Dictionary<ScopeKey, ExportedState> State = new();
public readonly Dictionary<EdgeConnection, ExportedState> EdgeState = new();
public Dictionary<ScopeKey, PortableValue> StateData { get; } = new();
public Dictionary<EdgeId, PortableValue> EdgeStateData { get; } = new();
public CheckpointInfo? Parent { get; }
}
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Threading.Tasks;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal sealed class CheckpointManagerImpl<TStoreObject> : ICheckpointManager
{
private readonly IWireMarshaller<TStoreObject> _marshaller;
private readonly ICheckpointStore<TStoreObject> _store;
public CheckpointManagerImpl(IWireMarshaller<TStoreObject> marshaller, ICheckpointStore<TStoreObject> store)
{
this._marshaller = marshaller;
this._store = store;
}
public ValueTask<CheckpointInfo> CommitCheckpointAsync(string runId, Checkpoint checkpoint)
{
TStoreObject storeObject = this._marshaller.Marshal(checkpoint);
return this._store.CreateCheckpointAsync(runId, storeObject, checkpoint.Parent);
}
public async ValueTask<Checkpoint> LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo)
{
TStoreObject result = await this._store.RetrieveCheckpointAsync(runId, checkpointInfo).ConfigureAwait(false);
return this._marshaller.Marshal<Checkpoint>(result);
}
}
@@ -1,12 +1,29 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Execution;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class DirectEdgeInfo(DirectEdgeData data) : EdgeInfo(Edge.Type.Direct, data.Connection)
/// <summary>
/// Represents a direct <see cref="Edge"/> in the <see cref="Workflow"/>.
/// </summary>
public sealed class DirectEdgeInfo : EdgeInfo
{
public bool HasCondition => data.Condition != null;
internal DirectEdgeInfo(DirectEdgeData data) : this(data.Condition != null, data.Connection) { }
protected override bool IsMatchInternal(EdgeData edgeData)
[JsonConstructor]
internal DirectEdgeInfo(bool hasCondition, EdgeConnection connection) : base(EdgeKind.Direct, connection)
{
this.HasCondition = hasCondition;
}
/// <summary>
/// Gets a value indicating whether this direct edge has a condition associated with it.
/// </summary>
public bool HasCondition { get; }
internal override bool IsMatchInternal(EdgeData edgeData)
{
return edgeData is DirectEdgeData directEdge
&& this.HasCondition == (directEdge.Condition != null);
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Provides support for using <see cref="EdgeId"/> values as dictionary keys when serializing and deserializing JSON.
/// </summary>
internal sealed class EdgeIdConverter : JsonConverterDictionarySupportBase<EdgeId>
{
protected override JsonTypeInfo<EdgeId> TypeInfo => WorkflowsJsonUtilities.JsonContext.Default.EdgeId;
protected override EdgeId Parse(string propertyName)
{
if (int.TryParse(propertyName, out int edgeId))
{
return new(edgeId);
}
throw new JsonException($"Cannot deserialize EdgeId from JSON propery name '{propertyName}'");
}
protected override string Stringify([DisallowNull] EdgeId value)
{
return value.EdgeIndex.ToString();
}
}
@@ -1,21 +1,43 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Execution;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal abstract class EdgeInfo(Edge.Type edgeType, EdgeConnection connection)
/// <summary>
/// Base class representing information about an edge in a workflow.
/// </summary>
[JsonPolymorphic(UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FailSerialization)]
[JsonDerivedType(typeof(DirectEdgeInfo), (int)EdgeKind.Direct)]
[JsonDerivedType(typeof(FanOutEdgeInfo), (int)EdgeKind.FanOut)]
[JsonDerivedType(typeof(FanInEdgeInfo), (int)EdgeKind.FanIn)]
public class EdgeInfo
{
public Edge.Type EdgeType => edgeType;
public EdgeConnection Connection { get; } = Throw.IfNull(connection);
/// <summary>
/// The kind of edge.
/// </summary>
public EdgeKind Kind { get; }
public bool IsMatch(Edge edge)
/// <summary>
/// Gets the connection information associated with the edge.
/// </summary>
public EdgeConnection Connection { get; }
[JsonConstructor]
internal EdgeInfo(EdgeKind kind, EdgeConnection connection)
{
return this.EdgeType == edge.EdgeType
this.Kind = kind;
this.Connection = Throw.IfNull(connection);
}
internal bool IsMatch(Edge edge)
{
return this.Kind == edge.Kind
&& this.Connection.Equals(edge.Data.Connection)
&& this.IsMatchInternal(edge.Data);
}
protected virtual bool IsMatchInternal(EdgeData edgeData) => true;
internal virtual bool IsMatchInternal(EdgeData edgeData) => true;
}
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Agents.Workflows.Execution;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Provides support for using <see cref="ExecutorIdentity"/> values as dictionary keys when serializing and deserializing JSON.
/// </summary>
internal sealed class ExecutorIdentityConverter() : JsonConverterDictionarySupportBase<ExecutorIdentity>
{
protected override JsonTypeInfo<ExecutorIdentity> TypeInfo
=> WorkflowsJsonUtilities.JsonContext.Default.ExecutorIdentity;
protected override ExecutorIdentity Parse(string propertyName)
{
if (propertyName.Length == 0)
{
return ExecutorIdentity.None;
}
if (propertyName[0] == '@')
{
return new() { Id = propertyName.Substring(1) };
}
throw new JsonException($"Invalid ExecutorIdentity key Expecting empty string or a value that is prefixed with '@'. Got '{propertyName}'");
}
protected override string Stringify(ExecutorIdentity value)
{
return value == ExecutorIdentity.None
? string.Empty
: $"@{value.Id}";
}
}
@@ -1,12 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class ExportedState(object state)
{
public Type RuntimeType => Throw.IfNull(state).GetType();
public object Value => Throw.IfNull(state);
}
@@ -1,5 +1,21 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Execution;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class FanInEdgeInfo(FanInEdgeData data) : EdgeInfo(Edge.Type.FanIn, data.Connection);
/// <summary>
/// Represents a fan-in <see cref="Edge"/> in the <see cref="Workflow"/>.
/// </summary>
public sealed class FanInEdgeInfo : EdgeInfo
{
internal FanInEdgeInfo(FanInEdgeData data) : base(EdgeKind.FanIn, data.Connection)
{
}
[JsonConstructor]
internal FanInEdgeInfo(EdgeConnection connection) : base(EdgeKind.FanIn, connection)
{
}
}
@@ -1,12 +1,29 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Execution;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class FanOutEdgeInfo(FanOutEdgeData data) : EdgeInfo(Edge.Type.FanOut, data.Connection)
/// <summary>
/// Represents a fan-out <see cref="Edge"/> in the <see cref="Workflow"/>.
/// </summary>
public sealed class FanOutEdgeInfo : EdgeInfo
{
public bool HasAssigner => data.EdgeAssigner != null;
internal FanOutEdgeInfo(FanOutEdgeData data) : this(data.EdgeAssigner != null, data.Connection) { }
protected override bool IsMatchInternal(EdgeData edgeData)
[JsonConstructor]
internal FanOutEdgeInfo(bool hasAssigner, EdgeConnection connection) : base(EdgeKind.FanOut, connection)
{
this.HasAssigner = hasAssigner;
}
/// <summary>
/// Gets a value indicating whether this fan-out edge has an edge-assigner associated with it.
/// </summary>
public bool HasAssigner { get; }
internal override bool IsMatchInternal(EdgeData edgeData)
{
return edgeData is FanOutEdgeData fanOutEdge
&& this.HasAssigner == (fanOutEdge.EdgeAssigner != null);
@@ -0,0 +1,167 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Provides a file system-based implementation of a JSON checkpoint store that persists checkpoint data and index
/// information to disk using JSON files.
/// </summary>
/// <remarks>This class manages checkpoint storage by writing JSON files to a specified directory and maintaining
/// an index file for efficient retrieval. It is intended for scenarios where durable, process-exclusive checkpoint
/// persistence is required. Instances of this class are not thread-safe and should not be shared across multiple
/// threads without external synchronization. The class implements IDisposable; callers should ensure Dispose is called
/// to release file handles and system resources when the store is no longer needed.</remarks>
public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDisposable
{
[System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "CA2213:Disposable fields should be disposed",
Justification = "It is disposed, the analyzer is just not picking it up properly")]
private FileStream? _indexFile;
internal DirectoryInfo Directory { get; }
internal HashSet<CheckpointInfo> CheckpointIndex { get; }
/// <summary>
/// Initializes a new instance of the <see cref="FileSystemJsonCheckpointStore"/> class that uses the specified directory
/// </summary>
/// <param name="directory"></param>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="InvalidOperationException"></exception>
public FileSystemJsonCheckpointStore(DirectoryInfo directory)
{
this.Directory = directory ?? throw new ArgumentNullException(nameof(directory));
if (!directory.Exists)
{
directory.Create();
}
try
{
this._indexFile = File.Open(Path.Combine(directory.FullName, "index.jsonl"), FileMode.OpenOrCreate, FileAccess.ReadWrite, FileShare.None);
}
catch
{
throw new InvalidOperationException($"The store at '{directory.FullName}' is already in use by another process.");
}
try
{
// read the lines of indexfile and parse them as CheckpointInfos
this.CheckpointIndex = new HashSet<CheckpointInfo>();
using StreamReader reader = new(this._indexFile, encoding: Encoding.UTF8, detectEncodingFromByteOrderMarks: false, bufferSize: -1, leaveOpen: true);
while (reader.ReadLine() is string line)
{
CheckpointInfo? info = JsonSerializer.Deserialize<CheckpointInfo>(line, this.KeyTypeInfo);
if (info != null)
{
this.CheckpointIndex.Add(info);
}
}
}
catch
{
throw new InvalidOperationException($"Could not load store at '{directory.FullName}'. Index corrupted.");
}
}
/// <inheritdoc/>
public void Dispose()
{
FileStream? indexFileLocal = Interlocked.Exchange(ref this._indexFile, null);
indexFileLocal?.Dispose();
}
[System.Diagnostics.CodeAnalysis.SuppressMessage("Maintainability", "CA1513:Use ObjectDisposedException throw helper",
Justification = "Throw helper does not exist in NetFx 4.7.2")]
private void CheckDisposed()
{
if (this._indexFile == null)
{
throw new ObjectDisposedException($"{nameof(FileSystemJsonCheckpointStore)}({this.Directory.FullName})");
}
}
private string GetFileNameForCheckpoint(string runId, CheckpointInfo key)
=> Path.Combine(this.Directory.FullName, $"{runId}_{key.CheckpointId}.json");
private CheckpointInfo GetUnusedCheckpointInfo(string runId)
{
CheckpointInfo key;
do
{
key = new(runId);
} while (!this.CheckpointIndex.Add(key));
return key;
}
/// <inheritdoc/>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1835:Prefer the 'Memory'-based overloads for 'ReadAsync' and 'WriteAsync'",
Justification = "Memory-based overload is missing for 4.7.2")]
public override async ValueTask<CheckpointInfo> CreateCheckpointAsync(string runId, JsonElement value, CheckpointInfo? parent = null)
{
this.CheckDisposed();
CheckpointInfo key = this.GetUnusedCheckpointInfo(runId);
string fileName = this.GetFileNameForCheckpoint(runId, key);
try
{
using Stream checkpointStream = File.Open(fileName, FileMode.Create, FileAccess.Write, FileShare.None);
using Utf8JsonWriter jsonWriter = new(checkpointStream, new JsonWriterOptions() { Indented = false });
value.WriteTo(jsonWriter);
JsonSerializer.Serialize(this._indexFile!, key, this.KeyTypeInfo);
byte[] bytes = Encoding.UTF8.GetBytes(Environment.NewLine);
await this._indexFile!.WriteAsync(bytes, 0, bytes.Length, CancellationToken.None).ConfigureAwait(false);
return key;
}
catch (Exception ex)
{
this.CheckpointIndex.Remove(key);
try
{
// try to clean up after ourselves
File.Delete(fileName);
}
catch { }
throw new InvalidOperationException($"Could not create checkpoint in store at '{this.Directory.FullName}'.", ex);
}
}
/// <inheritdoc/>
public override async ValueTask<JsonElement> RetrieveCheckpointAsync(string runId, CheckpointInfo key)
{
this.CheckDisposed();
string fileName = this.GetFileNameForCheckpoint(runId, key);
if (!this.CheckpointIndex.Contains(key) ||
!File.Exists(fileName))
{
throw new KeyNotFoundException($"Checkpoint '{key.CheckpointId}' not found in store at '{this.Directory.FullName}'.");
}
using FileStream checkpointFileStream = File.Open(fileName, FileMode.Open, FileAccess.Read, FileShare.Read);
using JsonDocument document = await JsonDocument.ParseAsync(checkpointFileStream).ConfigureAwait(false);
return document.RootElement.Clone();
}
/// <inheritdoc/>
public override ValueTask<IEnumerable<CheckpointInfo>> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null)
{
this.CheckDisposed();
return new(this.CheckpointIndex);
}
}
@@ -13,16 +13,18 @@ internal interface ICheckpointManager
/// <summary>
/// Commits the specified checkpoint and returns information that can be used to retrieve it later.
/// </summary>
/// <param name="checkpoint">The <see cref="Checkpoint"/> to be committed.</param>
/// <param name="runId">The identifier for the current run or execution context.</param>
/// <param name="checkpoint">The checkpoint to commit.</param>
/// <returns>A <see cref="CheckpointInfo"/> representing the incoming checkpoint.</returns>
ValueTask<CheckpointInfo> CommitCheckpointAsync(Checkpoint checkpoint);
ValueTask<CheckpointInfo> CommitCheckpointAsync(string runId, Checkpoint checkpoint);
/// <summary>
/// Retrieves the checkpoint associated with the specified checkpoint information.
/// </summary>
/// <param name="runId">The identifier for the current run of execution context.</param>
/// <param name="checkpointInfo">The information used to identify the checkpoint.</param>
/// <returns>A <see cref="ValueTask{TResult}"/> representing the asynchronous operation. The result contains the <see
/// cref="Checkpoint"/> associated with the specified <paramref name="checkpointInfo"/>.</returns>
/// <exception cref="KeyNotFoundException">Thrown if the checkpoint is not found.</exception>
ValueTask<Checkpoint> LookupCheckpointAsync(CheckpointInfo checkpointInfo);
ValueTask<Checkpoint> LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo);
}
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Defines a contract for storing and retrieving checkpoints associated with a specific run and key.
/// </summary>
/// <remarks>Implementations of this interface enable durable or in-memory storage of checkpoints, which can be
/// used to resume or audit long-running processes. The interface is generic to support different storage object types
/// depending on the application's requirements.</remarks>
/// <typeparam name="TStoreObject">The type of object to be stored as the value for each checkpoint.</typeparam>
public interface ICheckpointStore<TStoreObject>
{
/// <summary>
/// Asynchronously retrieves the collection of checkpoint information for the specified run identifier, optionally
/// filtered by a parent checkpoint.
/// </summary>
/// <param name="runId">The unique identifier of the run for which to retrieve checkpoint information. Cannot be null or empty.</param>
/// <param name="withParent">An optional parent checkpoint to filter the results. If specified, only checkpoints with the given parent are
/// returned; otherwise, all checkpoints for the run are included.</param>
/// <returns>A value task representing the asynchronous operation. The result contains a collection of <see
/// cref="CheckpointInfo"/> objects associated with the specified run. The collection is empty if no checkpoints are
/// found.</returns>
ValueTask<IEnumerable<CheckpointInfo>> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null);
/// <summary>
/// Asynchronously creates a checkpoint for the specified run and key, associating it with the provided value and
/// optional parent checkpoint.
/// </summary>
/// <param name="runId">The unique identifier of the run for which the checkpoint is being created. Cannot be null or empty.</param>
/// <param name="value">The value to associate with the checkpoint. Cannot be null.</param>
/// <param name="parent">The optional parent checkpoint information. If specified, the new checkpoint will be linked as a child of this
/// parent.</param>
/// <returns>A ValueTask that represents the asynchronous operation. The result contains the <see cref="CheckpointInfo"/>
/// object representing this stored checkpoint.</returns>
ValueTask<CheckpointInfo> CreateCheckpointAsync(string runId, TStoreObject value, CheckpointInfo? parent = null);
/// <summary>
/// Asynchronously retrieves a checkpoint object associated with the specified run and checkpoint key.
/// </summary>
/// <param name="runId">The unique identifier of the run for which the checkpoint is to be retrieved. Cannot be null or empty.</param>
/// <param name="key">The key identifying the specific checkpoint to retrieve. Cannot be null.</param>
/// <returns>A ValueTask that represents the asynchronous operation. The result contains the checkpoint object associated
/// with the specified run and key.</returns>
ValueTask<TStoreObject> RetrieveCheckpointAsync(string runId, CheckpointInfo key);
}
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Implements an abstraction across serialization mechanisms to represent a lazily-deserialized value.
///
/// This can be used when the target-type information is not known at time of initial deserialization.
/// </summary>
internal interface IDelayedDeserialization
{
/// <summary>
/// Attempt to deserialize the value as the provided type.
/// </summary>
/// <typeparam name="TValue"></typeparam>
/// <returns></returns>
TValue Deserialize<TValue>();
/// <summary>
/// Attempt to deserialize the value as the provided type.
/// </summary>
/// <param name="targetType"></param>
/// <returns></returns>
object? Deserialize(Type targetType);
}
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Defines methods for marshalling and unmarshalling objects to and from a wire format.
/// </summary>
/// <typeparam name="TWireContainer"></typeparam>
public interface IWireMarshaller<TWireContainer>
{
/// <summary>
/// Marshals the specified value of the given type into a wire format container.
/// </summary>
/// <param name="value"></param>
/// <param name="type"></param>
/// <returns></returns>
TWireContainer Marshal(object value, Type type);
/// <summary>
/// Marshals the specified value into a wire format container.
/// </summary>
/// <typeparam name="TValue"></typeparam>
/// <param name="value"></param>
/// <returns></returns>
TWireContainer Marshal<TValue>(TValue value);
/// <summary>
/// Unmarshals the specified wire format container into an object of the given type.
/// </summary>
/// <typeparam name="TValue"></typeparam>
/// <param name="data"></param>
/// <returns></returns>
TValue Marshal<TValue>(TWireContainer data);
/// <summary>
/// Unmarshals the specified wire format container into an object of the specified target type.
/// </summary>
/// <param name="targetType"></param>
/// <param name="data"></param>
/// <returns></returns>
object Marshal(Type targetType, TWireContainer data);
}
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// An in-memory implementation of <see cref="ICheckpointManager"/> that stores checkpoints in a dictionary.
/// </summary>
internal sealed class InMemoryCheckpointManager : ICheckpointManager
{
private readonly Dictionary<string, RunCheckpointCache<Checkpoint>> _store = new();
private RunCheckpointCache<Checkpoint> GetRunStore(string runId)
{
if (!this._store.TryGetValue(runId, out RunCheckpointCache<Checkpoint>? runStore))
{
runStore = this._store[runId] = new();
}
return runStore;
}
public ValueTask<CheckpointInfo> CommitCheckpointAsync(string runId, Checkpoint checkpoint)
{
RunCheckpointCache<Checkpoint> runStore = this.GetRunStore(runId);
CheckpointInfo key;
do
{
key = new(runId);
} while (!runStore.Add(key, checkpoint));
return new(key);
}
public ValueTask<Checkpoint> LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo)
{
if (!this.GetRunStore(runId).TryGet(checkpointInfo, out Checkpoint? value))
{
throw new KeyNotFoundException($"Could not retrieve checkpoint with id {checkpointInfo.CheckpointId} for run {runId}");
}
return new(value);
}
}
@@ -2,4 +2,12 @@
namespace Microsoft.Agents.Workflows.Checkpointing;
internal record class InputPortInfo(TypeId InputType, TypeId OutputType, string PortId);
/// <summary>
/// Information about an input port, including its input and output types.
/// </summary>
/// <param name="RequestType"></param>
/// <param name="ResponseType"></param>
/// <param name="PortId"></param>
public record class InputPortInfo(TypeId RequestType, TypeId ResponseType, string PortId)
{
}
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading.Tasks;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// An abstract base class for checkpoint stores that use JSON for serialization.
/// </summary>
public abstract class JsonCheckpointStore : ICheckpointStore<JsonElement>
{
/// <summary>
/// A default TypeInfo for serializing the <see cref="CheckpointInfo"/> type, if needed.
/// </summary>
protected JsonTypeInfo<CheckpointInfo> KeyTypeInfo => WorkflowsJsonUtilities.JsonContext.Default.CheckpointInfo;
/// <inheritdoc/>
public abstract ValueTask<CheckpointInfo> CreateCheckpointAsync(string runId, JsonElement value, CheckpointInfo? parent = null);
/// <inheritdoc/>
public abstract ValueTask<JsonElement> RetrieveCheckpointAsync(string runId, CheckpointInfo key);
/// <inheritdoc/>
public abstract ValueTask<IEnumerable<CheckpointInfo>> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null);
}
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Provides support for JSON serialization and deserialization using a specified JsonTypeInfo.
/// </summary>
/// <typeparam name="T"></typeparam>
internal abstract class JsonConverterBase<T> : JsonConverter<T>
{
protected abstract JsonTypeInfo<T> TypeInfo { get; }
public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
SequencePosition position = reader.Position;
T? maybeValue = JsonSerializer.Deserialize<T>(ref reader, this.TypeInfo);
if (maybeValue is null)
{
throw new JsonException($"Could not deserialize a {typeof(T).Name} from JSON at position {position}");
}
return maybeValue;
}
public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value, this.TypeInfo);
}
}
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Provides support for using <typeparamref name="T"/> values as dictionary keys when serializing and deserializing JSON.
/// It chains to the provided <see cref="JsonTypeInfo{T}"/> for serialization and deserialization when not used as a property
/// name.
/// </summary>
/// <typeparam name="T"></typeparam>
internal abstract class JsonConverterDictionarySupportBase<T> : JsonConverterBase<T>
{
protected abstract string Stringify([DisallowNull] T value);
protected abstract T Parse(string propertyName);
public override T ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
SequencePosition position = reader.Position;
string? propertyName = reader.GetString();
if (propertyName == null)
{
throw new JsonException($"Got null trying to read property name at position {position}");
}
return this.Parse(propertyName);
}
public override void WriteAsPropertyName(Utf8JsonWriter writer, [DisallowNull] T value, JsonSerializerOptions options)
{
string propertyName = this.Stringify(value);
writer.WritePropertyName(propertyName);
}
}
@@ -0,0 +1,79 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class JsonMarshaller : IWireMarshaller<JsonElement>
{
private readonly JsonSerializerOptions _internalOptions;
private readonly JsonSerializerOptions? _externalOptions;
public JsonMarshaller(JsonSerializerOptions? serializerOptions = null)
{
this._internalOptions = new JsonSerializerOptions(WorkflowsJsonUtilities.DefaultOptions);
this._internalOptions.Converters.Add(new PortableValueConverter(this));
this._internalOptions.Converters.Add(new ExecutorIdentityConverter());
this._internalOptions.Converters.Add(new ScopeKeyConverter());
this._internalOptions.Converters.Add(new EdgeIdConverter());
this._externalOptions = serializerOptions;
}
private JsonTypeInfo LookupTypeInfo(Type type)
{
if (!this._internalOptions.TryGetTypeInfo(type, out JsonTypeInfo? typeInfo))
{
if (this._externalOptions == null ||
!this._externalOptions.TryGetTypeInfo(type, out typeInfo))
{
throw new InvalidOperationException($"No JSON type info is available for type '{type}'.");
}
}
return typeInfo;
}
public JsonElement Marshal(object value, Type type)
=> JsonSerializer.SerializeToElement(value, this.LookupTypeInfo(type));
public JsonElement Marshal<TValue>(TValue value)
=> JsonSerializer.SerializeToElement(value, this.LookupTypeInfo(typeof(TValue)));
public TValue Marshal<TValue>(JsonElement data)
{
Type type = typeof(TValue);
object? value = JsonSerializer.Deserialize(data, this.LookupTypeInfo(type));
if (value is null)
{
throw new InvalidOperationException($"Could not deserialize the value as the expected type {typeof(TValue)}.");
}
if (value is TValue typedValue)
{
return typedValue;
}
throw new InvalidOperationException($"Deserialized value is not of the expected type {typeof(TValue)}.");
}
public object Marshal(Type targetType, JsonElement data)
{
object? value = JsonSerializer.Deserialize(data, this.LookupTypeInfo(targetType));
if (value is null)
{
throw new InvalidOperationException($"Could not deserialize the value as the expected type {targetType}.");
}
if (targetType.IsInstanceOfType(value))
{
return value;
}
throw new InvalidOperationException($"Deserialized value is not of the expected type {targetType}.");
}
}
@@ -0,0 +1,63 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Represents a value serialized to the JSON format (<see cref="JsonMarshaller"/>).
/// When type information is not available during deserialization, this will wrap a clone of the
/// <see cref="JsonElement"/> to be deserialized later.
/// </summary>
/// <param name="serializer"></param>
/// <param name="data"></param>
/// <seealso cref="PortableValue"/>
internal sealed class JsonWireSerializedValue(JsonMarshaller serializer, JsonElement data) : IDelayedDeserialization
{
internal JsonElement Data { get; } = data.Clone();
public TValue Deserialize<TValue>() => serializer.Marshal<TValue>(data);
public object? Deserialize(Type targetType) => serializer.Marshal(targetType, data);
public override bool Equals(object? obj)
{
if (obj == null)
{
return false;
}
if (obj is JsonWireSerializedValue otherValue)
{
return JsonElement.DeepEquals(this.Data, otherValue.Data);
}
else if (obj is JsonElement element)
{
return this.Data.Equals(element);
}
else if (obj is not IDelayedDeserialization)
{
// Assume this has the target type of deserialization; serialize it using the explicit type
// and compare. Of course, this also means that if this is a supertype, it could encounter
// truncation.
try
{
JsonElement otherElement = serializer.Marshal(obj, obj.GetType());
return JsonElement.DeepEquals(this.Data, otherElement);
}
catch
{
return false;
}
}
return false;
}
public override int GetHashCode()
{
return this.Data.GetHashCode();
}
}
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Execution;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal sealed class PortableMessageEnvelope
{
public TypeId MessageType { get; }
public PortableValue Message { get; }
public string? TargetId { get; }
[JsonConstructor]
internal PortableMessageEnvelope(TypeId messageType, PortableValue message, string? targetId)
{
this.MessageType = messageType;
this.Message = message;
this.TargetId = targetId;
}
public PortableMessageEnvelope(MessageEnvelope envelope)
{
this.MessageType = envelope.MessageType;
this.Message = new PortableValue(envelope.Message);
this.TargetId = envelope.TargetId;
}
public MessageEnvelope ToMessageEnvelope()
{
return new MessageEnvelope(this.Message, this.MessageType, this.TargetId);
}
}
@@ -0,0 +1,71 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Provides special handling for <see cref="PortableValue"/> serialization and deserialization, enabling delayed deserialization
/// of the inner value. This is used to enable serialization/deserialization of objects whose type information is not available
/// at the time of initial deserialization, e.g. user-defined state types.
///
/// This operates in conjuction with <see cref="IDelayedDeserialization"/> and <see cref="PortableValue"/> to abstract
/// away the speicfics of a given serialization format in favor of <see cref="PortableValue.As{TValue}"/> and
/// <see cref="PortableValue.Is{TValue}"/>.
/// </summary>
/// <param name="marshaller"></param>
internal sealed class PortableValueConverter(JsonMarshaller marshaller) : JsonConverter<PortableValue>
{
public override PortableValue? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
SequencePosition initial = reader.Position;
JsonTypeInfo<PortableValue> baseTypeInfo = WorkflowsJsonUtilities.JsonContext.Default.PortableValue;
PortableValue? maybeValue = JsonSerializer.Deserialize<PortableValue>(ref reader, baseTypeInfo);
if (maybeValue is null)
{
throw new JsonException($"Could not deserialize a PortableValue from JSON at position {initial}.");
}
else if (maybeValue.Value is JsonElement element)
{
// This happens when we do not have the type information available to deserialize the value directly.
// We need to wrap it in a JsonWireSerializedValue so that we can deserialize it
return new PortableValue(maybeValue.TypeId, new JsonWireSerializedValue(marshaller, element));
}
else if (maybeValue.TypeId.IsMatch(maybeValue.Value.GetType()))
{
return maybeValue;
}
throw new JsonException($"Deserialized PortableValue contains a value of type {maybeValue.Value.GetType()} which does not match the expected type {maybeValue.TypeId} at position {initial}.");
}
public override void Write(Utf8JsonWriter writer, PortableValue value, JsonSerializerOptions options)
{
PortableValue proxyValue;
if (value.IsDelayedDeserialization && !value.IsDeserialized)
{
if (value.Value is JsonWireSerializedValue jsonWireValue)
{
proxyValue = new(value.TypeId, jsonWireValue.Data);
}
else
{
// Users should never see this unless they're trying to cross wire formats
throw new InvalidOperationException("Cannot serialize a PortableValue that has not been deserialized. Please deserialize it with .As/AsType() or Is/IsType() methods first.");
}
}
else
{
JsonElement element = marshaller.Marshal(value.Value, value.Value.GetType());
proxyValue = new(value.TypeId, element);
}
JsonTypeInfo<PortableValue> baseTypeInfo = WorkflowsJsonUtilities.JsonContext.Default.PortableValue;
JsonSerializer.Serialize(writer, proxyValue, baseTypeInfo);
}
}
@@ -18,12 +18,12 @@ internal static class RepresentationExtensions
public static EdgeInfo ToEdgeInfo(this Edge edge)
{
Throw.IfNull(edge);
return edge.EdgeType switch
return edge.Kind switch
{
Edge.Type.Direct => new DirectEdgeInfo(edge.DirectEdgeData!),
Edge.Type.FanOut => new FanOutEdgeInfo(edge.FanOutEdgeData!),
Edge.Type.FanIn => new FanInEdgeInfo(edge.FanInEdgeData!),
_ => throw new NotSupportedException($"Unsupported edge type: {edge.EdgeType}")
EdgeKind.Direct => new DirectEdgeInfo(edge.DirectEdgeData!),
EdgeKind.FanOut => new FanOutEdgeInfo(edge.FanOutEdgeData!),
EdgeKind.FanIn => new FanInEdgeInfo(edge.FanInEdgeData!),
_ => throw new NotSupportedException($"Unsupported edge type: {edge.Kind}")
};
}
@@ -54,6 +54,6 @@ internal static class RepresentationExtensions
public static WorkflowInfo ToWorkflowInfo<TInput>(this Workflow<TInput> workflow)
=> workflow.ToWorkflowInfo(outputType: null, outputExecutorId: null);
public static WorkflowInfo GetInfo<TInput, TResult>(this Workflow<TInput, TResult> workflow)
public static WorkflowInfo ToWorkflowInfo<TInput, TResult>(this Workflow<TInput, TResult> workflow)
=> workflow.ToWorkflowInfo(outputType: new TypeId(typeof(TResult)), outputExecutorId: workflow.OutputCollectorId);
}
@@ -0,0 +1,40 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal sealed class RunCheckpointCache<TStoreObject>
{
private readonly HashSet<CheckpointInfo> _checkpointIndex = new();
private readonly Dictionary<CheckpointInfo, TStoreObject> _cache = new();
public IEnumerable<CheckpointInfo> Index => this._checkpointIndex;
public bool IsInIndex(CheckpointInfo key) => this._checkpointIndex.Contains(key);
public bool TryGet(CheckpointInfo key, [MaybeNullWhen(false)] out TStoreObject value) => this._cache.TryGetValue(key, out value);
public CheckpointInfo Add(string runId, TStoreObject value)
{
CheckpointInfo key;
do
{
key = new(runId);
} while (!this.Add(key, value));
return key;
}
public bool Add(CheckpointInfo key, TStoreObject value)
{
bool added = this._checkpointIndex.Add(key);
if (added)
{
this._cache[key] = value;
}
return added;
}
}
@@ -0,0 +1,93 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Text.RegularExpressions;
namespace Microsoft.Agents.Workflows.Checkpointing;
/// <summary>
/// Provides support for using <see cref="ScopeKey"/> values as dictionary keys when serializing and deserializing JSON.
/// </summary>
internal sealed class ScopeKeyConverter : JsonConverterDictionarySupportBase<ScopeKey>
{
protected override JsonTypeInfo<ScopeKey> TypeInfo => WorkflowsJsonUtilities.JsonContext.Default.ScopeKey;
public static readonly Regex ScopeKeyPropertyNamePattern =
new(@"^(?<executorId>(((\|\|)|([^\|]))*))\|(?<scopeName>(@(((\|\|)|([^\|]))*))?)\|(?<key>(((\|\|)|([^\|]))*)?)$",
RegexOptions.Compiled | RegexOptions.CultureInvariant | RegexOptions.ExplicitCapture);
protected override ScopeKey Parse(string propertyName)
{
Match scopeKeyPatternMatch = ScopeKeyPropertyNamePattern.Match(propertyName);
if (!scopeKeyPatternMatch.Success)
{
throw new JsonException($"Invalid ScopeKey property name format. Got '{propertyName}'.");
}
string executorId = scopeKeyPatternMatch.Groups["executorId"].Value;
string scopeName = scopeKeyPatternMatch.Groups["scopeName"].Value;
string key = scopeKeyPatternMatch.Groups["key"].Value;
return new ScopeKey(Unescape(executorId)!,
Unescape(scopeName, allowNullAndPad: true),
Unescape(key)!);
}
[return: NotNull]
private static string Escape(string? value, bool allowNullAndPad = false, [CallerArgumentExpression("value")] string componentName = "ScopeKey")
{
if (!allowNullAndPad && value == null)
{
throw new JsonException($"Invalid {componentName} '{value}'. Expecting non-null string.");
}
if (value == null)
{
return string.Empty;
}
if (allowNullAndPad)
{
return $"@{value.Replace("|", "||")}";
}
return $"{value.Replace("|", "||")}";
}
private static string? Unescape([DisallowNull] string value, bool allowNullAndPad = false, [CallerArgumentExpression("value")] string componentName = "ScopeKey")
{
if (value.Length == 0)
{
if (!allowNullAndPad)
{
throw new JsonException($"Invalid {componentName} '{value}'. Expecting empty string or a value that is prefixed with '@'.");
}
return null;
}
if (allowNullAndPad && value[0] != '@')
{
throw new JsonException($"Invalid {componentName} component '{value}'. Expecting empty string or a value that is prefixed with '@'.");
}
if (allowNullAndPad)
{
value = value.Substring(1);
}
return value.Replace("||", "|");
}
protected override string Stringify([DisallowNull] ScopeKey value)
{
string? executorIdEscaped = Escape(value.ScopeId.ExecutorId);
string? scopeNameEscaped = Escape(value.ScopeId.ScopeName, allowNullAndPad: true);
string? keyEscaped = Escape(value.Key);
return $"{executorIdEscaped}|{scopeNameEscaped}|{keyEscaped}";
}
}
@@ -1,20 +1,102 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class TypeId(Type type)
/// <summary>
/// A representation of a type's identity, including its assembly and type names.
/// </summary>
public class TypeId
{
public string AssemblyName => Throw.IfNull(type.Assembly.FullName);
public string TypeName => Throw.IfNull(type.FullName);
/// <inheritdoc cref="System.Reflection.Assembly.FullName"/>
public string AssemblyName { get; }
/// <inheritdoc cref="Type.FullName"/>
public string TypeName { get; }
/// <summary>
/// Initializes a new instance of the <see cref="TypeId"/> class.
/// </summary>
/// <param name="assemblyName"></param>
/// <param name="typeName"></param>
[JsonConstructor]
public TypeId(string assemblyName, string typeName)
{
this.AssemblyName = Throw.IfNull(assemblyName);
this.TypeName = Throw.IfNull(typeName);
}
/// <summary>
/// Initializes a new instance of the TypeId class using the specified type.
/// </summary>
/// <param name="type">The type for which to create a unique identifier. Cannot be null.</param>
public TypeId(Type type)
: this(
Throw.IfNullOrMemberNull(type.Assembly,
type.Assembly.FullName),
Throw.IfMemberNull(type,
type.FullName))
{ }
/// <inheritdoc />
public override bool Equals(object? obj)
=> obj is TypeId other
&& this.AssemblyName == other.AssemblyName
&& this.TypeName == other.TypeName;
/// <inheritdoc />
public override int GetHashCode() => HashCode.Combine(this.AssemblyName, this.TypeName);
/// <inheritdoc />
public static bool operator ==(TypeId? left, TypeId? right) => object.ReferenceEquals(left, right) || (!object.ReferenceEquals(left, null) && left.Equals(right));
/// <inheritdoc />
public static bool operator !=(TypeId? left, TypeId? right) => !(left == right);
/// <summary>
/// Determines whether the specified type matches both the assembly name and type name represented by this instance.
/// </summary>
/// <param name="type">The type to compare against the stored assembly and type names. Cannot be null.</param>
/// <returns>true if the specified type's assembly and type names are equal to those stored in this instance; otherwise,
/// false.</returns>
public bool IsMatch(Type type)
{
return this.AssemblyName == type.Assembly.FullName
&& this.TypeName == type.FullName;
}
/// <summary>
/// Determines whether the current instance matches the specified type parameter.
/// </summary>
/// <typeparam name="T">The type to compare against the current instance.</typeparam>
/// <returns>true if the current instance matches the specified type; otherwise, false.</returns>
public bool IsMatch<T>() => this.IsMatch(typeof(T));
/// <summary>
/// Determines whether the specified type or any of its base types match the criteria defined by this instance.
/// </summary>
/// <param name="type">The type to evaluate for a match, including its inheritance hierarchy.</param>
/// <returns>true if the specified type or any of its base types satisfy the match criteria; otherwise, false.</returns>
public bool IsMatchPolymorphic(Type type)
{
Type? candidateType = type;
while (candidateType != null)
{
if (this.IsMatch(candidateType))
{
return true;
}
candidateType = candidateType.BaseType;
}
return false;
}
/// <inheritdoc/>
public override string ToString() => $"{this.TypeName}, {this.AssemblyName}";
}
@@ -3,20 +3,22 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows.Checkpointing;
internal class WorkflowInfo
{
[JsonConstructor]
internal WorkflowInfo(
Dictionary<string, ExecutorInfo> executors,
Dictionary<string, List<EdgeInfo>> edges,
HashSet<InputPortInfo> inputPorts,
TypeId inputType,
string startExecutorId,
TypeId? outputType = null,
string? outputCollectorId = null)
TypeId? outputType,
string? outputCollectorId)
{
this.Executors = Throw.IfNull(executors);
this.Edges = Throw.IfNull(edges);
@@ -93,8 +95,8 @@ internal class WorkflowInfo
if (workflow.Ports.Count != this.InputPorts.Count ||
this.InputPorts.Any(portInfo =>
!workflow.Ports.TryGetValue(portInfo.PortId, out InputPort? port) ||
!portInfo.InputType.IsMatch(port.Request) ||
!portInfo.OutputType.IsMatch(port.Response)))
!portInfo.RequestType.IsMatch(port.Request) ||
!portInfo.ResponseType.IsMatch(port.Response)))
{
return false;
}
@@ -9,27 +9,32 @@ namespace Microsoft.Agents.Workflows;
/// Represents a directed edge between two nodes, optionally associated with a condition that determines whether the
/// edge is active.
/// </summary>
/// <param name="sourceId">The id of the source executor node.</param>
/// <param name="sinkId">The id of the target executor node.</param>
/// <param name="condition">A predicate determining whether the edge is active for a given message.</param>
public sealed class DirectEdgeData(string sourceId, string sinkId, PredicateT? condition = null) : EdgeData
public sealed class DirectEdgeData : EdgeData
{
internal DirectEdgeData(string sourceId, string sinkId, EdgeId id, PredicateT? condition = null) : base(id)
{
this.SourceId = sourceId;
this.SinkId = sinkId;
this.Condition = condition;
this.Connection = new([sourceId], [sinkId]);
}
/// <summary>
/// The Id of the source <see cref="Executor"/> node.
/// </summary>
public string SourceId => sourceId;
public string SourceId { get; }
/// <summary>
/// The Id of the destination <see cref="Executor"/> node.
/// </summary>
public string SinkId => sinkId;
public string SinkId { get; }
/// <summary>
/// An optional predicate determining whether the edge is active for a given message. If <see langword="null"/>,
/// the edge is always active when a message is generated by the source.
/// </summary>
public PredicateT? Condition => condition;
public PredicateT? Condition { get; }
/// <inheritdoc />
internal override EdgeConnection Connection { get; } = new([sourceId], [sinkId]);
internal override EdgeConnection Connection { get; }
}
+27 -27
View File
@@ -4,43 +4,43 @@ using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// Specified the edge type.
/// </summary>
public enum EdgeKind
{
/// <summary>
/// A direct connection from one node to another.
/// </summary>
Direct,
/// <summary>
/// A connection from one node to a set of nodes.
/// </summary>
FanOut,
/// <summary>
/// A connection from a set of nodes to a single node.
/// </summary>
FanIn
}
/// <summary>
/// Represents a connection or relationship between nodes, characterized by its type and associated data.
/// </summary>
/// <remarks>
/// An <see cref="Edge"/> can be of type <see cref="Type.Direct"/>, <see cref="Type.FanOut"/>, or <see
/// cref="Type.FanIn"/>, as specified by the <see cref="EdgeType"/> property. The <see cref="Data"/> property holds
/// An <see cref="Edge"/> can be of type <see cref="EdgeKind.Direct"/>, <see cref="EdgeKind.FanOut"/>, or <see
/// cref="EdgeKind.FanIn"/>, as specified by the <see cref="Kind"/> property. The <see cref="Data"/> property holds
/// additional information relevant to the edge, and its concrete type depends on the value of <see
/// cref="EdgeType"/>, functioning as a tagged union.
/// cref="Kind"/>, functioning as a tagged union.
/// </remarks>
public sealed class Edge
{
/// <summary>
/// Specified the edge type.
/// </summary>
public enum Type
{
/// <summary>
/// A direct connection from one node to another.
/// </summary>
Direct,
/// <summary>
/// A connection from one node to a set of nodes.
/// </summary>
FanOut,
/// <summary>
/// A connection from a set of nodes to a single node.
/// </summary>
FanIn
}
/// <summary>
/// Specifies the type of the edge, which determines how the edge is processed in the workflow.
/// </summary>
public Type EdgeType { get; init; }
public EdgeKind Kind { get; init; }
/// <summary>
/// The <see cref="Type"/>-dependent edge data.
/// The <see cref="EdgeKind"/>-dependent edge data.
/// </summary>
/// <seealso cref="DirectEdgeData"/>
/// <seealso cref="FanOutEdgeData"/>
@@ -51,21 +51,21 @@ public sealed class Edge
{
this.Data = Throw.IfNull(data);
this.EdgeType = Type.Direct;
this.Kind = EdgeKind.Direct;
}
internal Edge(FanOutEdgeData data)
{
this.Data = Throw.IfNull(data);
this.EdgeType = Type.FanOut;
this.Kind = EdgeKind.FanOut;
}
internal Edge(FanInEdgeData data)
{
this.Data = Throw.IfNull(data);
this.EdgeType = Type.FanIn;
this.Kind = EdgeKind.FanIn;
}
internal DirectEdgeData? DirectEdgeData => this.Data as DirectEdgeData;
@@ -13,4 +13,11 @@ public abstract class EdgeData
/// Gets the connection representation of the edge.
/// </summary>
internal abstract EdgeConnection Connection { get; }
internal EdgeData(EdgeId id)
{
this.Id = id;
}
internal EdgeId Id { get; }
}
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json.Serialization;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// A unique identifier of an <see cref="Edge"/> within a <see cref="Workflow"/>.
/// </summary>
public readonly struct EdgeId : IEquatable<EdgeId>
{
[JsonConstructor]
internal EdgeId(int edgeIndex)
{
this.EdgeIndex = edgeIndex;
}
internal int EdgeIndex { get; }
/// <inheritdoc />
public override bool Equals(object? obj)
{
if (obj == null)
{
return false;
}
if (obj is EdgeId edgeId)
{
return this.EdgeIndex == edgeId.EdgeIndex;
}
if (obj is int edgeIndex)
{
return this.EdgeIndex == edgeIndex;
}
return false;
}
/// <inheritdoc />
public bool Equals(EdgeId other)
{
return this.EdgeIndex == other.EdgeIndex;
}
/// <inheritdoc />
public override int GetHashCode()
{
return this.EdgeIndex.GetHashCode();
}
/// <inheritdoc />
public static bool operator ==(EdgeId left, EdgeId right) => left.Equals(right);
/// <inheritdoc />
public static bool operator !=(EdgeId left, EdgeId right) => !left.Equals(right);
}
@@ -15,10 +15,19 @@ namespace Microsoft.Agents.Workflows.Execution;
/// Ordering is relevant because in at least one case, the order of sinks is significant for the execution of
/// the edge: <see cref="FanOutEdgeData"/>.
/// </remarks>
/// <param name="sourceIds">An ordered list of unique identifiers of the sources connected by this edge.</param>
/// <param name="sinkIds">An ordered list of unique identifiers of the sinks connected by this edge.</param>
public class EdgeConnection(List<string> sourceIds, List<string> sinkIds) : IEquatable<EdgeConnection>
public class EdgeConnection : IEquatable<EdgeConnection>
{
/// <summary>
/// Create an <see cref="EdgeConnection"/> instance with the specified source and sink IDs.
/// </summary>
/// <param name="sourceIds">An ordered list of unique identifiers of the sources connected by this edge.</param>
/// <param name="sinkIds">An ordered list of unique identifiers of the sinks connected by this edge.</param>
public EdgeConnection(List<string> sourceIds, List<string> sinkIds)
{
this.SourceIds = Throw.IfNull(sourceIds);
this.SinkIds = Throw.IfNull(sinkIds);
}
/// <summary>
/// Creates a new <see cref="EdgeConnection"/> instance with the specified source and sink IDs, ensuring that all
/// IDs are unique.
@@ -82,13 +91,27 @@ public class EdgeConnection(List<string> sourceIds, List<string> sinkIds) : IEqu
);
}
/// <inheritdoc />
public static bool operator ==(EdgeConnection? left, EdgeConnection? right)
{
if (left is null)
{
return right is null;
}
return left.Equals(right);
}
/// <inheritdoc />
public static bool operator !=(EdgeConnection? left, EdgeConnection? right) => !(left == right);
/// <summary>
/// The unique identifiers of the sources connected by this edge.
/// </summary>
public List<string> SourceIds { get; } = sourceIds;
public List<string> SourceIds { get; }
/// <summary>
/// The unique identifiers of the sinks connected by this edge.
/// </summary>
public List<string> SinkIds { get; } = sinkIds;
public List<string> SinkIds { get; }
}
@@ -10,8 +10,8 @@ namespace Microsoft.Agents.Workflows.Execution;
internal class EdgeMap
{
private readonly Dictionary<EdgeConnection, object> _edgeRunners = new();
private readonly Dictionary<EdgeConnection, FanInEdgeState> _fanInState = new();
private readonly Dictionary<EdgeId, object> _edgeRunners = new();
private readonly Dictionary<EdgeId, FanInEdgeState> _fanInState = new();
private readonly Dictionary<string, InputEdgeRunner> _portEdgeRunners;
private readonly InputEdgeRunner _inputRunner;
private readonly IStepTracer? _stepTracer;
@@ -24,20 +24,20 @@ internal class EdgeMap
{
foreach (Edge edge in workflowEdges.Values.SelectMany(e => e))
{
object edgeRunner = edge.EdgeType switch
object edgeRunner = edge.Kind switch
{
Edge.Type.Direct => new DirectEdgeRunner(runContext, edge.DirectEdgeData!),
Edge.Type.FanOut => new FanOutEdgeRunner(runContext, edge.FanOutEdgeData!),
Edge.Type.FanIn => new FanInEdgeRunner(runContext, edge.FanInEdgeData!),
_ => throw new NotSupportedException($"Unsupported edge type: {edge.EdgeType}")
EdgeKind.Direct => new DirectEdgeRunner(runContext, edge.DirectEdgeData!),
EdgeKind.FanOut => new FanOutEdgeRunner(runContext, edge.FanOutEdgeData!),
EdgeKind.FanIn => new FanInEdgeRunner(runContext, edge.FanInEdgeData!),
_ => throw new NotSupportedException($"Unsupported edge type: {edge.Kind}")
};
if (edgeRunner is FanInEdgeRunner fanInRunner)
{
this._fanInState[edge.Data.Connection] = fanInRunner.CreateState();
this._fanInState[edge.Data.Id] = fanInRunner.CreateState();
}
this._edgeRunners[edge.Data.Connection] = edgeRunner;
this._edgeRunners[edge.Data.Id] = edgeRunner;
}
this._portEdgeRunners = workflowPorts.ToDictionary(
@@ -51,14 +51,14 @@ internal class EdgeMap
public async ValueTask<IEnumerable<object?>> InvokeEdgeAsync(Edge edge, string sourceId, MessageEnvelope message)
{
EdgeConnection connection = edge.Data.Connection;
if (!this._edgeRunners.TryGetValue(connection, out object? edgeRunner))
EdgeId id = edge.Data.Id;
if (!this._edgeRunners.TryGetValue(id, out object? edgeRunner))
{
throw new InvalidOperationException($"Edge {edge} not found in the edge map.");
}
IEnumerable<object?> edgeResults;
switch (edge.EdgeType)
switch (edge.Kind)
{
// We know the corresponding EdgeRunner type given the FlowEdge EdgeType, as
// established in the EdgeMap() ctor; this avoid doing an as-cast inside of
@@ -66,24 +66,24 @@ internal class EdgeMap
// in FanIn/Out cases)
// TODO: Once we have a fixed interface, if it is reasonably generalizable
// between the Runners, we can normalize it behind an IFace.
case Edge.Type.Direct:
case EdgeKind.Direct:
{
DirectEdgeRunner runner = (DirectEdgeRunner)this._edgeRunners[connection];
DirectEdgeRunner runner = (DirectEdgeRunner)this._edgeRunners[id];
edgeResults = await runner.ChaseAsync(message, this._stepTracer).ConfigureAwait(false);
break;
}
case Edge.Type.FanOut:
case EdgeKind.FanOut:
{
FanOutEdgeRunner runner = (FanOutEdgeRunner)this._edgeRunners[connection];
FanOutEdgeRunner runner = (FanOutEdgeRunner)this._edgeRunners[id];
edgeResults = await runner.ChaseAsync(message, this._stepTracer).ConfigureAwait(false);
break;
}
case Edge.Type.FanIn:
case EdgeKind.FanIn:
{
FanInEdgeState state = this._fanInState[connection];
FanInEdgeRunner runner = (FanInEdgeRunner)this._edgeRunners[connection];
FanInEdgeState state = this._fanInState[id];
FanInEdgeRunner runner = (FanInEdgeRunner)this._edgeRunners[id];
edgeResults = [await runner.ChaseAsync(sourceId, message, state, this._stepTracer).ConfigureAwait(false)];
break;
}
@@ -104,43 +104,45 @@ internal class EdgeMap
public async ValueTask<IEnumerable<object?>> InvokeResponseAsync(ExternalResponse response)
{
if (!this._portEdgeRunners.TryGetValue(response.Port.Id, out InputEdgeRunner? portRunner))
if (!this._portEdgeRunners.TryGetValue(response.PortInfo.PortId, out InputEdgeRunner? portRunner))
{
throw new InvalidOperationException($"Port {response.Port.Id} not found in the edge map.");
throw new InvalidOperationException($"Port {response.PortInfo.PortId} not found in the edge map.");
}
return [await portRunner.ChaseAsync(new MessageEnvelope(response), this._stepTracer).ConfigureAwait(false)];
}
internal ValueTask<Dictionary<EdgeConnection, ExportedState>> ExportStateAsync()
internal ValueTask<Dictionary<EdgeId, PortableValue>> ExportStateAsync()
{
Dictionary<EdgeConnection, ExportedState> exportedStates = new();
Dictionary<EdgeId, PortableValue> exportedStates = new();
// Right now there is only fan-in state
foreach (EdgeConnection connection in this._fanInState.Keys)
foreach (EdgeId id in this._fanInState.Keys)
{
FanInEdgeState state = this._fanInState[connection];
exportedStates[connection] = new ExportedState(state);
FanInEdgeState state = this._fanInState[id];
exportedStates[id] = new PortableValue(state);
}
return new ValueTask<Dictionary<EdgeConnection, ExportedState>>(exportedStates);
return new(exportedStates);
}
internal ValueTask ImportStateAsync(Checkpoint checkpoint)
{
Dictionary<EdgeConnection, ExportedState> importedState = checkpoint.EdgeState;
Dictionary<EdgeId, PortableValue> importedState = checkpoint.EdgeStateData;
this._fanInState.Clear();
foreach (EdgeConnection connection in importedState.Keys)
foreach (EdgeId id in importedState.Keys)
{
ExportedState exportedState = importedState[connection];
if (exportedState.Value is FanInEdgeState fanInState)
PortableValue exportedState = importedState[id];
FanInEdgeState? fanInState = exportedState.As<FanInEdgeState>();
if (fanInState is not null)
{
this._fanInState[connection] = fanInState;
this._fanInState[id] = fanInState;
}
else
{
throw new InvalidOperationException($"Unsupported exported state type: {exportedState.GetType()} for connection {connection}");
throw new InvalidOperationException($"Unsupported exported state type: {exportedState.GetType()}; {id}");
}
}
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading.Tasks;
namespace Microsoft.Agents.Workflows.Execution;
@@ -12,33 +13,44 @@ internal class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData edgeData
public FanInEdgeState CreateState() => new(this.EdgeData);
public async ValueTask<IEnumerable<object?>> ChaseAsync(string sourceId, MessageEnvelope envelope, FanInEdgeState state, IStepTracer? tracer)
public ValueTask<IEnumerable<object?>> ChaseAsync(string sourceId, MessageEnvelope envelope, FanInEdgeState state, IStepTracer? tracer)
{
if (envelope.TargetId != null && this.EdgeData.SinkId != envelope.TargetId)
{
// This message is not for us.
return [];
return new([]);
}
object message = envelope.Message;
IEnumerable<object>? releasedMessages = state.ProcessMessage(sourceId, message);
IEnumerable<MessageEnvelope>? releasedMessages = state.ProcessMessage(sourceId, envelope);
if (releasedMessages is null)
{
// Not ready to process yet.
return [];
return new([]);
}
return this.ForwardReleasedMessagesAsync(releasedMessages, tracer);
}
private async ValueTask<IEnumerable<object?>> ForwardReleasedMessagesAsync(IEnumerable<MessageEnvelope> releasedMessages, IStepTracer? tracer)
{
Executor target = await this.RunContext.EnsureExecutorAsync(this.EdgeData.SinkId, tracer)
.ConfigureAwait(false);
List<Task<object?>> messageTasks = [];
foreach (var messageTask in releasedMessages)
foreach (MessageEnvelope releasedEnvelope in releasedMessages)
{
if (target.CanHandle(messageTask.GetType()))
object message = releasedEnvelope.Message;
Debug.Assert(message is PortableValue, "It should not be possible to get messages released without roundtripping them through" +
"PortableValue via PortableMessageEnvelope.");
PortableValue portable = message as PortableValue ?? new PortableValue(releasedEnvelope.MessageType, message);
if (target.CanHandle(portable.TypeId))
{
tracer?.TraceActivated(target.Id);
messageTasks.Add(target.ExecuteAsync(messageTask, envelope.MessageType, this.BoundContext).AsTask());
messageTasks.Add(target.ExecuteAsync(portable, releasedEnvelope.MessageType, this.BoundContext).AsTask());
}
}
@@ -1,28 +1,53 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using Microsoft.Agents.Workflows.Checkpointing;
namespace Microsoft.Agents.Workflows.Execution;
internal record FanInEdgeState(FanInEdgeData EdgeData)
internal class FanInEdgeState
{
private List<object>? _pendingMessages = [];
private HashSet<string>? _unseen = new(EdgeData.SourceIds);
public IEnumerable<object>? ProcessMessage(string sourceId, object message)
private List<PortableMessageEnvelope> _pendingMessages;
public FanInEdgeState(FanInEdgeData fanInEdge)
{
this._pendingMessages!.Add(message);
this._unseen!.Remove(sourceId);
this.SourceIds = fanInEdge.SourceIds.ToArray();
this.Unseen = new(this.SourceIds);
if (this._unseen.Count == 0)
this._pendingMessages = [];
}
public string[] SourceIds { get; }
public HashSet<string> Unseen { get; private set; }
public List<PortableMessageEnvelope> PendingMessages => this._pendingMessages;
[JsonConstructor]
public FanInEdgeState(string[] sourceIds, HashSet<string> unseen, List<PortableMessageEnvelope> pendingMessages)
{
this.SourceIds = sourceIds;
this.Unseen = unseen;
this._pendingMessages = pendingMessages;
}
public IEnumerable<MessageEnvelope>? ProcessMessage(string sourceId, MessageEnvelope envelope)
{
this.PendingMessages.Add(new(envelope));
this.Unseen.Remove(sourceId);
if (this.Unseen.Count == 0)
{
List<object> result = this._pendingMessages;
List<PortableMessageEnvelope> takenMessages = Interlocked.Exchange(ref this._pendingMessages, []);
this.Unseen = new(this.SourceIds);
this._pendingMessages = [];
this._unseen = new(this.EdgeData.SourceIds);
if (takenMessages.Count == 0)
{
return null;
}
return result;
return takenMessages.Select(portable => portable.ToMessageEnvelope());
}
return null;
@@ -35,7 +35,7 @@ internal class InputEdgeRunner(IRunnerContext runContext, string sinkId)
}
// TODO: Throw instead? / Log
Debug.WriteLine($"Executor {target.Id} cannot handle message of type {envelope.MessageType.FullName}. Dropping.");
Debug.WriteLine($"Executor {target.Id} cannot handle message of type {envelope.MessageType.TypeName}. Dropping.");
return null;
}
@@ -1,12 +1,22 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using Microsoft.Agents.Workflows.Checkpointing;
namespace Microsoft.Agents.Workflows.Execution;
internal sealed class MessageEnvelope(object message, Type? declaredType = null, string? targetId = null)
internal sealed class MessageEnvelope(object message, TypeId? declaredType = null, string? targetId = null)
{
public Type MessageType => declaredType ?? message.GetType();
public TypeId MessageType => declaredType ?? new(message.GetType());
public object Message => message;
public string? TargetId => targetId;
internal MessageEnvelope(object message, Type declaredType, string? targetId = null)
: this(message, new TypeId(declaredType), targetId)
{
if (!declaredType.IsAssignableFrom(message.GetType()))
{
throw new ArgumentException($"The declared type {declaredType} is not compatible with the message instance of type {message.GetType()}");
}
}
}
@@ -2,7 +2,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Shared.Diagnostics;
using MessageHandlerF =
@@ -17,24 +19,29 @@ namespace Microsoft.Agents.Workflows.Execution;
internal class MessageRouter
{
private readonly Dictionary<Type, MessageHandlerF> _typedHandlers;
private readonly Dictionary<TypeId, Type> _runtimeTypeMap;
private readonly bool _hasCatchall;
internal MessageRouter(Dictionary<Type, MessageHandlerF> handlers)
{
this._typedHandlers = Throw.IfNull(handlers);
this._hasCatchall = this._typedHandlers.ContainsKey(typeof(object));
Throw.IfNull(handlers);
this.IncomingTypes = [.. this._typedHandlers.Keys];
this._typedHandlers = handlers;
this._runtimeTypeMap = handlers.Keys.ToDictionary(t => new TypeId(t), t => t);
this._hasCatchall = handlers.ContainsKey(typeof(object));
this.IncomingTypes = [.. handlers.Keys];
}
public HashSet<Type> IncomingTypes { get; }
public bool CanHandle(object message) => this.CanHandle(Throw.IfNull(message).GetType());
public bool CanHandle(object message) => this.CanHandle(new TypeId(Throw.IfNull(message).GetType()));
public bool CanHandle(Type candidateType) => this.CanHandle(new TypeId(Throw.IfNull(candidateType)));
public bool CanHandle(Type candidateType)
public bool CanHandle(TypeId candidateType)
{
// For now we only support routing to handlers registered on the exact type (no base type delegation).
return this._hasCatchall || this._typedHandlers.ContainsKey(candidateType);
return this._hasCatchall || this._runtimeTypeMap.ContainsKey(candidateType);
}
public async ValueTask<CallResult?> RouteMessageAsync(object message, IWorkflowContext context, bool requireRoute = false)
@@ -43,6 +50,13 @@ internal class MessageRouter
CallResult? result = null;
if (message is PortableValue portableValue &&
this._runtimeTypeMap.TryGetValue(portableValue.TypeId, out Type? runtimeType))
{
// If we found a runtime type, we can use it
message = portableValue.AsType(runtimeType) ?? message;
}
try
{
if (this._typedHandlers.TryGetValue(message.GetType(), out MessageHandlerF? handler))
@@ -5,9 +5,9 @@ using Microsoft.Agents.Workflows.Checkpointing;
namespace Microsoft.Agents.Workflows.Execution;
internal class RunnerStateData(HashSet<string> instantiatedExecutors, Dictionary<ExecutorIdentity, List<ExportedState>> queuedMessages, List<ExternalRequest> outstandingRequests)
internal class RunnerStateData(HashSet<string> instantiatedExecutors, Dictionary<ExecutorIdentity, List<PortableMessageEnvelope>> queuedMessages, List<ExternalRequest> outstandingRequests)
{
public HashSet<string> InstantiatedExecutors { get; } = instantiatedExecutors;
public Dictionary<ExecutorIdentity, List<ExportedState>> QueuedMessages { get; } = queuedMessages;
public Dictionary<ExecutorIdentity, List<PortableMessageEnvelope>> QueuedMessages { get; } = queuedMessages;
public List<ExternalRequest> OutstandingRequests { get; } = outstandingRequests;
}
@@ -109,12 +109,12 @@ internal class StateManager
// What's the right thing to do when we have a state object, but it is the wrong type?
if (result.IsDelete)
{
return new ValueTask<T?>((T?)default);
return new((T?)default);
}
if (result.Value is T)
{
return new ValueTask<T?>((T?)result.Value);
return new((T?)result.Value);
}
throw new InvalidOperationException($"State for key '{key}' in scope '{scopeId}' is not of type '{typeof(T).Name}'.");
@@ -124,20 +124,31 @@ internal class StateManager
return scope.ReadStateAsync<T>(key);
}
public ValueTask WriteStateAsync<T>(string executorId, string? scopeName, string key, T? value)
public ValueTask WriteStateAsync<T>(string executorId, string? scopeName, string key, T value)
=> this.WriteStateAsync<T>(new ScopeId(Throw.IfNullOrEmpty(executorId), scopeName), key, value);
public ValueTask WriteStateAsync<T>(ScopeId scopeId, string key, T? value)
public ValueTask WriteStateAsync<T>(ScopeId scopeId, string key, T value)
{
Throw.IfNullOrEmpty(key);
UpdateKey stateKey = new(scopeId, key);
StateUpdate update = value == null ? StateUpdate.Delete(key) : StateUpdate.Update(key, value);
StateUpdate update = StateUpdate.Update(key, value);
this._queuedUpdates[stateKey] = update;
return default;
}
public ValueTask ClearStateAsync(string executorId, string? scopeName, string key)
=> this.ClearStateAsync(new ScopeId(Throw.IfNullOrEmpty(executorId), scopeName), key);
public ValueTask ClearStateAsync(ScopeId scopeId, string key)
{
Throw.IfNullOrEmpty(key);
UpdateKey stateKey = new(scopeId, key);
this._queuedUpdates[stateKey] = StateUpdate.Delete(key);
return default;
}
public async ValueTask PublishUpdatesAsync(IStepTracer? tracer)
{
Dictionary<ScopeId, Dictionary<string, List<StateUpdate>>> updatesByScope = [];
@@ -172,15 +183,15 @@ internal class StateManager
this._queuedUpdates.Clear();
}
private static IEnumerable<KeyValuePair<ScopeKey, ExportedState>> ExportScope(StateScope scope)
private static IEnumerable<KeyValuePair<ScopeKey, PortableValue>> ExportScope(StateScope scope)
{
foreach (KeyValuePair<string, ExportedState> state in scope.ExportStates())
foreach (KeyValuePair<string, PortableValue> state in scope.ExportStates())
{
yield return new(new ScopeKey(scope.ScopeId, state.Key), state.Value);
}
}
internal async ValueTask<Dictionary<ScopeKey, ExportedState>> ExportStateAsync()
internal async ValueTask<Dictionary<ScopeKey, PortableValue>> ExportStateAsync()
{
if (this._queuedUpdates.Count != 0)
{
@@ -201,7 +212,7 @@ internal class StateManager
this._queuedUpdates.Clear();
this._scopes.Clear();
Dictionary<ScopeKey, ExportedState> importedState = checkpoint.State;
Dictionary<ScopeKey, PortableValue> importedState = checkpoint.StateData;
foreach (ScopeKey scopeKey in importedState.Keys)
{
@@ -4,14 +4,13 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows.Execution;
internal class StateScope
{
private readonly Dictionary<string, object> _stateData = new();
private readonly Dictionary<string, PortableValue> _stateData = new();
public ScopeId ScopeId { get; }
public StateScope(ScopeId scopeId)
@@ -30,15 +29,32 @@ internal class StateScope
return new(keys);
}
public bool Contains<T>(string key)
{
Throw.IfNullOrEmpty(key);
if (this._stateData.TryGetValue(key, out PortableValue? value))
{
return value.Is<T>();
}
return false;
}
public bool ContainsKey(string key)
{
Throw.IfNullOrEmpty(key);
return this._stateData.ContainsKey(key);
}
public ValueTask<T?> ReadStateAsync<T>(string key)
{
Throw.IfNullOrEmpty(key);
if (this._stateData.TryGetValue(key, out object? value) && value is T typedValue)
if (this._stateData.TryGetValue(key, out PortableValue? value))
{
return new ValueTask<T?>(typedValue);
return new(value.As<T>());
}
return new ValueTask<T?>((T?)default);
return new((T?)default);
}
public ValueTask WriteStateAsync(Dictionary<string, List<StateUpdate>> updates)
@@ -64,28 +80,28 @@ internal class StateScope
}
else
{
this._stateData[key] = update.Value!;
this._stateData[key] = new PortableValue(update.Value!);
}
}
return default;
}
public IEnumerable<KeyValuePair<string, ExportedState>> ExportStates()
public IEnumerable<KeyValuePair<string, PortableValue>> ExportStates()
{
return this._stateData.Keys.Select(WrapStates);
KeyValuePair<string, ExportedState> WrapStates(string key)
KeyValuePair<string, PortableValue> WrapStates(string key)
{
return new(key, new(this._stateData[key]));
return new(key, this._stateData[key]);
}
}
public void ImportState(string key, ExportedState state)
public void ImportState(string key, PortableValue state)
{
Throw.IfNullOrEmpty(key);
Throw.IfNull(state);
this._stateData[key] = state.Value;
this._stateData[key] = state;
}
}
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Agents.Workflows.Checkpointing;
@@ -25,25 +24,23 @@ internal class StepContext
// TODO: Create a MessageEnvelope class that extends from the ExportedState object (with appropriate rename) to avoid
// unnecessary wrapping and unwrapping of messages during checkpointing.
internal Dictionary<ExecutorIdentity, List<ExportedState>> ExportMessages()
internal Dictionary<ExecutorIdentity, List<PortableMessageEnvelope>> ExportMessages()
{
return this.QueuedMessages.Keys.ToDictionary(
keySelector: identity => identity,
elementSelector: identity => this.QueuedMessages[identity]
.Select(v => new ExportedState(v))
.Select(v => new PortableMessageEnvelope(v))
.ToList()
);
}
internal void ImportMessages(Dictionary<ExecutorIdentity, List<ExportedState>> messages)
internal void ImportMessages(Dictionary<ExecutorIdentity, List<PortableMessageEnvelope>> messages)
{
foreach (ExecutorIdentity identity in messages.Keys)
{
this.QueuedMessages[identity] = messages[identity].Select(UnwrapExportedState).ToList();
}
MessageEnvelope UnwrapExportedState(ExportedState es)
=> es.Value as MessageEnvelope
?? throw new InvalidDataException($"Expected a MessageEnvelope in the ExportedState. Got {es.RuntimeType}");
MessageEnvelope UnwrapExportedState(PortableMessageEnvelope es) => es.ToMessageEnvelope();
}
}
@@ -6,6 +6,7 @@ using System.Diagnostics;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Agents.Workflows.Execution;
using Microsoft.Agents.Workflows.Reflection;
@@ -65,7 +66,7 @@ public abstract class Executor : IIdentified
/// <returns>A ValueTask representing the asynchronous operation, wrapping the output from the executor.</returns>
/// <exception cref="NotSupportedException">No handler found for the message type.</exception>
/// <exception cref="TargetInvocationException">An exception is generated while handling the message.</exception>
public async ValueTask<object?> ExecuteAsync(object message, Type messageType, IWorkflowContext context)
public async ValueTask<object?> ExecuteAsync(object message, TypeId messageType, IWorkflowContext context)
{
await context.AddEventAsync(new ExecutorInvokedEvent(this.Id, message)).ConfigureAwait(false);
@@ -79,7 +80,7 @@ public abstract class Executor : IIdentified
}
else
{
executionResult = new ExecutorFailureEvent(this.Id, result.Exception);
executionResult = new ExecutorFailedEvent(this.Id, result.Exception);
}
await context.AddEventAsync(executionResult).ConfigureAwait(false);
@@ -141,6 +142,8 @@ public abstract class Executor : IIdentified
/// <param name="messageType"></param>
/// <returns></returns>
public bool CanHandle(Type messageType) => this.Router.CanHandle(messageType);
internal bool CanHandle(TypeId messageType) => this.Router.CanHandle(messageType);
}
/// <summary>
@@ -1,10 +1,15 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// Base class for <see cref="Executor"/>-scoped events.
/// </summary>
[JsonDerivedType(typeof(ExecutorInvokedEvent))]
[JsonDerivedType(typeof(ExecutorCompletedEvent))]
[JsonDerivedType(typeof(ExecutorFailedEvent))]
public class ExecutorEvent(string executorId, object? data) : WorkflowEvent(data)
{
/// <summary>
@@ -9,7 +9,7 @@ namespace Microsoft.Agents.Workflows;
/// </summary>
/// <param name="executorId">The unique identifier of the executor that has failed.</param>
/// <param name="err">The exception representing the error.</param>
public sealed class ExecutorFailureEvent(string executorId, Exception? err)
public sealed class ExecutorFailedEvent(string executorId, Exception? err)
: ExecutorEvent(executorId, data: err)
{
/// <summary>
@@ -2,6 +2,7 @@
using System;
using System.Diagnostics.CodeAnalysis;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows;
@@ -9,11 +10,25 @@ namespace Microsoft.Agents.Workflows;
/// <summary>
/// Represents a request to an external input port.
/// </summary>
/// <param name="Port">The port to invoke.</param>
/// <param name="PortInfo">The port to invoke.</param>
/// <param name="RequestId">A unique identifier for this request instance.</param>
/// <param name="Data">The data contained in the request.</param>
public record ExternalRequest(InputPort Port, string RequestId, object Data)
public record ExternalRequest(InputPortInfo PortInfo, string RequestId, PortableValue Data)
{
/// <summary>
/// Attempts to retrieve the underlying data as the specified type.
/// </summary>
/// <typeparam name="TValue">The type to which the data should be cast or converted.</typeparam>
/// <returns>The data cast to the specified type, or null if the data cannot be cast to the specified type.</returns>
public TValue? DataAs<TValue>() => this.Data.As<TValue>();
/// <summary>
/// Determines whether the underlying data is of the specified type.
/// </summary>
/// <typeparam name="TValue">The type to compare with the underlying data.</typeparam>
/// <returns>true if the underlying data is of type TValue; otherwise, false.</returns>
public bool DataIs<TValue>() => this.Data.Is<TValue>();
/// <summary>
/// Creates a new <see cref="ExternalRequest"/> for the specified input port and data payload.
/// </summary>
@@ -32,7 +47,7 @@ public record ExternalRequest(InputPort Port, string RequestId, object Data)
requestId ??= Guid.NewGuid().ToString("N");
return new ExternalRequest(port, requestId, data);
return new ExternalRequest(port.ToPortInfo(), requestId, new PortableValue(data));
}
/// <summary>
@@ -53,13 +68,13 @@ public record ExternalRequest(InputPort Port, string RequestId, object Data)
/// <exception cref="InvalidOperationException">Thrown when the input data object does not match the expected response type.</exception>
public ExternalResponse CreateResponse(object data)
{
if (!Throw.IfNull(this.Port).Response.IsAssignableFrom(Throw.IfNull(data).GetType()))
if (!Throw.IfNull(this.PortInfo).ResponseType.IsMatchPolymorphic(Throw.IfNull(data).GetType()))
{
throw new InvalidOperationException(
$"Message type {data.GetType().Name} is not assignable to the response type {this.Port.Response.Name} of input port {this.Port.Id}.");
$"Message type {data.GetType().Name} does not match expected response type {this.PortInfo.ResponseType.TypeName} of input port {this.PortInfo.PortId}.");
}
return new ExternalResponse(this.Port, this.RequestId, data);
return new ExternalResponse(this.PortInfo, this.RequestId, new PortableValue(data));
}
/// <summary>
@@ -1,13 +1,36 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using Microsoft.Agents.Workflows.Checkpointing;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// Represents a request from an external input port.
/// </summary>
/// <param name="Port">The port invoked.</param>
/// <param name="PortInfo">The port invoked.</param>
/// <param name="RequestId">The unique identifier of the corresponding request.</param>
/// <param name="Data">The data contained in the response.</param>
public record ExternalResponse(InputPort Port, string RequestId, object Data)
public record ExternalResponse(InputPortInfo PortInfo, string RequestId, PortableValue Data)
{
/// <summary>
/// Attempts to retrieve the underlying data as the specified type.
/// </summary>
/// <typeparam name="TValue">The type to which the data should be cast or converted.</typeparam>
/// <returns>The data cast to the specified type, or null if the data cannot be cast to the specified type.</returns>
public TValue? DataAs<TValue>() => this.Data.As<TValue>();
/// <summary>
/// Determines whether the underlying data is of the specified type.
/// </summary>
/// <typeparam name="TValue">The type to compare with the underlying data.</typeparam>
/// <returns>true if the underlying data is of type TValue; otherwise, false.</returns>
public bool DataIs<TValue>() => this.Data.Is<TValue>();
/// <summary>
/// Attempts to retrieve the underlying data as the specified type.
/// </summary>
/// <param name="targetType">The type to which the data should be cast or converted.</param>
/// <returns>The data cast to the specified type, or null if the data cannot be cast to the specified type.</returns>
public object? DataAs(Type targetType) => this.Data.AsType(targetType);
}
@@ -8,20 +8,25 @@ namespace Microsoft.Agents.Workflows;
/// <summary>
/// Represents a connection from a set of nodes to a single node. It will trigger either when all edges have data.
/// </summary>
/// <param name="sourceIds">An enumeration of ids of the source executor nodes.</param>
/// <param name="sinkId">The id of the target executor node.</param>
public sealed class FanInEdgeData(List<string> sourceIds, string sinkId) : EdgeData
internal sealed class FanInEdgeData : EdgeData
{
internal FanInEdgeData(List<string> sourceIds, string sinkId, EdgeId id) : base(id)
{
this.SourceIds = sourceIds;
this.SinkId = sinkId;
this.Connection = new(sourceIds, [sinkId]);
}
/// <summary>
/// The ordered list of Ids of the source <see cref="Executor"/> nodes.
/// </summary>
public List<string> SourceIds => sourceIds;
public List<string> SourceIds { get; }
/// <summary>
/// The Id of the destination <see cref="Executor"/> node.
/// </summary>
public string SinkId => sinkId;
public string SinkId { get; }
/// <inheritdoc />
internal override EdgeConnection Connection { get; } = new(sourceIds, [sinkId]);
internal override EdgeConnection Connection { get; }
}
@@ -11,30 +11,32 @@ namespace Microsoft.Agents.Workflows;
/// Represents a connection from a single node to a set of nodes, optionally associated with a paritition selector
/// function which maps incoming messages to a subset of the target set.
/// </summary>
/// <param name="sourceId">The id of the source executor node.</param>
/// <param name="sinkIds">A list of ids of the target executor nodes.</param>
/// <param name="assigner">A function that maps an incoming message to a subset of the target executor nodes.</param>
public sealed class FanOutEdgeData(
string sourceId,
List<string> sinkIds,
AssignerF? assigner = null) : EdgeData
internal sealed class FanOutEdgeData : EdgeData
{
internal FanOutEdgeData(string sourceId, List<string> sinkIds, EdgeId edgeId, AssignerF? assigner = null) : base(edgeId)
{
this.SourceId = sourceId;
this.SinkIds = sinkIds;
this.EdgeAssigner = assigner;
this.Connection = new([sourceId], sinkIds);
}
/// <summary>
/// The Id of the source <see cref="Executor"/> node.
/// </summary>
public string SourceId => sourceId;
public string SourceId { get; }
/// <summary>
/// The ordered list of Ids of the destination <see cref="Executor"/> nodes.
/// </summary>
public List<string> SinkIds => sinkIds;
public List<string> SinkIds { get; }
/// <summary>
/// A function mapping an incoming message to a subset of the target executor nodes (or optionally all of them).
/// If <see langword="null"/>, all destination nodes are selected.
/// </summary>
public AssignerF? EdgeAssigner => assigner;
public AssignerF? EdgeAssigner { get; }
/// <inheritdoc />
internal override EdgeConnection Connection { get; } = new([sourceId], sinkIds);
internal override EdgeConnection Connection { get; }
}
@@ -22,17 +22,20 @@ namespace Microsoft.Agents.Workflows.InProc;
/// <typeparam name="TInput">The type of input accepted by the workflow. Must be non-nullable.</typeparam>
internal class InProcessRunner<TInput> : ISuperStepRunner, ICheckpointingRunner where TInput : notnull
{
public InProcessRunner(Workflow<TInput> workflow, ICheckpointManager? checkpointManager)
public InProcessRunner(Workflow<TInput> workflow, ICheckpointManager? checkpointManager, string? runId = null)
{
this.Workflow = Throw.IfNull(workflow);
this.RunContext = new InProcessRunnerContext<TInput>(workflow);
this.CheckpointManager = checkpointManager;
this.RunId = runId ?? Guid.NewGuid().ToString("N");
// Initialize the runners for each of the edges, along with the state for edges that
// need it.
this.EdgeMap = new EdgeMap(this.RunContext, this.Workflow.Edges, this.Workflow.Ports.Values, this.Workflow.StartExecutorId, this.StepTracer);
}
public string RunId { get; }
public async ValueTask<bool> IsValidInputAsync<TMessage>(TMessage message)
{
Throw.IfNull(message);
@@ -166,9 +169,22 @@ internal class InProcessRunner<TInput> : ISuperStepRunner, ICheckpointingRunner
return true;
}
this.EmitPendingEvents();
return false;
}
private void EmitPendingEvents()
{
if (this.RunContext.QueuedEvents.Count > 0)
{
foreach (WorkflowEvent @event in this.RunContext.QueuedEvents)
{
this.RaiseWorkflowEvent(@event);
}
this.RunContext.QueuedEvents.Clear();
}
}
private async ValueTask RunSuperstepAsync(StepContext currentStep)
{
this.RaiseWorkflowEvent(this.StepTracer.Advance(currentStep));
@@ -197,11 +213,7 @@ internal class InProcessRunner<TInput> : ISuperStepRunner, ICheckpointingRunner
IEnumerable<object?> results = (await Task.WhenAll(edgeTasks).ConfigureAwait(false)).SelectMany(r => r);
// After the message handler invocations, we may have some events to deliver
foreach (WorkflowEvent @event in this.RunContext.QueuedEvents)
{
this.RaiseWorkflowEvent(@event);
}
this.RunContext.QueuedEvents.Clear();
this.EmitPendingEvents();
await this.CheckpointAsync().ConfigureAwait(false);
@@ -228,16 +240,16 @@ internal class InProcessRunner<TInput> : ISuperStepRunner, ICheckpointingRunner
this._workflowInfoCache = this.Workflow.ToWorkflowInfo();
}
Dictionary<EdgeConnection, ExportedState> edgeData = await this.EdgeMap.ExportStateAsync().ConfigureAwait(false);
Dictionary<EdgeId, PortableValue> edgeData = await this.EdgeMap.ExportStateAsync().ConfigureAwait(false);
await prepareTask.ConfigureAwait(false);
await this.RunContext.StateManager.PublishUpdatesAsync(this.StepTracer).ConfigureAwait(false);
RunnerStateData runnerData = await this.RunContext.ExportStateAsync().ConfigureAwait(false);
Dictionary<ScopeKey, ExportedState> stateData = await this.RunContext.StateManager.ExportStateAsync().ConfigureAwait(false);
Dictionary<ScopeKey, PortableValue> stateData = await this.RunContext.StateManager.ExportStateAsync().ConfigureAwait(false);
Checkpoint checkpoint = new(this.StepTracer.StepNumber, this._workflowInfoCache, runnerData, stateData, edgeData);
CheckpointInfo checkpointInfo = await this.CheckpointManager.CommitCheckpointAsync(checkpoint).ConfigureAwait(false);
CheckpointInfo checkpointInfo = await this.CheckpointManager.CommitCheckpointAsync(this.RunId, checkpoint).ConfigureAwait(false);
this.StepTracer.TraceCheckpointCreated(checkpointInfo);
this._checkpoints.Add(checkpointInfo);
}
@@ -250,7 +262,7 @@ internal class InProcessRunner<TInput> : ISuperStepRunner, ICheckpointingRunner
throw new InvalidOperationException("This run was not configured with a CheckpointManager, so it cannot restore checkpoints.");
}
Checkpoint checkpoint = await this.CheckpointManager.LookupCheckpointAsync(checkpointInfo)
Checkpoint checkpoint = await this.CheckpointManager.LookupCheckpointAsync(this.RunId, checkpointInfo)
.ConfigureAwait(false);
// Validate the checkpoint is compatible with this workflow
@@ -283,11 +295,11 @@ internal class InProcessRunner<TInput, TResult> : IRunnerWithOutput<TResult>, IC
private readonly Workflow<TInput, TResult> _workflow;
private readonly InProcessRunner<TInput> _innerRunner;
public InProcessRunner(Workflow<TInput, TResult> workflow, CheckpointManager? checkpointManager)
public InProcessRunner(Workflow<TInput, TResult> workflow, CheckpointManager? checkpointManager, string? runId = null)
{
this._workflow = Throw.IfNull(workflow);
InProcessRunner<TInput> runner = new(workflow, checkpointManager);
InProcessRunner<TInput> runner = new(workflow, checkpointManager, runId);
this._innerRunner = runner;
}
@@ -135,7 +135,7 @@ internal class InProcessRunnerContext<TExternalInput> : IRunnerContext
throw new InvalidOperationException("Cannot export state when there are queued events. Please process or clear the events before exporting state.");
}
Dictionary<ExecutorIdentity, List<ExportedState>> queuedMessages = this._nextStep.ExportMessages();
Dictionary<ExecutorIdentity, List<PortableMessageEnvelope>> queuedMessages = this._nextStep.ExportMessages();
RunnerStateData result = new(instantiatedExecutors: [.. this._executors.Keys],
queuedMessages,
outstandingRequests: [.. this._externalRequests.Values]);
@@ -79,7 +79,7 @@ public static class InProcessExecution
CheckpointManager checkpointManager,
CancellationToken cancellation = default) where TInput : notnull
{
InProcessRunner<TInput> runner = new(workflow, checkpointManager);
InProcessRunner<TInput> runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId);
StreamingRun result = await runner.ResumeStreamAsync(fromCheckpoint, cancellation).ConfigureAwait(false);
return new(result, runner);
@@ -155,7 +155,7 @@ public static class InProcessExecution
CheckpointManager checkpointManager,
CancellationToken cancellation = default) where TInput : notnull
{
InProcessRunner<TInput, TResult> runner = new(workflow, checkpointManager);
InProcessRunner<TInput, TResult> runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId);
StreamingRun<TResult> result = await runner.ResumeStreamAsync(fromCheckpoint, cancellation).ConfigureAwait(false);
return new(result, runner);
@@ -225,7 +225,7 @@ public static class InProcessExecution
CheckpointManager checkpointManager,
CancellationToken cancellation = default) where TInput : notnull
{
InProcessRunner<TInput> runner = new(workflow, checkpointManager);
InProcessRunner<TInput> runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId);
Run result = await runner.ResumeAsync(fromCheckpoint, cancellation).ConfigureAwait(false);
return new(result, runner);
@@ -298,7 +298,7 @@ public static class InProcessExecution
CheckpointManager checkpointManager,
CancellationToken cancellation = default) where TInput : notnull
{
InProcessRunner<TInput, TResult> runner = new(workflow, checkpointManager);
InProcessRunner<TInput, TResult> runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId);
Run<TResult> result = await runner.ResumeAsync(fromCheckpoint, cancellation).ConfigureAwait(false);
return new(result, runner);
@@ -0,0 +1,176 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// Represents a value that can be exported / imported to a workflow, e.g. through an external request/response, or
/// through checkpointing. Abstracts away delayed deserialization and type conversion where appropriate.
/// </summary>
public sealed class PortableValue
{
internal PortableValue(object value)
{
this._value = value;
this.TypeId = new(value.GetType());
}
[JsonConstructor]
internal PortableValue(TypeId typeId, object value)
{
this.TypeId = Throw.IfNull(typeId);
this._value = value;
}
/// <inheritdoc />
public override bool Equals(object? obj)
{
if (obj == null)
{
return false;
}
if (obj is not PortableValue other)
{
Type targetType = obj.GetType();
return this.AsType(targetType)?.Equals(obj) ?? false;
}
return this.TypeId == other.TypeId
&& ((this.Value == null && other.Value == null)
|| this.Value != null && this.Value.Equals(other.Value));
}
/// <inheritdoc />
public override int GetHashCode()
{
return HashCode.Combine(this.TypeId, this.Value);
}
/// <inheritdoc />
public static bool operator ==(PortableValue? left, PortableValue? right)
{
if (left is null)
{
return right is null;
}
return left.Equals(right);
}
/// <inheritdoc />
public static bool operator !=(PortableValue? left, PortableValue? right) => !(left == right);
/// <summary>
/// The identifier of the type of the instance in <see cref="Value"/>.
/// </summary>
public TypeId TypeId { get; }
[JsonIgnore]
internal bool IsDelayedDeserialization => this.Value is IDelayedDeserialization;
[JsonIgnore]
internal bool IsDeserialized => this._deserializedValueCache != null;
private readonly object _value;
private object? _deserializedValueCache = null;
/// <summary>
/// Gets the raw underlying value represented by this instance.
/// </summary>
[JsonInclude]
internal object Value => this._deserializedValueCache ?? Throw.IfNull(this._value);
/// <summary>
/// Attempts to retrieve the underlying value as the specified type, deserializing if necessary.
/// </summary>
/// <remarks>If the underlying value implements delayed deserialization, this method will attempt to
/// deserialize it to the specified type. If the value is already of the requested type, it is returned directly.
/// Otherwise, the default value for TValue is returned.
///
/// For nullable value types, make sure to make <typeparamref name="TValue"/> be nullable, e.g. <c>int?</c>,
/// otherwise the default non-null value of the type is returned when the value is missing. Use <see cref="AsValue{TValue}"/>
/// to get the correct behavior when unable to pass in the explicit-nullable type.
/// </remarks>
/// <typeparam name="TValue">The type to which the value should be cast or deserialized.</typeparam>
/// <returns>The value cast or deserialized to type TValue if possible; otherwise, the default value for type TValue.</returns>
public TValue? As<TValue>()
{
if (this.Value is IDelayedDeserialization delayedDeserialization)
{
if (this._deserializedValueCache == null)
{
this._deserializedValueCache = delayedDeserialization.Deserialize<TValue>();
}
}
if (this.Value is TValue typedValue)
{
return typedValue;
}
return default;
}
/// <summary>
/// Attempts to retrieve the underlying value as the specified nullable value type, deserializing if
/// necessary.
/// </summary>
/// <remarks>If the underlying value implements delayed deserialization, this method will attempt to
/// deserialize it to the specified type. If the value is already of the requested type, it is returned directly.
/// Otherwise, null is returned.</remarks>
/// <typeparam name="TValue">The value type to which the value should be cast or deserialized.</typeparam>
/// <returns>The value cast or deserialized to type TValue if possible; otherwise, null.</returns>
public TValue? AsValue<TValue>() where TValue : struct
{
if (this.Value is IDelayedDeserialization delayedDeserialization)
{
this._deserializedValueCache ??= delayedDeserialization.Deserialize<TValue>();
}
if (this.Value is TValue typedValue)
{
return typedValue;
}
return default;
}
/// <summary>
/// Determines whether the current value can be represented as the specified type.
/// </summary>
/// <typeparam name="TValue">The type to test for compatibility with the current value.</typeparam>
/// <returns>true if the current value can be represented as type TValue; otherwise, false.</returns>
public bool Is<TValue>() => this.IsType(typeof(TValue));
/// <summary>
/// Attempts to retrieve the underlying value as the specified type, deserializing if necessary.
/// </summary>
/// <param name="targetType">The type to which the value should be cast or deserialized.</param>
/// <returns>The value cast or deserialized to type targetType if possible; otherwise, null.</returns>
public object? AsType(Type targetType)
{
Throw.IfNull(targetType);
if (this.Value is IDelayedDeserialization delayedDeserialization)
{
this._deserializedValueCache ??= delayedDeserialization.Deserialize(targetType);
}
return this.Value is not null && targetType.IsAssignableFrom(this.Value.GetType())
? this.Value
: this._deserializedValueCache = null;
}
/// <summary>
/// Determines whether the current instance can be assigned to the specified target type.
/// </summary>
/// <param name="targetType">The type to compare with the current instance. Cannot be null.</param>
/// <returns>true if the current instance can be assigned to targetType; otherwise, false.</returns>
public bool IsType(Type targetType) => this.AsType(targetType) != null;
}
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.Workflows;
@@ -8,19 +9,17 @@ namespace Microsoft.Agents.Workflows;
/// <summary>
/// Represents a unique key within a specific scope, combining a scope identifier and a key string.
/// </summary>
/// <param name="scopeId">The <see cref="ScopeId"/> associated with this key.</param>
/// <param name="key">The unique key within the specified scope.</param>
public class ScopeKey(ScopeId scopeId, string key)
public class ScopeKey
{
/// <summary>
/// The identifier for the scope associated with this key.
/// </summary>
public ScopeId ScopeId { get; } = Throw.IfNull(scopeId);
public ScopeId ScopeId { get; }
/// <summary>
/// The unique key within the specified scope.
/// </summary>
public string Key { get; } = Throw.IfNullOrEmpty(key);
public string Key { get; }
/// <summary>
/// Initializes a new instance of the <see cref="ScopeKey"/> class.
@@ -32,6 +31,18 @@ public class ScopeKey(ScopeId scopeId, string key)
: this(new ScopeId(Throw.IfNullOrEmpty(executorId), scopeName), key)
{ }
/// <summary>
/// Iniitalizes a new instance of the <see cref="ScopeKey"/> class.
/// </summary>
/// <param name="scopeId">The <see cref="ScopeId"/> associated with this key.</param>
/// <param name="key">The unique key within the specified scope.</param>
[JsonConstructor]
public ScopeKey(ScopeId scopeId, string key)
{
this.ScopeId = Throw.IfNull(scopeId);
this.Key = Throw.IfNullOrEmpty(key);
}
/// <inheritdoc/>
public override string ToString()
{
@@ -65,14 +65,16 @@ internal class RequestInfoExecutor : Executor
Throw.IfNull(message);
Throw.IfNull(message.Data);
if (!this.Port.Response.IsAssignableFrom(message.Data.GetType()))
object? data = message.DataAs(this.Port.Response);
if (data == null)
{
throw new InvalidOperationException(
$"Message type {message.Data.GetType().Name} is not assignable to the response type {this.Port.Response.Name} of input port {this.Port.Id}.");
$"Message type {message.Data.TypeId} is not assignable to the response type {this.Port.Response.Name} of input port {this.Port.Id}.");
}
await context.SendMessageAsync(message).ConfigureAwait(false);
await context.SendMessageAsync(message.Data).ConfigureAwait(false);
await context.SendMessageAsync(data).ConfigureAwait(false);
return message;
}
@@ -1,10 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// Base class for SuperStep-scoped events, for example, <see cref="SuperStepCompletedEvent"/>
/// </summary>
[JsonDerivedType(typeof(SuperStepStartedEvent))]
[JsonDerivedType(typeof(SuperStepCompletedEvent))]
public class SuperStepEvent(int stepNumber, object? data = null) : WorkflowEvent(data)
{
/// <summary>
@@ -30,7 +30,7 @@ public sealed class SwitchBuilder
/// <param name="executors">One or more executors to associate with the predicate. Each executor will be invoked if the predicate matches.
/// Cannot be null.</param>
/// <returns>The current <see cref="SwitchBuilder"/> instance, allowing for method chaining.</returns>
public SwitchBuilder AddCase(Func<object?, bool> predicate, params ExecutorIsh[] executors)
public SwitchBuilder AddCase<T>(Func<T?, bool> predicate, params ExecutorIsh[] executors)
{
Throw.IfNull(predicate);
Throw.IfNull(executors);
@@ -49,7 +49,8 @@ public sealed class SwitchBuilder
indicies.Add(index);
}
this._caseMap.Add((predicate, indicies));
Func<object?, bool> casePredicate = WorkflowBuilder.CreateConditionFunc(predicate)!;
this._caseMap.Add((casePredicate, indicies));
return this;
}
@@ -83,7 +84,7 @@ public sealed class SwitchBuilder
List<(Func<object?, bool> Predicate, HashSet<int> OutgoingIndicies)> caseMap = this._caseMap;
HashSet<int> defaultIndicies = this._defaultIndicies;
return builder.AddFanOutEdge(source, CasePartitioner, [.. this._executors]);
return builder.AddFanOutEdge<object>(source, CasePartitioner, [.. this._executors]);
IEnumerable<int> CasePartitioner(object? input, int targetCount)
{
@@ -2,6 +2,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Agents.Workflows.Specialized;
using Microsoft.Shared.Diagnostics;
@@ -17,10 +19,20 @@ public class Workflow
/// </summary>
internal Dictionary<string, ExecutorRegistration> Registrations { get; init; } = new();
internal Dictionary<string, HashSet<Edge>> Edges { get; init; } = new();
/// <summary>
/// Gets the collection of edges grouped by their source node identifier.
/// </summary>
public Dictionary<string, HashSet<Edge>> Edges { get; internal init; } = new();
public Dictionary<string, HashSet<EdgeInfo>> ReflectEdges()
{
return this.Edges.Keys.ToDictionary(
keySelector: key => key,
elementSelector: key => new HashSet<EdgeInfo>(this.Edges[key].Select(RepresentationExtensions.ToEdgeInfo))
);
}
internal Dictionary<string, InputPort> Ports { get; init; } = new();
/// <summary>
/// Gets the collection of external request ports, keyed by their ID.
@@ -28,7 +40,13 @@ public class Workflow
/// <remarks>
/// Each port has a corresponding entry in the <see cref="Registrations"/> dictionary.
/// </remarks>
public Dictionary<string, InputPort> Ports { get; internal init; } = new();
public Dictionary<string, InputPortInfo> ReflectPorts()
{
return this.Ports.Keys.ToDictionary(
keySelector: key => key,
elementSelector: key => this.Ports[key].ToPortInfo()
);
}
/// <summary>
/// Gets the identifier of the starting executor of the workflow.
@@ -20,15 +20,16 @@ namespace Microsoft.Agents.Workflows;
/// <see cref="ExecutorIsh.Type.Unbound"/>.</remarks>
public class WorkflowBuilder
{
private record struct EdgeId(string SourceId, string TargetId)
private record struct EdgeConnection(string SourceId, string TargetId)
{
public override string ToString() => $"{this.SourceId} -> {this.TargetId}";
}
private int _edgeCount = 0;
private readonly Dictionary<string, ExecutorRegistration> _executors = new();
private readonly Dictionary<string, HashSet<Edge>> _edges = new();
private readonly HashSet<string> _unboundExecutors = new();
private readonly HashSet<EdgeId> _conditionlessEdges = new();
private readonly HashSet<EdgeConnection> _conditionlessConnections = new();
private readonly Dictionary<string, InputPort> _inputPorts = new();
private readonly string _startExecutorId;
@@ -121,6 +122,52 @@ public class WorkflowBuilder
return edges;
}
/// <summary>
/// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a
/// condition.
/// </summary>
/// <param name="source">The executor that acts as the source node of the edge. Cannot be null.</param>
/// <param name="target">The executor that acts as the target node of the edge. Cannot be null.</param>
/// <returns>The current instance of <see cref="WorkflowBuilder"/>.</returns>
/// <exception cref="InvalidOperationException">Thrown if an unconditional edge between the specified source and target
/// executors already exists.</exception>
public WorkflowBuilder AddEdge(ExecutorIsh source, ExecutorIsh target)
=> this.AddEdge<object>(source, target, null);
internal static Func<object?, bool>? CreateConditionFunc<T>(Func<T?, bool>? condition)
{
if (condition == null)
{
return null;
}
return maybeObj =>
{
if (typeof(T) != typeof(object) && maybeObj is PortableValue portableValue)
{
maybeObj = portableValue.AsType(typeof(T));
}
return condition(maybeObj is T typed ? typed : default);
};
}
internal static Func<object?, bool>? CreateConditionFunc<T>(Func<object?, bool>? condition)
{
if (condition == null)
{
return null;
}
return maybeObj =>
{
if (typeof(T) != typeof(object) && maybeObj is PortableValue portableValue)
{
maybeObj = portableValue.AsType(typeof(T));
}
return condition(maybeObj);
};
}
private EdgeId TakeEdgeId() => new(Interlocked.Increment(ref this._edgeCount));
/// <summary>
/// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a
/// condition.
@@ -132,7 +179,7 @@ public class WorkflowBuilder
/// <returns>The current instance of <see cref="WorkflowBuilder"/>.</returns>
/// <exception cref="InvalidOperationException">Thrown if an unconditional edge between the specified source and target
/// executors already exists.</exception>
public WorkflowBuilder AddEdge(ExecutorIsh source, ExecutorIsh target, Func<object?, bool>? condition = null)
public WorkflowBuilder AddEdge<T>(ExecutorIsh source, ExecutorIsh target, Func<T?, bool>? condition = null)
{
// Add an edge from source to target with an optional condition.
// This is a low-level builder method that does not enforce any specific executor type.
@@ -140,21 +187,51 @@ public class WorkflowBuilder
Throw.IfNull(source);
Throw.IfNull(target);
EdgeId id = new(source.Id, target.Id);
if (condition == null && this._conditionlessEdges.Contains(id))
EdgeConnection connection = new(source.Id, target.Id);
if (condition == null && this._conditionlessConnections.Contains(connection))
{
throw new InvalidOperationException(
$"An edge from '{source.Id}' to '{target.Id}' already exists without a condition. " +
"You cannot add another edge without a condition for the same source and target.");
}
DirectEdgeData directEdge = new(this.Track(source).Id, this.Track(target).Id, condition);
DirectEdgeData directEdge = new(this.Track(source).Id, this.Track(target).Id, this.TakeEdgeId(), CreateConditionFunc(condition));
this.EnsureEdgesFor(source.Id).Add(new(directEdge));
return this;
}
/// <summary>
/// Adds a fan-out edge from the specified source executor to one or more target executors, optionally using a
/// custom partitioning function.
/// </summary>
/// <remarks>If a partitioner function is provided, it will be used to distribute input across the target
/// executors. The order of targets determines their mapping in the partitioning process.</remarks>
/// <param name="source">The source executor from which the fan-out edge originates. Cannot be null.</param>
/// <param name="targets">One or more target executors that will receive the fan-out edge. Cannot be null or empty.</param>
/// <returns>The current instance of <see cref="WorkflowBuilder"/>.</returns>
public WorkflowBuilder AddFanOutEdge(ExecutorIsh source, params ExecutorIsh[] targets)
=> this.AddFanOutEdge<object>(source, null, targets);
internal static Func<object?, int, IEnumerable<int>>? CreateEdgeAssignerFunc<T>(Func<T?, int, IEnumerable<int>>? partitioner)
{
if (partitioner == null)
{
return null;
}
return (object? maybeObj, int count) =>
{
if (typeof(T) != typeof(object) && maybeObj is PortableValue portableValue)
{
maybeObj = portableValue.AsType(typeof(T));
}
return partitioner(maybeObj is T typed ? typed : default, count);
};
}
/// <summary>
/// Adds a fan-out edge from the specified source executor to one or more target executors, optionally using a
/// custom partitioning function.
@@ -166,7 +243,7 @@ public class WorkflowBuilder
/// If null, messages will route to all targets.</param>
/// <param name="targets">One or more target executors that will receive the fan-out edge. Cannot be null or empty.</param>
/// <returns>The current instance of <see cref="WorkflowBuilder"/>.</returns>
public WorkflowBuilder AddFanOutEdge(ExecutorIsh source, Func<object?, int, IEnumerable<int>>? partitioner = null, params ExecutorIsh[] targets)
public WorkflowBuilder AddFanOutEdge<T>(ExecutorIsh source, Func<T?, int, IEnumerable<int>>? partitioner = null, params ExecutorIsh[] targets)
{
Throw.IfNull(source);
Throw.IfNullOrEmpty(targets);
@@ -174,7 +251,8 @@ public class WorkflowBuilder
FanOutEdgeData fanOutEdge = new(
this.Track(source).Id,
targets.Select(target => this.Track(target).Id).ToList(),
partitioner);
this.TakeEdgeId(),
CreateEdgeAssignerFunc<T>(partitioner));
this.EnsureEdgesFor(source.Id).Add(new(fanOutEdge));
@@ -198,7 +276,8 @@ public class WorkflowBuilder
FanInEdgeData edgeData = new(
sources.Select(source => this.Track(source).Id).ToList(),
this.Track(target).Id);
this.Track(target).Id,
this.TakeEdgeId());
foreach (string sourceId in edgeData.SourceIds)
{
@@ -28,18 +28,22 @@ public static class WorkflowBuilderExtensions
{
Throw.IfNullOrEmpty(executors);
Func<object?, bool> predicate = WorkflowBuilder.CreateConditionFunc<TMessage>((Func<object?, bool>)IsAllowedType)!;
if (executors.Length == 1)
{
return builder.AddEdge(source, executors[0], IsAllowedType);
return builder.AddEdge(source, executors[0], predicate);
}
return builder.AddSwitch(source,
(switch_) =>
{
switch_.AddCase(IsAllowedType, executors);
switch_.AddCase(predicate, executors);
});
bool IsAllowedType(object? message) => message is TMessage;
// The reason we can check for "not null" here is that CreateConditionFunc<T> will do the correct unwrapping
// logic for PortableValues.
bool IsAllowedType(object? message) => message is not null;
}
/// <summary>
@@ -54,18 +58,22 @@ public static class WorkflowBuilderExtensions
{
Throw.IfNullOrEmpty(executors);
Func<object?, bool> predicate = WorkflowBuilder.CreateConditionFunc<TMessage>((Func<object?, bool>)IsAllowedType)!;
if (executors.Length == 1)
{
return builder.AddEdge(source, executors[0], IsAllowedType);
return builder.AddEdge(source, executors[0], predicate);
}
return builder.AddSwitch(source,
(switch_) =>
{
switch_.AddCase(IsAllowedType, executors);
switch_.AddCase(predicate, executors);
});
bool IsAllowedType(object? message) => message is not TMessage;
// The reason we can check for "null" here is that CreateConditionFunc<T> will do the correct unwrapping
// logic for PortableValues.
bool IsAllowedType(object? message) => message is null;
}
/// <summary>
@@ -158,6 +166,8 @@ public static class WorkflowBuilderExtensions
/// to access the aggregated output directly. The completion condition can be used to implement custom termination
/// logic, such as early stopping when a desired result is reached.</remarks>
/// <typeparam name="TInput">The type of input items processed by the workflow.</typeparam>
/// <typeparam name="TIntermediate">The type of items generated by the <paramref name="outputSource"/>,
/// and aggregated by the <paramref name="aggregator"/>.</typeparam>
/// <typeparam name="TResult">The type of aggregated result produced by the workflow.</typeparam>
/// <param name="builder">The workflow builder used to construct the workflow and define its execution graph.</param>
/// <param name="outputSource">The executor that produces output items to be collected and aggregated. Cannot be null.</param>
@@ -166,18 +176,18 @@ public static class WorkflowBuilderExtensions
/// aggregated result. If null, the workflow will not raise a <see cref="WorkflowCompletedEvent"/>.</param>
/// <returns>A workflow that collects output from the specified executor, aggregates results, and exposes the aggregated
/// output.</returns>
public static Workflow<TInput, TResult> BuildWithOutput<TInput, TResult>(
public static Workflow<TInput, TResult> BuildWithOutput<TInput, TIntermediate, TResult>(
this WorkflowBuilder builder,
ExecutorIsh outputSource,
StreamingAggregator<TInput, TResult> aggregator,
Func<TInput, TResult?, bool>? completionCondition = null)
StreamingAggregator<TIntermediate, TResult> aggregator,
Func<TIntermediate, TResult?, bool>? completionCondition = null)
{
Throw.IfNull(outputSource);
Throw.IfNull(aggregator);
OutputCollectorExecutor<TInput, TResult> outputSink = new(aggregator, completionCondition);
OutputCollectorExecutor<TIntermediate, TResult> outputSink = new(aggregator, completionCondition);
// TODO: Check taht the outputSource has a TResult output?
// TODO: Check that the outputSource has a TResult output?
builder.AddEdge(outputSource, outputSink);
Workflow<TInput> workflow = builder.Build<TInput>();
@@ -1,10 +1,19 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text.Json.Serialization;
namespace Microsoft.Agents.Workflows;
/// <summary>
/// Base class for <see cref="Workflow"/>-scoped events.
/// </summary>
[JsonDerivedType(typeof(ExecutorEvent))]
[JsonDerivedType(typeof(SuperStepEvent))]
[JsonDerivedType(typeof(WorkflowStartedEvent))]
[JsonDerivedType(typeof(WorkflowCompletedEvent))]
[JsonDerivedType(typeof(WorkflowErrorEvent))]
[JsonDerivedType(typeof(WorkflowWarningEvent))]
[JsonDerivedType(typeof(RequestInfoEvent))]
public class WorkflowEvent(object? data = null)
{
/// <summary>
@@ -30,6 +30,6 @@ public static class WorkflowHostingExtensions
{ "data", request.Data}
};
return new FunctionCallContent(request.RequestId, request.Port.Id, parameters);
return new FunctionCallContent(request.RequestId, request.PortInfo.PortId, parameters);
}
}
@@ -3,6 +3,8 @@
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Agents.Workflows.Execution;
using Microsoft.Extensions.AI;
using static Microsoft.Agents.Workflows.WorkflowMessageStore;
@@ -53,9 +55,38 @@ internal static partial class WorkflowsJsonUtilities
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
NumberHandling = JsonNumberHandling.AllowReadingFromString)]
// Agent abstraction types
// Checkpointing Types
[JsonSerializable(typeof(Checkpoint))]
[JsonSerializable(typeof(CheckpointInfo))]
[JsonSerializable(typeof(PortableValue))]
[JsonSerializable(typeof(PortableMessageEnvelope))]
// Runtime State Types
[JsonSerializable(typeof(ScopeKey))]
[JsonSerializable(typeof(ScopeId))]
[JsonSerializable(typeof(ExecutorIdentity))]
[JsonSerializable(typeof(RunnerStateData))]
// Workflow Representation Types
[JsonSerializable(typeof(WorkflowInfo))]
[JsonSerializable(typeof(EdgeConnection))]
// Workflow-as-Agent
[JsonSerializable(typeof(StoreState))]
// Message Types
[JsonSerializable(typeof(ChatMessage))]
[JsonSerializable(typeof(ExternalRequest))]
[JsonSerializable(typeof(ExternalResponse))]
[JsonSerializable(typeof(TurnToken))]
// Event Types
//[JsonSerializable(typeof(WorkflowEvent))]
// Currently cannot be serialized because it includes Exceptions.
// We'll need a way to marshal this correct in the AgentRuntime case.
// For now this is okay, because we never serialize WorkflowEvents into
// checkpoints.
[JsonSerializable(typeof(JsonElement))]
[ExcludeFromCodeCoverage]
internal sealed partial class JsonContext : JsonSerializerContext;
}
@@ -19,7 +19,7 @@ public class EdgeMapSmokeTests
Dictionary<string, HashSet<Edge>> workflowEdges = new();
FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3");
FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0));
Edge fanInEdge = new(edgeData);
workflowEdges["executor1"] = [fanInEdge];
@@ -31,7 +31,7 @@ public class EdgeRunnerTests
runContext.Executors["executor1"] = new ForwardMessageExecutor<string>("executor1");
runContext.Executors["executor2"] = new ForwardMessageExecutor<string>("executor2");
DirectEdgeData edgeData = new("executor1", "executor2", condition);
DirectEdgeData edgeData = new("executor1", "executor2", new EdgeId(0), condition);
DirectEdgeRunner runner = new(runContext, edgeData);
MessageEnvelope envelope = new(MessageVariant1, targetId: targetId);
@@ -90,7 +90,7 @@ public class EdgeRunnerTests
? (targetMatch.Value ? "executor2" : "executor1")
: null;
FanOutEdgeData edgeData = new("executor1", ["executor2", "executor3"], assigner);
FanOutEdgeData edgeData = new("executor1", ["executor2", "executor3"], new EdgeId(0), assigner);
FanOutEdgeRunner runner = new(runContext, edgeData);
MessageEnvelope envelope = new("test", targetId: targetId);
@@ -145,7 +145,7 @@ public class EdgeRunnerTests
runContext.Executors["executor2"] = new ForwardMessageExecutor<string>("executor2");
runContext.Executors["executor3"] = new ForwardMessageExecutor<string>("executor3");
FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3");
FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0));
FanInEdgeRunner runner = new(runContext, edgeData);
// Step 1: Send message from executor1, should not forward yet.
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Checkpointing;
namespace Microsoft.Agents.Workflows.UnitTests;
internal sealed class InMemoryJsonStore : JsonCheckpointStore
{
private readonly Dictionary<string, RunCheckpointCache<JsonElement>> _store = new();
private RunCheckpointCache<JsonElement> EnsureRunStore(string runId)
{
if (!this._store.TryGetValue(runId, out RunCheckpointCache<JsonElement>? runStore))
{
runStore = this._store[runId] = new();
}
return runStore;
}
public override ValueTask<CheckpointInfo> CreateCheckpointAsync(string runId, JsonElement value, CheckpointInfo? parent = null)
{
return new(this.EnsureRunStore(runId).Add(runId, value));
}
public override ValueTask<JsonElement> RetrieveCheckpointAsync(string runId, CheckpointInfo key)
{
if (!this.EnsureRunStore(runId).TryGet(key, out JsonElement result))
{
throw new KeyNotFoundException("Could not retrieve checkpoint with id {key.CheckpointId} for run {runId}");
}
return new(result);
}
public override ValueTask<IEnumerable<CheckpointInfo>> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null)
{
return new(this.EnsureRunStore(runId).Index);
}
}
@@ -127,7 +127,7 @@ public class InProcessStateTests
.AddEdge(writer, validator, MaxTurns(4))
.AddEdge(validator, writer, MaxTurns(4)).Build<TurnToken>();
Checkpointed<Run> checkpointed = await InProcessExecution.RunAsync(workflow, new(), new CheckpointManager());
Checkpointed<Run> checkpointed = await InProcessExecution.RunAsync(workflow, new(), CheckpointManager.Default);
checkpointed.Checkpoints.Should().HaveCount(6);
checkpointed.Run.Status.Should().Be(RunStatus.Idle);
@@ -0,0 +1,653 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using FluentAssertions;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Agents.Workflows.Execution;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.Workflows.UnitTests;
public class JsonSerializationTests
{
private static JsonSerializerOptions TestCustomSerializedJsonOptions
{
get
{
JsonSerializerOptions options = new(TestJsonContext.Default.Options);
options.MakeReadOnly();
return options;
}
}
private static int s_nextEdgeId = 0;
private static EdgeId TakeEdgeId() => new(Interlocked.Increment(ref s_nextEdgeId));
private static T RunJsonRoundtrip<T>(T value, JsonSerializerOptions? externalOptions = null, Expression<Func<T, bool>>? predicate = null)
{
JsonMarshaller marshaller = new(externalOptions);
JsonElement element = marshaller.Marshal<T>(value);
T deserialized = marshaller.Marshal<T>(element);
if (deserialized != null)
{
if (predicate != null)
{
deserialized.Should().Match<T>(predicate);
}
return deserialized;
}
Debug.Fail($"Could not roundtrip type '{typeof(T).Name}'. JSON = '{element}'.");
throw new NotSupportedException($"Could not roundtrip type '{typeof(T).Name}'.");
}
[Fact]
public void Test_EdgeConnection_JsonRoundtrip()
{
EdgeConnection connection = new(new List<string> { "Source1", "Source2" }, new List<string> { "Sink1", "Sink2" });
RunJsonRoundtrip(connection, predicate: connection.CreateValidator());
}
[Fact]
public void Test_TypeId_JsonRoundtrip()
{
TypeId type = new(typeof(Type));
RunJsonRoundtrip(type, predicate: CreateValidator());
Expression<Func<TypeId, bool>> CreateValidator()
{
return deserialized => deserialized.AssemblyName == type.AssemblyName &&
deserialized.TypeName == type.TypeName &&
deserialized.IsMatch(typeof(Type));
}
}
[Fact]
public void Test_ExecutorInfo_JsonRoundtrip()
{
ExecutorInfo executorInfo = new(new(typeof(ForwardMessageExecutor<string>)), "ForwardString");
RunJsonRoundtrip(executorInfo, predicate: CreateValidator());
Expression<Func<ExecutorInfo, bool>> CreateValidator()
{
return deserialized => deserialized.ExecutorId == executorInfo.ExecutorId &&
// Rely on the TypeId test to probe TypeId serialization - just validate that we got a functional TypeId
deserialized.ExecutorType.IsMatch(typeof(ForwardMessageExecutor<string>));
}
}
private static InputPort TestPort => InputPort.Create<string, int>("StringToInt");
private static InputPortInfo TestPortInfo => TestPort.ToPortInfo();
[Fact]
public void Test_InputPortInfo_JsonRoundtrip()
{
RunJsonRoundtrip(TestPortInfo, predicate: TestPort.CreatePortInfoValidator());
}
private static DirectEdgeInfo TestDirectEdgeInfo_NoCondition => new(new("SourceExecutor", "TargetExecutor", TakeEdgeId(), condition: null));
private static DirectEdgeInfo TestDirectEdgeInfo_Condition => new(new("SourceExecutor", "TargetExecutor", TakeEdgeId(), condition: msg => msg is not null));
[Fact]
public void Test_DirectEdgeInfo_JsonRoundtrip()
{
RunJsonRoundtrip(TestDirectEdgeInfo_NoCondition, predicate: TestDirectEdgeInfo_NoCondition.CreateValidator());
RunJsonRoundtrip(TestDirectEdgeInfo_Condition, predicate: TestDirectEdgeInfo_Condition.CreateValidator());
}
private static FanOutEdgeInfo TestFanOutEdgeInfo_NoAssigner => new(new("SourceExecutor", ["TargetExecutor1", "TargetExecutor2"], TakeEdgeId(), assigner: null));
private static FanOutEdgeInfo TestFanOutEdgeInfo_Assigner => new(new("SourceExecutor", ["TargetExecutor1", "TargetExecutor2"], TakeEdgeId(), assigner: (msg, count) => []));
[Fact]
public void Test_FanOutEdgeInfo_JsonRoundtrip()
{
RunJsonRoundtrip(TestFanOutEdgeInfo_NoAssigner, predicate: TestFanOutEdgeInfo_NoAssigner.CreateValidator());
RunJsonRoundtrip(TestFanOutEdgeInfo_Assigner, predicate: TestFanOutEdgeInfo_Assigner.CreateValidator());
}
private static FanInEdgeData TestFanInEdgeData => new(["SourceExecutor1", "SourceExecutor2"], "TargetExecutor", TakeEdgeId());
private static FanInEdgeInfo TestFanInEdgeInfo => new(TestFanInEdgeData);
[Fact]
public void Test_FanInEdgeInfo_JsonRoundtrip()
{
RunJsonRoundtrip(TestFanInEdgeInfo, predicate: TestFanInEdgeInfo.CreateValidator());
}
private static EdgeInfo TestEdgeInfo_DirectNoCondition { get; } = TestDirectEdgeInfo_NoCondition;
private static EdgeInfo TestEdgeInfo_DirectCondition { get; } = TestDirectEdgeInfo_Condition;
private static EdgeInfo TestEdgeInfo_FanOutNoAssigner { get; } = TestFanOutEdgeInfo_NoAssigner;
private static EdgeInfo TestEdgeInfo_FanOutAssigner { get; } = TestFanOutEdgeInfo_Assigner;
private static EdgeInfo TestEdgeInfo_FanIn { get; } = TestFanInEdgeInfo;
[Fact]
public void Test_EdgeInfoPolymorphism_JsonRoundtrip()
{
RunJsonRoundtrip(TestEdgeInfo_DirectNoCondition, predicate: TestEdgeInfo_DirectNoCondition.CreatePolyValidator());
RunJsonRoundtrip(TestEdgeInfo_DirectCondition, predicate: TestEdgeInfo_DirectCondition.CreatePolyValidator());
RunJsonRoundtrip(TestEdgeInfo_FanOutNoAssigner, predicate: TestEdgeInfo_FanOutNoAssigner.CreatePolyValidator());
RunJsonRoundtrip(TestEdgeInfo_FanOutAssigner, predicate: TestEdgeInfo_FanOutAssigner.CreatePolyValidator());
RunJsonRoundtrip(TestEdgeInfo_FanIn, predicate: TestEdgeInfo_FanIn.CreatePolyValidator());
}
private const string ForwardStringId = nameof(s_forwardString);
private const string ForwardIntId = nameof(s_forwardInt);
private static readonly ExecutorIdentity s_forwardString = new() { Id = ForwardStringId };
private static readonly ExecutorIdentity s_forwardInt = new() { Id = ForwardIntId };
private const string IntToStringId = nameof(IntToString);
private const string StringToIntId = nameof(StringToInt);
private static InputPortInfo IntToString => InputPort.Create<int, string>(IntToStringId).ToPortInfo();
private static InputPortInfo StringToInt => InputPort.Create<string, int>(StringToIntId).ToPortInfo();
private static Workflow<string, int> CreateTestWorkflow()
{
ForwardMessageExecutor<string> forwardString = new(ForwardStringId);
ForwardMessageExecutor<int> forwardInt = new(ForwardIntId);
InputPort stringToInt = InputPort.Create<string, int>(StringToIntId);
InputPort intToString = InputPort.Create<int, string>(IntToStringId);
WorkflowBuilder builder = new(forwardString);
builder.AddEdge(forwardString, stringToInt)
.AddEdge(stringToInt, forwardInt)
.AddEdge(forwardInt, intToString);
Workflow<string, int> workflow = builder.BuildWithOutput<string, int, int>(
intToString,
StreamingAggregators.Last<int>(), (int _, int __) => true);
return workflow;
}
private static WorkflowInfo TestWorkflowInfo => CreateTestWorkflow().ToWorkflowInfo();
private static void ValidateWorkflowInfo(WorkflowInfo actual, WorkflowInfo prototype)
{
ValidateExecutorDictionary(prototype.Executors, prototype.Edges, actual.Executors, actual.Edges);
ValidateInputPorts(prototype.InputPorts, actual.InputPorts);
actual.InputType.Should().Match<TypeId>(prototype.InputType.CreateValidator());
actual.StartExecutorId.Should().Be(prototype.StartExecutorId);
actual.OutputType.Should().NotBeNull().And.Match<TypeId>(prototype.OutputType!.CreateValidator());
actual.OutputCollectorId.Should().NotBeNull().And.Be(prototype.OutputCollectorId);
void ValidateExecutorDictionary(Dictionary<string, ExecutorInfo> expected,
Dictionary<string, List<EdgeInfo>> expectedEdges,
Dictionary<string, ExecutorInfo> actual,
Dictionary<string, List<EdgeInfo>> actualEdges)
{
actual.Should().HaveCount(expected.Count);
actualEdges.Should().HaveCount(expectedEdges.Count);
foreach (string key in expected.Keys)
{
actual.Should().ContainKey(key);
ExecutorInfo actualValue = actual[key];
ExecutorInfo expectedValue = expected[key];
actualValue.Should().Match<ExecutorInfo>(expectedValue.CreateValidator());
if (expectedEdges.TryGetValue(key, out List<EdgeInfo>? expectedEdgeList))
{
List<EdgeInfo>? actualEdgeList = actualEdges.Should().ContainKey(key).WhoseValue;
actualEdgeList.Should().NotBeNull();
ValidateExecutorEdges(expectedEdgeList, actualEdgeList);
}
}
}
void ValidateExecutorEdges(List<EdgeInfo> expected, List<EdgeInfo> actual)
{
actual.Should().HaveCount(expected.Count);
foreach (EdgeInfo expectedEdge in expected)
{
actual.Should().ContainSingle(edge => edge.CreatePolyValidator().Compile()(edge));
}
}
void ValidateInputPorts(HashSet<InputPortInfo> expected, HashSet<InputPortInfo> actual)
=> actual.Should().HaveCount(expected.Count).And.IntersectWith(expected);
}
[Fact]
public void Test_WorkflowInfo_JsonRoundtrip()
{
WorkflowInfo prototype = TestWorkflowInfo;
JsonMarshaller marshaller = new();
JsonElement jsonElement = marshaller.Marshal(prototype, typeof(WorkflowInfo));
WorkflowInfo deserialized = marshaller.Marshal<WorkflowInfo>(jsonElement);
ValidateWorkflowInfo(deserialized, prototype);
}
private static ExecutorIdentity TestIdentity => new() { Id = "Executor1" };
[Fact]
public void Test_ExecutorIdentity_JsonRoundtrip()
{
RunJsonRoundtrip(TestIdentity, predicate: TestIdentity.CreateValidator());
RunJsonRoundtrip(ExecutorIdentity.None, predicate: ExecutorIdentity.None.CreateValidator());
}
private static ScopeId TestScopeId_Private => new("Executor1", null);
private static ScopeId TestScopeId_Public => new("Executor1", "Scope1");
[Fact]
public void Test_ScopeId_JsonRoundtrip()
{
RunJsonRoundtrip(TestScopeId_Private, predicate: TestScopeId_Private.CreateValidator());
RunJsonRoundtrip(TestScopeId_Public, predicate: TestScopeId_Public.CreateValidator());
}
private static ScopeKey TestScopeKey_Private => new(TestScopeId_Private, "Key1");
private static ScopeKey TestScopeKey_Public => new(TestScopeId_Public, "Key1");
[Fact]
public void Test_ScopeKey_JsonRoundtrip()
{
RunJsonRoundtrip(TestScopeKey_Private, predicate: TestScopeKey_Private.CreateValidator());
RunJsonRoundtrip(TestScopeKey_Public, predicate: TestScopeKey_Public.CreateValidator());
}
private static ExternalRequest TestExternalRequest => ExternalRequest.Create(TestPort, "Request1", "TestData");
[Fact]
public void SanityCheck_JsonTypeInfo()
{
JsonTypeInfo? info = WorkflowsJsonUtilities.JsonContext.Default.GetTypeInfo(typeof(string));
info.Should().NotBeNull();
}
[Fact]
public void Test_PortableValue_JsonRoundtrip_BuiltInType()
{
PortableValue value = new("TestString");
PortableValue result = RunJsonRoundtrip(value);
result.Should().Be(value);
// Also validate that we can extract the value as the correct type
string? extracted = result.As<string>();
extracted.Should().Be("TestString");
// And that we can't extract it as an incorrect type
result.Is<int>().Should().BeFalse();
}
[Fact]
public void Test_PortableValue_JsonRoundTrip_InternalType()
{
ChatMessage message = new(ChatRole.User, "Hello, world!");
PortableValue value = new(message);
PortableValue result = RunJsonRoundtrip(value);
result.Should().Be(value);
// Also validate that we can extract the value as the correct type
ChatMessage? chatMessage = result.As<ChatMessage>();
chatMessage.Should().NotBeNull();
chatMessage.Role.Should().Be(ChatRole.User);
chatMessage.Text.Should().Be("Hello, world!");
// And that we can't extract it as an incorrect type
result.Is<int>().Should().BeFalse();
}
[Fact]
public void Test_PortableValue_JsonRoundTrip_CustomType()
{
TestJsonSerializable test = new() { Id = 42, Name = "Test" };
PortableValue value = new(test);
PortableValue result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions);
result.Should().Be(value);
// Also validate that we can extract the value as the correct type
TestJsonSerializable? extracted = result.As<TestJsonSerializable>();
extracted.Should().NotBeNull();
extracted.Id.Should().Be(42);
extracted.Name.Should().Be("Test");
// And that we can't extract it as an incorrect type
result.Is<int>().Should().BeFalse();
}
private static void ValidateExternalRequest(ExternalRequest actual, ExternalRequest expected)
{
bool isIdEqual = actual.RequestId == expected.RequestId;
bool isPortEqual = actual.PortInfo == expected.PortInfo;
bool isDataEqual = actual.Data == expected.Data;
isIdEqual.Should().BeTrue();
isPortEqual.Should().BeTrue();
isDataEqual.Should().BeTrue();
}
[Fact]
public void Test_ExternalRequest_JsonRoundtrip()
{
ExternalRequest result = RunJsonRoundtrip(TestExternalRequest);
ValidateExternalRequest(result, TestExternalRequest);
}
private static ExternalResponse TestExternalResponse => TestExternalRequest.CreateResponse(123);
[Fact]
public void Test_ExternalResponse_JsonRoundtrip()
{
ExternalResponse result = RunJsonRoundtrip(TestExternalResponse);
bool isIdEqual = result.RequestId == TestExternalResponse.RequestId;
bool isPortEqual = result.PortInfo == TestExternalResponse.PortInfo;
bool isDataEqual = result.Data == TestExternalResponse.Data;
isIdEqual.Should().BeTrue();
isPortEqual.Should().BeTrue();
isDataEqual.Should().BeTrue();
}
[Fact]
public void Test_PortableMessageEnvelope_JsonRoundtrip_BuiltInType()
{
string message = "TestMessage";
MessageEnvelope envelope = new(message, new TypeId(typeof(object)), targetId: "Target1");
PortableMessageEnvelope value = new(envelope);
PortableMessageEnvelope result = RunJsonRoundtrip(value);
bool isTypeEqual = result.MessageType == value.MessageType;
bool isTargetEqual = result.TargetId == value.TargetId;
bool isMessageEqual = result.Message == value.Message;
isTypeEqual.Should().BeTrue();
isTargetEqual.Should().BeTrue();
isMessageEqual.Should().BeTrue();
MessageEnvelope reconstructed = result.ToMessageEnvelope();
reconstructed.MessageType.Should().Be(envelope.MessageType);
reconstructed.TargetId.Should().Be(envelope.TargetId);
reconstructed.Message.Should().Be(envelope.Message);
}
[Fact]
public void Test_PortableMessageEnvelope_JsonRoundtrip_InternalType()
{
ChatMessage message = new(ChatRole.User, "Hello, world!");
MessageEnvelope envelope = new(message, new TypeId(typeof(object)), targetId: "Target1");
PortableMessageEnvelope value = new(envelope);
PortableMessageEnvelope result = RunJsonRoundtrip(value);
bool isTypeEqual = result.MessageType == value.MessageType;
bool isTargetEqual = result.TargetId == value.TargetId;
bool isMessageEqual = result.Message == value.Message;
isTypeEqual.Should().BeTrue();
isTargetEqual.Should().BeTrue();
isMessageEqual.Should().BeTrue();
MessageEnvelope reconstructed = result.ToMessageEnvelope();
reconstructed.MessageType.Should().Be(envelope.MessageType);
reconstructed.TargetId.Should().Be(envelope.TargetId);
// Unfortunately, ChatMessage does not contain an "equality" comparer, so we need to explicitly pull it out
// Simulate what PortableValue does in .Equals()
Type expectedType = envelope.Message.GetType();
object? maybeReconstructedMessage = ((PortableValue)reconstructed.Message)!.AsType(expectedType);
maybeReconstructedMessage.Should().NotBeNull()
.And.BeOfType<ChatMessage>()
.And.Match(message.CreateValidatorCheckingText());
}
[Fact]
public void Test_PortableMessageEnvelope_JsonRoundtrip_CustomType()
{
TestJsonSerializable message = new() { Id = 42, Name = "Test" };
MessageEnvelope envelope = new(message, new TypeId(typeof(object)), targetId: "Target1");
PortableMessageEnvelope value = new(envelope);
PortableMessageEnvelope result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions);
bool isTypeEqual = result.MessageType == value.MessageType;
bool isTargetEqual = result.TargetId == value.TargetId;
bool isMessageEqual = result.Message == value.Message;
isTypeEqual.Should().BeTrue();
isTargetEqual.Should().BeTrue();
isMessageEqual.Should().BeTrue();
MessageEnvelope reconstructed = result.ToMessageEnvelope();
reconstructed.MessageType.Should().Be(envelope.MessageType);
reconstructed.TargetId.Should().Be(envelope.TargetId);
reconstructed.Message.Should().Be(envelope.Message);
}
private static RunnerStateData TestRunnerStateData
{
get
{
return new(
[ForwardStringId, ForwardIntId],
CreateQueuedMessages(),
outstandingRequests: [TestExternalRequest]
);
Dictionary<ExecutorIdentity, List<PortableMessageEnvelope>> CreateQueuedMessages()
{
Dictionary<ExecutorIdentity, List<PortableMessageEnvelope>> result = new();
MessageEnvelope externalEnvelope = new(TestExternalResponse);
result.Add(ExecutorIdentity.None, [new(externalEnvelope)]);
MessageEnvelope internalEnvelope = new("InternalMessage");
result.Add("TestExecutor1", [new(internalEnvelope)]);
return result;
}
}
}
private static void ValidateRunnerStateData(RunnerStateData result, RunnerStateData prototype)
{
Assert.Collection(result.InstantiatedExecutors,
prototype.InstantiatedExecutors.Select(
prototype =>
(Action<string>)(actual => actual.Should().Be(prototype))).ToArray());
result.QueuedMessages.Should().HaveCount(prototype.QueuedMessages.Count);
foreach (ExecutorIdentity key in prototype.QueuedMessages.Keys)
{
result.QueuedMessages.Should().ContainKey(key);
List<PortableMessageEnvelope> actualList = result.QueuedMessages[key];
List<PortableMessageEnvelope> expectedList = prototype.QueuedMessages[key];
actualList.Should().HaveCount(expectedList.Count);
for (int i = 0; i < expectedList.Count; i++)
{
PortableMessageEnvelope actual = actualList[i];
PortableMessageEnvelope expected = expectedList[i];
actual.MessageType.Should().Be(expected.MessageType);
actual.TargetId.Should().Be(expected.TargetId);
actual.Message.Should().Be(expected.Message);
}
}
result.OutstandingRequests.Should().HaveCount(prototype.OutstandingRequests.Count);
Assert.Collection(result.OutstandingRequests,
prototype.OutstandingRequests.Select(
expected =>
(Action<ExternalRequest>)(actual => ValidateExternalRequest(actual, expected))).ToArray());
}
[Fact]
public void Test_RunnerStateData_JsonRoundtrip()
{
RunnerStateData prototype = TestRunnerStateData;
RunnerStateData result = RunJsonRoundtrip(prototype);
ValidateRunnerStateData(result, prototype);
}
private static FanInEdgeState TestFanInEdgeState => new(TestFanInEdgeData);
private static PortableValue CreateEdgeState<TMessage>(TMessage message) where TMessage : notnull
{
FanInEdgeState state = TestFanInEdgeState;
_ = state.ProcessMessage("SourceExecutor1", new MessageEnvelope(message, typeof(TMessage)));
return new(state);
}
private static TestJsonSerializable TestCustomSerializable => new() { Id = 42, Name = nameof(TestCustomSerializable) };
private static Dictionary<EdgeId, PortableValue> TestEdgeState
{
get
{
return new()
{
[TakeEdgeId()] = CreateEdgeState("Hello, world!"),
[TakeEdgeId()] = CreateEdgeState(TestExternalResponse),
[TakeEdgeId()] = CreateEdgeState(TestCustomSerializable)
};
}
}
private static void ValidateEdgeStateData(Dictionary<EdgeId, PortableValue> result, Dictionary<EdgeId, PortableValue> prototype)
{
result.Should().HaveCount(prototype.Count);
foreach (EdgeId id in prototype.Keys)
{
result.Should().ContainKey(id)
.And.Subject[id].Should().Be(prototype[id])
.And.Subject.As<PortableValue>()
.As<FanInEdgeState>().Should().NotBeNull()
.And.Match(CreateValidator(prototype[id].As<FanInEdgeState>()!));
}
Expression<Func<FanInEdgeState, bool>> CreateValidator(FanInEdgeState prototype)
{
return actual => actual.Unseen.SetEquals(prototype.Unseen) &&
actual.SourceIds.SequenceEqual(prototype.SourceIds) &&
actual.PendingMessages.Zip(prototype.PendingMessages,
(actualMessage, expectedMessage) => actualMessage.MessageType == expectedMessage.MessageType &&
actualMessage.TargetId == expectedMessage.TargetId &&
actualMessage.Message.Equals(expectedMessage.Message)).All(v => v);
}
}
[Fact]
public void Test_EdgeStateData_JsonRoundtrip()
{
Dictionary<EdgeId, PortableValue> value = TestEdgeState;
Dictionary<EdgeId, PortableValue> result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions);
ValidateEdgeStateData(result, value);
}
private static ScopeKey TestScopeKey1 => new(StringToIntId, null, "Key1");
private static ScopeKey TestScopeKey2 => new(StringToIntId, "Shared", "Key2");
private static ScopeKey TestScopeKey3 => new(IntToStringId, "Shared", "Key3");
private static ChatMessage TestUserMessage => new(ChatRole.User, "Hello");
private static Dictionary<ScopeKey, PortableValue> TestStateData
{
get
{
return new()
{
[TestScopeKey1] = new("Lorem Ipsum"),
[TestScopeKey2] = new(TestUserMessage),
[TestScopeKey3] = new(TestCustomSerializable)
};
}
}
private static void ValidateStateData(Dictionary<ScopeKey, PortableValue> result, Dictionary<ScopeKey, PortableValue> prototype)
{
result.Should().HaveCount(prototype.Count);
foreach (ScopeKey key in prototype.Keys)
{
PortableValue state =
result.Should().ContainKey(key)
.And.Subject[key].Should().Be(prototype[key])
.And.Subject.As<PortableValue>();
switch (key.Key)
{
case "Key1":
state.As<string>().Should().Be("Lorem Ipsum");
break;
case "Key2":
ChatMessage? maybeMessage = state.As<ChatMessage>();
maybeMessage.Should().NotBeNull()
.And.Match(TestUserMessage.CreateValidatorCheckingText());
break;
case "Key3":
state.As<TestJsonSerializable>().Should().Be(TestCustomSerializable);
break;
default:
throw new NotImplementedException($"Missing validation for key '{key.Key}'");
}
}
}
[Fact]
public void Test_ExecutorStateData_JsonRoundTrip()
{
Dictionary<ScopeKey, PortableValue> value = TestStateData;
Dictionary<ScopeKey, PortableValue> result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions);
ValidateStateData(result, value);
}
private static readonly string s_runId = Guid.NewGuid().ToString("N");
private static readonly string s_parentCheckpointId = Guid.NewGuid().ToString("N");
private static CheckpointInfo TestParentCheckpointInfo => new(s_runId, s_parentCheckpointId);
[Fact]
public void Test_Checkpoint_JsonRoundTrip()
{
Checkpoint prototype = new(12, TestWorkflowInfo, TestRunnerStateData, TestStateData, TestEdgeState, TestParentCheckpointInfo);
Checkpoint result = RunJsonRoundtrip(prototype, TestCustomSerializedJsonOptions);
result.Should().Match((Checkpoint checkpoint) => checkpoint.StepNumber == prototype.StepNumber);
result.Parent.Should().Be(prototype.Parent);
ValidateWorkflowInfo(result.Workflow, prototype.Workflow);
ValidateRunnerStateData(result.RunnerData, prototype.RunnerData);
ValidateStateData(result.StateData, prototype.StateData);
ValidateEdgeStateData(result.EdgeStateData, prototype.EdgeStateData);
}
}
@@ -106,51 +106,53 @@ public class RepresentationTests
[Fact]
public void Test_EdgeInfos()
{
int edgeId = 0;
// Direct Edges
Edge directEdgeNoCondition = new(new DirectEdgeData(Source(1), Sink(2)));
Edge directEdgeNoCondition = new(new DirectEdgeData(Source(1), Sink(2), TakeEdgeId()));
RunEdgeInfoMatchTest(directEdgeNoCondition);
Edge directEdgeNoCondition2 = new(new DirectEdgeData(Source(1), Sink(2)));
Edge directEdgeNoCondition2 = new(new DirectEdgeData(Source(1), Sink(2), TakeEdgeId()));
RunEdgeInfoMatchTest(directEdgeNoCondition, directEdgeNoCondition2);
Edge directEdgeNoCondition3 = new(new DirectEdgeData(Source(3), Sink(4)));
Edge directEdgeNoCondition3 = new(new DirectEdgeData(Source(3), Sink(4), TakeEdgeId()));
RunEdgeInfoMatchTest(directEdgeNoCondition, directEdgeNoCondition3, expect: false);
Edge directEdgeWithCondition = new(new DirectEdgeData(Source(3), Sink(4), Condition()));
Edge directEdgeWithCondition = new(new DirectEdgeData(Source(3), Sink(4), TakeEdgeId(), Condition()));
RunEdgeInfoMatchTest(directEdgeWithCondition);
RunEdgeInfoMatchTest(directEdgeNoCondition2, directEdgeWithCondition, expect: false);
RunEdgeInfoMatchTest(directEdgeNoCondition3, directEdgeWithCondition, expect: false);
// FanOut Edges
Edge fanOutEdgeNoAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)]));
Edge fanOutEdgeNoAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], TakeEdgeId()));
RunEdgeInfoMatchTest(fanOutEdgeNoAssigner);
Edge fanOutEdgeNoAssigner2 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)]));
Edge fanOutEdgeNoAssigner2 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], TakeEdgeId()));
RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner2);
Edge fanOutEdgeNoAssigner3 = new(new FanOutEdgeData(Source(1), [Sink(3), Sink(4), Sink(2)]));
Edge fanOutEdgeNoAssigner3 = new(new FanOutEdgeData(Source(1), [Sink(3), Sink(4), Sink(2)], TakeEdgeId()));
RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner3, expect: false); // Order matters (though without Assigner maybe it shouldn't?)
Edge fanOutEdgeNoAssigner4 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(5)]));
Edge fanOutEdgeNoAssigner5 = new(new FanOutEdgeData(Source(2), [Sink(2), Sink(3), Sink(4)]));
Edge fanOutEdgeNoAssigner4 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(5)], TakeEdgeId()));
Edge fanOutEdgeNoAssigner5 = new(new FanOutEdgeData(Source(2), [Sink(2), Sink(3), Sink(4)], TakeEdgeId()));
RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner4, expect: false); // Identity matters
RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner5, expect: false);
Edge fanOutEdgeWithAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], EdgeAssigner()));
Edge fanOutEdgeWithAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], TakeEdgeId(), EdgeAssigner()));
RunEdgeInfoMatchTest(fanOutEdgeWithAssigner);
// FanIn Edges
Edge fanInEdge = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1)));
Edge fanInEdge = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId()));
RunEdgeInfoMatchTest(fanInEdge);
Edge fanInEdge2 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1)));
Edge fanInEdge2 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId()));
RunEdgeInfoMatchTest(fanInEdge, fanInEdge2);
Edge fanInEdge3 = new(new FanInEdgeData([Source(2), Source(3), Source(1)], Sink(1)));
Edge fanInEdge3 = new(new FanInEdgeData([Source(2), Source(3), Source(1)], Sink(1), TakeEdgeId()));
RunEdgeInfoMatchTest(fanInEdge, fanInEdge3, expect: false); // Order matters (though for FanIn maybe it shouldn't?)
Edge fanInEdge4 = new(new FanInEdgeData([Source(1), Source(2), Source(4)], Sink(1)));
Edge fanInEdge5 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(2)));
Edge fanInEdge4 = new(new FanInEdgeData([Source(1), Source(2), Source(4)], Sink(1), TakeEdgeId()));
Edge fanInEdge5 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(2), TakeEdgeId()));
RunEdgeInfoMatchTest(fanInEdge, fanInEdge4, expect: false); // Identity matters
RunEdgeInfoMatchTest(fanInEdge, fanInEdge5, expect: false);
@@ -161,6 +163,8 @@ public class RepresentationTests
EdgeInfo info = edge.ToEdgeInfo();
info.IsMatch(comparatorEdge).Should().Be(expect);
}
EdgeId TakeEdgeId() => new(edgeId++);
}
[Fact]
@@ -21,8 +21,8 @@ internal static class Step2EntryPoint
RemoveSpamExecutor removeSpam = new();
return new WorkflowBuilder(detectSpam)
.AddEdge(detectSpam, respondToMessage, isSpam => isSpam is false) // If not spam, respond
.AddEdge(detectSpam, removeSpam, isSpam => isSpam is true) // If spam, remove
.AddEdge(detectSpam, respondToMessage, (bool isSpam) => isSpam is false) // If not spam, respond
.AddEdge(detectSpam, removeSpam, (bool isSpam) => isSpam is true) // If spam, remove
.Build<string>();
}
}
@@ -137,6 +137,6 @@ internal sealed class JudgeExecutor : ReflectingExecutor<JudgeExecutor>, IMessag
protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellation = default)
{
this.Tries = await context.ReadStateAsync<int>("TryCount").ConfigureAwait(false);
this.Tries = await context.ReadStateAsync<int?>("TryCount").ConfigureAwait(false) ?? 0;
}
}
@@ -15,8 +15,8 @@ internal static class Step4EntryPoint
return new WorkflowBuilder(guessNumber)
.AddEdge(guessNumber, judge)
.AddEdge(judge, guessNumber, (message) => message is NumberSignal signal && signal != NumberSignal.Matched)
.BuildWithOutput<NumberSignal, string>(judge, ComputeStreamingOutput, (NumberSignal s, string? _) => s == NumberSignal.Matched);
.AddEdge(judge, guessNumber, (NumberSignal signal) => signal != NumberSignal.Matched)
.BuildWithOutput<NumberSignal, NumberSignal, string>(judge, ComputeStreamingOutput, (NumberSignal s, string? _) => s == NumberSignal.Matched);
}
public static Workflow<NumberSignal, string> WorkflowInstance
@@ -60,10 +60,10 @@ internal static class Step4EntryPoint
Func<string, int> userGuessCallback,
string? runningState)
{
object result = request.Port.Id switch
object result = request.PortInfo.PortId switch
{
"GuessNumber" => userGuessCallback(runningState ?? "Guess the number."),
_ => throw new NotSupportedException($"Request {request.Port.Id} is not supported")
_ => throw new NotSupportedException($"Request {request.PortInfo.PortId} is not supported")
};
return request.CreateResponse(result);
@@ -11,13 +11,13 @@ namespace Microsoft.Agents.Workflows.Sample;
internal static class Step5EntryPoint
{
private static CheckpointManager CheckpointManager { get; } = new();
public static async ValueTask<string> RunAsync(TextWriter writer, Func<string, int> userGuessCallback, bool rehydrateToRestore = false)
public static async ValueTask<string> RunAsync(TextWriter writer, Func<string, int> userGuessCallback, bool rehydrateToRestore = false, CheckpointManager? checkpointManager = null)
{
checkpointManager ??= CheckpointManager.Default;
Workflow<NumberSignal, string> workflow = Step4EntryPoint.CreateWorkflowInstance(out JudgeExecutor judge);
Checkpointed<StreamingRun<string>> checkpointed =
await InProcessExecution.StreamAsync(workflow, NumberSignal.Init, CheckpointManager)
await InProcessExecution.StreamAsync(workflow, NumberSignal.Init, checkpointManager)
.ConfigureAwait(false);
List<CheckpointInfo> checkpoints = new();
@@ -34,7 +34,7 @@ internal static class Step5EntryPoint
if (rehydrateToRestore)
{
checkpointed = await InProcessExecution.ResumeStreamAsync(workflow, targetCheckpoint, CheckpointManager, CancellationToken.None)
checkpointed = await InProcessExecution.ResumeStreamAsync(workflow, targetCheckpoint, checkpointManager, CancellationToken.None)
.ConfigureAwait(false);
handle = checkpointed.Run;
}
@@ -105,10 +105,10 @@ internal static class Step5EntryPoint
Func<string, int> userGuessCallback,
string? runningState)
{
object result = request.Port.Id switch
object result = request.PortInfo.PortId switch
{
"GuessNumber" => userGuessCallback(runningState ?? "Guess the number."),
_ => throw new NotSupportedException($"Request {request.Port.Id} is not supported")
_ => throw new NotSupportedException($"Request {request.PortInfo.PortId} is not supported")
};
return request.CreateResponse(result);
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;
using Microsoft.Agents.Workflows.Sample;
namespace Microsoft.Agents.Workflows.UnitTests;
// Checkpointing Types
[JsonSerializable(typeof(NumberSignal))]
[ExcludeFromCodeCoverage]
internal sealed partial class SampleJsonContext : JsonSerializerContext;
@@ -3,6 +3,7 @@
using System;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.Agents.Workflows.Sample;
@@ -122,6 +123,29 @@ public class SampleSmokeTest
Assert.Equal("You guessed correctly! You Win!", guessResult);
}
[Fact]
public async Task Test_RunSample_Step5bAsync()
{
using StringWriter writer = new();
VerifyingPlaybackResponder<string, int> responder = new(
// Iteration 1
("Guess the number.", 50),
("Your guess was too high. Try again.", 23),
// Iteration 2
("Your guess was too high. Try again.", 23),
("Your guess was too low. Try again.", 42)
);
JsonSerializerOptions options = new(SampleJsonContext.Default.Options);
options.MakeReadOnly();
CheckpointManager memoryJsonManager = CheckpointManager.CreateJson(new InMemoryJsonStore(), options);
string guessResult = await Step5EntryPoint.RunAsync(writer, userGuessCallback: responder.InvokeNext, rehydrateToRestore: true, checkpointManager: memoryJsonManager);
Assert.Equal("You guessed correctly! You Win!", guessResult);
}
[Fact]
public async Task Test_RunSample_Step6Async()
{
@@ -406,7 +406,7 @@ public class StateManagerTests
// Act: Update the key from one executor and delete it from another
await manager.WriteStateAsync(scopeSelfView, Key1, "newValue");
await manager.WriteStateAsync<string>(scopeOtherView, Key1, null);
await manager.ClearStateAsync(scopeOtherView, Key1);
Func<Task> act = async () => await manager.PublishUpdatesAsync(tracer: null);
if (isSharedScope)
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Linq.Expressions;
namespace Microsoft.Agents.Workflows.UnitTests;
internal sealed class SubstitutionVisitor(ParameterExpression parameter, Expression substitution) : ExpressionVisitor
{
private ParameterExpression Parameter => parameter;
private Expression Substitution => substitution;
protected override Expression VisitParameter(ParameterExpression node)
{
if (node.Name == this.Parameter.Name)
{
return this.Substitution;
}
return base.VisitParameter(node);
}
}
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization;
namespace Microsoft.Agents.Workflows.UnitTests;
// Checkpointing Types
[JsonSerializable(typeof(TestJsonSerializable))]
[ExcludeFromCodeCoverage]
internal sealed partial class TestJsonContext : JsonSerializerContext;
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.Agents.Workflows.UnitTests;
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
NumberHandling = JsonNumberHandling.AllowReadingFromString)]
internal sealed class TestJsonSerializable
{
public int Id { get; set; }
public string Name { get; set; } = string.Empty;
public override bool Equals(object? obj)
{
if (obj == null)
{
return false;
}
if (obj is not TestJsonSerializable other)
{
return false;
}
return this.Id == other.Id && this.Name == other.Name;
}
public override int GetHashCode() => HashCode.Combine(this.Id, this.Name);
}
@@ -0,0 +1,167 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.Agents.Workflows.Checkpointing;
using Microsoft.Agents.Workflows.Execution;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.Workflows.UnitTests;
internal static partial class ValidationExtensions
{
public static Expression<Func<EdgeConnection, bool>> CreateValidator(this EdgeConnection prototype)
{
return actual => actual.SourceIds.Count == prototype.SourceIds.Count &&
actual.SinkIds.Count == prototype.SinkIds.Count &&
prototype.SourceIds.SequenceEqual(actual.SourceIds) &&
prototype.SinkIds.SequenceEqual(actual.SinkIds);
}
public static Expression<Func<TypeId, bool>> CreateValidator(this TypeId prototype)
{
return actual => actual.AssemblyName == prototype.AssemblyName &&
actual.TypeName == prototype.TypeName;
}
public static Expression<Func<ExecutorInfo, bool>> CreateValidator(this ExecutorInfo prototype)
{
return actual => actual.ExecutorId == prototype.ExecutorId &&
// Rely on the TypeId test to probe TypeId serialization - just validate that we got a functional TypeId
actual.ExecutorType.Equals(prototype.ExecutorType);
}
public static Expression<Func<InputPortInfo, bool>> CreatePortInfoValidator(this InputPort prototype)
{
return actual => actual.PortId == prototype.Id &&
// Rely on the TypeId test to probe TypeId serialization - just validate that we got a functional TypeId
actual.RequestType.IsMatch(prototype.Request) &&
actual.ResponseType.IsMatch(prototype.Response);
}
public static Expression<Func<DirectEdgeInfo, bool>> CreateValidator(this DirectEdgeInfo prototype)
{
return actual => actual.Connection == prototype.Connection &&
actual.HasCondition == prototype.HasCondition;
}
public static Expression<Func<FanOutEdgeInfo, bool>> CreateValidator(this FanOutEdgeInfo prototype)
{
return actual => actual.Connection == prototype.Connection &&
actual.HasAssigner == prototype.HasAssigner;
}
public static Expression<Func<FanInEdgeInfo, bool>> CreateValidator(this FanInEdgeInfo prototype)
{
return actual => actual.Connection == prototype.Connection;
}
public static Expression<Func<EdgeInfo, bool>> CreatePolyValidator(this EdgeInfo prototype)
{
switch (prototype.Kind)
{
case EdgeKind.Direct:
{
var innerValidatorExpr = CreateValidator((DirectEdgeInfo)prototype);
// Check that incoming is of the correct type, and if so, chain to the body
Debug.Assert(innerValidatorExpr.Parameters.Count == 1, "Validator is of unexpected arity");
return CreateValidatorExpression(innerValidatorExpr);
}
case EdgeKind.FanOut:
{
var innerValidatorExpr = CreateValidator((FanOutEdgeInfo)prototype);
// Check that incoming is of the correct type, and if so, chain to the body
Debug.Assert(innerValidatorExpr.Parameters.Count == 1, "Validator is of unexpected arity");
return CreateValidatorExpression(innerValidatorExpr);
}
case EdgeKind.FanIn:
{
var innerValidatorExpr = CreateValidator((FanInEdgeInfo)prototype);
// Check that incoming is of the correct type, and if so, chain to the body
Debug.Assert(innerValidatorExpr.Parameters.Count == 1, "Validator is of unexpected arity");
return CreateValidatorExpression(innerValidatorExpr);
}
default:
throw new NotSupportedException($"Unsupported edge type: {prototype.Kind}");
}
Expression<Func<EdgeInfo, bool>> CreateValidatorExpression<TInner>(Expression<Func<TInner, bool>> innerValidator)
where TInner : EdgeInfo
{
var innerParam = innerValidator.Parameters[0];
var innerBody = innerValidator.Body;
var outerParam = Expression.Parameter(typeof(EdgeInfo), "actual");
var convertExpr = Expression.Convert(outerParam, typeof(TInner));
ExpressionVisitor visitor = new SubstitutionVisitor(innerParam, convertExpr);
Expression innerValidatorExpr = visitor.Visit(innerBody);
BinaryExpression bodyExpression = Expression.AndAlso(
Expression.AndAlso(
Expression.Equal(
Expression.Property(outerParam, nameof(EdgeInfo.Kind)),
Expression.Constant(prototype.Kind)
),
Expression.TypeIs(outerParam, typeof(TInner))
),
innerValidatorExpr
);
Expression<Func<EdgeInfo, bool>> validatorExpr = Expression.Lambda<Func<EdgeInfo, bool>>(
bodyExpression,
outerParam
);
return validatorExpr;
}
}
public static Expression<Func<ScopeId, bool>> CreateValidator(this ScopeId prototype)
{
return actual => actual.ExecutorId == prototype.ExecutorId &&
actual.ScopeName == prototype.ScopeName;
}
public static Expression<Func<ScopeKey, bool>> CreateValidator(this ScopeKey prototype)
{
return actual => actual.Key == prototype.Key &&
actual.ScopeId.ScopeName == prototype.ScopeId.ScopeName &&
actual.ScopeId.ExecutorId == prototype.ScopeId.ExecutorId;
}
public static Expression<Func<ExecutorIdentity, bool>> CreateValidator(this ExecutorIdentity prototype)
{
return actual => actual.Id == prototype.Id;
}
public static Expression<Func<ExternalRequest, bool>> CreateValidator(this ExternalRequest prototype)
{
return actual => actual.RequestId == prototype.RequestId &&
actual.PortInfo == prototype.PortInfo &&
actual.Data == prototype.Data;
}
public static Expression<Func<ExternalResponse, bool>> CreateValidator(this ExternalResponse prototype)
{
return actual => actual.RequestId == prototype.RequestId &&
actual.Data == prototype.Data;
}
public static Expression<Func<ChatMessage, bool>> CreateValidatorCheckingText(this ChatMessage prototype)
{
return actual => actual.Role == prototype.Role &&
actual.AuthorName == prototype.AuthorName &&
actual.CreatedAt == prototype.CreatedAt &&
actual.MessageId == prototype.MessageId &&
actual.Text == prototype.Text;
}
}