Skip to content

Fix encryption key cache design for AKV provider #3464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using Azure.Security.KeyVault.Keys.Cryptography;
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;
using System.Threading;
using static Azure.Security.KeyVault.Keys.Cryptography.SignatureAlgorithm;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
Expand All @@ -25,16 +25,14 @@ internal class AzureSqlKeyCryptographer
private readonly ConcurrentDictionary<Uri, KeyClient> _keyClientDictionary = new();

/// <summary>
/// Holds references to the fetch key tasks and maps them to their corresponding Azure Key Vault Key Identifier (URI).
/// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the
/// key into the key dictionary.
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
/// </summary>
private readonly ConcurrentDictionary<string, Task<Azure.Response<KeyVaultKey>>> _keyFetchTaskDictionary = new();
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new();

/// <summary>
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
/// SemaphoreSlim to ensure thread safety when accessing the key dictionary or making network calls to Azure Key Vault to fetch keys.
/// </summary>
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new();
private SemaphoreSlim _keyDictionarySemaphore = new(1, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SemaphoreSlim is IDisposable. Do we need to dispose of it?


/// <summary>
/// Holds references to the Azure Key Vault CryptographyClient objects and maps them to their corresponding Azure Key Vault Key Identifier (URI).
Expand All @@ -52,18 +50,34 @@ internal AzureSqlKeyCryptographer(TokenCredential tokenCredential)

/// <summary>
/// Adds the key, specified by the Key Identifier URI, to the cache.
/// Validates the key type and fetches the key from Azure Key Vault if it is not already cached.
/// </summary>
/// <param name="keyIdentifierUri"></param>
internal void AddKey(string keyIdentifierUri)
{
if (TheKeyHasNotBeenCached(keyIdentifierUri))
try
{
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion);
CreateKeyClient(vaultUri);
FetchKey(vaultUri, keyName, keyVersion, keyIdentifierUri);
}
// Allow only one thread to proceed to ensure thread safety
// as we will need to fetch key information from Azure Key Vault if the key is not found in cache.
_keyDictionarySemaphore.Wait();
Comment on lines +58 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move _keyDictionarySemaphore.Wait(); to before the try. We don't want to execute the finally if Wait() throws an exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if GetOrAdd() could be used here as well, and avoid the semaphore? Pass a lambda that fetches the key, and let ConcurrentDictionary block all other calls until the fetch + add is complete. Would that work?


bool TheKeyHasNotBeenCached(string k) => !_keyDictionary.ContainsKey(k) && !_keyFetchTaskDictionary.ContainsKey(k);
if (!_keyDictionary.ContainsKey(keyIdentifierUri))
{
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion);

// Fetch the KeyClient for the Key vault URI.
KeyClient keyClient = GetOrCreateKeyClient(vaultUri);

// Fetch the key from Azure Key Vault.
KeyVaultKey key = FetchKeyFromKeyVault(keyClient, keyName, keyVersion);

_keyDictionary.AddOrUpdate(keyIdentifierUri, key, (k, v) => key);
}
}
finally
{
_keyDictionarySemaphore.Release();
}
}

/// <summary>
Expand All @@ -79,12 +93,6 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
return key;
}

if (_keyFetchTaskDictionary.TryGetValue(keyIdentifierUri, out Task<Azure.Response<KeyVaultKey>> task))
{
AKVEventSource.Log.TryTraceEvent("New Master key fetched.");
return Task.Run(() => task).GetAwaiter().GetResult();
}

// Not a public exception - not likely to occur.
AKVEventSource.Log.TryTraceEvent("Master key not found.");
throw ADP.MasterKeyNotFound(keyIdentifierUri);
Expand All @@ -95,10 +103,7 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
/// </summary>
/// <param name="keyIdentifierUri">The key vault key identifier URI</param>
/// <returns></returns>
internal int GetKeySize(string keyIdentifierUri)
{
return GetKey(keyIdentifierUri).Key.N.Length;
}
internal int GetKeySize(string keyIdentifierUri) => GetKey(keyIdentifierUri).Key.N.Length;

/// <summary>
/// Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL.
Expand Down Expand Up @@ -142,49 +147,67 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri)

CryptographyClient cryptographyClient = new(GetKey(keyIdentifierUri).Id, TokenCredential);
_cryptoClientDictionary.TryAdd(keyIdentifierUri, cryptographyClient);

return cryptographyClient;
}

