Skip to content

Commit f5a669e

Browse files
committed
Remove Deprecated Usages of RemoteJWKSet
Closes gh-16251 Signed-off-by: Daeho Kwon <trewq231@naver.com>
1 parent ed5cccc commit f5a669e

File tree

3 files changed

+105
-139
lines changed

3 files changed

+105
-139
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 102 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,8 +20,8 @@
2020
import com.nimbusds.jose.jwk.JWK;
2121
import com.nimbusds.jose.jwk.JWKMatcher;
2222
import com.nimbusds.jose.jwk.JWKSelector;
23-
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
24-
import com.nimbusds.jose.jwk.source.URLBasedJWKSetSource;
23+
import com.nimbusds.jose.jwk.source.JWKSetParseException;
24+
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
2525
import java.io.IOException;
2626
import java.net.MalformedURLException;
2727
import java.net.URL;
@@ -35,6 +35,7 @@
3535
import java.util.List;
3636
import java.util.Map;
3737
import java.util.Set;
38+
import java.util.concurrent.locks.ReentrantLock;
3839
import java.util.function.Consumer;
3940
import java.util.function.Function;
4041

@@ -48,8 +49,6 @@
4849
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
4950
import com.nimbusds.jose.proc.SecurityContext;
5051
import com.nimbusds.jose.proc.SingleKeyJWSKeySelector;
51-
import com.nimbusds.jose.util.Resource;
52-
import com.nimbusds.jose.util.ResourceRetriever;
5352
import com.nimbusds.jwt.JWT;
5453
import com.nimbusds.jwt.JWTClaimsSet;
5554
import com.nimbusds.jwt.JWTParser;
@@ -61,6 +60,7 @@
6160
import org.apache.commons.logging.LogFactory;
6261

6362
import org.springframework.cache.Cache;
63+
import org.springframework.cache.support.NoOpCache;
6464
import org.springframework.core.convert.converter.Converter;
6565
import org.springframework.http.HttpHeaders;
6666
import org.springframework.http.HttpMethod;
@@ -278,7 +278,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
278278

279279
private RestOperations restOperations = new RestTemplate();
280280

281-
private Cache cache;
281+
private Cache cache = new NoOpCache("default");
282282

283283
private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer;
284284

@@ -381,19 +381,13 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
381381
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
382382
}
383383

384-
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
385-
URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource(toURL(jwkSetUri), jwkSetRetriever);
386-
if(this.cache == null) {
387-
return new SpringURLBasedJWKSource(urlBasedJWKSetSource);
388-
}
389-
SpringJWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
390-
return new SpringURLBasedJWKSource<>(urlBasedJWKSetSource, jwkSetCache);
384+
JWKSource<SecurityContext> jwkSource() {
385+
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
386+
return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
391387
}
392388

393389
JWTProcessor<SecurityContext> processor() {
394-
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
395-
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
396-
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever, jwkSetUri);
390+
JWKSource<SecurityContext> jwkSource = jwkSource();
397391
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
398392
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
399393
// Spring Security validates the claim set independent from Nimbus
@@ -420,153 +414,130 @@ private static URL toURL(String url) {
420414
}
421415
}
422416

423-
private static final class SpringURLBasedJWKSource<C extends SecurityContext> implements JWKSource<C> {
417+
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
424418

425-
private final URLBasedJWKSetSource urlBasedJWKSetSource;
419+
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
426420

427-
private final SpringJWKSetCache jwkSetCache;
421+
private final ReentrantLock reentrantLock = new ReentrantLock();
428422

429-
private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource) {
430-
this(urlBasedJWKSetSource, null);
431-
}
423+
private final RestOperations restOperations;
424+
425+
private final Cache cache;
426+
427+
private final URL url;
428+
429+
private final String jwkSetUri;
432430

433-
private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource, SpringJWKSetCache jwkSetCache) {
434-
this.urlBasedJWKSetSource = urlBasedJWKSetSource;
435-
this.jwkSetCache = jwkSetCache;
431+
private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
432+
Assert.notNull(restOperations, "restOperations cannot be null");
433+
this.restOperations = restOperations;
434+
this.cache = cache;
435+
this.url = url;
436+
this.jwkSetUri = jwkSetUri;
436437
}
437438

