Skip to content

Commit 052836e

Browse files
[FEATURE] usage of JWKS with JWT (w/o OpenID connect)
Signed-off-by: Sebastian Michalski <shekerama@gmail.com>
1 parent a580dfc commit 052836e

File tree

7 files changed

+215
-27
lines changed

7 files changed

+215
-27
lines changed

src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator
5555
private final String jwtUrlParameter;
5656
private final String subjectKey;
5757
private final String rolesKey;
58+
private final String requiredAudience;
59+
private final String requiredIssuer;
5860

5961
public static final int DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS = 30;
6062
private final int clockSkewToleranceSeconds ;
@@ -66,10 +68,12 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) {
6668
rolesKey = settings.get("roles_key");
6769
subjectKey = settings.get("subject_key");
6870
clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS);
71+
requiredAudience = settings.get("required_audience");
72+
requiredIssuer = settings.get("required_issuer");
6973

7074
try {
7175
this.keyProvider = this.initKeyProvider(settings, configPath);
72-
jwtVerifier = new JwtVerifier(keyProvider, clockSkewToleranceSeconds );
76+
jwtVerifier = new JwtVerifier(keyProvider, clockSkewToleranceSeconds, requiredIssuer, requiredAudience);
7377

7478
} catch (Exception e) {
7579
log.error("Error creating JWT authenticator. JWT authentication will not work", e);
@@ -233,4 +237,12 @@ public boolean reRequestAuthentication(RestChannel channel, AuthCredentials auth
233237
return true;
234238
}
235239

240+
public String getRequiredAudience() {
241+
return requiredAudience;
242+
}
243+
244+
public String getRequiredIssuer() {
245+
return requiredIssuer;
246+
}
247+
236248
}

src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticator.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,15 @@ protected KeyProvider initKeyProvider(Settings settings, Path configPath) throws
3232

3333
int refreshRateLimitTimeWindowMs = settings.getAsInt("refresh_rate_limit_time_window_ms", 10000);
3434
int refreshRateLimitCount = settings.getAsInt("refresh_rate_limit_count", 10);
35-
36-
KeySetRetriever keySetRetriever = new KeySetRetriever(settings.get("openid_connect_url"),
37-
getSSLConfig(settings, configPath), settings.getAsBoolean("cache_jwks_endpoint", false));
35+
var jwksUri = settings.get("jwks_uri");
36+
37+
KeySetRetriever keySetRetriever;
38+
if(jwksUri != null && !jwksUri.isBlank()) {
39+
keySetRetriever =
40+
new KeySetRetriever(getSSLConfig(settings, configPath), settings.getAsBoolean("cache_jwks_endpoint", false), jwksUri);
41+
} else {
42+
keySetRetriever = new KeySetRetriever(settings.get("openid_connect_url"), getSSLConfig(settings, configPath), settings.getAsBoolean("cache_jwks_endpoint", false));
43+
}
3844

3945
keySetRetriever.setRequestTimeoutMs(idpRequestTimeoutMs);
4046

src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ public class JwtVerifier {
3333

3434
private final KeyProvider keyProvider;
3535
private final int clockSkewToleranceSeconds;
36-
37-
public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds ) {
36+
private final String requiredIssuer;
37+
private final String requiredAudience;
38+
39+
public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, String requiredAudience) {
3840
this.keyProvider = keyProvider;
3941
this.clockSkewToleranceSeconds = clockSkewToleranceSeconds;
42+
this.requiredIssuer = requiredIssuer;
43+
this.requiredAudience = requiredAudience;
4044
}
4145

