Skip to content

Commit 457f4db

Browse files
committed
Polish Nimbus JWK Source Implementation
Issue gh-16251
1 parent 2349247 commit 457f4db

File tree

2 files changed

+50
-117
lines changed

2 files changed

+50
-117
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

+49-115
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,14 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import com.nimbusds.jose.KeySourceException;
20-
import com.nimbusds.jose.jwk.JWK;
21-
import com.nimbusds.jose.jwk.JWKMatcher;
22-
import com.nimbusds.jose.jwk.JWKSelector;
23-
import com.nimbusds.jose.jwk.source.JWKSetParseException;
24-
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
25-
import java.io.IOException;
26-
import java.net.MalformedURLException;
27-
import java.net.URL;
19+
import java.net.URI;
2820
import java.security.interfaces.RSAPublicKey;
2921
import java.text.ParseException;
3022
import java.util.Arrays;
3123
import java.util.Collection;
3224
import java.util.Collections;
3325
import java.util.HashSet;
3426
import java.util.LinkedHashMap;
35-
import java.util.List;
3627
import java.util.Map;
3728
import java.util.Set;
3829
import java.util.concurrent.locks.ReentrantLock;
@@ -43,8 +34,13 @@
4334

4435
import com.nimbusds.jose.JOSEException;
4536
import com.nimbusds.jose.JWSAlgorithm;
37+
import com.nimbusds.jose.KeySourceException;
38+
import com.nimbusds.jose.RemoteKeySourceException;
4639
import com.nimbusds.jose.jwk.JWKSet;
40+
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
41+
import com.nimbusds.jose.jwk.source.JWKSetSource;
4742
import com.nimbusds.jose.jwk.source.JWKSource;
43+
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
4844
import com.nimbusds.jose.proc.JWSKeySelector;
4945
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
5046
import com.nimbusds.jose.proc.SecurityContext;
@@ -170,7 +166,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
170166
.build();
171167
// @formatter:on
172168
}
173-
catch (KeySourceException ex) {
169+
catch (RemoteKeySourceException ex) {
174170
this.logger.trace("Failed to retrieve JWK set", ex);
175171
if (ex.getCause() instanceof ParseException) {
176172
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@@ -383,7 +379,11 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
383379

384380
JWKSource<SecurityContext> jwkSource() {
385381
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
386-
return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
382+
return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri))
383+
.refreshAheadCache(false)
384+
.rateLimited(false)
385+
.cache(this.cache instanceof NoOpCache)
386+
.build();
387387
}
388388

389389
JWTProcessor<SecurityContext> processor() {
@@ -405,16 +405,7 @@ public NimbusJwtDecoder build() {
405405
return new NimbusJwtDecoder(processor());
406406
}
407407

408-
private static URL toURL(String url) {
409-
try {
410-
return new URL(url);
411-
}
412-
catch (MalformedURLException ex) {
413-
throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
414-
}
415-
}
416-
417-
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
408+
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {
418409

419410
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
420411

@@ -424,120 +415,63 @@ private static final class SpringJWKSource<C extends SecurityContext> implements
424415

425416
private final Cache cache;
426417

427-
private final URL url;
428-
429418
private final String jwkSetUri;
430419

431-
private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
420+
private JWKSet jwkSet;
421+
422+
private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) {
432423
Assert.notNull(restOperations, "restOperations cannot be null");
433424
this.restOperations = restOperations;
434425
this.cache = cache;
435-
this.url = url;
436426
this.jwkSetUri = jwkSetUri;
437-
}
438-
439-
440-
@Override
441-
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
442-
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
443-
JWKSet jwkSet = null;
444-
if (cachedJwkSet != null) {
445-
jwkSet = parse(cachedJwkSet);
446-
}
447-
if (jwkSet == null) {
448-
if(reentrantLock.tryLock()) {
449-
try {
450-
String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
451-
if (cachedJwkSetAfterLock != null) {
452-
jwkSet = parse(cachedJwkSetAfterLock);
453-
}
454-
if(jwkSet == null) {
455-
try {
456-
jwkSet = fetchJWKSet();
457-
} catch (IOException e) {
458-
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
459-
}
460-
}
461-
} finally {
462-
reentrantLock.unlock();
463-
}
464-
}
465-
}
466-
List<JWK> matches = jwkSelector.select(jwkSet);
467-
if(!matches.isEmpty()) {
468-
return matches;
469-
}
470-
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
471-
if (soughtKeyID == null) {
472-
return Collections.emptyList();
473-
}
474-
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
475-
return Collections.emptyList();
476-
}
477-
478-
if(reentrantLock.tryLock()) {
427+
String jwks = this.cache.get(this.jwkSetUri, String.class);
428+
if (jwks != null) {
479429
try {
480-
String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
481-
JWKSet cacheJwkSet = parse(jwkSetUri);
482-
if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
483-
try {
484-
jwkSet = fetchJWKSet();
485-
} catch (IOException e) {
486-
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
487-
}
488-
} else if (jwkSetUri != null) {
489-
jwkSet = parse(jwkSetUri);
490-
}
491-
} finally {
492-
reentrantLock.unlock();
430+
this.jwkSet = JWKSet.parse(jwks);
431+
}
432+
catch (ParseException ignored) {
433+
// Ignore invalid cache value
493434
}
494435
}
495-
if(jwkSet == null) {
496-
return Collections.emptyList();
497-
}
498-
return jwkSelector.select(jwkSet);
499436
}
500437

