1
1
/*
2
- * Copyright 2002-2024 the original author or authors.
2
+ * Copyright 2002-2025 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
20
20
import com .nimbusds .jose .jwk .JWK ;
21
21
import com .nimbusds .jose .jwk .JWKMatcher ;
22
22
import com .nimbusds .jose .jwk .JWKSelector ;
23
- import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
24
- import com .nimbusds .jose .jwk .source .URLBasedJWKSetSource ;
23
+ import com .nimbusds .jose .jwk .source .JWKSetParseException ;
24
+ import com .nimbusds .jose .jwk .source .JWKSetRetrievalException ;
25
25
import java .io .IOException ;
26
26
import java .net .MalformedURLException ;
27
27
import java .net .URL ;
35
35
import java .util .List ;
36
36
import java .util .Map ;
37
37
import java .util .Set ;
38
+ import java .util .concurrent .locks .ReentrantLock ;
38
39
import java .util .function .Consumer ;
39
40
import java .util .function .Function ;
40
41
48
49
import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
49
50
import com .nimbusds .jose .proc .SecurityContext ;
50
51
import com .nimbusds .jose .proc .SingleKeyJWSKeySelector ;
51
- import com .nimbusds .jose .util .Resource ;
52
- import com .nimbusds .jose .util .ResourceRetriever ;
53
52
import com .nimbusds .jwt .JWT ;
54
53
import com .nimbusds .jwt .JWTClaimsSet ;
55
54
import com .nimbusds .jwt .JWTParser ;
61
60
import org .apache .commons .logging .LogFactory ;
62
61
63
62
import org .springframework .cache .Cache ;
63
+ import org .springframework .cache .support .NoOpCache ;
64
64
import org .springframework .core .convert .converter .Converter ;
65
65
import org .springframework .http .HttpHeaders ;
66
66
import org .springframework .http .HttpMethod ;
@@ -278,7 +278,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
278
278
279
279
private RestOperations restOperations = new RestTemplate ();
280
280
281
- private Cache cache ;
281
+ private Cache cache = new NoOpCache ( "default" ) ;
282
282
283
283
private Consumer <ConfigurableJWTProcessor <SecurityContext >> jwtProcessorCustomizer ;
284
284
@@ -381,19 +381,13 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
381
381
return new JWSVerificationKeySelector <>(jwsAlgorithms , jwkSource );
382
382
}
383
383
384
- JWKSource <SecurityContext > jwkSource (ResourceRetriever jwkSetRetriever , String jwkSetUri ) {
385
- URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource (toURL (jwkSetUri ), jwkSetRetriever );
386
- if (this .cache == null ) {
387
- return new SpringURLBasedJWKSource (urlBasedJWKSetSource );
388
- }
389
- SpringJWKSetCache jwkSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
390
- return new SpringURLBasedJWKSource <>(urlBasedJWKSetSource , jwkSetCache );
384
+ JWKSource <SecurityContext > jwkSource () {
385
+ String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
386
+ return new SpringJWKSource <>(this .restOperations , this .cache , toURL (jwkSetUri ), jwkSetUri );
391
387
}
392
388
393
389
JWTProcessor <SecurityContext > processor () {
394
- ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever (this .restOperations );
395
- String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
396
- JWKSource <SecurityContext > jwkSource = jwkSource (jwkSetRetriever , jwkSetUri );
390
+ JWKSource <SecurityContext > jwkSource = jwkSource ();
397
391
ConfigurableJWTProcessor <SecurityContext > jwtProcessor = new DefaultJWTProcessor <>();
398
392
jwtProcessor .setJWSKeySelector (jwsKeySelector (jwkSource ));
399
393
// Spring Security validates the claim set independent from Nimbus
@@ -420,153 +414,130 @@ private static URL toURL(String url) {
420
414
}
421
415
}
422
416
423
- private static final class SpringURLBasedJWKSource <C extends SecurityContext > implements JWKSource <C > {
417
+ private static final class SpringJWKSource <C extends SecurityContext > implements JWKSource <C > {
424
418
425
- private final URLBasedJWKSetSource urlBasedJWKSetSource ;
419
+ private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ( "application" , "jwk-set+json" ) ;
426
420
427
- private final SpringJWKSetCache jwkSetCache ;
421
+ private final ReentrantLock reentrantLock = new ReentrantLock () ;
428
422
429
- private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource ) {
430
- this (urlBasedJWKSetSource , null );
431
- }
423
+ private final RestOperations restOperations ;
424
+
425
+ private final Cache cache ;
426
+
427
+ private final URL url ;
428
+
429
+ private final String jwkSetUri ;
432
430
433
- private SpringURLBasedJWKSource (URLBasedJWKSetSource urlBasedJWKSetSource , SpringJWKSetCache jwkSetCache ) {
434
- this .urlBasedJWKSetSource = urlBasedJWKSetSource ;
435
- this .jwkSetCache = jwkSetCache ;
431
+ private SpringJWKSource (RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
432
+ Assert .notNull (restOperations , "restOperations cannot be null" );
433
+ this .restOperations = restOperations ;
434
+ this .cache = cache ;
435
+ this .url = url ;
436
+ this .jwkSetUri = jwkSetUri ;
436
437
}
437
438
439
+
438
440
@ Override
439
441
public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
440
- if (this .jwkSetCache != null ) {
441
- JWKSet jwkSet = this .jwkSetCache .get ();
442
- if (this .jwkSetCache .requiresRefresh () || jwkSet == null ) {
443
- synchronized (this ) {
444
- jwkSet = fetchJWKSet ();
445
- this .jwkSetCache .put (jwkSet );
446
- }
447
- }
448
- List <JWK > matches = jwkSelector .select (jwkSet );
449
- if (!matches .isEmpty ()) {
450
- return matches ;
451
- }
452
- String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
453
- if (soughtKeyID == null ) {
454
- return Collections .emptyList ();
455
- }
456
- if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
457
- return Collections .emptyList ();
458
- }
459
- synchronized (this ) {
460
- if (jwkSet == this .jwkSetCache .get ()) {
461
- jwkSet = fetchJWKSet ();
462
- this .jwkSetCache .put (jwkSet );
463
- } else {
464
- jwkSet = this .jwkSetCache .get ();
442
+ String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
443
+ JWKSet jwkSet = null ;
444
+ if (cachedJwkSet != null ) {
445
+ jwkSet = parse (cachedJwkSet );
446
+ }
447
+ if (jwkSet == null ) {
448
+ if (reentrantLock .tryLock ()) {
449
+ try {
450
+ String cachedJwkSetAfterLock = this .cache .get (this .jwkSetUri , String .class );
451
+ if (cachedJwkSetAfterLock != null ) {
452
+ jwkSet = parse (cachedJwkSetAfterLock );
453
+ }
454
+ if (jwkSet == null ) {
455
+ try {
456
+ jwkSet = fetchJWKSet ();
457
+ } catch (IOException e ) {
458
+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
459
+ }
460
+ }
461
+ } finally {
462
+ reentrantLock .unlock ();
465
463
}
466
464
}
467
- if (jwkSet == null ) {
468
- return Collections .emptyList ();
469
- }
470
- return jwkSelector .select (jwkSet );
471
465
}
472
- return jwkSelector .select (fetchJWKSet ());
473
- }
474
-
475
- private JWKSet fetchJWKSet () throws KeySourceException {
476
- return this .urlBasedJWKSetSource .getJWKSet (JWKSetCacheRefreshEvaluator .noRefresh (),
477
- System .currentTimeMillis (), null );
478
- }
479
-
480
- private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
481
- Set <String > keyIDs = jwkMatcher .getKeyIDs ();
482
-
483
- if (keyIDs == null || keyIDs .isEmpty ()) {
484
- return null ;
466
+ List <JWK > matches = jwkSelector .select (jwkSet );
467
+ if (!matches .isEmpty ()) {
468
+ return matches ;
485
469
}
486
-
487
- for (String id : keyIDs ) {
488
- if (id != null ) {
489
- return id ;
490
- }
470
+ String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
471
+ if (soughtKeyID == null ) {
472
+ return Collections .emptyList ();
473
+ }
474
+ if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
475
+ return Collections .emptyList ();
491
476
}
492
- return null ;
493
- }
494
- }
495
-
496
- private static final class SpringJWKSetCache {
497
-
498
- private final String jwkSetUri ;
499
-
500
- private final Cache cache ;
501
-
502
- private JWKSet jwkSet ;
503
-
504
- SpringJWKSetCache (String jwkSetUri , Cache cache ) {
505
- this .jwkSetUri = jwkSetUri ;
506
- this .cache = cache ;
507
- this .updateJwkSetFromCache ();
508
- }
509
477
510
- private void updateJwkSetFromCache () {
511
- String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
512
- if (cachedJwkSet != null ) {
478
+ if (reentrantLock .tryLock ()) {
513
479
try {
514
- this .jwkSet = JWKSet .parse (cachedJwkSet );
515
- }
516
- catch (ParseException ignored ) {
517
- // Ignore invalid cache value
480
+ String jwkSetUri = this .cache .get (this .jwkSetUri , String .class );
481
+ JWKSet cacheJwkSet = parse (jwkSetUri );
482
+ if (jwkSetUri != null && cacheJwkSet .toString ().equals (jwkSet .toString ())) {
483
+ try {
484
+ jwkSet = fetchJWKSet ();
485
+ } catch (IOException e ) {
486
+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
487
+ }
488
+ } else if (jwkSetUri != null ) {
489
+ jwkSet = parse (jwkSetUri );
490
+ }
491
+ } finally {
492
+ reentrantLock .unlock ();
518
493
}
519
494
}
495
+ if (jwkSet == null ) {
496
+ return Collections .emptyList ();
497
+ }
498
+ return jwkSelector .select (jwkSet );
520
499
}
521
500
522
- // Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
523
- public void put (JWKSet jwkSet ) {
524
- this .jwkSet = jwkSet ;
525
- this .cache .put (this .jwkSetUri , jwkSet .toString (false ));
526
- }
527
-
528
- public JWKSet get () {
529
- return (!requiresRefresh ()) ? this .jwkSet : null ;
530
- }
531
-
532
- public boolean requiresRefresh () {
533
- return this .cache .get (this .jwkSetUri ) == null ;
534
- }
535
-
536
- }
537
-
538
- private static class RestOperationsResourceRetriever implements ResourceRetriever {
539
-
540
- private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
541
-
542
- private final RestOperations restOperations ;
543
-
544
- RestOperationsResourceRetriever (RestOperations restOperations ) {
545
- Assert .notNull (restOperations , "restOperations cannot be null" );
546
- this .restOperations = restOperations ;
547
- }
548
-
549
- @ Override
550
- public Resource retrieveResource (URL url ) throws IOException {
501
+ private JWKSet fetchJWKSet () throws IOException , KeySourceException {
551
502
HttpHeaders headers = new HttpHeaders ();
552
503
headers .setAccept (Arrays .asList (MediaType .APPLICATION_JSON , APPLICATION_JWK_SET_JSON ));
553
- ResponseEntity <String > response = getResponse (url , headers );
504
+ ResponseEntity <String > response = getResponse (headers );
554
505
if (response .getStatusCode ().value () != 200 ) {
555
506
throw new IOException (response .toString ());
556
507
}
557
- return new Resource (response .getBody (), "UTF-8" );
508
+ try {
509
+ String jwkSet = response .getBody ();
510
+ this .cache .put (this .jwkSetUri , jwkSet );
511
+ return JWKSet .parse (jwkSet );
512
+ } catch (ParseException e ) {
513
+ throw new JWKSetParseException ("Unable to parse JWK set" , e );
514
+ }
558
515
}
559
516
560
- private ResponseEntity <String > getResponse (URL url , HttpHeaders headers ) throws IOException {
517
+ private ResponseEntity <String > getResponse (HttpHeaders headers ) throws IOException {
561
518
try {
562
- RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , url .toURI ());
519
+ RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , this . url .toURI ());
563
520
return this .restOperations .exchange (request , String .class );
564
- }
565
- catch (Exception ex ) {
521
+ } catch (Exception ex ) {
566
522
throw new IOException (ex );
567
523
}
568
524
}
569
525
526
+ private JWKSet parse (String cachedJwkSet ) {
527
+ JWKSet jwkSet = null ;
528
+ try {
529
+ jwkSet = JWKSet .parse (cachedJwkSet );
530
+ } catch (ParseException ignored ) {
531
+ // Ignore invalid cache value
532
+ }
533
+ return jwkSet ;
534
+ }
535
+
536
+ private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
537
+ Set <String > keyIDs = jwkMatcher .getKeyIDs ();
538
+ return (keyIDs == null || keyIDs .isEmpty ()) ?
539
+ null : keyIDs .stream ().filter (id -> id != null ).findFirst ().orElse (null );
540
+ }
570
541
}
571
542
572
543
}
0 commit comments