439+
438440
@Override
439441
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
440-
if (this.jwkSetCache != null) {
441-
JWKSet jwkSet = this.jwkSetCache.get();
442-
if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
443-
synchronized (this) {
444-
jwkSet = fetchJWKSet();
445-
this.jwkSetCache.put(jwkSet);
446-
}
447-
}
448-
List<JWK> matches = jwkSelector.select(jwkSet);
449-
if(!matches.isEmpty()) {
450-
return matches;
451-
}
452-
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
453-
if (soughtKeyID == null) {
454-
return Collections.emptyList();
455-
}
456-
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
457-
return Collections.emptyList();
458-
}
459-
synchronized (this) {
460-
if(jwkSet == this.jwkSetCache.get()) {
461-
jwkSet = fetchJWKSet();
462-
this.jwkSetCache.put(jwkSet);
463-
} else {
464-
jwkSet = this.jwkSetCache.get();
442+
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
443+
JWKSet jwkSet = null;
444+
if (cachedJwkSet != null) {
445+
jwkSet = parse(cachedJwkSet);
446+
}
447+
if (jwkSet == null) {
448+
if(reentrantLock.tryLock()) {
449+
try {
450+
String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
451+
if (cachedJwkSetAfterLock != null) {
452+
jwkSet = parse(cachedJwkSetAfterLock);
453+
}
454+
if(jwkSet == null) {
455+
try {
456+
jwkSet = fetchJWKSet();
457+
} catch (IOException e) {
458+
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
459+
}
460+
}
461+
} finally {
462+
reentrantLock.unlock();
465463
}
466464
}
467-
if(jwkSet == null) {
468-
return Collections.emptyList();
469-
}
470-
return jwkSelector.select(jwkSet);
471465
}
472-
return jwkSelector.select(fetchJWKSet());
473-
}
474-
475-
private JWKSet fetchJWKSet() throws KeySourceException {
476-
return this.urlBasedJWKSetSource.getJWKSet(JWKSetCacheRefreshEvaluator.noRefresh(),
477-
System.currentTimeMillis(), null);
478-
}
479-
480-
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
481-
Set<String> keyIDs = jwkMatcher.getKeyIDs();
482-
483-
if (keyIDs == null || keyIDs.isEmpty()) {
484-
return null;
466+
List<JWK> matches = jwkSelector.select(jwkSet);
467+
if(!matches.isEmpty()) {
468+
return matches;
485469
}
486-
487-
for (String id: keyIDs) {
488-
if (id != null) {
489-
return id;
490-
}
470+
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
471+
if (soughtKeyID == null) {
472+
return Collections.emptyList();
473+
}
474+
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
475+
return Collections.emptyList();
491476
}
492-
return null;
493-
}
494-
}
495-
496-
private static final class SpringJWKSetCache {
497-
498-
private final String jwkSetUri;
499-
500-
private final Cache cache;
501-
502-
private JWKSet jwkSet;
503-
504-
SpringJWKSetCache(String jwkSetUri, Cache cache) {
505-
this.jwkSetUri = jwkSetUri;
506-
this.cache = cache;
507-
this.updateJwkSetFromCache();
508-
}
509477

510-
private void updateJwkSetFromCache() {
511-
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
512-
if (cachedJwkSet != null) {
478+
if(reentrantLock.tryLock()) {
513479
try {
514-
this.jwkSet = JWKSet.parse(cachedJwkSet);
515-
}
516-
catch (ParseException ignored) {
517-
// Ignore invalid cache value
480+
String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
481+
JWKSet cacheJwkSet = parse(jwkSetUri);
482+
if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
483+
try {
484+
jwkSet = fetchJWKSet();
485+
} catch (IOException e) {
486+
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
487+
}
488+
} else if (jwkSetUri != null) {
489+
jwkSet = parse(jwkSetUri);
490+
}
491+
} finally {
492+
reentrantLock.unlock();
518493
}
519494
}
495+
if(jwkSet == null) {
496+
return Collections.emptyList();
497+
}
498+
return jwkSelector.select(jwkSet);
520499
}
521500

522-
// Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
523-
public void put(JWKSet jwkSet) {
524-
this.jwkSet = jwkSet;
525-
this.cache.put(this.jwkSetUri, jwkSet.toString(false));
526-
}
527-
528-
public JWKSet get() {
529-
return (!requiresRefresh()) ? this.jwkSet : null;
530-
}
531-
532-
public boolean requiresRefresh() {
533-
return this.cache.get(this.jwkSetUri) == null;
534-
}
535-
536-
}
537-
538-
private static class RestOperationsResourceRetriever implements ResourceRetriever {
539-
540-
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
541-
542-
private final RestOperations restOperations;
543-
544-
RestOperationsResourceRetriever(RestOperations restOperations) {
545-
Assert.notNull(restOperations, "restOperations cannot be null");
546-
this.restOperations = restOperations;
547-
}
548-
549-
@Override
550-
public Resource retrieveResource(URL url) throws IOException {
501+
private JWKSet fetchJWKSet() throws IOException, KeySourceException {
551502
HttpHeaders headers = new HttpHeaders();
552503
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
553-
ResponseEntity<String> response = getResponse(url, headers);
504+
ResponseEntity<String> response = getResponse(headers);
554505
if (response.getStatusCode().value() != 200) {
555506
throw new IOException(response.toString());
556507
}
557-
return new Resource(response.getBody(), "UTF-8");
508+
try {
509+
String jwkSet = response.getBody();
510+
this.cache.put(this.jwkSetUri, jwkSet);
511+
return JWKSet.parse(jwkSet);
512+
} catch (ParseException e) {
513+
throw new JWKSetParseException("Unable to parse JWK set", e);
514+
}
558515
}
559516

560-
private ResponseEntity<String> getResponse(URL url, HttpHeaders headers) throws IOException {
517+
private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
561518
try {
562-
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
519+
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
563520
return this.restOperations.exchange(request, String.class);
564-
}
565-
catch (Exception ex) {
521+
} catch (Exception ex) {
566522
throw new IOException(ex);
567523
}
568524
}
569525

526+
private JWKSet parse(String cachedJwkSet) {
527+
JWKSet jwkSet = null;
528+
try {
529+
jwkSet = JWKSet.parse(cachedJwkSet);
530+
} catch (ParseException ignored) {
531+
// Ignore invalid cache value
532+
}
533+
return jwkSet;
534+
}
535+
536+
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
537+
Set<String> keyIDs = jwkMatcher.getKeyIDs();
538+
return (keyIDs == null || keyIDs.isEmpty()) ?
539+
null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
540+
}
570541
}
571542

