Skip to content

Commit

Permalink
[fix][broker] Implement authenticateAsync for AuthenticationProviderL…
Browse files Browse the repository at this point in the history
…ist (apache#20132)

PIP: apache#12105 and apache#19771 

### Motivation

With the implementation of asynchronous authentication in PIP 97, I missed a case in the `AuthenticationProviderList` where we need to implement the `authenticateAsync` methods. This PR is necessary for making the `AuthenticationProviderToken` and the `AuthenticationProviderOpenID` work together, which is necessary for anyone transitioning to `AuthenticationProviderOpenID`.

### Modifications

* Implement `AuthenticationListState#authenticateAsync` using a recursive algorithm that first attempts to authenticate the client using the current `authState` and then tries the remaining options.
* Implement `AuthenticationProviderList#authenticateAsync` using a recursive algorithm that attempts each provider sequentially.
* Add test to `AuthenticationProviderListTest` that exercises this method. It didn't technically fail previously, but it's worth adding.
* Add test to `AuthenticationProviderOpenIDIntegrationTest` to cover the exact failures that were causing problems.
  • Loading branch information
michaeljmarshall authored Apr 19, 2023
1 parent 46a65fd commit 58ccf02
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.github.tomakehurst.wiremock.WireMockServer;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.impl.DefaultJwtBuilder;
import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.security.Keys;
import java.io.IOException;
import java.nio.file.Files;
Expand All @@ -41,13 +42,19 @@
import java.util.Base64;
import java.util.Date;
import java.util.HashMap;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import javax.naming.AuthenticationException;
import lombok.Cleanup;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.AuthenticationDataCommand;
import org.apache.pulsar.broker.authentication.AuthenticationProvider;
import org.apache.pulsar.broker.authentication.AuthenticationProviderToken;
import org.apache.pulsar.broker.authentication.AuthenticationService;
import org.apache.pulsar.broker.authentication.AuthenticationState;
import org.apache.pulsar.broker.authentication.utils.AuthTokenUtils;
import org.apache.pulsar.common.api.AuthData;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
Expand Down Expand Up @@ -438,6 +445,56 @@ public void testAuthenticationStateOpenIDForTokenExpiration() throws Exception {
assertTrue(state.isExpired());
}

/**
* This test covers the migration scenario where you have both the Token and OpenID providers. It ensures
* both kinds of authentication work.
* @throws Exception
*/
@Test
public void testAuthenticationProviderListStateSuccess() throws Exception {
ServiceConfiguration conf = new ServiceConfiguration();
conf.setAuthenticationEnabled(true);
conf.setAuthenticationProviders(Set.of(AuthenticationProviderOpenID.class.getName(),
AuthenticationProviderToken.class.getName()));
Properties props = conf.getProperties();
props.setProperty(AuthenticationProviderOpenID.REQUIRE_HTTPS, "false");
props.setProperty(AuthenticationProviderOpenID.ALLOWED_AUDIENCES, "allowed-audience");
props.setProperty(AuthenticationProviderOpenID.ALLOWED_TOKEN_ISSUERS, issuer);

// Set up static token
KeyPair keyPair = Keys.keyPairFor(SignatureAlgorithm.RS256);
// Use public key for validation
String publicKeyStr = AuthTokenUtils.encodeKeyBase64(keyPair.getPublic());
props.setProperty("tokenPublicKey", publicKeyStr);
// Use private key to generate token
String privateKeyStr = AuthTokenUtils.encodeKeyBase64(keyPair.getPrivate());
PrivateKey privateKey = AuthTokenUtils.decodePrivateKey(Decoders.BASE64.decode(privateKeyStr),
SignatureAlgorithm.RS256);
String staticToken = AuthTokenUtils.createToken(privateKey, "superuser", Optional.empty());

@Cleanup
AuthenticationService service = new AuthenticationService(conf);
AuthenticationProvider provider = service.getAuthenticationProvider("token");

// First, authenticate using OIDC
String role = "superuser";
String oidcToken = generateToken(validJwk, issuer, role, "allowed-audience", 0L, 0L, 10000L);
assertEquals(role, provider.authenticateAsync(new AuthenticationDataCommand(oidcToken)).get());

// Authenticate using the static token
assertEquals("superuser", provider.authenticateAsync(new AuthenticationDataCommand(staticToken)).get());

// Use authenticationState to authentication using OIDC
AuthenticationState state1 = service.getAuthenticationProvider("token").newAuthState(null, null, null);
assertNull(state1.authenticateAsync(AuthData.of(oidcToken.getBytes())).get());
assertEquals(state1.getAuthRole(), role);

// Use authenticationState to authentication using static token
AuthenticationState state2 = service.getAuthenticationProvider("token").newAuthState(null, null, null);
assertNull(state2.authenticateAsync(AuthData.of(staticToken.getBytes())).get());
assertEquals(state1.getAuthRole(), role);
}

@Test
void ensureRoleClaimForNonSubClaimReturnsRole() throws Exception {
AuthenticationProviderOpenID provider = new AuthenticationProviderOpenID();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import javax.naming.AuthenticationException;
import javax.net.ssl.SSLSession;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -76,9 +77,12 @@ static <T, W> T applyAuthProcessor(List<W> processors, AuthProcessor<T, W> authF
private static class AuthenticationListState implements AuthenticationState {

private final List<AuthenticationState> states;
private AuthenticationState authState;
private volatile AuthenticationState authState;

AuthenticationListState(List<AuthenticationState> states) {
if (states == null || states.isEmpty()) {
throw new IllegalArgumentException("Authentication state requires at least one state");
}
this.states = states;
this.authState = states.get(0);
}
Expand All @@ -96,6 +100,61 @@ public String getAuthRole() throws AuthenticationException {
return getAuthState().getAuthRole();
}

@Override
public CompletableFuture<AuthData> authenticateAsync(AuthData authData) {
// First, attempt to authenticate with the current auth state
CompletableFuture<AuthData> authChallengeFuture = new CompletableFuture<>();
authState
.authenticateAsync(authData)
.whenComplete((authChallenge, ex) -> {
if (ex == null) {
// Current authState is still correct. Just need to return the authChallenge.
authChallengeFuture.complete(authChallenge);
} else {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + authState.getClass() + ": ", ex);
}
authenticateRemainingAuthStates(authChallengeFuture, authData, ex, states.size() - 1);
}
});
return authChallengeFuture;
}

private void authenticateRemainingAuthStates(CompletableFuture<AuthData> authChallengeFuture,
AuthData clientAuthData,
Throwable previousException,
int index) {
if (index < 0) {
if (previousException == null) {
previousException = new AuthenticationException("Authentication required");
}
AuthenticationMetrics.authenticateFailure(AuthenticationProviderList.class.getSimpleName(),
"authentication-provider-list", "Authentication required");
authChallengeFuture.completeExceptionally(previousException);
return;
}
AuthenticationState state = states.get(index);
if (state == authState) {
// Skip the current auth state
authenticateRemainingAuthStates(authChallengeFuture, clientAuthData, null, index - 1);
} else {
state.authenticateAsync(clientAuthData)
.whenComplete((authChallenge, ex) -> {
if (ex == null) {
// Found the correct auth state
authState = state;
authChallengeFuture.complete(authChallenge);
} else {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider "
+ authState.getClass() + ": ", ex);
}
authenticateRemainingAuthStates(authChallengeFuture, clientAuthData, ex, index - 1);
}
});
}
}

