From db253ccda03360a24844768105c0f76d2757e8a0 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 15 Aug 2025 16:41:45 +0100 Subject: [PATCH] .NET: Bring some MEAI types to AF temporarily. (#426) * Bring some MEAI types to AF temporarily. * Exclude the copied files from code completion. --- .../ChatCompletion/ChatClientExtensions.cs | 2 +- .../MEAI/LoggingHelpers.cs | 40 + .../MEAI/NewFunctionInvokingChatClient.cs | 1008 +++++++++++++++++ .../MEAI/OpenTelemetryConsts.cs | 144 +++ 4 files changed, 1193 insertions(+), 1 deletion(-) create mode 100644 dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/LoggingHelpers.cs create mode 100644 dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/NewFunctionInvokingChatClient.cs create mode 100644 dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/OpenTelemetryConsts.cs diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientExtensions.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientExtensions.cs index b0a61e12c0..e459c62297 100644 --- a/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/ChatCompletion/ChatClientExtensions.cs @@ -13,7 +13,7 @@ internal static class ChatClientExtensions chatBuilder.UseAgentInvocation(); } - if (chatClient.GetService() is null) + if (chatClient.GetService() is null) { chatBuilder.UseFunctionInvocation(); } diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/LoggingHelpers.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/LoggingHelpers.cs new file mode 100644 index 0000000000..6a0e13677f --- /dev/null +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/LoggingHelpers.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +// WARNING: +// This class has been temporarily copied here from MEAI, to allow prototyping +// functionality that will be moved to MEAI in the future. +// This file is not intended to be modified. + +#pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable S108 // Nested blocks of code should not be left empty +#pragma warning disable S2486 // Generic exceptions should not be ignored + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; + +namespace Microsoft.Extensions.AI; + +/// Provides internal helpers for implementing logging. +[ExcludeFromCodeCoverage] +internal static class LoggingHelpers +{ + /// Serializes as JSON for logging purposes. + public static string AsJson(T value, JsonSerializerOptions? options) + { + if (options?.TryGetTypeInfo(typeof(T), out var typeInfo) is true || + AIJsonUtilities.DefaultOptions.TryGetTypeInfo(typeof(T), out typeInfo)) + { + try + { + return JsonSerializer.Serialize(value, typeInfo); + } + catch + { + } + } + + // If we're unable to get a type info for the value, or if we fail to serialize, + // return an empty JSON object. We do not want lack of type info to disrupt application behavior with exceptions. + return "{}"; + } +} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/NewFunctionInvokingChatClient.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/NewFunctionInvokingChatClient.cs new file mode 100644 index 0000000000..2b98879adc --- /dev/null +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/NewFunctionInvokingChatClient.cs @@ -0,0 +1,1008 @@ +// Copyright (c) Microsoft. All rights reserved. + +// WARNING: +// This class is a copy of FunctionInvokingChatClient from MEAI, and is intended to be modified with +// changes that we want to prototype here, before updating FunctionInvokingChatClient in MEAI. +// The intention is to keep the changes in this file to a minimum, so that we can easily +// merge them back into MEAI when ready. + +// AF repo suppressions for code copied from MEAI. +#pragma warning disable IDE1006 // Naming Styles +#pragma warning disable IDE0009 // Member access should be qualified. +#pragma warning disable CA2007 // Consider calling ConfigureAwait on the awaited task +#pragma warning disable VSTHRD111 // Use ConfigureAwait(bool) + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable CA2213 // Disposable fields should be disposed +#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test +#pragma warning disable SA1202 // 'protected' members should come before 'private' members +#pragma warning disable S107 // Methods should not have too many parameters + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that invokes functions defined on . +/// Include this in a chat pipeline to resolve function calls automatically. +/// +/// +/// +/// When this client receives a in a chat response, it responds +/// by calling the corresponding defined in , +/// producing a that it sends back to the inner client. This loop +/// is repeated until there are no more function calls to make, or until another stop condition is met, +/// such as hitting . +/// +/// +/// The provided implementation of is thread-safe for concurrent use so long as the +/// instances employed as part of the supplied are also safe. +/// The property can be used to control whether multiple function invocation +/// requests as part of the same request are invocable concurrently, but even with that set to +/// (the default), multiple concurrent requests to this same instance and using the same tools could result in those +/// tools being used concurrently (one per request). For example, a function that accesses the HttpContext of a specific +/// ASP.NET web request should only be used as part of a single at a time, and only with +/// set to , in case the inner client decided to issue multiple +/// invocation requests to that same function. +/// +/// +[ExcludeFromCodeCoverage] +public partial class NewFunctionInvokingChatClient : DelegatingChatClient +{ + /// The for the current function invocation. + private static readonly AsyncLocal _currentContext = new(); + + /// Gets the specified when constructing the , if any. + protected IServiceProvider? FunctionInvocationServices { get; } + + /// The logger to use for logging information about function invocation. + private readonly ILogger _logger; + + /// The to use for telemetry. + /// This component does not own the instance and should not dispose it. + private readonly ActivitySource? _activitySource; + + /// Maximum number of roundtrips allowed to the inner client. + private int _maximumIterationsPerRequest = 40; // arbitrary default to prevent runaway execution + + /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing. + private int _maximumConsecutiveErrorsPerRequest = 3; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , or the next instance in a chain of clients. + /// An to use for logging information about function invocation. + /// An optional to use for resolving services required by the instances being invoked. + public NewFunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) + : base(innerClient) + { + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _activitySource = innerClient.GetService(); + FunctionInvocationServices = functionInvocationServices; + } + + /// + /// Gets or sets the for the current function invocation. + /// + /// + /// This value flows across async calls. + /// + public static FunctionInvocationContext? CurrentContext + { + get => _currentContext.Value; + protected set => _currentContext.Value = value; + } + + /// + /// Gets or sets a value indicating whether detailed exception information should be included + /// in the chat history when calling the underlying . + /// + /// + /// if the full exception message is added to the chat history + /// when calling the underlying . + /// if a generic error message is included in the chat history. + /// The default value is . + /// + /// + /// + /// Setting the value to prevents the underlying language model from disclosing + /// raw exception details to the end user, since it doesn't receive that information. Even in this + /// case, the raw object is available to application code by inspecting + /// the property. + /// + /// + /// Setting the value to can help the underlying bypass problems on + /// its own, for example by retrying the function call with different arguments. However it might + /// result in disclosing the raw exception information to external users, which can be a security + /// concern depending on the application scenario. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to whether detailed errors are provided during an in-flight request. + /// + /// + public bool IncludeDetailedErrors { get; set; } + + /// + /// Gets or sets a value indicating whether to allow concurrent invocation of functions. + /// + /// + /// if multiple function calls can execute in parallel. + /// if function calls are processed serially. + /// The default value is . + /// + /// + /// An individual response from the inner client might contain multiple function call requests. + /// By default, such function calls are processed serially. Set to + /// to enable concurrent invocation such that multiple function calls can execute in parallel. + /// + public bool AllowConcurrentInvocation { get; set; } + + /// + /// Gets or sets the maximum number of iterations per request. + /// + /// + /// The maximum number of iterations per request. + /// The default value is 40. + /// + /// + /// + /// Each request to this might end up making + /// multiple requests to the inner client. Each time the inner client responds with + /// a function call request, this client might perform that invocation and send the results + /// back to the inner client in a new request. This property limits the number of times + /// such a roundtrip is performed. The value must be at least one, as it includes the initial request. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to how many iterations are allowed for an in-flight request. + /// + /// + public int MaximumIterationsPerRequest + { + get => _maximumIterationsPerRequest; + set + { + if (value < 1) + { + Throw.ArgumentOutOfRangeException(nameof(value)); + } + + _maximumIterationsPerRequest = value; + } + } + + /// + /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error. + /// + /// + /// The maximum number of consecutive iterations that are allowed to fail with an error. + /// The default value is 3. + /// + /// + /// + /// When function invocations fail with an exception, the + /// continues to make requests to the inner client, optionally supplying exception information (as + /// controlled by ). This allows the to + /// recover from errors by trying other function parameters that may succeed. + /// + /// + /// However, in case function invocations continue to produce exceptions, this property can be used to + /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be + /// rethrown to the caller. + /// + /// + /// If the value is set to zero, all function calling exceptions immediately terminate the function + /// invocation loop and the exception will be rethrown to the caller. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to how many iterations are allowed for an in-flight request. + /// + /// + public int MaximumConsecutiveErrorsPerRequest + { + get => _maximumConsecutiveErrorsPerRequest; + set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0); + } + + /// Gets or sets a collection of additional tools the client is able to invoke. + /// + /// These will not impact the requests sent by the , which will pass through the + /// unmodified. However, if the inner client requests the invocation of a tool + /// that was not in , this collection will also be consulted + /// to look for a corresponding tool to invoke. This is useful when the service may have been pre-configured to be aware + /// of certain tools that aren't also sent on each individual request. + /// + public IList? AdditionalTools { get; set; } + + /// Gets or sets a delegate used to invoke instances. + /// + /// By default, the protected method is called for each to be invoked, + /// invoking the instance and returning its result. If this delegate is set to a non- value, + /// will replace its normal invocation with a call to this delegate, enabling + /// this delegate to assume all invocation handling of the function. + /// + public Func>? FunctionInvoker { get; set; } + + /// + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + // A single request into this GetResponseAsync may result in multiple requests to the inner client. + // Create an activity to group them together for better observability. + using Activity? activity = _activitySource?.StartActivity($"{nameof(FunctionInvokingChatClient)}.{nameof(GetResponseAsync)}"); + + // Copy the original messages in order to avoid enumerating the original messages multiple times. + // The IEnumerable can represent an arbitrary amount of work. + List originalMessages = [.. messages]; + messages = originalMessages; + + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response + UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response + List? functionCallContents = null; // function call contents that need responding to in the current turn + bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set + int consecutiveErrorCount = 0; + + for (int iteration = 0; ; iteration++) + { + functionCallContents?.Clear(); + + // Make the call to the inner client. + response = await base.GetResponseAsync(messages, options, cancellationToken); + if (response is null) + { + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); + } + + // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. + bool requiresFunctionInvocation = + (options?.Tools is { Count: > 0 } || AdditionalTools is { Count: > 0 }) && + iteration < MaximumIterationsPerRequest && + CopyFunctionCalls(response.Messages, ref functionCallContents); + + // In a common case where we make a request and there's no function calling work required, + // fast path out by just returning the original response. + if (iteration == 0 && !requiresFunctionInvocation) + { + return response; + } + + // Track aggregate details from the response, including all of the response messages and usage details. + (responseMessages ??= []).AddRange(response.Messages); + if (response.Usage is not null) + { + if (totalUsage is not null) + { + totalUsage.Add(response.Usage); + } + else + { + totalUsage = response.Usage; + } + } + + // If there are no tools to call, or for any other reason we should stop, we're done. + // Break out of the loop and allow the handling at the end to configure the response + // with aggregated data from previous requests. + if (!requiresFunctionInvocation) + { + break; + } + + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId); + + // Add the responses from the function calls into the augmented history and also into the tracked + // list of response messages. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken); + responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; + + if (modeAndMessages.ShouldTerminate) + { + break; + } + + UpdateOptionsForNextIteration(ref options, response.ConversationId); + } + + Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); + response.Messages = responseMessages!; + response.Usage = totalUsage; + + AddUsageTags(activity, totalUsage); + + return response; + } + + /// + public override async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. + // Create an activity to group them together for better observability. + using Activity? activity = _activitySource?.StartActivity($"{nameof(FunctionInvokingChatClient)}.{nameof(GetStreamingResponseAsync)}"); + UsageDetails? totalUsage = activity is { IsAllDataRequested: true } ? new() : null; // tracked usage across all turns, to be used for activity purposes + + // Copy the original messages in order to avoid enumerating the original messages multiple times. + // The IEnumerable can represent an arbitrary amount of work. + List originalMessages = [.. messages]; + messages = originalMessages; + + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + List? functionCallContents = null; // function call contents that need responding to in the current turn + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history + bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set + List updates = []; // updates from the current response + int consecutiveErrorCount = 0; + + for (int iteration = 0; ; iteration++) + { + updates.Clear(); + functionCallContents?.Clear(); + + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken)) + { + if (update is null) + { + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); + } + + updates.Add(update); + + _ = CopyFunctionCalls(update.Contents, ref functionCallContents); + + if (totalUsage is not null) + { + IList contents = update.Contents; + int contentsCount = contents.Count; + for (int i = 0; i < contentsCount; i++) + { + if (contents[i] is UsageContent uc) + { + totalUsage.Add(uc.Details); + } + } + } + + yield return update; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + + // If there are no tools to call, or for any other reason we should stop, return the response. + if (functionCallContents is not { Count: > 0 } || + (options?.Tools is not { Count: > 0 } && AdditionalTools is not { Count: > 0 }) || + iteration >= _maximumIterationsPerRequest) + { + break; + } + + // Reconstitute a response from the response updates. + var response = updates.ToChatResponse(); + (responseMessages ??= []).AddRange(response.Messages); + + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId); + + // Process all of the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken); + responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; + + // This is a synthetic ID since we're generating the tool messages instead of getting them from + // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to + // use the same message ID for all of them within a given iteration, as this is a single logical + // message with multiple content items. We could also use different message IDs per tool content, + // but there's no benefit to doing so. + string toolResponseId = Guid.NewGuid().ToString("N"); + + // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages + // includes all activities, including generated function results. + foreach (var message in modeAndMessages.MessagesAdded) + { + var toolResultUpdate = new ChatResponseUpdate + { + AdditionalProperties = message.AdditionalProperties, + AuthorName = message.AuthorName, + ConversationId = response.ConversationId, + CreatedAt = DateTimeOffset.UtcNow, + Contents = message.Contents, + RawRepresentation = message.RawRepresentation, + ResponseId = toolResponseId, + MessageId = toolResponseId, // See above for why this can be the same as ResponseId + Role = message.Role, + }; + + yield return toolResultUpdate; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + + if (modeAndMessages.ShouldTerminate) + { + break; + } + + UpdateOptionsForNextIteration(ref options, response.ConversationId); + } + + AddUsageTags(activity, totalUsage); + } + + /// Adds tags to for usage details in . + private static void AddUsageTags(Activity? activity, UsageDetails? usage) + { + if (usage is not null && activity is { IsAllDataRequested: true }) + { + if (usage.InputTokenCount is long inputTokens) + { + _ = activity.AddTag(OpenTelemetryConsts.GenAI.Usage.InputTokens, (int)inputTokens); + } + + if (usage.OutputTokenCount is long outputTokens) + { + _ = activity.AddTag(OpenTelemetryConsts.GenAI.Usage.OutputTokens, (int)outputTokens); + } + } + } + + /// Prepares the various chat message lists after a response from the inner client and before invoking functions. + /// The original messages provided by the caller. + /// The messages reference passed to the inner client. + /// The augmented history containing all the messages to be sent. + /// The most recent response being handled. + /// A list of all response messages received up until this point. + /// Whether the previous iteration's response had a conversation ID. + private static void FixupHistories( + IEnumerable originalMessages, + ref IEnumerable messages, + [NotNull] ref List? augmentedHistory, + ChatResponse response, + List allTurnsResponseMessages, + ref bool lastIterationHadConversationId) + { + // We're now going to need to augment the history with function result contents. + // That means we need a separate list to store the augmented history. + if (response.ConversationId is not null) + { + // The response indicates the inner client is tracking the history, so we don't want to send + // anything we've already sent or received. + if (augmentedHistory is not null) + { + augmentedHistory.Clear(); + } + else + { + augmentedHistory = []; + } + + lastIterationHadConversationId = true; + } + else if (lastIterationHadConversationId) + { + // In the very rare case where the inner client returned a response with a conversation ID but then + // returned a subsequent response without one, we want to reconstitute the full history. To do that, + // we can populate the history with the original chat messages and then all of the response + // messages up until this point, which includes the most recent ones. + augmentedHistory ??= []; + augmentedHistory.Clear(); + augmentedHistory.AddRange(originalMessages); + augmentedHistory.AddRange(allTurnsResponseMessages); + + lastIterationHadConversationId = false; + } + else + { + // If augmentedHistory is already non-null, then we've already populated it with everything up + // until this point (except for the most recent response). If it's null, we need to seed it with + // the chat history provided by the caller. + augmentedHistory ??= originalMessages.ToList(); + + // Now add the most recent response messages. + augmentedHistory.AddMessages(response); + + lastIterationHadConversationId = false; + } + + // Use the augmented history as the new set of messages to send. + messages = augmentedHistory; + } + + /// Copies any from to . + private static bool CopyFunctionCalls( + IList messages, [NotNullWhen(true)] ref List? functionCalls) + { + bool any = false; + int count = messages.Count; + for (int i = 0; i < count; i++) + { + any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls); + } + + return any; + } + + /// Copies any from to . + private static bool CopyFunctionCalls( + IList content, [NotNullWhen(true)] ref List? functionCalls) + { + bool any = false; + int count = content.Count; + for (int i = 0; i < count; i++) + { + if (content[i] is FunctionCallContent functionCall) + { + (functionCalls ??= []).Add(functionCall); + any = true; + } + } + + return any; + } + + private static void UpdateOptionsForNextIteration(ref ChatOptions? options, string? conversationId) + { + if (options is null) + { + if (conversationId is not null) + { + options = new() { ConversationId = conversationId }; + } + } + else if (options.ToolMode is RequiredChatToolMode) + { + // We have to reset the tool mode to be non-required after the first iteration, + // as otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = null; + options.ConversationId = conversationId; + } + else if (options.ConversationId != conversationId) + { + // As with the other modes, ensure we've propagated the chat conversation ID to the options. + // We only need to clone the options if we're actually mutating it. + options = options.Clone(); + options.ConversationId = conversationId; + } + } + + /// + /// Processes the function calls in the list. + /// + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call contents representing the functions to be invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors. + /// Whether the function calls are being processed in a streaming context. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( + List messages, ChatOptions? options, List functionCallContents, int iteration, int consecutiveErrorCount, + bool isStreaming, CancellationToken cancellationToken) + { + // We must add a response for every tool call, regardless of whether we successfully executed it or not. + // If we successfully execute it, we'll add the result. If we don't, we'll add an error. + + Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call."); + var shouldTerminate = false; + var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest; + + // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. + if (functionCallContents.Count == 1) + { + FunctionInvocationResult result = await ProcessFunctionCallAsync( + messages, options, functionCallContents, + iteration, 0, captureCurrentIterationExceptions, isStreaming, cancellationToken); + + IList addedMessages = CreateResponseMessages([result]); + ThrowIfNoFunctionResultsAdded(addedMessages); + UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount); + messages.AddRange(addedMessages); + + return (result.Terminate, consecutiveErrorCount, addedMessages); + } + else + { + List results = []; + + if (AllowConcurrentInvocation) + { + // Rather than awaiting each function before invoking the next, invoke all of them + // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, + // but if a function invocation completes asynchronously, its processing can overlap + // with the processing of other the other invocation invocations. + results.AddRange(await Task.WhenAll( + from callIndex in Enumerable.Range(0, functionCallContents.Count) + select ProcessFunctionCallAsync( + messages, options, functionCallContents, + iteration, callIndex, captureExceptions: true, isStreaming, cancellationToken))); + + shouldTerminate = results.Any(r => r.Terminate); + } + else + { + // Invoke each function serially. + for (int callIndex = 0; callIndex < functionCallContents.Count; callIndex++) + { + var functionResult = await ProcessFunctionCallAsync( + messages, options, functionCallContents, + iteration, callIndex, captureCurrentIterationExceptions, isStreaming, cancellationToken); + + results.Add(functionResult); + + // If any function requested termination, we should stop right away. + if (functionResult.Terminate) + { + shouldTerminate = true; + break; + } + } + } + + IList addedMessages = CreateResponseMessages(results.ToArray()); + ThrowIfNoFunctionResultsAdded(addedMessages); + UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount); + messages.AddRange(addedMessages); + + return (shouldTerminate, consecutiveErrorCount, addedMessages); + } + } + +#pragma warning disable CA1851 // Possible multiple enumerations of 'IEnumerable' collection + /// + /// Updates the consecutive error count, and throws an exception if the count exceeds the maximum. + /// + /// Added messages. + /// Consecutive error count. + /// Thrown if the maximum consecutive error count is exceeded. + private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount) + { + var allExceptions = added.SelectMany(m => m.Contents.OfType()) + .Select(frc => frc.Exception!) + .Where(e => e is not null); + + if (allExceptions.Any()) + { + consecutiveErrorCount++; + if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest) + { + var allExceptionsArray = allExceptions.ToArray(); + if (allExceptionsArray.Length == 1) + { + ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw(); + } + else + { + throw new AggregateException(allExceptionsArray); + } + } + } + else + { + consecutiveErrorCount = 0; + } + } +#pragma warning restore CA1851 + + /// + /// Throws an exception if doesn't create any messages. + /// + private void ThrowIfNoFunctionResultsAdded(IList? messages) + { + if (messages is null || messages.Count == 0) + { + Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); + } + } + + /// Processes the function call described in []. + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call contents representing all the functions being invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The 0-based index of the function being called out of . + /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows. + /// Whether the function calls are being processed in a streaming context. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task ProcessFunctionCallAsync( + List messages, ChatOptions? options, List callContents, + int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken) + { + var callContent = callContents[functionCallIndex]; + + // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. + AIFunction? aiFunction = FindAIFunction(options?.Tools, callContent.Name) ?? FindAIFunction(AdditionalTools, callContent.Name); + if (aiFunction is null) + { + return new(terminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); + } + + FunctionInvocationContext context = new() + { + Function = aiFunction, + Arguments = new(callContent.Arguments) { Services = FunctionInvocationServices }, + Messages = messages, + Options = options, + CallContent = callContent, + Iteration = iteration, + FunctionCallIndex = functionCallIndex, + FunctionCount = callContents.Count, + IsStreaming = isStreaming + }; + + object? result; + try + { + result = await InstrumentedInvokeFunctionAsync(context, cancellationToken); + } + catch (Exception e) when (!cancellationToken.IsCancellationRequested) + { + if (!captureExceptions) + { + throw; + } + + return new( + terminate: false, + FunctionInvocationStatus.Exception, + callContent, + result: null, + exception: e); + } + + return new( + terminate: context.Terminate, + FunctionInvocationStatus.RanToCompletion, + callContent, + result, + exception: null); + + static AIFunction? FindAIFunction(IList? tools, string functionName) + { + if (tools is not null) + { + int count = tools.Count; + for (int i = 0; i < count; i++) + { + if (tools[i] is AIFunction function && function.Name == functionName) + { + return function; + } + } + } + + return null; + } + } + + /// Creates one or more response messages for function invocation results. + /// Information about the function call invocations and results. + /// A list of all chat messages created from . + protected virtual IList CreateResponseMessages( + ReadOnlySpan results) + { + var contents = new List(results.Length); + for (int i = 0; i < results.Length; i++) + { + contents.Add(CreateFunctionResultContent(results[i])); + } + + return [new(ChatRole.Tool, contents)]; + + FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) + { + _ = Throw.IfNull(result); + + object? functionResult; + if (result.Status == FunctionInvocationStatus.RanToCompletion) + { + functionResult = result.Result ?? "Success: Function completed."; + } + else + { + string message = result.Status switch + { + FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", + FunctionInvocationStatus.Exception => "Error: Function failed.", + _ => "Error: Unknown error.", + }; + + if (IncludeDetailedErrors && result.Exception is not null) + { + message = $"{message} Exception: {result.Exception.Message}"; + } + + functionResult = message; + } + + return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; + } + } + + /// Invokes the function asynchronously. + /// + /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// + /// The to monitor for cancellation requests. The default is . + /// The result of the function invocation, or if the function invocation returned . + /// is . + private async Task InstrumentedInvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) + { + _ = Throw.IfNull(context); + + using Activity? activity = _activitySource?.StartActivity( + $"{OpenTelemetryConsts.GenAI.ExecuteTool} {context.Function.Name}", + ActivityKind.Internal, + default(ActivityContext), + [ + new(OpenTelemetryConsts.GenAI.Operation.Name, "execute_tool"), + new(OpenTelemetryConsts.GenAI.Tool.Call.Id, context.CallContent.CallId), + new(OpenTelemetryConsts.GenAI.Tool.Name, context.Function.Name), + new(OpenTelemetryConsts.GenAI.Tool.Description, context.Function.Description), + ]); + + long startingTimestamp = 0; + if (_logger.IsEnabled(LogLevel.Debug)) + { + startingTimestamp = Stopwatch.GetTimestamp(); + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions)); + } + else + { + LogInvoking(context.Function.Name); + } + } + + object? result = null; + try + { + CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit + result = await InvokeFunctionAsync(context, cancellationToken); + } + catch (Exception e) + { + if (activity is not null) + { + _ = activity.SetTag("error.type", e.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, e.Message); + } + + if (e is OperationCanceledException) + { + LogInvocationCanceled(context.Function.Name); + } + else + { + LogInvocationFailed(context.Function.Name, e); + } + + throw; + } + finally + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + TimeSpan elapsed = GetElapsedTime(startingTimestamp); + + if (result is not null && _logger.IsEnabled(LogLevel.Trace)) + { + LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingHelpers.AsJson(result, context.Function.JsonSerializerOptions)); + } + else + { + LogInvocationCompleted(context.Function.Name, elapsed); + } + } + } + + return result; + } + + /// This method will invoke the function within the try block. + /// The function invocation context. + /// Cancellation token. + /// The function result. + protected virtual ValueTask InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) + { + _ = Throw.IfNull(context); + + return FunctionInvoker is { } invoker ? + invoker(context, cancellationToken) : + context.Function.InvokeAsync(context.Arguments, cancellationToken); + } + + private static TimeSpan GetElapsedTime(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency))); +#endif + + [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)] + private partial void LogInvoking(string methodName); + + [LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)] + private partial void LogInvokingSensitive(string methodName, string arguments); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)] + private partial void LogInvocationCompleted(string methodName, TimeSpan duration); + + [LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)] + private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")] + private partial void LogInvocationCanceled(string methodName); + + [LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")] + private partial void LogInvocationFailed(string methodName, Exception error); + + /// Provides information about the invocation of a function call. + public sealed class FunctionInvocationResult + { + /// + /// Initializes a new instance of the class. + /// + /// Indicates whether the caller should terminate the processing loop. + /// Indicates the status of the function invocation. + /// Contains information about the function call. + /// The result of the function call. + /// The exception thrown by the function call, if any. + internal FunctionInvocationResult(bool terminate, FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception) + { + Terminate = terminate; + Status = status; + CallContent = callContent; + Result = result; + Exception = exception; + } + + /// Gets status about how the function invocation completed. + public FunctionInvocationStatus Status { get; } + + /// Gets the function call content information associated with this invocation. + public FunctionCallContent CallContent { get; } + + /// Gets the result of the function call. + public object? Result { get; } + + /// Gets any exception the function call threw. + public Exception? Exception { get; } + + /// Gets a value indicating whether the caller should terminate the processing loop. + public bool Terminate { get; } + } + + /// Provides error codes for when errors occur as part of the function calling loop. + public enum FunctionInvocationStatus + { + /// The operation completed successfully. + RanToCompletion, + + /// The requested function could not be found. + NotFound, + + /// The function call failed with an exception. + Exception, + } +} diff --git a/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/OpenTelemetryConsts.cs b/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/OpenTelemetryConsts.cs new file mode 100644 index 0000000000..89eebe9f61 --- /dev/null +++ b/dotnet/src/Microsoft.Extensions.AI.Agents/MEAI/OpenTelemetryConsts.cs @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft. All rights reserved. + +// WARNING: +// This class has been temporarily copied here from MEAI, to allow prototyping +// functionality that will be moved to MEAI in the future. +// This file is not intended to be modified. + +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.Extensions.AI; + +#pragma warning disable CA1716 // Identifiers should not match keywords +#pragma warning disable S4041 // Type names should not match namespaces + +/// Provides constants used by various telemetry services. +[ExcludeFromCodeCoverage] +internal static class OpenTelemetryConsts +{ + public const string DefaultSourceName = "Experimental.Microsoft.Extensions.AI"; + + public const string SecondsUnit = "s"; + public const string TokensUnit = "token"; + + public static class Event + { + public const string Name = "event.name"; + } + + public static class Error + { + public const string Type = "error.type"; + } + + public static class GenAI + { + public const string Choice = "gen_ai.choice"; + public const string SystemName = "gen_ai.system"; + + public const string Chat = "chat"; + public const string Embeddings = "embeddings"; + public const string ExecuteTool = "execute_tool"; + + public static class Assistant + { + public const string Message = "gen_ai.assistant.message"; + } + + public static class Client + { + public static class OperationDuration + { + public const string Description = "Measures the duration of a GenAI operation"; + public const string Name = "gen_ai.client.operation.duration"; + public static readonly double[] ExplicitBucketBoundaries = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28, 2.56, 5.12, 10.24, 20.48, 40.96, 81.92]; + } + + public static class TokenUsage + { + public const string Description = "Measures number of input and output tokens used"; + public const string Name = "gen_ai.client.token.usage"; + public static readonly int[] ExplicitBucketBoundaries = [1, 4, 16, 64, 256, 1_024, 4_096, 16_384, 65_536, 262_144, 1_048_576, 4_194_304, 16_777_216, 67_108_864]; + } + } + + public static class Conversation + { + public const string Id = "gen_ai.conversation.id"; + } + + public static class Operation + { + public const string Name = "gen_ai.operation.name"; + } + + public static class Output + { + public const string Type = "gen_ai.output.type"; + } + + public static class Request + { + public const string EmbeddingDimensions = "gen_ai.request.embedding.dimensions"; + public const string FrequencyPenalty = "gen_ai.request.frequency_penalty"; + public const string Model = "gen_ai.request.model"; + public const string MaxTokens = "gen_ai.request.max_tokens"; + public const string PresencePenalty = "gen_ai.request.presence_penalty"; + public const string Seed = "gen_ai.request.seed"; + public const string StopSequences = "gen_ai.request.stop_sequences"; + public const string Temperature = "gen_ai.request.temperature"; + public const string TopK = "gen_ai.request.top_k"; + public const string TopP = "gen_ai.request.top_p"; + + public static string PerProvider(string providerName, string parameterName) => $"gen_ai.{providerName}.request.{parameterName}"; + } + + public static class Response + { + public const string FinishReasons = "gen_ai.response.finish_reasons"; + public const string Id = "gen_ai.response.id"; + public const string Model = "gen_ai.response.model"; + + public static string PerProvider(string providerName, string parameterName) => $"gen_ai.{providerName}.response.{parameterName}"; + } + + public static class System + { + public const string Message = "gen_ai.system.message"; + } + + public static class Token + { + public const string Type = "gen_ai.token.type"; + } + + public static class Tool + { + public const string Name = "gen_ai.tool.name"; + public const string Description = "gen_ai.tool.description"; + public const string Message = "gen_ai.tool.message"; + + public static class Call + { + public const string Id = "gen_ai.tool.call.id"; + } + } + + public static class Usage + { + public const string InputTokens = "gen_ai.usage.input_tokens"; + public const string OutputTokens = "gen_ai.usage.output_tokens"; + } + + public static class User + { + public const string Message = "gen_ai.user.message"; + } + } + + public static class Server + { + public const string Address = "server.address"; + public const string Port = "server.port"; + } +}