Declarative workflow bugfix

This commit is contained in:
Peter Ibekwe
2026-06-15 16:43:10 -07:00
Unverified
parent 8b0405de1b
commit 61378eee01
5 changed files with 1509 additions and 166 deletions
@@ -62,7 +62,7 @@ internal abstract class DeclarativeActionExecutor : Executor<ActionExecutorResul
protected virtual bool EmitResultEvent => true;
/// <inheritdoc/>
public ValueTask ResetAsync()
public virtual ValueTask ResetAsync()
{
return default;
}
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
@@ -28,12 +29,23 @@ internal sealed class InvokeFunctionToolExecutor(
WorkflowFormulaState state) :
DeclarativeActionExecutor<InvokeFunctionTool>(model, state)
{
private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshot);
private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshots);
private const string PendingCallIdsStateKey = nameof(_pendingNonApprovalCallIds);
private const string LegacyApprovalSnapshotStateKey = "_approvalSnapshot";
/// <summary>
/// Snapshot of evaluated parameters at approval-request time.
/// Snapshots of evaluated parameters captured at approval-request time, keyed by
/// per-invocation request id. Each pending approval lives here until the matching
/// response is captured.
/// </summary>
private ApprovalSnapshot? _approvalSnapshot;
private readonly ConcurrentDictionary<string, ApprovalSnapshot> _approvalSnapshots = new(StringComparer.Ordinal);
/// <summary>
/// Per-invocation call ids for in-flight non-approval requests; used to match the
/// returning <see cref="FunctionResultContent"/> on the response path. Used as a set;
/// the byte value is ignored.
/// </summary>
private readonly ConcurrentDictionary<string, byte> _pendingNonApprovalCallIds = new(StringComparer.Ordinal);
/// <summary>
/// Step identifiers for the function tool invocation workflow.
@@ -65,9 +77,12 @@ internal sealed class InvokeFunctionToolExecutor(
bool requireApproval = this.GetRequireApproval();
Dictionary<string, object?>? arguments = this.GetArguments();
// Per-invocation request id stamped on the outbound content.
string requestId = Guid.NewGuid().ToString("N");
// Create the function call content to send to the caller
FunctionCallContent functionCall = new(
callId: this.Id,
callId: requestId,
name: functionName,
arguments: arguments);
@@ -77,11 +92,15 @@ internal sealed class InvokeFunctionToolExecutor(
// If approval is required, add user input request content
if (requireApproval)
{
// Snapshot the evaluated parameters.
// If state mutates during the approval window, the approved values are used on resume.
this._approvalSnapshot = new ApprovalSnapshot(functionName, arguments);
// Capture the evaluated parameters keyed by request id; the matching response
// resumes from this snapshot.
this._approvalSnapshots[requestId] = new ApprovalSnapshot(functionName, arguments);
requestMessage.Contents.Add(new ToolApprovalRequestContent(this.Id, functionCall));
requestMessage.Contents.Add(new ToolApprovalRequestContent(requestId, functionCall));
}
else
{
this._pendingNonApprovalCallIds.TryAdd(requestId, 0);
}
AgentResponse agentResponse = new([requestMessage]);
@@ -108,13 +127,24 @@ internal sealed class InvokeFunctionToolExecutor(
bool autoSend = this.GetAutoSendValue();
string? conversationId = this.GetConversationId();
// Extract function results from the response
IEnumerable<FunctionResultContent> functionResults = response.Messages
// Match the inbound result by its per-invocation call id.
FunctionResultContent? matchingResult = response.Messages
.SelectMany(m => m.Contents)
.OfType<FunctionResultContent>();
.OfType<FunctionResultContent>()
.FirstOrDefault(r => this.IsKnownPendingId(r.CallId));
FunctionResultContent? matchingResult = functionResults
.FirstOrDefault(r => r.CallId == this.Id);
// Legacy non-approval backstop: when no pendings are tracked, accept a result
// whose CallId equals this.Id. The runtime has already routed the response to
// this executor's port and the framework does not invoke a function here.
if (matchingResult is null
&& this._pendingNonApprovalCallIds.IsEmpty
&& this._approvalSnapshots.IsEmpty)
{
matchingResult = response.Messages
.SelectMany(m => m.Contents)
.OfType<FunctionResultContent>()
.FirstOrDefault(r => string.Equals(r.CallId, this.Id, StringComparison.Ordinal));
}
// When the caller approved an approval-required function call but didn't execute it
// locally (the hosted Foundry scenario, where mcp_approval_response is converted to a
@@ -123,14 +153,43 @@ internal sealed class InvokeFunctionToolExecutor(
// SendActivity/PropertyPath consumers like {Local.Result}).
if (matchingResult is null)
{
ToolApprovalResponseContent? approval = response.Messages
List<ToolApprovalResponseContent> approvals = response.Messages
.SelectMany(m => m.Contents)
.OfType<ToolApprovalResponseContent>()
.FirstOrDefault(r => r.RequestId == this.Id);
.ToList();
if (approval is { Approved: true })
// Prefer an approval matching a pending snapshot; otherwise take the first
// present approval.
ToolApprovalResponseContent? approval =
approvals.FirstOrDefault(r => this._approvalSnapshots.ContainsKey(r.RequestId))
?? approvals.FirstOrDefault();
if (approval is not null)
{
matchingResult = await this.InvokeRegisteredFunctionAsync(cancellationToken).ConfigureAwait(false);
if (!this._approvalSnapshots.ContainsKey(approval.RequestId))
{
this.Logger.LogWarning(
"Approval response '{RequestId}' did not match any pending invocation on '{ActionId}'.",
approval.RequestId, this.Id);
await this.AssignErrorAsync(context, "Function invocation was not approved by user.").ConfigureAwait(false);
}
else if (!approval.Approved)
{
this._approvalSnapshots.TryRemove(approval.RequestId, out _);
await this.AssignErrorAsync(context, "Function invocation was not approved by user.").ConfigureAwait(false);
}
else if (this._approvalSnapshots.TryRemove(approval.RequestId, out ApprovalSnapshot? snapshot))
{
matchingResult = await this.InvokeRegisteredFunctionAsync(approval.RequestId, snapshot, cancellationToken).ConfigureAwait(false);
}
else
{
// Snapshot was consumed by a concurrent delivery; surface the not-approved error.
this.Logger.LogWarning(
"Approval response '{RequestId}' had no remaining pending snapshot on '{ActionId}'.",
approval.RequestId, this.Id);
await this.AssignErrorAsync(context, "Function invocation was not approved by user.").ConfigureAwait(false);
}
}
}
@@ -145,6 +204,10 @@ internal sealed class InvokeFunctionToolExecutor(
AgentResponse resultResponse = new([new ChatMessage(ChatRole.Tool, [matchingResult])]);
await context.AddEventAsync(new AgentResponseEvent(this.Id, resultResponse), cancellationToken).ConfigureAwait(false);
}
// Drop the per-invocation entry now that the response has been processed.
this._pendingNonApprovalCallIds.TryRemove(matchingResult.CallId, out _);
this._approvalSnapshots.TryRemove(matchingResult.CallId, out _);
}
// Store messages if output path is configured
@@ -167,31 +230,76 @@ internal sealed class InvokeFunctionToolExecutor(
// Completes the action after processing the function result.
await context.RaiseCompletionEventAsync(this.Model, cancellationToken).ConfigureAwait(false);
}
// Clear the approval snapshot after the action completes so a subsequent
// execution of the same executor instance doesn't reuse stale data.
this._approvalSnapshot = null;
await context.QueueStateUpdateAsync<ApprovalSnapshot?>(ApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false);
private bool IsKnownPendingId(string callId) =>
this._pendingNonApprovalCallIds.ContainsKey(callId) || this._approvalSnapshots.ContainsKey(callId);
/// <inheritdoc/>
public override ValueTask ResetAsync()
{
this._approvalSnapshots.Clear();
this._pendingNonApprovalCallIds.Clear();
return default;
}
/// <inheritdoc/>
/// <remarks>
/// Persists the approval snapshot to workflow state so it survives checkpoint/restore cycles.
/// Persists pending approval snapshots and non-approval call ids so they survive
/// checkpoint/restore cycles.
/// </remarks>
protected override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, this._approvalSnapshot, null, cancellationToken).ConfigureAwait(false);
Dictionary<string, ApprovalSnapshot> snapshotCopy = this._approvalSnapshots.ToDictionary(kvp => kvp.Key, kvp => kvp.Value, StringComparer.Ordinal);
await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, snapshotCopy, null, cancellationToken).ConfigureAwait(false);
List<string> pendingCopy = [.. this._pendingNonApprovalCallIds.Keys];
await context.QueueStateUpdateAsync(PendingCallIdsStateKey, pendingCopy, null, cancellationToken).ConfigureAwait(false);
await base.OnCheckpointingAsync(context, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
/// <remarks>
/// Restores the approval snapshot from workflow state after a checkpoint restore.
/// Restores pending approval snapshots and non-approval call ids from workflow state
/// after a checkpoint restore.
/// </remarks>
protected override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false);
this._approvalSnapshot = await context.ReadStateAsync<ApprovalSnapshot>(ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
this._approvalSnapshots.Clear();
Dictionary<string, ApprovalSnapshot>? snapshots = await context.ReadStateAsync<Dictionary<string, ApprovalSnapshot>>(
ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
if (snapshots is not null)
{
foreach (KeyValuePair<string, ApprovalSnapshot> entry in snapshots)
{
this._approvalSnapshots[entry.Key] = entry.Value;
}
}
this._pendingNonApprovalCallIds.Clear();
List<string>? pending = await context.ReadStateAsync<List<string>>(
PendingCallIdsStateKey, null, cancellationToken).ConfigureAwait(false);
if (pending is not null)
{
foreach (string id in pending)
{
this._pendingNonApprovalCallIds.TryAdd(id, 0);
}
}
// Migrate a single ApprovalSnapshot at the legacy key under this.Id so the
// legacy approval response matches the per-invocation map; clear the legacy key.
ApprovalSnapshot? legacy = await context.ReadStateAsync<ApprovalSnapshot>(
LegacyApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
if (legacy is not null)
{
this._approvalSnapshots.TryAdd(this.Id, legacy);
await context.QueueStateUpdateAsync<ApprovalSnapshot?>(
LegacyApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false);
}
}
/// <summary>
@@ -280,6 +388,14 @@ internal sealed class InvokeFunctionToolExecutor(
await this.AssignAsync(this.Model.Output.Result?.Path, resultValue.ToFormula(), context).ConfigureAwait(false);
}
private async ValueTask AssignErrorAsync(IWorkflowContext context, string errorMessage)
{
if (this.Model.Output?.Result is not null)
{
await this.AssignAsync(this.Model.Output.Result?.Path, $"Error: {errorMessage}".ToFormula(), context).ConfigureAwait(false);
}
}
private string GetFunctionName() =>
this.Evaluator.GetValue(
Throw.IfNull(
@@ -297,32 +413,19 @@ internal sealed class InvokeFunctionToolExecutor(
return conversationIdValue.Length == 0 ? null : conversationIdValue;
}
private async ValueTask<FunctionResultContent?> InvokeRegisteredFunctionAsync(CancellationToken cancellationToken)
private async ValueTask<FunctionResultContent?> InvokeRegisteredFunctionAsync(string callId, ApprovalSnapshot snapshot, CancellationToken cancellationToken)
{
string functionName;
Dictionary<string, object?>? arguments;
if (this._approvalSnapshot is { } snapshot)
{
// Use the snapshot captured at approval-request time so we invoke exactly what
// the user approved, even if Power Fx state has mutated during the approval window.
functionName = snapshot.FunctionName;
arguments = snapshot.Arguments;
}
else
{
// Fallback for checkpoints created before approval snapshots were introduced.
this.Logger.LogWarning("Approval snapshot missing for '{ActionId}'; falling back to expression re-evaluation.", this.Id);
functionName = this.GetFunctionName();
arguments = this.GetArguments();
}
// Use the snapshot captured at approval-request time so we invoke exactly what
// the user approved, even if Power Fx state has mutated during the approval window.
string functionName = snapshot.FunctionName;
Dictionary<string, object?>? arguments = snapshot.Arguments;
AIFunction? function = agentProvider.Functions?.FirstOrDefault(
f => string.Equals(f.Name, functionName, StringComparison.Ordinal));
if (function is null)
{
return new FunctionResultContent(this.Id, result: null)
return new FunctionResultContent(callId, result: null)
{
Exception = new InvalidOperationException(
$"Function '{functionName}' is not registered with the agent provider."),
@@ -338,7 +441,7 @@ internal sealed class InvokeFunctionToolExecutor(
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
return new FunctionResultContent(this.Id, result: null) { Exception = ex };
return new FunctionResultContent(callId, result: null) { Exception = ex };
}
// Match FunctionInvokingChatClient's serialization: pass strings through as-is and
@@ -352,7 +455,7 @@ internal sealed class InvokeFunctionToolExecutor(
_ => JsonSerializer.Serialize(result, AIJsonUtilities.DefaultOptions.GetTypeInfo(result.GetType())),
};
return new FunctionResultContent(this.Id, serialized);
return new FunctionResultContent(callId, serialized);
}
private bool GetRequireApproval()
@@ -396,9 +499,8 @@ internal sealed class InvokeFunctionToolExecutor(
}
/// <summary>
/// Stores the evaluated parameters at approval-request time so that
/// <see cref="CaptureResponseAsync"/> uses the values the user reviewed,
/// even if <see cref="WorkflowFormulaState"/> mutates during the approval window.
/// Captured invocation parameters used by <see cref="CaptureResponseAsync"/> on
/// resume so the approved values are invoked regardless of subsequent state changes.
/// </summary>
internal sealed record ApprovalSnapshot(
string FunctionName,
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
@@ -27,13 +29,15 @@ internal sealed class InvokeMcpToolExecutor(
WorkflowFormulaState state) :
DeclarativeActionExecutor<InvokeMcpTool>(model, state)
{
private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshot);
private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshots);
private const string LegacyApprovalSnapshotStateKey = "_approvalSnapshot";
/// <summary>
/// Snapshot of evaluated parameters at approval-request time.
/// Used to prevent TOCTOU attacks where state mutates during the approval window.
/// Snapshots of evaluated parameters captured at approval-request time, keyed by
/// per-invocation request id. Each pending approval lives here until the matching
/// response is captured.
/// </summary>
private ApprovalSnapshot? _approvalSnapshot;
private readonly ConcurrentDictionary<string, ApprovalSnapshot> _approvalSnapshots = new(StringComparer.Ordinal);
/// <summary>
/// Step identifiers for the MCP tool invocation workflow.
@@ -83,19 +87,22 @@ internal sealed class InvokeMcpToolExecutor(
if (requireApproval)
{
// Snapshot the evaluated parameters to prevent TOCTOU attacks.
// If state mutates during the approval window, the approved values are used on resume.
this._approvalSnapshot = new ApprovalSnapshot(serverUrl, serverLabel, toolName, arguments, connectionName);
// Per-invocation request id stamped on the outbound content.
string requestId = Guid.NewGuid().ToString("N");
// Capture the evaluated parameters keyed by request id; the matching response
// resumes from this snapshot.
this._approvalSnapshots[requestId] = new ApprovalSnapshot(serverUrl, serverLabel, toolName, arguments, connectionName);
// Create tool call content for approval request.
// Transport headers (e.g. Authorization) are intentionally excluded from the
// approval event: they must not cross into the externally-surfaced approval request.
McpServerToolCallContent toolCall = new(this.Id, toolName, serverLabel ?? serverUrl)
McpServerToolCallContent toolCall = new(requestId, toolName, serverLabel ?? serverUrl)
{
Arguments = arguments
};
ToolApprovalRequestContent approvalRequest = new(this.Id, toolCall);
ToolApprovalRequestContent approvalRequest = new(requestId, toolCall);
ChatMessage requestMessage = new(ChatRole.Assistant, [approvalRequest]);
AgentResponse agentResponse = new([requestMessage]);
@@ -140,31 +147,37 @@ internal sealed class InvokeMcpToolExecutor(
ToolApprovalResponseContent? approvalResponse = response.Messages
.SelectMany(m => m.Contents)
.OfType<ToolApprovalResponseContent>()
.FirstOrDefault(r => r.RequestId == this.Id);
.FirstOrDefault(r => this._approvalSnapshots.ContainsKey(r.RequestId));
if (approvalResponse?.Approved != true)
{
// Tool call was rejected
// Rejected, no matching pending, or unknown request id: surface a not-approved
// error and drop any matched snapshot.
if (approvalResponse is not null)
{
this._approvalSnapshots.TryRemove(approvalResponse.RequestId, out _);
}
await this.AssignErrorAsync(context, "MCP tool invocation was not approved by user.").ConfigureAwait(false);
return;
}
// Source invocation fields from the snapshot captured at approval-request time.
// Headers are re-evaluated (they may contain auth secrets not persisted to state).
if (!this._approvalSnapshots.TryRemove(approvalResponse.RequestId, out ApprovalSnapshot? snapshot))
{
await this.AssignErrorAsync(context, "MCP tool invocation was not approved by user.").ConfigureAwait(false);
return;
}
// Approved - use the snapshot from approval-request time to prevent TOCTOU attacks.
// Headers are re-evaluated (they may contain auth secrets that should not be persisted).
string serverUrl = this._approvalSnapshot?.ServerUrl ?? this.GetServerUrl();
string? serverLabel = this._approvalSnapshot?.ServerLabel ?? this.GetServerLabel();
string toolName = this._approvalSnapshot?.ToolName ?? this.GetToolName();
Dictionary<string, object?>? arguments = this._approvalSnapshot?.Arguments ?? this.GetArguments();
Dictionary<string, string>? headers = this.GetHeaders();
string? connectionName = this._approvalSnapshot?.ConnectionName ?? this.GetConnectionName();
McpServerToolResultContent resultContent = await mcpToolHandler.InvokeToolAsync(
serverUrl,
serverLabel,
toolName,
arguments,
snapshot.ServerUrl,
snapshot.ServerLabel,
snapshot.ToolName,
snapshot.Arguments,
headers,
connectionName,
snapshot.ConnectionName,
cancellationToken).ConfigureAwait(false);
await this.ProcessResultAsync(context, resultContent, cancellationToken).ConfigureAwait(false);
@@ -175,31 +188,57 @@ internal sealed class InvokeMcpToolExecutor(
/// </summary>
public async ValueTask CompleteAsync(IWorkflowContext context, ActionExecutorResult message, CancellationToken cancellationToken)
{
// Clear the approval snapshot after successful completion.
this._approvalSnapshot = null;
await ClearSnapshotStateAsync(context, cancellationToken).ConfigureAwait(false);
await context.RaiseCompletionEventAsync(this.Model, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public override ValueTask ResetAsync()
{
this._approvalSnapshots.Clear();
return default;
}
/// <inheritdoc/>
/// <remarks>
/// Persists the approval snapshot to workflow state so it survives checkpoint/restore cycles.
/// Persists pending approval snapshots to workflow state so they survive
/// checkpoint/restore cycles.
/// </remarks>
protected override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, this._approvalSnapshot, null, cancellationToken).ConfigureAwait(false);
Dictionary<string, ApprovalSnapshot> snapshotCopy = this._approvalSnapshots.ToDictionary(kvp => kvp.Key, kvp => kvp.Value, StringComparer.Ordinal);
await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, snapshotCopy, null, cancellationToken).ConfigureAwait(false);
await base.OnCheckpointingAsync(context, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
/// <remarks>
/// Restores the approval snapshot from workflow state after a checkpoint restore.
/// Restores pending approval snapshots from workflow state after a checkpoint restore.
/// </remarks>
protected override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false);
this._approvalSnapshot = await context.ReadStateAsync<ApprovalSnapshot>(ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
this._approvalSnapshots.Clear();
Dictionary<string, ApprovalSnapshot>? snapshots = await context.ReadStateAsync<Dictionary<string, ApprovalSnapshot>>(
ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
if (snapshots is not null)
{
foreach (KeyValuePair<string, ApprovalSnapshot> entry in snapshots)
{
this._approvalSnapshots[entry.Key] = entry.Value;
}
}
// Migrate a single ApprovalSnapshot at the legacy key under this.Id so the
// legacy approval response matches the per-invocation map; clear the legacy key.
ApprovalSnapshot? legacy = await context.ReadStateAsync<ApprovalSnapshot>(
LegacyApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
if (legacy is not null)
{
this._approvalSnapshots.TryAdd(this.Id, legacy);
await context.QueueStateUpdateAsync<ApprovalSnapshot?>(
LegacyApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false);
}
}
private async ValueTask ProcessResultAsync(IWorkflowContext context, McpServerToolResultContent resultContent, CancellationToken cancellationToken)
@@ -404,17 +443,8 @@ internal sealed class InvokeMcpToolExecutor(
}
/// <summary>
/// Clears the persisted approval snapshot state after a successful tool invocation.
/// </summary>
private static async ValueTask ClearSnapshotStateAsync(IWorkflowContext context, CancellationToken cancellationToken)
{
await context.QueueStateUpdateAsync<ApprovalSnapshot?>(ApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Stores the evaluated parameters at approval-request time so that
/// <see cref="CaptureResponseAsync"/> uses the values the user reviewed,
/// even if <see cref="WorkflowFormulaState"/> mutates during the approval window.
/// Captured invocation parameters used by <see cref="CaptureResponseAsync"/> on
/// resume so the approved values are invoked regardless of subsequent state changes.
/// </summary>
internal sealed record ApprovalSnapshot(
string ServerUrl,
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
@@ -301,8 +302,9 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
onInvoke: name => capturedFunctionName = name);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
// Act - trigger ExecuteAsync to emit the approval request
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate parallel branch mutating state during the approval window
@@ -310,7 +312,7 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
this.State.Bind();
// User clicks approve (they saw "safe_readonly_query" in the approval UI)
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
@@ -349,8 +351,9 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
onInvokeArguments: args => capturedArguments = args);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
// Act - trigger ExecuteAsync to emit the approval request
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate parallel branch mutating state during the approval window
@@ -358,7 +361,7 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
this.State.Bind();
// User clicks approve
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
@@ -396,18 +399,20 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
onInvoke: name => capturedFunctionName = name);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore();
// Act - trigger ExecuteAsync to emit the approval request and capture the snapshot
List<ExternalInputRequest> emittedRequests = [];
Dictionary<string, object?> stateStore = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore, emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate checkpoint: persist to state store
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
// Simulate restore on a "new" executor instance by clearing the in-memory field via reflection
// (In production, a new executor instance would be created with _approvalSnapshot == null)
typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!
.SetValue(action, null);
// Simulate restore on a "new" executor instance by clearing the in-memory dictionary via reflection
ConcurrentDictionary<string, ApprovalSnapshot> liveSnapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
liveSnapshots.Clear();
// Restore from state store
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
@@ -417,7 +422,7 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
this.State.Bind();
// User clicks approve
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
@@ -428,9 +433,7 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
}
/// <summary>
/// Verifies that the approval snapshot is cleared after a completed approval cycle,
/// both in-memory and in the persisted state store. This prevents stale data from
/// influencing a subsequent execution of the same executor instance.
/// Verifies that the approval snapshot entry is removed after a completed approval cycle.
/// </summary>
[Fact]
public async Task InvokeFunctionToolCaptureResponseClearsSnapshotAfterCompletionAsync()
@@ -451,33 +454,723 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - run the full approval cycle
List<ExternalInputRequest> emittedRequests = [];
Dictionary<string, object?> stateStore = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore);
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore, emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Sanity: snapshot was captured
FieldInfo snapshotField = typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!;
Assert.NotNull(snapshotField.GetValue(action));
// Sanity: snapshot dict has exactly one entry
FieldInfo snapshotsField = typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!;
ConcurrentDictionary<string, ApprovalSnapshot> snapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)snapshotsField.GetValue(action)!;
Assert.Single(snapshots);
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - both in-memory field and persisted state are cleared
Assert.Null(snapshotField.GetValue(action));
Assert.True(stateStore.ContainsKey("_approvalSnapshot"));
Assert.Null(stateStore["_approvalSnapshot"]);
// Assert - in-memory dict is empty after the matching response is captured
Assert.Empty(snapshots);
}
private static ExternalInputResponse CreateApprovalResponse(string actionId, bool approved)
/// <summary>
/// Each ExecuteAsync invocation must produce a unique per-invocation request id on
/// both the FunctionCallContent.CallId and the ToolApprovalRequestContent.RequestId.
/// </summary>
[Fact]
public async Task InvokeFunctionToolEmitsUniqueRequestIdPerInvocationAsync()
{
FunctionCallContent functionCall = new(callId: actionId, name: "ignored");
ToolApprovalRequestContent approvalRequest = new(actionId, functionCall);
// Arrange
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolEmitsUniqueRequestIdPerInvocationAsync),
functionName: "any_function",
requireApproval: true);
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => "result", name: "any_function")]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
// Act - emit two approval requests from the same executor instance
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Assert - two distinct request ids surfaced
Assert.Equal(2, emittedRequests.Count);
string id1 = emittedRequests[0].AgentResponse.Messages
.SelectMany(m => m.Contents).OfType<ToolApprovalRequestContent>().Single().RequestId;
string id2 = emittedRequests[1].AgentResponse.Messages
.SelectMany(m => m.Contents).OfType<ToolApprovalRequestContent>().Single().RequestId;
Assert.NotEqual(id1, id2);
Assert.NotEqual(action.Id, id1);
Assert.NotEqual(action.Id, id2);
// And the matching inner FunctionCallContent uses the same id
FunctionCallContent fcc1 = emittedRequests[0].AgentResponse.Messages
.SelectMany(m => m.Contents).OfType<FunctionCallContent>().Single();
Assert.Equal(id1, fcc1.CallId);
}
/// <summary>
/// Two concurrent pending approvals on the same executor must each resume with their
/// own approved arguments — out-of-order responses must not swap which invocation gets
/// which set of arguments.
/// </summary>
[Fact]
public async Task InvokeFunctionToolConcurrentPendingApprovalsDoNotSwapAsync()
{
// Arrange
const string FunctionName = "process_query";
const string ArgumentKey = "query";
const string ArgumentsA = "A-args";
const string ArgumentsB = "B-args";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolConcurrentPendingApprovalsDoNotSwapAsync),
functionName: FunctionName,
requireApproval: true,
argumentKey: ArgumentKey,
argumentValue: ArgumentsA);
InvokeFunctionTool modelB = this.CreateModel(
displayName: nameof(InvokeFunctionToolConcurrentPendingApprovalsDoNotSwapAsync) + "B",
functionName: FunctionName,
requireApproval: true,
argumentKey: ArgumentKey,
argumentValue: ArgumentsB);
List<string?> capturedQueries = [];
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create((string query) => $"executed:{query}", name: FunctionName)],
onInvokeArguments: args => capturedQueries.Add(args[ArgumentKey]?.ToString()));
// Two executor instances simulating concurrent fan-in scenarios with different YAML-evaluated args
InvokeFunctionToolExecutor actionA = new(model, testAgentProvider, this.State);
InvokeFunctionToolExecutor actionB = new(modelB, testAgentProvider, this.State);
List<ExternalInputRequest> emittedA = [];
List<ExternalInputRequest> emittedB = [];
Mock<IWorkflowContext> ctxA = CreateMockWorkflowContext(emittedA);
Mock<IWorkflowContext> ctxB = CreateMockWorkflowContext(emittedB);
// Act - both executors emit approval requests
await actionA.HandleAsync(new ActionExecutorResult(actionA.Id), ctxA.Object, CancellationToken.None);
await actionB.HandleAsync(new ActionExecutorResult(actionB.Id), ctxB.Object, CancellationToken.None);
// Deliver responses out of order: B first, then A
await actionB.CaptureResponseAsync(ctxB.Object, CreateApprovalResponseFor(emittedB, approved: true), CancellationToken.None);
await actionA.CaptureResponseAsync(ctxA.Object, CreateApprovalResponseFor(emittedA, approved: true), CancellationToken.None);
// Assert - each invocation executed with its own approved arguments
Assert.Equal([ArgumentsB, ArgumentsA], capturedQueries);
}
/// <summary>
/// When the approval response references a request id that is not in the snapshot map,
/// the executor must surface a structured error and must not invoke any function.
/// </summary>
[Fact]
public async Task InvokeFunctionToolMissingSnapshotReturnsStructuredErrorAsync()
{
// Arrange
const string FunctionName = "any_function";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolMissingSnapshotReturnsStructuredErrorAsync),
functionName: FunctionName,
requireApproval: true);
bool functionWasInvoked = false;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => { functionWasInvoked = true; return "result"; }, name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
// Act - deliver an approval response whose RequestId has no matching snapshot
FunctionCallContent fcc = new(callId: "stale-id", name: FunctionName);
ToolApprovalRequestContent staleRequest = new("stale-id", fcc);
ToolApprovalResponseContent staleResponse = staleRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [staleResponse]));
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the registered function must NOT have been invoked. The
// ToolApprovalResponseContent.RequestId did not match any snapshot in the executor's
// map, so the executor does not attempt to invoke the function at all (no silent
// state re-evaluation).
Assert.False(functionWasInvoked);
}
/// <summary>
/// Two non-approval invocations of the same executor must emit distinct per-invocation
/// CallIds so each response is matched to its originating request.
/// </summary>
[Fact]
public async Task InvokeFunctionToolNonApprovalCallIdsAreDistinctPerInvocationAsync()
{
// Arrange
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolNonApprovalCallIdsAreDistinctPerInvocationAsync),
functionName: "any_function",
requireApproval: false);
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => "result", name: "any_function")]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
// Act - emit two non-approval function-call requests
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Assert - distinct CallIds were stamped on the two emitted FunctionCallContents
Assert.Equal(2, emittedRequests.Count);
FunctionCallContent fcc1 = emittedRequests[0].AgentResponse.Messages
.SelectMany(m => m.Contents).OfType<FunctionCallContent>().Single();
FunctionCallContent fcc2 = emittedRequests[1].AgentResponse.Messages
.SelectMany(m => m.Contents).OfType<FunctionCallContent>().Single();
Assert.NotEqual(fcc1.CallId, fcc2.CallId);
Assert.NotEqual(action.Id, fcc1.CallId);
Assert.NotEqual(action.Id, fcc2.CallId);
}
/// <summary>
/// A snapshot persisted at the legacy <c>"_approvalSnapshot"</c> key must be migrated
/// under <c>this.Id</c> after restore so an approval response carrying
/// <c>RequestId == this.Id</c> resumes with the snapshot's arguments.
/// </summary>
[Fact]
public async Task InvokeFunctionToolLegacySingleSnapshotCheckpointIsMigratedAsync()
{
// Arrange
const string FunctionName = "any_function";
const string ArgumentKey = "query";
const string LegacyApprovedArg = "legacy-approved";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolLegacySingleSnapshotCheckpointIsMigratedAsync),
functionName: FunctionName,
requireApproval: true);
AIFunctionArguments? capturedArguments = null;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create((string query) => $"executed:{query}", name: FunctionName)],
onInvokeArguments: args => capturedArguments = args);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Seed the state store with a single ApprovalSnapshot at the legacy key.
Dictionary<string, object?> stateStore = new()
{
["_approvalSnapshot"] = new ApprovalSnapshot(
FunctionName,
new Dictionary<string, object?> { [ArgumentKey] = LegacyApprovedArg }),
};
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore);
// Act - restore migrates the legacy snapshot under this.Id.
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
ConcurrentDictionary<string, ApprovalSnapshot> snapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
Assert.True(snapshots.ContainsKey(action.Id));
// Deliver an approval response with RequestId == action.Id and resume.
FunctionCallContent fcc = new(callId: action.Id, name: FunctionName);
ToolApprovalRequestContent legacyRequest = new(action.Id, fcc);
ToolApprovalResponseContent legacyResponse = legacyRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [legacyResponse]));
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the function was invoked with the snapshot arguments.
Assert.NotNull(capturedArguments);
Assert.Equal(LegacyApprovedArg, capturedArguments[ArgumentKey]?.ToString());
}
/// <summary>
/// The legacy <c>"_approvalSnapshot"</c> key is removed from the state store after
/// migration so subsequent checkpoints do not carry stale data.
/// </summary>
[Fact]
public async Task InvokeFunctionToolLegacyKeyIsClearedAfterMigrationAsync()
{
// Arrange
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolLegacyKeyIsClearedAfterMigrationAsync),
functionName: "any_function",
requireApproval: true);
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => "result", name: "any_function")]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
Dictionary<string, object?> stateStore = new()
{
["_approvalSnapshot"] = new ApprovalSnapshot("any_function", new Dictionary<string, object?>()),
};
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore);
// Act
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
// Assert - legacy key was cleared via QueueStateUpdateAsync<ApprovalSnapshot?>(null).
Assert.False(stateStore.ContainsKey("_approvalSnapshot"));
}
/// <summary>
/// Drives ExecuteAsync → checkpoint → ResetAsync → restore → CaptureResponseAsync on a
/// single pending approval and asserts the originally-approved arguments are used,
/// even though ResetAsync cleared the in-memory dict between checkpoint and restore.
/// </summary>
[Fact]
public async Task InvokeFunctionToolResumeAfterResetUsesPersistedSnapshotAsync()
{
// Arrange
const string FunctionName = "process_query";
const string ArgumentKey = "query";
const string ApprovedQuery = "SELECT * FROM users LIMIT 10";
this.State.Set("SqlQuery", FormulaValue.New(ApprovedQuery));
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModelWithVariableArgument(
displayName: nameof(InvokeFunctionToolResumeAfterResetUsesPersistedSnapshotAsync),
functionName: FunctionName,
argumentKey: ArgumentKey,
variableName: "SqlQuery");
AIFunctionArguments? capturedArguments = null;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create((string query) => $"executed:{query}", name: FunctionName)],
onInvokeArguments: args => capturedArguments = args);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Dictionary<string, object?> stateStore = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore, emittedRequests);
ConcurrentDictionary<string, ApprovalSnapshot> liveSnapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
// Act - emit, checkpoint, reset (simulates runner end), restore, then capture.
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Single(liveSnapshots);
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the originally-approved argument was used and the entry was removed.
Assert.NotNull(capturedArguments);
Assert.Equal(ApprovedQuery, capturedArguments[ArgumentKey]?.ToString());
Assert.Empty(liveSnapshots);
}
/// <summary>
/// Two pending invocations (A then B) are interleaved with checkpoint/reset/restore
/// cycles; A's snapshot must survive both reset cycles and route A's response to
/// A's arguments, while B remains pending and is later resolved correctly.
/// </summary>
[Fact]
public async Task InvokeFunctionToolMultiplePendingInvocationsSurviveCheckpointResetRestoreAsync()
{
// Arrange
const string FunctionName = "process_query";
const string ArgumentKey = "query";
const string ArgumentsA = "A-args";
const string ArgumentsB = "B-args";
this.State.Set("SqlQuery", FormulaValue.New(ArgumentsA));
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModelWithVariableArgument(
displayName: nameof(InvokeFunctionToolMultiplePendingInvocationsSurviveCheckpointResetRestoreAsync),
functionName: FunctionName,
argumentKey: ArgumentKey,
variableName: "SqlQuery");
List<string?> capturedQueries = [];
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create((string query) => $"executed:{query}", name: FunctionName)],
onInvokeArguments: args => capturedQueries.Add(args[ArgumentKey]?.ToString()));
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Dictionary<string, object?> stateStore = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore, emittedRequests);
ConcurrentDictionary<string, ApprovalSnapshot> liveSnapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
// Act - invocation A with ArgumentsA, then full checkpoint/reset/restore.
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Single(liveSnapshots);
// Mutate the source variable, then invocation B with ArgumentsB.
this.State.Set("SqlQuery", FormulaValue.New(ArgumentsB));
this.State.Bind();
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Equal(2, liveSnapshots.Count);
// Capture A's response. State has been mutated to ArgumentsB but the per-invocation
// snapshot must still drive invocation with ArgumentsA.
Assert.Equal(2, emittedRequests.Count);
ExternalInputResponse responseA = CreateApprovalResponseForRequest(emittedRequests[0], approved: true);
await action.CaptureResponseAsync(mockContext.Object, responseA, CancellationToken.None);
Assert.Single(liveSnapshots);
Assert.Equal([ArgumentsA], capturedQueries);
// Another checkpoint/reset/restore cycle - B's snapshot survives.
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Single(liveSnapshots);
// Capture B's response.
ExternalInputResponse responseB = CreateApprovalResponseForRequest(emittedRequests[1], approved: true);
await action.CaptureResponseAsync(mockContext.Object, responseB, CancellationToken.None);
// Assert - both invocations executed with their own approved arguments; nothing pending.
Assert.Equal([ArgumentsA, ArgumentsB], capturedQueries);
Assert.Empty(liveSnapshots);
}
/// <summary>
/// An approval response whose RequestId does not match any pending snapshot must
/// NOT invoke the function and must assign a not-approved error to Output.Result.
/// </summary>
[Fact]
public async Task InvokeFunctionToolUnmatchedApprovalAssignsErrorAsync()
{
// Arrange
const string FunctionName = "any_function";
const string ResultVariable = "Result";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolUnmatchedApprovalAssignsErrorAsync),
functionName: FunctionName,
requireApproval: true,
outputResultVariable: ResultVariable);
bool functionWasInvoked = false;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => { functionWasInvoked = true; return "result"; }, name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
// Act - deliver an approval response whose RequestId has no matching snapshot.
FunctionCallContent fcc = new(callId: "stale-id", name: FunctionName);
ToolApprovalRequestContent staleRequest = new("stale-id", fcc);
ToolApprovalResponseContent staleResponse = staleRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [staleResponse]));
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - function was NOT invoked AND the error string landed at Output.Result.
Assert.False(functionWasInvoked);
Assert.Contains(mockContext.Invocations, i =>
i.Method.Name == nameof(IWorkflowContext.QueueStateUpdateAsync)
&& i.Arguments.Count >= 2
&& i.Arguments[1] is StringValue sv
&& sv.Value.Contains("not approved by user"));
}
/// <summary>
/// An approval response whose RequestId matches a pending snapshot but is
/// Approved == false must NOT invoke the function, must remove the snapshot, and
/// must assign a not-approved error to Output.Result.
/// </summary>
[Fact]
public async Task InvokeFunctionToolRejectedApprovalAssignsErrorAsync()
{
// Arrange
const string FunctionName = "any_function";
const string ResultVariable = "Result";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolRejectedApprovalAssignsErrorAsync),
functionName: FunctionName,
requireApproval: true,
outputResultVariable: ResultVariable);
bool functionWasInvoked = false;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => { functionWasInvoked = true; return "result"; }, name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
ConcurrentDictionary<string, ApprovalSnapshot> liveSnapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
// Act - emit the approval request, then deliver a rejection for it.
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
Assert.Single(liveSnapshots);
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: false);
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - function not invoked, snapshot removed, error assigned.
Assert.False(functionWasInvoked);
Assert.Empty(liveSnapshots);
Assert.Contains(mockContext.Invocations, i =>
i.Method.Name == nameof(IWorkflowContext.QueueStateUpdateAsync)
&& i.Arguments.Count >= 2
&& i.Arguments[1] is StringValue sv
&& sv.Value.Contains("not approved by user"));
}
/// <summary>
/// When a response contains multiple <see cref="ToolApprovalResponseContent"/> items —
/// e.g. an unrelated / stale approval followed by the valid one — the executor must
/// select the approval whose RequestId matches a pending snapshot and invoke the
/// function, not silently drop the valid approval because a stale one appeared first.
/// </summary>
[Fact]
public async Task InvokeFunctionToolApprovalMatchPrefersPendingSnapshotAsync()
{
// Arrange
const string FunctionName = "any_function";
const string ResultVariable = "Result";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolApprovalMatchPrefersPendingSnapshotAsync),
functionName: FunctionName,
requireApproval: true,
outputResultVariable: ResultVariable);
bool functionWasInvoked = false;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => { functionWasInvoked = true; return "result"; }, name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
// Emit one valid approval request from this executor.
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
ExternalInputRequest emitted = Assert.Single(emittedRequests);
ToolApprovalRequestContent validRequest = emitted.AgentResponse.Messages
.SelectMany(m => m.Contents)
.OfType<ToolApprovalRequestContent>()
.Single();
// Build a batched response: a stale (unrelated) approval first, then the valid one.
ToolApprovalRequestContent staleRequest = new("stale-id", new FunctionCallContent("stale-id", FunctionName));
ToolApprovalResponseContent staleResponse = staleRequest.CreateResponse(approved: true);
ToolApprovalResponseContent validResponse = validRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [staleResponse, validResponse]));
// Act
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the valid approval drove invocation; no not-approved error was assigned.
Assert.True(functionWasInvoked);
Assert.DoesNotContain(mockContext.Invocations, i =>
i.Method.Name == nameof(IWorkflowContext.QueueStateUpdateAsync)
&& i.Arguments.Count >= 2
&& i.Arguments[1] is StringValue sv
&& sv.Value.Contains("not approved by user"));
}
/// <summary>
/// Delivering the same approval response twice must invoke the registered function
/// exactly once; the second delivery surfaces the not-approved error path because the
/// snapshot has already been consumed.
/// </summary>
[Fact]
public async Task InvokeFunctionToolDuplicateApprovalDeliveryInvokesFunctionOnceAsync()
{
// Arrange
const string FunctionName = "any_function";
const string ResultVariable = "Result";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolDuplicateApprovalDeliveryInvokesFunctionOnceAsync),
functionName: FunctionName,
requireApproval: true,
outputResultVariable: ResultVariable);
int invocationCount = 0;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => { Interlocked.Increment(ref invocationCount); return "result"; }, name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
// Emit one approval request.
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Act - deliver the SAME approval response twice.
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the registered AIFunction was invoked exactly once.
Assert.Equal(1, invocationCount);
// The second delivery surfaced the not-approved error path.
Assert.Contains(mockContext.Invocations, i =>
i.Method.Name == nameof(IWorkflowContext.QueueStateUpdateAsync)
&& i.Arguments.Count >= 2
&& i.Arguments[1] is StringValue sv
&& sv.Value.Contains("not approved by user"));
}
/// <summary>
/// A non-approval <c>FunctionResultContent</c> whose CallId equals <c>this.Id</c> is
/// consumed and assigned to <c>Output.Result</c> when no pendings are tracked.
/// </summary>
[Fact]
public async Task InvokeFunctionToolLegacyNonApprovalResultIsAcceptedAsync()
{
// Arrange - a fresh executor has no tracked pendings.
const string FunctionName = "any_function";
const string ResultVariable = "Result";
const string HostResult = "host-computed-result";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolLegacyNonApprovalResultIsAcceptedAsync),
functionName: FunctionName,
requireApproval: false,
outputResultVariable: ResultVariable);
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => "should-not-be-called", name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
// Act - deliver a FunctionResultContent with CallId == action.Id.
FunctionResultContent legacyResult = new(action.Id, HostResult);
ExternalInputResponse response = new(new ChatMessage(ChatRole.Tool, [legacyResult]));
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the host-computed result was assigned to Output.Result and no
// not-approved error was emitted.
Assert.Contains(mockContext.Invocations, i =>
i.Method.Name == nameof(IWorkflowContext.QueueStateUpdateAsync)
&& i.Arguments.Count >= 2
&& i.Arguments[1] is StringValue sv
&& sv.Value == HostResult);
Assert.DoesNotContain(mockContext.Invocations, i =>
i.Method.Name == nameof(IWorkflowContext.QueueStateUpdateAsync)
&& i.Arguments.Count >= 2
&& i.Arguments[1] is StringValue sv
&& sv.Value.Contains("not approved by user"));
}
/// <summary>
/// The legacy non-approval backstop must NOT fire when the executor has a tracked
/// pending invocation; a <c>FunctionResultContent</c> with <c>CallId == this.Id</c>
/// is rejected in that state.
/// </summary>
[Fact]
public async Task InvokeFunctionToolLegacyNonApprovalBackstopGatedOnEmptyStateAsync()
{
// Arrange - emit a non-approval call so a per-invocation CallId is tracked.
const string FunctionName = "any_function";
const string ResultVariable = "Result";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolLegacyNonApprovalBackstopGatedOnEmptyStateAsync),
functionName: FunctionName,
requireApproval: false,
outputResultVariable: ResultVariable);
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => "result", name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Act - deliver a FunctionResultContent with CallId == action.Id (not the emitted GUID).
FunctionResultContent staleLegacyResult = new(action.Id, "should-be-rejected");
ExternalInputResponse response = new(new ChatMessage(ChatRole.Tool, [staleLegacyResult]));
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - Output.Result was NOT assigned with the rejected result.
Assert.DoesNotContain(mockContext.Invocations, i =>
i.Method.Name == nameof(IWorkflowContext.QueueStateUpdateAsync)
&& i.Arguments.Count >= 2
&& i.Arguments[1] is StringValue sv
&& sv.Value == "should-be-rejected");
}
/// <summary>
/// Builds an approval response paired to the inner <c>ToolApprovalRequestContent.RequestId</c>
/// of a specific emitted request. Used when multiple requests are emitted and the
/// caller needs to address one by position.
/// </summary>
private static ExternalInputResponse CreateApprovalResponseForRequest(ExternalInputRequest emitted, bool approved)
{
ToolApprovalRequestContent approvalRequest = emitted.AgentResponse.Messages
.SelectMany(m => m.Contents)
.OfType<ToolApprovalRequestContent>()
.Single();
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved);
return new ExternalInputResponse(new ChatMessage(ChatRole.User, [approvalResponse]));
}
private static Mock<IWorkflowContext> CreateMockWorkflowContext()
/// <summary>
/// Extracts the inner <c>ToolApprovalRequestContent.RequestId</c> from the
/// approval request the executor emitted, and builds a paired response. This mirrors
/// the framework's symmetric content-id rewriting at the envelope boundary.
/// </summary>
private static ExternalInputResponse CreateApprovalResponseFor(IReadOnlyList<ExternalInputRequest> emittedRequests, bool approved)
{
ExternalInputRequest emitted = Assert.Single(emittedRequests);
return CreateApprovalResponseForRequest(emitted, approved);
}
private static Mock<IWorkflowContext> CreateMockWorkflowContext(List<ExternalInputRequest>? emittedRequests = null)
{
Mock<IWorkflowContext> mockContext = new();
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
@@ -485,25 +1178,64 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<object?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<object, string?, CancellationToken>((msg, _, _) =>
{
if (emittedRequests is not null && msg is ExternalInputRequest request)
{
emittedRequests.Add(request);
}
})
.Returns(default(ValueTask));
return mockContext;
}
/// <summary>
/// Creates a mock workflow context that actually stores state values (for checkpoint/restore tests).
/// Optionally accepts an externally-owned dictionary so callers can inspect the persisted state.
/// Optionally accepts an externally-owned dictionary so callers can inspect the persisted state,
/// and an optional emitted-request list so tests can build matching responses.
/// </summary>
private static Mock<IWorkflowContext> CreateMockWorkflowContextWithStateStore(Dictionary<string, object?>? stateStore = null)
private static Mock<IWorkflowContext> CreateMockWorkflowContextWithStateStore(
Dictionary<string, object?>? stateStore = null,
List<ExternalInputRequest>? emittedRequests = null)
{
stateStore ??= [];
Mock<IWorkflowContext> mockContext = new();
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<Dictionary<string, ApprovalSnapshot>>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<string, Dictionary<string, ApprovalSnapshot>, string?, CancellationToken>((key, value, _, _) => stateStore[key] = value)
.Returns(default(ValueTask));
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<List<string>>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<string, List<string>, string?, CancellationToken>((key, value, _, _) => stateStore[key] = value)
.Returns(default(ValueTask));
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<ApprovalSnapshot?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<string, ApprovalSnapshot?, string?, CancellationToken>((key, value, _, _) => stateStore[key] = value)
.Callback<string, ApprovalSnapshot?, string?, CancellationToken>((key, value, _, _) =>
{
if (value is null)
{
stateStore.Remove(key);
}
else
{
stateStore[key] = value;
}
})
.Returns(default(ValueTask));
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<object, string?, CancellationToken>((msg, _, _) =>
{
if (emittedRequests is not null && msg is ExternalInputRequest request)
{
emittedRequests.Add(request);
}
})
.Returns(default(ValueTask));
mockContext.Setup(c => c.ReadStateAsync<Dictionary<string, ApprovalSnapshot>>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns<string, string?, CancellationToken>((key, _, _) =>
new ValueTask<Dictionary<string, ApprovalSnapshot>?>(stateStore.TryGetValue(key, out object? val) ? val as Dictionary<string, ApprovalSnapshot> : null));
mockContext.Setup(c => c.ReadStateAsync<List<string>>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns<string, string?, CancellationToken>((key, _, _) =>
new ValueTask<List<string>?>(stateStore.TryGetValue(key, out object? val) ? val as List<string> : null));
mockContext.Setup(c => c.ReadStateAsync<ApprovalSnapshot>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns<string, string?, CancellationToken>((key, _, _) =>
new ValueTask<ApprovalSnapshot?>(stateStore.TryGetValue(key, out object? val) ? val as ApprovalSnapshot : null));
@@ -622,7 +1354,8 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
bool? requireApproval = false,
string? conversationId = null,
string? argumentKey = null,
string? argumentValue = null)
string? argumentValue = null,
string? outputResultVariable = null)
{
InvokeFunctionTool.Builder builder = new()
{
@@ -642,6 +1375,14 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
builder.Arguments.Add(argumentKey, ValueExpression.Literal(new StringDataValue(argumentValue)));
}
if (outputResultVariable is not null)
{
builder.Output = new InvokeToolOutput.Builder
{
Result = new InitializablePropertyPath(PropertyPath.TopicVariable(outputResultVariable), isInitializer: false),
};
}
return AssignParent<InvokeFunctionTool>(builder);
}
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
@@ -423,15 +424,15 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
Mock<IWorkflowContext> mockContext = new(MockBehavior.Loose);
// Emit the approval request so the executor records the per-invocation snapshot.
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Build an approved response matching this action's request id.
McpServerToolCallContent toolCall = new(action.Id, TestToolName, TestServerLabel);
ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall);
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse]));
// Build the matching approved response from the emitted request.
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Act - call CaptureResponseAsync directly so the post-approval branch actually executes.
// Act - call CaptureResponseAsync so the post-approval branch actually executes.
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - headers reach the transport invocation on the approved path.
@@ -887,7 +888,8 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate parallel branch mutating state during the approval window
@@ -895,10 +897,7 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
this.State.Bind();
// User clicks approve (they saw "safe_readonly_query" in the approval UI)
McpServerToolCallContent toolCall = new(action.Id, ApprovedToolName, TestServerUrl);
ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall);
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse]));
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
@@ -950,7 +949,8 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate parallel branch mutating state during the approval window
@@ -958,10 +958,7 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
this.State.Bind();
// User clicks approve
McpServerToolCallContent toolCall = new(action.Id, TestToolName, TestServerUrl);
ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall);
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse]));
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
@@ -1011,7 +1008,8 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate parallel branch mutating state during the approval window
@@ -1019,10 +1017,7 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
this.State.Bind();
// User clicks approve
McpServerToolCallContent toolCall = new(action.Id, TestToolName, ApprovedServerUrl);
ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall);
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse]));
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
@@ -1072,17 +1067,18 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore();
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(emittedRequests);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate checkpoint: persist to state store
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
// Simulate restore on a "new" executor instance by clearing the in-memory field via reflection
// (In production, a new executor instance would be created with _approvalSnapshot == null)
typeof(InvokeMcpToolExecutor)
.GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!
.SetValue(action, null);
// Simulate restore on a "new" executor instance by clearing the in-memory dictionary via reflection
ConcurrentDictionary<string, ApprovalSnapshot> liveSnapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeMcpToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
liveSnapshots.Clear();
// Restore from state store
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
@@ -1092,10 +1088,7 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
this.State.Bind();
// User clicks approve
McpServerToolCallContent toolCall = new(action.Id, ApprovedToolName, TestServerUrl);
ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall);
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse]));
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
@@ -1105,7 +1098,440 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
Assert.Equal(ApprovedToolName, capturedToolName);
}
private static Mock<IWorkflowContext> CreateMockWorkflowContext()
/// <summary>
/// Each ExecuteAsync invocation must produce a unique per-invocation request id on
/// both the McpServerToolCallContent and the wrapping ToolApprovalRequestContent.
/// </summary>
[Fact]
public async Task InvokeMcpToolEmitsUniqueRequestIdPerInvocationAsync()
{
// Arrange
this.State.InitializeSystem();
this.State.Bind();
InvokeMcpTool model = this.CreateModelWithApproval(
displayName: nameof(InvokeMcpToolEmitsUniqueRequestIdPerInvocationAsync),
serverUrl: TestServerUrl,
toolName: TestToolName);
Mock<IMcpToolHandler> mockProvider = new();
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
List<ExternalInputRequest> emittedRequests = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext(emittedRequests);
// Act
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Assert - two distinct request ids surfaced
Assert.Equal(2, emittedRequests.Count);
string id1 = emittedRequests[0].AgentResponse.Messages
.SelectMany(m => m.Contents).OfType<ToolApprovalRequestContent>().Single().RequestId;
string id2 = emittedRequests[1].AgentResponse.Messages
.SelectMany(m => m.Contents).OfType<ToolApprovalRequestContent>().Single().RequestId;
Assert.NotEqual(id1, id2);
Assert.NotEqual(action.Id, id1);
Assert.NotEqual(action.Id, id2);
}
/// <summary>
/// Two concurrent pending MCP approvals on different executor instances (representing
/// concurrent fan-in or interleaved invocations) must each resume with their own
/// approved parameters when responses are delivered out of order.
/// </summary>
[Fact]
public async Task InvokeMcpToolConcurrentPendingApprovalsDoNotSwapAsync()
{
// Arrange
const string ToolA = "tool_alpha";
const string ToolB = "tool_beta";
this.State.InitializeSystem();
this.State.Bind();
InvokeMcpTool modelA = this.CreateModelWithApproval(
displayName: nameof(InvokeMcpToolConcurrentPendingApprovalsDoNotSwapAsync) + "A",
serverUrl: TestServerUrl,
toolName: ToolA);
InvokeMcpTool modelB = this.CreateModelWithApproval(
displayName: nameof(InvokeMcpToolConcurrentPendingApprovalsDoNotSwapAsync) + "B",
serverUrl: TestServerUrl,
toolName: ToolB);
List<string?> capturedToolNames = [];
Mock<IMcpToolHandler> mockProvider = new();
mockProvider.Setup(p => p.InvokeToolAsync(
It.IsAny<string>(),
It.IsAny<string?>(),
It.IsAny<string>(),
It.IsAny<IDictionary<string, object?>?>(),
It.IsAny<IDictionary<string, string>?>(),
It.IsAny<string?>(),
It.IsAny<CancellationToken>()))
.Callback<string, string?, string, IDictionary<string, object?>?, IDictionary<string, string>?, string?, CancellationToken>(
(_, _, toolName, _, _, _, _) => capturedToolNames.Add(toolName))
.ReturnsAsync(new McpServerToolResultContent("capture-call-id")
{
Outputs = [new TextContent("ok")]
});
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor actionA = new(modelA, mockProvider.Object, mockAgentProvider.Object, this.State);
InvokeMcpToolExecutor actionB = new(modelB, mockProvider.Object, mockAgentProvider.Object, this.State);
List<ExternalInputRequest> emittedA = [];
List<ExternalInputRequest> emittedB = [];
Mock<IWorkflowContext> ctxA = CreateMockWorkflowContext(emittedA);
Mock<IWorkflowContext> ctxB = CreateMockWorkflowContext(emittedB);
// Act - both executors emit approval requests
await actionA.HandleAsync(new ActionExecutorResult(actionA.Id), ctxA.Object, CancellationToken.None);
await actionB.HandleAsync(new ActionExecutorResult(actionB.Id), ctxB.Object, CancellationToken.None);
// Deliver responses out of order
await actionB.CaptureResponseAsync(ctxB.Object, CreateApprovalResponseFor(emittedB, approved: true), CancellationToken.None);
await actionA.CaptureResponseAsync(ctxA.Object, CreateApprovalResponseFor(emittedA, approved: true), CancellationToken.None);
// Assert - each invocation invoked its own approved tool name
Assert.Equal([ToolB, ToolA], capturedToolNames);
}
/// <summary>
/// When the approval response references a request id that is not in the snapshot map,
/// the executor must NOT invoke the MCP tool.
/// </summary>
[Fact]
public async Task InvokeMcpToolMissingSnapshotAssignsErrorAsync()
{
// Arrange
this.State.InitializeSystem();
this.State.Bind();
InvokeMcpTool model = this.CreateModelWithApproval(
displayName: nameof(InvokeMcpToolMissingSnapshotAssignsErrorAsync),
serverUrl: TestServerUrl,
toolName: TestToolName);
Mock<IMcpToolHandler> mockProvider = new();
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
// Act - deliver an approval response whose RequestId has no matching snapshot
McpServerToolCallContent toolCall = new("stale-id", TestToolName, TestServerUrl);
ToolApprovalRequestContent staleRequest = new("stale-id", toolCall);
ToolApprovalResponseContent staleResponse = staleRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [staleResponse]));
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - mcpToolHandler.InvokeToolAsync must NOT have been called
mockProvider.Verify(p => p.InvokeToolAsync(
It.IsAny<string>(),
It.IsAny<string?>(),
It.IsAny<string>(),
It.IsAny<IDictionary<string, object?>?>(),
It.IsAny<IDictionary<string, string>?>(),
It.IsAny<string?>(),
It.IsAny<CancellationToken>()), Times.Never);
}
/// <summary>
/// A snapshot persisted at the legacy <c>"_approvalSnapshot"</c> key must be migrated
/// under <c>this.Id</c> after restore so an approval response carrying
/// <c>RequestId == this.Id</c> resumes with the snapshot's tool name.
/// </summary>
[Fact]
public async Task InvokeMcpToolLegacySingleSnapshotCheckpointIsMigratedAsync()
{
// Arrange
const string LegacyApprovedToolName = "legacy_approved_tool";
this.State.InitializeSystem();
this.State.Bind();
InvokeMcpTool model = this.CreateModelWithApproval(
displayName: nameof(InvokeMcpToolLegacySingleSnapshotCheckpointIsMigratedAsync),
serverUrl: TestServerUrl,
toolName: TestToolName);
string? capturedToolName = null;
Mock<IMcpToolHandler> mockProvider = new();
mockProvider.Setup(p => p.InvokeToolAsync(
It.IsAny<string>(),
It.IsAny<string?>(),
It.IsAny<string>(),
It.IsAny<IDictionary<string, object?>?>(),
It.IsAny<IDictionary<string, string>?>(),
It.IsAny<string?>(),
It.IsAny<CancellationToken>()))
.Callback<string, string?, string, IDictionary<string, object?>?, IDictionary<string, string>?, string?, CancellationToken>(
(_, _, toolName, _, _, _, _) => capturedToolName = toolName)
.ReturnsAsync(new McpServerToolResultContent("capture-call-id")
{
Outputs = [new TextContent("ok")]
});
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
// Seed the state store with a single ApprovalSnapshot at the legacy key.
Dictionary<string, object?> stateStore = new()
{
["_approvalSnapshot"] = new ApprovalSnapshot(
TestServerUrl, null, LegacyApprovedToolName, new Dictionary<string, object?>(), null),
};
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStoreSeeded(stateStore);
// Act - restore migrates the legacy snapshot under this.Id.
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
ConcurrentDictionary<string, ApprovalSnapshot> snapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeMcpToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
Assert.True(snapshots.ContainsKey(action.Id));
// Deliver an approval response with RequestId == action.Id and resume.
McpServerToolCallContent toolCall = new(action.Id, LegacyApprovedToolName, TestServerUrl);
ToolApprovalRequestContent legacyRequest = new(action.Id, toolCall);
ToolApprovalResponseContent legacyResponse = legacyRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [legacyResponse]));
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the MCP tool was invoked with the snapshot's tool name.
Assert.Equal(LegacyApprovedToolName, capturedToolName);
}
/// <summary>
/// The legacy <c>"_approvalSnapshot"</c> key is removed from the state store after
/// migration so subsequent checkpoints do not carry stale data.
/// </summary>
[Fact]
public async Task InvokeMcpToolLegacyKeyIsClearedAfterMigrationAsync()
{
// Arrange
this.State.InitializeSystem();
this.State.Bind();
InvokeMcpTool model = this.CreateModelWithApproval(
displayName: nameof(InvokeMcpToolLegacyKeyIsClearedAfterMigrationAsync),
serverUrl: TestServerUrl,
toolName: TestToolName);
Mock<IMcpToolHandler> mockProvider = new();
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
Dictionary<string, object?> stateStore = new()
{
["_approvalSnapshot"] = new ApprovalSnapshot(
TestServerUrl, null, TestToolName, new Dictionary<string, object?>(), null),
};
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStoreSeeded(stateStore);
// Act
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
// Assert - legacy key was cleared via QueueStateUpdateAsync<ApprovalSnapshot?>(null).
Assert.False(stateStore.ContainsKey("_approvalSnapshot"));
}
/// <summary>
/// Variant of CreateMockWorkflowContextWithStateStore that accepts a pre-seeded state
/// store and supports the read/write operations exercised by the legacy-migration path.
/// </summary>
private static Mock<IWorkflowContext> CreateMockWorkflowContextWithStateStoreSeeded(Dictionary<string, object?> stateStore)
{
Mock<IWorkflowContext> mockContext = new();
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<ApprovalSnapshot?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<string, ApprovalSnapshot?, string?, CancellationToken>((key, value, _, _) =>
{
if (value is null)
{
stateStore.Remove(key);
}
else
{
stateStore[key] = value;
}
})
.Returns(default(ValueTask));
mockContext.Setup(c => c.ReadStateAsync<Dictionary<string, ApprovalSnapshot>>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns<string, string?, CancellationToken>((key, _, _) =>
new ValueTask<Dictionary<string, ApprovalSnapshot>?>(stateStore.TryGetValue(key, out object? val) ? val as Dictionary<string, ApprovalSnapshot> : null));
mockContext.Setup(c => c.ReadStateAsync<ApprovalSnapshot>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns<string, string?, CancellationToken>((key, _, _) =>
new ValueTask<ApprovalSnapshot?>(stateStore.TryGetValue(key, out object? val) ? val as ApprovalSnapshot : null));
mockContext.Setup(c => c.ReadStateKeysAsync(It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new HashSet<string>());
return mockContext;
}
/// <summary>
/// Drives ExecuteAsync → checkpoint → ResetAsync → restore → CaptureResponseAsync on a
/// single pending approval and asserts the originally-approved tool name is used,
/// even though ResetAsync cleared the in-memory dict between checkpoint and restore.
/// </summary>
[Fact]
public async Task InvokeMcpToolResumeAfterResetUsesPersistedSnapshotAsync()
{
// Arrange
const string ApprovedToolName = "approved_tool";
this.State.Set("TargetTool", FormulaValue.New(ApprovedToolName));
this.State.InitializeSystem();
this.State.Bind();
InvokeMcpTool model = this.CreateModelWithVariableToolName(
displayName: nameof(InvokeMcpToolResumeAfterResetUsesPersistedSnapshotAsync),
serverUrl: TestServerUrl,
variableName: "TargetTool");
string? capturedToolName = null;
Mock<IMcpToolHandler> mockProvider = new();
mockProvider.Setup(p => p.InvokeToolAsync(
It.IsAny<string>(),
It.IsAny<string?>(),
It.IsAny<string>(),
It.IsAny<IDictionary<string, object?>?>(),
It.IsAny<IDictionary<string, string>?>(),
It.IsAny<string?>(),
It.IsAny<CancellationToken>()))
.Callback<string, string?, string, IDictionary<string, object?>?, IDictionary<string, string>?, string?, CancellationToken>(
(_, _, toolName, _, _, _, _) => capturedToolName = toolName)
.ReturnsAsync(new McpServerToolResultContent("capture-call-id")
{
Outputs = [new TextContent("ok")]
});
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
List<ExternalInputRequest> emittedRequests = [];
Dictionary<string, object?> stateStore = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(emittedRequests, stateStore);
ConcurrentDictionary<string, ApprovalSnapshot> liveSnapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeMcpToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
// Act - emit, checkpoint, reset (simulates runner end), restore, then capture.
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Single(liveSnapshots);
ExternalInputResponse response = CreateApprovalResponseFor(emittedRequests, approved: true);
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the originally-approved tool name was used and the entry was removed.
Assert.Equal(ApprovedToolName, capturedToolName);
Assert.Empty(liveSnapshots);
}
/// <summary>
/// Two pending invocations (A then B) are interleaved with checkpoint/reset/restore
/// cycles; A's snapshot must survive both reset cycles and route A's response to
/// A's tool name, while B remains pending and is later resolved correctly.
/// </summary>
[Fact]
public async Task InvokeMcpToolMultiplePendingInvocationsSurviveCheckpointResetRestoreAsync()
{
// Arrange
const string ToolA = "tool_alpha";
const string ToolB = "tool_beta";
this.State.Set("TargetTool", FormulaValue.New(ToolA));
this.State.InitializeSystem();
this.State.Bind();
InvokeMcpTool model = this.CreateModelWithVariableToolName(
displayName: nameof(InvokeMcpToolMultiplePendingInvocationsSurviveCheckpointResetRestoreAsync),
serverUrl: TestServerUrl,
variableName: "TargetTool");
List<string?> capturedToolNames = [];
Mock<IMcpToolHandler> mockProvider = new();
mockProvider.Setup(p => p.InvokeToolAsync(
It.IsAny<string>(),
It.IsAny<string?>(),
It.IsAny<string>(),
It.IsAny<IDictionary<string, object?>?>(),
It.IsAny<IDictionary<string, string>?>(),
It.IsAny<string?>(),
It.IsAny<CancellationToken>()))
.Callback<string, string?, string, IDictionary<string, object?>?, IDictionary<string, string>?, string?, CancellationToken>(
(_, _, toolName, _, _, _, _) => capturedToolNames.Add(toolName))
.ReturnsAsync(new McpServerToolResultContent("capture-call-id")
{
Outputs = [new TextContent("ok")]
});
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);
List<ExternalInputRequest> emittedRequests = [];
Dictionary<string, object?> stateStore = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(emittedRequests, stateStore);
ConcurrentDictionary<string, ApprovalSnapshot> liveSnapshots = (ConcurrentDictionary<string, ApprovalSnapshot>)typeof(InvokeMcpToolExecutor)
.GetField("_approvalSnapshots", BindingFlags.NonPublic | BindingFlags.Instance)!
.GetValue(action)!;
// Act - invocation A with ToolA, then full checkpoint/reset/restore.
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Single(liveSnapshots);
// Mutate the source variable, then invocation B with ToolB.
this.State.Set("TargetTool", FormulaValue.New(ToolB));
this.State.Bind();
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Equal(2, liveSnapshots.Count);
// Capture A's response. State has been mutated to ToolB but the per-invocation
// snapshot must still drive invocation with ToolA.
Assert.Equal(2, emittedRequests.Count);
ExternalInputResponse responseA = CreateApprovalResponseForRequest(emittedRequests[0], approved: true);
await action.CaptureResponseAsync(mockContext.Object, responseA, CancellationToken.None);
Assert.Single(liveSnapshots);
Assert.Equal([ToolA], capturedToolNames);
// Another checkpoint/reset/restore cycle - B's snapshot survives.
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
await action.ResetAsync();
Assert.Empty(liveSnapshots);
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
Assert.Single(liveSnapshots);
// Capture B's response.
ExternalInputResponse responseB = CreateApprovalResponseForRequest(emittedRequests[1], approved: true);
await action.CaptureResponseAsync(mockContext.Object, responseB, CancellationToken.None);
// Assert - both invocations executed with their own approved tool names; nothing pending.
Assert.Equal([ToolA, ToolB], capturedToolNames);
Assert.Empty(liveSnapshots);
}
private InvokeMcpTool CreateModelWithApproval(string displayName, string serverUrl, string toolName)
{
InvokeMcpTool.Builder builder = new()
{
Id = this.CreateActionId(),
DisplayName = this.FormatDisplayName(displayName),
ServerUrl = new StringExpression.Builder(StringExpression.Literal(serverUrl)),
ToolName = new StringExpression.Builder(StringExpression.Literal(toolName)),
RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)),
};
return AssignParent<InvokeMcpTool>(builder);
}
private static Mock<IWorkflowContext> CreateMockWorkflowContext(List<ExternalInputRequest>? emittedRequests = null)
{
Mock<IWorkflowContext> mockContext = new();
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
@@ -1113,32 +1539,76 @@ public sealed class InvokeMcpToolExecutorTest(ITestOutputHelper output) : Workfl
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<object?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<object, string?, CancellationToken>((msg, _, _) =>
{
if (emittedRequests is not null && msg is ExternalInputRequest request)
{
emittedRequests.Add(request);
}
})
.Returns(default(ValueTask));
return mockContext;
}
/// <summary>
/// Creates a mock workflow context that actually stores state values (for checkpoint/restore tests).
/// Optionally accepts an externally-owned state store so callers can drive multi-step
/// checkpoint/reset/restore sequences against the same persisted state.
/// </summary>
private static Mock<IWorkflowContext> CreateMockWorkflowContextWithStateStore()
private static Mock<IWorkflowContext> CreateMockWorkflowContextWithStateStore(
List<ExternalInputRequest>? emittedRequests = null,
Dictionary<string, object?>? stateStore = null)
{
Dictionary<string, object?> stateStore = new();
stateStore ??= new Dictionary<string, object?>();
Mock<IWorkflowContext> mockContext = new();
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<ApprovalSnapshot?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<string, ApprovalSnapshot?, string?, CancellationToken>((key, value, _, _) => stateStore[key] = value)
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<Dictionary<string, ApprovalSnapshot>>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<string, Dictionary<string, ApprovalSnapshot>, string?, CancellationToken>((key, value, _, _) => stateStore[key] = value)
.Returns(default(ValueTask));
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<object, string?, CancellationToken>((msg, _, _) =>
{
if (emittedRequests is not null && msg is ExternalInputRequest request)
{
emittedRequests.Add(request);
}
})
.Returns(default(ValueTask));
mockContext.Setup(c => c.ReadStateAsync<ApprovalSnapshot>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
mockContext.Setup(c => c.ReadStateAsync<Dictionary<string, ApprovalSnapshot>>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns<string, string?, CancellationToken>((key, _, _) =>
new ValueTask<ApprovalSnapshot?>(stateStore.TryGetValue(key, out object? val) ? val as ApprovalSnapshot : null));
new ValueTask<Dictionary<string, ApprovalSnapshot>?>(stateStore.TryGetValue(key, out object? val) ? val as Dictionary<string, ApprovalSnapshot> : null));
mockContext.Setup(c => c.ReadStateKeysAsync(It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new HashSet<string>());
return mockContext;
}
/// <summary>
/// Builds an approval response paired to the request id stamped on the emitted
/// <c>MCPToolApprovalRequestContent</c>. Mirrors the framework's symmetric
/// content-id rewriting at the envelope boundary.
/// </summary>
private static ExternalInputResponse CreateApprovalResponseFor(IReadOnlyList<ExternalInputRequest> emittedRequests, bool approved)
{
ExternalInputRequest emitted = Assert.Single(emittedRequests);
return CreateApprovalResponseForRequest(emitted, approved);
}
/// <summary>
/// Builds an approval response paired to the inner <c>ToolApprovalRequestContent.RequestId</c>
/// of a specific emitted request. Used when multiple requests are emitted and the
/// caller needs to address one by position.
/// </summary>
private static ExternalInputResponse CreateApprovalResponseForRequest(ExternalInputRequest emitted, bool approved)
{
ToolApprovalRequestContent approvalRequest = emitted.AgentResponse.Messages
.SelectMany(m => m.Contents)
.OfType<ToolApprovalRequestContent>()
.Single();
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved);
return new ExternalInputResponse(new ChatMessage(ChatRole.User, [approvalResponse]));
}
/// <summary>
/// Invokes a protected method on an executor via reflection (for testing checkpoint hooks).
/// </summary>