Skip to content

Commit

Permalink
Merge pull request pnp#1056 from gautamdsheth/feature/token-cache
Browse files Browse the repository at this point in the history
Feature: Improve authentication performance by caching on file system
  • Loading branch information
gautamdsheth authored Aug 30, 2024
2 parents 83825f4 + c8c35a3 commit 279e18c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 24 deletions.
64 changes: 40 additions & 24 deletions src/lib/PnP.Framework/AuthenticationManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Broker;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Extensions.Msal;
using Microsoft.SharePoint.Client;
using PnP.Core.Services;
using PnP.Framework.Http;
using PnP.Framework.Utilities;
using PnP.Framework.Utilities.Cache;
using PnP.Framework.Utilities.Context;
using System;
using System.Configuration;
Expand Down Expand Up @@ -347,14 +349,14 @@ public AuthenticationManager(SecureString accessToken)
/// <param name="managedIdentityUserAssignedIdentifier">The identifier of the User Assigned Managed Identity. Can be the clientId, objectId or resourceId. Mandatory when <paramref name="managedIdentityType"/> is not SystemAssigned. Should be omitted if it is SystemAssigned.</param>
public AuthenticationManager(string endpoint, string identityHeader, ManagedIdentityType managedIdentityType = ManagedIdentityType.SystemAssigned, string managedIdentityUserAssignedIdentifier = null)
{
if(managedIdentityType != ManagedIdentityType.SystemAssigned && string.IsNullOrWhiteSpace(managedIdentityUserAssignedIdentifier))
if (managedIdentityType != ManagedIdentityType.SystemAssigned && string.IsNullOrWhiteSpace(managedIdentityUserAssignedIdentifier))
{
throw new ArgumentException($"When {nameof(managedIdentityType)} is not SystemAssigned, {nameof(managedIdentityUserAssignedIdentifier)} must be provided", nameof(managedIdentityType));
}

authenticationType = managedIdentityType == ManagedIdentityType.SystemAssigned ? ClientContextType.SystemAssignedManagedIdentity : ClientContextType.UserAssignedManagedIdentity;
this.managedIdentityType = managedIdentityType;
this.managedIdentityUserAssignedIdentifier = managedIdentityUserAssignedIdentifier;
this.managedIdentityType = managedIdentityType;
this.managedIdentityUserAssignedIdentifier = managedIdentityUserAssignedIdentifier;

// Construct the URL to call to get the token based on the type of Managed Identity in use
switch (managedIdentityType)
Expand All @@ -379,7 +381,7 @@ public AuthenticationManager(string endpoint, string identityHeader, ManagedIden
Diagnostics.Log.Debug(Constants.LOGGING_SOURCE, "Using the system assigned managed identity");
mi = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned).WithHttpClientFactory(HttpClientFactory).Build();
break;
}
}

}

Expand Down Expand Up @@ -412,12 +414,14 @@ public AuthenticationManager(string clientId, string username, SecureString pass
if (!string.IsNullOrEmpty(redirectUrl))
{
builder = builder.WithRedirectUri(redirectUrl);
}
}
builder.WithLegacyCacheCompatibility(false);
this.username = username;
this.password = password;
publicClientApplication = builder.Build();

