Skip to content

Commit

Permalink
[VM, VMSS MSI]: Adding support for IMDS based MSI token retrival (#263)
Browse files Browse the repository at this point in the history
* Adding support for IMDS based MSI token retrival

* Addressing review comments: Move MSI versions and endpoint to const, using delay provider from SdkContext
  • Loading branch information
anuchandy authored and Hovsep committed Apr 12, 2018
1 parent 0b9ffd1 commit 6bc1c0a
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using Microsoft.Azure.Management.ResourceManager.Fluent.Core;
using System;

namespace Microsoft.Azure.Management.ResourceManager.Fluent.Authentication
{
Expand All @@ -26,6 +27,7 @@ public MSIResourceType ResourceType
/// <summary>
/// Get or Set the MSI extension port to retrieve access token from.
/// </summary>
[Obsolete("Port is used for MSI VM extension based login, login using MSI VM extension is deprecated infavour of IMDS based login")]
public int? Port
{
get; set;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
using Microsoft.Rest;
using Newtonsoft.Json;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
Expand All @@ -14,13 +16,24 @@
namespace Microsoft.Azure.Management.ResourceManager.Fluent.Authentication
{
/// <summary>
/// TokenProvider that can retrieve AD acess token from the local MSI port.
/// TokenProvider that can retrieve AD acess token from the local MSI port & IMDS service (for VM & VMSS) and
/// from environment (for AppService).
/// </summary>
public class MSITokenProvider : ITokenProvider, IBeta
{
private readonly IList<int> retrySlots = new List<int>(new int[] { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765 });
private readonly int maxRetry;
private ConcurrentDictionary<string, MSIToken> cache = new ConcurrentDictionary<string, MSIToken>();
private static SemaphoreSlim semaphoreSlim = new SemaphoreSlim(1, 1);

private readonly string resource;
private readonly MSILoginInformation msiLoginInformation;

private const string imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token";
private const string imdsMsiApiVersion = "2018-02-01";
private const string appServiceMsiApiVersion = "2017-09-01";


/// <summary>
/// Creates MSITokenProvider.
/// </summary>
Expand All @@ -45,7 +58,49 @@ public async Task<AuthenticationHeaderValue> GetAuthenticationHeaderAsync(Cancel

private async Task<AuthenticationHeaderValue> GetAuthenticationHeaderForVirtualMachineAsync(string resource, CancellationToken cancellationToken = default(CancellationToken))
{
int port = msiLoginInformation.Port == null ? 50342 : msiLoginInformation.Port.Value;
if (msiLoginInformation.Port != null)
{
// Token retrival from VM extension will be removed in the next v1.10 release as IMDS service
// replaces VM extension. MsiLoginInformation.Port is marked as deprecated and will be removed
//
var token = await GetTokenFromMSIExtensionAsync(msiLoginInformation.Port.Value,
resource,
cancellationToken);
return new AuthenticationHeaderValue(token.TokenType, token.AccessToken);
}
else
{
var token = await GetTokenFromIMDSEndpointAsync(resource, cancellationToken);
return new AuthenticationHeaderValue(token.TokenType, token.AccessToken);
}
}

private async Task<AuthenticationHeaderValue> GetAuthenticationHeaderForAppServiceAsync(string resource, CancellationToken cancellationToken = default(CancellationToken))
{
var endpoint = Environment.GetEnvironmentVariable("MSI_ENDPOINT") ?? throw new ArgumentNullException("MSI_ENDPOINT");
var secret = Environment.GetEnvironmentVariable("MSI_SECRET") ?? throw new ArgumentNullException("MSI_SECRET");
HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Get, $"{endpoint}?resource={resource}&api-version={MSITokenProvider.appServiceMsiApiVersion}");
msiRequest.Headers.Add("Metadata", "true");
msiRequest.Headers.Add("Secret", secret);

var msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken);
string content = await msiResponse.Content.ReadAsStringAsync();
dynamic loginInfo = JsonConvert.DeserializeObject(content);
string tokenType = loginInfo.token_type;
if (tokenType == null)
{
throw MSILoginException.TokenTypeNotFound(content);
}
string accessToken = loginInfo.access_token;
if (accessToken == null)
{
throw MSILoginException.AcessTokenNotFound(content);
}
return new AuthenticationHeaderValue(tokenType, accessToken);
}

private async Task<MSIToken> GetTokenFromMSIExtensionAsync(int port, string resource, CancellationToken cancellationToken)
{
HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Post, $"http://localhost:{port}/oauth2/token");
msiRequest.Headers.Add("Metadata", "true");

Expand All @@ -65,47 +120,176 @@ public async Task<AuthenticationHeaderValue> GetAuthenticationHeaderAsync(Cancel
{
parameters.Add("msi_res_id", this.msiLoginInformation.UserAssignedIdentityResourceId);
}
else
{
throw new ArgumentException("MSI: UserAssignedIdentityObjectId, UserAssignedIdentityClientId or UserAssignedIdentityResourceId must be set");
}

msiRequest.Content = new FormUrlEncodedContent(parameters);

var msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken);
string content = await msiResponse.Content.ReadAsStringAsync();
dynamic loginInfo = JsonConvert.DeserializeObject(content);
string tokenType = loginInfo.token_type;
if (tokenType == null)
if (loginInfo.access_token == null)
{
throw MSILoginException.TokenTypeNotFound(content);
throw MSILoginException.AcessTokenNotFound(content);
}
string accessToken = loginInfo.access_token;
if (accessToken == null)
if (loginInfo.token_type == null)
{
throw MSILoginException.AcessTokenNotFound(content);
throw MSILoginException.TokenTypeNotFound(content);
}
return new AuthenticationHeaderValue(tokenType, accessToken);
//
MSIToken msiToken = new MSIToken
{
AccessToken = loginInfo.access_token,
TokenType = loginInfo.token_type
};
return msiToken;
}

private async Task<AuthenticationHeaderValue> GetAuthenticationHeaderForAppServiceAsync(string resource, CancellationToken cancellationToken = default(CancellationToken))
private async Task<MSIToken> GetTokenFromIMDSEndpointAsync(string resource, CancellationToken cancellationToken)
{
var endpoint = Environment.GetEnvironmentVariable("MSI_ENDPOINT") ?? throw new ArgumentNullException("MSI_ENDPOINT");
var secret = Environment.GetEnvironmentVariable("MSI_SECRET") ?? throw new ArgumentNullException("MSI_SECRET");
HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Get, $"{endpoint}?resource={resource}&api-version=2017-09-01");
msiRequest.Headers.Add("Metadata", "true");
msiRequest.Headers.Add("Secret", secret);
// First hit cache
//
if (cache.TryGetValue(resource, out MSIToken token) == true && !token.IsExpired)
{
return token;
}

var msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken);
string content = await msiResponse.Content.ReadAsStringAsync();
dynamic loginInfo = JsonConvert.DeserializeObject(content);
string tokenType = loginInfo.token_type;
if (tokenType == null)
// if cache miss then retrieve from IMDS endpoint with retry
//
await semaphoreSlim.WaitAsync();
try
{
throw MSILoginException.TokenTypeNotFound(content);
// Try hit cache once again in case another thread already updated the cache while this thread was waiting
//
if (cache.TryGetValue(resource, out token) == true && !token.IsExpired)
{
return token;
}
else
{
token = await this.RetrieveTokenFromIMDSWithRetryAsync(resource, cancellationToken);
cache.AddOrUpdate(resource, token, (key, oldValue) => token);
return token;
}
}
string accessToken = loginInfo.access_token;
if (accessToken == null)
finally
{
throw MSILoginException.AcessTokenNotFound(content);
semaphoreSlim.Release();
}
}

private async Task<MSIToken> RetrieveTokenFromIMDSWithRetryAsync(string resource, CancellationToken cancellationToken)
{
var uriBuilder = new UriBuilder(MSITokenProvider.imdsEndpoint);
//
var query = new Dictionary<string, string>
{
["api-version"] = MSITokenProvider.imdsMsiApiVersion,
["resource"] = resource
};
if (this.msiLoginInformation.UserAssignedIdentityObjectId != null)
{
query["object_id"] = this.msiLoginInformation.UserAssignedIdentityObjectId;
}
else if (this.msiLoginInformation.UserAssignedIdentityClientId != null)
{
query["client_id"] = this.msiLoginInformation.UserAssignedIdentityClientId;
}
else if (this.msiLoginInformation.UserAssignedIdentityResourceId != null)
{
query["msi_res_id"] = this.msiLoginInformation.UserAssignedIdentityResourceId;
}
else
{
throw new ArgumentException("MSI: UserAssignedIdentityObjectId, UserAssignedIdentityClientId or UserAssignedIdentityResourceId must be set");
}
uriBuilder.Query = await new FormUrlEncodedContent(query).ReadAsStringAsync();
string url = uriBuilder.ToString();
//
int retry = 1;
while (retry <= maxRetry)
{
//
using (HttpRequestMessage msiRequest = new HttpRequestMessage(HttpMethod.Get, url))
{
msiRequest.Headers.Add("Metadata", "true");
using (HttpResponseMessage msiResponse = await (new HttpClient()).SendAsync(msiRequest, cancellationToken))
{
int statusCode = ((int)msiResponse.StatusCode);
if (ShouldRetry(statusCode))
{

int retryTimeout = retrySlots[new Random().Next(retry)];
await SdkContext.DelayProvider.DelayAsync(retryTimeout * 1000, cancellationToken);
retry++;
}
else if (statusCode != 200)
{
string content = await msiResponse.Content.ReadAsStringAsync();
throw new HttpRequestException($"Code: {statusCode} ReasonReasonPhrase: {msiResponse.ReasonPhrase} Body: {content}");
}
else
{
string content = await msiResponse.Content.ReadAsStringAsync();
dynamic loginInfo = JsonConvert.DeserializeObject(content);
if (loginInfo.access_token == null)
{
throw MSILoginException.AcessTokenNotFound(content);
}
if (loginInfo.token_type == null)
{
throw MSILoginException.TokenTypeNotFound(content);
}
//
MSIToken msiToken = new MSIToken
{
AccessToken = loginInfo.access_token,
ExpireOn = loginInfo.expires_on,
TokenType = loginInfo.token_type
};
return msiToken;
}
}
}
}
throw new MSIMaxRetryReachedException(maxRetry);
}

private static bool ShouldRetry(int statusCode)
{
return (statusCode == 429 || statusCode == 404 || (statusCode >= 500 && statusCode <= 599));
}

private class MSIToken
{
private static DateTime epoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);

public string TokenType { get; set; }

public string AccessToken { get; set; }

public string ExpireOn { get; set; }

public bool IsExpired
{
get
{
if (this.ExpireOn == null)
{
return true;
}
else if (!Int32.TryParse(this.ExpireOn, out int iexpireOn))
{
return true;
}
else
{
return DateTime.UtcNow.AddMinutes(5).CompareTo(epoch.AddSeconds(iexpireOn)) > 0;
}
}
}
return new AuthenticationHeaderValue(tokenType, accessToken);
}
}

Expand All @@ -123,4 +307,11 @@ public MSITokenProvider Create(string resource)
return new MSITokenProvider(resource, msiLoginInformation);
}
}

public class MSIMaxRetryReachedException : Exception
{
public MSIMaxRetryReachedException(int maxRetry) : base($"MSI: Failed to acquire tokens after retrying %{ maxRetry} times")
{
}
}
}

0 comments on commit 6bc1c0a

Please sign in to comment.