Skip to content

Remove Deprecated Usages of RemoteJWKSet #16537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,7 @@

package org.springframework.security.oauth2.jwt;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URI;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.Arrays;
Expand All @@ -28,24 +26,25 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;

import javax.crypto.SecretKey;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
import com.nimbusds.jose.jwk.source.JWKSetSource;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.proc.SingleKeyJWSKeySelector;
import com.nimbusds.jose.util.Resource;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
Expand All @@ -57,6 +56,7 @@
import org.apache.commons.logging.LogFactory;

import org.springframework.cache.Cache;
import org.springframework.cache.support.NoOpCache;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
Expand All @@ -80,6 +80,7 @@
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
* @author Daeho Kwon
* @since 5.2
*/
public final class NimbusJwtDecoder implements JwtDecoder {
Expand Down Expand Up @@ -273,7 +274,7 @@ public static final class JwkSetUriJwtDecoderBuilder {

private RestOperations restOperations = new RestTemplate();

private Cache cache;
private Cache cache = new NoOpCache("default");

private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer;

Expand Down Expand Up @@ -376,18 +377,17 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
}

JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
if (this.cache == null) {
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever);
}
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
JWKSource<SecurityContext> jwkSource() {
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri))
.refreshAheadCache(false)
.rateLimited(false)
.cache(this.cache instanceof NoOpCache)
.build();
}

JWTProcessor<SecurityContext> processor() {
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever, jwkSetUri);
JWKSource<SecurityContext> jwkSource = jwkSource();
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
// Spring Security validates the claim set independent from Nimbus
Expand All @@ -405,93 +405,73 @@ public NimbusJwtDecoder build() {
return new NimbusJwtDecoder(processor());
}

private static URL toURL(String url) {
try {
return new URL(url);
}
catch (MalformedURLException ex) {
throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
}
}
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {

private static final class SpringJWKSetCache implements JWKSetCache {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");

private final String jwkSetUri;
private final ReentrantLock reentrantLock = new ReentrantLock();

private final RestOperations restOperations;

private final Cache cache;

private final String jwkSetUri;

private JWKSet jwkSet;

SpringJWKSetCache(String jwkSetUri, Cache cache) {
this.jwkSetUri = jwkSetUri;
private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
this.cache = cache;
this.updateJwkSetFromCache();
}

private void updateJwkSetFromCache() {
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
if (cachedJwkSet != null) {
this.jwkSetUri = jwkSetUri;
String jwks = this.cache.get(this.jwkSetUri, String.class);
if (jwks != null) {
try {
this.jwkSet = JWKSet.parse(cachedJwkSet);
this.jwkSet = JWKSet.parse(jwks);
}
catch (ParseException ignored) {
// Ignore invalid cache value
}
}
}

// Note: Only called from inside a synchronized block in RemoteJWKSet.
@Override
public void put(JWKSet jwkSet) {
this.jwkSet = jwkSet;
this.cache.put(this.jwkSetUri, jwkSet.toString(false));
}

@Override
public JWKSet get() {
return (!requiresRefresh()) ? this.jwkSet : null;

}

@Override
public boolean requiresRefresh() {
return this.cache.get(this.jwkSetUri) == null;
}

}

private static class RestOperationsResourceRetriever implements ResourceRetriever {

private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");

private final RestOperations restOperations;

RestOperationsResourceRetriever(RestOperations restOperations) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
}

@Override
public Resource retrieveResource(URL url) throws IOException {
private String fetchJwks() throws Exception {
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
ResponseEntity<String> response = getResponse(url, headers);
if (response.getStatusCode().value() != 200) {
throw new IOException(response.toString());
}
return new Resource(response.getBody(), "UTF-8");
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri));
ResponseEntity<String> response = this.restOperations.exchange(request, String.class);
String jwks = response.getBody();
this.jwkSet = JWKSet.parse(jwks);
return jwks;
}

private ResponseEntity<String> getResponse(URL url, HttpHeaders headers) throws IOException {
@Override
public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context)
throws KeySourceException {
try {
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
return this.restOperations.exchange(request, String.class);
this.reentrantLock.lock();
if (refreshEvaluator.requiresRefresh(this.jwkSet)) {
this.cache.invalidate();
}
this.cache.get(this.jwkSetUri, this::fetchJwks);
return this.jwkSet;
}
catch (Exception ex) {
throw new IOException(ex);
catch (Cache.ValueRetrievalException ex) {
if (ex.getCause() instanceof RemoteKeySourceException keys) {
throw keys;
}
throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause());
}
finally {
this.reentrantLock.unlock();
}
}

@Override
public void close() {

}

}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,7 +60,6 @@

import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.cache.support.SimpleValueWrapper;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpStatus;
Expand Down Expand Up @@ -702,29 +701,25 @@ public void decodeWhenCacheStoredThenAbleToRetrieveJwkSetFromCache() {
@Test
public void decodeWhenCacheThenRetrieveFromCache() throws Exception {
RestOperations restOperations = mock(RestOperations.class);
Cache cache = mock(Cache.class);
given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET);
given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class));
Cache cache = new ConcurrentMapCache("cache");
cache.put(JWK_SET_URI, JWK_SET);
// @formatter:off
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI)
.cache(cache)
.restOperations(restOperations)
.build();
// @formatter:on
jwtDecoder.decode(SIGNED_JWT);
verify(cache).get(eq(JWK_SET_URI), eq(String.class));
verify(cache, times(2)).get(eq(JWK_SET_URI));
verifyNoMoreInteractions(cache);
assertThat(cache.get(JWK_SET_URI, String.class)).isSameAs(JWK_SET);
verifyNoInteractions(restOperations);
}

// gh-11621
@Test
public void decodeWhenCacheAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSEException {
RestOperations restOperations = mock(RestOperations.class);
Cache cache = mock(Cache.class);
given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET);
given(cache.get(eq(JWK_SET_URI))).willReturn(new SimpleValueWrapper(JWK_SET));
Cache cache = new ConcurrentMapCache("cache");
cache.put(JWK_SET_URI, JWK_SET);
given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
.willReturn(new ResponseEntity<>(NEW_KID_JWK_SET, HttpStatus.OK));

Expand Down Expand Up @@ -794,19 +789,16 @@ public void decodeWhenCacheIsConfiguredAndValueLoaderErrorsThenThrowsJwtExceptio
@Test
public void decodeWhenCacheIsConfiguredAndParseFailsOnCachedValueThenExceptionIgnored() {
RestOperations restOperations = mock(RestOperations.class);
Cache cache = mock(Cache.class);
given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET);
given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class));
Cache cache = new ConcurrentMapCache("cache");
cache.put(JWK_SET_URI, JWK_SET);
// @formatter:off
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI)
.cache(cache)
.restOperations(restOperations)
.build();
// @formatter:on
jwtDecoder.decode(SIGNED_JWT);
verify(cache).get(eq(JWK_SET_URI), eq(String.class));
verify(cache, times(2)).get(eq(JWK_SET_URI));
verifyNoMoreInteractions(cache);
assertThat(cache.get(JWK_SET_URI, String.class)).isSameAs(JWK_SET);
verifyNoInteractions(restOperations);

}
Expand Down