Skip to content

Commit

Permalink
Updated KeyVaultCredentialPolicy to extend `BearerTokenAuthenticati…
Browse files Browse the repository at this point in the history
…onPolicy` in Key Vault clients. (#24199)

* Replaced all uses of KeyVaultCredentialPolicy with BearerTokenAuthenticationPolicy in Key Vault clients. Removed the KeyVaultCredentialPolicy and ScopeTokeCache classes from all Track 2 Key Vault libraries.

* We now pass the appropriate scope to BearerTokenAuthenticationPolicy creating a new instance in client builders, tests and samples.

* Added tests and recordings for KEK tests on MHSM. Fixed and cleaned up tests.

* Removed unused imports.

* Renamed MHSM_SCOPE to MANAGED_HSM_SCOPE in all client builders.

* Reintroduced KeyVaultCredentialPolicy and modified it to extend from BearerTokenAuthenticationPolicy while extracting the scope provided in bearer challenges returned by the Key Vault service.

* Fixed CvheckStyle errors.

* Made changes to KeyVaultCredentialPolicy so we don't set the body of a request as null, but an empty String instead.

* Removed scope constants from Key vault client builders.

* Attempted to fix flaky live tests.

* Removed verify test for HSM as the FromSource test already verifies the build's code coverage and running in parallel against the same HSM can cause problems for some tests.

* Reverted KeyVaultCredentialPolicy in all libraries to set the request body to null instead of an empty string when sending the first unauthenticated  request to get a bearer challenge. Also stored the value of the "Content-Length" header in the pipeline context for use in a subsequent request.

* Fixed KV Administration client live tests that failed due to the authentication policy changes. Also fixed some flaky live tests.

* Fixed CheckStyle issues.

* Fixed another CheckStyle issue.

* Fixed issue that caused an NPE in KeyVaultCredentialPolicy if the content of the request being originally sent were null from the beginning.

* Updated KeyVaultCredentialPolicy in all other libraries.

* Made an attempt at fixing the backup async live tests.

* Added sleep timer when running against service for restore operations.

* Applied PR feedback.
  • Loading branch information
vcolin7 authored Sep 30, 2021
1 parent d110970 commit 75ff342
Show file tree
Hide file tree
Showing 73 changed files with 2,373 additions and 1,515 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,6 @@ the main ServiceBusClientBuilder. -->
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.messaging.servicebus.implementation.ServiceBusTokenCredentialHttpPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.messaging.eventgrid.implementation.CloudEventTracingPipelinePolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.storage.common.implementation.policy.SasTokenCredentialPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.security.keyvault.administration.implementation.KeyVaultCredentialPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.security.keyvault.certificates.implementation.KeyVaultCredentialPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.security.keyvault.keys.implementation.KeyVaultCredentialPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.security.keyvault.secrets.implementation.KeyVaultCredentialPolicy.java"/>
<suppress checks="com.azure.tools.checkstyle.checks.HttpPipelinePolicy" files="com.azure.storage.blob.implementation.util.BlobUserAgentModificationPolicy.java"/>


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,10 @@ Mono<Response<KeyVaultRoleDefinition>> setRoleDefinitionWithResponse(SetRoleDefi
options.getRoleDefinitionName(), parameters,
context.addData(AZ_TRACING_NAMESPACE_KEY, KEYVAULT_TRACING_NAMESPACE_VALUE))
.doOnRequest(ignored ->
logger.verbose("Creating role assignment - {}", options.getRoleDefinitionName()))
.doOnSuccess(response -> logger.verbose("Created role assignment - {}", response.getValue().getName()))
logger.verbose("Creating role definition - {}", options.getRoleDefinitionName()))
.doOnSuccess(response -> logger.verbose("Created role definition - {}", response.getValue().getName()))
.doOnError(error ->
logger.warning("Failed to create role assignment - {}", options.getRoleDefinitionName(), error))
logger.warning("Failed to create role definition - {}", options.getRoleDefinitionName(), error))
.onErrorMap(KeyVaultAdministrationUtils::mapThrowableToKeyVaultAdministrationException)
.map(KeyVaultAccessControlAsyncClient::transformRoleDefinitionResponse);
} catch (RuntimeException e) {
Expand Down Expand Up @@ -471,11 +471,11 @@ Mono<Response<KeyVaultRoleDefinition>> getRoleDefinitionWithResponse(KeyVaultRol
return clientImpl.getRoleDefinitions()
.getWithResponseAsync(vaultUrl, roleScope.toString(), roleDefinitionName,
context.addData(AZ_TRACING_NAMESPACE_KEY, KEYVAULT_TRACING_NAMESPACE_VALUE))
.doOnRequest(ignored -> logger.verbose("Retrieving role assignment - {}", roleDefinitionName))
.doOnRequest(ignored -> logger.verbose("Retrieving role definition - {}", roleDefinitionName))
.doOnSuccess(response ->
logger.verbose("Retrieved role assignment - {}", response.getValue().getName()))
logger.verbose("Retrieved role definition - {}", response.getValue().getName()))
.doOnError(error ->
logger.warning("Failed to retrieved role assignment - {}", roleDefinitionName, error))
logger.warning("Failed to retrieved role definition - {}", roleDefinitionName, error))
.onErrorMap(KeyVaultAdministrationUtils::mapThrowableToKeyVaultAdministrationException)
.map(KeyVaultAccessControlAsyncClient::transformRoleDefinitionResponse);
} catch (RuntimeException e) {
Expand Down Expand Up @@ -555,9 +555,9 @@ Mono<Response<Void>> deleteRoleDefinitionWithResponse(KeyVaultRoleScope roleScop
return clientImpl.getRoleDefinitions()
.deleteWithResponseAsync(vaultUrl, roleScope.toString(), roleDefinitionName,
context.addData(AZ_TRACING_NAMESPACE_KEY, KEYVAULT_TRACING_NAMESPACE_VALUE))
.doOnRequest(ignored -> logger.verbose("Deleting role assignment - {}", roleDefinitionName))
.doOnSuccess(response -> logger.verbose("Deleted role assignment - {}", response.getValue().getName()))
.doOnError(error -> logger.warning("Failed to delete role assignment - {}", roleDefinitionName, error))
.doOnRequest(ignored -> logger.verbose("Deleting role definition - {}", roleDefinitionName))
.doOnSuccess(response -> logger.verbose("Deleted role definition - {}", response.getValue().getName()))
.doOnError(error -> logger.warning("Failed to delete role definition - {}", roleDefinitionName, error))
.onErrorMap(KeyVaultAdministrationUtils::mapThrowableToKeyVaultAdministrationException)
.map(response -> (Response<Void>) new SimpleResponse<Void>(response, null))
.onErrorResume(KeyVaultAdministrationException.class, e ->
Expand Down Expand Up @@ -897,7 +897,7 @@ Mono<Response<KeyVaultRoleAssignment>> getRoleAssignmentWithResponse(KeyVaultRol
.doOnSuccess(response ->
logger.verbose("Retrieved role assignment - {}", response.getValue().getName()))
.doOnError(error ->
logger.warning("Failed to retrieved role assignment - {}", roleAssignmentName, error))
logger.warning("Failed to retrieve role assignment - {}", roleAssignmentName, error))
.onErrorMap(KeyVaultAdministrationUtils::mapThrowableToKeyVaultAdministrationException)
.map(KeyVaultAccessControlAsyncClient::transformRoleAssignmentResponse);
} catch (RuntimeException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,125 +1,186 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.security.keyvault.administration.implementation;

import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.core.http.HttpPipelineCallContext;
import com.azure.core.http.HttpPipelineNextPolicy;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.HttpResponse;
import com.azure.core.http.policy.HttpPipelinePolicy;
import com.azure.core.http.policy.BearerTokenAuthenticationPolicy;
import com.azure.core.util.CoreUtils;
import com.azure.core.util.logging.ClientLogger;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.net.URL;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
* A policy that authenticates requests with Azure Key Vault service. The content added by this policy is leveraged
* in {@link TokenCredential} to get and set the correct "Authorization" header value.
* A policy that authenticates requests with the Azure Key Vault service. The content added by this policy is
* leveraged in {@link TokenCredential} to get and set the correct "Authorization" header value.
*
* @see TokenCredential
*/
public final class KeyVaultCredentialPolicy implements HttpPipelinePolicy {
private final ClientLogger logger = new ClientLogger(KeyVaultCredentialPolicy.class);
private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
public class KeyVaultCredentialPolicy extends BearerTokenAuthenticationPolicy {
private static final String BEARER_TOKEN_PREFIX = "Bearer ";
private static final String AUTHORIZATION = "Authorization";
private final ScopeTokenCache cache;
private static final String CONTENT_LENGTH_HEADER = "Content-Length";
private static final String KEY_VAULT_STASHED_CONTENT_KEY = "KeyVaultCredentialPolicyStashedBody";
private static final String KEY_VAULT_STASHED_CONTENT_LENGTH_KEY = "KeyVaultCredentialPolicyStashedContentLength";
private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
private static final ConcurrentMap<String, String> SCOPE_CACHE = new ConcurrentHashMap<>();
private String scope;

/**
* Creates KeyVaultCredentialPolicy.
* Creates a {@link KeyVaultCredentialPolicy}.
*
* @param credential the token credential to authenticate the request
* @param credential The token credential to authenticate the request.
*/
public KeyVaultCredentialPolicy(TokenCredential credential) {
Objects.requireNonNull(credential);

this.cache = new ScopeTokenCache(credential::getToken);
}

/**
* Adds the required header to authenticate a request to Azure Key Vault service.
*
* @param context The request {@link HttpPipelineCallContext context}.
* @param next The next HTTP pipeline policy to process the {@link HttpPipelineCallContext context's} request
* after this policy completes.
* @return A {@link Mono} representing the {@link HttpResponse HTTP response} that will arrive asynchronously.
*/
@Override
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
if (!context.getHttpRequest().getUrl().getProtocol().startsWith("https")) {
return Mono.error(new RuntimeException("Token credentials require a URL using the HTTPS protocol scheme"));
}

return next.clone().process()
.doOnNext(httpResponse -> {
// KV follows challenge based auth. Currently every service
// call hit the endpoint for challenge and then resend the
// request with token. The challenge response body is not
// consumed, not draining/closing the body will result in leak.
// Ref: https://github.com/Azure/azure-sdk-for-java/issues/7934
// https://github.com/Azure/azure-sdk-for-java/issues/10467
try {
httpResponse.getBody().subscribe().dispose();
} catch (RuntimeException ignored) {
logger.logExceptionAsWarning(ignored);
}
// The ReactorNettyHttpResponse::close() should be sufficient
// and should take care similar body disposal but looks like that
// is not happening, need to re-visit the close() method.
})
.map(res -> res.getHeaderValue(WWW_AUTHENTICATE))
.map(header -> extractChallenge(header, BEARER_TOKEN_PREFIX))
.flatMap(map -> {
cache.setTokenRequest(new TokenRequestContext().addScopes(map.get("resource") + "/.default"));
return cache.getToken();
})
.flatMap(token -> {
context.getHttpRequest().setHeader(AUTHORIZATION, BEARER_TOKEN_PREFIX + token.getToken());
return next.process();
});
super(credential);
}

/**
* Extracts the challenge off the authentication header.
* Extracts attributes off the bearer challenge in the authentication header.
*
* @param authenticateHeader The authentication header containing all the challenges.
* @param authenticateHeader The authentication header containing the challenge.
* @param authChallengePrefix The authentication challenge name.
* @return A challenge map.
*
* @return A challenge attributes map.
*/
private static Map<String, String> extractChallenge(String authenticateHeader, String authChallengePrefix) {
if (!isValidChallenge(authenticateHeader, authChallengePrefix)) {
return null;
private static Map<String, String> extractChallengeAttributes(String authenticateHeader,
String authChallengePrefix) {
if (!isBearerChallenge(authenticateHeader, authChallengePrefix)) {
return Collections.emptyMap();
}

authenticateHeader =
authenticateHeader.toLowerCase(Locale.ROOT).replace(authChallengePrefix.toLowerCase(Locale.ROOT), "");

String[] challenges = authenticateHeader.split(", ");
Map<String, String> challengeMap = new HashMap<>();
String[] attributes = authenticateHeader.split(", ");
Map<String, String> attributeMap = new HashMap<>();

for (String pair : challenges) {
for (String pair : attributes) {
String[] keyValue = pair.split("=");
challengeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", ""));

attributeMap.put(keyValue[0].replaceAll("\"", ""), keyValue[1].replaceAll("\"", ""));
}

return challengeMap;
return attributeMap;
}

/**
* Verifies whether a challenge is bearer or not.
*
* @param authenticateHeader The authentication header containing all the challenges.
* @param authenticateHeader The authentication header containing all the challenges.
* @param authChallengePrefix The authentication challenge name.
* @return A boolean indicating tha challenge is valid or not.
* @return A boolean indicating if the challenge is a bearer challenge or not.
*/
private static boolean isValidChallenge(String authenticateHeader, String authChallengePrefix) {
private static boolean isBearerChallenge(String authenticateHeader, String authChallengePrefix) {
return (!CoreUtils.isNullOrEmpty(authenticateHeader)
&& authenticateHeader.toLowerCase(Locale.ROOT).startsWith(authChallengePrefix.toLowerCase(Locale.ROOT)));
}

@Override
public Mono<Void> authorizeRequest(HttpPipelineCallContext context) {
return Mono.defer(() -> {
HttpRequest request = context.getHttpRequest();

// If this policy doesn't have an authorityScope cached try to get it from the static challenge cache.
if (this.scope == null) {
String authority = getRequestAuthority(request);
this.scope = SCOPE_CACHE.get(authority);
}

if (this.scope != null) {
// We fetched the scope from the cache, but we have not initialized the scopes in the base yet.
TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);

return setAuthorizationHeader(context, tokenRequestContext);
}

// The body is removed from the initial request because Key Vault supports other authentication schemes which
// also protect the body of the request. As a result, before we know the auth scheme we need to avoid sending
// an unprotected body to Key Vault. We don't currently support this enhanced auth scheme in the SDK but we
// still don't want to send any unprotected data to vaults which require it.

// Do not overwrite previous contents if retrying after initial request failed (e.g. timeout).
if (!context.getData(KEY_VAULT_STASHED_CONTENT_KEY).isPresent()) {
if (request.getBody() != null) {
context.setData(KEY_VAULT_STASHED_CONTENT_KEY, request.getBody());
context.setData(KEY_VAULT_STASHED_CONTENT_LENGTH_KEY,
request.getHeaders().getValue(CONTENT_LENGTH_HEADER));
request.setHeader(CONTENT_LENGTH_HEADER, "0");
request.setBody((Flux<ByteBuffer>) null);
}
}

return Mono.empty();
});
}

@SuppressWarnings("unchecked")
@Override
public Mono<Boolean> authorizeRequestOnChallenge(HttpPipelineCallContext context, HttpResponse response) {
return Mono.defer(() -> {
HttpRequest request = context.getHttpRequest();
Optional<Object> contentOptional = context.getData(KEY_VAULT_STASHED_CONTENT_KEY);
Optional<Object> contentLengthOptional = context.getData(KEY_VAULT_STASHED_CONTENT_LENGTH_KEY);

if (request.getBody() == null && contentOptional.isPresent() && contentLengthOptional.isPresent()) {
request.setBody((Flux<ByteBuffer>) contentOptional.get());
request.setHeader(CONTENT_LENGTH_HEADER, (String) contentLengthOptional.get());
}

String authority = getRequestAuthority(request);
Map<String, String> challengeAttributes =
extractChallengeAttributes(response.getHeaderValue(WWW_AUTHENTICATE), BEARER_TOKEN_PREFIX);
String scope = challengeAttributes.get("resource");

if (scope != null) {
scope = scope + "/.default";
} else {
scope = challengeAttributes.get("scope");
}

if (scope == null) {
this.scope = SCOPE_CACHE.get(authority);

if (this.scope == null) {
return Mono.just(false);
}
} else {
this.scope = scope;

SCOPE_CACHE.put(authority, this.scope);
}

TokenRequestContext tokenRequestContext = new TokenRequestContext().addScopes(this.scope);

return setAuthorizationHeader(context, tokenRequestContext)
.then(Mono.just(true));
});
}

static void clearCache() {
SCOPE_CACHE.clear();
}

private static String getRequestAuthority(HttpRequest request) {
URL url = request.getUrl();
String authority = url.getAuthority();
int port = url.getPort();

if (!authority.contains(":") && port > 0) {
authority = authority + ":" + port;
}

return authority;
}
}
Loading

0 comments on commit 75ff342

Please sign in to comment.