diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs index 2e67a90ac9b9d..01756ec677928 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -128,76 +129,95 @@ public override async Task GetAuthResultAsync(string re } List certs = null; - switch (_certificateIdentifierType) - { - case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier: - // Get certificate for the given Key Vault secret identifier - try - { - var keyVaultCert = await _keyVaultClient.GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false); - certs = new List() { keyVaultCert }; + Dictionary exceptionDictionary = new Dictionary(); - // If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token. - if (string.IsNullOrWhiteSpace(authority)) + try + { + switch (_certificateIdentifierType) + { + case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier: + // Get certificate for the given Key Vault secret identifier + try { - _tenantId = _keyVaultClient.PrincipalUsed.TenantId; - authority = $"{_azureAdInstance}{_tenantId}"; + var keyVaultCert = await _keyVaultClient + .GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false); + certs = new List() { keyVaultCert }; + + // If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token. + if (string.IsNullOrWhiteSpace(authority)) + { + _tenantId = _keyVaultClient.PrincipalUsed.TenantId; + authority = $"{_azureAdInstance}{_tenantId}"; + } } - } - catch (Exception exp) - { - throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, - $"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}"); - } - break; - case CertificateIdentifierType.SubjectName: - case CertificateIdentifierType.Thumbprint: - // Get certificates for the given thumbprint or subject name. - bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint; - certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint, - _storeLocation); - - if (certs == null || certs.Count == 0) - { - throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, - AzureServiceTokenProviderException.LocalCertificateNotFound); - } - break; - } - - // If multiple certs were found, use in order of most recently created. - // This helps if old cert is rolled over, but not removed. - certs = certs.OrderByDescending(p => p.NotBefore).ToList(); + catch (Exception exp) + { + throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, + $"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}"); + } + break; + case CertificateIdentifierType.SubjectName: + case CertificateIdentifierType.Thumbprint: + // Get certificates for the given thumbprint or subject name. + bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint; + certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint, + _storeLocation); + + if (certs == null || certs.Count == 0) + { + throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, + AzureServiceTokenProviderException.LocalCertificateNotFound); + } + break; + } - // To hold reason why token could not be acquired per cert tried. - Dictionary exceptionDictionary = new Dictionary(); + Debug.Assert(certs != null, "Probably wrong certificateIdentifierType was used to instantiate this class!"); - foreach (X509Certificate2 cert in certs) - { - if (!string.IsNullOrEmpty(cert.Thumbprint)) + // If multiple certs were found, use in order of most recently created. + // This helps if old cert is rolled over, but not removed. + // To hold reason why token could not be acquired per cert tried. + foreach (X509Certificate2 cert in certs.OrderByDescending(p => p.NotBefore)) { - try + if (!string.IsNullOrEmpty(cert.Thumbprint)) { - ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert); + try + { + ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert); - var authResult = - await _authenticationContext.AcquireTokenAsync(authority, resource, certCred).ConfigureAwait(false); + var authResult = + await _authenticationContext.AcquireTokenAsync(authority, resource, certCred) + .ConfigureAwait(false); - var accessToken = authResult?.AccessToken; + var accessToken = authResult?.AccessToken; - if (accessToken != null) - { - PrincipalUsed.CertificateThumbprint = cert.Thumbprint; - PrincipalUsed.IsAuthenticated = true; - PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId; + if (accessToken != null) + { + PrincipalUsed.CertificateThumbprint = cert.Thumbprint; + PrincipalUsed.IsAuthenticated = true; + PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId; - return authResult; + return authResult; + } + } + catch (Exception exp) + { + // If token cannot be acquired using a cert, try the next one + exceptionDictionary[cert.Thumbprint] = exp.Message; } } - catch (Exception exp) + } + } + finally + { + if (certs != null) + { + foreach (var cert in certs) { - // If token cannot be acquired using a cert, try the next one - exceptionDictionary[cert.Thumbprint] = exp.Message; +#if net452 + cert.Reset(); +#else + cert.Dispose(); +#endif } } }