// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.CopilotStudio.Client;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Extensions.Msal;
using Microsoft.Shared.Diagnostics;
namespace CopilotStudio.IntegrationTests.Support;
#pragma warning disable CA1812 // Internal class that is apparently never instantiated.
///
/// A that adds an authentication token to the request headers for Copilot Studio API calls.
///
///
/// For more information on how to setup various authentication flows, see the Microsoft Identity documentation at https://aka.ms/msal.
///
internal sealed class CopilotStudioTokenHandler : HttpClientHandler
{
private const string AuthenticationHeader = "Bearer";
private const string CacheFolderName = "mcs_client_console";
private const string KeyChainServiceName = "copilot_studio_client_app";
private const string KeyChainAccountName = "copilot_studio_client";
private readonly CopilotStudioConnectionSettings _settings;
private readonly string[] _scopes;
private IConfidentialClientApplication? _clientApplication;
///
/// Initializes a new instance of the class with the specified connection settings.
///
/// The connection settings for Copilot Studio.
public CopilotStudioTokenHandler(CopilotStudioConnectionSettings settings)
{
Throw.IfNull(settings);
this._settings = settings;
this._scopes = [CopilotClient.ScopeFromSettings(this._settings)];
}
///
protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if (request.Headers.Authorization is null)
{
AuthenticationResult authResponse = await this.AuthenticateAsync(cancellationToken).ConfigureAwait(false);
request.Headers.Authorization = new AuthenticationHeaderValue(AuthenticationHeader, authResponse.AccessToken);
}
return await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
}
private Task AuthenticateAsync(CancellationToken cancellationToken) =>
this._settings.UseInteractiveAuthentication ?
this.AuthenticateInteractiveAsync(cancellationToken) :
this.AuthenticateServiceAsync(cancellationToken);
private async Task AuthenticateServiceAsync(CancellationToken cancellationToken)
{
if (this._clientApplication is null)
{
this._clientApplication = ConfidentialClientApplicationBuilder.Create(this._settings.AppClientId)
.WithAuthority(AzureCloudInstance.AzurePublic, this._settings.TenantId)
.WithClientSecret(this._settings.AppClientSecret)
.Build();
MsalCacheHelper tokenCacheHelper = await CreateCacheHelperAsync("AppTokenCache").ConfigureAwait(false);
tokenCacheHelper.RegisterCache(this._clientApplication.AppTokenCache);
}
AuthenticationResult authResponse;
authResponse = await this._clientApplication.AcquireTokenForClient(this._scopes).ExecuteAsync(cancellationToken).ConfigureAwait(false);
return authResponse;
}
private async Task AuthenticateInteractiveAsync(CancellationToken cancellationToken = default!)
{
IPublicClientApplication app =
PublicClientApplicationBuilder.Create(this._settings.AppClientId)
.WithAuthority(AadAuthorityAudience.AzureAdMyOrg)
.WithTenantId(this._settings.TenantId)
.WithRedirectUri("http://localhost")
.Build();
MsalCacheHelper tokenCacheHelper = await CreateCacheHelperAsync("TokenCache").ConfigureAwait(false);
tokenCacheHelper.RegisterCache(app.UserTokenCache);
IEnumerable accounts = await app.GetAccountsAsync().ConfigureAwait(false);
IAccount? account = accounts.FirstOrDefault();
AuthenticationResult authResponse;
try
{
authResponse = await app.AcquireTokenSilent(this._scopes, account).ExecuteAsync(cancellationToken).ConfigureAwait(false);
}
catch (MsalUiRequiredException)
{
authResponse = await app.AcquireTokenInteractive(this._scopes).ExecuteAsync(cancellationToken).ConfigureAwait(false);
}
return authResponse;
}
private static async Task CreateCacheHelperAsync(string cacheFileName)
{
string currentDir = Path.Combine(AppContext.BaseDirectory, CacheFolderName);
if (!Directory.Exists(currentDir))
{
Directory.CreateDirectory(currentDir);
}
StorageCreationPropertiesBuilder storageProperties = new(cacheFileName, currentDir);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
storageProperties.WithLinuxUnprotectedFile();
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
storageProperties.WithMacKeyChain(KeyChainServiceName, KeyChainAccountName);
}
return await MsalCacheHelper.CreateAsync(storageProperties.Build()).ConfigureAwait(false);
}
}