var cacheHelper = MsalCacheHelperUtility.CreateCacheHelper();
cacheHelper?.RegisterCache(publicClientApplication.UserTokenCache);
// register tokencache if callback provided
tokenCacheCallback?.Invoke(publicClientApplication.UserTokenCache);
authenticationType = ClientContextType.AzureADCredentials;
Expand All @@ -434,7 +438,7 @@ public AuthenticationManager(string clientId, string username, SecureString pass
/// <param name="azureEnvironment">The azure environment to use. Defaults to AzureEnvironment.Production</param>
/// <param name="tokenCacheCallback">If present, after setting up the base flow for authentication this callback will be called to register a custom tokencache. See https://aka.ms/msal-net-token-cache-serialization.</param>
/// <param name="useWAM">If true, uses WAM for authentication. Works only on Windows OS</param>
public AuthenticationManager(string clientId, Action<string, int> openBrowserCallback, string tenantId = null, string successMessageHtml = null, string failureMessageHtml = null, AzureEnvironment azureEnvironment = AzureEnvironment.Production, Action<ITokenCache> tokenCacheCallback = null, bool useWAM = false) : this(clientId, Utilities.OAuth.DefaultBrowserUi.FindFreeLocalhostRedirectUri(), tenantId, azureEnvironment, tokenCacheCallback , new Utilities.OAuth.DefaultBrowserUi(openBrowserCallback, successMessageHtml, failureMessageHtml), useWAM = false)
public AuthenticationManager(string clientId, Action<string, int> openBrowserCallback, string tenantId = null, string successMessageHtml = null, string failureMessageHtml = null, AzureEnvironment azureEnvironment = AzureEnvironment.Production, Action<ITokenCache> tokenCacheCallback = null, bool useWAM = false) : this(clientId, Utilities.OAuth.DefaultBrowserUi.FindFreeLocalhostRedirectUri(), tenantId, azureEnvironment, tokenCacheCallback, new Utilities.OAuth.DefaultBrowserUi(openBrowserCallback, successMessageHtml, failureMessageHtml), useWAM = false)
{
}

Expand All @@ -452,30 +456,39 @@ public AuthenticationManager(string clientId, string redirectUrl = null, string
{
this.azureEnvironment = azureEnvironment;

var builder = PublicClientApplicationBuilder.Create(clientId).WithHttpClientFactory(HttpClientFactory);
if (useWAM && Environment.OSVersion.Platform == PlatformID.Win32NT)
PublicClientApplicationBuilder builder = PublicClientApplicationBuilder.Create(clientId).WithHttpClientFactory(HttpClientFactory); ;
builder = GetBuilderWithAuthority(builder, azureEnvironment);
if (useWAM && SharedUtilities.IsWindowsPlatform())
{
BrokerOptions brokerOptions = new(BrokerOptions.OperatingSystems.Windows)
{
Title = "Login with M365 PnP"
Title = "Login with M365 PnP",
ListOperatingSystemAccounts = true,
};
builder = builder.WithBroker(brokerOptions).WithDefaultRedirectUri().WithParentActivityOrWindow(WindowHandleUtilities.GetConsoleOrTerminalWindow).WithHttpClientFactory(HttpClientFactory);
}

builder = GetBuilderWithAuthority(builder, azureEnvironment);
builder = builder.WithBroker(brokerOptions).WithDefaultRedirectUri().WithParentActivityOrWindow(WindowHandleUtilities.GetConsoleOrTerminalWindow);

if (!string.IsNullOrEmpty(redirectUrl))
{
builder = builder.WithRedirectUri(redirectUrl);
if (!string.IsNullOrEmpty(tenantId))
{
builder = builder.WithTenantId(tenantId);
}
}
if (!string.IsNullOrEmpty(tenantId))
else
{
builder = builder.WithTenantId(tenantId);
if (!string.IsNullOrEmpty(redirectUrl))
{
builder = builder.WithRedirectUri(redirectUrl);
}
if (!string.IsNullOrEmpty(tenantId))
{
builder = builder.WithTenantId(tenantId);
}
this.customWebUi = customWebUi;
}
builder.WithLegacyCacheCompatibility(false);
publicClientApplication = builder.Build();

this.customWebUi = customWebUi;
var cacheHelper = MsalCacheHelperUtility.CreateCacheHelper();
cacheHelper?.RegisterCache(publicClientApplication.UserTokenCache);

// register tokencache if callback provided
tokenCacheCallback?.Invoke(publicClientApplication.UserTokenCache);
Expand Down Expand Up @@ -524,6 +537,9 @@ public AuthenticationManager(string clientId, string tenantId, Func<DeviceCodeRe
builder.WithLegacyCacheCompatibility(false);
publicClientApplication = builder.Build();

var cacheHelper = MsalCacheHelperUtility.CreateCacheHelper();
cacheHelper?.RegisterCache(publicClientApplication.UserTokenCache);

// register tokencache if callback provided
tokenCacheCallback?.Invoke(publicClientApplication.UserTokenCache);

Expand Down Expand Up @@ -831,7 +847,7 @@ public async Task<string> GetAccessTokenAsync(string[] scopes, CancellationToken
{
AuthenticationResult authResult = null;


Diagnostics.Log.Debug("GetAccessTokenAsync", $"Authentication type: {authenticationType}");

switch (authenticationType)
Expand Down Expand Up @@ -954,7 +970,7 @@ public async Task<string> GetAccessTokenAsync(string[] scopes, CancellationToken
// If it is a Uri, we're going to assume the audience is the root part of the Uri, i.e. tenant.sharepoint.com
var audienceUri = new Uri(scopes.FirstOrDefault(s => Uri.IsWellFormedUriString(s, UriKind.Absolute)) ?? $"https://{GetGraphEndPoint()}");
return GetManagedIdentityToken($"{audienceUri.Scheme}://{audienceUri.Authority}");
}
}
case ClientContextType.PnPCoreSdk:
{
return await this.authenticationProvider.GetAccessTokenAsync(uri, scopes).ConfigureAwait(false);
Expand Down Expand Up @@ -1490,7 +1506,7 @@ public static string GetACSEndPoint(AzureEnvironment environment)
AzureEnvironment.Production => "accesscontrol.windows.net",
AzureEnvironment.Germany => "microsoftonline.de",
AzureEnvironment.China => "accesscontrol.chinacloudapi.cn",
AzureEnvironment.USGovernment => "accesscontrol.windows.net",
AzureEnvironment.USGovernment => "accesscontrol.windows.net",
AzureEnvironment.USGovernmentHigh => "microsoftonline.us",
AzureEnvironment.USGovernmentDoD => "microsoftonline.us",
AzureEnvironment.PPE => "windows-ppe.net",
Expand Down Expand Up @@ -1928,7 +1944,7 @@ public ConfidentialClientApplicationBuilder GetBuilderWithAuthority(Confidential
{
switch (azureEnvironment)
{
case AzureEnvironment.USGovernment:
case AzureEnvironment.USGovernment:
{
builder = builder.WithAuthority(AzureCloudInstance.AzurePublic, AadAuthorityAudience.AzureAdMyOrg);
break;
Expand Down
86 changes: 86 additions & 0 deletions src/lib/PnP.Framework/Utilities/Cache/MsalCacheHelperUtility.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Microsoft.Identity.Client.Extensions.Msal;
using System;
using System.Collections.Generic;
using System.IO;

namespace PnP.Framework.Utilities.Cache
{
public class MsalCacheHelperUtility
{

private static MsalCacheHelper MsalCacheHelper;
private static readonly object ObjectLock = new();

private static class Config
{
// Cache settings
public const string CacheFileName = "m365pnpmsal.cache";
public readonly static string CacheDir = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), ".M365PnPAuthService");

public const string KeyChainServiceName = "M365.PnP.Framework";
public const string KeyChainAccountName = "M365PnPAuthCache";

public const string LinuxKeyRingSchema = "com.m365.pnp.auth.tokencache";
public const string LinuxKeyRingCollection = MsalCacheHelper.LinuxKeyRingDefaultCollection;
public const string LinuxKeyRingLabel = "MSAL token cache for M365 PnP Framework.";
public static readonly KeyValuePair<string, string> LinuxKeyRingAttr1 = new KeyValuePair<string, string>("Version", "1");
public static readonly KeyValuePair<string, string> LinuxKeyRingAttr2 = new KeyValuePair<string, string>("Product", "M365PnPAuth");
}

public static MsalCacheHelper CreateCacheHelper()
{
if (MsalCacheHelper == null)
{
lock (ObjectLock)
{
if (MsalCacheHelper == null)
{
StorageCreationProperties storageProperties;

try
{
storageProperties = new StorageCreationPropertiesBuilder(
Config.CacheFileName,
Config.CacheDir)
.WithLinuxKeyring(
Config.LinuxKeyRingSchema,
Config.LinuxKeyRingCollection,
Config.LinuxKeyRingLabel,
Config.LinuxKeyRingAttr1,
Config.LinuxKeyRingAttr2)
.WithMacKeyChain(
Config.KeyChainServiceName,
Config.KeyChainAccountName)
.Build();

var cacheHelper = MsalCacheHelper.CreateAsync(storageProperties).ConfigureAwait(false).GetAwaiter().GetResult();

cacheHelper.VerifyPersistence();
MsalCacheHelper = cacheHelper;

}
catch (MsalCachePersistenceException)
{
// do not use the same file name so as not to overwrite the encrypted version
storageProperties = new StorageCreationPropertiesBuilder(
Config.CacheFileName + ".plaintext",
Config.CacheDir)
.WithUnprotectedFile()
.Build();

var cacheHelper = MsalCacheHelper.CreateAsync(storageProperties).ConfigureAwait(false).GetAwaiter().GetResult();
cacheHelper.VerifyPersistence();

MsalCacheHelper = cacheHelper;
}
catch
{
MsalCacheHelper = null;
}
}
}
}
return MsalCacheHelper;
}
}
}

0 comments on commit 279e18c

Please sign in to comment.