Skip to content

Enhanced jwks-provider tests #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.scalecube.security.tokens.jwt;

import static io.scalecube.security.tokens.jwt.Utils.toRsaPublicKey;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.PropertyAccessor;
Expand All @@ -24,9 +26,9 @@ public final class JwksKeyProvider implements KeyProvider {

private static final Logger LOGGER = LoggerFactory.getLogger(JwksKeyProvider.class);

private final Scheduler scheduler = Schedulers.newSingle("jwks-key-provider", true);
private static final ObjectMapper OBJECT_MAPPER = newObjectMapper();

private final ObjectMapper mapper;
private final Scheduler scheduler;
private final String jwksUri;
private final long connectTimeoutMillis;
private final long readTimeoutMillis;
Expand All @@ -37,22 +39,21 @@ public final class JwksKeyProvider implements KeyProvider {
* @param jwksUri jwksUri
*/
public JwksKeyProvider(String jwksUri) {
this.jwksUri = jwksUri;
this.mapper = initMapper();
this.connectTimeoutMillis = Duration.ofSeconds(10).toMillis();
this.readTimeoutMillis = Duration.ofSeconds(10).toMillis();
this(jwksUri, newScheduler(), Duration.ofSeconds(10), Duration.ofSeconds(10));
}

/**
* Constructor.
*
* @param jwksUri jwksUri
* @param scheduler scheduler
* @param connectTimeout connectTimeout
* @param readTimeout readTimeout
*/
public JwksKeyProvider(String jwksUri, Duration connectTimeout, Duration readTimeout) {
public JwksKeyProvider(
String jwksUri, Scheduler scheduler, Duration connectTimeout, Duration readTimeout) {
this.jwksUri = jwksUri;
this.mapper = initMapper();
this.scheduler = scheduler;
this.connectTimeoutMillis = connectTimeout.toMillis();
this.readTimeoutMillis = readTimeout.toMillis();
}
Expand Down Expand Up @@ -87,7 +88,7 @@ private Mono<InputStream> callJwksUri() {

private JwkInfoList toKeyList(InputStream stream) {
try (InputStream inputStream = new BufferedInputStream(stream)) {
return mapper.readValue(inputStream, JwkInfoList.class);
return OBJECT_MAPPER.readValue(inputStream, JwkInfoList.class);
} catch (IOException e) {
LOGGER.error("[toKeyList] Exception occurred: {}", e.toString());
throw new KeyProviderException(e);
Expand All @@ -98,10 +99,10 @@ private Optional<Key> findRsaKey(JwkInfoList list, String kid) {
return list.keys().stream()
.filter(k -> kid.equals(k.kid()))
.findFirst()
.map(info -> Utils.getRsaPublicKey(info.modulus(), info.exponent()));
.map(info -> toRsaPublicKey(info.modulus(), info.exponent()));
}

private static ObjectMapper initMapper() {
private static ObjectMapper newObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
Expand All @@ -111,4 +112,8 @@ private static ObjectMapper initMapper() {
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
return mapper;
}

private static Scheduler newScheduler() {
return Schedulers.newElastic("jwks-key-provider", 60, true);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.scalecube.security.tokens.jwt;

import io.scalecube.security.tokens.jwt.jsonwebtoken.JsonwebtokenParserFactory;
import java.security.Key;
import java.time.Duration;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -18,37 +20,36 @@ public final class JwtTokenResolverImpl implements JwtTokenResolver {

private final KeyProvider keyProvider;
private final JwtTokenParserFactory tokenParserFactory;
private final int cleanupIntervalSec;
private final Scheduler scheduler;
private final Duration cleanupInterval;

private final Map<String, Mono<Key>> keyResolutions = new ConcurrentHashMap<>();

/**
* Constructor.
*
* @param keyProvider key provider
* @param tokenParserFactory token parser factoty
*/
public JwtTokenResolverImpl(KeyProvider keyProvider, JwtTokenParserFactory tokenParserFactory) {
this(keyProvider, tokenParserFactory, 3600, Schedulers.newSingle("caching-key-provider", true));
public JwtTokenResolverImpl(KeyProvider keyProvider) {
this(keyProvider, new JsonwebtokenParserFactory(), newScheduler(), Duration.ofSeconds(60));
}

/**
* Constructor.
*
* @param keyProvider key provider
* @param tokenParserFactory token parser factoty
* @param cleanupIntervalSec cleanup interval (in sec) for resolved cached keys
* @param scheduler cleanup scheduler
* @param cleanupInterval cleanup interval for resolved cached keys
*/
public JwtTokenResolverImpl(
KeyProvider keyProvider,
JwtTokenParserFactory tokenParserFactory,
int cleanupIntervalSec,
Scheduler scheduler) {
Scheduler scheduler,
Duration cleanupInterval) {
this.keyProvider = keyProvider;
this.tokenParserFactory = tokenParserFactory;
this.cleanupIntervalSec = cleanupIntervalSec;
this.cleanupInterval = cleanupInterval;
this.scheduler = scheduler;
}

Expand Down Expand Up @@ -107,12 +108,16 @@ private Mono<Key> findKey(String kid, AtomicReference<Mono<Key>> computedValueHo

private void scheduleCleanup(String kid, AtomicReference<Mono<Key>> computedValueHolder) {
scheduler.schedule(
() -> cleanup(kid, computedValueHolder), cleanupIntervalSec, TimeUnit.SECONDS);
() -> cleanup(kid, computedValueHolder), cleanupInterval.toMillis(), TimeUnit.MILLISECONDS);
}

private void cleanup(String kid, AtomicReference<Mono<Key>> computedValueHolder) {
if (computedValueHolder.get() != null) {
keyResolutions.remove(kid, computedValueHolder.get());
}
}

private static Scheduler newScheduler() {
return Schedulers.newElastic("token-resolver-cleaner", 60, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ private Utils() {
* @param e exponent (b64 url encoded)
* @return RSA public key instance
*/
public static Key getRsaPublicKey(String n, String e) {
public static Key toRsaPublicKey(String n, String e) {
Decoder b64Decoder = Base64.getUrlDecoder();
BigInteger modulus = new BigInteger(1, b64Decoder.decode(n));
BigInteger exponent = new BigInteger(1, b64Decoder.decode(e));
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.scalecube.security.tokens.jwt;

import java.io.IOException;
import static io.scalecube.security.tokens.jwt.Utils.toRsaPublicKey;

import java.security.Key;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;
import java.util.Properties;
Expand All @@ -10,13 +12,14 @@
import org.mockito.Mockito;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import reactor.test.scheduler.VirtualTimeScheduler;

class JwtTokenResolverTests extends BaseTest {
class JwtTokenResolverTests {

private static final Map<String, Object> BODY = Collections.singletonMap("aud", "aud");

@Test
void testTokenResolver() throws IOException {
void testTokenResolver() throws Exception {
TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties");

JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class);
Expand All @@ -32,7 +35,9 @@ void testTokenResolver() throws IOException {
KeyProvider keyProvider = Mockito.mock(KeyProvider.class);
Mockito.when(keyProvider.findKey(tokenWithKey.kid)).thenReturn(Mono.just(tokenWithKey.key));

JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory);
JwtTokenResolverImpl tokenResolver =
new JwtTokenResolverImpl(
keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3));

// N times call resolve
StepVerifier.create(tokenResolver.resolve(tokenWithKey.token).repeat(3))
Expand All @@ -45,7 +50,7 @@ void testTokenResolver() throws IOException {
}

@Test
void testTokenResolverWithRotatingKey() throws IOException {
void testTokenResolverWithRotatingKey() throws Exception {
TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties");
TokenWithKey tokenWithKeyAfterRotation =
new TokenWithKey("token-and-pubkey.after-rotation.properties");
Expand All @@ -70,7 +75,9 @@ void testTokenResolverWithRotatingKey() throws IOException {
Mockito.when(keyProvider.findKey(tokenWithKeyAfterRotation.kid))
.thenReturn(Mono.just(tokenWithKeyAfterRotation.key));

JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory);
JwtTokenResolverImpl tokenResolver =
new JwtTokenResolverImpl(
keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3));

// Call normal token first
StepVerifier.create(tokenResolver.resolve(tokenWithKey.token))
Expand All @@ -90,7 +97,7 @@ void testTokenResolverWithRotatingKey() throws IOException {
}

@Test
void testTokenResolverWithWrongKey() throws IOException {
void testTokenResolverWithWrongKey() throws Exception {
TokenWithKey tokenWithWrongKey = new TokenWithKey("token-and-wrong-pubkey.properties");

JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class);
Expand All @@ -106,7 +113,9 @@ void testTokenResolverWithWrongKey() throws IOException {
Mockito.when(keyProvider.findKey(tokenWithWrongKey.kid))
.thenReturn(Mono.just(tokenWithWrongKey.key));

JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory);
JwtTokenResolverImpl tokenResolver =
new JwtTokenResolverImpl(
keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3));

// Must fail (retry N times)
StepVerifier.create(tokenResolver.resolve(tokenWithWrongKey.token).retry(1))
Expand All @@ -118,7 +127,7 @@ void testTokenResolverWithWrongKey() throws IOException {
}

@Test
void testTokenResolverWhenKeyProviderFailing() throws IOException {
void testTokenResolverWhenKeyProviderFailing() throws Exception {
TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties");

JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class);
Expand All @@ -134,7 +143,9 @@ void testTokenResolverWhenKeyProviderFailing() throws IOException {
KeyProvider keyProvider = Mockito.mock(KeyProvider.class);
Mockito.when(keyProvider.findKey(tokenWithKey.kid)).thenThrow(RuntimeException.class);

JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory);
JwtTokenResolverImpl tokenResolver =
new JwtTokenResolverImpl(
keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3));

// Must fail with "hola" (retry N times)
StepVerifier.create(tokenResolver.resolve(tokenWithKey.token).retry(1)).expectError().verify();
Expand All @@ -149,13 +160,13 @@ static class TokenWithKey {
final Key key;
final String kid;

TokenWithKey(String s) throws IOException {
TokenWithKey(String s) throws Exception {
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
Properties props = new Properties();
props.load(classLoader.getResourceAsStream(s));
this.token = props.getProperty("token");
this.kid = props.getProperty("kid");
this.key = Utils.getRsaPublicKey(props.getProperty("n"), props.getProperty("e"));
this.key = toRsaPublicKey(props.getProperty("n"), props.getProperty("e"));
}
}
}
Loading