@Override
public AuthData authenticate(AuthData authData) throws AuthenticationException {
return applyAuthProcessor(
Expand Down Expand Up @@ -160,6 +219,40 @@ public String getAuthMethodName() {
return providers.get(0).getAuthMethodName();
}

@Override
public CompletableFuture<String> authenticateAsync(AuthenticationDataSource authData) {
CompletableFuture<String> roleFuture = new CompletableFuture<>();
authenticateRemainingAuthProviders(roleFuture, authData, null, providers.size() - 1);
return roleFuture;
}

private void authenticateRemainingAuthProviders(CompletableFuture<String> roleFuture,
AuthenticationDataSource authData,
Throwable previousException,
int index) {
if (index < 0) {
if (previousException == null) {
previousException = new AuthenticationException("Authentication required");
}
AuthenticationMetrics.authenticateFailure(AuthenticationProviderList.class.getSimpleName(),
"authentication-provider-list", "Authentication required");
roleFuture.completeExceptionally(previousException);
return;
}
AuthenticationProvider provider = providers.get(index);
provider.authenticateAsync(authData)
.whenComplete((role, ex) -> {
if (ex == null) {
roleFuture.complete(role);
} else {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ex);
}
authenticateRemainingAuthProviders(roleFuture, authData, ex, index - 1);
}
});
}

@Override
public String authenticate(AuthenticationDataSource authData) throws AuthenticationException {
return applyAuthProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,30 @@ public void testAuthenticate() throws Exception {
testAuthenticate(tokenBB, SUBJECT_B);
}

private void testAuthenticateAsync(String token, String expectedSubject) throws Exception {
String actualSubject = authProvider.authenticateAsync(new AuthenticationDataSource() {
@Override
public boolean hasDataFromCommand() {
return true;
}

@Override
public String getCommandData() {
return token;
}
}).get();
assertEquals(actualSubject, expectedSubject);
}

@Test
public void testAuthenticateAsync() throws Exception {
testAuthenticateAsync(tokenAA, SUBJECT_A);
testAuthenticateAsync(tokenAB, SUBJECT_B);
testAuthenticateAsync(tokenBA, SUBJECT_A);
testAuthenticateAsync(tokenBB, SUBJECT_B);
}


private AuthenticationState newAuthState(String token, String expectedSubject) throws Exception {
// Must pass the token to the newAuthState for legacy reasons.
AuthenticationState authState = authProvider.newAuthState(
Expand Down

0 comments on commit 58ccf02

Please sign in to comment.