572543
}

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -308,6 +308,7 @@ private void prepareConfigurationResponse() {
308308
private void prepareConfigurationResponse(String body) {
309309
this.server.enqueue(response(body));
310310
this.server.enqueue(response(JWK_SET));
311+
this.server.enqueue(response(JWK_SET)); // default NoOpCache
311312
}
312313

313314
private void prepareConfigurationResponseOidc() {

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -60,7 +60,6 @@
6060

6161
import org.springframework.cache.Cache;
6262
import org.springframework.cache.concurrent.ConcurrentMapCache;
63-
import org.springframework.cache.support.SimpleValueWrapper;
6463
import org.springframework.core.ParameterizedTypeReference;
6564
import org.springframework.core.convert.converter.Converter;
6665
import org.springframework.http.HttpStatus;
@@ -704,7 +703,6 @@ public void decodeWhenCacheThenRetrieveFromCache() throws Exception {
704703
RestOperations restOperations = mock(RestOperations.class);
705704
Cache cache = mock(Cache.class);
706705
given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET);
707-
given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class));
708706
// @formatter:off
709707
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI)
710708
.cache(cache)
@@ -713,7 +711,6 @@ public void decodeWhenCacheThenRetrieveFromCache() throws Exception {
713711
// @formatter:on
714712
jwtDecoder.decode(SIGNED_JWT);
715713
verify(cache).get(eq(JWK_SET_URI), eq(String.class));
716-
verify(cache, times(2)).get(eq(JWK_SET_URI));
717714
verifyNoMoreInteractions(cache);
718715
verifyNoInteractions(restOperations);
719716
}
@@ -724,7 +721,6 @@ public void decodeWhenCacheAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSE
724721
RestOperations restOperations = mock(RestOperations.class);
725722
Cache cache = mock(Cache.class);
726723
given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET);
727-
given(cache.get(eq(JWK_SET_URI))).willReturn(new SimpleValueWrapper(JWK_SET));
728724
given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
729725
.willReturn(new ResponseEntity<>(NEW_KID_JWK_SET, HttpStatus.OK));
730726

@@ -796,7 +792,6 @@ public void decodeWhenCacheIsConfiguredAndParseFailsOnCachedValueThenExceptionIg
796792
RestOperations restOperations = mock(RestOperations.class);
797793
Cache cache = mock(Cache.class);
798794
given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET);
799-
given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class));
800795
// @formatter:off
801796
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI)
802797
.cache(cache)
@@ -805,7 +800,6 @@ public void decodeWhenCacheIsConfiguredAndParseFailsOnCachedValueThenExceptionIg
805800
// @formatter:on
806801
jwtDecoder.decode(SIGNED_JWT);
807802
verify(cache).get(eq(JWK_SET_URI), eq(String.class));
808-
verify(cache, times(2)).get(eq(JWK_SET_URI));
809803
verifyNoMoreInteractions(cache);
810804
verifyNoInteractions(restOperations);
811805

0 commit comments

Comments
 (0)