diff --git a/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs b/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs index 238e41a9b..04aa56333 100644 --- a/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs +++ b/tests/Microsoft.Identity.Web.Test/CacheExtensionsTests.cs @@ -5,13 +5,16 @@ using System.Globalization; using System.Threading.Tasks; using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using Microsoft.Identity.Client; using Microsoft.Identity.Client.Cache; using Microsoft.Identity.Web.Test.Common; using Microsoft.Identity.Web.Test.Common.Mocks; using Microsoft.Identity.Web.TokenCacheProviders.Distributed; +using Microsoft.Identity.Web.TokenCacheProviders.InMemory; using Microsoft.IdentityModel.Abstractions; using Xunit; @@ -190,6 +193,111 @@ public async Task SingletonMsal_ResultsInCorrectCacheEntries_Test() } } + #region CacheKeyExtensibility test + private const int TokenCacheMemoryLimitInMb = 100; + private static MemoryCache s_memoryCache = InitiatlizeMemoryCache(); + + private static MemoryCache InitiatlizeMemoryCache() + { + // For 100 MB limit ... ~2KB per token entry means 50,000 entries + var options = Options.Create(new MemoryCacheOptions() { SizeLimit = (TokenCacheMemoryLimitInMb / 2) * 1000 }); + var cache = new MemoryCache(options); + + return cache; + } + + /// + /// Token cache for MSAL based on MemoryCache, which can be partitioned by an additional key. + /// For app tokens, the default key is ClientID + TenantID (and MSAL also looks for resource). + /// + private class PartitionedMsalTokenMemoryCacheProvider : MsalMemoryTokenCacheProvider + { + private readonly string? _cacheKeySuffix; + + /// + /// Ctor + /// + /// A memory cache which can be configured for max size etc. + /// Additional cache options, which canbe ignored for app tokens. + /// An aditional partition key. If let null, the original cache scoping is used (clientID, tenantID). MSAL also looks for resource. + public PartitionedMsalTokenMemoryCacheProvider( + IMemoryCache memoryCache, + IOptions cacheOptions, + string? cachePartition) : base(memoryCache, cacheOptions) + { + _cacheKeySuffix = cachePartition; + } + + public override string GetSuggestedCacheKey(TokenCacheNotificationArgs args) + { + return base.GetSuggestedCacheKey(args) + (_cacheKeySuffix ?? ""); + } + } + + private async Task GetTokensAssociatedWithKey(string? cachePartition, bool expectCacheHit) + { + MockHttpMessageHandler? handler = null; + MockHttpClientFactory? mockHttpClient = null; + try + { + + if (expectCacheHit == false) + { + mockHttpClient = new MockHttpClientFactory(); + handler = mockHttpClient.AddMockHandler(MockHttpCreator.CreateClientCredentialTokenHandler()); + } + + var msalMemoryTokenCacheProvider = + new PartitionedMsalTokenMemoryCacheProvider( + s_memoryCache, + Options.Create(new MsalMemoryTokenCacheOptions()), + cachePartition: cachePartition); + + var confidentialApp = ConfidentialClientApplicationBuilder + .Create(TestConstants.ClientId) + .WithAuthority(TestConstants.AuthorityCommonTenant) + .WithHttpClientFactory(mockHttpClient) + .WithInstanceDiscovery(false) + .WithClientSecret(TestConstants.ClientSecret) + .Build(); + + await msalMemoryTokenCacheProvider.InitializeAsync(confidentialApp.AppTokenCache).ConfigureAwait(false); + + AuthenticationResult result = await confidentialApp + .AcquireTokenForClient(["https://graph.microsoft.com/.default"]) + .ExecuteAsync() + .ConfigureAwait(false); + + Assert.Equal( + expectCacheHit ? + TokenSource.Cache : + TokenSource.IdentityProvider, + result.AuthenticationResultMetadata.TokenSource); + + return result; + + } + finally + { + handler?.Dispose(); + mockHttpClient?.Dispose(); + } + } + + #endregion + + [Fact] + public async Task CacheKeyExtensibility() + { + var result = await GetTokensAssociatedWithKey("foo", expectCacheHit: false).ConfigureAwait(false); + result = await GetTokensAssociatedWithKey("bar", expectCacheHit: false).ConfigureAwait(false); + result = await GetTokensAssociatedWithKey(null, expectCacheHit: false).ConfigureAwait(false); + + result = await GetTokensAssociatedWithKey("foo", expectCacheHit: true).ConfigureAwait(false); + result = await GetTokensAssociatedWithKey("bar", expectCacheHit: true).ConfigureAwait(false); + result = await GetTokensAssociatedWithKey(null, expectCacheHit: true).ConfigureAwait(false); + } + private enum CacheType { InMemory,