-
Notifications
You must be signed in to change notification settings - Fork 311
Fix encryption key cache design for AKV provider #3464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
12df21b
bf7fa07
3bee038
8651abb
c0a96a8
672f959
9d1f09b
44671ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
using Azure.Security.KeyVault.Keys.Cryptography; | ||
using System; | ||
using System.Collections.Concurrent; | ||
using System.Threading.Tasks; | ||
using System.Threading; | ||
using static Azure.Security.KeyVault.Keys.Cryptography.SignatureAlgorithm; | ||
|
||
namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider | ||
|
@@ -25,16 +25,14 @@ internal class AzureSqlKeyCryptographer | |
private readonly ConcurrentDictionary<Uri, KeyClient> _keyClientDictionary = new(); | ||
|
||
/// <summary> | ||
/// Holds references to the fetch key tasks and maps them to their corresponding Azure Key Vault Key Identifier (URI). | ||
/// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the | ||
/// key into the key dictionary. | ||
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI). | ||
/// </summary> | ||
private readonly ConcurrentDictionary<string, Task<Azure.Response<KeyVaultKey>>> _keyFetchTaskDictionary = new(); | ||
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new(); | ||
|
||
/// <summary> | ||
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI). | ||
/// SemaphoreSlim to ensure thread safety when accessing the key dictionary or making network calls to Azure Key Vault to fetch keys. | ||
/// </summary> | ||
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new(); | ||
private SemaphoreSlim _keyDictionarySemaphore = new(1, 1); | ||
|
||
/// <summary> | ||
/// Holds references to the Azure Key Vault CryptographyClient objects and maps them to their corresponding Azure Key Vault Key Identifier (URI). | ||
|
@@ -52,18 +50,34 @@ internal AzureSqlKeyCryptographer(TokenCredential tokenCredential) | |
|
||
/// <summary> | ||
/// Adds the key, specified by the Key Identifier URI, to the cache. | ||
/// Validates the key type and fetches the key from Azure Key Vault if it is not already cached. | ||
/// </summary> | ||
/// <param name="keyIdentifierUri"></param> | ||
internal void AddKey(string keyIdentifierUri) | ||
{ | ||
if (TheKeyHasNotBeenCached(keyIdentifierUri)) | ||
try | ||
{ | ||
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion); | ||
CreateKeyClient(vaultUri); | ||
FetchKey(vaultUri, keyName, keyVersion, keyIdentifierUri); | ||
} | ||
// Allow only one thread to proceed to ensure thread safety | ||
// as we will need to fetch key information from Azure Key Vault if the key is not found in cache. | ||
_keyDictionarySemaphore.Wait(); | ||
Comment on lines
+58
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if GetOrAdd() could be used here as well, and avoid the semaphore? Pass a lambda that fetches the key, and let ConcurrentDictionary block all other calls until the fetch + add is complete. Would that work? |
||
|
||
bool TheKeyHasNotBeenCached(string k) => !_keyDictionary.ContainsKey(k) && !_keyFetchTaskDictionary.ContainsKey(k); | ||
if (!_keyDictionary.ContainsKey(keyIdentifierUri)) | ||
{ | ||
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion); | ||
|
||
// Fetch the KeyClient for the Key vault URI. | ||
KeyClient keyClient = GetOrCreateKeyClient(vaultUri); | ||
|
||
// Fetch the key from Azure Key Vault. | ||
KeyVaultKey key = FetchKeyFromKeyVault(keyClient, keyName, keyVersion); | ||
|
||
_keyDictionary.AddOrUpdate(keyIdentifierUri, key, (k, v) => key); | ||
} | ||
} | ||
finally | ||
{ | ||
_keyDictionarySemaphore.Release(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
|
@@ -79,12 +93,6 @@ internal KeyVaultKey GetKey(string keyIdentifierUri) | |
return key; | ||
} | ||
|
||
if (_keyFetchTaskDictionary.TryGetValue(keyIdentifierUri, out Task<Azure.Response<KeyVaultKey>> task)) | ||
{ | ||
AKVEventSource.Log.TryTraceEvent("New Master key fetched."); | ||
return Task.Run(() => task).GetAwaiter().GetResult(); | ||
} | ||
|
||
// Not a public exception - not likely to occur. | ||
AKVEventSource.Log.TryTraceEvent("Master key not found."); | ||
throw ADP.MasterKeyNotFound(keyIdentifierUri); | ||
|
@@ -95,10 +103,7 @@ internal KeyVaultKey GetKey(string keyIdentifierUri) | |
/// </summary> | ||
/// <param name="keyIdentifierUri">The key vault key identifier URI</param> | ||
/// <returns></returns> | ||
internal int GetKeySize(string keyIdentifierUri) | ||
{ | ||
return GetKey(keyIdentifierUri).Key.N.Length; | ||
} | ||
internal int GetKeySize(string keyIdentifierUri) => GetKey(keyIdentifierUri).Key.N.Length; | ||
|
||
/// <summary> | ||
/// Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL. | ||
|
@@ -142,49 +147,67 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri) | |
|
||
CryptographyClient cryptographyClient = new(GetKey(keyIdentifierUri).Id, TokenCredential); | ||
_cryptoClientDictionary.TryAdd(keyIdentifierUri, cryptographyClient); | ||
|
||
return cryptographyClient; | ||
} | ||
|
||
/// <summary> | ||
/// | ||
/// Fetches the column encryption key from the Azure Key Vault. | ||
/// </summary> | ||
/// <param name="vaultUri">The Azure Key Vault URI</param> | ||
/// <param name="keyClient">The KeyClient instance</param> | ||
/// <param name="keyName">The name of the Azure Key Vault key</param> | ||
/// <param name="keyVersion">The version of the Azure Key Vault key</param> | ||
/// <param name="keyResourceUri">The Azure Key Vault key identifier</param> | ||
private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string keyResourceUri) | ||
private KeyVaultKey FetchKeyFromKeyVault(KeyClient keyClient, string keyName, string keyVersion) | ||
{ | ||
Task<Azure.Response<KeyVaultKey>> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion); | ||
_keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, fetchKeyTask, (k, v) => fetchKeyTask); | ||
AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Master key, or column encryption key? |
||
|
||
fetchKeyTask | ||
.ContinueWith(k => ValidateRsaKey(k.GetAwaiter().GetResult())) | ||
.ContinueWith(k => _keyDictionary.AddOrUpdate(keyResourceUri, k.GetAwaiter().GetResult(), (key, v) => k.GetAwaiter().GetResult())); | ||
Azure.Response<KeyVaultKey> keyResponse = keyClient?.GetKey(keyName, keyVersion); | ||
|
||
Task.Run(() => fetchKeyTask); | ||
// Handle the case where the key response is null or contains an error | ||
// This can happen if the key does not exist or if there is an issue with the KeyClient. | ||
// In such cases, we log the error and throw an exception. | ||
if (keyResponse == null || keyResponse.Value == null || keyResponse.GetRawResponse().IsError) | ||
{ | ||
AKVEventSource.Log.TryTraceEvent("Get Key failed to fetch Key from Azure Key Vault for key {0}, version {1}", keyName, keyVersion); | ||
if (keyResponse?.GetRawResponse() is Azure.Response response) | ||
{ | ||
AKVEventSource.Log.TryTraceEvent("Response status {0} : {1}", response.Status, response.ReasonPhrase); | ||
} | ||
throw ADP.GetKeyFailed(keyName); | ||
} | ||
else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unnecessary else block. |
||
{ | ||
KeyVaultKey key = keyResponse.Value; | ||
|
||
// Validate that the key is of type RSA | ||
key = ValidateRsaKey(key); | ||
return key; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Looks up the KeyClient object by it's URI and then fetches the key by name. | ||
/// Gets or creates a KeyClient for the specified Azure Key Vault URI. | ||
/// </summary> | ||
/// <param name="vaultUri">The Azure Key Vault URI</param> | ||
/// <param name="keyName">Then name of the key</param> | ||
/// <param name="keyVersion">Then version of the key</param> | ||
/// <param name="vaultUri">Key Identifier URL</param> | ||
/// <returns></returns> | ||
private Task<Azure.Response<KeyVaultKey>> FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion) | ||
private KeyClient GetOrCreateKeyClient(Uri vaultUri) | ||
{ | ||
_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient); | ||
AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName); | ||
return keyClient?.GetKeyAsync(keyName, keyVersion); | ||
// Fetch the KeyClient for the specified vault URI. | ||
if (!_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient)) | ||
{ | ||
// If the KeyClient does not exist, create a new one and add it to the dictionary. | ||
keyClient = new KeyClient(vaultUri, TokenCredential); | ||
_keyClientDictionary.TryAdd(vaultUri, keyClient); | ||
} | ||
|
||
return keyClient; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If TryAdd() fails (because another thread just added a KeyClient for this vaultUri), then we're returning a different KeyClient instance than has been stored in the dictionary. Is this a problem? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could replace this function body with:
|
||
} | ||
|
||
/// <summary> | ||
/// Validates that a key is of type RSA | ||
/// </summary> | ||
/// <param name="key"></param> | ||
/// <returns></returns> | ||
private KeyVaultKey ValidateRsaKey(KeyVaultKey key) | ||
private static KeyVaultKey ValidateRsaKey(KeyVaultKey key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this function return KeyVaultKey? Not sure what the intention is here, since neither the return value nor exceptions are documented. |
||
{ | ||
if (key.KeyType != KeyType.Rsa && key.KeyType != KeyType.RsaHsm) | ||
{ | ||
|
@@ -195,26 +218,14 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key) | |
return key; | ||
} | ||
|
||
/// <summary> | ||
/// Instantiates and adds a KeyClient to the KeyClient dictionary | ||
/// </summary> | ||
/// <param name="vaultUri">The Azure Key Vault URI</param> | ||
private void CreateKeyClient(Uri vaultUri) | ||
{ | ||
if (!_keyClientDictionary.ContainsKey(vaultUri)) | ||
{ | ||
_keyClientDictionary.TryAdd(vaultUri, new KeyClient(vaultUri, TokenCredential)); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Validates and parses the Azure Key Vault URI and key name. | ||
/// </summary> | ||
/// <param name="masterKeyPath">The Azure Key Vault key identifier</param> | ||
/// <param name="vaultUri">The Azure Key Vault URI</param> | ||
/// <param name="masterKeyName">The name of the key</param> | ||
/// <param name="masterKeyVersion">The version of the key</param> | ||
private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion) | ||
private static void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion) | ||
{ | ||
Uri masterKeyPathUri = new(masterKeyPath); | ||
vaultUri = new Uri(masterKeyPathUri.GetLeftPart(UriPartial.Authority)); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
using System; | ||
using System.Text; | ||
using System.Threading; | ||
using Azure.Core; | ||
using Azure.Security.KeyVault.Keys.Cryptography; | ||
using static Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider.Validator; | ||
|
@@ -55,6 +56,8 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt | |
|
||
private readonly static KeyWrapAlgorithm s_keyWrapAlgorithm = KeyWrapAlgorithm.RsaOaep; | ||
|
||
private SemaphoreSlim _cacheSemaphore = new(1, 1); | ||
|
||
/// <summary> | ||
/// List of Trusted Endpoints | ||
/// | ||
|
@@ -69,7 +72,7 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt | |
/// <summary> | ||
/// A cache for storing the results of signature verification of column master key metadata. | ||
/// </summary> | ||
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache = | ||
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache = | ||
new(maxSizeLimit: 2000) { TimeToLive = TimeSpan.FromDays(10) }; | ||
|
||
/// <summary> | ||
|
@@ -230,7 +233,7 @@ byte[] DecryptEncryptionKey() | |
// Get ciphertext | ||
byte[] cipherText = new byte[cipherTextLength]; | ||
Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength); | ||
|
||
currentIndex += cipherTextLength; | ||
|
||
// Get signature | ||
|
@@ -397,14 +400,7 @@ private byte[] CompileMasterKeyMetadata(string masterKeyPath, bool allowEnclaveC | |
/// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in value; for example, "0x7F2C4A00". | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value -> source |
||
/// </remarks> | ||
private string ToHexString(byte[] source) | ||
{ | ||
if (source is null) | ||
{ | ||
return null; | ||
} | ||
|
||
return "0x" + BitConverter.ToString(source).Replace("-", ""); | ||
} | ||
=> source is null ? null : "0x" + BitConverter.ToString(source).Replace("-", ""); | ||
|
||
/// <summary> | ||
/// Returns the cached decrypted column encryption key, or unwraps the encrypted column encryption key if not present. | ||
|
@@ -415,8 +411,20 @@ private string ToHexString(byte[] source) | |
/// <remarks> | ||
/// | ||
/// </remarks> | ||
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem) | ||
=> _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem); | ||
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem) | ||
{ | ||
try | ||
{ | ||
// Allow only one thread to access the cache at a time. | ||
_cacheSemaphore.Wait(); | ||
Comment on lines
+416
to
+419
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move |
||
return _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem); | ||
} | ||
finally | ||
{ | ||
// Release the semaphore to allow other threads to access the cache. | ||
_cacheSemaphore.Release(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Returns the cached signature verification result, or proceeds to verify if not present. | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -118,13 +118,16 @@ | |||||
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value> | ||||||
</resheader> | ||||||
<data name="NullOrWhitespaceForEach" xml:space="preserve"> | ||||||
<value>One or more of the elements in {0} are null or empty or consist of only whitespace.</value> | ||||||
<value>One or more of the elements in '{0}' are null or empty or consist of only whitespace.</value> | ||||||
</data> | ||||||
<data name="CipherTextLengthMismatch" xml:space="preserve"> | ||||||
<value>CipherText length does not match the RSA key size.</value> | ||||||
</data> | ||||||
<data name="EmptyArgumentInternal" xml:space="preserve"> | ||||||
<value>Internal error. Empty {0} specified.</value> | ||||||
<value>Internal error. Empty '{0}' specified.</value> | ||||||
</data> | ||||||
<data name="GetKeyFailed" xml:space="preserve"> | ||||||
<value>Fetching the key failed: '{0}'.</value> | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
</data> | ||||||
<data name="MasterKeyNotFound" xml:space="preserve"> | ||||||
<value>The key with identifier '{0}' was not found.</value> | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SemaphoreSlim is IDisposable. Do we need to dispose of it?