diff --git a/polaris-service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java b/polaris-service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java index 5caeac21d..19d0664eb 100644 --- a/polaris-service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java +++ b/polaris-service/src/main/java/org/apache/polaris/service/auth/JWTBroker.java @@ -20,6 +20,7 @@ import com.auth0.jwt.JWT; import com.auth0.jwt.algorithms.Algorithm; +import com.auth0.jwt.exceptions.JWTVerificationException; import com.auth0.jwt.interfaces.DecodedJWT; import com.auth0.jwt.interfaces.JWTVerifier; import java.time.Instant; @@ -34,9 +35,12 @@ import org.apache.polaris.core.persistence.PolarisEntityManager; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; import org.apache.polaris.service.types.TokenType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Generates a JWT Token. */ abstract class JWTBroker implements TokenBroker { + private static final Logger LOGGER = LoggerFactory.getLogger(JWTBroker.class); private static final String ISSUER_KEY = "polaris"; private static final String CLAIM_KEY_ACTIVE = "active"; @@ -56,36 +60,36 @@ abstract class JWTBroker implements TokenBroker { @Override public DecodedToken verify(String token) { - JWTVerifier verifier = JWT.require(getAlgorithm()).build(); - DecodedJWT decodedJWT = verifier.verify(token); - Boolean isActive = decodedJWT.getClaim(CLAIM_KEY_ACTIVE).asBoolean(); - if (isActive == null || !isActive) { - throw new NotAuthorizedException("Token is not active"); + JWTVerifier verifier = JWT.require(getAlgorithm()).withClaim(CLAIM_KEY_ACTIVE, true).build(); + + try { + DecodedJWT decodedJWT = verifier.verify(token); + return new DecodedToken() { + @Override + public Long getPrincipalId() { + return decodedJWT.getClaim("principalId").asLong(); + } + + @Override + public String getClientId() { + return decodedJWT.getClaim("client_id").asString(); + } + + @Override + public String getSub() { + return decodedJWT.getSubject(); + } + + @Override + public String getScope() { + return decodedJWT.getClaim("scope").asString(); + } + }; + + } catch (JWTVerificationException e) { + LOGGER.error("Failed to verify the token with error", e); + throw new NotAuthorizedException("Failed to verify the token"); } - if (decodedJWT.getExpiresAtAsInstant().isBefore(Instant.now())) { - throw new NotAuthorizedException("Token has expired"); - } - return new DecodedToken() { - @Override - public Long getPrincipalId() { - return decodedJWT.getClaim("principalId").asLong(); - } - - @Override - public String getClientId() { - return decodedJWT.getClaim("client_id").asString(); - } - - @Override - public String getSub() { - return decodedJWT.getSubject(); - } - - @Override - public String getScope() { - return decodedJWT.getClaim("scope").asString(); - } - }; } @Override diff --git a/polaris-service/src/test/java/org/apache/polaris/service/admin/PolarisServiceImplIntegrationTest.java b/polaris-service/src/test/java/org/apache/polaris/service/admin/PolarisServiceImplIntegrationTest.java index 2eb5a4525..cbb6ed887 100644 --- a/polaris-service/src/test/java/org/apache/polaris/service/admin/PolarisServiceImplIntegrationTest.java +++ b/polaris-service/src/test/java/org/apache/polaris/service/admin/PolarisServiceImplIntegrationTest.java @@ -19,13 +19,20 @@ package org.apache.polaris.service.admin; import static io.dropwizard.jackson.Jackson.newObjectMapper; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.apache.polaris.service.context.DefaultContextResolver.REALM_PROPERTY_KEY; import static org.assertj.core.api.Assertions.assertThat; +import com.auth0.jwt.JWT; +import com.auth0.jwt.JWTCreator; +import com.auth0.jwt.algorithms.Algorithm; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import io.dropwizard.testing.ConfigOverride; import io.dropwizard.testing.ResourceHelpers; import io.dropwizard.testing.junit5.DropwizardAppExtension; @@ -34,9 +41,14 @@ import jakarta.ws.rs.client.Invocation; import jakarta.ws.rs.core.Response; import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Arrays; +import java.util.Base64; import java.util.List; import java.util.Map; +import java.util.UUID; import org.apache.commons.lang3.RandomStringUtils; import org.apache.iceberg.catalog.Namespace; import org.apache.iceberg.rest.RESTUtil; @@ -79,6 +91,7 @@ import org.apache.polaris.core.admin.model.UpdatePrincipalRoleRequest; import org.apache.polaris.core.entity.PolarisEntityConstants; import org.apache.polaris.service.PolarisApplication; +import org.apache.polaris.service.auth.BasePolarisAuthenticator; import org.apache.polaris.service.auth.TokenUtils; import org.apache.polaris.service.config.PolarisApplicationConfig; import org.apache.polaris.service.test.PolarisConnectionExtension; @@ -89,10 +102,16 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.slf4j.LoggerFactory; +import org.testcontainers.shaded.org.awaitility.Awaitility; @ExtendWith({DropwizardExtensionsSupport.class, PolarisConnectionExtension.class}) public class PolarisServiceImplIntegrationTest { private static final int MAX_IDENTIFIER_LENGTH = 256; + private static final String ISSUER_KEY = "polaris"; + private static final String CLAIM_KEY_ACTIVE = "active"; + private static final String CLAIM_KEY_CLIENT_ID = "client_id"; + private static final String CLAIM_KEY_PRINCIPAL_ID = "principalId"; + private static final String CLAIM_KEY_SCOPE = "scope"; // TODO: Add a test-only hook that fully clobbers all persistence state so we can have a fresh // slate on every test case; otherwise, leftover state from one test from failures will interfere @@ -113,6 +132,7 @@ public class PolarisServiceImplIntegrationTest { ConfigOverride.config("gcp_credentials.expires_in", "12345")); private static String userToken; private static String realm; + private static String clientId; @BeforeAll public static void setup( @@ -121,6 +141,12 @@ public static void setup( userToken = adminToken.token(); realm = polarisRealm; + Base64.Decoder decoder = Base64.getUrlDecoder(); + String[] chunks = adminToken.token().split("\\."); + String payload = new String(decoder.decode(chunks[1]), UTF_8); + JsonElement jsonElement = JsonParser.parseString(payload); + clientId = String.valueOf(((JsonObject) jsonElement).get("client_id")); + // Set up test location PolarisConnectionExtension.createTestDir(realm); } @@ -2225,6 +2251,80 @@ public void testTableManageAccessCanGrantAndRevokeFromCatalogRoles() { Response.Status.FORBIDDEN); } + @Test + public void testTokenExpiry() { + // TokenExpiredException - if the token has expired. + String newToken = + defaultJwt() + .withExpiresAt(Instant.now().plus(1, ChronoUnit.SECONDS)) + .sign(Algorithm.HMAC256("polaris")); + Awaitility.await("expected list of records should be produced") + .atMost(Duration.ofSeconds(2)) + .pollDelay(Duration.ofSeconds(1)) + .pollInterval(Duration.ofSeconds(1)) + .untilAsserted( + () -> { + try (Response response = + newRequest( + "http://localhost:%d/api/management/v1/principals", "Bearer " + newToken) + .get()) { + assertThat(response) + .returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus); + } + }); + } + + @Test + public void testTokenInactive() { + // InvalidClaimException - if a claim contained a different value than the expected one. + String newToken = + defaultJwt().withClaim(CLAIM_KEY_ACTIVE, false).sign(Algorithm.HMAC256("polaris")); + try (Response response = + newRequest("http://localhost:%d/api/management/v1/principals", "Bearer " + newToken) + .get()) { + assertThat(response) + .returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus); + } + } + + @Test + public void testTokenInvalidSignature() { + // SignatureVerificationException - if the signature is invalid. + String newToken = defaultJwt().sign(Algorithm.HMAC256("invalid_secret")); + try (Response response = + newRequest("http://localhost:%d/api/management/v1/principals", "Bearer " + newToken) + .get()) { + assertThat(response) + .returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus); + } + } + + @Test + public void testTokenInvalidPrincipalId() { + String newToken = + defaultJwt().withClaim(CLAIM_KEY_PRINCIPAL_ID, 0).sign(Algorithm.HMAC256("polaris")); + try (Response response = + newRequest("http://localhost:%d/api/management/v1/principals", "Bearer " + newToken) + .get()) { + assertThat(response) + .returns(Response.Status.UNAUTHORIZED.getStatusCode(), Response::getStatus); + } + } + + public static JWTCreator.Builder defaultJwt() { + Instant now = Instant.now(); + return JWT.create() + .withIssuer(ISSUER_KEY) + .withSubject(String.valueOf(1)) + .withIssuedAt(now) + .withExpiresAt(now.plus(10, ChronoUnit.SECONDS)) + .withJWTId(UUID.randomUUID().toString()) + .withClaim(CLAIM_KEY_ACTIVE, true) + .withClaim(CLAIM_KEY_CLIENT_ID, clientId) + .withClaim(CLAIM_KEY_PRINCIPAL_ID, 1) + .withClaim(CLAIM_KEY_SCOPE, BasePolarisAuthenticator.PRINCIPAL_ROLE_ALL); + } + private static void createNamespace(String catalogName, String namespaceName) { try (Response response = newRequest("http://localhost:%d/api/catalog/v1/" + catalogName + "/namespaces", userToken)