From 6bc1c0ac0e2ac62ea1a897cf9207ac5ec54e8d04 Mon Sep 17 00:00:00 2001 From: Anu Thomas Chandy Date: Thu, 12 Apr 2018 16:46:02 -0700 Subject: [PATCH] [VM, VMSS MSI]: Adding support for IMDS based MSI token retrival (#263) * Adding support for IMDS based MSI token retrival * Addressing review comments: Move MSI versions and endpoint to const, using delay provider from SdkContext --- .../Authentication/MSILoginInformation.cs | 2 + .../Authentication/MSITokenProvider.cs | 241 ++++++++++++++++-- 2 files changed, 218 insertions(+), 25 deletions(-) diff --git a/src/ResourceManagement/ResourceManager/Authentication/MSILoginInformation.cs b/src/ResourceManagement/ResourceManager/Authentication/MSILoginInformation.cs index 158d3b4f3b6..e9b02b45009 100644 --- a/src/ResourceManagement/ResourceManager/Authentication/MSILoginInformation.cs +++ b/src/ResourceManagement/ResourceManager/Authentication/MSILoginInformation.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using Microsoft.Azure.Management.ResourceManager.Fluent.Core; +using System; namespace Microsoft.Azure.Management.ResourceManager.Fluent.Authentication { @@ -26,6 +27,7 @@ public MSIResourceType ResourceType /// /// Get or Set the MSI extension port to retrieve access token from. /// + [Obsolete("Port is used for MSI VM extension based login, login using MSI VM extension is deprecated infavour of IMDS based login")] public int? Port { get; set; diff --git a/src/ResourceManagement/ResourceManager/Authentication/MSITokenProvider.cs b/src/ResourceManagement/ResourceManager/Authentication/MSITokenProvider.cs index 4df7a4797c0..56cafab7917 100644 --- a/src/ResourceManagement/ResourceManager/Authentication/MSITokenProvider.cs +++ b/src/ResourceManagement/ResourceManager/Authentication/MSITokenProvider.cs @@ -5,7 +5,9 @@ using Microsoft.Rest; using Newtonsoft.Json; using System; +using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Net.Http; using System.Net.Http.Headers; using System.Threading; @@ -14,13 +16,24 @@ namespace Microsoft.Azure.Management.ResourceManager.Fluent.Authentication { /// - /// TokenProvider that can retrieve AD acess token from the local MSI port. + /// TokenProvider that can retrieve AD acess token from the local MSI port & IMDS service (for VM & VMSS) and + /// from environment (for AppService). /// public class MSITokenProvider : ITokenProvider, IBeta { + private readonly IList retrySlots = new List(new int[] { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765 }); + private readonly int maxRetry; + private ConcurrentDictionary cache = new ConcurrentDictionary(); + private static SemaphoreSlim semaphoreSlim = new SemaphoreSlim(1, 1); + private readonly string resource; private readonly MSILoginInformation msiLoginInformation; + private const string imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"; + private const string imdsMsiApiVersion = "2018-02-01"; + private const string appServiceMsiApiVersion = "2017-09-01"; + + /// /// Creates MSITokenProvider. /// @@ -45,7 +58,49 @@ public async Task GetAuthenticationHeaderAsync(Cancel private async Task GetAuthenticationHeaderForVirtualMachineAsync(string resource, CancellationToken cancellationToken = default(CancellationToken)) { - int port = msiLoginInformation.Port == null ? 50342 : msiLoginInformation.Port.Value; + if (msiLoginInformation.Port != null) + { + // Token retrival from VM extension will be removed in the next v1.10 release as IMDS service + // replaces VM extension. MsiLoginInformation.Port is marked as deprecated and will be removed + // + var token = await GetTokenFromMSIExtensionAsync(msiLoginInformation.Port.Value, + resource, + cancellationToken); + return new AuthenticationHeaderValue(token.TokenType, token.AccessToken); + } + else + { + var token = await GetTokenFromIMDSEndpointAsync(resource, cancellationToken); + return new AuthenticationHeaderValue(token.TokenType, token.AccessToken); + } + } + + private async Task GetAuthenticationHeaderForAppServiceAsync(string resource, CancellationToken cancellationToken = default(CancellationToken)) + { + var endpoint = Environment.GetEnvironmentVariable("MSI_ENDPOINT") ?? throw new ArgumentNullException("MSI_ENDPOINT"); + var secret = Environment.GetEnvironmentVariable("MSI_SECRET") ?? throw new ArgumentNullException("MSI_SECRET"); + HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Get, $"{endpoint}?resource={resource}&api-version={MSITokenProvider.appServiceMsiApiVersion}"); + msiRequest.Headers.Add("Metadata", "true"); + msiRequest.Headers.Add("Secret", secret); + + var msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken); + string content = await msiResponse.Content.ReadAsStringAsync(); + dynamic loginInfo = JsonConvert.DeserializeObject(content); + string tokenType = loginInfo.token_type; + if (tokenType == null) + { + throw MSILoginException.TokenTypeNotFound(content); + } + string accessToken = loginInfo.access_token; + if (accessToken == null) + { + throw MSILoginException.AcessTokenNotFound(content); + } + return new AuthenticationHeaderValue(tokenType, accessToken); + } + + private async Task GetTokenFromMSIExtensionAsync(int port, string resource, CancellationToken cancellationToken) + { HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Post, $"http://localhost:{port}/oauth2/token"); msiRequest.Headers.Add("Metadata", "true"); @@ -65,47 +120,176 @@ public async Task GetAuthenticationHeaderAsync(Cancel { parameters.Add("msi_res_id", this.msiLoginInformation.UserAssignedIdentityResourceId); } + else + { + throw new ArgumentException("MSI: UserAssignedIdentityObjectId, UserAssignedIdentityClientId or UserAssignedIdentityResourceId must be set"); + } msiRequest.Content = new FormUrlEncodedContent(parameters); var msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken); string content = await msiResponse.Content.ReadAsStringAsync(); dynamic loginInfo = JsonConvert.DeserializeObject(content); - string tokenType = loginInfo.token_type; - if (tokenType == null) + if (loginInfo.access_token == null) { - throw MSILoginException.TokenTypeNotFound(content); + throw MSILoginException.AcessTokenNotFound(content); } - string accessToken = loginInfo.access_token; - if (accessToken == null) + if (loginInfo.token_type == null) { - throw MSILoginException.AcessTokenNotFound(content); + throw MSILoginException.TokenTypeNotFound(content); } - return new AuthenticationHeaderValue(tokenType, accessToken); + // + MSIToken msiToken = new MSIToken + { + AccessToken = loginInfo.access_token, + TokenType = loginInfo.token_type + }; + return msiToken; } - private async Task GetAuthenticationHeaderForAppServiceAsync(string resource, CancellationToken cancellationToken = default(CancellationToken)) + private async Task GetTokenFromIMDSEndpointAsync(string resource, CancellationToken cancellationToken) { - var endpoint = Environment.GetEnvironmentVariable("MSI_ENDPOINT") ?? throw new ArgumentNullException("MSI_ENDPOINT"); - var secret = Environment.GetEnvironmentVariable("MSI_SECRET") ?? throw new ArgumentNullException("MSI_SECRET"); - HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Get, $"{endpoint}?resource={resource}&api-version=2017-09-01"); - msiRequest.Headers.Add("Metadata", "true"); - msiRequest.Headers.Add("Secret", secret); + // First hit cache + // + if (cache.TryGetValue(resource, out MSIToken token) == true && !token.IsExpired) + { + return token; + } - var msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken); - string content = await msiResponse.Content.ReadAsStringAsync(); - dynamic loginInfo = JsonConvert.DeserializeObject(content); - string tokenType = loginInfo.token_type; - if (tokenType == null) + // if cache miss then retrieve from IMDS endpoint with retry + // + await semaphoreSlim.WaitAsync(); + try { - throw MSILoginException.TokenTypeNotFound(content); + // Try hit cache once again in case another thread already updated the cache while this thread was waiting + // + if (cache.TryGetValue(resource, out token) == true && !token.IsExpired) + { + return token; + } + else + { + token = await this.RetrieveTokenFromIMDSWithRetryAsync(resource, cancellationToken); + cache.AddOrUpdate(resource, token, (key, oldValue) => token); + return token; + } } - string accessToken = loginInfo.access_token; - if (accessToken == null) + finally { - throw MSILoginException.AcessTokenNotFound(content); + semaphoreSlim.Release(); + } + } + + private async Task RetrieveTokenFromIMDSWithRetryAsync(string resource, CancellationToken cancellationToken) + { + var uriBuilder = new UriBuilder(MSITokenProvider.imdsEndpoint); + // + var query = new Dictionary + { + ["api-version"] = MSITokenProvider.imdsMsiApiVersion, + ["resource"] = resource + }; + if (this.msiLoginInformation.UserAssignedIdentityObjectId != null) + { + query["object_id"] = this.msiLoginInformation.UserAssignedIdentityObjectId; + } + else if (this.msiLoginInformation.UserAssignedIdentityClientId != null) + { + query["client_id"] = this.msiLoginInformation.UserAssignedIdentityClientId; + } + else if (this.msiLoginInformation.UserAssignedIdentityResourceId != null) + { + query["msi_res_id"] = this.msiLoginInformation.UserAssignedIdentityResourceId; + } + else + { + throw new ArgumentException("MSI: UserAssignedIdentityObjectId, UserAssignedIdentityClientId or UserAssignedIdentityResourceId must be set"); + } + uriBuilder.Query = await new FormUrlEncodedContent(query).ReadAsStringAsync(); + string url = uriBuilder.ToString(); + // + int retry = 1; + while (retry <= maxRetry) + { + // + using (HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Get, url)) + { + msiRequest.Headers.Add("Metadata", "true"); + using (HttpResponseMessage msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken)) + { + int statusCode = ((int)msiResponse.StatusCode); + if (ShouldRetry(statusCode)) + { + + int retryTimeout = retrySlots[new Random().Next(retry)]; + await SdkContext.DelayProvider.DelayAsync(retryTimeout * 1000, cancellationToken); + retry++; + } + else if (statusCode != 200) + { + string content = await msiResponse.Content.ReadAsStringAsync(); + throw new HttpRequestException($"Code: {statusCode} ReasonReasonPhrase: {msiResponse.ReasonPhrase} Body: {content}"); + } + else + { + string content = await msiResponse.Content.ReadAsStringAsync(); + dynamic loginInfo = JsonConvert.DeserializeObject(content); + if (loginInfo.access_token == null) + { + throw MSILoginException.AcessTokenNotFound(content); + } + if (loginInfo.token_type == null) + { + throw MSILoginException.TokenTypeNotFound(content); + } + // + MSIToken msiToken = new MSIToken + { + AccessToken = loginInfo.access_token, + ExpireOn = loginInfo.expires_on, + TokenType = loginInfo.token_type + }; + return msiToken; + } + } + } + } + throw new MSIMaxRetryReachedException(maxRetry); + } + + private static bool ShouldRetry(int statusCode) + { + return (statusCode == 429 || statusCode == 404 || (statusCode >= 500 && statusCode <= 599)); + } + + private class MSIToken + { + private static DateTime epoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); + + public string TokenType { get; set; } + + public string AccessToken { get; set; } + + public string ExpireOn { get; set; } + + public bool IsExpired + { + get + { + if (this.ExpireOn == null) + { + return true; + } + else if (!Int32.TryParse(this.ExpireOn, out int iexpireOn)) + { + return true; + } + else + { + return DateTime.UtcNow.AddMinutes(5).CompareTo(epoch.AddSeconds(iexpireOn)) > 0; + } + } } - return new AuthenticationHeaderValue(tokenType, accessToken); } } @@ -123,4 +307,11 @@ public MSITokenProvider Create(string resource) return new MSITokenProvider(resource, msiLoginInformation); } } + + public class MSIMaxRetryReachedException : Exception + { + public MSIMaxRetryReachedException(int maxRetry) : base($"MSI: Failed to acquire tokens after retrying %{ maxRetry} times") + { + } + } }