.NET: Bug fixes for declarative workflows (#6427)

* declarative workflow approval flow fix

* Update mcp handler cache construction

* fix method argument.

* Update dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix identation

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Peter Ibekwe
2026-06-10 11:08:32 -07:00
committed by GitHub
Unverified
parent 60cc5ee4e4
commit 3753d938f5
4 changed files with 664 additions and 22 deletions
@@ -2,10 +2,10 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Threading;
@@ -39,7 +39,7 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
private static readonly JsonWriterOptions s_toolListJsonWriterOptions = new() { Indented = true };
private readonly Func<string, CancellationToken, Task<HttpClient?>>? _httpClientProvider;
private readonly Dictionary<string, McpClient> _clients = [];
private readonly Dictionary<(string Url, string Label, string Connection, string HeadersHash), McpClient> _clients = [];
private readonly Dictionary<string, HttpClient> _ownedHttpClients = [];
private readonly SemaphoreSlim _clientLock = new(1, 1);
@@ -66,16 +66,15 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
string? connectionName,
CancellationToken cancellationToken = default)
{
// TODO: Handle connectionName and server label appropriately when Hosted scenario supports them. For now, ignore
if (IsListToolsToolName(toolName))
{
ThrowIfListToolsArgumentsSpecified(arguments);
McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false);
McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false);
IList<McpClientTool> tools = await listToolsClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
return CreateListToolsResultContent(tools.Select(tool => tool.ProtocolTool));
}
McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false);
McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false);
McpServerToolResultContent resultContent = new(Guid.NewGuid().ToString());
@@ -145,10 +144,11 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
string serverUrl,
string? serverLabel,
IDictionary<string, string>? headers,
string? connectionName,
CancellationToken cancellationToken)
{
string normalizedUrl = serverUrl.Trim().ToUpperInvariant();
string clientCacheKey = $"{normalizedUrl}|{ComputeHeadersHash(headers)}";
string trimmedUrl = serverUrl.Trim();
var clientCacheKey = BuildCacheKey(trimmedUrl, serverLabel, connectionName, headers);
await this._clientLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
@@ -158,7 +158,7 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
return existingClient;
}
McpClient newClient = await this.CreateClientAsync(serverUrl, serverLabel, headers, normalizedUrl, cancellationToken).ConfigureAwait(false);
McpClient newClient = await this.CreateClientAsync(trimmedUrl, serverLabel, headers, trimmedUrl, cancellationToken).ConfigureAwait(false);
this._clients[clientCacheKey] = newClient;
return newClient;
}
@@ -168,6 +168,19 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
}
}
/// <summary>
/// Builds the per-client cache key as a 4-tuple of
/// (trimmed serverUrl, serverLabel, connectionName, headers hash). All four components
/// participate so that callers using different labels/connections/headers receive
/// distinct <see cref="McpClient"/> instances even when targeting the same URL.
/// </summary>
internal static (string Url, string Label, string Connection, string HeadersHash) BuildCacheKey(
string trimmedUrl,
string? serverLabel,
string? connectionName,
IDictionary<string, string>? headers) =>
(trimmedUrl, serverLabel ?? string.Empty, connectionName ?? string.Empty, ComputeHeadersHash(headers));
private async Task<McpClient> CreateClientAsync(
string serverUrl,
string? serverLabel,
@@ -185,7 +198,12 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
if (httpClient is null && !this._ownedHttpClients.TryGetValue(httpClientCacheKey, out httpClient))
{
httpClient = new HttpClient();
// Disable cookies so handler-level state (cookie jar) cannot cross the cache-key
// isolation boundary established by GetOrCreateClientAsync. The actual MCP auth
// travels via AdditionalHeaders (set per-transport below), not session cookies.
// CheckCertificateRevocationList satisfies CA5399 since we're explicitly constructing the handler.
HttpClientHandler handler = new() { UseCookies = false, CheckCertificateRevocationList = true };
httpClient = new HttpClient(handler);
this._ownedHttpClients[httpClientCacheKey] = httpClient;
}
@@ -202,26 +220,50 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false);
}
private static string ComputeHeadersHash(IDictionary<string, string>? headers)
/// <summary>
/// Computes a deterministic, order-independent hash of the header set.
/// Header names are lower-cased for case-insensitive matching (RFC 7230 §3.2).
/// Header values remain case-sensitive (RFC 7235 — credentials are case-sensitive).
/// </summary>
#pragma warning disable CA1308 // RFC 7230 §3.2 requires lower-cased header names for case-insensitive comparison; CA1308's uppercase preference does not apply here
internal static string ComputeHeadersHash(IDictionary<string, string>? headers)
{
if (headers is null || headers.Count == 0)
{
return string.Empty;
}
// Build a deterministic, sorted representation of the headers
// Within a single process lifetime, the hashcodes are consistent.
// This will ensure that the same set of headers always produces the same hash, regardless of order.
SortedDictionary<string, string> sorted = new(headers.ToDictionary(h => h.Key.ToUpperInvariant(), h => h.Value.ToUpperInvariant()));
int hashCode = 17;
foreach (KeyValuePair<string, string> kvp in sorted)
// Sort by lower-cased key for deterministic ordering, preserving value case.
SortedDictionary<string, string> sorted = new(StringComparer.Ordinal);
foreach (KeyValuePair<string, string> header in headers)
{
hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Key);
hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Value);
sorted[header.Key.ToLowerInvariant()] = header.Value;
}
return hashCode.ToString(CultureInfo.InvariantCulture);
StringBuilder payload = new();
foreach (KeyValuePair<string, string> kvp in sorted)
{
payload.Append(kvp.Key).Append(':').Append(kvp.Value).Append('\n');
}
byte[] inputBytes = Encoding.UTF8.GetBytes(payload.ToString());
#if NET5_0_OR_GREATER
byte[] hashBytes = SHA256.HashData(inputBytes);
#else
using SHA256 sha256 = SHA256.Create();
byte[] hashBytes = sha256.ComputeHash(inputBytes);
#endif
// Convert to hex string (compatible with net472/netstandard2.0)
StringBuilder hex = new(hashBytes.Length * 2);
foreach (byte b in hashBytes)
{
hex.Append(b.ToString("X2", System.Globalization.CultureInfo.InvariantCulture));
}
return hex.ToString();
}
#pragma warning restore CA1308
private static void ThrowIfListToolsArgumentsSpecified(IDictionary<string, object?>? arguments)
{
@@ -13,6 +13,7 @@ using Microsoft.Agents.AI.Workflows.Declarative.Kit;
using Microsoft.Agents.AI.Workflows.Declarative.PowerFx;
using Microsoft.Agents.ObjectModel;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Agents.AI.Workflows.Declarative.ObjectModel;
@@ -27,6 +28,13 @@ internal sealed class InvokeFunctionToolExecutor(
WorkflowFormulaState state) :
DeclarativeActionExecutor<InvokeFunctionTool>(model, state)
{
private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshot);
/// <summary>
/// Snapshot of evaluated parameters at approval-request time.
/// </summary>
private ApprovalSnapshot? _approvalSnapshot;
/// <summary>
/// Step identifiers for the function tool invocation workflow.
/// </summary>
@@ -69,6 +77,10 @@ internal sealed class InvokeFunctionToolExecutor(
// If approval is required, add user input request content
if (requireApproval)
{
// Snapshot the evaluated parameters.
// If state mutates during the approval window, the approved values are used on resume.
this._approvalSnapshot = new ApprovalSnapshot(functionName, arguments);
requestMessage.Contents.Add(new ToolApprovalRequestContent(this.Id, functionCall));
}
@@ -155,6 +167,31 @@ internal sealed class InvokeFunctionToolExecutor(
// Completes the action after processing the function result.
await context.RaiseCompletionEventAsync(this.Model, cancellationToken).ConfigureAwait(false);
// Clear the approval snapshot after the action completes so a subsequent
// execution of the same executor instance doesn't reuse stale data.
this._approvalSnapshot = null;
await context.QueueStateUpdateAsync<ApprovalSnapshot?>(ApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
/// <remarks>
/// Persists the approval snapshot to workflow state so it survives checkpoint/restore cycles.
/// </remarks>
protected override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, this._approvalSnapshot, null, cancellationToken).ConfigureAwait(false);
await base.OnCheckpointingAsync(context, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
/// <remarks>
/// Restores the approval snapshot from workflow state after a checkpoint restore.
/// </remarks>
protected override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
{
await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false);
this._approvalSnapshot = await context.ReadStateAsync<ApprovalSnapshot>(ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
}
/// <summary>
@@ -262,7 +299,24 @@ internal sealed class InvokeFunctionToolExecutor(
private async ValueTask<FunctionResultContent?> InvokeRegisteredFunctionAsync(CancellationToken cancellationToken)
{
string functionName = this.GetFunctionName();
string functionName;
Dictionary<string, object?>? arguments;
if (this._approvalSnapshot is { } snapshot)
{
// Use the snapshot captured at approval-request time so we invoke exactly what
// the user approved, even if Power Fx state has mutated during the approval window.
functionName = snapshot.FunctionName;
arguments = snapshot.Arguments;
}
else
{
// Fallback for checkpoints created before approval snapshots were introduced.
this.Logger.LogWarning("Approval snapshot missing for '{ActionId}'; falling back to expression re-evaluation.", this.Id);
functionName = this.GetFunctionName();
arguments = this.GetArguments();
}
AIFunction? function = agentProvider.Functions?.FirstOrDefault(
f => string.Equals(f.Name, functionName, StringComparison.Ordinal));
@@ -275,8 +329,7 @@ internal sealed class InvokeFunctionToolExecutor(
};
}
Dictionary<string, object?>? arguments = this.GetArguments();
AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments);
AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments.NormalizePortableValues());
object? result;
try
@@ -341,4 +394,13 @@ internal sealed class InvokeFunctionToolExecutor(
return result;
}
/// <summary>
/// Stores the evaluated parameters at approval-request time so that
/// <see cref="CaptureResponseAsync"/> uses the values the user reviewed,
/// even if <see cref="WorkflowFormulaState"/> mutates during the approval window.
/// </summary>
internal sealed record ApprovalSnapshot(
string FunctionName,
Dictionary<string, object?>? Arguments);
}
@@ -321,6 +321,189 @@ public sealed class DefaultMcpToolHandlerTests
#endregion
#region ComputeHeadersHash Tests
[Fact]
public void ComputeHeadersHash_WithNullHeaders_ReturnsEmptyString()
{
// Act
string result = DefaultMcpToolHandler.ComputeHeadersHash(null);
// Assert
result.Should().BeEmpty();
}
[Fact]
public void ComputeHeadersHash_WithEmptyHeaders_ReturnsEmptyString()
{
// Act
string result = DefaultMcpToolHandler.ComputeHeadersHash(new Dictionary<string, string>());
// Assert
result.Should().BeEmpty();
}
[Fact]
public void ComputeHeadersHash_SameHeadersDifferentOrder_ReturnsSameHash()
{
// Arrange
Dictionary<string, string> headers1 = new()
{
["Authorization"] = "Bearer token123",
["X-Custom"] = "value1"
};
Dictionary<string, string> headers2 = new()
{
["X-Custom"] = "value1",
["Authorization"] = "Bearer token123"
};
// Act
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
// Assert
hash1.Should().Be(hash2);
}
[Fact]
public void ComputeHeadersHash_SameKeysDifferentCaseKeys_ReturnsSameHash()
{
// Arrange — RFC 7230: header names are case-insensitive
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer token" };
Dictionary<string, string> headers2 = new() { ["authorization"] = "Bearer token" };
// Act
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
// Assert
hash1.Should().Be(hash2);
}
[Fact]
public void ComputeHeadersHash_SameKeysDifferentCaseValues_ReturnsDifferentHash()
{
// Arrange — RFC 7235: credentials are case-sensitive
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer ABC" };
Dictionary<string, string> headers2 = new() { ["Authorization"] = "Bearer abc" };
// Act
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
// Assert
hash1.Should().NotBe(hash2);
}
[Fact]
public void ComputeHeadersHash_DifferentHeaders_ReturnsDifferentHash()
{
// Arrange
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer token1" };
Dictionary<string, string> headers2 = new() { ["Authorization"] = "Bearer token2" };
// Act
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
// Assert
hash1.Should().NotBe(hash2);
}
#endregion
#region Cache Key Discrimination Tests
// These tests exercise BuildCacheKey directly because the integration path
// (InvokeToolAsync against a fake server) doesn't surface cache-hit behavior
// without standing up a real MCP server — McpClient.CreateAsync fails before
// _clients[key] = newClient runs, so nothing ever gets cached.
// Tuple equality on the returned 4-tuple verifies that the dimensions
// collectively discriminate cache entries.
[Fact]
public void BuildCacheKey_SameInputs_ReturnsEqualKeys()
{
// Arrange
Dictionary<string, string> headers = new() { ["Authorization"] = "Bearer token" };
// Act
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "conn", headers);
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "conn", headers);
// Assert
key1.Should().Be(key2);
}
[Fact]
public void BuildCacheKey_DifferentConnectionName_ReturnsDifferentKeys()
{
// Act
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "connection-a", null);
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "connection-b", null);
// Assert
key1.Should().NotBe(key2);
key1.Connection.Should().Be("connection-a");
key2.Connection.Should().Be("connection-b");
}
[Fact]
public void BuildCacheKey_DifferentServerLabel_ReturnsDifferentKeys()
{
// Act
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label-a", null, null);
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label-b", null, null);
// Assert
key1.Should().NotBe(key2);
key1.Label.Should().Be("label-a");
key2.Label.Should().Be("label-b");
}
[Fact]
public void BuildCacheKey_CaseSensitiveUrlPath_ReturnsDifferentKeys()
{
// Arrange — RFC 3986: URL path is case-sensitive
// Act
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/Tools", null, null, null);
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/tools", null, null, null);
// Assert
key1.Should().NotBe(key2);
}
[Fact]
public void BuildCacheKey_HeaderValuesCaseSensitive_ReturnsDifferentKeys()
{
// Arrange — RFC 7235: credentials are case-sensitive
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer ABC" };
Dictionary<string, string> headers2 = new() { ["Authorization"] = "Bearer abc" };
// Act
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, headers1);
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, headers2);
// Assert — header value case must propagate into the cache key
key1.Should().NotBe(key2);
key1.HeadersHash.Should().NotBe(key2.HeadersHash);
}
[Fact]
public void BuildCacheKey_NullLabelAndConnection_NormalizesToEmptyString()
{
// Act
var key = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, null);
// Assert — verifies null-safety contract callers rely on
key.Label.Should().BeEmpty();
key.Connection.Should().BeEmpty();
key.HeadersHash.Should().BeEmpty();
}
#endregion
#region Reserved Tools/List Tests
[Fact]
@@ -1,11 +1,21 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.AI.Workflows.Declarative.Events;
using Microsoft.Agents.AI.Workflows.Declarative.Kit;
using Microsoft.Agents.AI.Workflows.Declarative.ObjectModel;
using Microsoft.Agents.AI.Workflows.Declarative.PowerFx;
using Microsoft.Agents.ObjectModel;
using Microsoft.Extensions.AI;
using Microsoft.PowerFx.Types;
using Moq;
using ApprovalSnapshot = Microsoft.Agents.AI.Workflows.Declarative.ObjectModel.InvokeFunctionToolExecutor.ApprovalSnapshot;
namespace Microsoft.Agents.AI.Workflows.Declarative.UnitTests.ObjectModel;
@@ -261,6 +271,323 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
#endregion
#region Approval Snapshot Security Tests
/// <summary>
/// Verifies that mutating the function-name variable after approval does not change
/// which function is actually invoked. The originally-approved name must be used.
/// </summary>
[Fact]
public async Task InvokeFunctionToolCaptureResponseUsesApprovedFunctionNameNotMutatedAsync()
{
// Arrange
const string ApprovedFunctionName = "safe_readonly_query";
const string MutatedFunctionName = "dangerous_admin_tool";
this.State.Set("TargetFunction", FormulaValue.New(ApprovedFunctionName));
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModelWithVariableFunctionName(
displayName: nameof(InvokeFunctionToolCaptureResponseUsesApprovedFunctionNameNotMutatedAsync),
variableName: "TargetFunction");
string? capturedFunctionName = null;
TestFunctionAgentProvider testAgentProvider = new(
[
AIFunctionFactory.Create(() => "safe-result", name: ApprovedFunctionName),
AIFunctionFactory.Create(() => "dangerous-result", name: MutatedFunctionName),
],
onInvoke: name => capturedFunctionName = name);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate parallel branch mutating state during the approval window
this.State.Set("TargetFunction", FormulaValue.New(MutatedFunctionName));
this.State.Bind();
// User clicks approve (they saw "safe_readonly_query" in the approval UI)
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the originally-approved function must be invoked, not the mutated one
Assert.NotNull(capturedFunctionName);
Assert.Equal(ApprovedFunctionName, capturedFunctionName);
}
/// <summary>
/// Verifies that mutating an argument variable after approval does not change
/// the arguments actually passed to the invoked function.
/// </summary>
[Fact]
public async Task InvokeFunctionToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync()
{
// Arrange
const string FunctionName = "process_query";
const string ArgumentKey = "query";
const string ApprovedQuery = "SELECT * FROM users LIMIT 10";
const string MutatedQuery = "DROP TABLE users CASCADE; --";
this.State.Set("SqlQuery", FormulaValue.New(ApprovedQuery));
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModelWithVariableArgument(
displayName: nameof(InvokeFunctionToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync),
functionName: FunctionName,
argumentKey: ArgumentKey,
variableName: "SqlQuery");
AIFunctionArguments? capturedArguments = null;
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create((string query) => $"executed:{query}", name: FunctionName)],
onInvokeArguments: args => capturedArguments = args);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate parallel branch mutating state during the approval window
this.State.Set("SqlQuery", FormulaValue.New(MutatedQuery));
this.State.Bind();
// User clicks approve
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the originally-approved argument must be used, not the mutated one
Assert.NotNull(capturedArguments);
Assert.Equal(ApprovedQuery, capturedArguments[ArgumentKey]?.ToString());
}
/// <summary>
/// Verifies that the approval snapshot survives a checkpoint/restore cycle.
/// After restore, the originally-approved function must still be used even if state was mutated.
/// </summary>
[Fact]
public async Task InvokeFunctionToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync()
{
// Arrange
const string ApprovedFunctionName = "safe_readonly_query";
const string MutatedFunctionName = "dangerous_admin_tool";
this.State.Set("TargetFunction", FormulaValue.New(ApprovedFunctionName));
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModelWithVariableFunctionName(
displayName: nameof(InvokeFunctionToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync),
variableName: "TargetFunction");
string? capturedFunctionName = null;
TestFunctionAgentProvider testAgentProvider = new(
[
AIFunctionFactory.Create(() => "safe-result", name: ApprovedFunctionName),
AIFunctionFactory.Create(() => "dangerous-result", name: MutatedFunctionName),
],
onInvoke: name => capturedFunctionName = name);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - trigger ExecuteAsync to store the approval snapshot
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore();
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Simulate checkpoint: persist to state store
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
// Simulate restore on a "new" executor instance by clearing the in-memory field via reflection
// (In production, a new executor instance would be created with _approvalSnapshot == null)
typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!
.SetValue(action, null);
// Restore from state store
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
// Mutate state after restore (simulating parallel branch)
this.State.Set("TargetFunction", FormulaValue.New(MutatedFunctionName));
this.State.Bind();
// User clicks approve
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
// Resume after approval
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - the originally-approved function must be invoked, not the mutated one
Assert.NotNull(capturedFunctionName);
Assert.Equal(ApprovedFunctionName, capturedFunctionName);
}
/// <summary>
/// Verifies that the approval snapshot is cleared after a completed approval cycle,
/// both in-memory and in the persisted state store. This prevents stale data from
/// influencing a subsequent execution of the same executor instance.
/// </summary>
[Fact]
public async Task InvokeFunctionToolCaptureResponseClearsSnapshotAfterCompletionAsync()
{
// Arrange
const string FunctionName = "any_function";
this.State.InitializeSystem();
this.State.Bind();
InvokeFunctionTool model = this.CreateModel(
displayName: nameof(InvokeFunctionToolCaptureResponseClearsSnapshotAfterCompletionAsync),
functionName: FunctionName,
requireApproval: true);
TestFunctionAgentProvider testAgentProvider = new(
[AIFunctionFactory.Create(() => "result", name: FunctionName)]);
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
// Act - run the full approval cycle
Dictionary<string, object?> stateStore = [];
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore);
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
// Sanity: snapshot was captured
FieldInfo snapshotField = typeof(InvokeFunctionToolExecutor)
.GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!;
Assert.NotNull(snapshotField.GetValue(action));
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
// Assert - both in-memory field and persisted state are cleared
Assert.Null(snapshotField.GetValue(action));
Assert.True(stateStore.ContainsKey("_approvalSnapshot"));
Assert.Null(stateStore["_approvalSnapshot"]);
}
private static ExternalInputResponse CreateApprovalResponse(string actionId, bool approved)
{
FunctionCallContent functionCall = new(callId: actionId, name: "ignored");
ToolApprovalRequestContent approvalRequest = new(actionId, functionCall);
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved);
return new ExternalInputResponse(new ChatMessage(ChatRole.User, [approvalResponse]));
}
private static Mock<IWorkflowContext> CreateMockWorkflowContext()
{
Mock<IWorkflowContext> mockContext = new();
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<object?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
return mockContext;
}
/// <summary>
/// Creates a mock workflow context that actually stores state values (for checkpoint/restore tests).
/// Optionally accepts an externally-owned dictionary so callers can inspect the persisted state.
/// </summary>
private static Mock<IWorkflowContext> CreateMockWorkflowContextWithStateStore(Dictionary<string, object?>? stateStore = null)
{
stateStore ??= [];
Mock<IWorkflowContext> mockContext = new();
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<ApprovalSnapshot?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Callback<string, ApprovalSnapshot?, string?, CancellationToken>((key, value, _, _) => stateStore[key] = value)
.Returns(default(ValueTask));
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns(default(ValueTask));
mockContext.Setup(c => c.ReadStateAsync<ApprovalSnapshot>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.Returns<string, string?, CancellationToken>((key, _, _) =>
new ValueTask<ApprovalSnapshot?>(stateStore.TryGetValue(key, out object? val) ? val as ApprovalSnapshot : null));
mockContext.Setup(c => c.ReadStateKeysAsync(It.IsAny<string?>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new HashSet<string>());
return mockContext;
}
/// <summary>
/// Invokes a protected method on the executor via reflection (for testing checkpoint hooks).
/// </summary>
private static async ValueTask InvokeProtectedMethodAsync(InvokeFunctionToolExecutor action, string methodName, IWorkflowContext context, CancellationToken cancellationToken)
{
MethodInfo method = typeof(InvokeFunctionToolExecutor)
.GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance)!;
ValueTask result = (ValueTask)method.Invoke(action, [context, cancellationToken])!;
await result.ConfigureAwait(false);
}
/// <summary>
/// Minimal concrete <see cref="ResponseAgentProvider"/> that exposes an injected
/// <see cref="AIFunction"/> registry and records which function got invoked.
/// Used by the framework-invoke approval branch (<c>InvokeRegisteredFunctionAsync</c>).
/// </summary>
private sealed class TestFunctionAgentProvider : ResponseAgentProvider
{
private readonly Action<string>? _onInvoke;
private readonly Action<AIFunctionArguments>? _onInvokeArguments;
public TestFunctionAgentProvider(
IEnumerable<AIFunction> functions,
Action<string>? onInvoke = null,
Action<AIFunctionArguments>? onInvokeArguments = null)
{
this._onInvoke = onInvoke;
this._onInvokeArguments = onInvokeArguments;
this.Functions = functions.Select(f => (AIFunction)new RecordingAIFunction(f, this)).ToList();
}
internal void RecordInvocation(string name, AIFunctionArguments? arguments)
{
this._onInvoke?.Invoke(name);
if (arguments is not null)
{
this._onInvokeArguments?.Invoke(arguments);
}
}
public override Task<string> CreateConversationAsync(CancellationToken cancellationToken = default) =>
throw new NotSupportedException();
public override Task<ChatMessage> CreateMessageAsync(string conversationId, ChatMessage conversationMessage, CancellationToken cancellationToken = default) =>
throw new NotSupportedException();
public override Task<ChatMessage> GetMessageAsync(string conversationId, string messageId, CancellationToken cancellationToken = default) =>
throw new NotSupportedException();
public override IAsyncEnumerable<AgentResponseUpdate> InvokeAgentAsync(
string agentId, string? agentVersion, string? conversationId,
IEnumerable<ChatMessage>? messages, IDictionary<string, object?>? inputArguments,
CancellationToken cancellationToken = default) =>
throw new NotSupportedException();
public override IAsyncEnumerable<ChatMessage> GetMessagesAsync(
string conversationId, int? limit = null, string? after = null, string? before = null,
bool newestFirst = false, CancellationToken cancellationToken = default) =>
throw new NotSupportedException();
private sealed class RecordingAIFunction(AIFunction inner, TestFunctionAgentProvider owner) : AIFunction
{
public override string Name => inner.Name;
public override string Description => inner.Description;
public override JsonElement JsonSchema => inner.JsonSchema;
protected override ValueTask<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken)
{
owner.RecordInvocation(inner.Name, arguments);
return inner.InvokeAsync(arguments, cancellationToken);
}
}
}
#endregion
#region Helper Methods
private async Task ExecuteTestAsync(InvokeFunctionTool model)
@@ -318,5 +645,33 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
return AssignParent<InvokeFunctionTool>(builder);
}
private InvokeFunctionTool CreateModelWithVariableFunctionName(string displayName, string variableName)
{
InvokeFunctionTool.Builder builder = new()
{
Id = this.CreateActionId(),
DisplayName = this.FormatDisplayName(displayName),
FunctionName = new StringExpression.Builder(
StringExpression.Variable(PropertyPath.TopicVariable(variableName))),
RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)),
};
return AssignParent<InvokeFunctionTool>(builder);
}
private InvokeFunctionTool CreateModelWithVariableArgument(
string displayName, string functionName, string argumentKey, string variableName)
{
InvokeFunctionTool.Builder builder = new()
{
Id = this.CreateActionId(),
DisplayName = this.FormatDisplayName(displayName),
FunctionName = new StringExpression.Builder(StringExpression.Literal(functionName)),
RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)),
};
builder.Arguments.Add(argumentKey,
ValueExpression.Variable(PropertyPath.TopicVariable(variableName)));
return AssignParent<InvokeFunctionTool>(builder);
}
#endregion
}