Skip to content

Handle verification of JWT tokens gracefully. #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.JWTVerifier;
import java.time.Instant;
Expand All @@ -34,9 +35,12 @@
import org.apache.polaris.core.persistence.PolarisEntityManager;
import org.apache.polaris.core.persistence.PolarisMetaStoreManager;
import org.apache.polaris.service.types.TokenType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Generates a JWT Token. */
abstract class JWTBroker implements TokenBroker {
private static final Logger LOGGER = LoggerFactory.getLogger(JWTBroker.class);

private static final String ISSUER_KEY = "polaris";
private static final String CLAIM_KEY_ACTIVE = "active";
Expand All @@ -56,36 +60,36 @@ abstract class JWTBroker implements TokenBroker {

@Override
public DecodedToken verify(String token) {
JWTVerifier verifier = JWT.require(getAlgorithm()).build();
DecodedJWT decodedJWT = verifier.verify(token);
Boolean isActive = decodedJWT.getClaim(CLAIM_KEY_ACTIVE).asBoolean();
if (isActive == null || !isActive) {
throw new NotAuthorizedException("Token is not active");
JWTVerifier verifier = JWT.require(getAlgorithm()).withClaim(CLAIM_KEY_ACTIVE, true).build();

try {
DecodedJWT decodedJWT = verifier.verify(token);
return new DecodedToken() {
@Override
public Long getPrincipalId() {
return decodedJWT.getClaim("principalId").asLong();
}

@Override
public String getClientId() {
return decodedJWT.getClaim("client_id").asString();
}

@Override
public String getSub() {
return decodedJWT.getSubject();
}

@Override
public String getScope() {
return decodedJWT.getClaim("scope").asString();
}
};

} catch (JWTVerificationException e) {
LOGGER.error("Failed to verify the token with error", e);
throw new NotAuthorizedException("Failed to verify the token");
}
if (decodedJWT.getExpiresAtAsInstant().isBefore(Instant.now())) {
throw new NotAuthorizedException("Token has expired");
}
return new DecodedToken() {
@Override
public Long getPrincipalId() {
return decodedJWT.getClaim("principalId").asLong();
}

@Override
public String getClientId() {
return decodedJWT.getClaim("client_id").asString();
}

@Override
public String getSub() {
return decodedJWT.getSubject();
}

@Override
public String getScope() {
return decodedJWT.getClaim("scope").asString();
}
};
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@
package org.apache.polaris.service.admin;

import static io.dropwizard.jackson.Jackson.newObjectMapper;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.polaris.service.context.DefaultContextResolver.REALM_PROPERTY_KEY;
import static org.assertj.core.api.Assertions.assertThat;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTCreator;
import com.auth0.jwt.algorithms.Algorithm;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import io.dropwizard.testing.ConfigOverride;
import io.dropwizard.testing.ResourceHelpers;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
Expand All @@ -34,9 +41,14 @@
import jakarta.ws.rs.client.Invocation;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.iceberg.catalog.Namespace;
import org.apache.iceberg.rest.RESTUtil;
Expand Down Expand Up @@ -79,6 +91,7 @@
import org.apache.polaris.core.admin.model.UpdatePrincipalRoleRequest;
import org.apache.polaris.core.entity.PolarisEntityConstants;
import org.apache.polaris.service.PolarisApplication;
import org.apache.polaris.service.auth.BasePolarisAuthenticator;
import org.apache.polaris.service.auth.TokenUtils;
import org.apache.polaris.service.config.PolarisApplicationConfig;
import org.apache.polaris.service.test.PolarisConnectionExtension;
Expand All @@ -89,10 +102,16 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.slf4j.LoggerFactory;
import org.testcontainers.shaded.org.awaitility.Awaitility;

@ExtendWith({DropwizardExtensionsSupport.class, PolarisConnectionExtension.class})
public class PolarisServiceImplIntegrationTest {
private static final int MAX_IDENTIFIER_LENGTH = 256;
private static final String ISSUER_KEY = "polaris";
private static final String CLAIM_KEY_ACTIVE = "active";
private static final String CLAIM_KEY_CLIENT_ID = "client_id";
private static final String CLAIM_KEY_PRINCIPAL_ID = "principalId";
private static final String CLAIM_KEY_SCOPE = "scope";

// TODO: Add a test-only hook that fully clobbers all persistence state so we can have a fresh
// slate on every test case; otherwise, leftover state from one test from failures will interfere
Expand All @@ -113,6 +132,7 @@ public class PolarisServiceImplIntegrationTest {
ConfigOverride.config("gcp_credentials.expires_in", "12345"));
private static String userToken;
private static String realm;
private static String clientId;

@BeforeAll
public static void setup(
Expand All @@ -121,6 +141,12 @@ public static void setup(
userToken = adminToken.token();
realm = polarisRealm;

Base64.Decoder decoder = Base64.getUrlDecoder();
String[] chunks = adminToken.token().split("\\.");
String payload = new String(decoder.decode(chunks[1]), UTF_8);
JsonElement jsonElement = JsonParser.parseString(payload);
clientId = String.valueOf(((JsonObject) jsonElement).get("client_id"));

// Set up test location
PolarisConnectionExtension.createTestDir(realm);
}
Expand Down Expand Up @@ -2225,6 +2251,80 @@ public void testTableManageAccessCanGrantAndRevokeFromCatalogRoles() {
Response.Status.FORBIDDEN);
}

@Test
public void testTokenExpiry() {
// TokenExpiredException - if the token has expired.
String newToken =
defaultJwt()
.withExpiresAt(Instant.now().plus(1, ChronoUnit.SECONDS))
.sign(Algorithm.HMAC256("polaris"));
Awaitility.await("expected list of records should be produced")
.atMost(Duration.ofSeconds(2))
.pollDelay(Duration.ofSeconds(1))
.pollInterval(Duration.ofSeconds(1))
.untilAsserted(
() -> {
try (Response response =
newRequest(
"http://localhost:%d/api/management/v1/principals", "Bearer " + newToken)
.get()) {
assertThat(response)
.returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus);
}
});
}

@Test
public void testTokenInactive() {
// InvalidClaimException - if a claim contained a different value than the expected one.
String newToken =
defaultJwt().withClaim(CLAIM_KEY_ACTIVE, false).sign(Algorithm.HMAC256("polaris"));
try (Response response =
newRequest("http://localhost:%d/api/management/v1/principals", "Bearer " + newToken)
.get()) {
assertThat(response)
.returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus);
}
}

@Test
public void testTokenInvalidSignature() {
// SignatureVerificationException - if the signature is invalid.
String newToken = defaultJwt().sign(Algorithm.HMAC256("invalid_secret"));
try (Response response =
newRequest("http://localhost:%d/api/management/v1/principals", "Bearer " + newToken)
.get()) {
assertThat(response)
.returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus);
}
}

@Test
public void testTokenInvalidPrincipalId() {
String newToken =
defaultJwt().withClaim(CLAIM_KEY_PRINCIPAL_ID, 0).sign(Algorithm.HMAC256("polaris"));
try (Response response =
newRequest("http://localhost:%d/api/management/v1/principals", "Bearer " + newToken)
.get()) {
assertThat(response)
.returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus);
}
}

public static JWTCreator.Builder defaultJwt() {
Instant now = Instant.now();
return JWT.create()
.withIssuer(ISSUER_KEY)
.withSubject(String.valueOf(1))
.withIssuedAt(now)
.withExpiresAt(now.plus(10, ChronoUnit.SECONDS))
.withJWTId(UUID.randomUUID().toString())
.withClaim(CLAIM_KEY_ACTIVE, true)
.withClaim(CLAIM_KEY_CLIENT_ID, clientId)
.withClaim(CLAIM_KEY_PRINCIPAL_ID, 1)
.withClaim(CLAIM_KEY_SCOPE, BasePolarisAuthenticator.PRINCIPAL_ROLE_ALL);
}

private static void createNamespace(String catalogName, String namespaceName) {
try (Response response =
newRequest("http://localhost:%d/api/catalog/v1/" + catalogName + "/namespaces", userToken)
Expand Down