mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: Bug fixes for declarative workflows (#6427)
* declarative workflow approval flow fix * Update mcp handler cache construction * fix method argument. * Update dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeFunctionToolExecutor.cs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix identation --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
60cc5ee4e4
commit
3753d938f5
@@ -2,10 +2,10 @@
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Net.Http;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
@@ -39,7 +39,7 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
|
||||
private static readonly JsonWriterOptions s_toolListJsonWriterOptions = new() { Indented = true };
|
||||
|
||||
private readonly Func<string, CancellationToken, Task<HttpClient?>>? _httpClientProvider;
|
||||
private readonly Dictionary<string, McpClient> _clients = [];
|
||||
private readonly Dictionary<(string Url, string Label, string Connection, string HeadersHash), McpClient> _clients = [];
|
||||
private readonly Dictionary<string, HttpClient> _ownedHttpClients = [];
|
||||
private readonly SemaphoreSlim _clientLock = new(1, 1);
|
||||
|
||||
@@ -66,16 +66,15 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
|
||||
string? connectionName,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
// TODO: Handle connectionName and server label appropriately when Hosted scenario supports them. For now, ignore
|
||||
if (IsListToolsToolName(toolName))
|
||||
{
|
||||
ThrowIfListToolsArgumentsSpecified(arguments);
|
||||
McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false);
|
||||
McpClient listToolsClient = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false);
|
||||
IList<McpClientTool> tools = await listToolsClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
|
||||
return CreateListToolsResultContent(tools.Select(tool => tool.ProtocolTool));
|
||||
}
|
||||
|
||||
McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, cancellationToken).ConfigureAwait(false);
|
||||
McpClient client = await this.GetOrCreateClientAsync(serverUrl, serverLabel, headers, connectionName, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
McpServerToolResultContent resultContent = new(Guid.NewGuid().ToString());
|
||||
|
||||
@@ -145,10 +144,11 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
|
||||
string serverUrl,
|
||||
string? serverLabel,
|
||||
IDictionary<string, string>? headers,
|
||||
string? connectionName,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
string normalizedUrl = serverUrl.Trim().ToUpperInvariant();
|
||||
string clientCacheKey = $"{normalizedUrl}|{ComputeHeadersHash(headers)}";
|
||||
string trimmedUrl = serverUrl.Trim();
|
||||
var clientCacheKey = BuildCacheKey(trimmedUrl, serverLabel, connectionName, headers);
|
||||
|
||||
await this._clientLock.WaitAsync(cancellationToken).ConfigureAwait(false);
|
||||
try
|
||||
@@ -158,7 +158,7 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
|
||||
return existingClient;
|
||||
}
|
||||
|
||||
McpClient newClient = await this.CreateClientAsync(serverUrl, serverLabel, headers, normalizedUrl, cancellationToken).ConfigureAwait(false);
|
||||
McpClient newClient = await this.CreateClientAsync(trimmedUrl, serverLabel, headers, trimmedUrl, cancellationToken).ConfigureAwait(false);
|
||||
this._clients[clientCacheKey] = newClient;
|
||||
return newClient;
|
||||
}
|
||||
@@ -168,6 +168,19 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Builds the per-client cache key as a 4-tuple of
|
||||
/// (trimmed serverUrl, serverLabel, connectionName, headers hash). All four components
|
||||
/// participate so that callers using different labels/connections/headers receive
|
||||
/// distinct <see cref="McpClient"/> instances even when targeting the same URL.
|
||||
/// </summary>
|
||||
internal static (string Url, string Label, string Connection, string HeadersHash) BuildCacheKey(
|
||||
string trimmedUrl,
|
||||
string? serverLabel,
|
||||
string? connectionName,
|
||||
IDictionary<string, string>? headers) =>
|
||||
(trimmedUrl, serverLabel ?? string.Empty, connectionName ?? string.Empty, ComputeHeadersHash(headers));
|
||||
|
||||
private async Task<McpClient> CreateClientAsync(
|
||||
string serverUrl,
|
||||
string? serverLabel,
|
||||
@@ -185,7 +198,12 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
|
||||
|
||||
if (httpClient is null && !this._ownedHttpClients.TryGetValue(httpClientCacheKey, out httpClient))
|
||||
{
|
||||
httpClient = new HttpClient();
|
||||
// Disable cookies so handler-level state (cookie jar) cannot cross the cache-key
|
||||
// isolation boundary established by GetOrCreateClientAsync. The actual MCP auth
|
||||
// travels via AdditionalHeaders (set per-transport below), not session cookies.
|
||||
// CheckCertificateRevocationList satisfies CA5399 since we're explicitly constructing the handler.
|
||||
HttpClientHandler handler = new() { UseCookies = false, CheckCertificateRevocationList = true };
|
||||
httpClient = new HttpClient(handler);
|
||||
this._ownedHttpClients[httpClientCacheKey] = httpClient;
|
||||
}
|
||||
|
||||
@@ -202,26 +220,50 @@ public sealed class DefaultMcpToolHandler : IMcpToolHandler, IAsyncDisposable
|
||||
return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private static string ComputeHeadersHash(IDictionary<string, string>? headers)
|
||||
/// <summary>
|
||||
/// Computes a deterministic, order-independent hash of the header set.
|
||||
/// Header names are lower-cased for case-insensitive matching (RFC 7230 §3.2).
|
||||
/// Header values remain case-sensitive (RFC 7235 — credentials are case-sensitive).
|
||||
/// </summary>
|
||||
#pragma warning disable CA1308 // RFC 7230 §3.2 requires lower-cased header names for case-insensitive comparison; CA1308's uppercase preference does not apply here
|
||||
internal static string ComputeHeadersHash(IDictionary<string, string>? headers)
|
||||
{
|
||||
if (headers is null || headers.Count == 0)
|
||||
{
|
||||
return string.Empty;
|
||||
}
|
||||
|
||||
// Build a deterministic, sorted representation of the headers
|
||||
// Within a single process lifetime, the hashcodes are consistent.
|
||||
// This will ensure that the same set of headers always produces the same hash, regardless of order.
|
||||
SortedDictionary<string, string> sorted = new(headers.ToDictionary(h => h.Key.ToUpperInvariant(), h => h.Value.ToUpperInvariant()));
|
||||
int hashCode = 17;
|
||||
foreach (KeyValuePair<string, string> kvp in sorted)
|
||||
// Sort by lower-cased key for deterministic ordering, preserving value case.
|
||||
SortedDictionary<string, string> sorted = new(StringComparer.Ordinal);
|
||||
foreach (KeyValuePair<string, string> header in headers)
|
||||
{
|
||||
hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Key);
|
||||
hashCode = (hashCode * 31) + StringComparer.OrdinalIgnoreCase.GetHashCode(kvp.Value);
|
||||
sorted[header.Key.ToLowerInvariant()] = header.Value;
|
||||
}
|
||||
|
||||
return hashCode.ToString(CultureInfo.InvariantCulture);
|
||||
StringBuilder payload = new();
|
||||
foreach (KeyValuePair<string, string> kvp in sorted)
|
||||
{
|
||||
payload.Append(kvp.Key).Append(':').Append(kvp.Value).Append('\n');
|
||||
}
|
||||
|
||||
byte[] inputBytes = Encoding.UTF8.GetBytes(payload.ToString());
|
||||
#if NET5_0_OR_GREATER
|
||||
byte[] hashBytes = SHA256.HashData(inputBytes);
|
||||
#else
|
||||
using SHA256 sha256 = SHA256.Create();
|
||||
byte[] hashBytes = sha256.ComputeHash(inputBytes);
|
||||
#endif
|
||||
|
||||
// Convert to hex string (compatible with net472/netstandard2.0)
|
||||
StringBuilder hex = new(hashBytes.Length * 2);
|
||||
foreach (byte b in hashBytes)
|
||||
{
|
||||
hex.Append(b.ToString("X2", System.Globalization.CultureInfo.InvariantCulture));
|
||||
}
|
||||
|
||||
return hex.ToString();
|
||||
}
|
||||
#pragma warning restore CA1308
|
||||
|
||||
private static void ThrowIfListToolsArgumentsSpecified(IDictionary<string, object?>? arguments)
|
||||
{
|
||||
|
||||
+65
-3
@@ -13,6 +13,7 @@ using Microsoft.Agents.AI.Workflows.Declarative.Kit;
|
||||
using Microsoft.Agents.AI.Workflows.Declarative.PowerFx;
|
||||
using Microsoft.Agents.ObjectModel;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
|
||||
namespace Microsoft.Agents.AI.Workflows.Declarative.ObjectModel;
|
||||
@@ -27,6 +28,13 @@ internal sealed class InvokeFunctionToolExecutor(
|
||||
WorkflowFormulaState state) :
|
||||
DeclarativeActionExecutor<InvokeFunctionTool>(model, state)
|
||||
{
|
||||
private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshot);
|
||||
|
||||
/// <summary>
|
||||
/// Snapshot of evaluated parameters at approval-request time.
|
||||
/// </summary>
|
||||
private ApprovalSnapshot? _approvalSnapshot;
|
||||
|
||||
/// <summary>
|
||||
/// Step identifiers for the function tool invocation workflow.
|
||||
/// </summary>
|
||||
@@ -69,6 +77,10 @@ internal sealed class InvokeFunctionToolExecutor(
|
||||
// If approval is required, add user input request content
|
||||
if (requireApproval)
|
||||
{
|
||||
// Snapshot the evaluated parameters.
|
||||
// If state mutates during the approval window, the approved values are used on resume.
|
||||
this._approvalSnapshot = new ApprovalSnapshot(functionName, arguments);
|
||||
|
||||
requestMessage.Contents.Add(new ToolApprovalRequestContent(this.Id, functionCall));
|
||||
}
|
||||
|
||||
@@ -155,6 +167,31 @@ internal sealed class InvokeFunctionToolExecutor(
|
||||
|
||||
// Completes the action after processing the function result.
|
||||
await context.RaiseCompletionEventAsync(this.Model, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
// Clear the approval snapshot after the action completes so a subsequent
|
||||
// execution of the same executor instance doesn't reuse stale data.
|
||||
this._approvalSnapshot = null;
|
||||
await context.QueueStateUpdateAsync<ApprovalSnapshot?>(ApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
/// <remarks>
|
||||
/// Persists the approval snapshot to workflow state so it survives checkpoint/restore cycles.
|
||||
/// </remarks>
|
||||
protected override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, this._approvalSnapshot, null, cancellationToken).ConfigureAwait(false);
|
||||
await base.OnCheckpointingAsync(context, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
/// <remarks>
|
||||
/// Restores the approval snapshot from workflow state after a checkpoint restore.
|
||||
/// </remarks>
|
||||
protected override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default)
|
||||
{
|
||||
await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false);
|
||||
this._approvalSnapshot = await context.ReadStateAsync<ApprovalSnapshot>(ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -262,7 +299,24 @@ internal sealed class InvokeFunctionToolExecutor(
|
||||
|
||||
private async ValueTask<FunctionResultContent?> InvokeRegisteredFunctionAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
string functionName = this.GetFunctionName();
|
||||
string functionName;
|
||||
Dictionary<string, object?>? arguments;
|
||||
|
||||
if (this._approvalSnapshot is { } snapshot)
|
||||
{
|
||||
// Use the snapshot captured at approval-request time so we invoke exactly what
|
||||
// the user approved, even if Power Fx state has mutated during the approval window.
|
||||
functionName = snapshot.FunctionName;
|
||||
arguments = snapshot.Arguments;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Fallback for checkpoints created before approval snapshots were introduced.
|
||||
this.Logger.LogWarning("Approval snapshot missing for '{ActionId}'; falling back to expression re-evaluation.", this.Id);
|
||||
functionName = this.GetFunctionName();
|
||||
arguments = this.GetArguments();
|
||||
}
|
||||
|
||||
AIFunction? function = agentProvider.Functions?.FirstOrDefault(
|
||||
f => string.Equals(f.Name, functionName, StringComparison.Ordinal));
|
||||
|
||||
@@ -275,8 +329,7 @@ internal sealed class InvokeFunctionToolExecutor(
|
||||
};
|
||||
}
|
||||
|
||||
Dictionary<string, object?>? arguments = this.GetArguments();
|
||||
AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments);
|
||||
AIFunctionArguments? functionArguments = arguments is null ? null : new AIFunctionArguments(arguments.NormalizePortableValues());
|
||||
|
||||
object? result;
|
||||
try
|
||||
@@ -341,4 +394,13 @@ internal sealed class InvokeFunctionToolExecutor(
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Stores the evaluated parameters at approval-request time so that
|
||||
/// <see cref="CaptureResponseAsync"/> uses the values the user reviewed,
|
||||
/// even if <see cref="WorkflowFormulaState"/> mutates during the approval window.
|
||||
/// </summary>
|
||||
internal sealed record ApprovalSnapshot(
|
||||
string FunctionName,
|
||||
Dictionary<string, object?>? Arguments);
|
||||
}
|
||||
|
||||
+183
@@ -321,6 +321,189 @@ public sealed class DefaultMcpToolHandlerTests
|
||||
|
||||
#endregion
|
||||
|
||||
#region ComputeHeadersHash Tests
|
||||
|
||||
[Fact]
|
||||
public void ComputeHeadersHash_WithNullHeaders_ReturnsEmptyString()
|
||||
{
|
||||
// Act
|
||||
string result = DefaultMcpToolHandler.ComputeHeadersHash(null);
|
||||
|
||||
// Assert
|
||||
result.Should().BeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ComputeHeadersHash_WithEmptyHeaders_ReturnsEmptyString()
|
||||
{
|
||||
// Act
|
||||
string result = DefaultMcpToolHandler.ComputeHeadersHash(new Dictionary<string, string>());
|
||||
|
||||
// Assert
|
||||
result.Should().BeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ComputeHeadersHash_SameHeadersDifferentOrder_ReturnsSameHash()
|
||||
{
|
||||
// Arrange
|
||||
Dictionary<string, string> headers1 = new()
|
||||
{
|
||||
["Authorization"] = "Bearer token123",
|
||||
["X-Custom"] = "value1"
|
||||
};
|
||||
Dictionary<string, string> headers2 = new()
|
||||
{
|
||||
["X-Custom"] = "value1",
|
||||
["Authorization"] = "Bearer token123"
|
||||
};
|
||||
|
||||
// Act
|
||||
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
|
||||
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
|
||||
|
||||
// Assert
|
||||
hash1.Should().Be(hash2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ComputeHeadersHash_SameKeysDifferentCaseKeys_ReturnsSameHash()
|
||||
{
|
||||
// Arrange — RFC 7230: header names are case-insensitive
|
||||
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer token" };
|
||||
Dictionary<string, string> headers2 = new() { ["authorization"] = "Bearer token" };
|
||||
|
||||
// Act
|
||||
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
|
||||
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
|
||||
|
||||
// Assert
|
||||
hash1.Should().Be(hash2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ComputeHeadersHash_SameKeysDifferentCaseValues_ReturnsDifferentHash()
|
||||
{
|
||||
// Arrange — RFC 7235: credentials are case-sensitive
|
||||
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer ABC" };
|
||||
Dictionary<string, string> headers2 = new() { ["Authorization"] = "Bearer abc" };
|
||||
|
||||
// Act
|
||||
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
|
||||
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
|
||||
|
||||
// Assert
|
||||
hash1.Should().NotBe(hash2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ComputeHeadersHash_DifferentHeaders_ReturnsDifferentHash()
|
||||
{
|
||||
// Arrange
|
||||
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer token1" };
|
||||
Dictionary<string, string> headers2 = new() { ["Authorization"] = "Bearer token2" };
|
||||
|
||||
// Act
|
||||
string hash1 = DefaultMcpToolHandler.ComputeHeadersHash(headers1);
|
||||
string hash2 = DefaultMcpToolHandler.ComputeHeadersHash(headers2);
|
||||
|
||||
// Assert
|
||||
hash1.Should().NotBe(hash2);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Cache Key Discrimination Tests
|
||||
|
||||
// These tests exercise BuildCacheKey directly because the integration path
|
||||
// (InvokeToolAsync against a fake server) doesn't surface cache-hit behavior
|
||||
// without standing up a real MCP server — McpClient.CreateAsync fails before
|
||||
// _clients[key] = newClient runs, so nothing ever gets cached.
|
||||
// Tuple equality on the returned 4-tuple verifies that the dimensions
|
||||
// collectively discriminate cache entries.
|
||||
|
||||
[Fact]
|
||||
public void BuildCacheKey_SameInputs_ReturnsEqualKeys()
|
||||
{
|
||||
// Arrange
|
||||
Dictionary<string, string> headers = new() { ["Authorization"] = "Bearer token" };
|
||||
|
||||
// Act
|
||||
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "conn", headers);
|
||||
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "conn", headers);
|
||||
|
||||
// Assert
|
||||
key1.Should().Be(key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCacheKey_DifferentConnectionName_ReturnsDifferentKeys()
|
||||
{
|
||||
// Act
|
||||
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "connection-a", null);
|
||||
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label", "connection-b", null);
|
||||
|
||||
// Assert
|
||||
key1.Should().NotBe(key2);
|
||||
key1.Connection.Should().Be("connection-a");
|
||||
key2.Connection.Should().Be("connection-b");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCacheKey_DifferentServerLabel_ReturnsDifferentKeys()
|
||||
{
|
||||
// Act
|
||||
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label-a", null, null);
|
||||
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", "label-b", null, null);
|
||||
|
||||
// Assert
|
||||
key1.Should().NotBe(key2);
|
||||
key1.Label.Should().Be("label-a");
|
||||
key2.Label.Should().Be("label-b");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCacheKey_CaseSensitiveUrlPath_ReturnsDifferentKeys()
|
||||
{
|
||||
// Arrange — RFC 3986: URL path is case-sensitive
|
||||
// Act
|
||||
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/Tools", null, null, null);
|
||||
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/tools", null, null, null);
|
||||
|
||||
// Assert
|
||||
key1.Should().NotBe(key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCacheKey_HeaderValuesCaseSensitive_ReturnsDifferentKeys()
|
||||
{
|
||||
// Arrange — RFC 7235: credentials are case-sensitive
|
||||
Dictionary<string, string> headers1 = new() { ["Authorization"] = "Bearer ABC" };
|
||||
Dictionary<string, string> headers2 = new() { ["Authorization"] = "Bearer abc" };
|
||||
|
||||
// Act
|
||||
var key1 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, headers1);
|
||||
var key2 = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, headers2);
|
||||
|
||||
// Assert — header value case must propagate into the cache key
|
||||
key1.Should().NotBe(key2);
|
||||
key1.HeadersHash.Should().NotBe(key2.HeadersHash);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCacheKey_NullLabelAndConnection_NormalizesToEmptyString()
|
||||
{
|
||||
// Act
|
||||
var key = DefaultMcpToolHandler.BuildCacheKey("http://localhost/mcp", null, null, null);
|
||||
|
||||
// Assert — verifies null-safety contract callers rely on
|
||||
key.Label.Should().BeEmpty();
|
||||
key.Connection.Should().BeEmpty();
|
||||
key.HeadersHash.Should().BeEmpty();
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Reserved Tools/List Tests
|
||||
|
||||
[Fact]
|
||||
|
||||
+355
@@ -1,11 +1,21 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Text.Json;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Agents.AI.Workflows.Declarative.Events;
|
||||
using Microsoft.Agents.AI.Workflows.Declarative.Kit;
|
||||
using Microsoft.Agents.AI.Workflows.Declarative.ObjectModel;
|
||||
using Microsoft.Agents.AI.Workflows.Declarative.PowerFx;
|
||||
using Microsoft.Agents.ObjectModel;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.PowerFx.Types;
|
||||
using Moq;
|
||||
using ApprovalSnapshot = Microsoft.Agents.AI.Workflows.Declarative.ObjectModel.InvokeFunctionToolExecutor.ApprovalSnapshot;
|
||||
|
||||
namespace Microsoft.Agents.AI.Workflows.Declarative.UnitTests.ObjectModel;
|
||||
|
||||
@@ -261,6 +271,323 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
|
||||
|
||||
#endregion
|
||||
|
||||
#region Approval Snapshot Security Tests
|
||||
|
||||
/// <summary>
|
||||
/// Verifies that mutating the function-name variable after approval does not change
|
||||
/// which function is actually invoked. The originally-approved name must be used.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task InvokeFunctionToolCaptureResponseUsesApprovedFunctionNameNotMutatedAsync()
|
||||
{
|
||||
// Arrange
|
||||
const string ApprovedFunctionName = "safe_readonly_query";
|
||||
const string MutatedFunctionName = "dangerous_admin_tool";
|
||||
|
||||
this.State.Set("TargetFunction", FormulaValue.New(ApprovedFunctionName));
|
||||
this.State.InitializeSystem();
|
||||
this.State.Bind();
|
||||
|
||||
InvokeFunctionTool model = this.CreateModelWithVariableFunctionName(
|
||||
displayName: nameof(InvokeFunctionToolCaptureResponseUsesApprovedFunctionNameNotMutatedAsync),
|
||||
variableName: "TargetFunction");
|
||||
|
||||
string? capturedFunctionName = null;
|
||||
TestFunctionAgentProvider testAgentProvider = new(
|
||||
[
|
||||
AIFunctionFactory.Create(() => "safe-result", name: ApprovedFunctionName),
|
||||
AIFunctionFactory.Create(() => "dangerous-result", name: MutatedFunctionName),
|
||||
],
|
||||
onInvoke: name => capturedFunctionName = name);
|
||||
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
|
||||
|
||||
// Act - trigger ExecuteAsync to store the approval snapshot
|
||||
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
|
||||
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
|
||||
|
||||
// Simulate parallel branch mutating state during the approval window
|
||||
this.State.Set("TargetFunction", FormulaValue.New(MutatedFunctionName));
|
||||
this.State.Bind();
|
||||
|
||||
// User clicks approve (they saw "safe_readonly_query" in the approval UI)
|
||||
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
|
||||
|
||||
// Resume after approval
|
||||
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
|
||||
|
||||
// Assert - the originally-approved function must be invoked, not the mutated one
|
||||
Assert.NotNull(capturedFunctionName);
|
||||
Assert.Equal(ApprovedFunctionName, capturedFunctionName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verifies that mutating an argument variable after approval does not change
|
||||
/// the arguments actually passed to the invoked function.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task InvokeFunctionToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync()
|
||||
{
|
||||
// Arrange
|
||||
const string FunctionName = "process_query";
|
||||
const string ArgumentKey = "query";
|
||||
const string ApprovedQuery = "SELECT * FROM users LIMIT 10";
|
||||
const string MutatedQuery = "DROP TABLE users CASCADE; --";
|
||||
|
||||
this.State.Set("SqlQuery", FormulaValue.New(ApprovedQuery));
|
||||
this.State.InitializeSystem();
|
||||
this.State.Bind();
|
||||
|
||||
InvokeFunctionTool model = this.CreateModelWithVariableArgument(
|
||||
displayName: nameof(InvokeFunctionToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync),
|
||||
functionName: FunctionName,
|
||||
argumentKey: ArgumentKey,
|
||||
variableName: "SqlQuery");
|
||||
|
||||
AIFunctionArguments? capturedArguments = null;
|
||||
TestFunctionAgentProvider testAgentProvider = new(
|
||||
[AIFunctionFactory.Create((string query) => $"executed:{query}", name: FunctionName)],
|
||||
onInvokeArguments: args => capturedArguments = args);
|
||||
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
|
||||
|
||||
// Act - trigger ExecuteAsync to store the approval snapshot
|
||||
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContext();
|
||||
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
|
||||
|
||||
// Simulate parallel branch mutating state during the approval window
|
||||
this.State.Set("SqlQuery", FormulaValue.New(MutatedQuery));
|
||||
this.State.Bind();
|
||||
|
||||
// User clicks approve
|
||||
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
|
||||
|
||||
// Resume after approval
|
||||
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
|
||||
|
||||
// Assert - the originally-approved argument must be used, not the mutated one
|
||||
Assert.NotNull(capturedArguments);
|
||||
Assert.Equal(ApprovedQuery, capturedArguments[ArgumentKey]?.ToString());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verifies that the approval snapshot survives a checkpoint/restore cycle.
|
||||
/// After restore, the originally-approved function must still be used even if state was mutated.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task InvokeFunctionToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync()
|
||||
{
|
||||
// Arrange
|
||||
const string ApprovedFunctionName = "safe_readonly_query";
|
||||
const string MutatedFunctionName = "dangerous_admin_tool";
|
||||
|
||||
this.State.Set("TargetFunction", FormulaValue.New(ApprovedFunctionName));
|
||||
this.State.InitializeSystem();
|
||||
this.State.Bind();
|
||||
|
||||
InvokeFunctionTool model = this.CreateModelWithVariableFunctionName(
|
||||
displayName: nameof(InvokeFunctionToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync),
|
||||
variableName: "TargetFunction");
|
||||
|
||||
string? capturedFunctionName = null;
|
||||
TestFunctionAgentProvider testAgentProvider = new(
|
||||
[
|
||||
AIFunctionFactory.Create(() => "safe-result", name: ApprovedFunctionName),
|
||||
AIFunctionFactory.Create(() => "dangerous-result", name: MutatedFunctionName),
|
||||
],
|
||||
onInvoke: name => capturedFunctionName = name);
|
||||
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
|
||||
|
||||
// Act - trigger ExecuteAsync to store the approval snapshot
|
||||
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore();
|
||||
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
|
||||
|
||||
// Simulate checkpoint: persist to state store
|
||||
await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None);
|
||||
|
||||
// Simulate restore on a "new" executor instance by clearing the in-memory field via reflection
|
||||
// (In production, a new executor instance would be created with _approvalSnapshot == null)
|
||||
typeof(InvokeFunctionToolExecutor)
|
||||
.GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!
|
||||
.SetValue(action, null);
|
||||
|
||||
// Restore from state store
|
||||
await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None);
|
||||
|
||||
// Mutate state after restore (simulating parallel branch)
|
||||
this.State.Set("TargetFunction", FormulaValue.New(MutatedFunctionName));
|
||||
this.State.Bind();
|
||||
|
||||
// User clicks approve
|
||||
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
|
||||
|
||||
// Resume after approval
|
||||
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
|
||||
|
||||
// Assert - the originally-approved function must be invoked, not the mutated one
|
||||
Assert.NotNull(capturedFunctionName);
|
||||
Assert.Equal(ApprovedFunctionName, capturedFunctionName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verifies that the approval snapshot is cleared after a completed approval cycle,
|
||||
/// both in-memory and in the persisted state store. This prevents stale data from
|
||||
/// influencing a subsequent execution of the same executor instance.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public async Task InvokeFunctionToolCaptureResponseClearsSnapshotAfterCompletionAsync()
|
||||
{
|
||||
// Arrange
|
||||
const string FunctionName = "any_function";
|
||||
|
||||
this.State.InitializeSystem();
|
||||
this.State.Bind();
|
||||
|
||||
InvokeFunctionTool model = this.CreateModel(
|
||||
displayName: nameof(InvokeFunctionToolCaptureResponseClearsSnapshotAfterCompletionAsync),
|
||||
functionName: FunctionName,
|
||||
requireApproval: true);
|
||||
|
||||
TestFunctionAgentProvider testAgentProvider = new(
|
||||
[AIFunctionFactory.Create(() => "result", name: FunctionName)]);
|
||||
InvokeFunctionToolExecutor action = new(model, testAgentProvider, this.State);
|
||||
|
||||
// Act - run the full approval cycle
|
||||
Dictionary<string, object?> stateStore = [];
|
||||
Mock<IWorkflowContext> mockContext = CreateMockWorkflowContextWithStateStore(stateStore);
|
||||
await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None);
|
||||
|
||||
// Sanity: snapshot was captured
|
||||
FieldInfo snapshotField = typeof(InvokeFunctionToolExecutor)
|
||||
.GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)!;
|
||||
Assert.NotNull(snapshotField.GetValue(action));
|
||||
|
||||
ExternalInputResponse response = CreateApprovalResponse(action.Id, approved: true);
|
||||
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);
|
||||
|
||||
// Assert - both in-memory field and persisted state are cleared
|
||||
Assert.Null(snapshotField.GetValue(action));
|
||||
Assert.True(stateStore.ContainsKey("_approvalSnapshot"));
|
||||
Assert.Null(stateStore["_approvalSnapshot"]);
|
||||
}
|
||||
|
||||
private static ExternalInputResponse CreateApprovalResponse(string actionId, bool approved)
|
||||
{
|
||||
FunctionCallContent functionCall = new(callId: actionId, name: "ignored");
|
||||
ToolApprovalRequestContent approvalRequest = new(actionId, functionCall);
|
||||
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved);
|
||||
return new ExternalInputResponse(new ChatMessage(ChatRole.User, [approvalResponse]));
|
||||
}
|
||||
|
||||
private static Mock<IWorkflowContext> CreateMockWorkflowContext()
|
||||
{
|
||||
Mock<IWorkflowContext> mockContext = new();
|
||||
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(default(ValueTask));
|
||||
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<object?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(default(ValueTask));
|
||||
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(default(ValueTask));
|
||||
return mockContext;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a mock workflow context that actually stores state values (for checkpoint/restore tests).
|
||||
/// Optionally accepts an externally-owned dictionary so callers can inspect the persisted state.
|
||||
/// </summary>
|
||||
private static Mock<IWorkflowContext> CreateMockWorkflowContextWithStateStore(Dictionary<string, object?>? stateStore = null)
|
||||
{
|
||||
stateStore ??= [];
|
||||
Mock<IWorkflowContext> mockContext = new();
|
||||
mockContext.Setup(c => c.AddEventAsync(It.IsAny<WorkflowEvent>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(default(ValueTask));
|
||||
mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny<string>(), It.IsAny<ApprovalSnapshot?>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
|
||||
.Callback<string, ApprovalSnapshot?, string?, CancellationToken>((key, value, _, _) => stateStore[key] = value)
|
||||
.Returns(default(ValueTask));
|
||||
mockContext.Setup(c => c.SendMessageAsync(It.IsAny<object>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
|
||||
.Returns(default(ValueTask));
|
||||
mockContext.Setup(c => c.ReadStateAsync<ApprovalSnapshot>(It.IsAny<string>(), It.IsAny<string?>(), It.IsAny<CancellationToken>()))
|
||||
.Returns<string, string?, CancellationToken>((key, _, _) =>
|
||||
new ValueTask<ApprovalSnapshot?>(stateStore.TryGetValue(key, out object? val) ? val as ApprovalSnapshot : null));
|
||||
mockContext.Setup(c => c.ReadStateKeysAsync(It.IsAny<string?>(), It.IsAny<CancellationToken>()))
|
||||
.ReturnsAsync(new HashSet<string>());
|
||||
return mockContext;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Invokes a protected method on the executor via reflection (for testing checkpoint hooks).
|
||||
/// </summary>
|
||||
private static async ValueTask InvokeProtectedMethodAsync(InvokeFunctionToolExecutor action, string methodName, IWorkflowContext context, CancellationToken cancellationToken)
|
||||
{
|
||||
MethodInfo method = typeof(InvokeFunctionToolExecutor)
|
||||
.GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance)!;
|
||||
ValueTask result = (ValueTask)method.Invoke(action, [context, cancellationToken])!;
|
||||
await result.ConfigureAwait(false);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Minimal concrete <see cref="ResponseAgentProvider"/> that exposes an injected
|
||||
/// <see cref="AIFunction"/> registry and records which function got invoked.
|
||||
/// Used by the framework-invoke approval branch (<c>InvokeRegisteredFunctionAsync</c>).
|
||||
/// </summary>
|
||||
private sealed class TestFunctionAgentProvider : ResponseAgentProvider
|
||||
{
|
||||
private readonly Action<string>? _onInvoke;
|
||||
private readonly Action<AIFunctionArguments>? _onInvokeArguments;
|
||||
|
||||
public TestFunctionAgentProvider(
|
||||
IEnumerable<AIFunction> functions,
|
||||
Action<string>? onInvoke = null,
|
||||
Action<AIFunctionArguments>? onInvokeArguments = null)
|
||||
{
|
||||
this._onInvoke = onInvoke;
|
||||
this._onInvokeArguments = onInvokeArguments;
|
||||
this.Functions = functions.Select(f => (AIFunction)new RecordingAIFunction(f, this)).ToList();
|
||||
}
|
||||
|
||||
internal void RecordInvocation(string name, AIFunctionArguments? arguments)
|
||||
{
|
||||
this._onInvoke?.Invoke(name);
|
||||
if (arguments is not null)
|
||||
{
|
||||
this._onInvokeArguments?.Invoke(arguments);
|
||||
}
|
||||
}
|
||||
|
||||
public override Task<string> CreateConversationAsync(CancellationToken cancellationToken = default) =>
|
||||
throw new NotSupportedException();
|
||||
|
||||
public override Task<ChatMessage> CreateMessageAsync(string conversationId, ChatMessage conversationMessage, CancellationToken cancellationToken = default) =>
|
||||
throw new NotSupportedException();
|
||||
|
||||
public override Task<ChatMessage> GetMessageAsync(string conversationId, string messageId, CancellationToken cancellationToken = default) =>
|
||||
throw new NotSupportedException();
|
||||
|
||||
public override IAsyncEnumerable<AgentResponseUpdate> InvokeAgentAsync(
|
||||
string agentId, string? agentVersion, string? conversationId,
|
||||
IEnumerable<ChatMessage>? messages, IDictionary<string, object?>? inputArguments,
|
||||
CancellationToken cancellationToken = default) =>
|
||||
throw new NotSupportedException();
|
||||
|
||||
public override IAsyncEnumerable<ChatMessage> GetMessagesAsync(
|
||||
string conversationId, int? limit = null, string? after = null, string? before = null,
|
||||
bool newestFirst = false, CancellationToken cancellationToken = default) =>
|
||||
throw new NotSupportedException();
|
||||
|
||||
private sealed class RecordingAIFunction(AIFunction inner, TestFunctionAgentProvider owner) : AIFunction
|
||||
{
|
||||
public override string Name => inner.Name;
|
||||
public override string Description => inner.Description;
|
||||
public override JsonElement JsonSchema => inner.JsonSchema;
|
||||
|
||||
protected override ValueTask<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken)
|
||||
{
|
||||
owner.RecordInvocation(inner.Name, arguments);
|
||||
return inner.InvokeAsync(arguments, cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helper Methods
|
||||
|
||||
private async Task ExecuteTestAsync(InvokeFunctionTool model)
|
||||
@@ -318,5 +645,33 @@ public sealed class InvokeFunctionToolExecutorTest(ITestOutputHelper output) : W
|
||||
return AssignParent<InvokeFunctionTool>(builder);
|
||||
}
|
||||
|
||||
private InvokeFunctionTool CreateModelWithVariableFunctionName(string displayName, string variableName)
|
||||
{
|
||||
InvokeFunctionTool.Builder builder = new()
|
||||
{
|
||||
Id = this.CreateActionId(),
|
||||
DisplayName = this.FormatDisplayName(displayName),
|
||||
FunctionName = new StringExpression.Builder(
|
||||
StringExpression.Variable(PropertyPath.TopicVariable(variableName))),
|
||||
RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)),
|
||||
};
|
||||
return AssignParent<InvokeFunctionTool>(builder);
|
||||
}
|
||||
|
||||
private InvokeFunctionTool CreateModelWithVariableArgument(
|
||||
string displayName, string functionName, string argumentKey, string variableName)
|
||||
{
|
||||
InvokeFunctionTool.Builder builder = new()
|
||||
{
|
||||
Id = this.CreateActionId(),
|
||||
DisplayName = this.FormatDisplayName(displayName),
|
||||
FunctionName = new StringExpression.Builder(StringExpression.Literal(functionName)),
|
||||
RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)),
|
||||
};
|
||||
builder.Arguments.Add(argumentKey,
|
||||
ValueExpression.Variable(PropertyPath.TopicVariable(variableName)));
|
||||
return AssignParent<InvokeFunctionTool>(builder);
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user