diff --git a/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs b/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs
index 238e41a9b..04aa56333 100644
--- a/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs
+++ b/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs
@@ -5,13 +5,16 @@
using System.Globalization;
using System.Threading.Tasks;
using Microsoft.Extensions.Caching.Distributed;
+using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Cache;
using Microsoft.Identity.Web.Test.Common;
using Microsoft.Identity.Web.Test.Common.Mocks;
using Microsoft.Identity.Web.TokenCacheProviders.Distributed;
+using Microsoft.Identity.Web.TokenCacheProviders.InMemory;
using Microsoft.IdentityModel.Abstractions;
using Xunit;
@@ -190,6 +193,111 @@ public async Task SingletonMsal_ResultsInCorrectCacheEntries_Test()
}
}
+ #region CacheKeyExtensibility test
+ private const int TokenCacheMemoryLimitInMb = 100;
+ private static MemoryCache s_memoryCache = InitiatlizeMemoryCache();
+
+ private static MemoryCache InitiatlizeMemoryCache()
+ {
+ // For 100 MB limit ... ~2KB per token entry means 50,000 entries
+ var options = Options.Create(new MemoryCacheOptions() { SizeLimit = (TokenCacheMemoryLimitInMb / 2) * 1000 });
+ var cache = new MemoryCache(options);
+
+ return cache;
+ }
+
+ ///
+ /// Token cache for MSAL based on MemoryCache, which can be partitioned by an additional key.
+ /// For app tokens, the default key is ClientID + TenantID (and MSAL also looks for resource).
+ ///
+ private class PartitionedMsalTokenMemoryCacheProvider : MsalMemoryTokenCacheProvider
+ {
+ private readonly string? _cacheKeySuffix;
+
+ ///
+ /// Ctor
+ ///
+ /// A memory cache which can be configured for max size etc.
+ /// Additional cache options, which canbe ignored for app tokens.
+ /// An aditional partition key. If let null, the original cache scoping is used (clientID, tenantID). MSAL also looks for resource.
+ public PartitionedMsalTokenMemoryCacheProvider(
+ IMemoryCache memoryCache,
+ IOptions cacheOptions,
+ string? cachePartition) : base(memoryCache, cacheOptions)
+ {
+ _cacheKeySuffix = cachePartition;
+ }
+
+ public override string GetSuggestedCacheKey(TokenCacheNotificationArgs args)
+ {
+ return base.GetSuggestedCacheKey(args) + (_cacheKeySuffix ?? "");
+ }
+ }
+
+ private async Task GetTokensAssociatedWithKey(string? cachePartition, bool expectCacheHit)
+ {
+ MockHttpMessageHandler? handler = null;
+ MockHttpClientFactory? mockHttpClient = null;
+ try
+ {
+
+ if (expectCacheHit == false)
+ {
+ mockHttpClient = new MockHttpClientFactory();
+ handler = mockHttpClient.AddMockHandler(MockHttpCreator.CreateClientCredentialTokenHandler());
+ }
+
+ var msalMemoryTokenCacheProvider =
+ new PartitionedMsalTokenMemoryCacheProvider(
+ s_memoryCache,
+ Options.Create(new MsalMemoryTokenCacheOptions()),
+ cachePartition: cachePartition);
+
+ var confidentialApp = ConfidentialClientApplicationBuilder
+ .Create(TestConstants.ClientId)
+ .WithAuthority(TestConstants.AuthorityCommonTenant)
+ .WithHttpClientFactory(mockHttpClient)
+ .WithInstanceDiscovery(false)
+ .WithClientSecret(TestConstants.ClientSecret)
+ .Build();
+
+ await msalMemoryTokenCacheProvider.InitializeAsync(confidentialApp.AppTokenCache).ConfigureAwait(false);
+
+ AuthenticationResult result = await confidentialApp
+ .AcquireTokenForClient(["https://graph.microsoft.com/.default"])
+ .ExecuteAsync()
+ .ConfigureAwait(false);
+
+ Assert.Equal(
+ expectCacheHit ?
+ TokenSource.Cache :
+ TokenSource.IdentityProvider,
+ result.AuthenticationResultMetadata.TokenSource);
+
+ return result;
+
+ }
+ finally
+ {
+ handler?.Dispose();
+ mockHttpClient?.Dispose();
+ }
+ }
+
+ #endregion
+
+ [Fact]
+ public async Task CacheKeyExtensibility()
+ {
+ var result = await GetTokensAssociatedWithKey("foo", expectCacheHit: false).ConfigureAwait(false);
+ result = await GetTokensAssociatedWithKey("bar", expectCacheHit: false).ConfigureAwait(false);
+ result = await GetTokensAssociatedWithKey(null, expectCacheHit: false).ConfigureAwait(false);
+
+ result = await GetTokensAssociatedWithKey("foo", expectCacheHit: true).ConfigureAwait(false);
+ result = await GetTokensAssociatedWithKey("bar", expectCacheHit: true).ConfigureAwait(false);
+ result = await GetTokensAssociatedWithKey(null, expectCacheHit: true).ConfigureAwait(false);
+ }
+
private enum CacheType
{
InMemory,