@@ -57,6 +57,11 @@ public class KeySetRetriever implements KeySetProvider {
5757 private long lastCacheStatusLog = 0 ;
5858 private String jwksUri ;
5959
60+ // Security validation settings (optional, for JWKS endpoints)
61+ private long maxResponseSizeBytes = -1 ; // -1 means no limit
62+ private int maxKeyCount = -1 ; // -1 means no limit
63+ private boolean enableSecurityValidation = false ;
64+
6065 KeySetRetriever (String openIdConnectEndpoint , SSLConfig sslConfig , boolean useCacheForOidConnectEndpoint ) {
6166 this .openIdConnectEndpoint = openIdConnectEndpoint ;
6267 this .sslConfig = sslConfig ;
@@ -71,10 +76,41 @@ public class KeySetRetriever implements KeySetProvider {
7176 configureCache (useCacheForOidConnectEndpoint );
7277 }
7378
79+ /**
80+ * Factory method to create a KeySetRetriever for JWKS endpoint access.
81+ * This method provides a public API for creating KeySetRetriever instances
82+ * with built-in security validation to protect against malicious JWKS endpoints.
83+ *
84+ * @param sslConfig SSL configuration for HTTPS connections
85+ * @param useCacheForJwksEndpoint whether to enable caching for JWKS endpoint
86+ * When true, JWKS responses will be cached to improve performance
87+ * and reduce network calls to the JWKS endpoint.
88+ * @param jwksUri the JWKS endpoint URI
89+ * @param maxResponseSizeBytes maximum allowed HTTP response size in bytes
90+ * @param maxKeyCount maximum number of keys allowed in JWKS
91+ * @return a new KeySetRetriever instance with security validation enabled
92+ */
93+ public static KeySetRetriever createForJwksUri (
94+ SSLConfig sslConfig ,
95+ boolean useCacheForJwksEndpoint ,
96+ String jwksUri ,
97+ long maxResponseSizeBytes ,
98+ int maxKeyCount
99+ ) {
100+ KeySetRetriever retriever = new KeySetRetriever (sslConfig , useCacheForJwksEndpoint , jwksUri );
101+ retriever .enableSecurityValidation = true ;
102+ retriever .maxResponseSizeBytes = maxResponseSizeBytes ;
103+ retriever .maxKeyCount = maxKeyCount ;
104+ return retriever ;
105+ }
106+
74107 public JWKSet get () throws AuthenticatorUnavailableException {
75108 String uri = getJwksUri ();
76109
77- try (CloseableHttpClient httpClient = createHttpClient (null )) {
110+ // Use cache storage if it's configured
111+ HttpCacheStorage cacheStorage = oidcHttpCacheStorage ;
112+
113+ try (CloseableHttpClient httpClient = createHttpClient (cacheStorage )) {
78114
79115 HttpGet httpGet = new HttpGet (uri );
80116
@@ -85,7 +121,20 @@ public JWKSet get() throws AuthenticatorUnavailableException {
85121
86122 httpGet .setConfig (requestConfig );
87123
88- try (CloseableHttpResponse response = httpClient .execute (httpGet )) {
124+ // Configure HTTP client to only accept JSON responses for JWKS endpoints
125+ if (enableSecurityValidation ) {
126+ httpGet .setHeader ("Accept" , "application/json, application/jwk-set+json" );
127+ }
128+
129+ HttpCacheContext httpContext = null ;
130+ if (cacheStorage != null ) {
131+ httpContext = new HttpCacheContext ();
132+ }
133+
134+ try (CloseableHttpResponse response = httpClient .execute (httpGet , httpContext )) {
135+ if (httpContext != null ) {
136+ logCacheResponseStatus (httpContext , true );
137+ }
89138 if (response .getCode () < 200 || response .getCode () >= 300 ) {
90139 throw new AuthenticatorUnavailableException ("Error while getting " + uri + ": " + response .getReasonPhrase ());
91140 }
@@ -95,11 +144,41 @@ public JWKSet get() throws AuthenticatorUnavailableException {
95144 if (httpEntity == null ) {
96145 throw new AuthenticatorUnavailableException ("Error while getting " + uri + ": Empty response entity" );
97146 }
147+
148+ // Apply security validation if enabled (for JWKS endpoints)
149+ if (enableSecurityValidation ) {
150+ // Validate response size
151+ if (maxResponseSizeBytes > 0 ) {
152+ long contentLength = httpEntity .getContentLength ();
153+ if (contentLength > maxResponseSizeBytes ) {
154+ throw new AuthenticatorUnavailableException (
155+ String .format (
156+ "JWKS response too large from %s: %d bytes (max: %d)" ,
157+ uri ,
158+ contentLength ,
159+ maxResponseSizeBytes
160+ )
161+ );
162+ }
163+ }
164+ }
165+
166+ // Load JWKS using Nimbus JOSE (handles JSON parsing and validation)
98167 JWKSet keySet = JWKSet .load (httpEntity .getContent ());
99168
169+ // Apply minimal additional validation only for direct JWKS endpoints
170+ if (enableSecurityValidation ) {
171+ // Simple key count validation - HARD LIMIT
172+ if (maxKeyCount > 0 && keySet .getKeys ().size () > maxKeyCount ) {
173+ throw new AuthenticatorUnavailableException (
174+ String .format ("JWKS from %s contains %d keys, but max allowed is %d" , uri , keySet .getKeys ().size (), maxKeyCount )
175+ );
176+ }
177+ }
178+
100179 return keySet ;
101180 } catch (ParseException e ) {
102- throw new RuntimeException ( e );
181+ throw new AuthenticatorUnavailableException ( "Error parsing JWKS from " + uri + ": " + e . getMessage (), e );
103182 }
104183 } catch (IOException e ) {
105184 throw new AuthenticatorUnavailableException ("Error while getting " + uri + ": " + e , e );
@@ -177,21 +256,43 @@ public void setRequestTimeoutMs(int httpTimeoutMs) {
177256 }
178257
179258 private void logCacheResponseStatus (HttpCacheContext httpContext ) {
259+ logCacheResponseStatus (httpContext , false );
260+ }
261+
262+ private void logCacheResponseStatus (HttpCacheContext httpContext , boolean isJwksRequest ) {
180263 this .oidcRequests ++;
181264
182- switch (httpContext .getCacheResponseStatus ()) {
183- case CACHE_HIT :
184- this .oidcCacheHits ++;
185- break ;
186- case CACHE_MODULE_RESPONSE :
187- this .oidcCacheModuleResponses ++;
188- break ;
189- case CACHE_MISS :
265+ // Handle cache statistics based on the response status
266+ // For OIDC discovery flow, only count the JWKS request (not the discovery request)
267+ // For direct JWKS URI, count all requests
268+ boolean shouldCountStats = (jwksUri != null ) || isJwksRequest ;
269+
270+ if (!shouldCountStats ) {
271+ log .debug ("Skipping cache statistics for OIDC discovery request #{}" , this .oidcRequests );
272+ return ;
273+ }
274+
275+ if (httpContext .getCacheResponseStatus () == null ) {
276+ if (oidcHttpCacheStorage != null ) {
190277 this .oidcCacheMisses ++;
191- break ;
192- case VALIDATED :
193- this .oidcCacheHitsValidated ++;
194- break ;
278+ log .debug ("Null cache status - counting as cache miss. Total misses: {}" , this .oidcCacheMisses );
279+ }
280+ } else {
281+ switch (httpContext .getCacheResponseStatus ()) {
282+ case CACHE_HIT :
283+ this .oidcCacheHits ++;
284+ break ;
285+ case CACHE_MODULE_RESPONSE :
286+ this .oidcCacheModuleResponses ++;
287+ break ;
288+ case CACHE_MISS :
289+ this .oidcCacheMisses ++;
290+ break ;
291+ case VALIDATED :
292+ this .oidcCacheHits ++;
293+ this .oidcCacheHitsValidated ++;
294+ break ;
295+ }
195296 }
196297
197298 long now = System .currentTimeMillis ();
@@ -208,7 +309,6 @@ private void logCacheResponseStatus(HttpCacheContext httpContext) {
208309 );
209310 lastCacheStatusLog = now ;
210311 }
211-
212312 }
213313
214314 private CloseableHttpClient createHttpClient (HttpCacheStorage httpCacheStorage ) {
@@ -255,4 +355,5 @@ public int getOidcCacheHitsValidated() {
255355 public int getOidcCacheModuleResponses () {
256356 return oidcCacheModuleResponses ;
257357 }
358+
258359}
0 commit comments