Skip to content

Commit

Permalink
Add cloud instance name validation (#2804)
Browse files Browse the repository at this point in the history
* Added EnableAadSigningKeyValidation extension method to include cloud instance validation

* Added tests to validate cloud instance validation

* Added tests to check that Issuer valdiation is still used

* SecurityTokenInvalidCloudInstanceException extends SecurityTokenInvalidSigningKeyException and used Ordinal string comparison

* Removed duplicated tests

* Suppress the obsolete warning CS0618

* Rid of cloudInstanceName parameter for extension method

* Extracted conditions to improve readability

* call ValidateSigningKeyCloudInstanceName before custom delegates

* Reverted changes in JsonWebTokenHandler.ValidateSignatureTests.cs

* Renamed exception

* Added GetJsonWebKeyBySecurityKey that using a loop to find a key

* Renamed CloudInstanceName to CloudInstance for public members

* Reverted renaming of ValidateIssuerSigningKeyTests to avoid breaking changes

* Fix tests by create new instance of TokenValidationParameters in order to avoid a cycling.

* Assigned delegates properly to avoid infinite recursion

---------

Co-authored-by: Alex Holub <alexholub@microsoft.com>
  • Loading branch information
alexholub113 and Alex Holub authored Sep 20, 2024
1 parent a9380ab commit f0d09d4
Show file tree
Hide file tree
Showing 7 changed files with 525 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Runtime.Serialization;

namespace Microsoft.IdentityModel.Tokens
{
/// <summary>
/// This exception is thrown when the cloud instance of the signing key was not matched with the cloud instance from configuration.
/// </summary>
[Serializable]
public class SecurityTokenInvalidCloudInstanceException : SecurityTokenInvalidSigningKeyException
{
[NonSerialized]
const string _Prefix = "Microsoft.IdentityModel." + nameof(SecurityTokenInvalidCloudInstanceException) + ".";

[NonSerialized]
const string _SigningKeyCloudInstanceNameKey = _Prefix + nameof(SigningKeyCloudInstanceName);

[NonSerialized]
const string _ConfigurationCloudInstanceNameKey = _Prefix + nameof(ConfigurationCloudInstanceName);

/// <summary>
/// Gets or sets the cloud instance name of the signing key that created the validation exception.
/// </summary>
public string SigningKeyCloudInstanceName { get; set; }

/// <summary>
/// Gets or sets the cloud instance name from the configuration that did not match the cloud instance name of the signing key.
/// </summary>
public string ConfigurationCloudInstanceName { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="SecurityTokenInvalidCloudInstanceException"/> class.
/// </summary>
public SecurityTokenInvalidCloudInstanceException()
: base()
{
}

/// <summary>
/// Initializes a new instance of the <see cref="SecurityTokenInvalidCloudInstanceException"/> class.
/// </summary>
/// <param name="message">Addtional information to be included in the exception and displayed to user.</param>
public SecurityTokenInvalidCloudInstanceException(string message)
: base(message)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="SecurityTokenInvalidCloudInstanceException"/> class.
/// </summary>
/// <param name="message">Addtional information to be included in the exception and displayed to user.</param>
/// <param name="innerException">A <see cref="Exception"/> that represents the root cause of the exception.</param>
public SecurityTokenInvalidCloudInstanceException(string message, Exception innerException)
: base(message, innerException)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="SecurityTokenInvalidCloudInstanceException"/> class.
/// </summary>
/// <param name="info">the <see cref="SerializationInfo"/> that holds the serialized object data.</param>
/// <param name="context">The contextual information about the source or destination.</param>
#if NET8_0_OR_GREATER
[Obsolete("Formatter-based serialization is obsolete", DiagnosticId = "SYSLIB0051")]
#endif
protected SecurityTokenInvalidCloudInstanceException(SerializationInfo info, StreamingContext context)
: base(info, context)
{
SerializationInfoEnumerator enumerator = info.GetEnumerator();
while (enumerator.MoveNext())
{
switch (enumerator.Name)
{
case _SigningKeyCloudInstanceNameKey:
SigningKeyCloudInstanceName = info.GetString(_SigningKeyCloudInstanceNameKey);
break;

case _ConfigurationCloudInstanceNameKey:
ConfigurationCloudInstanceName = info.GetString(_ConfigurationCloudInstanceNameKey);
break;

default:
// Ignore other fields.
break;
}
}
}

/// <inheritdoc/>
#if NET8_0_OR_GREATER
[Obsolete("Formatter-based serialization is obsolete", DiagnosticId = "SYSLIB0051")]
#endif
public override void GetObjectData(SerializationInfo info, StreamingContext context)
{
base.GetObjectData(info, context);

if (!string.IsNullOrEmpty(SigningKeyCloudInstanceName))
info.AddValue(_SigningKeyCloudInstanceNameKey, SigningKeyCloudInstanceName);

if (!string.IsNullOrEmpty(ConfigurationCloudInstanceName))
info.AddValue(_ConfigurationCloudInstanceNameKey, ConfigurationCloudInstanceName);
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.IdentityModel.Tokens/SecurityKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Licensed under the MIT License.

using System;
using Microsoft.IdentityModel.Logging;
using System.Text.Json.Serialization;
using Microsoft.IdentityModel.Logging;

namespace Microsoft.IdentityModel.Tokens
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Logging;
using Microsoft.IdentityModel.Protocols.OpenIdConnect;
Expand All @@ -17,6 +16,35 @@ namespace Microsoft.IdentityModel.Validators
/// </summary>
public static class AadTokenValidationParametersExtension
{
private const string CloudInstanceNameKey = "cloud_instance_name";

/// <summary>
/// Enables validation of the cloud instance of the Microsoft Entra ID token signing keys.
/// </summary>
/// <param name="tokenValidationParameters">The <see cref="TokenValidationParameters"/> that are used to validate the token.</param>
public static void EnableEntraIdSigningKeyCloudInstanceValidation(this TokenValidationParameters tokenValidationParameters)
{
if (tokenValidationParameters == null)
throw LogHelper.LogArgumentNullException(nameof(tokenValidationParameters));

IssuerSigningKeyValidatorUsingConfiguration userProvidedIssuerSigningKeyValidatorUsingConfiguration = tokenValidationParameters.IssuerSigningKeyValidatorUsingConfiguration;
IssuerSigningKeyValidator userProvidedIssuerSigningKeyValidator = tokenValidationParameters.IssuerSigningKeyValidator;

tokenValidationParameters.IssuerSigningKeyValidatorUsingConfiguration = (securityKey, securityToken, tvp, config) =>
{
ValidateSigningKeyCloudInstance(securityKey, config);
// preserve and run provided logic
if (userProvidedIssuerSigningKeyValidatorUsingConfiguration != null)
return userProvidedIssuerSigningKeyValidatorUsingConfiguration(securityKey, securityToken, tvp, config);
if (userProvidedIssuerSigningKeyValidator != null)
return userProvidedIssuerSigningKeyValidator(securityKey, securityToken, tvp);
return true;
};
}

/// <summary>
/// Enables the validation of the issuer of the signing keys used by the Microsoft identity platform (AAD) against the issuer of the token.
/// </summary>
Expand All @@ -26,8 +54,8 @@ public static void EnableAadSigningKeyIssuerValidation(this TokenValidationParam
if (tokenValidationParameters == null)
throw LogHelper.LogArgumentNullException(nameof(tokenValidationParameters));

var userProvidedIssuerSigningKeyValidatorUsingConfiguration = tokenValidationParameters.IssuerSigningKeyValidatorUsingConfiguration;
var userProvidedIssuerSigningKeyValidator = tokenValidationParameters.IssuerSigningKeyValidator;
IssuerSigningKeyValidatorUsingConfiguration userProvidedIssuerSigningKeyValidatorUsingConfiguration = tokenValidationParameters.IssuerSigningKeyValidatorUsingConfiguration;
IssuerSigningKeyValidator userProvidedIssuerSigningKeyValidator = tokenValidationParameters.IssuerSigningKeyValidator;

tokenValidationParameters.IssuerSigningKeyValidatorUsingConfiguration = (securityKey, securityToken, tvp, config) =>
{
Expand All @@ -49,8 +77,8 @@ public static void EnableAadSigningKeyIssuerValidation(this TokenValidationParam
/// </summary>
/// <param name="securityKey">The <see cref="SecurityKey"/> that signed the <see cref="SecurityToken"/>.</param>
/// <param name="securityToken">The <see cref="SecurityToken"/> being validated, could be a JwtSecurityToken or JsonWebToken.</param>
/// <param name="configuration">The <see cref="OpenIdConnectConfiguration"/> provided.</param>
/// <returns><c>true</c> if the issuer signing key is valid; otherwise, <c>false</c>.</returns>
/// <param name="configuration">The <see cref="BaseConfiguration"/> provided.</param>
/// <returns><c>true</c> if the issuer of the signing key is valid; otherwise, <c>false</c>.</returns>
internal static bool ValidateIssuerSigningKey(SecurityKey securityKey, SecurityToken securityToken, BaseConfiguration configuration)
{
if (securityKey == null)
Expand All @@ -59,18 +87,17 @@ internal static bool ValidateIssuerSigningKey(SecurityKey securityKey, SecurityT
if (securityToken == null)
throw LogHelper.LogArgumentNullException(nameof(securityToken));

var openIdConnectConfiguration = configuration as OpenIdConnectConfiguration;
if (openIdConnectConfiguration == null)
if (configuration is not OpenIdConnectConfiguration openIdConnectConfiguration)
return true;

var matchedKeyFromConfig = openIdConnectConfiguration.JsonWebKeySet?.Keys.FirstOrDefault(x => x != null && x.Kid == securityKey.KeyId);
JsonWebKey matchedKeyFromConfig = GetJsonWebKeyBySecurityKey(openIdConnectConfiguration, securityKey);
if (matchedKeyFromConfig != null && matchedKeyFromConfig.AdditionalData.TryGetValue(OpenIdProviderMetadataNames.Issuer, out object value))
{
var signingKeyIssuer = value as string;
string signingKeyIssuer = value as string;
if (string.IsNullOrWhiteSpace(signingKeyIssuer))
return true;

var tenantIdFromToken = GetTid(securityToken);
string tenantIdFromToken = GetTid(securityToken);
if (string.IsNullOrEmpty(tenantIdFromToken))
{
if (AppContextSwitches.DontFailOnMissingTid)
Expand All @@ -79,22 +106,22 @@ internal static bool ValidateIssuerSigningKey(SecurityKey securityKey, SecurityT
throw LogHelper.LogExceptionMessage(new SecurityTokenInvalidIssuerException(LogMessages.IDX40009));
}

var tokenIssuer = securityToken.Issuer;
string tokenIssuer = securityToken.Issuer;

#if NET6_0_OR_GREATER
if (!string.IsNullOrEmpty(tokenIssuer) && !tokenIssuer.Contains(tenantIdFromToken, StringComparison.Ordinal))
throw LogHelper.LogExceptionMessage(new SecurityTokenInvalidIssuerException(LogHelper.FormatInvariant(LogMessages.IDX40004, LogHelper.MarkAsNonPII(tokenIssuer), LogHelper.MarkAsNonPII(tenantIdFromToken))));

// creating an effectiveSigningKeyIssuer is required as signingKeyIssuer might contain {tenantid}
var effectiveSigningKeyIssuer = signingKeyIssuer.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken, StringComparison.Ordinal);
var v2TokenIssuer = openIdConnectConfiguration.Issuer?.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken, StringComparison.Ordinal);
string effectiveSigningKeyIssuer = signingKeyIssuer.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken, StringComparison.Ordinal);
string v2TokenIssuer = openIdConnectConfiguration.Issuer?.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken, StringComparison.Ordinal);
#else
if (!string.IsNullOrEmpty(tokenIssuer) && !tokenIssuer.Contains(tenantIdFromToken))
throw LogHelper.LogExceptionMessage(new SecurityTokenInvalidIssuerException(LogHelper.FormatInvariant(LogMessages.IDX40004, LogHelper.MarkAsNonPII(tokenIssuer), LogHelper.MarkAsNonPII(tenantIdFromToken))));

// creating an effectiveSigningKeyIssuer is required as signingKeyIssuer might contain {tenantid}
var effectiveSigningKeyIssuer = signingKeyIssuer.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken);
var v2TokenIssuer = openIdConnectConfiguration.Issuer?.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken);
string effectiveSigningKeyIssuer = signingKeyIssuer.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken);
string v2TokenIssuer = openIdConnectConfiguration.Issuer?.Replace(AadIssuerValidator.TenantIdTemplate, tenantIdFromToken);
#endif

// comparing effectiveSigningKeyIssuer with v2TokenIssuer is required as well because of the following scenario:
Expand All @@ -109,6 +136,58 @@ internal static bool ValidateIssuerSigningKey(SecurityKey securityKey, SecurityT
return true;
}

/// <summary>
/// Validates the cloud instance of the signing key.
/// </summary>
/// <param name="securityKey">The <see cref="SecurityKey"/> that signed the <see cref="SecurityToken"/>.</param>
/// <param name="configuration">The <see cref="BaseConfiguration"/> provided.</param>
internal static void ValidateSigningKeyCloudInstance(SecurityKey securityKey, BaseConfiguration configuration)
{
if (securityKey == null)
return;

if (configuration is not OpenIdConnectConfiguration openIdConnectConfiguration)
return;

JsonWebKey matchedKeyFromConfig = GetJsonWebKeyBySecurityKey(openIdConnectConfiguration, securityKey);
if (matchedKeyFromConfig != null && matchedKeyFromConfig.AdditionalData.TryGetValue(CloudInstanceNameKey, out object value))
{
string signingKeyCloudInstanceName = value as string;
if (string.IsNullOrWhiteSpace(signingKeyCloudInstanceName))
return;

if (openIdConnectConfiguration.AdditionalData.TryGetValue(CloudInstanceNameKey, out object configurationCloudInstanceNameObjectValue))
{
string configurationCloudInstanceName = configurationCloudInstanceNameObjectValue as string;
if (string.IsNullOrWhiteSpace(configurationCloudInstanceName))
return;

if (!string.Equals(signingKeyCloudInstanceName, configurationCloudInstanceName, StringComparison.Ordinal))
throw LogHelper.LogExceptionMessage(
new SecurityTokenInvalidCloudInstanceException(LogHelper.FormatInvariant(LogMessages.IDX40012, LogHelper.MarkAsNonPII(signingKeyCloudInstanceName), LogHelper.MarkAsNonPII(configurationCloudInstanceName)))
{
ConfigurationCloudInstanceName = configurationCloudInstanceName,
SigningKeyCloudInstanceName = signingKeyCloudInstanceName,
SigningKey = securityKey,
});
}
}
}

private static JsonWebKey GetJsonWebKeyBySecurityKey(OpenIdConnectConfiguration configuration, SecurityKey securityKey)
{
if (configuration.JsonWebKeySet == null)
return null;

foreach (JsonWebKey key in configuration.JsonWebKeySet.Keys)
{
if (key.Kid == securityKey.KeyId)
return key;
}

return null;
}

private static string GetTid(SecurityToken securityToken)
{
switch (securityToken)
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.IdentityModel.Validators/LogMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ internal static class LogMessages
public const string IDX40009 = "IDX40009: Either the 'tid' claim was not found or it didn't have a value.";
public const string IDX40010 = "IDX40010: The SecurityToken must be a 'JsonWebToken' or 'JwtSecurityToken'";
public const string IDX40011 = "IDX40011: The SecurityToken has multiple instances of the '{0}' claim.";
public const string IDX40012 = "IDX40012: The cloud instance of the signing key: '{0}', does not match cloud instance from configuration: '{1}'.";
}
}
5 changes: 5 additions & 0 deletions test/Microsoft.IdentityModel.TestUtils/Default.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ public static string Issuer
get => "http://Default.Issuer.com";
}

public static string CloudInstanceName
{
get => "microsoftonline.com";
}

public static IEnumerable<string> Issuers
{
get => new List<string> {
Expand Down
5 changes: 5 additions & 0 deletions test/Microsoft.IdentityModel.TestUtils/ExpectedException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ public static ExpectedException SecurityTokenInvalidIssuerException(string subst
return new ExpectedException(typeof(SecurityTokenInvalidIssuerException), substringExpected, innerTypeExpected, propertiesExpected: propertiesExpected);
}

public static ExpectedException SecurityTokenInvalidCloudInstanceException(string substringExpected = null, Type innerTypeExpected = null, Dictionary<string, object> propertiesExpected = null)
{
return new ExpectedException(typeof(SecurityTokenInvalidCloudInstanceException), substringExpected, innerTypeExpected, propertiesExpected: propertiesExpected);
}

public static ExpectedException SecurityTokenKeyWrapException(string substringExpected = null, Type innerTypeExpected = null, Dictionary<string, object> propertiesExpected = null)
{
return new ExpectedException(typeof(SecurityTokenKeyWrapException), substringExpected, innerTypeExpected, propertiesExpected: propertiesExpected);
Expand Down
Loading

0 comments on commit f0d09d4

Please sign in to comment.