4246
public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
@@ -112,6 +116,20 @@ private void validateClaims(JwtToken jwt) throws BadCredentialsException, JwtExc
112116
if (claims != null) {
113117
JwtUtils.validateJwtExpiry(claims, clockSkewToleranceSeconds, false);
114118
JwtUtils.validateJwtNotBefore(claims, clockSkewToleranceSeconds, false);
119+
validateRequiredAudienceAndIssuer(claims);
120+
}
121+
}
122+
123+
private void validateRequiredAudienceAndIssuer(JwtClaims claims) {
124+
String audience = claims.getAudience();
125+
String issuer = claims.getIssuer();
126+
127+
if (!audience.equals(requiredAudience)) {
128+
throw new JwtException("Invalid issuer");
129+
}
130+
131+
if (!issuer.equals(requiredIssuer)) {
132+
throw new JwtException("Invalid issuer");
115133
}
116134
}
117135
}

src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetRetriever.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,20 @@ public class KeySetRetriever implements KeySetProvider {
5454
private int oidcCacheModuleResponses = 0;
5555
private long oidcRequests = 0;
5656
private long lastCacheStatusLog = 0;
57+
private String jwksUri;
5758

5859
KeySetRetriever(String openIdConnectEndpoint, SSLConfig sslConfig, boolean useCacheForOidConnectEndpoint) {
5960
this.openIdConnectEndpoint = openIdConnectEndpoint;
6061
this.sslConfig = sslConfig;
6162

62-
if (useCacheForOidConnectEndpoint) {
63-
cacheConfig = CacheConfig.custom().setMaxCacheEntries(10).setMaxObjectSize(1024L * 1024L).build();
64-
oidcHttpCacheStorage = new BasicHttpCacheStorage(cacheConfig);
65-
}
63+
configureCache(useCacheForOidConnectEndpoint);
64+
}
65+
66+
KeySetRetriever(SSLConfig sslConfig, boolean useCacheForOidConnectEndpoint, String jwksUri) {
67+
this.jwksUri = jwksUri;
68+
this.sslConfig = sslConfig;
69+
70+
configureCache(useCacheForOidConnectEndpoint);
6671
}
6772

6873
public JsonWebKeys get() throws AuthenticatorUnavailableException {
@@ -101,6 +106,10 @@ public JsonWebKeys get() throws AuthenticatorUnavailableException {
101106

102107
String getJwksUri() throws AuthenticatorUnavailableException {
103108

109+
if (jwksUri != null && !jwksUri.isBlank()) {
110+
return jwksUri;
111+
}
112+
104113
try (CloseableHttpClient httpClient = createHttpClient(oidcHttpCacheStorage)) {
105114

106115
HttpGet httpGet = new HttpGet(openIdConnectEndpoint);
@@ -204,6 +213,13 @@ private CloseableHttpClient createHttpClient(HttpCacheStorage httpCacheStorage)
204213
return builder.build();
205214
}
206215

216+
private void configureCache(boolean useCacheForOidConnectEndpoint) {
217+
if (useCacheForOidConnectEndpoint) {
218+
cacheConfig = CacheConfig.custom().setMaxCacheEntries(10).setMaxObjectSize(1024L * 1024L).build();
219+
oidcHttpCacheStorage = new BasicHttpCacheStorage(cacheConfig);
220+
}
221+
}
222+
207223
public int getOidcCacheHits() {
208224
return oidcCacheHits;
209225
}

src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java

Lines changed: 137 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.junit.BeforeClass;
1919
import org.junit.Test;
2020

21+
import org.opensearch.OpenSearchSecurityException;
2122
import org.opensearch.common.settings.Settings;
2223
import org.opensearch.security.user.AuthCredentials;
2324
import org.opensearch.security.util.FakeRestRequest;
@@ -44,7 +45,11 @@ public static void tearDown() {
4445

4546
@Test
4647
public void basicTest() {
47-
Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build();
48+
Settings settings = Settings.builder()
49+
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
50+
.put("required_issuer", TestJwts.TEST_ISSUER)
51+
.put("required_audience", TestJwts.TEST_AUDIENCE)
52+
.build();
4853

4954
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
5055

@@ -55,12 +60,110 @@ public void basicTest() {
5560
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
5661
Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud"));
5762
Assert.assertEquals(0, creds.getBackendRoles().size());
58-
Assert.assertEquals(3, creds.getAttributes().size());
63+
Assert.assertEquals(4, creds.getAttributes().size());
64+
}
65+
66+
67+
@Test
68+
public void jwksUriTest() {
69+
Settings settings = Settings.builder()
70+
.put("jwks_uri", mockIdpServer.getJwksUri())
71+
.put("required_issuer", TestJwts.TEST_ISSUER)
72+
.put("required_audience", TestJwts.TEST_AUDIENCE)
73+
.build();
74+
75+
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
76+
77+
AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(
78+
ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), null);
79+
80+
Assert.assertNotNull(creds);
81+
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
82+
Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud"));
83+
Assert.assertEquals(0, creds.getBackendRoles().size());
84+
Assert.assertEquals(4, creds.getAttributes().size());
85+
}
86+
87+
@Test
88+
public void jwksMissingRequiredIssuerInClaimTest() {
89+
Settings settings = Settings.builder()
90+
.put("jwks_uri", mockIdpServer.getJwksUri())
91+
.put("required_audience", TestJwts.TEST_AUDIENCE)
92+
.build();
93+
94+
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
95+
96+
AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(
97+
ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), null);
98+
99+
Assert.assertNull(creds);
100+
}
101+
102+
@Test
103+
public void jwksNotMatchingRequiredIssuerInClaimTest() {
104+
Settings settings = Settings.builder()
105+
.put("jwks_uri", mockIdpServer.getJwksUri())
106+
.put("required_issuer", "Wrong Issuer")
107+
.build();
108+
109+
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
110+
111+
AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(
112+
ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), null);
113+
114+
Assert.assertNull(creds);
115+
}
116+
117+
@Test
118+
public void jwksMissingRequiredAudienceInClaimTest() {
119+
Settings settings = Settings.builder()
120+
.put("jwks_uri", mockIdpServer.getJwksUri())
121+
.put("required_issuer", TestJwts.TEST_ISSUER)
122+
.build();
123+
124+
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
125+
126+
AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(
127+
ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), null);
128+
129+
Assert.assertNull(creds);
130+
}
131+
132+
@Test
133+
public void jwksNotMatchingRequiredAudienceInClaimTest() {
134+
Settings settings = Settings.builder()
135+
.put("jwks_uri", mockIdpServer.getJwksUri())
136+
.put("required_audience", "Wrong Audience")
137+
.build();
138+
139+
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
140+
141+
AuthCredentials creds = jwtAuth.extractCredentials(new FakeRestRequest(
142+
ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()), null);
143+
144+
Assert.assertNull(creds);
145+
}
146+
147+
@Test
148+
public void jwksUriMissingTest() {
149+
var exception = Assert.assertThrows(Exception.class, () -> {
150+
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(Settings.builder().build(), null);
151+
jwtAuth.extractCredentials(
152+
new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()),
153+
null);
154+
});
155+
156+
Assert.assertEquals("Authentication backend failed", exception.getMessage());
157+
Assert.assertEquals(OpenSearchSecurityException.class, exception.getClass());
59158
}
60159

61160
@Test
62161
public void testEscapeKid() {
63-
Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build();
162+
Settings settings = Settings.builder()
163+
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
164+
.put("required_issuer", TestJwts.TEST_ISSUER)
165+
.put("required_audience", TestJwts.TEST_AUDIENCE)
166+
.build();
64167

65168
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
66169

@@ -71,12 +174,16 @@ public void testEscapeKid() {
71174
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
72175
Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud"));
73176
Assert.assertEquals(0, creds.getBackendRoles().size());
74-
Assert.assertEquals(3, creds.getAttributes().size());
177+
Assert.assertEquals(4, creds.getAttributes().size());
75178
}
76179

77180
@Test
78181
public void bearerTest() {
79-
Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build();
182+
Settings settings = Settings.builder()
183+
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
184+
.put("required_issuer", TestJwts.TEST_ISSUER)
185+
.put("required_audience", TestJwts.TEST_AUDIENCE)
186+
.build();
80187

81188
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
82189

@@ -89,13 +196,17 @@ public void bearerTest() {
89196
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
90197
Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud"));
91198
Assert.assertEquals(0, creds.getBackendRoles().size());
92-
Assert.assertEquals(3, creds.getAttributes().size());
199+
Assert.assertEquals(4, creds.getAttributes().size());
93200
}
94201

95202
@Test
96203
public void testRoles() throws Exception {
97-
Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri())
98-
.put("roles_key", TestJwts.ROLES_CLAIM).build();
204+
Settings settings = Settings.builder()
205+
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
206+
.put("roles_key", TestJwts.ROLES_CLAIM)
207+
.put("required_issuer", TestJwts.TEST_ISSUER)
208+
.put("required_audience", TestJwts.TEST_AUDIENCE)
209+
.build();
99210

100211
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
101212

@@ -126,6 +237,8 @@ public void testExpInSkew() throws Exception {
126237
Settings settings = Settings.builder()
127238
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
128239
.put("jwt_clock_skew_tolerance_seconds", "10")
240+
.put("required_issuer", TestJwts.TEST_ISSUER)
241+
.put("required_audience", TestJwts.TEST_AUDIENCE)
129242
.build();
130243

131244
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
@@ -149,6 +262,8 @@ public void testNbf() throws Exception {
149262
Settings settings = Settings.builder()
150263
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
151264
.put("jwt_clock_skew_tolerance_seconds", "0")
265+
.put("required_issuer", TestJwts.TEST_ISSUER)
266+
.put("required_audience", TestJwts.TEST_AUDIENCE)
152267
.build();
153268

154269
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
@@ -172,6 +287,8 @@ public void testNbfInSkew() throws Exception {
172287
Settings settings = Settings.builder()
173288
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
174289
.put("jwt_clock_skew_tolerance_seconds", "10")
290+
.put("required_issuer", TestJwts.TEST_ISSUER)
291+
.put("required_audience", TestJwts.TEST_AUDIENCE)
175292
.build();
176293

177294
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
@@ -192,7 +309,11 @@ public void testNbfInSkew() throws Exception {
192309
@Test
193310
public void testRS256() throws Exception {
194311

195-
Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build();
312+
Settings settings = Settings.builder()
313+
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
314+
.put("required_issuer", TestJwts.TEST_ISSUER)
315+
.put("required_audience", TestJwts.TEST_AUDIENCE)
316+
.build();
196317

197318
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
198319

@@ -203,7 +324,7 @@ public void testRS256() throws Exception {
203324
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
204325
Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud"));
205326
Assert.assertEquals(0, creds.getBackendRoles().size());
206-
Assert.assertEquals(3, creds.getAttributes().size());
327+
Assert.assertEquals(4, creds.getAttributes().size());
207328
}
208329

209330
@Test
@@ -221,7 +342,11 @@ public void testBadSignature() throws Exception {
221342

222343
@Test
223344
public void testPeculiarJsonEscaping() {
224-
Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build();
345+
Settings settings = Settings.builder()
346+
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
347+
.put("required_issuer", TestJwts.TEST_ISSUER)
348+
.put("required_audience", TestJwts.TEST_AUDIENCE)
349+
.build();
225350

226351
HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);
227352

@@ -233,7 +358,7 @@ public void testPeculiarJsonEscaping() {
233358
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
234359
Assert.assertEquals(TestJwts.TEST_AUDIENCE, creds.getAttributes().get("attr.jwt.aud"));
235360
Assert.assertEquals(0, creds.getBackendRoles().size());
236-
Assert.assertEquals(3, creds.getAttributes().size());
361+
Assert.assertEquals(4, creds.getAttributes().size());
237362
}
238363

239364
}

src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ public String getDiscoverUri() {
118118
return uri + CTX_DISCOVER;
119119
}
120120

121+
public String getJwksUri() {
122+
return uri + CTX_KEYS;
123+
}
124+
121125
public int getPort() {
122126
return port;
123127
}

0 commit comments

Comments
 (0)