From 2015f0dc09b1162631bd25d9d5cc0d9e6c7a7df7 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 17 Sep 2025 15:57:19 -0400 Subject: [PATCH] .NET: [BREAKING] Support Checkpoint Serialization (#735) * feat: Support Checkpoint Serialization * Implements serialization roundtripping for checkpoints. * Adds support for JSON serialization * Adds FileSystem-based checkpoint persistence * fix: Executor State does not deserialize correctly The StateManager was not properly handling delay-deserialized values. * Fix PortableValue handling in StateManager (this makes it delegate to PortableValue the uwnrapping) * Fix UnitTest to actually test checkpoint serialization * Additional review comment fixes --------- Co-authored-by: Chris <66376200+crickman@users.noreply.github.com> --- dotnet/.editorconfig | 3 + .../CheckpointAndRehydrate/Program.cs | 2 +- .../Checkpoint/CheckpointAndResume/Program.cs | 2 +- .../CheckpointWithHumanInTheLoop/Program.cs | 8 +- .../03_MultiSelection/Program.cs | 14 +- .../Workflows/Declarative/Program.cs | 6 +- .../HumanInTheLoopBasic/Program.cs | 6 +- .../CheckpointInfo.cs | 32 +- .../CheckpointManager.cs | 57 +- .../Checkpointing/Checkpoint.cs | 21 +- .../Checkpointing/CheckpointManagerImpl.cs | 30 + .../Checkpointing/DirectEdgeInfo.cs | 23 +- .../Checkpointing/EdgeIdConverter.cs | 30 + .../Checkpointing/EdgeInfo.cs | 34 +- .../ExecutorIdentityConverter.cs | 38 + .../Checkpointing/ExportedState.cs | 12 - .../Checkpointing/FanInEdgeInfo.cs | 18 +- .../Checkpointing/FanOutEdgeInfo.cs | 23 +- .../FileSystemJsonCheckpointStore.cs | 167 +++++ .../Checkpointing/ICheckpointManager.cs | 8 +- .../Checkpointing/ICheckpointStore.cs | 49 ++ .../Checkpointing/IDelayedDeserialization.cs | 27 + .../Checkpointing/IWireMarshaller.cs | 44 ++ .../InMemoryCheckpointManager.cs | 47 ++ .../Checkpointing/InputPortInfo.cs | 10 +- .../Checkpointing/JsonCheckpointStore.cs | 28 + .../Checkpointing/JsonConverterBase.cs | 35 + .../JsonConverterDictionarySupportBase.cs | 39 ++ .../Checkpointing/JsonMarshaller.cs | 79 +++ .../Checkpointing/JsonWireSerializedValue.cs | 63 ++ .../Checkpointing/PortableMessageEnvelope.cs | 33 + .../Checkpointing/PortableValueConverter.cs | 71 ++ .../Checkpointing/RepresentationExtensions.cs | 12 +- .../Checkpointing/RunCheckpointCache.cs | 40 ++ .../Checkpointing/ScopeKeyConverter.cs | 93 +++ .../Checkpointing/TypeId.cs | 88 ++- .../Checkpointing/WorkflowInfo.cs | 10 +- .../DirectEdgeData.cs | 21 +- dotnet/src/Microsoft.Agents.Workflows/Edge.cs | 54 +- .../Microsoft.Agents.Workflows/EdgeData.cs | 7 + .../src/Microsoft.Agents.Workflows/EdgeId.cs | 59 ++ .../Execution/EdgeConnection.cs | 33 +- .../Execution/EdgeMap.cs | 68 +- .../Execution/FanInEdgeRunner.cs | 26 +- .../Execution/FanInEdgeState.cs | 51 +- .../Execution/InputEdgeRunner.cs | 2 +- .../Execution/MessageEnvelope.cs | 14 +- .../Execution/MessageRouter.cs | 28 +- .../Execution/RunnerStateData.cs | 4 +- .../Execution/StateManager.cs | 29 +- .../Execution/StateScope.cs | 38 +- .../Execution/StepContext.cs | 11 +- .../Microsoft.Agents.Workflows/Executor.cs | 7 +- .../ExecutorEvent.cs | 5 + ...FailureEvent.cs => ExecutorFailedEvent.cs} | 2 +- .../ExternalRequest.cs | 27 +- .../ExternalResponse.cs | 27 +- .../FanInEdgeData.cs | 17 +- .../FanOutEdgeData.cs | 24 +- .../InProc/InProcessRunner.cs | 36 +- .../InProc/InProcessRunnerContext.cs | 2 +- .../InProcessExecution.cs | 8 +- .../PortableValue.cs | 176 +++++ .../Microsoft.Agents.Workflows/ScopeKey.cs | 21 +- .../Specialized/RequestInfoExecutor.cs | 8 +- .../SuperStepEvent.cs | 4 + .../SwitchBuilder.cs | 7 +- .../Microsoft.Agents.Workflows/Workflow.cs | 22 +- .../WorkflowBuilder.cs | 97 ++- .../WorkflowBuilderExtensions.cs | 32 +- .../WorkflowEvent.cs | 9 + .../WorkflowHostingExtensions.cs | 2 +- .../WorkflowsJsonUtilities.cs | 33 +- .../EdgeMapSmokeTests.cs | 2 +- .../EdgeRunnerTests.cs | 6 +- .../InMemoryJsonStore.cs | 43 ++ .../InProcessStateTests.cs | 2 +- .../JsonSerializationTests.cs | 653 ++++++++++++++++++ .../RepresentationTests.cs | 34 +- .../Sample/02_Simple_Workflow_Condition.cs | 4 +- .../Sample/03_Simple_Workflow_Loop.cs | 2 +- .../04_Simple_Workflow_ExternalRequest.cs | 8 +- .../05_Simple_Workflow_Checkpointing.cs | 14 +- .../SampleJsonContext.cs | 12 + .../SampleSmokeTest.cs | 24 + .../StateManagerTests.cs | 2 +- .../SubstitutionVisitor.cs | 21 + .../TestJsonContext.cs | 11 + .../TestJsonSerializable.cs | 34 + .../ValidationExtensions.cs | 167 +++++ 90 files changed, 3011 insertions(+), 341 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/CheckpointManagerImpl.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeIdConverter.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExecutorIdentityConverter.cs delete mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExportedState.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointStore.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IDelayedDeserialization.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IWireMarshaller.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InMemoryCheckpointManager.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonCheckpointStore.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterBase.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterDictionarySupportBase.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonMarshaller.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonWireSerializedValue.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableMessageEnvelope.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableValueConverter.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RunCheckpointCache.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ScopeKeyConverter.cs create mode 100644 dotnet/src/Microsoft.Agents.Workflows/EdgeId.cs rename dotnet/src/Microsoft.Agents.Workflows/{ExecutorFailureEvent.cs => ExecutorFailedEvent.cs} (88%) create mode 100644 dotnet/src/Microsoft.Agents.Workflows/PortableValue.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InMemoryJsonStore.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/JsonSerializationTests.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleJsonContext.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SubstitutionVisitor.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonContext.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonSerializable.cs create mode 100644 dotnet/tests/Microsoft.Agents.Workflows.UnitTests/ValidationExtensions.cs diff --git a/dotnet/.editorconfig b/dotnet/.editorconfig index 76f90b6f33..59d57c4818 100644 --- a/dotnet/.editorconfig +++ b/dotnet/.editorconfig @@ -436,3 +436,6 @@ resharper_redundant_using_directive_highlighting = warning # Resharper's "Redund resharper_inconsistent_naming_highlighting = warning # Resharper's "Inconsistent naming" highlighting resharper_redundant_this_qualifier_highlighting = warning # Resharper's "Redundant 'this' qualifier" highlighting resharper_arrange_this_qualifier_highlighting = warning # Resharper's "Arrange 'this' qualifier" highlighting +csharp_style_prefer_primary_constructors = true:suggestion +csharp_prefer_system_threading_lock = true:suggestion +csharp_style_prefer_simple_property_accessors = true:suggestion diff --git a/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndRehydrate/Program.cs b/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndRehydrate/Program.cs index d039b889b8..7c20b773df 100644 --- a/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndRehydrate/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndRehydrate/Program.cs @@ -32,7 +32,7 @@ public static class Program var workflow = WorkflowHelper.GetWorkflow(); // Create checkpoint manager - var checkpointManager = new CheckpointManager(); + var checkpointManager = CheckpointManager.Default; var checkpoints = new List(); // Execute the workflow and save checkpoints diff --git a/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndResume/Program.cs b/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndResume/Program.cs index ebabba2141..4b5b5fe425 100644 --- a/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndResume/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointAndResume/Program.cs @@ -31,7 +31,7 @@ public static class Program var workflow = WorkflowHelper.GetWorkflow(); // Create checkpoint manager - var checkpointManager = new CheckpointManager(); + var checkpointManager = CheckpointManager.Default; var checkpoints = new List(); // Execute the workflow and save checkpoints diff --git a/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointWithHumanInTheLoop/Program.cs b/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointWithHumanInTheLoop/Program.cs index 167cc760d3..f0f193a0e0 100644 --- a/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointWithHumanInTheLoop/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Checkpoint/CheckpointWithHumanInTheLoop/Program.cs @@ -34,7 +34,7 @@ public static class Program var workflow = WorkflowHelper.GetWorkflow(); // Create checkpoint manager - var checkpointManager = new CheckpointManager(); + var checkpointManager = CheckpointManager.Default; var checkpoints = new List(); // Execute the workflow and save checkpoints @@ -102,9 +102,9 @@ public static class Program private static ExternalResponse HandleExternalRequest(ExternalRequest request) { - if (request.Port.Request == typeof(SignalWithNumber)) + var signal = request.DataAs(); + if (signal is not null) { - var signal = (SignalWithNumber)request.Data; switch (signal.Signal) { case NumberSignal.Init: @@ -119,7 +119,7 @@ public static class Program } } - throw new NotSupportedException($"Request {request.Port.Request} is not supported"); + throw new NotSupportedException($"Request {request.PortInfo.RequestType} is not supported"); } private static int ReadIntegerFromConsole(string prompt) diff --git a/dotnet/samples/GettingStarted/Workflows/ConditionalEdges/03_MultiSelection/Program.cs b/dotnet/samples/GettingStarted/Workflows/ConditionalEdges/03_MultiSelection/Program.cs index 4a0ae93763..67d90a0705 100644 --- a/dotnet/samples/GettingStarted/Workflows/ConditionalEdges/03_MultiSelection/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/ConditionalEdges/03_MultiSelection/Program.cs @@ -75,10 +75,10 @@ public static class Program // After the email assistant writes a response, it will be sent to the send email executor .AddEdge(emailAssistantExecutor, sendEmailExecutor) // Save the analysis result to the database if summary is not needed - .AddEdge( + .AddEdge( emailAnalysisExecutor, databaseAccessExecutor, - condition: analysisResult => analysisResult is AnalysisResult result && result.EmailLength <= LongEmailThreshold) + condition: analysisResult => analysisResult is not null && analysisResult.EmailLength <= LongEmailThreshold) // Save the analysis result to the database with summary .AddEdge(emailSummaryExecutor, databaseAccessExecutor); var workflow = builder.Build(); @@ -107,21 +107,21 @@ public static class Program /// Creates a partitioner for routing messages based on the analysis result. /// /// A function that takes an analysis result and returns the target partitions. - private static Func> GetPartitioner() + private static Func> GetPartitioner() { return (analysisResult, targetCount) => { - if (analysisResult is AnalysisResult result) + if (analysisResult is not null) { - if (result.spamDecision == SpamDecision.Spam) + if (analysisResult.spamDecision == SpamDecision.Spam) { return [0]; // Route to spam handler } - else if (result.spamDecision == SpamDecision.NotSpam) + else if (analysisResult.spamDecision == SpamDecision.NotSpam) { List targets = [1]; // Route to the email assistant - if (result.EmailLength > LongEmailThreshold) + if (analysisResult.EmailLength > LongEmailThreshold) { targets.Add(2); // Route to the email summarizer too } diff --git a/dotnet/samples/GettingStarted/Workflows/Declarative/Program.cs b/dotnet/samples/GettingStarted/Workflows/Declarative/Program.cs index 478331efba..858ddf7157 100644 --- a/dotnet/samples/GettingStarted/Workflows/Declarative/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/Declarative/Program.cs @@ -53,7 +53,7 @@ internal sealed class Program // Run the workflow, just like any other workflow string input = this.GetWorkflowInput(); - CheckpointManager checkpointManager = new(); + CheckpointManager checkpointManager = CheckpointManager.Default; Checkpointed run = await InProcessExecution.StreamAsync(workflow, input, checkpointManager); bool isComplete = false; @@ -151,7 +151,7 @@ internal sealed class Program Debug.WriteLine($"ACTION EXIT #{actionComplete.ActionId} [{actionComplete.ActionType}]"); break; - case ExecutorFailureEvent executorFailure: + case ExecutorFailedEvent executorFailure: Debug.WriteLine($"STEP ERROR #{executorFailure.ExecutorId}: {executorFailure.Data?.Message ?? "Unknown"}"); break; @@ -256,7 +256,7 @@ internal sealed class Program } private static InputResponse HandleExternalRequest(ExternalRequest request) { - InputRequest? message = request.Data as InputRequest; + InputRequest? message = request.Data.As(); string? userInput = null; do { diff --git a/dotnet/samples/GettingStarted/Workflows/HumanInTheLoop/HumanInTheLoopBasic/Program.cs b/dotnet/samples/GettingStarted/Workflows/HumanInTheLoop/HumanInTheLoopBasic/Program.cs index 07ffe8c06d..98bc65e4ad 100644 --- a/dotnet/samples/GettingStarted/Workflows/HumanInTheLoop/HumanInTheLoopBasic/Program.cs +++ b/dotnet/samples/GettingStarted/Workflows/HumanInTheLoop/HumanInTheLoopBasic/Program.cs @@ -50,9 +50,9 @@ public static class Program private static ExternalResponse HandleExternalRequest(ExternalRequest request) { - if (request.Port.Request == typeof(NumberSignal)) + if (request.DataIs()) { - var signal = (NumberSignal)request.Data; + var signal = request.DataAs(); switch (signal) { case NumberSignal.Init: @@ -67,7 +67,7 @@ public static class Program } } - throw new NotSupportedException($"Request {request.Port.Request} is not supported"); + throw new NotSupportedException($"Request {request.PortInfo.RequestType} is not supported"); } private static int ReadIntegerFromConsole(string prompt) diff --git a/dotnet/src/Microsoft.Agents.Workflows/CheckpointInfo.cs b/dotnet/src/Microsoft.Agents.Workflows/CheckpointInfo.cs index 5a6f31b2ad..a03b59e2a3 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/CheckpointInfo.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/CheckpointInfo.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows; @@ -10,14 +12,29 @@ namespace Microsoft.Agents.Workflows; public class CheckpointInfo : IEquatable { /// - /// The unique identifier for the checkpoint. + /// Gets the unique identifier for the current run. /// - public string CheckpointId { get; } = Guid.NewGuid().ToString("N"); + public string RunId { get; } /// - /// The date and time when the object was created, in Coordinated Universal Time (UTC). + /// The unique identifier for the checkpoint. /// - public DateTimeOffset CreatedAt { get; } = DateTimeOffset.UtcNow; + public string CheckpointId { get; } + + /// + /// Initializes a new instance of the class with a unique identifier and the current + /// UTC timestamp. + /// + /// This constructor generates a new unique identifier using a GUID in a 32-character, lowercase, + /// hexadecimal format and sets the timestamp to the current UTC time. + internal CheckpointInfo(string runId) : this(runId, Guid.NewGuid().ToString("N")) { } + + [JsonConstructor] + internal CheckpointInfo(string runId, string checkpointId) + { + this.RunId = Throw.IfNullOrEmpty(runId); + this.CheckpointId = Throw.IfNullOrEmpty(checkpointId); + } /// public bool Equals(CheckpointInfo? other) @@ -27,8 +44,7 @@ public class CheckpointInfo : IEquatable return false; } - return this.CheckpointId == other.CheckpointId && - this.CreatedAt == other.CreatedAt; + return this.RunId == other.RunId && this.CheckpointId == other.CheckpointId; } /// @@ -40,9 +56,9 @@ public class CheckpointInfo : IEquatable /// public override int GetHashCode() { - return HashCode.Combine(this.CheckpointId, this.CreatedAt); + return HashCode.Combine(this.RunId, this.CheckpointId); } /// - public override string ToString() => $"CheckpointId: {this.CheckpointId}, CreatedAt: {this.CreatedAt:O}"; + public override string ToString() => $"CheckpointInfo(RunId: {this.RunId}, CheckpointId: {this.CheckpointId})"; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/CheckpointManager.cs b/dotnet/src/Microsoft.Agents.Workflows/CheckpointManager.cs index c1caeab66a..e2c035ba34 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/CheckpointManager.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/CheckpointManager.cs @@ -1,36 +1,57 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Collections.Generic; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Agents.Workflows.Checkpointing; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows; /// -/// An in-memory implementation of that stores checkpoints in a dictionary. +/// A manager for storing and retrieving workflow execution checkpoints. /// public sealed class CheckpointManager : ICheckpointManager { - private readonly Dictionary _checkpoints = new(); + private readonly ICheckpointManager _impl; - ValueTask ICheckpointManager.CommitCheckpointAsync(Checkpoint checkpoint) + private static CheckpointManagerImpl CreateImpl( + IWireMarshaller marshaller, + ICheckpointStore store) { - Throw.IfNull(checkpoint); - - this._checkpoints[checkpoint] = checkpoint; - return new(checkpoint); + return new CheckpointManagerImpl(marshaller, store); } - ValueTask ICheckpointManager.LookupCheckpointAsync(CheckpointInfo checkpointInfo) + private CheckpointManager(ICheckpointManager impl) { - Throw.IfNull(checkpointInfo); - - if (!this._checkpoints.TryGetValue(checkpointInfo, out Checkpoint? checkpoint)) - { - throw new KeyNotFoundException($"Checkpoint not found: {checkpointInfo}"); - } - - return new ValueTask(checkpoint); + this._impl = impl; } + + /// + /// Creates a new instance of that uses the specified marshaller and store. + /// + /// + public static CheckpointManager CreateInMemory() => new(new InMemoryCheckpointManager()); + + /// + /// Gets the default in-memory checkpoint manager instance. + /// + public static CheckpointManager Default { get; } = CreateInMemory(); + + /// + /// Creates a new instance of the CheckpointManager that uses JSON serialization for checkpoint data. + /// + /// The checkpoint store to use for persisting and retrieving checkpoint data as JSON elements. Cannot be null. + /// Optional custom JSON serializer options to use for serialization and deserialization. Must be provided if + /// using custom types in messages or state. + /// A CheckpointManager instance configured to serialize checkpoint data as JSON. + public static CheckpointManager CreateJson(ICheckpointStore store, JsonSerializerOptions? customOptions = null) + { + JsonMarshaller marshaller = new(customOptions); + return new(CreateImpl(marshaller, store)); + } + + ValueTask ICheckpointManager.CommitCheckpointAsync(string runId, Checkpoint checkpoint) + => this._impl.CommitCheckpointAsync(runId, checkpoint); + + ValueTask ICheckpointManager.LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo) + => this._impl.LookupCheckpointAsync(runId, checkpointInfo); } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/Checkpoint.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/Checkpoint.cs index 7390c38255..79335b64e3 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/Checkpoint.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/Checkpoint.cs @@ -1,33 +1,40 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Text.Json.Serialization; using Microsoft.Agents.Workflows.Execution; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows.Checkpointing; -internal class Checkpoint : CheckpointInfo +internal class Checkpoint { + [JsonConstructor] internal Checkpoint( int stepNumber, WorkflowInfo workflow, RunnerStateData runnerData, - Dictionary stateData, - Dictionary edgeStateData) + Dictionary stateData, + Dictionary edgeStateData, + CheckpointInfo? parent = null) { this.StepNumber = Throw.IfLessThan(stepNumber, -1); // -1 is a special flag indicating the initial checkpoint. this.Workflow = Throw.IfNull(workflow); this.RunnerData = Throw.IfNull(runnerData); - this.State = Throw.IfNull(stateData); - this.EdgeState = Throw.IfNull(edgeStateData); + this.StateData = Throw.IfNull(stateData); + this.EdgeStateData = Throw.IfNull(edgeStateData); + this.Parent = parent; } + [JsonIgnore] public bool IsInitial => this.StepNumber == -1; public int StepNumber { get; } public WorkflowInfo Workflow { get; } public RunnerStateData RunnerData { get; } - public readonly Dictionary State = new(); - public readonly Dictionary EdgeState = new(); + public Dictionary StateData { get; } = new(); + public Dictionary EdgeStateData { get; } = new(); + + public CheckpointInfo? Parent { get; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/CheckpointManagerImpl.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/CheckpointManagerImpl.cs new file mode 100644 index 0000000000..67faf4a481 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/CheckpointManagerImpl.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +internal sealed class CheckpointManagerImpl : ICheckpointManager +{ + private readonly IWireMarshaller _marshaller; + private readonly ICheckpointStore _store; + + public CheckpointManagerImpl(IWireMarshaller marshaller, ICheckpointStore store) + { + this._marshaller = marshaller; + this._store = store; + } + + public ValueTask CommitCheckpointAsync(string runId, Checkpoint checkpoint) + { + TStoreObject storeObject = this._marshaller.Marshal(checkpoint); + + return this._store.CreateCheckpointAsync(runId, storeObject, checkpoint.Parent); + } + + public async ValueTask LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo) + { + TStoreObject result = await this._store.RetrieveCheckpointAsync(runId, checkpointInfo).ConfigureAwait(false); + return this._marshaller.Marshal(result); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/DirectEdgeInfo.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/DirectEdgeInfo.cs index 1bcd6c3e7f..48a725ba9b 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/DirectEdgeInfo.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/DirectEdgeInfo.cs @@ -1,12 +1,29 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; +using Microsoft.Agents.Workflows.Execution; + namespace Microsoft.Agents.Workflows.Checkpointing; -internal class DirectEdgeInfo(DirectEdgeData data) : EdgeInfo(Edge.Type.Direct, data.Connection) +/// +/// Represents a direct in the . +/// +public sealed class DirectEdgeInfo : EdgeInfo { - public bool HasCondition => data.Condition != null; + internal DirectEdgeInfo(DirectEdgeData data) : this(data.Condition != null, data.Connection) { } - protected override bool IsMatchInternal(EdgeData edgeData) + [JsonConstructor] + internal DirectEdgeInfo(bool hasCondition, EdgeConnection connection) : base(EdgeKind.Direct, connection) + { + this.HasCondition = hasCondition; + } + + /// + /// Gets a value indicating whether this direct edge has a condition associated with it. + /// + public bool HasCondition { get; } + + internal override bool IsMatchInternal(EdgeData edgeData) { return edgeData is DirectEdgeData directEdge && this.HasCondition == (directEdge.Condition != null); diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeIdConverter.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeIdConverter.cs new file mode 100644 index 0000000000..6ac8e80041 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeIdConverter.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Provides support for using values as dictionary keys when serializing and deserializing JSON. +/// +internal sealed class EdgeIdConverter : JsonConverterDictionarySupportBase +{ + protected override JsonTypeInfo TypeInfo => WorkflowsJsonUtilities.JsonContext.Default.EdgeId; + + protected override EdgeId Parse(string propertyName) + { + if (int.TryParse(propertyName, out int edgeId)) + { + return new(edgeId); + } + + throw new JsonException($"Cannot deserialize EdgeId from JSON propery name '{propertyName}'"); + } + + protected override string Stringify([DisallowNull] EdgeId value) + { + return value.EdgeIndex.ToString(); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeInfo.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeInfo.cs index c9aa4dd50a..ef2e05c87c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeInfo.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/EdgeInfo.cs @@ -1,21 +1,43 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; using Microsoft.Agents.Workflows.Execution; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows.Checkpointing; -internal abstract class EdgeInfo(Edge.Type edgeType, EdgeConnection connection) +/// +/// Base class representing information about an edge in a workflow. +/// +[JsonPolymorphic(UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FailSerialization)] +[JsonDerivedType(typeof(DirectEdgeInfo), (int)EdgeKind.Direct)] +[JsonDerivedType(typeof(FanOutEdgeInfo), (int)EdgeKind.FanOut)] +[JsonDerivedType(typeof(FanInEdgeInfo), (int)EdgeKind.FanIn)] +public class EdgeInfo { - public Edge.Type EdgeType => edgeType; - public EdgeConnection Connection { get; } = Throw.IfNull(connection); + /// + /// The kind of edge. + /// + public EdgeKind Kind { get; } - public bool IsMatch(Edge edge) + /// + /// Gets the connection information associated with the edge. + /// + public EdgeConnection Connection { get; } + + [JsonConstructor] + internal EdgeInfo(EdgeKind kind, EdgeConnection connection) { - return this.EdgeType == edge.EdgeType + this.Kind = kind; + this.Connection = Throw.IfNull(connection); + } + + internal bool IsMatch(Edge edge) + { + return this.Kind == edge.Kind && this.Connection.Equals(edge.Data.Connection) && this.IsMatchInternal(edge.Data); } - protected virtual bool IsMatchInternal(EdgeData edgeData) => true; + internal virtual bool IsMatchInternal(EdgeData edgeData) => true; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExecutorIdentityConverter.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExecutorIdentityConverter.cs new file mode 100644 index 0000000000..3656307009 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExecutorIdentityConverter.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Agents.Workflows.Execution; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Provides support for using values as dictionary keys when serializing and deserializing JSON. +/// +internal sealed class ExecutorIdentityConverter() : JsonConverterDictionarySupportBase +{ + protected override JsonTypeInfo TypeInfo + => WorkflowsJsonUtilities.JsonContext.Default.ExecutorIdentity; + + protected override ExecutorIdentity Parse(string propertyName) + { + if (propertyName.Length == 0) + { + return ExecutorIdentity.None; + } + + if (propertyName[0] == '@') + { + return new() { Id = propertyName.Substring(1) }; + } + + throw new JsonException($"Invalid ExecutorIdentity key Expecting empty string or a value that is prefixed with '@'. Got '{propertyName}'"); + } + + protected override string Stringify(ExecutorIdentity value) + { + return value == ExecutorIdentity.None + ? string.Empty + : $"@{value.Id}"; + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExportedState.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExportedState.cs deleted file mode 100644 index fcdbf8bdc2..0000000000 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ExportedState.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.Workflows.Checkpointing; - -internal class ExportedState(object state) -{ - public Type RuntimeType => Throw.IfNull(state).GetType(); - public object Value => Throw.IfNull(state); -} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanInEdgeInfo.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanInEdgeInfo.cs index 8ea914bb3b..3b9f89be55 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanInEdgeInfo.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanInEdgeInfo.cs @@ -1,5 +1,21 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; +using Microsoft.Agents.Workflows.Execution; + namespace Microsoft.Agents.Workflows.Checkpointing; -internal class FanInEdgeInfo(FanInEdgeData data) : EdgeInfo(Edge.Type.FanIn, data.Connection); +/// +/// Represents a fan-in in the . +/// +public sealed class FanInEdgeInfo : EdgeInfo +{ + internal FanInEdgeInfo(FanInEdgeData data) : base(EdgeKind.FanIn, data.Connection) + { + } + + [JsonConstructor] + internal FanInEdgeInfo(EdgeConnection connection) : base(EdgeKind.FanIn, connection) + { + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanOutEdgeInfo.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanOutEdgeInfo.cs index 7e4aa65502..74b6bf333d 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanOutEdgeInfo.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FanOutEdgeInfo.cs @@ -1,12 +1,29 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; +using Microsoft.Agents.Workflows.Execution; + namespace Microsoft.Agents.Workflows.Checkpointing; -internal class FanOutEdgeInfo(FanOutEdgeData data) : EdgeInfo(Edge.Type.FanOut, data.Connection) +/// +/// Represents a fan-out in the . +/// +public sealed class FanOutEdgeInfo : EdgeInfo { - public bool HasAssigner => data.EdgeAssigner != null; + internal FanOutEdgeInfo(FanOutEdgeData data) : this(data.EdgeAssigner != null, data.Connection) { } - protected override bool IsMatchInternal(EdgeData edgeData) + [JsonConstructor] + internal FanOutEdgeInfo(bool hasAssigner, EdgeConnection connection) : base(EdgeKind.FanOut, connection) + { + this.HasAssigner = hasAssigner; + } + + /// + /// Gets a value indicating whether this fan-out edge has an edge-assigner associated with it. + /// + public bool HasAssigner { get; } + + internal override bool IsMatchInternal(EdgeData edgeData) { return edgeData is FanOutEdgeData fanOutEdge && this.HasAssigner == (fanOutEdge.EdgeAssigner != null); diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs new file mode 100644 index 0000000000..46187e10bb --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Provides a file system-based implementation of a JSON checkpoint store that persists checkpoint data and index +/// information to disk using JSON files. +/// +/// This class manages checkpoint storage by writing JSON files to a specified directory and maintaining +/// an index file for efficient retrieval. It is intended for scenarios where durable, process-exclusive checkpoint +/// persistence is required. Instances of this class are not thread-safe and should not be shared across multiple +/// threads without external synchronization. The class implements IDisposable; callers should ensure Dispose is called +/// to release file handles and system resources when the store is no longer needed. +public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDisposable +{ + [System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "CA2213:Disposable fields should be disposed", + Justification = "It is disposed, the analyzer is just not picking it up properly")] + private FileStream? _indexFile; + + internal DirectoryInfo Directory { get; } + internal HashSet CheckpointIndex { get; } + + /// + /// Initializes a new instance of the class that uses the specified directory + /// + /// + /// + /// + public FileSystemJsonCheckpointStore(DirectoryInfo directory) + { + this.Directory = directory ?? throw new ArgumentNullException(nameof(directory)); + + if (!directory.Exists) + { + directory.Create(); + } + + try + { + this._indexFile = File.Open(Path.Combine(directory.FullName, "index.jsonl"), FileMode.OpenOrCreate, FileAccess.ReadWrite, FileShare.None); + } + catch + { + throw new InvalidOperationException($"The store at '{directory.FullName}' is already in use by another process."); + } + + try + { + // read the lines of indexfile and parse them as CheckpointInfos + this.CheckpointIndex = new HashSet(); + using StreamReader reader = new(this._indexFile, encoding: Encoding.UTF8, detectEncodingFromByteOrderMarks: false, bufferSize: -1, leaveOpen: true); + while (reader.ReadLine() is string line) + { + CheckpointInfo? info = JsonSerializer.Deserialize(line, this.KeyTypeInfo); + if (info != null) + { + this.CheckpointIndex.Add(info); + } + } + } + catch + { + throw new InvalidOperationException($"Could not load store at '{directory.FullName}'. Index corrupted."); + } + } + + /// + public void Dispose() + { + FileStream? indexFileLocal = Interlocked.Exchange(ref this._indexFile, null); + indexFileLocal?.Dispose(); + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Maintainability", "CA1513:Use ObjectDisposedException throw helper", + Justification = "Throw helper does not exist in NetFx 4.7.2")] + private void CheckDisposed() + { + if (this._indexFile == null) + { + throw new ObjectDisposedException($"{nameof(FileSystemJsonCheckpointStore)}({this.Directory.FullName})"); + } + } + + private string GetFileNameForCheckpoint(string runId, CheckpointInfo key) + => Path.Combine(this.Directory.FullName, $"{runId}_{key.CheckpointId}.json"); + + private CheckpointInfo GetUnusedCheckpointInfo(string runId) + { + CheckpointInfo key; + do + { + key = new(runId); + } while (!this.CheckpointIndex.Add(key)); + + return key; + } + + /// + [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1835:Prefer the 'Memory'-based overloads for 'ReadAsync' and 'WriteAsync'", + Justification = "Memory-based overload is missing for 4.7.2")] + public override async ValueTask CreateCheckpointAsync(string runId, JsonElement value, CheckpointInfo? parent = null) + { + this.CheckDisposed(); + + CheckpointInfo key = this.GetUnusedCheckpointInfo(runId); + string fileName = this.GetFileNameForCheckpoint(runId, key); + try + { + using Stream checkpointStream = File.Open(fileName, FileMode.Create, FileAccess.Write, FileShare.None); + using Utf8JsonWriter jsonWriter = new(checkpointStream, new JsonWriterOptions() { Indented = false }); + value.WriteTo(jsonWriter); + + JsonSerializer.Serialize(this._indexFile!, key, this.KeyTypeInfo); + byte[] bytes = Encoding.UTF8.GetBytes(Environment.NewLine); + await this._indexFile!.WriteAsync(bytes, 0, bytes.Length, CancellationToken.None).ConfigureAwait(false); + + return key; + } + catch (Exception ex) + { + this.CheckpointIndex.Remove(key); + + try + { + // try to clean up after ourselves + File.Delete(fileName); + } + catch { } + + throw new InvalidOperationException($"Could not create checkpoint in store at '{this.Directory.FullName}'.", ex); + } + } + + /// + public override async ValueTask RetrieveCheckpointAsync(string runId, CheckpointInfo key) + { + this.CheckDisposed(); + string fileName = this.GetFileNameForCheckpoint(runId, key); + + if (!this.CheckpointIndex.Contains(key) || + !File.Exists(fileName)) + { + throw new KeyNotFoundException($"Checkpoint '{key.CheckpointId}' not found in store at '{this.Directory.FullName}'."); + } + + using FileStream checkpointFileStream = File.Open(fileName, FileMode.Open, FileAccess.Read, FileShare.Read); + using JsonDocument document = await JsonDocument.ParseAsync(checkpointFileStream).ConfigureAwait(false); + + return document.RootElement.Clone(); + } + + /// + public override ValueTask> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null) + { + this.CheckDisposed(); + + return new(this.CheckpointIndex); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointManager.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointManager.cs index be93c391ea..29ea5076fa 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointManager.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointManager.cs @@ -13,16 +13,18 @@ internal interface ICheckpointManager /// /// Commits the specified checkpoint and returns information that can be used to retrieve it later. /// - /// The to be committed. + /// The identifier for the current run or execution context. + /// The checkpoint to commit. /// A representing the incoming checkpoint. - ValueTask CommitCheckpointAsync(Checkpoint checkpoint); + ValueTask CommitCheckpointAsync(string runId, Checkpoint checkpoint); /// /// Retrieves the checkpoint associated with the specified checkpoint information. /// + /// The identifier for the current run of execution context. /// The information used to identify the checkpoint. /// A representing the asynchronous operation. The result contains the associated with the specified . /// Thrown if the checkpoint is not found. - ValueTask LookupCheckpointAsync(CheckpointInfo checkpointInfo); + ValueTask LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo); } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointStore.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointStore.cs new file mode 100644 index 0000000000..4964d37d15 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ICheckpointStore.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Defines a contract for storing and retrieving checkpoints associated with a specific run and key. +/// +/// Implementations of this interface enable durable or in-memory storage of checkpoints, which can be +/// used to resume or audit long-running processes. The interface is generic to support different storage object types +/// depending on the application's requirements. +/// The type of object to be stored as the value for each checkpoint. +public interface ICheckpointStore +{ + /// + /// Asynchronously retrieves the collection of checkpoint information for the specified run identifier, optionally + /// filtered by a parent checkpoint. + /// + /// The unique identifier of the run for which to retrieve checkpoint information. Cannot be null or empty. + /// An optional parent checkpoint to filter the results. If specified, only checkpoints with the given parent are + /// returned; otherwise, all checkpoints for the run are included. + /// A value task representing the asynchronous operation. The result contains a collection of objects associated with the specified run. The collection is empty if no checkpoints are + /// found. + ValueTask> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null); + + /// + /// Asynchronously creates a checkpoint for the specified run and key, associating it with the provided value and + /// optional parent checkpoint. + /// + /// The unique identifier of the run for which the checkpoint is being created. Cannot be null or empty. + /// The value to associate with the checkpoint. Cannot be null. + /// The optional parent checkpoint information. If specified, the new checkpoint will be linked as a child of this + /// parent. + /// A ValueTask that represents the asynchronous operation. The result contains the + /// object representing this stored checkpoint. + ValueTask CreateCheckpointAsync(string runId, TStoreObject value, CheckpointInfo? parent = null); + + /// + /// Asynchronously retrieves a checkpoint object associated with the specified run and checkpoint key. + /// + /// The unique identifier of the run for which the checkpoint is to be retrieved. Cannot be null or empty. + /// The key identifying the specific checkpoint to retrieve. Cannot be null. + /// A ValueTask that represents the asynchronous operation. The result contains the checkpoint object associated + /// with the specified run and key. + ValueTask RetrieveCheckpointAsync(string runId, CheckpointInfo key); +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IDelayedDeserialization.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IDelayedDeserialization.cs new file mode 100644 index 0000000000..f643937d43 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IDelayedDeserialization.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Implements an abstraction across serialization mechanisms to represent a lazily-deserialized value. +/// +/// This can be used when the target-type information is not known at time of initial deserialization. +/// +internal interface IDelayedDeserialization +{ + /// + /// Attempt to deserialize the value as the provided type. + /// + /// + /// + TValue Deserialize(); + + /// + /// Attempt to deserialize the value as the provided type. + /// + /// + /// + object? Deserialize(Type targetType); +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IWireMarshaller.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IWireMarshaller.cs new file mode 100644 index 0000000000..ab15de72f4 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/IWireMarshaller.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Defines methods for marshalling and unmarshalling objects to and from a wire format. +/// +/// +public interface IWireMarshaller +{ + /// + /// Marshals the specified value of the given type into a wire format container. + /// + /// + /// + /// + TWireContainer Marshal(object value, Type type); + + /// + /// Marshals the specified value into a wire format container. + /// + /// + /// + /// + TWireContainer Marshal(TValue value); + + /// + /// Unmarshals the specified wire format container into an object of the given type. + /// + /// + /// + /// + TValue Marshal(TWireContainer data); + + /// + /// Unmarshals the specified wire format container into an object of the specified target type. + /// + /// + /// + /// + object Marshal(Type targetType, TWireContainer data); +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InMemoryCheckpointManager.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InMemoryCheckpointManager.cs new file mode 100644 index 0000000000..2a2e0f2f54 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InMemoryCheckpointManager.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// An in-memory implementation of that stores checkpoints in a dictionary. +/// +internal sealed class InMemoryCheckpointManager : ICheckpointManager +{ + private readonly Dictionary> _store = new(); + + private RunCheckpointCache GetRunStore(string runId) + { + if (!this._store.TryGetValue(runId, out RunCheckpointCache? runStore)) + { + runStore = this._store[runId] = new(); + } + + return runStore; + } + + public ValueTask CommitCheckpointAsync(string runId, Checkpoint checkpoint) + { + RunCheckpointCache runStore = this.GetRunStore(runId); + + CheckpointInfo key; + do + { + key = new(runId); + } while (!runStore.Add(key, checkpoint)); + + return new(key); + } + + public ValueTask LookupCheckpointAsync(string runId, CheckpointInfo checkpointInfo) + { + if (!this.GetRunStore(runId).TryGet(checkpointInfo, out Checkpoint? value)) + { + throw new KeyNotFoundException($"Could not retrieve checkpoint with id {checkpointInfo.CheckpointId} for run {runId}"); + } + + return new(value); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InputPortInfo.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InputPortInfo.cs index abe2dddc91..ee46dcae6c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InputPortInfo.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/InputPortInfo.cs @@ -2,4 +2,12 @@ namespace Microsoft.Agents.Workflows.Checkpointing; -internal record class InputPortInfo(TypeId InputType, TypeId OutputType, string PortId); +/// +/// Information about an input port, including its input and output types. +/// +/// +/// +/// +public record class InputPortInfo(TypeId RequestType, TypeId ResponseType, string PortId) +{ +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonCheckpointStore.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonCheckpointStore.cs new file mode 100644 index 0000000000..41b4331c4d --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonCheckpointStore.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading.Tasks; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// An abstract base class for checkpoint stores that use JSON for serialization. +/// +public abstract class JsonCheckpointStore : ICheckpointStore +{ + /// + /// A default TypeInfo for serializing the type, if needed. + /// + protected JsonTypeInfo KeyTypeInfo => WorkflowsJsonUtilities.JsonContext.Default.CheckpointInfo; + + /// + public abstract ValueTask CreateCheckpointAsync(string runId, JsonElement value, CheckpointInfo? parent = null); + + /// + public abstract ValueTask RetrieveCheckpointAsync(string runId, CheckpointInfo key); + + /// + public abstract ValueTask> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null); +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterBase.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterBase.cs new file mode 100644 index 0000000000..fbcd75bd9c --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterBase.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Provides support for JSON serialization and deserialization using a specified JsonTypeInfo. +/// +/// +internal abstract class JsonConverterBase : JsonConverter +{ + protected abstract JsonTypeInfo TypeInfo { get; } + + public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + SequencePosition position = reader.Position; + + T? maybeValue = JsonSerializer.Deserialize(ref reader, this.TypeInfo); + if (maybeValue is null) + { + throw new JsonException($"Could not deserialize a {typeof(T).Name} from JSON at position {position}"); + } + + return maybeValue; + } + + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value, this.TypeInfo); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterDictionarySupportBase.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterDictionarySupportBase.cs new file mode 100644 index 0000000000..d4ddefbeff --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonConverterDictionarySupportBase.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Provides support for using values as dictionary keys when serializing and deserializing JSON. +/// It chains to the provided for serialization and deserialization when not used as a property +/// name. +/// +/// +internal abstract class JsonConverterDictionarySupportBase : JsonConverterBase +{ + protected abstract string Stringify([DisallowNull] T value); + protected abstract T Parse(string propertyName); + + public override T ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + SequencePosition position = reader.Position; + string? propertyName = reader.GetString(); + + if (propertyName == null) + { + throw new JsonException($"Got null trying to read property name at position {position}"); + } + + return this.Parse(propertyName); + } + + public override void WriteAsPropertyName(Utf8JsonWriter writer, [DisallowNull] T value, JsonSerializerOptions options) + { + string propertyName = this.Stringify(value); + writer.WritePropertyName(propertyName); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonMarshaller.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonMarshaller.cs new file mode 100644 index 0000000000..14c313cc90 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonMarshaller.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +internal class JsonMarshaller : IWireMarshaller +{ + private readonly JsonSerializerOptions _internalOptions; + private readonly JsonSerializerOptions? _externalOptions; + + public JsonMarshaller(JsonSerializerOptions? serializerOptions = null) + { + this._internalOptions = new JsonSerializerOptions(WorkflowsJsonUtilities.DefaultOptions); + this._internalOptions.Converters.Add(new PortableValueConverter(this)); + this._internalOptions.Converters.Add(new ExecutorIdentityConverter()); + this._internalOptions.Converters.Add(new ScopeKeyConverter()); + this._internalOptions.Converters.Add(new EdgeIdConverter()); + + this._externalOptions = serializerOptions; + } + + private JsonTypeInfo LookupTypeInfo(Type type) + { + if (!this._internalOptions.TryGetTypeInfo(type, out JsonTypeInfo? typeInfo)) + { + if (this._externalOptions == null || + !this._externalOptions.TryGetTypeInfo(type, out typeInfo)) + { + throw new InvalidOperationException($"No JSON type info is available for type '{type}'."); + } + } + + return typeInfo; + } + + public JsonElement Marshal(object value, Type type) + => JsonSerializer.SerializeToElement(value, this.LookupTypeInfo(type)); + + public JsonElement Marshal(TValue value) + => JsonSerializer.SerializeToElement(value, this.LookupTypeInfo(typeof(TValue))); + + public TValue Marshal(JsonElement data) + { + Type type = typeof(TValue); + object? value = JsonSerializer.Deserialize(data, this.LookupTypeInfo(type)); + + if (value is null) + { + throw new InvalidOperationException($"Could not deserialize the value as the expected type {typeof(TValue)}."); + } + + if (value is TValue typedValue) + { + return typedValue; + } + + throw new InvalidOperationException($"Deserialized value is not of the expected type {typeof(TValue)}."); + } + + public object Marshal(Type targetType, JsonElement data) + { + object? value = JsonSerializer.Deserialize(data, this.LookupTypeInfo(targetType)); + + if (value is null) + { + throw new InvalidOperationException($"Could not deserialize the value as the expected type {targetType}."); + } + + if (targetType.IsInstanceOfType(value)) + { + return value; + } + + throw new InvalidOperationException($"Deserialized value is not of the expected type {targetType}."); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonWireSerializedValue.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonWireSerializedValue.cs new file mode 100644 index 0000000000..d07d874cc5 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/JsonWireSerializedValue.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Represents a value serialized to the JSON format (). +/// When type information is not available during deserialization, this will wrap a clone of the +/// to be deserialized later. +/// +/// +/// +/// +internal sealed class JsonWireSerializedValue(JsonMarshaller serializer, JsonElement data) : IDelayedDeserialization +{ + internal JsonElement Data { get; } = data.Clone(); + + public TValue Deserialize() => serializer.Marshal(data); + + public object? Deserialize(Type targetType) => serializer.Marshal(targetType, data); + + public override bool Equals(object? obj) + { + if (obj == null) + { + return false; + } + + if (obj is JsonWireSerializedValue otherValue) + { + return JsonElement.DeepEquals(this.Data, otherValue.Data); + } + else if (obj is JsonElement element) + { + return this.Data.Equals(element); + } + else if (obj is not IDelayedDeserialization) + { + // Assume this has the target type of deserialization; serialize it using the explicit type + // and compare. Of course, this also means that if this is a supertype, it could encounter + // truncation. + try + { + JsonElement otherElement = serializer.Marshal(obj, obj.GetType()); + + return JsonElement.DeepEquals(this.Data, otherElement); + } + catch + { + return false; + } + } + + return false; + } + + public override int GetHashCode() + { + return this.Data.GetHashCode(); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableMessageEnvelope.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableMessageEnvelope.cs new file mode 100644 index 0000000000..c2e73de73d --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableMessageEnvelope.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; +using Microsoft.Agents.Workflows.Execution; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +internal sealed class PortableMessageEnvelope +{ + public TypeId MessageType { get; } + public PortableValue Message { get; } + public string? TargetId { get; } + + [JsonConstructor] + internal PortableMessageEnvelope(TypeId messageType, PortableValue message, string? targetId) + { + this.MessageType = messageType; + this.Message = message; + this.TargetId = targetId; + } + + public PortableMessageEnvelope(MessageEnvelope envelope) + { + this.MessageType = envelope.MessageType; + this.Message = new PortableValue(envelope.Message); + this.TargetId = envelope.TargetId; + } + + public MessageEnvelope ToMessageEnvelope() + { + return new MessageEnvelope(this.Message, this.MessageType, this.TargetId); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableValueConverter.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableValueConverter.cs new file mode 100644 index 0000000000..c4456f5635 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/PortableValueConverter.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Provides special handling for serialization and deserialization, enabling delayed deserialization +/// of the inner value. This is used to enable serialization/deserialization of objects whose type information is not available +/// at the time of initial deserialization, e.g. user-defined state types. +/// +/// This operates in conjuction with and to abstract +/// away the speicfics of a given serialization format in favor of and +/// . +/// +/// +internal sealed class PortableValueConverter(JsonMarshaller marshaller) : JsonConverter +{ + public override PortableValue? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + SequencePosition initial = reader.Position; + + JsonTypeInfo baseTypeInfo = WorkflowsJsonUtilities.JsonContext.Default.PortableValue; + PortableValue? maybeValue = JsonSerializer.Deserialize(ref reader, baseTypeInfo); + + if (maybeValue is null) + { + throw new JsonException($"Could not deserialize a PortableValue from JSON at position {initial}."); + } + else if (maybeValue.Value is JsonElement element) + { + // This happens when we do not have the type information available to deserialize the value directly. + // We need to wrap it in a JsonWireSerializedValue so that we can deserialize it + return new PortableValue(maybeValue.TypeId, new JsonWireSerializedValue(marshaller, element)); + } + else if (maybeValue.TypeId.IsMatch(maybeValue.Value.GetType())) + { + return maybeValue; + } + + throw new JsonException($"Deserialized PortableValue contains a value of type {maybeValue.Value.GetType()} which does not match the expected type {maybeValue.TypeId} at position {initial}."); + } + + public override void Write(Utf8JsonWriter writer, PortableValue value, JsonSerializerOptions options) + { + PortableValue proxyValue; + if (value.IsDelayedDeserialization && !value.IsDeserialized) + { + if (value.Value is JsonWireSerializedValue jsonWireValue) + { + proxyValue = new(value.TypeId, jsonWireValue.Data); + } + else + { + // Users should never see this unless they're trying to cross wire formats + throw new InvalidOperationException("Cannot serialize a PortableValue that has not been deserialized. Please deserialize it with .As/AsType() or Is/IsType() methods first."); + } + } + else + { + JsonElement element = marshaller.Marshal(value.Value, value.Value.GetType()); + proxyValue = new(value.TypeId, element); + } + + JsonTypeInfo baseTypeInfo = WorkflowsJsonUtilities.JsonContext.Default.PortableValue; + JsonSerializer.Serialize(writer, proxyValue, baseTypeInfo); + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RepresentationExtensions.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RepresentationExtensions.cs index 96e3822d30..79051fefc8 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RepresentationExtensions.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RepresentationExtensions.cs @@ -18,12 +18,12 @@ internal static class RepresentationExtensions public static EdgeInfo ToEdgeInfo(this Edge edge) { Throw.IfNull(edge); - return edge.EdgeType switch + return edge.Kind switch { - Edge.Type.Direct => new DirectEdgeInfo(edge.DirectEdgeData!), - Edge.Type.FanOut => new FanOutEdgeInfo(edge.FanOutEdgeData!), - Edge.Type.FanIn => new FanInEdgeInfo(edge.FanInEdgeData!), - _ => throw new NotSupportedException($"Unsupported edge type: {edge.EdgeType}") + EdgeKind.Direct => new DirectEdgeInfo(edge.DirectEdgeData!), + EdgeKind.FanOut => new FanOutEdgeInfo(edge.FanOutEdgeData!), + EdgeKind.FanIn => new FanInEdgeInfo(edge.FanInEdgeData!), + _ => throw new NotSupportedException($"Unsupported edge type: {edge.Kind}") }; } @@ -54,6 +54,6 @@ internal static class RepresentationExtensions public static WorkflowInfo ToWorkflowInfo(this Workflow workflow) => workflow.ToWorkflowInfo(outputType: null, outputExecutorId: null); - public static WorkflowInfo GetInfo(this Workflow workflow) + public static WorkflowInfo ToWorkflowInfo(this Workflow workflow) => workflow.ToWorkflowInfo(outputType: new TypeId(typeof(TResult)), outputExecutorId: workflow.OutputCollectorId); } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RunCheckpointCache.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RunCheckpointCache.cs new file mode 100644 index 0000000000..2771d97be7 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/RunCheckpointCache.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +internal sealed class RunCheckpointCache +{ + private readonly HashSet _checkpointIndex = new(); + private readonly Dictionary _cache = new(); + + public IEnumerable Index => this._checkpointIndex; + + public bool IsInIndex(CheckpointInfo key) => this._checkpointIndex.Contains(key); + public bool TryGet(CheckpointInfo key, [MaybeNullWhen(false)] out TStoreObject value) => this._cache.TryGetValue(key, out value); + + public CheckpointInfo Add(string runId, TStoreObject value) + { + CheckpointInfo key; + + do + { + key = new(runId); + } while (!this.Add(key, value)); + + return key; + } + + public bool Add(CheckpointInfo key, TStoreObject value) + { + bool added = this._checkpointIndex.Add(key); + if (added) + { + this._cache[key] = value; + } + + return added; + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ScopeKeyConverter.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ScopeKeyConverter.cs new file mode 100644 index 0000000000..bfbe6742bc --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/ScopeKeyConverter.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Text.RegularExpressions; + +namespace Microsoft.Agents.Workflows.Checkpointing; + +/// +/// Provides support for using values as dictionary keys when serializing and deserializing JSON. +/// +internal sealed class ScopeKeyConverter : JsonConverterDictionarySupportBase +{ + protected override JsonTypeInfo TypeInfo => WorkflowsJsonUtilities.JsonContext.Default.ScopeKey; + + public static readonly Regex ScopeKeyPropertyNamePattern = + new(@"^(?(((\|\|)|([^\|]))*))\|(?(@(((\|\|)|([^\|]))*))?)\|(?(((\|\|)|([^\|]))*)?)$", + RegexOptions.Compiled | RegexOptions.CultureInvariant | RegexOptions.ExplicitCapture); + + protected override ScopeKey Parse(string propertyName) + { + Match scopeKeyPatternMatch = ScopeKeyPropertyNamePattern.Match(propertyName); + if (!scopeKeyPatternMatch.Success) + { + throw new JsonException($"Invalid ScopeKey property name format. Got '{propertyName}'."); + } + + string executorId = scopeKeyPatternMatch.Groups["executorId"].Value; + string scopeName = scopeKeyPatternMatch.Groups["scopeName"].Value; + string key = scopeKeyPatternMatch.Groups["key"].Value; + + return new ScopeKey(Unescape(executorId)!, + Unescape(scopeName, allowNullAndPad: true), + Unescape(key)!); + } + + [return: NotNull] + private static string Escape(string? value, bool allowNullAndPad = false, [CallerArgumentExpression("value")] string componentName = "ScopeKey") + { + if (!allowNullAndPad && value == null) + { + throw new JsonException($"Invalid {componentName} '{value}'. Expecting non-null string."); + } + + if (value == null) + { + return string.Empty; + } + + if (allowNullAndPad) + { + return $"@{value.Replace("|", "||")}"; + } + + return $"{value.Replace("|", "||")}"; + } + + private static string? Unescape([DisallowNull] string value, bool allowNullAndPad = false, [CallerArgumentExpression("value")] string componentName = "ScopeKey") + { + if (value.Length == 0) + { + if (!allowNullAndPad) + { + throw new JsonException($"Invalid {componentName} '{value}'. Expecting empty string or a value that is prefixed with '@'."); + } + + return null; + } + + if (allowNullAndPad && value[0] != '@') + { + throw new JsonException($"Invalid {componentName} component '{value}'. Expecting empty string or a value that is prefixed with '@'."); + } + + if (allowNullAndPad) + { + value = value.Substring(1); + } + + return value.Replace("||", "|"); + } + + protected override string Stringify([DisallowNull] ScopeKey value) + { + string? executorIdEscaped = Escape(value.ScopeId.ExecutorId); + string? scopeNameEscaped = Escape(value.ScopeId.ScopeName, allowNullAndPad: true); + string? keyEscaped = Escape(value.Key); + + return $"{executorIdEscaped}|{scopeNameEscaped}|{keyEscaped}"; + } +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/TypeId.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/TypeId.cs index 077484442f..b38b135529 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/TypeId.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/TypeId.cs @@ -1,20 +1,102 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows.Checkpointing; -internal class TypeId(Type type) +/// +/// A representation of a type's identity, including its assembly and type names. +/// +public class TypeId { - public string AssemblyName => Throw.IfNull(type.Assembly.FullName); - public string TypeName => Throw.IfNull(type.FullName); + /// + public string AssemblyName { get; } + /// + public string TypeName { get; } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + [JsonConstructor] + public TypeId(string assemblyName, string typeName) + { + this.AssemblyName = Throw.IfNull(assemblyName); + this.TypeName = Throw.IfNull(typeName); + } + + /// + /// Initializes a new instance of the TypeId class using the specified type. + /// + /// The type for which to create a unique identifier. Cannot be null. + public TypeId(Type type) + : this( + Throw.IfNullOrMemberNull(type.Assembly, + type.Assembly.FullName), + Throw.IfMemberNull(type, + type.FullName)) + { } + + /// + public override bool Equals(object? obj) + => obj is TypeId other + && this.AssemblyName == other.AssemblyName + && this.TypeName == other.TypeName; + + /// + public override int GetHashCode() => HashCode.Combine(this.AssemblyName, this.TypeName); + + /// + public static bool operator ==(TypeId? left, TypeId? right) => object.ReferenceEquals(left, right) || (!object.ReferenceEquals(left, null) && left.Equals(right)); + + /// + public static bool operator !=(TypeId? left, TypeId? right) => !(left == right); + + /// + /// Determines whether the specified type matches both the assembly name and type name represented by this instance. + /// + /// The type to compare against the stored assembly and type names. Cannot be null. + /// true if the specified type's assembly and type names are equal to those stored in this instance; otherwise, + /// false. public bool IsMatch(Type type) { return this.AssemblyName == type.Assembly.FullName && this.TypeName == type.FullName; } + /// + /// Determines whether the current instance matches the specified type parameter. + /// + /// The type to compare against the current instance. + /// true if the current instance matches the specified type; otherwise, false. public bool IsMatch() => this.IsMatch(typeof(T)); + + /// + /// Determines whether the specified type or any of its base types match the criteria defined by this instance. + /// + /// The type to evaluate for a match, including its inheritance hierarchy. + /// true if the specified type or any of its base types satisfy the match criteria; otherwise, false. + public bool IsMatchPolymorphic(Type type) + { + Type? candidateType = type; + + while (candidateType != null) + { + if (this.IsMatch(candidateType)) + { + return true; + } + + candidateType = candidateType.BaseType; + } + + return false; + } + + /// + public override string ToString() => $"{this.TypeName}, {this.AssemblyName}"; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/WorkflowInfo.cs b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/WorkflowInfo.cs index e879b3f71a..0f5fdd723c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/WorkflowInfo.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Checkpointing/WorkflowInfo.cs @@ -3,20 +3,22 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows.Checkpointing; internal class WorkflowInfo { + [JsonConstructor] internal WorkflowInfo( Dictionary executors, Dictionary> edges, HashSet inputPorts, TypeId inputType, string startExecutorId, - TypeId? outputType = null, - string? outputCollectorId = null) + TypeId? outputType, + string? outputCollectorId) { this.Executors = Throw.IfNull(executors); this.Edges = Throw.IfNull(edges); @@ -93,8 +95,8 @@ internal class WorkflowInfo if (workflow.Ports.Count != this.InputPorts.Count || this.InputPorts.Any(portInfo => !workflow.Ports.TryGetValue(portInfo.PortId, out InputPort? port) || - !portInfo.InputType.IsMatch(port.Request) || - !portInfo.OutputType.IsMatch(port.Response))) + !portInfo.RequestType.IsMatch(port.Request) || + !portInfo.ResponseType.IsMatch(port.Response))) { return false; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/DirectEdgeData.cs b/dotnet/src/Microsoft.Agents.Workflows/DirectEdgeData.cs index 165a43390e..74a73a3a4c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/DirectEdgeData.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/DirectEdgeData.cs @@ -9,27 +9,32 @@ namespace Microsoft.Agents.Workflows; /// Represents a directed edge between two nodes, optionally associated with a condition that determines whether the /// edge is active. /// -/// The id of the source executor node. -/// The id of the target executor node. -/// A predicate determining whether the edge is active for a given message. -public sealed class DirectEdgeData(string sourceId, string sinkId, PredicateT? condition = null) : EdgeData +public sealed class DirectEdgeData : EdgeData { + internal DirectEdgeData(string sourceId, string sinkId, EdgeId id, PredicateT? condition = null) : base(id) + { + this.SourceId = sourceId; + this.SinkId = sinkId; + this.Condition = condition; + this.Connection = new([sourceId], [sinkId]); + } + /// /// The Id of the source node. /// - public string SourceId => sourceId; + public string SourceId { get; } /// /// The Id of the destination node. /// - public string SinkId => sinkId; + public string SinkId { get; } /// /// An optional predicate determining whether the edge is active for a given message. If , /// the edge is always active when a message is generated by the source. /// - public PredicateT? Condition => condition; + public PredicateT? Condition { get; } /// - internal override EdgeConnection Connection { get; } = new([sourceId], [sinkId]); + internal override EdgeConnection Connection { get; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Edge.cs b/dotnet/src/Microsoft.Agents.Workflows/Edge.cs index 46452dc741..8feddc02de 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Edge.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Edge.cs @@ -4,43 +4,43 @@ using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows; +/// +/// Specified the edge type. +/// +public enum EdgeKind +{ + /// + /// A direct connection from one node to another. + /// + Direct, + /// + /// A connection from one node to a set of nodes. + /// + FanOut, + /// + /// A connection from a set of nodes to a single node. + /// + FanIn +} + /// /// Represents a connection or relationship between nodes, characterized by its type and associated data. /// /// -/// An can be of type , , or , as specified by the property. The property holds +/// An can be of type , , or , as specified by the property. The property holds /// additional information relevant to the edge, and its concrete type depends on the value of , functioning as a tagged union. +/// cref="Kind"/>, functioning as a tagged union. /// public sealed class Edge { - /// - /// Specified the edge type. - /// - public enum Type - { - /// - /// A direct connection from one node to another. - /// - Direct, - /// - /// A connection from one node to a set of nodes. - /// - FanOut, - /// - /// A connection from a set of nodes to a single node. - /// - FanIn - } - /// /// Specifies the type of the edge, which determines how the edge is processed in the workflow. /// - public Type EdgeType { get; init; } + public EdgeKind Kind { get; init; } /// - /// The -dependent edge data. + /// The -dependent edge data. /// /// /// @@ -51,21 +51,21 @@ public sealed class Edge { this.Data = Throw.IfNull(data); - this.EdgeType = Type.Direct; + this.Kind = EdgeKind.Direct; } internal Edge(FanOutEdgeData data) { this.Data = Throw.IfNull(data); - this.EdgeType = Type.FanOut; + this.Kind = EdgeKind.FanOut; } internal Edge(FanInEdgeData data) { this.Data = Throw.IfNull(data); - this.EdgeType = Type.FanIn; + this.Kind = EdgeKind.FanIn; } internal DirectEdgeData? DirectEdgeData => this.Data as DirectEdgeData; diff --git a/dotnet/src/Microsoft.Agents.Workflows/EdgeData.cs b/dotnet/src/Microsoft.Agents.Workflows/EdgeData.cs index 0ff69077cf..4cfbe5a5ce 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/EdgeData.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/EdgeData.cs @@ -13,4 +13,11 @@ public abstract class EdgeData /// Gets the connection representation of the edge. /// internal abstract EdgeConnection Connection { get; } + + internal EdgeData(EdgeId id) + { + this.Id = id; + } + + internal EdgeId Id { get; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/EdgeId.cs b/dotnet/src/Microsoft.Agents.Workflows/EdgeId.cs new file mode 100644 index 0000000000..691f28b479 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/EdgeId.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.Workflows; + +/// +/// A unique identifier of an within a . +/// +public readonly struct EdgeId : IEquatable +{ + [JsonConstructor] + internal EdgeId(int edgeIndex) + { + this.EdgeIndex = edgeIndex; + } + + internal int EdgeIndex { get; } + + /// + public override bool Equals(object? obj) + { + if (obj == null) + { + return false; + } + + if (obj is EdgeId edgeId) + { + return this.EdgeIndex == edgeId.EdgeIndex; + } + + if (obj is int edgeIndex) + { + return this.EdgeIndex == edgeIndex; + } + + return false; + } + + /// + public bool Equals(EdgeId other) + { + return this.EdgeIndex == other.EdgeIndex; + } + + /// + public override int GetHashCode() + { + return this.EdgeIndex.GetHashCode(); + } + + /// + public static bool operator ==(EdgeId left, EdgeId right) => left.Equals(right); + + /// + public static bool operator !=(EdgeId left, EdgeId right) => !left.Equals(right); +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeConnection.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeConnection.cs index bbd4ccd4cf..819c25038c 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeConnection.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeConnection.cs @@ -15,10 +15,19 @@ namespace Microsoft.Agents.Workflows.Execution; /// Ordering is relevant because in at least one case, the order of sinks is significant for the execution of /// the edge: . /// -/// An ordered list of unique identifiers of the sources connected by this edge. -/// An ordered list of unique identifiers of the sinks connected by this edge. -public class EdgeConnection(List sourceIds, List sinkIds) : IEquatable +public class EdgeConnection : IEquatable { + /// + /// Create an instance with the specified source and sink IDs. + /// + /// An ordered list of unique identifiers of the sources connected by this edge. + /// An ordered list of unique identifiers of the sinks connected by this edge. + public EdgeConnection(List sourceIds, List sinkIds) + { + this.SourceIds = Throw.IfNull(sourceIds); + this.SinkIds = Throw.IfNull(sinkIds); + } + /// /// Creates a new instance with the specified source and sink IDs, ensuring that all /// IDs are unique. @@ -82,13 +91,27 @@ public class EdgeConnection(List sourceIds, List sinkIds) : IEqu ); } + /// + public static bool operator ==(EdgeConnection? left, EdgeConnection? right) + { + if (left is null) + { + return right is null; + } + + return left.Equals(right); + } + + /// + public static bool operator !=(EdgeConnection? left, EdgeConnection? right) => !(left == right); + /// /// The unique identifiers of the sources connected by this edge. /// - public List SourceIds { get; } = sourceIds; + public List SourceIds { get; } /// /// The unique identifiers of the sinks connected by this edge. /// - public List SinkIds { get; } = sinkIds; + public List SinkIds { get; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs index 27bfe7e9ba..695b39748a 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/EdgeMap.cs @@ -10,8 +10,8 @@ namespace Microsoft.Agents.Workflows.Execution; internal class EdgeMap { - private readonly Dictionary _edgeRunners = new(); - private readonly Dictionary _fanInState = new(); + private readonly Dictionary _edgeRunners = new(); + private readonly Dictionary _fanInState = new(); private readonly Dictionary _portEdgeRunners; private readonly InputEdgeRunner _inputRunner; private readonly IStepTracer? _stepTracer; @@ -24,20 +24,20 @@ internal class EdgeMap { foreach (Edge edge in workflowEdges.Values.SelectMany(e => e)) { - object edgeRunner = edge.EdgeType switch + object edgeRunner = edge.Kind switch { - Edge.Type.Direct => new DirectEdgeRunner(runContext, edge.DirectEdgeData!), - Edge.Type.FanOut => new FanOutEdgeRunner(runContext, edge.FanOutEdgeData!), - Edge.Type.FanIn => new FanInEdgeRunner(runContext, edge.FanInEdgeData!), - _ => throw new NotSupportedException($"Unsupported edge type: {edge.EdgeType}") + EdgeKind.Direct => new DirectEdgeRunner(runContext, edge.DirectEdgeData!), + EdgeKind.FanOut => new FanOutEdgeRunner(runContext, edge.FanOutEdgeData!), + EdgeKind.FanIn => new FanInEdgeRunner(runContext, edge.FanInEdgeData!), + _ => throw new NotSupportedException($"Unsupported edge type: {edge.Kind}") }; if (edgeRunner is FanInEdgeRunner fanInRunner) { - this._fanInState[edge.Data.Connection] = fanInRunner.CreateState(); + this._fanInState[edge.Data.Id] = fanInRunner.CreateState(); } - this._edgeRunners[edge.Data.Connection] = edgeRunner; + this._edgeRunners[edge.Data.Id] = edgeRunner; } this._portEdgeRunners = workflowPorts.ToDictionary( @@ -51,14 +51,14 @@ internal class EdgeMap public async ValueTask> InvokeEdgeAsync(Edge edge, string sourceId, MessageEnvelope message) { - EdgeConnection connection = edge.Data.Connection; - if (!this._edgeRunners.TryGetValue(connection, out object? edgeRunner)) + EdgeId id = edge.Data.Id; + if (!this._edgeRunners.TryGetValue(id, out object? edgeRunner)) { throw new InvalidOperationException($"Edge {edge} not found in the edge map."); } IEnumerable edgeResults; - switch (edge.EdgeType) + switch (edge.Kind) { // We know the corresponding EdgeRunner type given the FlowEdge EdgeType, as // established in the EdgeMap() ctor; this avoid doing an as-cast inside of @@ -66,24 +66,24 @@ internal class EdgeMap // in FanIn/Out cases) // TODO: Once we have a fixed interface, if it is reasonably generalizable // between the Runners, we can normalize it behind an IFace. - case Edge.Type.Direct: + case EdgeKind.Direct: { - DirectEdgeRunner runner = (DirectEdgeRunner)this._edgeRunners[connection]; + DirectEdgeRunner runner = (DirectEdgeRunner)this._edgeRunners[id]; edgeResults = await runner.ChaseAsync(message, this._stepTracer).ConfigureAwait(false); break; } - case Edge.Type.FanOut: + case EdgeKind.FanOut: { - FanOutEdgeRunner runner = (FanOutEdgeRunner)this._edgeRunners[connection]; + FanOutEdgeRunner runner = (FanOutEdgeRunner)this._edgeRunners[id]; edgeResults = await runner.ChaseAsync(message, this._stepTracer).ConfigureAwait(false); break; } - case Edge.Type.FanIn: + case EdgeKind.FanIn: { - FanInEdgeState state = this._fanInState[connection]; - FanInEdgeRunner runner = (FanInEdgeRunner)this._edgeRunners[connection]; + FanInEdgeState state = this._fanInState[id]; + FanInEdgeRunner runner = (FanInEdgeRunner)this._edgeRunners[id]; edgeResults = [await runner.ChaseAsync(sourceId, message, state, this._stepTracer).ConfigureAwait(false)]; break; } @@ -104,43 +104,45 @@ internal class EdgeMap public async ValueTask> InvokeResponseAsync(ExternalResponse response) { - if (!this._portEdgeRunners.TryGetValue(response.Port.Id, out InputEdgeRunner? portRunner)) + if (!this._portEdgeRunners.TryGetValue(response.PortInfo.PortId, out InputEdgeRunner? portRunner)) { - throw new InvalidOperationException($"Port {response.Port.Id} not found in the edge map."); + throw new InvalidOperationException($"Port {response.PortInfo.PortId} not found in the edge map."); } return [await portRunner.ChaseAsync(new MessageEnvelope(response), this._stepTracer).ConfigureAwait(false)]; } - internal ValueTask> ExportStateAsync() + internal ValueTask> ExportStateAsync() { - Dictionary exportedStates = new(); + Dictionary exportedStates = new(); // Right now there is only fan-in state - foreach (EdgeConnection connection in this._fanInState.Keys) + foreach (EdgeId id in this._fanInState.Keys) { - FanInEdgeState state = this._fanInState[connection]; - exportedStates[connection] = new ExportedState(state); + FanInEdgeState state = this._fanInState[id]; + exportedStates[id] = new PortableValue(state); } - return new ValueTask>(exportedStates); + return new(exportedStates); } internal ValueTask ImportStateAsync(Checkpoint checkpoint) { - Dictionary importedState = checkpoint.EdgeState; + Dictionary importedState = checkpoint.EdgeStateData; this._fanInState.Clear(); - foreach (EdgeConnection connection in importedState.Keys) + foreach (EdgeId id in importedState.Keys) { - ExportedState exportedState = importedState[connection]; - if (exportedState.Value is FanInEdgeState fanInState) + PortableValue exportedState = importedState[id]; + + FanInEdgeState? fanInState = exportedState.As(); + if (fanInState is not null) { - this._fanInState[connection] = fanInState; + this._fanInState[id] = fanInState; } else { - throw new InvalidOperationException($"Unsupported exported state type: {exportedState.GetType()} for connection {connection}"); + throw new InvalidOperationException($"Unsupported exported state type: {exportedState.GetType()}; {id}"); } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs index f61d0cfe2c..24dd2788bc 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeRunner.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Diagnostics; using System.Threading.Tasks; namespace Microsoft.Agents.Workflows.Execution; @@ -12,33 +13,44 @@ internal class FanInEdgeRunner(IRunnerContext runContext, FanInEdgeData edgeData public FanInEdgeState CreateState() => new(this.EdgeData); - public async ValueTask> ChaseAsync(string sourceId, MessageEnvelope envelope, FanInEdgeState state, IStepTracer? tracer) + public ValueTask> ChaseAsync(string sourceId, MessageEnvelope envelope, FanInEdgeState state, IStepTracer? tracer) { if (envelope.TargetId != null && this.EdgeData.SinkId != envelope.TargetId) { // This message is not for us. - return []; + return new([]); } object message = envelope.Message; - IEnumerable? releasedMessages = state.ProcessMessage(sourceId, message); + IEnumerable? releasedMessages = state.ProcessMessage(sourceId, envelope); if (releasedMessages is null) { // Not ready to process yet. - return []; + return new([]); } + return this.ForwardReleasedMessagesAsync(releasedMessages, tracer); + } + + private async ValueTask> ForwardReleasedMessagesAsync(IEnumerable releasedMessages, IStepTracer? tracer) + { Executor target = await this.RunContext.EnsureExecutorAsync(this.EdgeData.SinkId, tracer) .ConfigureAwait(false); List> messageTasks = []; - foreach (var messageTask in releasedMessages) + foreach (MessageEnvelope releasedEnvelope in releasedMessages) { - if (target.CanHandle(messageTask.GetType())) + object message = releasedEnvelope.Message; + Debug.Assert(message is PortableValue, "It should not be possible to get messages released without roundtripping them through" + + "PortableValue via PortableMessageEnvelope."); + + PortableValue portable = message as PortableValue ?? new PortableValue(releasedEnvelope.MessageType, message); + + if (target.CanHandle(portable.TypeId)) { tracer?.TraceActivated(target.Id); - messageTasks.Add(target.ExecuteAsync(messageTask, envelope.MessageType, this.BoundContext).AsTask()); + messageTasks.Add(target.ExecuteAsync(portable, releasedEnvelope.MessageType, this.BoundContext).AsTask()); } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeState.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeState.cs index 2f31e7ab2e..cb6c5576c7 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeState.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/FanInEdgeState.cs @@ -1,28 +1,53 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; +using System.Text.Json.Serialization; +using System.Threading; +using Microsoft.Agents.Workflows.Checkpointing; namespace Microsoft.Agents.Workflows.Execution; -internal record FanInEdgeState(FanInEdgeData EdgeData) +internal class FanInEdgeState { - private List? _pendingMessages = []; - - private HashSet? _unseen = new(EdgeData.SourceIds); - - public IEnumerable? ProcessMessage(string sourceId, object message) + private List _pendingMessages; + public FanInEdgeState(FanInEdgeData fanInEdge) { - this._pendingMessages!.Add(message); - this._unseen!.Remove(sourceId); + this.SourceIds = fanInEdge.SourceIds.ToArray(); + this.Unseen = new(this.SourceIds); - if (this._unseen.Count == 0) + this._pendingMessages = []; + } + + public string[] SourceIds { get; } + public HashSet Unseen { get; private set; } + public List PendingMessages => this._pendingMessages; + + [JsonConstructor] + public FanInEdgeState(string[] sourceIds, HashSet unseen, List pendingMessages) + { + this.SourceIds = sourceIds; + this.Unseen = unseen; + + this._pendingMessages = pendingMessages; + } + + public IEnumerable? ProcessMessage(string sourceId, MessageEnvelope envelope) + { + this.PendingMessages.Add(new(envelope)); + this.Unseen.Remove(sourceId); + + if (this.Unseen.Count == 0) { - List result = this._pendingMessages; + List takenMessages = Interlocked.Exchange(ref this._pendingMessages, []); + this.Unseen = new(this.SourceIds); - this._pendingMessages = []; - this._unseen = new(this.EdgeData.SourceIds); + if (takenMessages.Count == 0) + { + return null; + } - return result; + return takenMessages.Select(portable => portable.ToMessageEnvelope()); } return null; diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs index 903c2864de..7de23a000a 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/InputEdgeRunner.cs @@ -35,7 +35,7 @@ internal class InputEdgeRunner(IRunnerContext runContext, string sinkId) } // TODO: Throw instead? / Log - Debug.WriteLine($"Executor {target.Id} cannot handle message of type {envelope.MessageType.FullName}. Dropping."); + Debug.WriteLine($"Executor {target.Id} cannot handle message of type {envelope.MessageType.TypeName}. Dropping."); return null; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs index 829a91625a..aac534d799 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageEnvelope.cs @@ -1,12 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Microsoft.Agents.Workflows.Checkpointing; namespace Microsoft.Agents.Workflows.Execution; -internal sealed class MessageEnvelope(object message, Type? declaredType = null, string? targetId = null) +internal sealed class MessageEnvelope(object message, TypeId? declaredType = null, string? targetId = null) { - public Type MessageType => declaredType ?? message.GetType(); + public TypeId MessageType => declaredType ?? new(message.GetType()); public object Message => message; public string? TargetId => targetId; + + internal MessageEnvelope(object message, Type declaredType, string? targetId = null) + : this(message, new TypeId(declaredType), targetId) + { + if (!declaredType.IsAssignableFrom(message.GetType())) + { + throw new ArgumentException($"The declared type {declaredType} is not compatible with the message instance of type {message.GetType()}"); + } + } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs index 2731f02b52..ae54c397e9 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/MessageRouter.cs @@ -2,7 +2,9 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; +using Microsoft.Agents.Workflows.Checkpointing; using Microsoft.Shared.Diagnostics; using MessageHandlerF = @@ -17,24 +19,29 @@ namespace Microsoft.Agents.Workflows.Execution; internal class MessageRouter { private readonly Dictionary _typedHandlers; + private readonly Dictionary _runtimeTypeMap; private readonly bool _hasCatchall; internal MessageRouter(Dictionary handlers) { - this._typedHandlers = Throw.IfNull(handlers); - this._hasCatchall = this._typedHandlers.ContainsKey(typeof(object)); + Throw.IfNull(handlers); - this.IncomingTypes = [.. this._typedHandlers.Keys]; + this._typedHandlers = handlers; + this._runtimeTypeMap = handlers.Keys.ToDictionary(t => new TypeId(t), t => t); + + this._hasCatchall = handlers.ContainsKey(typeof(object)); + + this.IncomingTypes = [.. handlers.Keys]; } public HashSet IncomingTypes { get; } - public bool CanHandle(object message) => this.CanHandle(Throw.IfNull(message).GetType()); + public bool CanHandle(object message) => this.CanHandle(new TypeId(Throw.IfNull(message).GetType())); + public bool CanHandle(Type candidateType) => this.CanHandle(new TypeId(Throw.IfNull(candidateType))); - public bool CanHandle(Type candidateType) + public bool CanHandle(TypeId candidateType) { - // For now we only support routing to handlers registered on the exact type (no base type delegation). - return this._hasCatchall || this._typedHandlers.ContainsKey(candidateType); + return this._hasCatchall || this._runtimeTypeMap.ContainsKey(candidateType); } public async ValueTask RouteMessageAsync(object message, IWorkflowContext context, bool requireRoute = false) @@ -43,6 +50,13 @@ internal class MessageRouter CallResult? result = null; + if (message is PortableValue portableValue && + this._runtimeTypeMap.TryGetValue(portableValue.TypeId, out Type? runtimeType)) + { + // If we found a runtime type, we can use it + message = portableValue.AsType(runtimeType) ?? message; + } + try { if (this._typedHandlers.TryGetValue(message.GetType(), out MessageHandlerF? handler)) diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/RunnerStateData.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/RunnerStateData.cs index 877518718a..08ef92828b 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/RunnerStateData.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/RunnerStateData.cs @@ -5,9 +5,9 @@ using Microsoft.Agents.Workflows.Checkpointing; namespace Microsoft.Agents.Workflows.Execution; -internal class RunnerStateData(HashSet instantiatedExecutors, Dictionary> queuedMessages, List outstandingRequests) +internal class RunnerStateData(HashSet instantiatedExecutors, Dictionary> queuedMessages, List outstandingRequests) { public HashSet InstantiatedExecutors { get; } = instantiatedExecutors; - public Dictionary> QueuedMessages { get; } = queuedMessages; + public Dictionary> QueuedMessages { get; } = queuedMessages; public List OutstandingRequests { get; } = outstandingRequests; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/StateManager.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/StateManager.cs index 6d625df92d..c4554fbf7f 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/StateManager.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/StateManager.cs @@ -109,12 +109,12 @@ internal class StateManager // What's the right thing to do when we have a state object, but it is the wrong type? if (result.IsDelete) { - return new ValueTask((T?)default); + return new((T?)default); } if (result.Value is T) { - return new ValueTask((T?)result.Value); + return new((T?)result.Value); } throw new InvalidOperationException($"State for key '{key}' in scope '{scopeId}' is not of type '{typeof(T).Name}'."); @@ -124,20 +124,31 @@ internal class StateManager return scope.ReadStateAsync(key); } - public ValueTask WriteStateAsync(string executorId, string? scopeName, string key, T? value) + public ValueTask WriteStateAsync(string executorId, string? scopeName, string key, T value) => this.WriteStateAsync(new ScopeId(Throw.IfNullOrEmpty(executorId), scopeName), key, value); - public ValueTask WriteStateAsync(ScopeId scopeId, string key, T? value) + public ValueTask WriteStateAsync(ScopeId scopeId, string key, T value) { Throw.IfNullOrEmpty(key); UpdateKey stateKey = new(scopeId, key); - StateUpdate update = value == null ? StateUpdate.Delete(key) : StateUpdate.Update(key, value); + StateUpdate update = StateUpdate.Update(key, value); this._queuedUpdates[stateKey] = update; return default; } + public ValueTask ClearStateAsync(string executorId, string? scopeName, string key) + => this.ClearStateAsync(new ScopeId(Throw.IfNullOrEmpty(executorId), scopeName), key); + + public ValueTask ClearStateAsync(ScopeId scopeId, string key) + { + Throw.IfNullOrEmpty(key); + UpdateKey stateKey = new(scopeId, key); + this._queuedUpdates[stateKey] = StateUpdate.Delete(key); + return default; + } + public async ValueTask PublishUpdatesAsync(IStepTracer? tracer) { Dictionary>> updatesByScope = []; @@ -172,15 +183,15 @@ internal class StateManager this._queuedUpdates.Clear(); } - private static IEnumerable> ExportScope(StateScope scope) + private static IEnumerable> ExportScope(StateScope scope) { - foreach (KeyValuePair state in scope.ExportStates()) + foreach (KeyValuePair state in scope.ExportStates()) { yield return new(new ScopeKey(scope.ScopeId, state.Key), state.Value); } } - internal async ValueTask> ExportStateAsync() + internal async ValueTask> ExportStateAsync() { if (this._queuedUpdates.Count != 0) { @@ -201,7 +212,7 @@ internal class StateManager this._queuedUpdates.Clear(); this._scopes.Clear(); - Dictionary importedState = checkpoint.State; + Dictionary importedState = checkpoint.StateData; foreach (ScopeKey scopeKey in importedState.Keys) { diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/StateScope.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/StateScope.cs index ce5d559337..3fcb754895 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/StateScope.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/StateScope.cs @@ -4,14 +4,13 @@ using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using Microsoft.Agents.Workflows.Checkpointing; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows.Execution; internal class StateScope { - private readonly Dictionary _stateData = new(); + private readonly Dictionary _stateData = new(); public ScopeId ScopeId { get; } public StateScope(ScopeId scopeId) @@ -30,15 +29,32 @@ internal class StateScope return new(keys); } + public bool Contains(string key) + { + Throw.IfNullOrEmpty(key); + if (this._stateData.TryGetValue(key, out PortableValue? value)) + { + return value.Is(); + } + + return false; + } + + public bool ContainsKey(string key) + { + Throw.IfNullOrEmpty(key); + return this._stateData.ContainsKey(key); + } + public ValueTask ReadStateAsync(string key) { Throw.IfNullOrEmpty(key); - if (this._stateData.TryGetValue(key, out object? value) && value is T typedValue) + if (this._stateData.TryGetValue(key, out PortableValue? value)) { - return new ValueTask(typedValue); + return new(value.As()); } - return new ValueTask((T?)default); + return new((T?)default); } public ValueTask WriteStateAsync(Dictionary> updates) @@ -64,28 +80,28 @@ internal class StateScope } else { - this._stateData[key] = update.Value!; + this._stateData[key] = new PortableValue(update.Value!); } } return default; } - public IEnumerable> ExportStates() + public IEnumerable> ExportStates() { return this._stateData.Keys.Select(WrapStates); - KeyValuePair WrapStates(string key) + KeyValuePair WrapStates(string key) { - return new(key, new(this._stateData[key])); + return new(key, this._stateData[key]); } } - public void ImportState(string key, ExportedState state) + public void ImportState(string key, PortableValue state) { Throw.IfNullOrEmpty(key); Throw.IfNull(state); - this._stateData[key] = state.Value; + this._stateData[key] = state; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs b/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs index 2aa8438f52..53eb22d0e3 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Execution/StepContext.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using System.IO; using System.Linq; using Microsoft.Agents.Workflows.Checkpointing; @@ -25,25 +24,23 @@ internal class StepContext // TODO: Create a MessageEnvelope class that extends from the ExportedState object (with appropriate rename) to avoid // unnecessary wrapping and unwrapping of messages during checkpointing. - internal Dictionary> ExportMessages() + internal Dictionary> ExportMessages() { return this.QueuedMessages.Keys.ToDictionary( keySelector: identity => identity, elementSelector: identity => this.QueuedMessages[identity] - .Select(v => new ExportedState(v)) + .Select(v => new PortableMessageEnvelope(v)) .ToList() ); } - internal void ImportMessages(Dictionary> messages) + internal void ImportMessages(Dictionary> messages) { foreach (ExecutorIdentity identity in messages.Keys) { this.QueuedMessages[identity] = messages[identity].Select(UnwrapExportedState).ToList(); } - MessageEnvelope UnwrapExportedState(ExportedState es) - => es.Value as MessageEnvelope - ?? throw new InvalidDataException($"Expected a MessageEnvelope in the ExportedState. Got {es.RuntimeType}"); + MessageEnvelope UnwrapExportedState(PortableMessageEnvelope es) => es.ToMessageEnvelope(); } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/Executor.cs b/dotnet/src/Microsoft.Agents.Workflows/Executor.cs index 8b813c08c6..9635d97174 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Executor.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Executor.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Reflection; using System.Threading; using System.Threading.Tasks; +using Microsoft.Agents.Workflows.Checkpointing; using Microsoft.Agents.Workflows.Execution; using Microsoft.Agents.Workflows.Reflection; @@ -65,7 +66,7 @@ public abstract class Executor : IIdentified /// A ValueTask representing the asynchronous operation, wrapping the output from the executor. /// No handler found for the message type. /// An exception is generated while handling the message. - public async ValueTask ExecuteAsync(object message, Type messageType, IWorkflowContext context) + public async ValueTask ExecuteAsync(object message, TypeId messageType, IWorkflowContext context) { await context.AddEventAsync(new ExecutorInvokedEvent(this.Id, message)).ConfigureAwait(false); @@ -79,7 +80,7 @@ public abstract class Executor : IIdentified } else { - executionResult = new ExecutorFailureEvent(this.Id, result.Exception); + executionResult = new ExecutorFailedEvent(this.Id, result.Exception); } await context.AddEventAsync(executionResult).ConfigureAwait(false); @@ -141,6 +142,8 @@ public abstract class Executor : IIdentified /// /// public bool CanHandle(Type messageType) => this.Router.CanHandle(messageType); + + internal bool CanHandle(TypeId messageType) => this.Router.CanHandle(messageType); } /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/ExecutorEvent.cs b/dotnet/src/Microsoft.Agents.Workflows/ExecutorEvent.cs index 9785bbfe29..9e67c49097 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/ExecutorEvent.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/ExecutorEvent.cs @@ -1,10 +1,15 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; + namespace Microsoft.Agents.Workflows; /// /// Base class for -scoped events. /// +[JsonDerivedType(typeof(ExecutorInvokedEvent))] +[JsonDerivedType(typeof(ExecutorCompletedEvent))] +[JsonDerivedType(typeof(ExecutorFailedEvent))] public class ExecutorEvent(string executorId, object? data) : WorkflowEvent(data) { /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/ExecutorFailureEvent.cs b/dotnet/src/Microsoft.Agents.Workflows/ExecutorFailedEvent.cs similarity index 88% rename from dotnet/src/Microsoft.Agents.Workflows/ExecutorFailureEvent.cs rename to dotnet/src/Microsoft.Agents.Workflows/ExecutorFailedEvent.cs index b3a3168cd0..6b849b0af2 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/ExecutorFailureEvent.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/ExecutorFailedEvent.cs @@ -9,7 +9,7 @@ namespace Microsoft.Agents.Workflows; /// /// The unique identifier of the executor that has failed. /// The exception representing the error. -public sealed class ExecutorFailureEvent(string executorId, Exception? err) +public sealed class ExecutorFailedEvent(string executorId, Exception? err) : ExecutorEvent(executorId, data: err) { /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/ExternalRequest.cs b/dotnet/src/Microsoft.Agents.Workflows/ExternalRequest.cs index dd84ce04ee..316a6edb75 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/ExternalRequest.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/ExternalRequest.cs @@ -2,6 +2,7 @@ using System; using System.Diagnostics.CodeAnalysis; +using Microsoft.Agents.Workflows.Checkpointing; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows; @@ -9,11 +10,25 @@ namespace Microsoft.Agents.Workflows; /// /// Represents a request to an external input port. /// -/// The port to invoke. +/// The port to invoke. /// A unique identifier for this request instance. /// The data contained in the request. -public record ExternalRequest(InputPort Port, string RequestId, object Data) +public record ExternalRequest(InputPortInfo PortInfo, string RequestId, PortableValue Data) { + /// + /// Attempts to retrieve the underlying data as the specified type. + /// + /// The type to which the data should be cast or converted. + /// The data cast to the specified type, or null if the data cannot be cast to the specified type. + public TValue? DataAs() => this.Data.As(); + + /// + /// Determines whether the underlying data is of the specified type. + /// + /// The type to compare with the underlying data. + /// true if the underlying data is of type TValue; otherwise, false. + public bool DataIs() => this.Data.Is(); + /// /// Creates a new for the specified input port and data payload. /// @@ -32,7 +47,7 @@ public record ExternalRequest(InputPort Port, string RequestId, object Data) requestId ??= Guid.NewGuid().ToString("N"); - return new ExternalRequest(port, requestId, data); + return new ExternalRequest(port.ToPortInfo(), requestId, new PortableValue(data)); } /// @@ -53,13 +68,13 @@ public record ExternalRequest(InputPort Port, string RequestId, object Data) /// Thrown when the input data object does not match the expected response type. public ExternalResponse CreateResponse(object data) { - if (!Throw.IfNull(this.Port).Response.IsAssignableFrom(Throw.IfNull(data).GetType())) + if (!Throw.IfNull(this.PortInfo).ResponseType.IsMatchPolymorphic(Throw.IfNull(data).GetType())) { throw new InvalidOperationException( - $"Message type {data.GetType().Name} is not assignable to the response type {this.Port.Response.Name} of input port {this.Port.Id}."); + $"Message type {data.GetType().Name} does not match expected response type {this.PortInfo.ResponseType.TypeName} of input port {this.PortInfo.PortId}."); } - return new ExternalResponse(this.Port, this.RequestId, data); + return new ExternalResponse(this.PortInfo, this.RequestId, new PortableValue(data)); } /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/ExternalResponse.cs b/dotnet/src/Microsoft.Agents.Workflows/ExternalResponse.cs index c80fbfd561..b39b7dc9fb 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/ExternalResponse.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/ExternalResponse.cs @@ -1,13 +1,36 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using Microsoft.Agents.Workflows.Checkpointing; + namespace Microsoft.Agents.Workflows; /// /// Represents a request from an external input port. /// -/// The port invoked. +/// The port invoked. /// The unique identifier of the corresponding request. /// The data contained in the response. -public record ExternalResponse(InputPort Port, string RequestId, object Data) +public record ExternalResponse(InputPortInfo PortInfo, string RequestId, PortableValue Data) { + /// + /// Attempts to retrieve the underlying data as the specified type. + /// + /// The type to which the data should be cast or converted. + /// The data cast to the specified type, or null if the data cannot be cast to the specified type. + public TValue? DataAs() => this.Data.As(); + + /// + /// Determines whether the underlying data is of the specified type. + /// + /// The type to compare with the underlying data. + /// true if the underlying data is of type TValue; otherwise, false. + public bool DataIs() => this.Data.Is(); + + /// + /// Attempts to retrieve the underlying data as the specified type. + /// + /// The type to which the data should be cast or converted. + /// The data cast to the specified type, or null if the data cannot be cast to the specified type. + public object? DataAs(Type targetType) => this.Data.AsType(targetType); } diff --git a/dotnet/src/Microsoft.Agents.Workflows/FanInEdgeData.cs b/dotnet/src/Microsoft.Agents.Workflows/FanInEdgeData.cs index 33c6fdc241..5d937fa331 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/FanInEdgeData.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/FanInEdgeData.cs @@ -8,20 +8,25 @@ namespace Microsoft.Agents.Workflows; /// /// Represents a connection from a set of nodes to a single node. It will trigger either when all edges have data. /// -/// An enumeration of ids of the source executor nodes. -/// The id of the target executor node. -public sealed class FanInEdgeData(List sourceIds, string sinkId) : EdgeData +internal sealed class FanInEdgeData : EdgeData { + internal FanInEdgeData(List sourceIds, string sinkId, EdgeId id) : base(id) + { + this.SourceIds = sourceIds; + this.SinkId = sinkId; + this.Connection = new(sourceIds, [sinkId]); + } + /// /// The ordered list of Ids of the source nodes. /// - public List SourceIds => sourceIds; + public List SourceIds { get; } /// /// The Id of the destination node. /// - public string SinkId => sinkId; + public string SinkId { get; } /// - internal override EdgeConnection Connection { get; } = new(sourceIds, [sinkId]); + internal override EdgeConnection Connection { get; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/FanOutEdgeData.cs b/dotnet/src/Microsoft.Agents.Workflows/FanOutEdgeData.cs index 1fe83edae2..2cd47a10b5 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/FanOutEdgeData.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/FanOutEdgeData.cs @@ -11,30 +11,32 @@ namespace Microsoft.Agents.Workflows; /// Represents a connection from a single node to a set of nodes, optionally associated with a paritition selector /// function which maps incoming messages to a subset of the target set. /// -/// The id of the source executor node. -/// A list of ids of the target executor nodes. -/// A function that maps an incoming message to a subset of the target executor nodes. -public sealed class FanOutEdgeData( - string sourceId, - List sinkIds, - AssignerF? assigner = null) : EdgeData +internal sealed class FanOutEdgeData : EdgeData { + internal FanOutEdgeData(string sourceId, List sinkIds, EdgeId edgeId, AssignerF? assigner = null) : base(edgeId) + { + this.SourceId = sourceId; + this.SinkIds = sinkIds; + this.EdgeAssigner = assigner; + this.Connection = new([sourceId], sinkIds); + } + /// /// The Id of the source node. /// - public string SourceId => sourceId; + public string SourceId { get; } /// /// The ordered list of Ids of the destination nodes. /// - public List SinkIds => sinkIds; + public List SinkIds { get; } /// /// A function mapping an incoming message to a subset of the target executor nodes (or optionally all of them). /// If , all destination nodes are selected. /// - public AssignerF? EdgeAssigner => assigner; + public AssignerF? EdgeAssigner { get; } /// - internal override EdgeConnection Connection { get; } = new([sourceId], sinkIds); + internal override EdgeConnection Connection { get; } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs index 4bf2f50e8c..05d5fa27be 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunner.cs @@ -22,17 +22,20 @@ namespace Microsoft.Agents.Workflows.InProc; /// The type of input accepted by the workflow. Must be non-nullable. internal class InProcessRunner : ISuperStepRunner, ICheckpointingRunner where TInput : notnull { - public InProcessRunner(Workflow workflow, ICheckpointManager? checkpointManager) + public InProcessRunner(Workflow workflow, ICheckpointManager? checkpointManager, string? runId = null) { this.Workflow = Throw.IfNull(workflow); this.RunContext = new InProcessRunnerContext(workflow); this.CheckpointManager = checkpointManager; + this.RunId = runId ?? Guid.NewGuid().ToString("N"); // Initialize the runners for each of the edges, along with the state for edges that // need it. this.EdgeMap = new EdgeMap(this.RunContext, this.Workflow.Edges, this.Workflow.Ports.Values, this.Workflow.StartExecutorId, this.StepTracer); } + public string RunId { get; } + public async ValueTask IsValidInputAsync(TMessage message) { Throw.IfNull(message); @@ -166,9 +169,22 @@ internal class InProcessRunner : ISuperStepRunner, ICheckpointingRunner return true; } + this.EmitPendingEvents(); return false; } + private void EmitPendingEvents() + { + if (this.RunContext.QueuedEvents.Count > 0) + { + foreach (WorkflowEvent @event in this.RunContext.QueuedEvents) + { + this.RaiseWorkflowEvent(@event); + } + this.RunContext.QueuedEvents.Clear(); + } + } + private async ValueTask RunSuperstepAsync(StepContext currentStep) { this.RaiseWorkflowEvent(this.StepTracer.Advance(currentStep)); @@ -197,11 +213,7 @@ internal class InProcessRunner : ISuperStepRunner, ICheckpointingRunner IEnumerable results = (await Task.WhenAll(edgeTasks).ConfigureAwait(false)).SelectMany(r => r); // After the message handler invocations, we may have some events to deliver - foreach (WorkflowEvent @event in this.RunContext.QueuedEvents) - { - this.RaiseWorkflowEvent(@event); - } - this.RunContext.QueuedEvents.Clear(); + this.EmitPendingEvents(); await this.CheckpointAsync().ConfigureAwait(false); @@ -228,16 +240,16 @@ internal class InProcessRunner : ISuperStepRunner, ICheckpointingRunner this._workflowInfoCache = this.Workflow.ToWorkflowInfo(); } - Dictionary edgeData = await this.EdgeMap.ExportStateAsync().ConfigureAwait(false); + Dictionary edgeData = await this.EdgeMap.ExportStateAsync().ConfigureAwait(false); await prepareTask.ConfigureAwait(false); await this.RunContext.StateManager.PublishUpdatesAsync(this.StepTracer).ConfigureAwait(false); RunnerStateData runnerData = await this.RunContext.ExportStateAsync().ConfigureAwait(false); - Dictionary stateData = await this.RunContext.StateManager.ExportStateAsync().ConfigureAwait(false); + Dictionary stateData = await this.RunContext.StateManager.ExportStateAsync().ConfigureAwait(false); Checkpoint checkpoint = new(this.StepTracer.StepNumber, this._workflowInfoCache, runnerData, stateData, edgeData); - CheckpointInfo checkpointInfo = await this.CheckpointManager.CommitCheckpointAsync(checkpoint).ConfigureAwait(false); + CheckpointInfo checkpointInfo = await this.CheckpointManager.CommitCheckpointAsync(this.RunId, checkpoint).ConfigureAwait(false); this.StepTracer.TraceCheckpointCreated(checkpointInfo); this._checkpoints.Add(checkpointInfo); } @@ -250,7 +262,7 @@ internal class InProcessRunner : ISuperStepRunner, ICheckpointingRunner throw new InvalidOperationException("This run was not configured with a CheckpointManager, so it cannot restore checkpoints."); } - Checkpoint checkpoint = await this.CheckpointManager.LookupCheckpointAsync(checkpointInfo) + Checkpoint checkpoint = await this.CheckpointManager.LookupCheckpointAsync(this.RunId, checkpointInfo) .ConfigureAwait(false); // Validate the checkpoint is compatible with this workflow @@ -283,11 +295,11 @@ internal class InProcessRunner : IRunnerWithOutput, IC private readonly Workflow _workflow; private readonly InProcessRunner _innerRunner; - public InProcessRunner(Workflow workflow, CheckpointManager? checkpointManager) + public InProcessRunner(Workflow workflow, CheckpointManager? checkpointManager, string? runId = null) { this._workflow = Throw.IfNull(workflow); - InProcessRunner runner = new(workflow, checkpointManager); + InProcessRunner runner = new(workflow, checkpointManager, runId); this._innerRunner = runner; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs index b2d30f21f4..aec275b2ef 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/InProc/InProcessRunnerContext.cs @@ -135,7 +135,7 @@ internal class InProcessRunnerContext : IRunnerContext throw new InvalidOperationException("Cannot export state when there are queued events. Please process or clear the events before exporting state."); } - Dictionary> queuedMessages = this._nextStep.ExportMessages(); + Dictionary> queuedMessages = this._nextStep.ExportMessages(); RunnerStateData result = new(instantiatedExecutors: [.. this._executors.Keys], queuedMessages, outstandingRequests: [.. this._externalRequests.Values]); diff --git a/dotnet/src/Microsoft.Agents.Workflows/InProcessExecution.cs b/dotnet/src/Microsoft.Agents.Workflows/InProcessExecution.cs index e435927f9e..96819a4b56 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/InProcessExecution.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/InProcessExecution.cs @@ -79,7 +79,7 @@ public static class InProcessExecution CheckpointManager checkpointManager, CancellationToken cancellation = default) where TInput : notnull { - InProcessRunner runner = new(workflow, checkpointManager); + InProcessRunner runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId); StreamingRun result = await runner.ResumeStreamAsync(fromCheckpoint, cancellation).ConfigureAwait(false); return new(result, runner); @@ -155,7 +155,7 @@ public static class InProcessExecution CheckpointManager checkpointManager, CancellationToken cancellation = default) where TInput : notnull { - InProcessRunner runner = new(workflow, checkpointManager); + InProcessRunner runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId); StreamingRun result = await runner.ResumeStreamAsync(fromCheckpoint, cancellation).ConfigureAwait(false); return new(result, runner); @@ -225,7 +225,7 @@ public static class InProcessExecution CheckpointManager checkpointManager, CancellationToken cancellation = default) where TInput : notnull { - InProcessRunner runner = new(workflow, checkpointManager); + InProcessRunner runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId); Run result = await runner.ResumeAsync(fromCheckpoint, cancellation).ConfigureAwait(false); return new(result, runner); @@ -298,7 +298,7 @@ public static class InProcessExecution CheckpointManager checkpointManager, CancellationToken cancellation = default) where TInput : notnull { - InProcessRunner runner = new(workflow, checkpointManager); + InProcessRunner runner = new(workflow, checkpointManager, runId: fromCheckpoint.RunId); Run result = await runner.ResumeAsync(fromCheckpoint, cancellation).ConfigureAwait(false); return new(result, runner); diff --git a/dotnet/src/Microsoft.Agents.Workflows/PortableValue.cs b/dotnet/src/Microsoft.Agents.Workflows/PortableValue.cs new file mode 100644 index 0000000000..ab3fe80db9 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.Workflows/PortableValue.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json.Serialization; + +using Microsoft.Agents.Workflows.Checkpointing; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.Workflows; + +/// +/// Represents a value that can be exported / imported to a workflow, e.g. through an external request/response, or +/// through checkpointing. Abstracts away delayed deserialization and type conversion where appropriate. +/// +public sealed class PortableValue +{ + internal PortableValue(object value) + { + this._value = value; + this.TypeId = new(value.GetType()); + } + + [JsonConstructor] + internal PortableValue(TypeId typeId, object value) + { + this.TypeId = Throw.IfNull(typeId); + this._value = value; + } + + /// + public override bool Equals(object? obj) + { + if (obj == null) + { + return false; + } + + if (obj is not PortableValue other) + { + Type targetType = obj.GetType(); + return this.AsType(targetType)?.Equals(obj) ?? false; + } + + return this.TypeId == other.TypeId + && ((this.Value == null && other.Value == null) + || this.Value != null && this.Value.Equals(other.Value)); + } + + /// + public override int GetHashCode() + { + return HashCode.Combine(this.TypeId, this.Value); + } + + /// + public static bool operator ==(PortableValue? left, PortableValue? right) + { + if (left is null) + { + return right is null; + } + + return left.Equals(right); + } + + /// + public static bool operator !=(PortableValue? left, PortableValue? right) => !(left == right); + + /// + /// The identifier of the type of the instance in . + /// + public TypeId TypeId { get; } + + [JsonIgnore] + internal bool IsDelayedDeserialization => this.Value is IDelayedDeserialization; + + [JsonIgnore] + internal bool IsDeserialized => this._deserializedValueCache != null; + + private readonly object _value; + private object? _deserializedValueCache = null; + + /// + /// Gets the raw underlying value represented by this instance. + /// + [JsonInclude] + internal object Value => this._deserializedValueCache ?? Throw.IfNull(this._value); + + /// + /// Attempts to retrieve the underlying value as the specified type, deserializing if necessary. + /// + /// If the underlying value implements delayed deserialization, this method will attempt to + /// deserialize it to the specified type. If the value is already of the requested type, it is returned directly. + /// Otherwise, the default value for TValue is returned. + /// + /// For nullable value types, make sure to make be nullable, e.g. int?, + /// otherwise the default non-null value of the type is returned when the value is missing. Use + /// to get the correct behavior when unable to pass in the explicit-nullable type. + /// + /// The type to which the value should be cast or deserialized. + /// The value cast or deserialized to type TValue if possible; otherwise, the default value for type TValue. + public TValue? As() + { + if (this.Value is IDelayedDeserialization delayedDeserialization) + { + if (this._deserializedValueCache == null) + { + this._deserializedValueCache = delayedDeserialization.Deserialize(); + } + } + + if (this.Value is TValue typedValue) + { + return typedValue; + } + + return default; + } + + /// + /// Attempts to retrieve the underlying value as the specified nullable value type, deserializing if + /// necessary. + /// + /// If the underlying value implements delayed deserialization, this method will attempt to + /// deserialize it to the specified type. If the value is already of the requested type, it is returned directly. + /// Otherwise, null is returned. + /// The value type to which the value should be cast or deserialized. + /// The value cast or deserialized to type TValue if possible; otherwise, null. + public TValue? AsValue() where TValue : struct + { + if (this.Value is IDelayedDeserialization delayedDeserialization) + { + this._deserializedValueCache ??= delayedDeserialization.Deserialize(); + } + + if (this.Value is TValue typedValue) + { + return typedValue; + } + + return default; + } + + /// + /// Determines whether the current value can be represented as the specified type. + /// + /// The type to test for compatibility with the current value. + /// true if the current value can be represented as type TValue; otherwise, false. + public bool Is() => this.IsType(typeof(TValue)); + + /// + /// Attempts to retrieve the underlying value as the specified type, deserializing if necessary. + /// + /// The type to which the value should be cast or deserialized. + /// The value cast or deserialized to type targetType if possible; otherwise, null. + public object? AsType(Type targetType) + { + Throw.IfNull(targetType); + + if (this.Value is IDelayedDeserialization delayedDeserialization) + { + this._deserializedValueCache ??= delayedDeserialization.Deserialize(targetType); + } + + return this.Value is not null && targetType.IsAssignableFrom(this.Value.GetType()) + ? this.Value + : this._deserializedValueCache = null; + } + + /// + /// Determines whether the current instance can be assigned to the specified target type. + /// + /// The type to compare with the current instance. Cannot be null. + /// true if the current instance can be assigned to targetType; otherwise, false. + public bool IsType(Type targetType) => this.AsType(targetType) != null; +} diff --git a/dotnet/src/Microsoft.Agents.Workflows/ScopeKey.cs b/dotnet/src/Microsoft.Agents.Workflows/ScopeKey.cs index ce5be0a899..68a3af93ca 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/ScopeKey.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/ScopeKey.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.Workflows; @@ -8,19 +9,17 @@ namespace Microsoft.Agents.Workflows; /// /// Represents a unique key within a specific scope, combining a scope identifier and a key string. /// -/// The associated with this key. -/// The unique key within the specified scope. -public class ScopeKey(ScopeId scopeId, string key) +public class ScopeKey { /// /// The identifier for the scope associated with this key. /// - public ScopeId ScopeId { get; } = Throw.IfNull(scopeId); + public ScopeId ScopeId { get; } /// /// The unique key within the specified scope. /// - public string Key { get; } = Throw.IfNullOrEmpty(key); + public string Key { get; } /// /// Initializes a new instance of the class. @@ -32,6 +31,18 @@ public class ScopeKey(ScopeId scopeId, string key) : this(new ScopeId(Throw.IfNullOrEmpty(executorId), scopeName), key) { } + /// + /// Iniitalizes a new instance of the class. + /// + /// The associated with this key. + /// The unique key within the specified scope. + [JsonConstructor] + public ScopeKey(ScopeId scopeId, string key) + { + this.ScopeId = Throw.IfNull(scopeId); + this.Key = Throw.IfNullOrEmpty(key); + } + /// public override string ToString() { diff --git a/dotnet/src/Microsoft.Agents.Workflows/Specialized/RequestInfoExecutor.cs b/dotnet/src/Microsoft.Agents.Workflows/Specialized/RequestInfoExecutor.cs index 9b12ff1515..1ae0fdeb01 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Specialized/RequestInfoExecutor.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Specialized/RequestInfoExecutor.cs @@ -65,14 +65,16 @@ internal class RequestInfoExecutor : Executor Throw.IfNull(message); Throw.IfNull(message.Data); - if (!this.Port.Response.IsAssignableFrom(message.Data.GetType())) + object? data = message.DataAs(this.Port.Response); + + if (data == null) { throw new InvalidOperationException( - $"Message type {message.Data.GetType().Name} is not assignable to the response type {this.Port.Response.Name} of input port {this.Port.Id}."); + $"Message type {message.Data.TypeId} is not assignable to the response type {this.Port.Response.Name} of input port {this.Port.Id}."); } await context.SendMessageAsync(message).ConfigureAwait(false); - await context.SendMessageAsync(message.Data).ConfigureAwait(false); + await context.SendMessageAsync(data).ConfigureAwait(false); return message; } diff --git a/dotnet/src/Microsoft.Agents.Workflows/SuperStepEvent.cs b/dotnet/src/Microsoft.Agents.Workflows/SuperStepEvent.cs index f0b64f2169..53c3dadce9 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/SuperStepEvent.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/SuperStepEvent.cs @@ -1,10 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; + namespace Microsoft.Agents.Workflows; /// /// Base class for SuperStep-scoped events, for example, /// +[JsonDerivedType(typeof(SuperStepStartedEvent))] +[JsonDerivedType(typeof(SuperStepCompletedEvent))] public class SuperStepEvent(int stepNumber, object? data = null) : WorkflowEvent(data) { /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/SwitchBuilder.cs b/dotnet/src/Microsoft.Agents.Workflows/SwitchBuilder.cs index ddfe157604..2411f62294 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/SwitchBuilder.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/SwitchBuilder.cs @@ -30,7 +30,7 @@ public sealed class SwitchBuilder /// One or more executors to associate with the predicate. Each executor will be invoked if the predicate matches. /// Cannot be null. /// The current instance, allowing for method chaining. - public SwitchBuilder AddCase(Func predicate, params ExecutorIsh[] executors) + public SwitchBuilder AddCase(Func predicate, params ExecutorIsh[] executors) { Throw.IfNull(predicate); Throw.IfNull(executors); @@ -49,7 +49,8 @@ public sealed class SwitchBuilder indicies.Add(index); } - this._caseMap.Add((predicate, indicies)); + Func casePredicate = WorkflowBuilder.CreateConditionFunc(predicate)!; + this._caseMap.Add((casePredicate, indicies)); return this; } @@ -83,7 +84,7 @@ public sealed class SwitchBuilder List<(Func Predicate, HashSet OutgoingIndicies)> caseMap = this._caseMap; HashSet defaultIndicies = this._defaultIndicies; - return builder.AddFanOutEdge(source, CasePartitioner, [.. this._executors]); + return builder.AddFanOutEdge(source, CasePartitioner, [.. this._executors]); IEnumerable CasePartitioner(object? input, int targetCount) { diff --git a/dotnet/src/Microsoft.Agents.Workflows/Workflow.cs b/dotnet/src/Microsoft.Agents.Workflows/Workflow.cs index 2b1549c5bd..d7af9497b8 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/Workflow.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/Workflow.cs @@ -2,6 +2,8 @@ using System; using System.Collections.Generic; +using System.Linq; +using Microsoft.Agents.Workflows.Checkpointing; using Microsoft.Agents.Workflows.Specialized; using Microsoft.Shared.Diagnostics; @@ -17,10 +19,20 @@ public class Workflow /// internal Dictionary Registrations { get; init; } = new(); + internal Dictionary> Edges { get; init; } = new(); + /// /// Gets the collection of edges grouped by their source node identifier. /// - public Dictionary> Edges { get; internal init; } = new(); + public Dictionary> ReflectEdges() + { + return this.Edges.Keys.ToDictionary( + keySelector: key => key, + elementSelector: key => new HashSet(this.Edges[key].Select(RepresentationExtensions.ToEdgeInfo)) + ); + } + + internal Dictionary Ports { get; init; } = new(); /// /// Gets the collection of external request ports, keyed by their ID. @@ -28,7 +40,13 @@ public class Workflow /// /// Each port has a corresponding entry in the dictionary. /// - public Dictionary Ports { get; internal init; } = new(); + public Dictionary ReflectPorts() + { + return this.Ports.Keys.ToDictionary( + keySelector: key => key, + elementSelector: key => this.Ports[key].ToPortInfo() + ); + } /// /// Gets the identifier of the starting executor of the workflow. diff --git a/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilder.cs b/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilder.cs index 49ad08c33b..2c29173741 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilder.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilder.cs @@ -20,15 +20,16 @@ namespace Microsoft.Agents.Workflows; /// . public class WorkflowBuilder { - private record struct EdgeId(string SourceId, string TargetId) + private record struct EdgeConnection(string SourceId, string TargetId) { public override string ToString() => $"{this.SourceId} -> {this.TargetId}"; } + private int _edgeCount = 0; private readonly Dictionary _executors = new(); private readonly Dictionary> _edges = new(); private readonly HashSet _unboundExecutors = new(); - private readonly HashSet _conditionlessEdges = new(); + private readonly HashSet _conditionlessConnections = new(); private readonly Dictionary _inputPorts = new(); private readonly string _startExecutorId; @@ -121,6 +122,52 @@ public class WorkflowBuilder return edges; } + /// + /// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a + /// condition. + /// + /// The executor that acts as the source node of the edge. Cannot be null. + /// The executor that acts as the target node of the edge. Cannot be null. + /// The current instance of . + /// Thrown if an unconditional edge between the specified source and target + /// executors already exists. + public WorkflowBuilder AddEdge(ExecutorIsh source, ExecutorIsh target) + => this.AddEdge(source, target, null); + + internal static Func? CreateConditionFunc(Func? condition) + { + if (condition == null) + { + return null; + } + return maybeObj => + { + if (typeof(T) != typeof(object) && maybeObj is PortableValue portableValue) + { + maybeObj = portableValue.AsType(typeof(T)); + } + return condition(maybeObj is T typed ? typed : default); + }; + } + + internal static Func? CreateConditionFunc(Func? condition) + { + if (condition == null) + { + return null; + } + return maybeObj => + { + if (typeof(T) != typeof(object) && maybeObj is PortableValue portableValue) + { + maybeObj = portableValue.AsType(typeof(T)); + } + return condition(maybeObj); + }; + } + + private EdgeId TakeEdgeId() => new(Interlocked.Increment(ref this._edgeCount)); + /// /// Adds a directed edge from the specified source executor to the target executor, optionally guarded by a /// condition. @@ -132,7 +179,7 @@ public class WorkflowBuilder /// The current instance of . /// Thrown if an unconditional edge between the specified source and target /// executors already exists. - public WorkflowBuilder AddEdge(ExecutorIsh source, ExecutorIsh target, Func? condition = null) + public WorkflowBuilder AddEdge(ExecutorIsh source, ExecutorIsh target, Func? condition = null) { // Add an edge from source to target with an optional condition. // This is a low-level builder method that does not enforce any specific executor type. @@ -140,21 +187,51 @@ public class WorkflowBuilder Throw.IfNull(source); Throw.IfNull(target); - EdgeId id = new(source.Id, target.Id); - if (condition == null && this._conditionlessEdges.Contains(id)) + EdgeConnection connection = new(source.Id, target.Id); + if (condition == null && this._conditionlessConnections.Contains(connection)) { throw new InvalidOperationException( $"An edge from '{source.Id}' to '{target.Id}' already exists without a condition. " + "You cannot add another edge without a condition for the same source and target."); } - DirectEdgeData directEdge = new(this.Track(source).Id, this.Track(target).Id, condition); + DirectEdgeData directEdge = new(this.Track(source).Id, this.Track(target).Id, this.TakeEdgeId(), CreateConditionFunc(condition)); this.EnsureEdgesFor(source.Id).Add(new(directEdge)); return this; } + /// + /// Adds a fan-out edge from the specified source executor to one or more target executors, optionally using a + /// custom partitioning function. + /// + /// If a partitioner function is provided, it will be used to distribute input across the target + /// executors. The order of targets determines their mapping in the partitioning process. + /// The source executor from which the fan-out edge originates. Cannot be null. + /// One or more target executors that will receive the fan-out edge. Cannot be null or empty. + /// The current instance of . + public WorkflowBuilder AddFanOutEdge(ExecutorIsh source, params ExecutorIsh[] targets) + => this.AddFanOutEdge(source, null, targets); + + internal static Func>? CreateEdgeAssignerFunc(Func>? partitioner) + { + if (partitioner == null) + { + return null; + } + + return (object? maybeObj, int count) => + { + if (typeof(T) != typeof(object) && maybeObj is PortableValue portableValue) + { + maybeObj = portableValue.AsType(typeof(T)); + } + + return partitioner(maybeObj is T typed ? typed : default, count); + }; + } + /// /// Adds a fan-out edge from the specified source executor to one or more target executors, optionally using a /// custom partitioning function. @@ -166,7 +243,7 @@ public class WorkflowBuilder /// If null, messages will route to all targets. /// One or more target executors that will receive the fan-out edge. Cannot be null or empty. /// The current instance of . - public WorkflowBuilder AddFanOutEdge(ExecutorIsh source, Func>? partitioner = null, params ExecutorIsh[] targets) + public WorkflowBuilder AddFanOutEdge(ExecutorIsh source, Func>? partitioner = null, params ExecutorIsh[] targets) { Throw.IfNull(source); Throw.IfNullOrEmpty(targets); @@ -174,7 +251,8 @@ public class WorkflowBuilder FanOutEdgeData fanOutEdge = new( this.Track(source).Id, targets.Select(target => this.Track(target).Id).ToList(), - partitioner); + this.TakeEdgeId(), + CreateEdgeAssignerFunc(partitioner)); this.EnsureEdgesFor(source.Id).Add(new(fanOutEdge)); @@ -198,7 +276,8 @@ public class WorkflowBuilder FanInEdgeData edgeData = new( sources.Select(source => this.Track(source).Id).ToList(), - this.Track(target).Id); + this.Track(target).Id, + this.TakeEdgeId()); foreach (string sourceId in edgeData.SourceIds) { diff --git a/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs index fcfebaebd4..4c2c547a2e 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/WorkflowBuilderExtensions.cs @@ -28,18 +28,22 @@ public static class WorkflowBuilderExtensions { Throw.IfNullOrEmpty(executors); + Func predicate = WorkflowBuilder.CreateConditionFunc((Func)IsAllowedType)!; + if (executors.Length == 1) { - return builder.AddEdge(source, executors[0], IsAllowedType); + return builder.AddEdge(source, executors[0], predicate); } return builder.AddSwitch(source, (switch_) => { - switch_.AddCase(IsAllowedType, executors); + switch_.AddCase(predicate, executors); }); - bool IsAllowedType(object? message) => message is TMessage; + // The reason we can check for "not null" here is that CreateConditionFunc will do the correct unwrapping + // logic for PortableValues. + bool IsAllowedType(object? message) => message is not null; } /// @@ -54,18 +58,22 @@ public static class WorkflowBuilderExtensions { Throw.IfNullOrEmpty(executors); + Func predicate = WorkflowBuilder.CreateConditionFunc((Func)IsAllowedType)!; + if (executors.Length == 1) { - return builder.AddEdge(source, executors[0], IsAllowedType); + return builder.AddEdge(source, executors[0], predicate); } return builder.AddSwitch(source, (switch_) => { - switch_.AddCase(IsAllowedType, executors); + switch_.AddCase(predicate, executors); }); - bool IsAllowedType(object? message) => message is not TMessage; + // The reason we can check for "null" here is that CreateConditionFunc will do the correct unwrapping + // logic for PortableValues. + bool IsAllowedType(object? message) => message is null; } /// @@ -158,6 +166,8 @@ public static class WorkflowBuilderExtensions /// to access the aggregated output directly. The completion condition can be used to implement custom termination /// logic, such as early stopping when a desired result is reached. /// The type of input items processed by the workflow. + /// The type of items generated by the , + /// and aggregated by the . /// The type of aggregated result produced by the workflow. /// The workflow builder used to construct the workflow and define its execution graph. /// The executor that produces output items to be collected and aggregated. Cannot be null. @@ -166,18 +176,18 @@ public static class WorkflowBuilderExtensions /// aggregated result. If null, the workflow will not raise a . /// A workflow that collects output from the specified executor, aggregates results, and exposes the aggregated /// output. - public static Workflow BuildWithOutput( + public static Workflow BuildWithOutput( this WorkflowBuilder builder, ExecutorIsh outputSource, - StreamingAggregator aggregator, - Func? completionCondition = null) + StreamingAggregator aggregator, + Func? completionCondition = null) { Throw.IfNull(outputSource); Throw.IfNull(aggregator); - OutputCollectorExecutor outputSink = new(aggregator, completionCondition); + OutputCollectorExecutor outputSink = new(aggregator, completionCondition); - // TODO: Check taht the outputSource has a TResult output? + // TODO: Check that the outputSource has a TResult output? builder.AddEdge(outputSource, outputSink); Workflow workflow = builder.Build(); diff --git a/dotnet/src/Microsoft.Agents.Workflows/WorkflowEvent.cs b/dotnet/src/Microsoft.Agents.Workflows/WorkflowEvent.cs index c469cd1c92..b7f91a67ee 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/WorkflowEvent.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/WorkflowEvent.cs @@ -1,10 +1,19 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; + namespace Microsoft.Agents.Workflows; /// /// Base class for -scoped events. /// +[JsonDerivedType(typeof(ExecutorEvent))] +[JsonDerivedType(typeof(SuperStepEvent))] +[JsonDerivedType(typeof(WorkflowStartedEvent))] +[JsonDerivedType(typeof(WorkflowCompletedEvent))] +[JsonDerivedType(typeof(WorkflowErrorEvent))] +[JsonDerivedType(typeof(WorkflowWarningEvent))] +[JsonDerivedType(typeof(RequestInfoEvent))] public class WorkflowEvent(object? data = null) { /// diff --git a/dotnet/src/Microsoft.Agents.Workflows/WorkflowHostingExtensions.cs b/dotnet/src/Microsoft.Agents.Workflows/WorkflowHostingExtensions.cs index abc679c844..d1e6260e09 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/WorkflowHostingExtensions.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/WorkflowHostingExtensions.cs @@ -30,6 +30,6 @@ public static class WorkflowHostingExtensions { "data", request.Data} }; - return new FunctionCallContent(request.RequestId, request.Port.Id, parameters); + return new FunctionCallContent(request.RequestId, request.PortInfo.PortId, parameters); } } diff --git a/dotnet/src/Microsoft.Agents.Workflows/WorkflowsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.Workflows/WorkflowsJsonUtilities.cs index 5f4efc332d..ee9335fffc 100644 --- a/dotnet/src/Microsoft.Agents.Workflows/WorkflowsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.Workflows/WorkflowsJsonUtilities.cs @@ -3,6 +3,8 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; +using Microsoft.Agents.Workflows.Checkpointing; +using Microsoft.Agents.Workflows.Execution; using Microsoft.Extensions.AI; using static Microsoft.Agents.Workflows.WorkflowMessageStore; @@ -53,9 +55,38 @@ internal static partial class WorkflowsJsonUtilities DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, NumberHandling = JsonNumberHandling.AllowReadingFromString)] - // Agent abstraction types + // Checkpointing Types + [JsonSerializable(typeof(Checkpoint))] + [JsonSerializable(typeof(CheckpointInfo))] + [JsonSerializable(typeof(PortableValue))] + [JsonSerializable(typeof(PortableMessageEnvelope))] + + // Runtime State Types + [JsonSerializable(typeof(ScopeKey))] + [JsonSerializable(typeof(ScopeId))] + [JsonSerializable(typeof(ExecutorIdentity))] + [JsonSerializable(typeof(RunnerStateData))] + + // Workflow Representation Types + [JsonSerializable(typeof(WorkflowInfo))] + [JsonSerializable(typeof(EdgeConnection))] + + // Workflow-as-Agent [JsonSerializable(typeof(StoreState))] + // Message Types + [JsonSerializable(typeof(ChatMessage))] + [JsonSerializable(typeof(ExternalRequest))] + [JsonSerializable(typeof(ExternalResponse))] + [JsonSerializable(typeof(TurnToken))] + + // Event Types + //[JsonSerializable(typeof(WorkflowEvent))] + // Currently cannot be serialized because it includes Exceptions. + // We'll need a way to marshal this correct in the AgentRuntime case. + // For now this is okay, because we never serialize WorkflowEvents into + // checkpoints. + [JsonSerializable(typeof(JsonElement))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; } diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeMapSmokeTests.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeMapSmokeTests.cs index da00a17554..590fcce60d 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeMapSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeMapSmokeTests.cs @@ -19,7 +19,7 @@ public class EdgeMapSmokeTests Dictionary> workflowEdges = new(); - FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3"); + FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0)); Edge fanInEdge = new(edgeData); workflowEdges["executor1"] = [fanInEdge]; diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeRunnerTests.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeRunnerTests.cs index 9e0281e744..57f77c371e 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeRunnerTests.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/EdgeRunnerTests.cs @@ -31,7 +31,7 @@ public class EdgeRunnerTests runContext.Executors["executor1"] = new ForwardMessageExecutor("executor1"); runContext.Executors["executor2"] = new ForwardMessageExecutor("executor2"); - DirectEdgeData edgeData = new("executor1", "executor2", condition); + DirectEdgeData edgeData = new("executor1", "executor2", new EdgeId(0), condition); DirectEdgeRunner runner = new(runContext, edgeData); MessageEnvelope envelope = new(MessageVariant1, targetId: targetId); @@ -90,7 +90,7 @@ public class EdgeRunnerTests ? (targetMatch.Value ? "executor2" : "executor1") : null; - FanOutEdgeData edgeData = new("executor1", ["executor2", "executor3"], assigner); + FanOutEdgeData edgeData = new("executor1", ["executor2", "executor3"], new EdgeId(0), assigner); FanOutEdgeRunner runner = new(runContext, edgeData); MessageEnvelope envelope = new("test", targetId: targetId); @@ -145,7 +145,7 @@ public class EdgeRunnerTests runContext.Executors["executor2"] = new ForwardMessageExecutor("executor2"); runContext.Executors["executor3"] = new ForwardMessageExecutor("executor3"); - FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3"); + FanInEdgeData edgeData = new(["executor1", "executor2"], "executor3", new EdgeId(0)); FanInEdgeRunner runner = new(runContext, edgeData); // Step 1: Send message from executor1, should not forward yet. diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InMemoryJsonStore.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InMemoryJsonStore.cs new file mode 100644 index 0000000000..162e791397 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InMemoryJsonStore.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Agents.Workflows.Checkpointing; + +namespace Microsoft.Agents.Workflows.UnitTests; + +internal sealed class InMemoryJsonStore : JsonCheckpointStore +{ + private readonly Dictionary> _store = new(); + + private RunCheckpointCache EnsureRunStore(string runId) + { + if (!this._store.TryGetValue(runId, out RunCheckpointCache? runStore)) + { + runStore = this._store[runId] = new(); + } + + return runStore; + } + + public override ValueTask CreateCheckpointAsync(string runId, JsonElement value, CheckpointInfo? parent = null) + { + return new(this.EnsureRunStore(runId).Add(runId, value)); + } + + public override ValueTask RetrieveCheckpointAsync(string runId, CheckpointInfo key) + { + if (!this.EnsureRunStore(runId).TryGet(key, out JsonElement result)) + { + throw new KeyNotFoundException("Could not retrieve checkpoint with id {key.CheckpointId} for run {runId}"); + } + + return new(result); + } + + public override ValueTask> RetrieveIndexAsync(string runId, CheckpointInfo? withParent = null) + { + return new(this.EnsureRunStore(runId).Index); + } +} diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InProcessStateTests.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InProcessStateTests.cs index 63bdf4455d..1fbbe76799 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InProcessStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/InProcessStateTests.cs @@ -127,7 +127,7 @@ public class InProcessStateTests .AddEdge(writer, validator, MaxTurns(4)) .AddEdge(validator, writer, MaxTurns(4)).Build(); - Checkpointed checkpointed = await InProcessExecution.RunAsync(workflow, new(), new CheckpointManager()); + Checkpointed checkpointed = await InProcessExecution.RunAsync(workflow, new(), CheckpointManager.Default); checkpointed.Checkpoints.Should().HaveCount(6); checkpointed.Run.Status.Should().Be(RunStatus.Idle); diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/JsonSerializationTests.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/JsonSerializationTests.cs new file mode 100644 index 0000000000..6a0a40e3c2 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/JsonSerializationTests.cs @@ -0,0 +1,653 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Linq.Expressions; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using FluentAssertions; +using Microsoft.Agents.Workflows.Checkpointing; +using Microsoft.Agents.Workflows.Execution; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.Workflows.UnitTests; + +public class JsonSerializationTests +{ + private static JsonSerializerOptions TestCustomSerializedJsonOptions + { + get + { + JsonSerializerOptions options = new(TestJsonContext.Default.Options); + options.MakeReadOnly(); + + return options; + } + } + + private static int s_nextEdgeId = 0; + + private static EdgeId TakeEdgeId() => new(Interlocked.Increment(ref s_nextEdgeId)); + + private static T RunJsonRoundtrip(T value, JsonSerializerOptions? externalOptions = null, Expression>? predicate = null) + { + JsonMarshaller marshaller = new(externalOptions); + + JsonElement element = marshaller.Marshal(value); + T deserialized = marshaller.Marshal(element); + + if (deserialized != null) + { + if (predicate != null) + { + deserialized.Should().Match(predicate); + } + + return deserialized; + } + + Debug.Fail($"Could not roundtrip type '{typeof(T).Name}'. JSON = '{element}'."); + throw new NotSupportedException($"Could not roundtrip type '{typeof(T).Name}'."); + } + + [Fact] + public void Test_EdgeConnection_JsonRoundtrip() + { + EdgeConnection connection = new(new List { "Source1", "Source2" }, new List { "Sink1", "Sink2" }); + RunJsonRoundtrip(connection, predicate: connection.CreateValidator()); + } + + [Fact] + public void Test_TypeId_JsonRoundtrip() + { + TypeId type = new(typeof(Type)); + RunJsonRoundtrip(type, predicate: CreateValidator()); + + Expression> CreateValidator() + { + return deserialized => deserialized.AssemblyName == type.AssemblyName && + deserialized.TypeName == type.TypeName && + deserialized.IsMatch(typeof(Type)); + } + } + + [Fact] + public void Test_ExecutorInfo_JsonRoundtrip() + { + ExecutorInfo executorInfo = new(new(typeof(ForwardMessageExecutor)), "ForwardString"); + RunJsonRoundtrip(executorInfo, predicate: CreateValidator()); + + Expression> CreateValidator() + { + return deserialized => deserialized.ExecutorId == executorInfo.ExecutorId && + // Rely on the TypeId test to probe TypeId serialization - just validate that we got a functional TypeId + deserialized.ExecutorType.IsMatch(typeof(ForwardMessageExecutor)); + } + } + + private static InputPort TestPort => InputPort.Create("StringToInt"); + private static InputPortInfo TestPortInfo => TestPort.ToPortInfo(); + + [Fact] + public void Test_InputPortInfo_JsonRoundtrip() + { + RunJsonRoundtrip(TestPortInfo, predicate: TestPort.CreatePortInfoValidator()); + } + + private static DirectEdgeInfo TestDirectEdgeInfo_NoCondition => new(new("SourceExecutor", "TargetExecutor", TakeEdgeId(), condition: null)); + private static DirectEdgeInfo TestDirectEdgeInfo_Condition => new(new("SourceExecutor", "TargetExecutor", TakeEdgeId(), condition: msg => msg is not null)); + + [Fact] + public void Test_DirectEdgeInfo_JsonRoundtrip() + { + RunJsonRoundtrip(TestDirectEdgeInfo_NoCondition, predicate: TestDirectEdgeInfo_NoCondition.CreateValidator()); + RunJsonRoundtrip(TestDirectEdgeInfo_Condition, predicate: TestDirectEdgeInfo_Condition.CreateValidator()); + } + + private static FanOutEdgeInfo TestFanOutEdgeInfo_NoAssigner => new(new("SourceExecutor", ["TargetExecutor1", "TargetExecutor2"], TakeEdgeId(), assigner: null)); + private static FanOutEdgeInfo TestFanOutEdgeInfo_Assigner => new(new("SourceExecutor", ["TargetExecutor1", "TargetExecutor2"], TakeEdgeId(), assigner: (msg, count) => [])); + + [Fact] + public void Test_FanOutEdgeInfo_JsonRoundtrip() + { + RunJsonRoundtrip(TestFanOutEdgeInfo_NoAssigner, predicate: TestFanOutEdgeInfo_NoAssigner.CreateValidator()); + RunJsonRoundtrip(TestFanOutEdgeInfo_Assigner, predicate: TestFanOutEdgeInfo_Assigner.CreateValidator()); + } + + private static FanInEdgeData TestFanInEdgeData => new(["SourceExecutor1", "SourceExecutor2"], "TargetExecutor", TakeEdgeId()); + private static FanInEdgeInfo TestFanInEdgeInfo => new(TestFanInEdgeData); + + [Fact] + public void Test_FanInEdgeInfo_JsonRoundtrip() + { + RunJsonRoundtrip(TestFanInEdgeInfo, predicate: TestFanInEdgeInfo.CreateValidator()); + } + + private static EdgeInfo TestEdgeInfo_DirectNoCondition { get; } = TestDirectEdgeInfo_NoCondition; + private static EdgeInfo TestEdgeInfo_DirectCondition { get; } = TestDirectEdgeInfo_Condition; + private static EdgeInfo TestEdgeInfo_FanOutNoAssigner { get; } = TestFanOutEdgeInfo_NoAssigner; + private static EdgeInfo TestEdgeInfo_FanOutAssigner { get; } = TestFanOutEdgeInfo_Assigner; + private static EdgeInfo TestEdgeInfo_FanIn { get; } = TestFanInEdgeInfo; + + [Fact] + public void Test_EdgeInfoPolymorphism_JsonRoundtrip() + { + RunJsonRoundtrip(TestEdgeInfo_DirectNoCondition, predicate: TestEdgeInfo_DirectNoCondition.CreatePolyValidator()); + RunJsonRoundtrip(TestEdgeInfo_DirectCondition, predicate: TestEdgeInfo_DirectCondition.CreatePolyValidator()); + RunJsonRoundtrip(TestEdgeInfo_FanOutNoAssigner, predicate: TestEdgeInfo_FanOutNoAssigner.CreatePolyValidator()); + RunJsonRoundtrip(TestEdgeInfo_FanOutAssigner, predicate: TestEdgeInfo_FanOutAssigner.CreatePolyValidator()); + RunJsonRoundtrip(TestEdgeInfo_FanIn, predicate: TestEdgeInfo_FanIn.CreatePolyValidator()); + } + + private const string ForwardStringId = nameof(s_forwardString); + private const string ForwardIntId = nameof(s_forwardInt); + + private static readonly ExecutorIdentity s_forwardString = new() { Id = ForwardStringId }; + private static readonly ExecutorIdentity s_forwardInt = new() { Id = ForwardIntId }; + + private const string IntToStringId = nameof(IntToString); + private const string StringToIntId = nameof(StringToInt); + + private static InputPortInfo IntToString => InputPort.Create(IntToStringId).ToPortInfo(); + private static InputPortInfo StringToInt => InputPort.Create(StringToIntId).ToPortInfo(); + + private static Workflow CreateTestWorkflow() + { + ForwardMessageExecutor forwardString = new(ForwardStringId); + ForwardMessageExecutor forwardInt = new(ForwardIntId); + + InputPort stringToInt = InputPort.Create(StringToIntId); + InputPort intToString = InputPort.Create(IntToStringId); + + WorkflowBuilder builder = new(forwardString); + builder.AddEdge(forwardString, stringToInt) + .AddEdge(stringToInt, forwardInt) + .AddEdge(forwardInt, intToString); + + Workflow workflow = builder.BuildWithOutput( + intToString, + StreamingAggregators.Last(), (int _, int __) => true); + + return workflow; + } + + private static WorkflowInfo TestWorkflowInfo => CreateTestWorkflow().ToWorkflowInfo(); + + private static void ValidateWorkflowInfo(WorkflowInfo actual, WorkflowInfo prototype) + { + ValidateExecutorDictionary(prototype.Executors, prototype.Edges, actual.Executors, actual.Edges); + ValidateInputPorts(prototype.InputPorts, actual.InputPorts); + + actual.InputType.Should().Match(prototype.InputType.CreateValidator()); + actual.StartExecutorId.Should().Be(prototype.StartExecutorId); + + actual.OutputType.Should().NotBeNull().And.Match(prototype.OutputType!.CreateValidator()); + actual.OutputCollectorId.Should().NotBeNull().And.Be(prototype.OutputCollectorId); + + void ValidateExecutorDictionary(Dictionary expected, + Dictionary> expectedEdges, + Dictionary actual, + Dictionary> actualEdges) + { + actual.Should().HaveCount(expected.Count); + actualEdges.Should().HaveCount(expectedEdges.Count); + + foreach (string key in expected.Keys) + { + actual.Should().ContainKey(key); + + ExecutorInfo actualValue = actual[key]; + ExecutorInfo expectedValue = expected[key]; + + actualValue.Should().Match(expectedValue.CreateValidator()); + + if (expectedEdges.TryGetValue(key, out List? expectedEdgeList)) + { + List? actualEdgeList = actualEdges.Should().ContainKey(key).WhoseValue; + actualEdgeList.Should().NotBeNull(); + + ValidateExecutorEdges(expectedEdgeList, actualEdgeList); + } + } + } + + void ValidateExecutorEdges(List expected, List actual) + { + actual.Should().HaveCount(expected.Count); + foreach (EdgeInfo expectedEdge in expected) + { + actual.Should().ContainSingle(edge => edge.CreatePolyValidator().Compile()(edge)); + } + } + + void ValidateInputPorts(HashSet expected, HashSet actual) + => actual.Should().HaveCount(expected.Count).And.IntersectWith(expected); + } + + [Fact] + public void Test_WorkflowInfo_JsonRoundtrip() + { + WorkflowInfo prototype = TestWorkflowInfo; + + JsonMarshaller marshaller = new(); + + JsonElement jsonElement = marshaller.Marshal(prototype, typeof(WorkflowInfo)); + WorkflowInfo deserialized = marshaller.Marshal(jsonElement); + + ValidateWorkflowInfo(deserialized, prototype); + } + + private static ExecutorIdentity TestIdentity => new() { Id = "Executor1" }; + + [Fact] + public void Test_ExecutorIdentity_JsonRoundtrip() + { + RunJsonRoundtrip(TestIdentity, predicate: TestIdentity.CreateValidator()); + RunJsonRoundtrip(ExecutorIdentity.None, predicate: ExecutorIdentity.None.CreateValidator()); + } + + private static ScopeId TestScopeId_Private => new("Executor1", null); + private static ScopeId TestScopeId_Public => new("Executor1", "Scope1"); + + [Fact] + public void Test_ScopeId_JsonRoundtrip() + { + RunJsonRoundtrip(TestScopeId_Private, predicate: TestScopeId_Private.CreateValidator()); + RunJsonRoundtrip(TestScopeId_Public, predicate: TestScopeId_Public.CreateValidator()); + } + + private static ScopeKey TestScopeKey_Private => new(TestScopeId_Private, "Key1"); + private static ScopeKey TestScopeKey_Public => new(TestScopeId_Public, "Key1"); + + [Fact] + public void Test_ScopeKey_JsonRoundtrip() + { + RunJsonRoundtrip(TestScopeKey_Private, predicate: TestScopeKey_Private.CreateValidator()); + RunJsonRoundtrip(TestScopeKey_Public, predicate: TestScopeKey_Public.CreateValidator()); + } + + private static ExternalRequest TestExternalRequest => ExternalRequest.Create(TestPort, "Request1", "TestData"); + + [Fact] + public void SanityCheck_JsonTypeInfo() + { + JsonTypeInfo? info = WorkflowsJsonUtilities.JsonContext.Default.GetTypeInfo(typeof(string)); + info.Should().NotBeNull(); + } + + [Fact] + public void Test_PortableValue_JsonRoundtrip_BuiltInType() + { + PortableValue value = new("TestString"); + PortableValue result = RunJsonRoundtrip(value); + + result.Should().Be(value); + + // Also validate that we can extract the value as the correct type + string? extracted = result.As(); + + extracted.Should().Be("TestString"); + + // And that we can't extract it as an incorrect type + result.Is().Should().BeFalse(); + } + + [Fact] + public void Test_PortableValue_JsonRoundTrip_InternalType() + { + ChatMessage message = new(ChatRole.User, "Hello, world!"); + + PortableValue value = new(message); + PortableValue result = RunJsonRoundtrip(value); + + result.Should().Be(value); + + // Also validate that we can extract the value as the correct type + ChatMessage? chatMessage = result.As(); + + chatMessage.Should().NotBeNull(); + chatMessage.Role.Should().Be(ChatRole.User); + chatMessage.Text.Should().Be("Hello, world!"); + + // And that we can't extract it as an incorrect type + result.Is().Should().BeFalse(); + } + + [Fact] + public void Test_PortableValue_JsonRoundTrip_CustomType() + { + TestJsonSerializable test = new() { Id = 42, Name = "Test" }; + + PortableValue value = new(test); + PortableValue result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions); + + result.Should().Be(value); + + // Also validate that we can extract the value as the correct type + TestJsonSerializable? extracted = result.As(); + + extracted.Should().NotBeNull(); + extracted.Id.Should().Be(42); + extracted.Name.Should().Be("Test"); + + // And that we can't extract it as an incorrect type + result.Is().Should().BeFalse(); + } + + private static void ValidateExternalRequest(ExternalRequest actual, ExternalRequest expected) + { + bool isIdEqual = actual.RequestId == expected.RequestId; + bool isPortEqual = actual.PortInfo == expected.PortInfo; + bool isDataEqual = actual.Data == expected.Data; + + isIdEqual.Should().BeTrue(); + isPortEqual.Should().BeTrue(); + isDataEqual.Should().BeTrue(); + } + + [Fact] + public void Test_ExternalRequest_JsonRoundtrip() + { + ExternalRequest result = RunJsonRoundtrip(TestExternalRequest); + ValidateExternalRequest(result, TestExternalRequest); + } + + private static ExternalResponse TestExternalResponse => TestExternalRequest.CreateResponse(123); + + [Fact] + public void Test_ExternalResponse_JsonRoundtrip() + { + ExternalResponse result = RunJsonRoundtrip(TestExternalResponse); + + bool isIdEqual = result.RequestId == TestExternalResponse.RequestId; + bool isPortEqual = result.PortInfo == TestExternalResponse.PortInfo; + bool isDataEqual = result.Data == TestExternalResponse.Data; + + isIdEqual.Should().BeTrue(); + isPortEqual.Should().BeTrue(); + isDataEqual.Should().BeTrue(); + } + + [Fact] + public void Test_PortableMessageEnvelope_JsonRoundtrip_BuiltInType() + { + string message = "TestMessage"; + + MessageEnvelope envelope = new(message, new TypeId(typeof(object)), targetId: "Target1"); + PortableMessageEnvelope value = new(envelope); + PortableMessageEnvelope result = RunJsonRoundtrip(value); + + bool isTypeEqual = result.MessageType == value.MessageType; + bool isTargetEqual = result.TargetId == value.TargetId; + bool isMessageEqual = result.Message == value.Message; + + isTypeEqual.Should().BeTrue(); + isTargetEqual.Should().BeTrue(); + isMessageEqual.Should().BeTrue(); + + MessageEnvelope reconstructed = result.ToMessageEnvelope(); + + reconstructed.MessageType.Should().Be(envelope.MessageType); + reconstructed.TargetId.Should().Be(envelope.TargetId); + reconstructed.Message.Should().Be(envelope.Message); + } + + [Fact] + public void Test_PortableMessageEnvelope_JsonRoundtrip_InternalType() + { + ChatMessage message = new(ChatRole.User, "Hello, world!"); + + MessageEnvelope envelope = new(message, new TypeId(typeof(object)), targetId: "Target1"); + PortableMessageEnvelope value = new(envelope); + PortableMessageEnvelope result = RunJsonRoundtrip(value); + + bool isTypeEqual = result.MessageType == value.MessageType; + bool isTargetEqual = result.TargetId == value.TargetId; + bool isMessageEqual = result.Message == value.Message; + + isTypeEqual.Should().BeTrue(); + isTargetEqual.Should().BeTrue(); + isMessageEqual.Should().BeTrue(); + + MessageEnvelope reconstructed = result.ToMessageEnvelope(); + + reconstructed.MessageType.Should().Be(envelope.MessageType); + reconstructed.TargetId.Should().Be(envelope.TargetId); + + // Unfortunately, ChatMessage does not contain an "equality" comparer, so we need to explicitly pull it out + // Simulate what PortableValue does in .Equals() + Type expectedType = envelope.Message.GetType(); + object? maybeReconstructedMessage = ((PortableValue)reconstructed.Message)!.AsType(expectedType); + maybeReconstructedMessage.Should().NotBeNull() + .And.BeOfType() + .And.Match(message.CreateValidatorCheckingText()); + } + + [Fact] + public void Test_PortableMessageEnvelope_JsonRoundtrip_CustomType() + { + TestJsonSerializable message = new() { Id = 42, Name = "Test" }; + + MessageEnvelope envelope = new(message, new TypeId(typeof(object)), targetId: "Target1"); + PortableMessageEnvelope value = new(envelope); + PortableMessageEnvelope result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions); + + bool isTypeEqual = result.MessageType == value.MessageType; + bool isTargetEqual = result.TargetId == value.TargetId; + bool isMessageEqual = result.Message == value.Message; + + isTypeEqual.Should().BeTrue(); + isTargetEqual.Should().BeTrue(); + isMessageEqual.Should().BeTrue(); + + MessageEnvelope reconstructed = result.ToMessageEnvelope(); + + reconstructed.MessageType.Should().Be(envelope.MessageType); + reconstructed.TargetId.Should().Be(envelope.TargetId); + reconstructed.Message.Should().Be(envelope.Message); + } + + private static RunnerStateData TestRunnerStateData + { + get + { + return new( + [ForwardStringId, ForwardIntId], + CreateQueuedMessages(), + outstandingRequests: [TestExternalRequest] + ); + + Dictionary> CreateQueuedMessages() + { + Dictionary> result = new(); + + MessageEnvelope externalEnvelope = new(TestExternalResponse); + result.Add(ExecutorIdentity.None, [new(externalEnvelope)]); + + MessageEnvelope internalEnvelope = new("InternalMessage"); + result.Add("TestExecutor1", [new(internalEnvelope)]); + + return result; + } + } + } + + private static void ValidateRunnerStateData(RunnerStateData result, RunnerStateData prototype) + { + Assert.Collection(result.InstantiatedExecutors, + prototype.InstantiatedExecutors.Select( + prototype => + (Action)(actual => actual.Should().Be(prototype))).ToArray()); + + result.QueuedMessages.Should().HaveCount(prototype.QueuedMessages.Count); + foreach (ExecutorIdentity key in prototype.QueuedMessages.Keys) + { + result.QueuedMessages.Should().ContainKey(key); + + List actualList = result.QueuedMessages[key]; + List expectedList = prototype.QueuedMessages[key]; + + actualList.Should().HaveCount(expectedList.Count); + for (int i = 0; i < expectedList.Count; i++) + { + PortableMessageEnvelope actual = actualList[i]; + PortableMessageEnvelope expected = expectedList[i]; + actual.MessageType.Should().Be(expected.MessageType); + actual.TargetId.Should().Be(expected.TargetId); + actual.Message.Should().Be(expected.Message); + } + } + + result.OutstandingRequests.Should().HaveCount(prototype.OutstandingRequests.Count); + + Assert.Collection(result.OutstandingRequests, + prototype.OutstandingRequests.Select( + expected => + (Action)(actual => ValidateExternalRequest(actual, expected))).ToArray()); + } + + [Fact] + public void Test_RunnerStateData_JsonRoundtrip() + { + RunnerStateData prototype = TestRunnerStateData; + RunnerStateData result = RunJsonRoundtrip(prototype); + + ValidateRunnerStateData(result, prototype); + } + + private static FanInEdgeState TestFanInEdgeState => new(TestFanInEdgeData); + private static PortableValue CreateEdgeState(TMessage message) where TMessage : notnull + { + FanInEdgeState state = TestFanInEdgeState; + _ = state.ProcessMessage("SourceExecutor1", new MessageEnvelope(message, typeof(TMessage))); + + return new(state); + } + + private static TestJsonSerializable TestCustomSerializable => new() { Id = 42, Name = nameof(TestCustomSerializable) }; + + private static Dictionary TestEdgeState + { + get + { + return new() + { + [TakeEdgeId()] = CreateEdgeState("Hello, world!"), + [TakeEdgeId()] = CreateEdgeState(TestExternalResponse), + [TakeEdgeId()] = CreateEdgeState(TestCustomSerializable) + }; + } + } + + private static void ValidateEdgeStateData(Dictionary result, Dictionary prototype) + { + result.Should().HaveCount(prototype.Count); + foreach (EdgeId id in prototype.Keys) + { + result.Should().ContainKey(id) + .And.Subject[id].Should().Be(prototype[id]) + .And.Subject.As() + .As().Should().NotBeNull() + .And.Match(CreateValidator(prototype[id].As()!)); + } + Expression> CreateValidator(FanInEdgeState prototype) + { + return actual => actual.Unseen.SetEquals(prototype.Unseen) && + actual.SourceIds.SequenceEqual(prototype.SourceIds) && + actual.PendingMessages.Zip(prototype.PendingMessages, + (actualMessage, expectedMessage) => actualMessage.MessageType == expectedMessage.MessageType && + actualMessage.TargetId == expectedMessage.TargetId && + actualMessage.Message.Equals(expectedMessage.Message)).All(v => v); + } + } + + [Fact] + public void Test_EdgeStateData_JsonRoundtrip() + { + Dictionary value = TestEdgeState; + Dictionary result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions); + + ValidateEdgeStateData(result, value); + } + + private static ScopeKey TestScopeKey1 => new(StringToIntId, null, "Key1"); + private static ScopeKey TestScopeKey2 => new(StringToIntId, "Shared", "Key2"); + private static ScopeKey TestScopeKey3 => new(IntToStringId, "Shared", "Key3"); + + private static ChatMessage TestUserMessage => new(ChatRole.User, "Hello"); + + private static Dictionary TestStateData + { + get + { + return new() + { + [TestScopeKey1] = new("Lorem Ipsum"), + [TestScopeKey2] = new(TestUserMessage), + [TestScopeKey3] = new(TestCustomSerializable) + }; + } + } + + private static void ValidateStateData(Dictionary result, Dictionary prototype) + { + result.Should().HaveCount(prototype.Count); + + foreach (ScopeKey key in prototype.Keys) + { + PortableValue state = + result.Should().ContainKey(key) + .And.Subject[key].Should().Be(prototype[key]) + .And.Subject.As(); + switch (key.Key) + { + case "Key1": + state.As().Should().Be("Lorem Ipsum"); + break; + case "Key2": + ChatMessage? maybeMessage = state.As(); + maybeMessage.Should().NotBeNull() + .And.Match(TestUserMessage.CreateValidatorCheckingText()); + break; + case "Key3": + state.As().Should().Be(TestCustomSerializable); + break; + default: + throw new NotImplementedException($"Missing validation for key '{key.Key}'"); + } + } + } + + [Fact] + public void Test_ExecutorStateData_JsonRoundTrip() + { + Dictionary value = TestStateData; + Dictionary result = RunJsonRoundtrip(value, TestCustomSerializedJsonOptions); + + ValidateStateData(result, value); + } + + private static readonly string s_runId = Guid.NewGuid().ToString("N"); + private static readonly string s_parentCheckpointId = Guid.NewGuid().ToString("N"); + + private static CheckpointInfo TestParentCheckpointInfo => new(s_runId, s_parentCheckpointId); + + [Fact] + public void Test_Checkpoint_JsonRoundTrip() + { + Checkpoint prototype = new(12, TestWorkflowInfo, TestRunnerStateData, TestStateData, TestEdgeState, TestParentCheckpointInfo); + Checkpoint result = RunJsonRoundtrip(prototype, TestCustomSerializedJsonOptions); + + result.Should().Match((Checkpoint checkpoint) => checkpoint.StepNumber == prototype.StepNumber); + + result.Parent.Should().Be(prototype.Parent); + + ValidateWorkflowInfo(result.Workflow, prototype.Workflow); + ValidateRunnerStateData(result.RunnerData, prototype.RunnerData); + ValidateStateData(result.StateData, prototype.StateData); + ValidateEdgeStateData(result.EdgeStateData, prototype.EdgeStateData); + } +} diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/RepresentationTests.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/RepresentationTests.cs index ef0978db75..81b7d0a3e5 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/RepresentationTests.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/RepresentationTests.cs @@ -106,51 +106,53 @@ public class RepresentationTests [Fact] public void Test_EdgeInfos() { + int edgeId = 0; + // Direct Edges - Edge directEdgeNoCondition = new(new DirectEdgeData(Source(1), Sink(2))); + Edge directEdgeNoCondition = new(new DirectEdgeData(Source(1), Sink(2), TakeEdgeId())); RunEdgeInfoMatchTest(directEdgeNoCondition); - Edge directEdgeNoCondition2 = new(new DirectEdgeData(Source(1), Sink(2))); + Edge directEdgeNoCondition2 = new(new DirectEdgeData(Source(1), Sink(2), TakeEdgeId())); RunEdgeInfoMatchTest(directEdgeNoCondition, directEdgeNoCondition2); - Edge directEdgeNoCondition3 = new(new DirectEdgeData(Source(3), Sink(4))); + Edge directEdgeNoCondition3 = new(new DirectEdgeData(Source(3), Sink(4), TakeEdgeId())); RunEdgeInfoMatchTest(directEdgeNoCondition, directEdgeNoCondition3, expect: false); - Edge directEdgeWithCondition = new(new DirectEdgeData(Source(3), Sink(4), Condition())); + Edge directEdgeWithCondition = new(new DirectEdgeData(Source(3), Sink(4), TakeEdgeId(), Condition())); RunEdgeInfoMatchTest(directEdgeWithCondition); RunEdgeInfoMatchTest(directEdgeNoCondition2, directEdgeWithCondition, expect: false); RunEdgeInfoMatchTest(directEdgeNoCondition3, directEdgeWithCondition, expect: false); // FanOut Edges - Edge fanOutEdgeNoAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)])); + Edge fanOutEdgeNoAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], TakeEdgeId())); RunEdgeInfoMatchTest(fanOutEdgeNoAssigner); - Edge fanOutEdgeNoAssigner2 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)])); + Edge fanOutEdgeNoAssigner2 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], TakeEdgeId())); RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner2); - Edge fanOutEdgeNoAssigner3 = new(new FanOutEdgeData(Source(1), [Sink(3), Sink(4), Sink(2)])); + Edge fanOutEdgeNoAssigner3 = new(new FanOutEdgeData(Source(1), [Sink(3), Sink(4), Sink(2)], TakeEdgeId())); RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner3, expect: false); // Order matters (though without Assigner maybe it shouldn't?) - Edge fanOutEdgeNoAssigner4 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(5)])); - Edge fanOutEdgeNoAssigner5 = new(new FanOutEdgeData(Source(2), [Sink(2), Sink(3), Sink(4)])); + Edge fanOutEdgeNoAssigner4 = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(5)], TakeEdgeId())); + Edge fanOutEdgeNoAssigner5 = new(new FanOutEdgeData(Source(2), [Sink(2), Sink(3), Sink(4)], TakeEdgeId())); RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner4, expect: false); // Identity matters RunEdgeInfoMatchTest(fanOutEdgeNoAssigner, fanOutEdgeNoAssigner5, expect: false); - Edge fanOutEdgeWithAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], EdgeAssigner())); + Edge fanOutEdgeWithAssigner = new(new FanOutEdgeData(Source(1), [Sink(2), Sink(3), Sink(4)], TakeEdgeId(), EdgeAssigner())); RunEdgeInfoMatchTest(fanOutEdgeWithAssigner); // FanIn Edges - Edge fanInEdge = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1))); + Edge fanInEdge = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId())); RunEdgeInfoMatchTest(fanInEdge); - Edge fanInEdge2 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1))); + Edge fanInEdge2 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(1), TakeEdgeId())); RunEdgeInfoMatchTest(fanInEdge, fanInEdge2); - Edge fanInEdge3 = new(new FanInEdgeData([Source(2), Source(3), Source(1)], Sink(1))); + Edge fanInEdge3 = new(new FanInEdgeData([Source(2), Source(3), Source(1)], Sink(1), TakeEdgeId())); RunEdgeInfoMatchTest(fanInEdge, fanInEdge3, expect: false); // Order matters (though for FanIn maybe it shouldn't?) - Edge fanInEdge4 = new(new FanInEdgeData([Source(1), Source(2), Source(4)], Sink(1))); - Edge fanInEdge5 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(2))); + Edge fanInEdge4 = new(new FanInEdgeData([Source(1), Source(2), Source(4)], Sink(1), TakeEdgeId())); + Edge fanInEdge5 = new(new FanInEdgeData([Source(1), Source(2), Source(3)], Sink(2), TakeEdgeId())); RunEdgeInfoMatchTest(fanInEdge, fanInEdge4, expect: false); // Identity matters RunEdgeInfoMatchTest(fanInEdge, fanInEdge5, expect: false); @@ -161,6 +163,8 @@ public class RepresentationTests EdgeInfo info = edge.ToEdgeInfo(); info.IsMatch(comparatorEdge).Should().Be(expect); } + + EdgeId TakeEdgeId() => new(edgeId++); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs index 9a5cbf64c0..a51d34aa5f 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/02_Simple_Workflow_Condition.cs @@ -21,8 +21,8 @@ internal static class Step2EntryPoint RemoveSpamExecutor removeSpam = new(); return new WorkflowBuilder(detectSpam) - .AddEdge(detectSpam, respondToMessage, isSpam => isSpam is false) // If not spam, respond - .AddEdge(detectSpam, removeSpam, isSpam => isSpam is true) // If spam, remove + .AddEdge(detectSpam, respondToMessage, (bool isSpam) => isSpam is false) // If not spam, respond + .AddEdge(detectSpam, removeSpam, (bool isSpam) => isSpam is true) // If spam, remove .Build(); } } diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs index b56efaff50..fa3edddc5e 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/03_Simple_Workflow_Loop.cs @@ -137,6 +137,6 @@ internal sealed class JudgeExecutor : ReflectingExecutor, IMessag protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellation = default) { - this.Tries = await context.ReadStateAsync("TryCount").ConfigureAwait(false); + this.Tries = await context.ReadStateAsync("TryCount").ConfigureAwait(false) ?? 0; } } diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/04_Simple_Workflow_ExternalRequest.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/04_Simple_Workflow_ExternalRequest.cs index 5f94475cb3..4e793b3398 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/04_Simple_Workflow_ExternalRequest.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/04_Simple_Workflow_ExternalRequest.cs @@ -15,8 +15,8 @@ internal static class Step4EntryPoint return new WorkflowBuilder(guessNumber) .AddEdge(guessNumber, judge) - .AddEdge(judge, guessNumber, (message) => message is NumberSignal signal && signal != NumberSignal.Matched) - .BuildWithOutput(judge, ComputeStreamingOutput, (NumberSignal s, string? _) => s == NumberSignal.Matched); + .AddEdge(judge, guessNumber, (NumberSignal signal) => signal != NumberSignal.Matched) + .BuildWithOutput(judge, ComputeStreamingOutput, (NumberSignal s, string? _) => s == NumberSignal.Matched); } public static Workflow WorkflowInstance @@ -60,10 +60,10 @@ internal static class Step4EntryPoint Func userGuessCallback, string? runningState) { - object result = request.Port.Id switch + object result = request.PortInfo.PortId switch { "GuessNumber" => userGuessCallback(runningState ?? "Guess the number."), - _ => throw new NotSupportedException($"Request {request.Port.Id} is not supported") + _ => throw new NotSupportedException($"Request {request.PortInfo.PortId} is not supported") }; return request.CreateResponse(result); diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/05_Simple_Workflow_Checkpointing.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/05_Simple_Workflow_Checkpointing.cs index 9093645810..78692a2a40 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/05_Simple_Workflow_Checkpointing.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/Sample/05_Simple_Workflow_Checkpointing.cs @@ -11,13 +11,13 @@ namespace Microsoft.Agents.Workflows.Sample; internal static class Step5EntryPoint { - private static CheckpointManager CheckpointManager { get; } = new(); - - public static async ValueTask RunAsync(TextWriter writer, Func userGuessCallback, bool rehydrateToRestore = false) + public static async ValueTask RunAsync(TextWriter writer, Func userGuessCallback, bool rehydrateToRestore = false, CheckpointManager? checkpointManager = null) { + checkpointManager ??= CheckpointManager.Default; + Workflow workflow = Step4EntryPoint.CreateWorkflowInstance(out JudgeExecutor judge); Checkpointed> checkpointed = - await InProcessExecution.StreamAsync(workflow, NumberSignal.Init, CheckpointManager) + await InProcessExecution.StreamAsync(workflow, NumberSignal.Init, checkpointManager) .ConfigureAwait(false); List checkpoints = new(); @@ -34,7 +34,7 @@ internal static class Step5EntryPoint if (rehydrateToRestore) { - checkpointed = await InProcessExecution.ResumeStreamAsync(workflow, targetCheckpoint, CheckpointManager, CancellationToken.None) + checkpointed = await InProcessExecution.ResumeStreamAsync(workflow, targetCheckpoint, checkpointManager, CancellationToken.None) .ConfigureAwait(false); handle = checkpointed.Run; } @@ -105,10 +105,10 @@ internal static class Step5EntryPoint Func userGuessCallback, string? runningState) { - object result = request.Port.Id switch + object result = request.PortInfo.PortId switch { "GuessNumber" => userGuessCallback(runningState ?? "Guess the number."), - _ => throw new NotSupportedException($"Request {request.Port.Id} is not supported") + _ => throw new NotSupportedException($"Request {request.PortInfo.PortId} is not supported") }; return request.CreateResponse(result); diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleJsonContext.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleJsonContext.cs new file mode 100644 index 0000000000..d9148e1407 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleJsonContext.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; +using Microsoft.Agents.Workflows.Sample; + +namespace Microsoft.Agents.Workflows.UnitTests; + +// Checkpointing Types +[JsonSerializable(typeof(NumberSignal))] +[ExcludeFromCodeCoverage] +internal sealed partial class SampleJsonContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs index ed1fec105b..7facfabc2b 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SampleSmokeTest.cs @@ -3,6 +3,7 @@ using System; using System.IO; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Agents.Workflows.Sample; @@ -122,6 +123,29 @@ public class SampleSmokeTest Assert.Equal("You guessed correctly! You Win!", guessResult); } + [Fact] + public async Task Test_RunSample_Step5bAsync() + { + using StringWriter writer = new(); + + VerifyingPlaybackResponder responder = new( + // Iteration 1 + ("Guess the number.", 50), + ("Your guess was too high. Try again.", 23), + + // Iteration 2 + ("Your guess was too high. Try again.", 23), + ("Your guess was too low. Try again.", 42) + ); + + JsonSerializerOptions options = new(SampleJsonContext.Default.Options); + options.MakeReadOnly(); + + CheckpointManager memoryJsonManager = CheckpointManager.CreateJson(new InMemoryJsonStore(), options); + string guessResult = await Step5EntryPoint.RunAsync(writer, userGuessCallback: responder.InvokeNext, rehydrateToRestore: true, checkpointManager: memoryJsonManager); + Assert.Equal("You guessed correctly! You Win!", guessResult); + } + [Fact] public async Task Test_RunSample_Step6Async() { diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/StateManagerTests.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/StateManagerTests.cs index e9554da976..5de5dd39b7 100644 --- a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/StateManagerTests.cs +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/StateManagerTests.cs @@ -406,7 +406,7 @@ public class StateManagerTests // Act: Update the key from one executor and delete it from another await manager.WriteStateAsync(scopeSelfView, Key1, "newValue"); - await manager.WriteStateAsync(scopeOtherView, Key1, null); + await manager.ClearStateAsync(scopeOtherView, Key1); Func act = async () => await manager.PublishUpdatesAsync(tracer: null); if (isSharedScope) diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SubstitutionVisitor.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SubstitutionVisitor.cs new file mode 100644 index 0000000000..913848f1d0 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/SubstitutionVisitor.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq.Expressions; + +namespace Microsoft.Agents.Workflows.UnitTests; + +internal sealed class SubstitutionVisitor(ParameterExpression parameter, Expression substitution) : ExpressionVisitor +{ + private ParameterExpression Parameter => parameter; + private Expression Substitution => substitution; + + protected override Expression VisitParameter(ParameterExpression node) + { + if (node.Name == this.Parameter.Name) + { + return this.Substitution; + } + + return base.VisitParameter(node); + } +} diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonContext.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonContext.cs new file mode 100644 index 0000000000..21ba03a62e --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonContext.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.Workflows.UnitTests; + +// Checkpointing Types +[JsonSerializable(typeof(TestJsonSerializable))] +[ExcludeFromCodeCoverage] +internal sealed partial class TestJsonContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonSerializable.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonSerializable.cs new file mode 100644 index 0000000000..98d1aa9618 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/TestJsonSerializable.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.Workflows.UnitTests; + +[JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + NumberHandling = JsonNumberHandling.AllowReadingFromString)] + +internal sealed class TestJsonSerializable +{ + public int Id { get; set; } + public string Name { get; set; } = string.Empty; + + public override bool Equals(object? obj) + { + if (obj == null) + { + return false; + } + + if (obj is not TestJsonSerializable other) + { + return false; + } + + return this.Id == other.Id && this.Name == other.Name; + } + + public override int GetHashCode() => HashCode.Combine(this.Id, this.Name); +} diff --git a/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/ValidationExtensions.cs b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/ValidationExtensions.cs new file mode 100644 index 0000000000..ac77efcf8f --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.Workflows.UnitTests/ValidationExtensions.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.Agents.Workflows.Checkpointing; +using Microsoft.Agents.Workflows.Execution; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.Workflows.UnitTests; + +internal static partial class ValidationExtensions +{ + public static Expression> CreateValidator(this EdgeConnection prototype) + { + return actual => actual.SourceIds.Count == prototype.SourceIds.Count && + actual.SinkIds.Count == prototype.SinkIds.Count && + prototype.SourceIds.SequenceEqual(actual.SourceIds) && + prototype.SinkIds.SequenceEqual(actual.SinkIds); + } + + public static Expression> CreateValidator(this TypeId prototype) + { + return actual => actual.AssemblyName == prototype.AssemblyName && + actual.TypeName == prototype.TypeName; + } + + public static Expression> CreateValidator(this ExecutorInfo prototype) + { + return actual => actual.ExecutorId == prototype.ExecutorId && + // Rely on the TypeId test to probe TypeId serialization - just validate that we got a functional TypeId + actual.ExecutorType.Equals(prototype.ExecutorType); + } + + public static Expression> CreatePortInfoValidator(this InputPort prototype) + { + return actual => actual.PortId == prototype.Id && + // Rely on the TypeId test to probe TypeId serialization - just validate that we got a functional TypeId + actual.RequestType.IsMatch(prototype.Request) && + actual.ResponseType.IsMatch(prototype.Response); + } + + public static Expression> CreateValidator(this DirectEdgeInfo prototype) + { + return actual => actual.Connection == prototype.Connection && + actual.HasCondition == prototype.HasCondition; + } + + public static Expression> CreateValidator(this FanOutEdgeInfo prototype) + { + return actual => actual.Connection == prototype.Connection && + actual.HasAssigner == prototype.HasAssigner; + } + + public static Expression> CreateValidator(this FanInEdgeInfo prototype) + { + return actual => actual.Connection == prototype.Connection; + } + + public static Expression> CreatePolyValidator(this EdgeInfo prototype) + { + switch (prototype.Kind) + { + case EdgeKind.Direct: + { + var innerValidatorExpr = CreateValidator((DirectEdgeInfo)prototype); + + // Check that incoming is of the correct type, and if so, chain to the body + Debug.Assert(innerValidatorExpr.Parameters.Count == 1, "Validator is of unexpected arity"); + + return CreateValidatorExpression(innerValidatorExpr); + } + case EdgeKind.FanOut: + { + var innerValidatorExpr = CreateValidator((FanOutEdgeInfo)prototype); + + // Check that incoming is of the correct type, and if so, chain to the body + Debug.Assert(innerValidatorExpr.Parameters.Count == 1, "Validator is of unexpected arity"); + + return CreateValidatorExpression(innerValidatorExpr); + } + case EdgeKind.FanIn: + { + var innerValidatorExpr = CreateValidator((FanInEdgeInfo)prototype); + + // Check that incoming is of the correct type, and if so, chain to the body + Debug.Assert(innerValidatorExpr.Parameters.Count == 1, "Validator is of unexpected arity"); + + return CreateValidatorExpression(innerValidatorExpr); + } + default: + throw new NotSupportedException($"Unsupported edge type: {prototype.Kind}"); + } + + Expression> CreateValidatorExpression(Expression> innerValidator) + where TInner : EdgeInfo + { + var innerParam = innerValidator.Parameters[0]; + var innerBody = innerValidator.Body; + + var outerParam = Expression.Parameter(typeof(EdgeInfo), "actual"); + var convertExpr = Expression.Convert(outerParam, typeof(TInner)); + + ExpressionVisitor visitor = new SubstitutionVisitor(innerParam, convertExpr); + Expression innerValidatorExpr = visitor.Visit(innerBody); + + BinaryExpression bodyExpression = Expression.AndAlso( + Expression.AndAlso( + Expression.Equal( + Expression.Property(outerParam, nameof(EdgeInfo.Kind)), + Expression.Constant(prototype.Kind) + ), + Expression.TypeIs(outerParam, typeof(TInner)) + ), + innerValidatorExpr + ); + + Expression> validatorExpr = Expression.Lambda>( + bodyExpression, + outerParam + ); + + return validatorExpr; + } + } + + public static Expression> CreateValidator(this ScopeId prototype) + { + return actual => actual.ExecutorId == prototype.ExecutorId && + actual.ScopeName == prototype.ScopeName; + } + + public static Expression> CreateValidator(this ScopeKey prototype) + { + return actual => actual.Key == prototype.Key && + actual.ScopeId.ScopeName == prototype.ScopeId.ScopeName && + actual.ScopeId.ExecutorId == prototype.ScopeId.ExecutorId; + } + + public static Expression> CreateValidator(this ExecutorIdentity prototype) + { + return actual => actual.Id == prototype.Id; + } + + public static Expression> CreateValidator(this ExternalRequest prototype) + { + return actual => actual.RequestId == prototype.RequestId && + actual.PortInfo == prototype.PortInfo && + actual.Data == prototype.Data; + } + + public static Expression> CreateValidator(this ExternalResponse prototype) + { + return actual => actual.RequestId == prototype.RequestId && + actual.Data == prototype.Data; + } + + public static Expression> CreateValidatorCheckingText(this ChatMessage prototype) + { + return actual => actual.Role == prototype.Role && + actual.AuthorName == prototype.AuthorName && + actual.CreatedAt == prototype.CreatedAt && + actual.MessageId == prototype.MessageId && + actual.Text == prototype.Text; + } +}