16
16
17
17
package org .springframework .security .oauth2 .jwt ;
18
18
19
- import com .nimbusds .jose .KeySourceException ;
20
- import com .nimbusds .jose .jwk .JWK ;
21
- import com .nimbusds .jose .jwk .JWKMatcher ;
22
- import com .nimbusds .jose .jwk .JWKSelector ;
23
- import com .nimbusds .jose .jwk .source .JWKSetParseException ;
24
- import com .nimbusds .jose .jwk .source .JWKSetRetrievalException ;
25
- import java .io .IOException ;
26
- import java .net .MalformedURLException ;
27
- import java .net .URL ;
19
+ import java .net .URI ;
28
20
import java .security .interfaces .RSAPublicKey ;
29
21
import java .text .ParseException ;
30
22
import java .util .Arrays ;
31
23
import java .util .Collection ;
32
24
import java .util .Collections ;
33
25
import java .util .HashSet ;
34
26
import java .util .LinkedHashMap ;
35
- import java .util .List ;
36
27
import java .util .Map ;
37
28
import java .util .Set ;
38
29
import java .util .concurrent .locks .ReentrantLock ;
43
34
44
35
import com .nimbusds .jose .JOSEException ;
45
36
import com .nimbusds .jose .JWSAlgorithm ;
37
+ import com .nimbusds .jose .KeySourceException ;
38
+ import com .nimbusds .jose .RemoteKeySourceException ;
46
39
import com .nimbusds .jose .jwk .JWKSet ;
40
+ import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
41
+ import com .nimbusds .jose .jwk .source .JWKSetSource ;
47
42
import com .nimbusds .jose .jwk .source .JWKSource ;
43
+ import com .nimbusds .jose .jwk .source .JWKSourceBuilder ;
48
44
import com .nimbusds .jose .proc .JWSKeySelector ;
49
45
import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
50
46
import com .nimbusds .jose .proc .SecurityContext ;
@@ -170,7 +166,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
170
166
.build ();
171
167
// @formatter:on
172
168
}
173
- catch (KeySourceException ex ) {
169
+ catch (RemoteKeySourceException ex ) {
174
170
this .logger .trace ("Failed to retrieve JWK set" , ex );
175
171
if (ex .getCause () instanceof ParseException ) {
176
172
throw new JwtException (String .format (DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ), ex );
@@ -383,7 +379,11 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
383
379
384
380
JWKSource <SecurityContext > jwkSource () {
385
381
String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
386
- return new SpringJWKSource <>(this .restOperations , this .cache , toURL (jwkSetUri ), jwkSetUri );
382
+ return JWKSourceBuilder .create (new SpringJWKSource <>(this .restOperations , this .cache , jwkSetUri ))
383
+ .refreshAheadCache (false )
384
+ .rateLimited (false )
385
+ .cache (this .cache instanceof NoOpCache )
386
+ .build ();
387
387
}
388
388
389
389
JWTProcessor <SecurityContext > processor () {
@@ -405,16 +405,7 @@ public NimbusJwtDecoder build() {
405
405
return new NimbusJwtDecoder (processor ());
406
406
}
407
407
408
- private static URL toURL (String url ) {
409
- try {
410
- return new URL (url );
411
- }
412
- catch (MalformedURLException ex ) {
413
- throw new IllegalArgumentException ("Invalid JWK Set URL \" " + url + "\" : " + ex .getMessage (), ex );
414
- }
415
- }
416
-
417
- private static final class SpringJWKSource <C extends SecurityContext > implements JWKSource <C > {
408
+ private static final class SpringJWKSource <C extends SecurityContext > implements JWKSetSource <C > {
418
409
419
410
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
420
411
@@ -424,120 +415,63 @@ private static final class SpringJWKSource<C extends SecurityContext> implements
424
415
425
416
private final Cache cache ;
426
417
427
- private final URL url ;
428
-
429
418
private final String jwkSetUri ;
430
419
431
- private SpringJWKSource (RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
420
+ private JWKSet jwkSet ;
421
+
422
+ private SpringJWKSource (RestOperations restOperations , Cache cache , String jwkSetUri ) {
432
423
Assert .notNull (restOperations , "restOperations cannot be null" );
433
424
this .restOperations = restOperations ;
434
425
this .cache = cache ;
435
- this .url = url ;
436
426
this .jwkSetUri = jwkSetUri ;
437
- }
438
-
439
-
440
- @ Override
441
- public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
442
- String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
443
- JWKSet jwkSet = null ;
444
- if (cachedJwkSet != null ) {
445
- jwkSet = parse (cachedJwkSet );
446
- }
447
- if (jwkSet == null ) {
448
- if (reentrantLock .tryLock ()) {
449
- try {
450
- String cachedJwkSetAfterLock = this .cache .get (this .jwkSetUri , String .class );
451
- if (cachedJwkSetAfterLock != null ) {
452
- jwkSet = parse (cachedJwkSetAfterLock );
453
- }
454
- if (jwkSet == null ) {
455
- try {
456
- jwkSet = fetchJWKSet ();
457
- } catch (IOException e ) {
458
- throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
459
- }
460
- }
461
- } finally {
462
- reentrantLock .unlock ();
463
- }
464
- }
465
- }
466
- List <JWK > matches = jwkSelector .select (jwkSet );
467
- if (!matches .isEmpty ()) {
468
- return matches ;
469
- }
470
- String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
471
- if (soughtKeyID == null ) {
472
- return Collections .emptyList ();
473
- }
474
- if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
475
- return Collections .emptyList ();
476
- }
477
-
478
- if (reentrantLock .tryLock ()) {
427
+ String jwks = this .cache .get (this .jwkSetUri , String .class );
428
+ if (jwks != null ) {
479
429
try {
480
- String jwkSetUri = this .cache .get (this .jwkSetUri , String .class );
481
- JWKSet cacheJwkSet = parse (jwkSetUri );
482
- if (jwkSetUri != null && cacheJwkSet .toString ().equals (jwkSet .toString ())) {
483
- try {
484
- jwkSet = fetchJWKSet ();
485
- } catch (IOException e ) {
486
- throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
487
- }
488
- } else if (jwkSetUri != null ) {
489
- jwkSet = parse (jwkSetUri );
490
- }
491
- } finally {
492
- reentrantLock .unlock ();
430
+ this .jwkSet = JWKSet .parse (jwks );
431
+ }
432
+ catch (ParseException ignored ) {
433
+ // Ignore invalid cache value
493
434
}
494
435
}
495
- if (jwkSet == null ) {
496
- return Collections .emptyList ();
497
- }
498
- return jwkSelector .select (jwkSet );
499
436
}
500
437
501
- private JWKSet fetchJWKSet () throws IOException , KeySourceException {
438
+ private String fetchJwks () throws Exception {
502
439
HttpHeaders headers = new HttpHeaders ();
503
440
headers .setAccept (Arrays .asList (MediaType .APPLICATION_JSON , APPLICATION_JWK_SET_JSON ));
504
- ResponseEntity <String > response = getResponse (headers );
505
- if (response .getStatusCode ().value () != 200 ) {
506
- throw new IOException (response .toString ());
507
- }
508
- try {
509
- String jwkSet = response .getBody ();
510
- this .cache .put (this .jwkSetUri , jwkSet );
511
- return JWKSet .parse (jwkSet );
512
- } catch (ParseException e ) {
513
- throw new JWKSetParseException ("Unable to parse JWK set" , e );
514
- }
441
+ RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , URI .create (this .jwkSetUri ));
442
+ ResponseEntity <String > response = this .restOperations .exchange (request , String .class );
443
+ String jwks = response .getBody ();
444
+ this .jwkSet = JWKSet .parse (jwks );
445
+ return jwks ;
515
446
}
516
447
517
- private ResponseEntity <String > getResponse (HttpHeaders headers ) throws IOException {
448
+ @ Override
449
+ public JWKSet getJWKSet (JWKSetCacheRefreshEvaluator refreshEvaluator , long currentTime , C context )
450
+ throws KeySourceException {
518
451
try {
519
- RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , this .url .toURI ());
520
- return this .restOperations .exchange (request , String .class );
521
- } catch (Exception ex ) {
522
- throw new IOException (ex );
452
+ this .reentrantLock .lock ();
453
+ if (refreshEvaluator .requiresRefresh (this .jwkSet )) {
454
+ this .cache .invalidate ();
455
+ }
456
+ this .cache .get (this .jwkSetUri , this ::fetchJwks );
457
+ return this .jwkSet ;
523
458
}
524
- }
525
-
526
- private JWKSet parse ( String cachedJwkSet ) {
527
- JWKSet jwkSet = null ;
528
- try {
529
- jwkSet = JWKSet . parse ( cachedJwkSet );
530
- } catch ( ParseException ignored ) {
531
- // Ignore invalid cache value
459
+ catch ( Cache . ValueRetrievalException ex ) {
460
+ if ( ex . getCause () instanceof RemoteKeySourceException keys ) {
461
+ throw keys ;
462
+ }
463
+ throw new RemoteKeySourceException ( ex . getCause (). getMessage (), ex . getCause ());
464
+ }
465
+ finally {
466
+ this . reentrantLock . unlock ();
532
467
}
533
- return jwkSet ;
534
468
}
535
469
536
- private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
537
- Set <String > keyIDs = jwkMatcher .getKeyIDs ();
538
- return (keyIDs == null || keyIDs .isEmpty ()) ?
539
- null : keyIDs .stream ().filter (id -> id != null ).findFirst ().orElse (null );
470
+ @ Override
471
+ public void close () {
472
+
540
473
}
474
+
541
475
}
542
476
543
477
}
0 commit comments