From 998d7647df3b415380e0e89788d0b4245b030f18 Mon Sep 17 00:00:00 2001 From: Dan Nicolescu Date: Thu, 3 Dec 2020 13:17:30 -0800 Subject: [PATCH] =?UTF-8?q?ClientCertificateAzureServiceTokenProvider=20di?= =?UTF-8?q?spose=20of=20certificate=20obj=E2=80=A6=20(#17266)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ects Properly dispose the X509Certificate2 objects after use to prevent temporary private key files to filling up the disk. Unless properly dispoed, they fill up the space between garbage collections and stay permanently in case of process kills. --- .../ClientCertificateAccessTokenProvider.cs | 132 ++++++++++-------- 1 file changed, 76 insertions(+), 56 deletions(-) diff --git a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs index 2e67a90ac9b9..01756ec67792 100644 --- a/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs +++ b/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/TokenProviders/ClientCertificateAccessTokenProvider.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -128,76 +129,95 @@ public override async Task GetAuthResultAsync(string re } List certs = null; - switch (_certificateIdentifierType) - { - case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier: - // Get certificate for the given Key Vault secret identifier - try - { - var keyVaultCert = await _keyVaultClient.GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false); - certs = new List() { keyVaultCert }; + Dictionary exceptionDictionary = new Dictionary(); - // If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token. - if (string.IsNullOrWhiteSpace(authority)) + try + { + switch (_certificateIdentifierType) + { + case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier: + // Get certificate for the given Key Vault secret identifier + try { - _tenantId = _keyVaultClient.PrincipalUsed.TenantId; - authority = $"{_azureAdInstance}{_tenantId}"; + var keyVaultCert = await _keyVaultClient + .GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false); + certs = new List() { keyVaultCert }; + + // If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token. + if (string.IsNullOrWhiteSpace(authority)) + { + _tenantId = _keyVaultClient.PrincipalUsed.TenantId; + authority = $"{_azureAdInstance}{_tenantId}"; + } } - } - catch (Exception exp) - { - throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, - $"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}"); - } - break; - case CertificateIdentifierType.SubjectName: - case CertificateIdentifierType.Thumbprint: - // Get certificates for the given thumbprint or subject name. - bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint; - certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint, - _storeLocation); - - if (certs == null || certs.Count == 0) - { - throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, - AzureServiceTokenProviderException.LocalCertificateNotFound); - } - break; - } - - // If multiple certs were found, use in order of most recently created. - // This helps if old cert is rolled over, but not removed. - certs = certs.OrderByDescending(p => p.NotBefore).ToList(); + catch (Exception exp) + { + throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, + $"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}"); + } + break; + case CertificateIdentifierType.SubjectName: + case CertificateIdentifierType.Thumbprint: + // Get certificates for the given thumbprint or subject name. + bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint; + certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint, + _storeLocation); + + if (certs == null || certs.Count == 0) + { + throw new AzureServiceTokenProviderException(ConnectionString, resource, authority, + AzureServiceTokenProviderException.LocalCertificateNotFound); + } + break; + } - // To hold reason why token could not be acquired per cert tried. - Dictionary exceptionDictionary = new Dictionary(); + Debug.Assert(certs != null, "Probably wrong certificateIdentifierType was used to instantiate this class!"); - foreach (X509Certificate2 cert in certs) - { - if (!string.IsNullOrEmpty(cert.Thumbprint)) + // If multiple certs were found, use in order of most recently created. + // This helps if old cert is rolled over, but not removed. + // To hold reason why token could not be acquired per cert tried. + foreach (X509Certificate2 cert in certs.OrderByDescending(p => p.NotBefore)) { - try + if (!string.IsNullOrEmpty(cert.Thumbprint)) { - ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert); + try + { + ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert); - var authResult = - await _authenticationContext.AcquireTokenAsync(authority, resource, certCred).ConfigureAwait(false); + var authResult = + await _authenticationContext.AcquireTokenAsync(authority, resource, certCred) + .ConfigureAwait(false); - var accessToken = authResult?.AccessToken; + var accessToken = authResult?.AccessToken; - if (accessToken != null) - { - PrincipalUsed.CertificateThumbprint = cert.Thumbprint; - PrincipalUsed.IsAuthenticated = true; - PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId; + if (accessToken != null) + { + PrincipalUsed.CertificateThumbprint = cert.Thumbprint; + PrincipalUsed.IsAuthenticated = true; + PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId; - return authResult; + return authResult; + } + } + catch (Exception exp) + { + // If token cannot be acquired using a cert, try the next one + exceptionDictionary[cert.Thumbprint] = exp.Message; } } - catch (Exception exp) + } + } + finally + { + if (certs != null) + { + foreach (var cert in certs) { - // If token cannot be acquired using a cert, try the next one - exceptionDictionary[cert.Thumbprint] = exp.Message; +#if net452 + cert.Reset(); +#else + cert.Dispose(); +#endif } } }