// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Net.Http; using System.Net.Http.Headers; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.CopilotStudio.Client; using Microsoft.Identity.Client; using Microsoft.Identity.Client.Extensions.Msal; using Microsoft.Shared.Diagnostics; namespace CopilotStudio.IntegrationTests.Support; #pragma warning disable CA1812 // Internal class that is apparently never instantiated. /// /// A that adds an authentication token to the request headers for Copilot Studio API calls. /// /// /// For more information on how to setup various authentication flows, see the Microsoft Identity documentation at https://aka.ms/msal. /// internal sealed class CopilotStudioTokenHandler : HttpClientHandler { private const string AuthenticationHeader = "Bearer"; private const string CacheFolderName = "mcs_client_console"; private const string KeyChainServiceName = "copilot_studio_client_app"; private const string KeyChainAccountName = "copilot_studio_client"; private readonly CopilotStudioConnectionSettings _settings; private readonly string[] _scopes; private IConfidentialClientApplication? _clientApplication; /// /// Initializes a new instance of the class with the specified connection settings. /// /// The connection settings for Copilot Studio. public CopilotStudioTokenHandler(CopilotStudioConnectionSettings settings) { Throw.IfNull(settings); this._settings = settings; this._scopes = [CopilotClient.ScopeFromSettings(this._settings)]; } /// protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { if (request.Headers.Authorization is null) { AuthenticationResult authResponse = await this.AuthenticateAsync(cancellationToken).ConfigureAwait(false); request.Headers.Authorization = new AuthenticationHeaderValue(AuthenticationHeader, authResponse.AccessToken); } return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); } private Task AuthenticateAsync(CancellationToken cancellationToken) => this._settings.UseInteractiveAuthentication ? this.AuthenticateInteractiveAsync(cancellationToken) : this.AuthenticateServiceAsync(cancellationToken); private async Task AuthenticateServiceAsync(CancellationToken cancellationToken) { if (this._clientApplication is null) { this._clientApplication = ConfidentialClientApplicationBuilder.Create(this._settings.AppClientId) .WithAuthority(AzureCloudInstance.AzurePublic, this._settings.TenantId) .WithClientSecret(this._settings.AppClientSecret) .Build(); MsalCacheHelper tokenCacheHelper = await CreateCacheHelperAsync("AppTokenCache").ConfigureAwait(false); tokenCacheHelper.RegisterCache(this._clientApplication.AppTokenCache); } AuthenticationResult authResponse; authResponse = await this._clientApplication.AcquireTokenForClient(this._scopes).ExecuteAsync(cancellationToken).ConfigureAwait(false); return authResponse; } private async Task AuthenticateInteractiveAsync(CancellationToken cancellationToken = default!) { IPublicClientApplication app = PublicClientApplicationBuilder.Create(this._settings.AppClientId) .WithAuthority(AadAuthorityAudience.AzureAdMyOrg) .WithTenantId(this._settings.TenantId) .WithRedirectUri("http://localhost") .Build(); MsalCacheHelper tokenCacheHelper = await CreateCacheHelperAsync("TokenCache").ConfigureAwait(false); tokenCacheHelper.RegisterCache(app.UserTokenCache); IEnumerable accounts = await app.GetAccountsAsync().ConfigureAwait(false); IAccount? account = accounts.FirstOrDefault(); AuthenticationResult authResponse; try { authResponse = await app.AcquireTokenSilent(this._scopes, account).ExecuteAsync(cancellationToken).ConfigureAwait(false); } catch (MsalUiRequiredException) { authResponse = await app.AcquireTokenInteractive(this._scopes).ExecuteAsync(cancellationToken).ConfigureAwait(false); } return authResponse; } private static async Task CreateCacheHelperAsync(string cacheFileName) { string currentDir = Path.Combine(AppContext.BaseDirectory, CacheFolderName); if (!Directory.Exists(currentDir)) { Directory.CreateDirectory(currentDir); } StorageCreationPropertiesBuilder storageProperties = new(cacheFileName, currentDir); if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { storageProperties.WithLinuxUnprotectedFile(); } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { storageProperties.WithMacKeyChain(KeyChainServiceName, KeyChainAccountName); } return await MsalCacheHelper.CreateAsync(storageProperties.Build()).ConfigureAwait(false); } }