501-
private JWKSet fetchJWKSet() throws IOException, KeySourceException {
438+
private String fetchJwks() throws Exception {
502439
HttpHeaders headers = new HttpHeaders();
503440
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
504-
ResponseEntity<String> response = getResponse(headers);
505-
if (response.getStatusCode().value() != 200) {
506-
throw new IOException(response.toString());
507-
}
508-
try {
509-
String jwkSet = response.getBody();
510-
this.cache.put(this.jwkSetUri, jwkSet);
511-
return JWKSet.parse(jwkSet);
512-
} catch (ParseException e) {
513-
throw new JWKSetParseException("Unable to parse JWK set", e);
514-
}
441+
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri));
442+
ResponseEntity<String> response = this.restOperations.exchange(request, String.class);
443+
String jwks = response.getBody();
444+
this.jwkSet = JWKSet.parse(jwks);
445+
return jwks;
515446
}
516447

517-
private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
448+
@Override
449+
public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context)
450+
throws KeySourceException {
518451
try {
519-
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
520-
return this.restOperations.exchange(request, String.class);
521-
} catch (Exception ex) {
522-
throw new IOException(ex);
452+
this.reentrantLock.lock();
453+
if (refreshEvaluator.requiresRefresh(this.jwkSet)) {
454+
this.cache.invalidate();
455+
}
456+
this.cache.get(this.jwkSetUri, this::fetchJwks);
457+
return this.jwkSet;
523458
}
524-
}
525-
526-
private JWKSet parse(String cachedJwkSet) {
527-
JWKSet jwkSet = null;
528-
try {
529-
jwkSet = JWKSet.parse(cachedJwkSet);
530-
} catch (ParseException ignored) {
531-
// Ignore invalid cache value
459+
catch (Cache.ValueRetrievalException ex) {
460+
if (ex.getCause() instanceof RemoteKeySourceException keys) {
461+
throw keys;
462+
}
463+
throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause());
464+
}
465+
finally {
466+
this.reentrantLock.unlock();
532467
}
533-
return jwkSet;
534468
}
535469

536-
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
537-
Set<String> keyIDs = jwkMatcher.getKeyIDs();
538-
return (keyIDs == null || keyIDs.isEmpty()) ?
539-
null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
470+
@Override
471+
public void close() {
472+
540473
}
474+
541475
}
542476

543477
}

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2025 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -308,7 +308,6 @@ private void prepareConfigurationResponse() {
308308
private void prepareConfigurationResponse(String body) {
309309
this.server.enqueue(response(body));
310310
this.server.enqueue(response(JWK_SET));
311-
this.server.enqueue(response(JWK_SET)); // default NoOpCache
312311
}
313312

314313
private void prepareConfigurationResponseOidc() {

0 commit comments

Comments
 (0)