Skip to content

Commit

Permalink
ClientCertificateAzureServiceTokenProvider dispose of certificate obj… (
Browse files Browse the repository at this point in the history
Azure#17266)

…ects

Properly dispose the X509Certificate2 objects after use to prevent temporary private key files to filling up the disk.
Unless properly dispoed, they fill up the space between garbage collections and stay permanently in case of process kills.
  • Loading branch information
danicole authored and annelo-msft committed Feb 17, 2021
1 parent 8feae5e commit 998d764
Showing 1 changed file with 76 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
Expand Down Expand Up @@ -128,76 +129,95 @@ public override async Task<AppAuthenticationResult> GetAuthResultAsync(string re
}

List<X509Certificate2> certs = null;
switch (_certificateIdentifierType)
{
case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier:
// Get certificate for the given Key Vault secret identifier
try
{
var keyVaultCert = await _keyVaultClient.GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false);
certs = new List<X509Certificate2>() { keyVaultCert };
Dictionary<string, string> exceptionDictionary = new Dictionary<string, string>();

// If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token.
if (string.IsNullOrWhiteSpace(authority))
try
{
switch (_certificateIdentifierType)
{
case CertificateIdentifierType.KeyVaultCertificateSecretIdentifier:
// Get certificate for the given Key Vault secret identifier
try
{
_tenantId = _keyVaultClient.PrincipalUsed.TenantId;
authority = $"{_azureAdInstance}{_tenantId}";
var keyVaultCert = await _keyVaultClient
.GetCertificateAsync(_certificateIdentifier, cancellationToken).ConfigureAwait(false);
certs = new List<X509Certificate2>() { keyVaultCert };

// If authority is still not specified, create it using azureAdInstance and tenantId. Tenant ID comes from Key Vault access token.
if (string.IsNullOrWhiteSpace(authority))
{
_tenantId = _keyVaultClient.PrincipalUsed.TenantId;
authority = $"{_azureAdInstance}{_tenantId}";
}
}
}
catch (Exception exp)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
$"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}");
}
break;
case CertificateIdentifierType.SubjectName:
case CertificateIdentifierType.Thumbprint:
// Get certificates for the given thumbprint or subject name.
bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint;
certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint,
_storeLocation);

if (certs == null || certs.Count == 0)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
AzureServiceTokenProviderException.LocalCertificateNotFound);
}
break;
}

// If multiple certs were found, use in order of most recently created.
// This helps if old cert is rolled over, but not removed.
certs = certs.OrderByDescending(p => p.NotBefore).ToList();
catch (Exception exp)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
$"{AzureServiceTokenProviderException.KeyVaultCertificateRetrievalError} {exp.Message}");
}
break;
case CertificateIdentifierType.SubjectName:
case CertificateIdentifierType.Thumbprint:
// Get certificates for the given thumbprint or subject name.
bool isThumbprint = _certificateIdentifierType == CertificateIdentifierType.Thumbprint;
certs = CertificateHelper.GetCertificates(_certificateIdentifier, isThumbprint,
_storeLocation);

if (certs == null || certs.Count == 0)
{
throw new AzureServiceTokenProviderException(ConnectionString, resource, authority,
AzureServiceTokenProviderException.LocalCertificateNotFound);
}
break;
}

// To hold reason why token could not be acquired per cert tried.
Dictionary<string, string> exceptionDictionary = new Dictionary<string, string>();
Debug.Assert(certs != null, "Probably wrong certificateIdentifierType was used to instantiate this class!");

foreach (X509Certificate2 cert in certs)
{
if (!string.IsNullOrEmpty(cert.Thumbprint))
// If multiple certs were found, use in order of most recently created.
// This helps if old cert is rolled over, but not removed.
// To hold reason why token could not be acquired per cert tried.
foreach (X509Certificate2 cert in certs.OrderByDescending(p => p.NotBefore))
{
try
if (!string.IsNullOrEmpty(cert.Thumbprint))
{
ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert);
try
{
ClientAssertionCertificate certCred = new ClientAssertionCertificate(_clientId, cert);

var authResult =
await _authenticationContext.AcquireTokenAsync(authority, resource, certCred).ConfigureAwait(false);
var authResult =
await _authenticationContext.AcquireTokenAsync(authority, resource, certCred)
.ConfigureAwait(false);

var accessToken = authResult?.AccessToken;
var accessToken = authResult?.AccessToken;

if (accessToken != null)
{
PrincipalUsed.CertificateThumbprint = cert.Thumbprint;
PrincipalUsed.IsAuthenticated = true;
PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId;
if (accessToken != null)
{
PrincipalUsed.CertificateThumbprint = cert.Thumbprint;
PrincipalUsed.IsAuthenticated = true;
PrincipalUsed.TenantId = AccessToken.Parse(accessToken).TenantId;

return authResult;
return authResult;
}
}
catch (Exception exp)
{
// If token cannot be acquired using a cert, try the next one
exceptionDictionary[cert.Thumbprint] = exp.Message;
}
}
catch (Exception exp)
}
}
finally
{
if (certs != null)
{
foreach (var cert in certs)
{
// If token cannot be acquired using a cert, try the next one
exceptionDictionary[cert.Thumbprint] = exp.Message;
#if net452
cert.Reset();
#else
cert.Dispose();
#endif
}
}
}
Expand Down

0 comments on commit 998d764

Please sign in to comment.