Skip to content

Commit

Permalink
Merge pull request #397 from v-xiangs/update-code-to-compile-with-AKV-v1
Browse files Browse the repository at this point in the history
Update code to compile with AKV v1.0.0
  • Loading branch information
xiangyushawn authored Jul 21, 2017
2 parents f46f622 + ba4651f commit 0d7d78b
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 125 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
<dependency>
<groupId>com.microsoft.azure</groupId>
<artifactId>azure-keyvault</artifactId>
<version>0.9.7</version>
<version>1.0.0</version>
<optional>true</optional>
</dependency>

Expand Down
75 changes: 40 additions & 35 deletions src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

package com.microsoft.sqlserver.jdbc;

import java.util.Map;

import org.apache.http.Header;
import org.apache.http.message.BasicHeader;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import com.microsoft.aad.adal4j.AuthenticationContext;
import com.microsoft.aad.adal4j.AuthenticationResult;
import com.microsoft.aad.adal4j.ClientCredential;
import com.microsoft.azure.keyvault.authentication.KeyVaultCredentials;
import com.microsoft.windowsazure.core.pipeline.filter.ServiceRequestContext;

/**
*
Expand All @@ -23,42 +24,46 @@
*/
class KeyVaultCredential extends KeyVaultCredentials {

// this is the only supported access token type
// https://msdn.microsoft.com/en-us/library/azure/dn645538.aspx
private final String accessTokenType = "Bearer";

SQLServerKeyVaultAuthenticationCallback authenticationCallback = null;
String clientId = null;
String clientKey = null;
String accessToken = null;

KeyVaultCredential(SQLServerKeyVaultAuthenticationCallback authenticationCallback) {
this.authenticationCallback = authenticationCallback;
KeyVaultCredential(String clientId,
String clientKey) {
this.clientId = clientId;
this.clientKey = clientKey;
}

/**
* Authenticates the service request
*
* @param request
* the ServiceRequestContext
* @param challenge
* used to get the accessToken
* @return BasicHeader
*/
@Override
public Header doAuthenticate(ServiceRequestContext request,
Map<String, String> challenge) {
assert null != challenge;

String authorization = challenge.get("authorization");
String resource = challenge.get("resource");

accessToken = authenticationCallback.getAccessToken(authorization, resource, "");
return new BasicHeader("Authorization", accessTokenType + " " + accessToken);
public String doAuthenticate(String authorization,
String resource,
String scope) {
AuthenticationResult token = getAccessTokenFromClientCredentials(authorization, resource, clientId, clientKey);
return token.getAccessToken();
}

void setAccessToken(String accessToken) {
this.accessToken = accessToken;
}
private static AuthenticationResult getAccessTokenFromClientCredentials(String authorization,
String resource,
String clientId,
String clientKey) {
AuthenticationContext context = null;
AuthenticationResult result = null;
ExecutorService service = null;
try {
service = Executors.newFixedThreadPool(1);
context = new AuthenticationContext(authorization, false, service);
ClientCredential credentials = new ClientCredential(clientId, clientKey);
Future<AuthenticationResult> future = context.acquireToken(resource, credentials, null);
result = future.get();
}
catch (Exception e) {
throw new RuntimeException(e);
}
finally {
service.shutdown();
}

if (result == null) {
throw new RuntimeException("authentication result was null");
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.util.Locale;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;

import org.apache.http.impl.client.HttpClientBuilder;

import com.microsoft.azure.keyvault.KeyVaultClient;
import com.microsoft.azure.keyvault.KeyVaultClientImpl;
import com.microsoft.azure.keyvault.models.KeyBundle;
import com.microsoft.azure.keyvault.models.KeyOperationResult;
import com.microsoft.azure.keyvault.models.KeyVerifyResult;
import com.microsoft.azure.keyvault.webkey.JsonWebKeyEncryptionAlgorithm;
import com.microsoft.azure.keyvault.webkey.JsonWebKeySignatureAlgorithm;

/**
Expand Down Expand Up @@ -67,26 +64,20 @@ public String getName() {
}

/**
* Constructor that takes a callback function to authenticate to AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key
* Constructor that authenticates to AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key
* Vault.
*
* @param authenticationCallback
* - Callback function used for authenticating to AAD.
* @param executorService
* - The ExecutorService used to create the keyVaultClient
* @param clientId
* Identifier of the client requesting the token.
* @param clientKey
* Key of the client requesting the token.
* @throws SQLServerException
* when an error occurs
*/
public SQLServerColumnEncryptionAzureKeyVaultProvider(SQLServerKeyVaultAuthenticationCallback authenticationCallback,
ExecutorService executorService) throws SQLServerException {
if (null == authenticationCallback) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
Object[] msgArgs1 = {"SQLServerKeyVaultAuthenticationCallback"};
throw new SQLServerException(form.format(msgArgs1), null);
}
credential = new KeyVaultCredential(authenticationCallback);
HttpClientBuilder builder = HttpClientBuilder.create();
keyVaultClient = new KeyVaultClientImpl(builder, executorService, credential);
public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId,
String clientKey) throws SQLServerException {
credential = new KeyVaultCredential(clientId, clientKey);
keyVaultClient = new KeyVaultClient(credential);
}

/**
Expand Down Expand Up @@ -309,7 +300,7 @@ public byte[] encryptColumnEncryptionKey(String masterKeyPath,
byte dataToSign[] = md.digest();

// Sign the hash
byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath);
byte[] signedHash = AzureKeyVaultSignHashedData(dataToSign, masterKeyPath);

if (signedHash.length != keySizeInBytes) {
throw new SQLServerException(SQLServerException.getErrString("R_SignedHashLengthError"), null);
Expand Down Expand Up @@ -434,14 +425,10 @@ private byte[] AzureKeyVaultWrap(String masterKeyPath,
throw new SQLServerException(SQLServerException.getErrString("R_CEKNull"), null);
}

KeyOperationResult wrappedKey = null;
try {
wrappedKey = keyVaultClient.wrapKeyAsync(masterKeyPath, encryptionAlgorithm, columnEncryptionKey).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_EncryptCEKError"), e);
}
return wrappedKey.getResult();
JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm);
KeyOperationResult wrappedKey = keyVaultClient.wrapKey(masterKeyPath, jsonEncryptionAlgorithm, columnEncryptionKey);

return wrappedKey.result();
}

/**
Expand All @@ -467,14 +454,10 @@ private byte[] AzureKeyVaultUnWrap(String masterKeyPath,
throw new SQLServerException(SQLServerException.getErrString("R_EmptyEncryptedCEK"), null);
}

KeyOperationResult unwrappedKey;
try {
unwrappedKey = keyVaultClient.unwrapKeyAsync(masterKeyPath, encryptionAlgorithm, encryptedColumnEncryptionKey).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_DecryptCEKError"), e);
}
return unwrappedKey.getResult();
JsonWebKeyEncryptionAlgorithm jsonEncryptionAlgorithm = new JsonWebKeyEncryptionAlgorithm(encryptionAlgorithm);
KeyOperationResult unwrappedKey = keyVaultClient.unwrapKey(masterKeyPath, jsonEncryptionAlgorithm, encryptedColumnEncryptionKey);

return unwrappedKey.result();
}

/**
Expand All @@ -491,14 +474,9 @@ private byte[] AzureKeyVaultSignHashedData(byte[] dataToSign,
String masterKeyPath) throws SQLServerException {
assert ((null != dataToSign) && (0 != dataToSign.length));

KeyOperationResult signedData = null;
try {
signedData = keyVaultClient.signAsync(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToSign).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_GenerateSignature"), e);
}
return signedData.getResult();
KeyOperationResult signedData = keyVaultClient.sign(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToSign);

return signedData.result();
}

/**
Expand All @@ -517,15 +495,9 @@ private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify,
assert ((null != dataToVerify) && (0 != dataToVerify.length));
assert ((null != signature) && (0 != signature.length));

boolean valid = false;
try {
valid = keyVaultClient.verifyAsync(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToVerify, signature).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_VerifySignature"), e);
}
KeyVerifyResult valid = keyVaultClient.verify(masterKeyPath, JsonWebKeySignatureAlgorithm.RS256, dataToVerify, signature);

return valid;
return valid.value();
}

/**
Expand All @@ -538,21 +510,22 @@ private boolean AzureKeyVaultVerifySignature(byte[] dataToVerify,
* when an error occurs
*/
private int getAKVKeySize(String masterKeyPath) throws SQLServerException {
KeyBundle retrievedKey = keyVaultClient.getKey(masterKeyPath);

KeyBundle retrievedKey = null;
try {
retrievedKey = keyVaultClient.getKeyAsync(masterKeyPath).get();
}
catch (InterruptedException | ExecutionException e) {
throw new SQLServerException(SQLServerException.getErrString("R_GetAKVKeySize"), e);
if (null == retrievedKey) {
String[] keyTokens = masterKeyPath.split("/");

MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_AKVKeyNotFound"));
Object[] msgArgs = {keyTokens[keyTokens.length - 1]};
throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
}

if (!"RSA".equalsIgnoreCase(retrievedKey.getKey().getKty()) && !"RSA-HSM".equalsIgnoreCase(retrievedKey.getKey().getKty())) {
if (!"RSA".equalsIgnoreCase(retrievedKey.key().kty().toString()) && !"RSA-HSM".equalsIgnoreCase(retrievedKey.key().kty().toString())) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NonRSAKey"));
Object[] msgArgs = {retrievedKey.getKey().getKty()};
Object[] msgArgs = {retrievedKey.key().kty().toString()};
throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
}

return retrievedKey.getKey().getN().length;
return retrievedKey.key().n().length;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -385,5 +385,6 @@ protected Object[][] getContents() {
{"R_kerberosLoginFailed", "Kerberos Login failed: {0} due to {1} ({2})"},
{"R_StoredProcedureNotFound", "Could not find stored procedure ''{0}''."},
{"R_jaasConfigurationNamePropertyDescription", "Login configuration file for Kerberos authentication."},
{"R_AKVKeyNotFound", "Key not found: {0}"},
};
}

0 comments on commit 0d7d78b

Please sign in to comment.