/// <summary>
///
/// Fetches the column encryption key from the Azure Key Vault.
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="keyClient">The KeyClient instance</param>
/// <param name="keyName">The name of the Azure Key Vault key</param>
/// <param name="keyVersion">The version of the Azure Key Vault key</param>
/// <param name="keyResourceUri">The Azure Key Vault key identifier</param>
private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string keyResourceUri)
private KeyVaultKey FetchKeyFromKeyVault(KeyClient keyClient, string keyName, string keyVersion)
{
Task<Azure.Response<KeyVaultKey>> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion);
_keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, fetchKeyTask, (k, v) => fetchKeyTask);
AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Master key, or column encryption key?


fetchKeyTask
.ContinueWith(k => ValidateRsaKey(k.GetAwaiter().GetResult()))
.ContinueWith(k => _keyDictionary.AddOrUpdate(keyResourceUri, k.GetAwaiter().GetResult(), (key, v) => k.GetAwaiter().GetResult()));
Azure.Response<KeyVaultKey> keyResponse = keyClient?.GetKey(keyName, keyVersion);

Task.Run(() => fetchKeyTask);
// Handle the case where the key response is null or contains an error
// This can happen if the key does not exist or if there is an issue with the KeyClient.
// In such cases, we log the error and throw an exception.
if (keyResponse == null || keyResponse.Value == null || keyResponse.GetRawResponse().IsError)
{
AKVEventSource.Log.TryTraceEvent("Get Key failed to fetch Key from Azure Key Vault for key {0}, version {1}", keyName, keyVersion);
if (keyResponse?.GetRawResponse() is Azure.Response response)
{
AKVEventSource.Log.TryTraceEvent("Response status {0} : {1}", response.Status, response.ReasonPhrase);
}
throw ADP.GetKeyFailed(keyName);
}
else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary else block.

{
KeyVaultKey key = keyResponse.Value;

// Validate that the key is of type RSA
key = ValidateRsaKey(key);
return key;
}
}

/// <summary>
/// Looks up the KeyClient object by it's URI and then fetches the key by name.
/// Gets or creates a KeyClient for the specified Azure Key Vault URI.
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="keyName">Then name of the key</param>
/// <param name="keyVersion">Then version of the key</param>
/// <param name="vaultUri">Key Identifier URL</param>
/// <returns></returns>
private Task<Azure.Response<KeyVaultKey>> FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion)
private KeyClient GetOrCreateKeyClient(Uri vaultUri)
{
_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient);
AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName);
return keyClient?.GetKeyAsync(keyName, keyVersion);
// Fetch the KeyClient for the specified vault URI.
if (!_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient))
{
// If the KeyClient does not exist, create a new one and add it to the dictionary.
keyClient = new KeyClient(vaultUri, TokenCredential);
_keyClientDictionary.TryAdd(vaultUri, keyClient);
}

return keyClient;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If TryAdd() fails (because another thread just added a KeyClient for this vaultUri), then we're returning a different KeyClient instance than has been stored in the dictionary. Is this a problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could replace this function body with:

return _keyClientDictionary.GetOrAdd(vaultUri, (_) => new KeyClient(vaultUri, TokenCredential));

}

/// <summary>
/// Validates that a key is of type RSA
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
private static KeyVaultKey ValidateRsaKey(KeyVaultKey key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this function return KeyVaultKey? Not sure what the intention is here, since neither the return value nor exceptions are documented.

{
if (key.KeyType != KeyType.Rsa && key.KeyType != KeyType.RsaHsm)
{
Expand All @@ -195,26 +218,14 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
return key;
}

/// <summary>
/// Instantiates and adds a KeyClient to the KeyClient dictionary
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
private void CreateKeyClient(Uri vaultUri)
{
if (!_keyClientDictionary.ContainsKey(vaultUri))
{
_keyClientDictionary.TryAdd(vaultUri, new KeyClient(vaultUri, TokenCredential));
}
}

/// <summary>
/// Validates and parses the Azure Key Vault URI and key name.
/// </summary>
/// <param name="masterKeyPath">The Azure Key Vault key identifier</param>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="masterKeyName">The name of the key</param>
/// <param name="masterKeyVersion">The version of the key</param>
private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
private static void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
{
Uri masterKeyPathUri = new(masterKeyPath);
vaultUri = new Uri(masterKeyPathUri.GetLeftPart(UriPartial.Authority));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Extensions.Caching.Memory;
using System;
using Microsoft.Extensions.Caching.Memory;
using static System.Math;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
Expand Down Expand Up @@ -92,6 +92,7 @@ internal TValue GetOrCreate(TKey key, Func<TValue> createItem)

/// <summary>
/// Determines whether the <see cref="LocalCache{TKey, TValue}">LocalCache</see> contains the specified key.
/// Used in unit tests to verify that the cache contains the expected entries.
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Text;
using System.Threading;
using Azure.Core;
using Azure.Security.KeyVault.Keys.Cryptography;
using static Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider.Validator;
Expand Down Expand Up @@ -55,6 +56,8 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt

private readonly static KeyWrapAlgorithm s_keyWrapAlgorithm = KeyWrapAlgorithm.RsaOaep;

private SemaphoreSlim _cacheSemaphore = new(1, 1);

/// <summary>
/// List of Trusted Endpoints
///
Expand All @@ -69,7 +72,7 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt
/// <summary>
/// A cache for storing the results of signature verification of column master key metadata.
/// </summary>
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
new(maxSizeLimit: 2000) { TimeToLive = TimeSpan.FromDays(10) };

/// <summary>
Expand Down Expand Up @@ -230,7 +233,7 @@ byte[] DecryptEncryptionKey()
// Get ciphertext
byte[] cipherText = new byte[cipherTextLength];
Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);

currentIndex += cipherTextLength;

// Get signature
Expand Down Expand Up @@ -397,14 +400,7 @@ private byte[] CompileMasterKeyMetadata(string masterKeyPath, bool allowEnclaveC
/// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in value; for example, "0x7F2C4A00".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value -> source

/// </remarks>
private string ToHexString(byte[] source)
{
if (source is null)
{
return null;
}

return "0x" + BitConverter.ToString(source).Replace("-", "");
}
=> source is null ? null : "0x" + BitConverter.ToString(source).Replace("-", "");

/// <summary>
/// Returns the cached decrypted column encryption key, or unwraps the encrypted column encryption key if not present.
Expand All @@ -415,8 +411,20 @@ private string ToHexString(byte[] source)
/// <remarks>
///
/// </remarks>
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
=> _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
{
try
{
// Allow only one thread to access the cache at a time.
_cacheSemaphore.Wait();
Comment on lines +416 to +419
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move _cacheSemaphore.Wait(); to before the try. We don't want to execute the finally if Wait() throws an exception.

return _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
}
finally
{
// Release the semaphore to allow other threads to access the cache.
_cacheSemaphore.Release();
}
}

/// <summary>
/// Returns the cached signature verification result, or proceeds to verify if not present.
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
</resheader>
<data name="NullOrWhitespaceForEach" xml:space="preserve">
<value>One or more of the elements in {0} are null or empty or consist of only whitespace.</value>
<value>One or more of the elements in '{0}' are null or empty or consist of only whitespace.</value>
</data>
<data name="CipherTextLengthMismatch" xml:space="preserve">
<value>CipherText length does not match the RSA key size.</value>
</data>
<data name="EmptyArgumentInternal" xml:space="preserve">
<value>Internal error. Empty {0} specified.</value>
<value>Internal error. Empty '{0}' specified.</value>
</data>
<data name="GetKeyFailed" xml:space="preserve">
<value>Fetching the key failed: '{0}'.</value>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<value>Fetching the key failed: '{0}'.</value>
<value>Failed to fetch Key from Azure Key Vault. Key: {0}</value>

</data>
<data name="MasterKeyNotFound" xml:space="preserve">
<value>The key with identifier '{0}' was not found.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ internal static ArgumentException NullOrWhitespaceForEach(string name) =>
new(string.Format(Strings.NullOrWhitespaceForEach, name));

internal static KeyNotFoundException MasterKeyNotFound(string masterKeyPath) =>
new(string.Format(CultureInfo.InvariantCulture, Strings.InvalidSignatureTemplate, masterKeyPath));
new(string.Format(CultureInfo.InvariantCulture, Strings.MasterKeyNotFound, masterKeyPath));

internal static KeyNotFoundException GetKeyFailed(string masterKeyPath) =>
new(string.Format(CultureInfo.InvariantCulture, Strings.GetKeyFailed, masterKeyPath));

internal static FormatException NonRsaKeyFormat(string keyType) =>
new(string.Format(CultureInfo.InvariantCulture, Strings.NonRsaKeyTemplate, keyType));
Expand Down
Loading
Loading