Skip to content

Commit 11f2364

Browse files
authored
Merge branch 'main' into search-relevance
Signed-off-by: Craig Perkins <cwperx@amazon.com>
2 parents 6bb2b64 + 228744a commit 11f2364

File tree

10 files changed

+292
-42
lines changed

10 files changed

+292
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
99
- Github workflow for changelog verification ([#5318](https://github.com/opensearch-project/security/pull/5318))
1010
- Register cluster settings listener for `plugins.security.cache.ttl_minutes` ([#5324](https://github.com/opensearch-project/security/pull/5324))
1111
- Add flush cache endpoint for individual user ([#5337](https://github.com/opensearch-project/security/pull/5337))
12+
- Handle roles in nested claim for JWT auth backends ([#5355](https://github.com/opensearch-project/security/pull/5355))
1213
- Integrate search-relevance functionalities with security plugin ([#5376](https://github.com/opensearch-project/security/pull/5376))
1314

1415
### Changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* The OpenSearch Contributors require contributions made to
6+
* this file be licensed under the Apache-2.0 license or a
7+
* compatible open source license.
8+
*
9+
*/
10+
package org.opensearch.security.http;
11+
12+
import java.security.KeyPair;
13+
import java.util.Arrays;
14+
import java.util.Base64;
15+
import java.util.HashMap;
16+
import java.util.List;
17+
import java.util.Map;
18+
19+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
20+
import org.apache.hc.core5.http.Header;
21+
import org.junit.ClassRule;
22+
import org.junit.Rule;
23+
import org.junit.Test;
24+
import org.junit.runner.RunWith;
25+
26+
import org.opensearch.test.framework.JwtConfigBuilder;
27+
import org.opensearch.test.framework.TestSecurityConfig;
28+
import org.opensearch.test.framework.cluster.ClusterManager;
29+
import org.opensearch.test.framework.cluster.LocalCluster;
30+
import org.opensearch.test.framework.cluster.TestRestClient;
31+
import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse;
32+
import org.opensearch.test.framework.log.LogsRule;
33+
34+
import io.jsonwebtoken.SignatureAlgorithm;
35+
import io.jsonwebtoken.security.Keys;
36+
37+
import static java.nio.charset.StandardCharsets.US_ASCII;
38+
import static org.hamcrest.MatcherAssert.assertThat;
39+
import static org.hamcrest.Matchers.containsInAnyOrder;
40+
import static org.hamcrest.Matchers.equalTo;
41+
import static org.hamcrest.Matchers.hasSize;
42+
import static org.opensearch.security.http.JwtAuthenticationTests.POINTER_BACKEND_ROLES;
43+
import static org.opensearch.security.http.JwtAuthenticationTests.POINTER_USERNAME;
44+
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.BASIC_AUTH_DOMAIN_ORDER;
45+
46+
@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
47+
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
48+
public class JwtAuthenticationNestedClaimsTests {
49+
50+
public static final String CLAIM_USERNAME = "preferred-username";
51+
public static final List<String> CLAIM_ROLES = List.of("attributes", "roles");
52+
53+
public static final String USER_SUPERHERO = "superhero";
54+
private static final KeyPair KEY_PAIR1 = Keys.keyPairFor(SignatureAlgorithm.RS256);
55+
private static final String PUBLIC_KEY1 = new String(Base64.getEncoder().encode(KEY_PAIR1.getPublic().getEncoded()), US_ASCII);
56+
private static final String JWT_AUTH_HEADER = "jwt-auth";
57+
58+
private static final JwtAuthorizationHeaderFactory tokenFactory1 = new JwtAuthorizationHeaderFactory(
59+
KEY_PAIR1.getPrivate(),
60+
CLAIM_USERNAME,
61+
CLAIM_ROLES,
62+
JWT_AUTH_HEADER
63+
);
64+
public static final TestSecurityConfig.AuthcDomain JWT_AUTH_DOMAIN = new TestSecurityConfig.AuthcDomain(
65+
"jwt",
66+
BASIC_AUTH_DOMAIN_ORDER - 1
67+
).jwtHttpAuthenticator(
68+
new JwtConfigBuilder().jwtHeader(JWT_AUTH_HEADER).signingKey(List.of(PUBLIC_KEY1)).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES)
69+
).backend("noop");
70+
71+
@ClassRule
72+
public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE)
73+
.anonymousAuth(false)
74+
.authc(JWT_AUTH_DOMAIN)
75+
.build();
76+
77+
@Rule
78+
public LogsRule logsRule = new LogsRule("org.opensearch.security.auth.http.jwt.HTTPJwtAuthenticator");
79+
80+
// TODO write tests for scenarios where roles are in nested claim. i.e. rolesKey: ['attributes', 'roles']
81+
@Test
82+
public void shouldAuthenticateWithNestedRolesClaim() {
83+
// Create nested claims structure
84+
Map<String, Object> attributes = new HashMap<>();
85+
List<String> rolesClaim = Arrays.asList("all_access", "securitymanager");
86+
attributes.put("roles", rolesClaim);
87+
88+
Map<String, Object> nestedClaims = new HashMap<>();
89+
nestedClaims.put("attributes", attributes);
90+
91+
// Generate token with nested claims
92+
Header header = tokenFactory1.generateValidTokenWithCustomClaims(USER_SUPERHERO, null, nestedClaims);
93+
94+
try (TestRestClient client = cluster.getRestClient(header)) {
95+
HttpResponse response = client.getAuthInfo();
96+
97+
response.assertStatusCode(200);
98+
String username = response.getTextFromJsonBody(POINTER_USERNAME);
99+
assertThat(username, equalTo(USER_SUPERHERO));
100+
List<String> roles = response.getTextArrayFromJsonBody(POINTER_BACKEND_ROLES);
101+
assertThat(roles, hasSize(2));
102+
assertThat(roles, containsInAnyOrder("all_access", "securitymanager"));
103+
}
104+
}
105+
106+
@Test
107+
public void shouldHandleMissingNestedRolesClaim() {
108+
// Create invalid nested claims structure
109+
Map<String, Object> attributes = new HashMap<>();
110+
attributes.put("wrong", "missing"); // Invalid format - should be a list
111+
112+
Map<String, Object> nestedClaims = new HashMap<>();
113+
nestedClaims.put("attributes", attributes);
114+
115+
Header header = tokenFactory1.generateValidTokenWithCustomClaims(USER_SUPERHERO, null, nestedClaims);
116+
117+
try (TestRestClient client = cluster.getRestClient(header)) {
118+
HttpResponse response = client.getAuthInfo();
119+
120+
response.assertStatusCode(200);
121+
String username = response.getTextFromJsonBody(POINTER_USERNAME);
122+
assertThat(username, equalTo(USER_SUPERHERO));
123+
List<String> roles = response.getTextArrayFromJsonBody(POINTER_BACKEND_ROLES);
124+
assertThat(roles, hasSize(0));
125+
}
126+
}
127+
}

src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
public class JwtAuthenticationTests {
7070

7171
public static final String CLAIM_USERNAME = "preferred-username";
72-
public static final String CLAIM_ROLES = "backend-user-roles";
72+
public static final List<String> CLAIM_ROLES = List.of("backend-user-roles");
7373

7474
public static final String USER_SUPERHERO = "superhero";
7575
public static final String USERNAME_ROOT = "root";
@@ -305,5 +305,4 @@ public void secondKeypairShouldAuthenticateWithJwtToken_positiveWithAnotherUsern
305305
assertThat(username, equalTo(USERNAME_ROOT));
306306
}
307307
}
308-
309308
}

src/integrationTest/java/org/opensearch/security/http/JwtAuthenticationWithUrlParamTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
public class JwtAuthenticationWithUrlParamTests {
5252

5353
public static final String CLAIM_USERNAME = "preferred-username";
54-
public static final String CLAIM_ROLES = "backend-user-roles";
54+
public static final List<String> CLAIM_ROLES = List.of("backend-user-roles");
5555
public static final String POINTER_USERNAME = "/user_name";
5656

5757
private static final KeyPair KEY_PAIR = Keys.keyPairFor(SignatureAlgorithm.RS256);

src/integrationTest/java/org/opensearch/security/http/JwtAuthorizationHeaderFactory.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
package org.opensearch.security.http;
1111

1212
import java.security.PrivateKey;
13-
import java.util.Arrays;
1413
import java.util.Date;
1514
import java.util.HashMap;
15+
import java.util.List;
1616
import java.util.Map;
17-
import java.util.stream.Collectors;
1817

1918
import com.google.common.collect.ImmutableMap;
2019
import org.apache.commons.lang3.StringUtils;
@@ -33,11 +32,11 @@ class JwtAuthorizationHeaderFactory {
3332

3433
private final String usernameClaimName;
3534

36-
private final String rolesClaimName;
35+
private final List<String> rolesClaimName;
3736

3837
private final String headerName;
3938

40-
public JwtAuthorizationHeaderFactory(PrivateKey privateKey, String usernameClaimName, String rolesClaimName, String headerName) {
39+
public JwtAuthorizationHeaderFactory(PrivateKey privateKey, String usernameClaimName, List<String> rolesClaimName, String headerName) {
4140
this.privateKey = requireNonNull(privateKey, "Private key is required");
4241
this.usernameClaimName = requireNonNull(usernameClaimName, "Username claim name is required");
4342
this.rolesClaimName = requireNonNull(rolesClaimName, "Roles claim name is required.");
@@ -64,8 +63,28 @@ private Map<String, Object> customClaimsMap(String username, String[] roles) {
6463
if (StringUtils.isNoneEmpty(username)) {
6564
builder.put(usernameClaimName, username);
6665
}
67-
if ((roles != null) && (roles.length > 0)) {
68-
builder.put(rolesClaimName, Arrays.stream(roles).collect(Collectors.joining(",")));
66+
if (roles != null && roles.length > 0) {
67+
if (rolesClaimName.size() == 1) {
68+
// Simple case - no nesting
69+
builder.put(rolesClaimName.get(0), String.join(",", roles));
70+
} else {
71+
// Handle nested claims
72+
Map<String, Object> nestedMap = new HashMap<>();
73+
Map<String, Object> currentMap = nestedMap;
74+
75+
// Build the nested structure
76+
for (int i = 0; i < rolesClaimName.size() - 1; i++) {
77+
Map<String, Object> nextMap = new HashMap<>();
78+
currentMap.put(rolesClaimName.get(i), nextMap);
79+
currentMap = nextMap;
80+
}
81+
82+
// Add the roles array at the deepest level
83+
currentMap.put(rolesClaimName.get(rolesClaimName.size() - 1), String.join(",", roles));
84+
85+
// Add the entire nested structure to the builder
86+
builder.putAll(nestedMap);
87+
}
6988
}
7089
return builder.build();
7190
}

src/integrationTest/java/org/opensearch/test/framework/JwtConfigBuilder.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public class JwtConfigBuilder {
2222
private String jwtUrlParameter;
2323
private List<String> signingKeys;
2424
private String subjectKey;
25-
private String rolesKey;
25+
private List<String> rolesKey;
2626

2727
public JwtConfigBuilder jwtHeader(String jwtHeader) {
2828
this.jwtHeader = jwtHeader;
@@ -45,6 +45,11 @@ public JwtConfigBuilder subjectKey(String subjectKey) {
4545
}
4646

4747
public JwtConfigBuilder rolesKey(String rolesKey) {
48+
this.rolesKey = List.of(rolesKey);
49+
return this;
50+
}
51+
52+
public JwtConfigBuilder rolesKey(List<String> rolesKey) {
4853
this.rolesKey = rolesKey;
4954
return this;
5055
}
@@ -64,7 +69,7 @@ public Map<String, Object> build() {
6469
if (isNoneBlank(subjectKey)) {
6570
builder.put("subject_key", subjectKey);
6671
}
67-
if (isNoneBlank(rolesKey)) {
72+
if (rolesKey != null && !rolesKey.isEmpty()) {
6873
builder.put("roles_key", rolesKey);
6974
}
7075
return builder.build();

src/main/java/org/opensearch/security/auth/http/jwt/AbstractHTTPJwtAuthenticator.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator
6161
private final boolean isDefaultAuthHeader;
6262
private final String jwtUrlParameter;
6363
private final String subjectKey;
64-
private final String rolesKey;
64+
private final List<String> rolesKey;
6565
private final List<String> requiredAudience;
6666
private final String requiredIssuer;
6767

@@ -72,7 +72,7 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) {
7272
jwtUrlParameter = settings.get("jwt_url_parameter");
7373
jwtHeaderName = settings.get("jwt_header", AUTHORIZATION);
7474
isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName);
75-
rolesKey = settings.get("roles_key");
75+
rolesKey = settings.getAsList("roles_key");
7676
subjectKey = settings.get("subject_key");
7777
clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS);
7878
requiredAudience = settings.getAsList("required_audience");
@@ -219,7 +219,21 @@ public String[] extractRoles(JWTClaimsSet claims) {
219219
return new String[0];
220220
}
221221

222-
Object rolesObject = claims.getClaim(rolesKey);
222+
Object rolesObject = null;
223+
Map<String, Object> claimsMap = claims.getClaims();
224+
for (int i = 0; i < rolesKey.size(); i++) {
225+
if (i == rolesKey.size() - 1) {
226+
rolesObject = claimsMap.get(rolesKey.get(i));
227+
} else if (claimsMap.get(rolesKey.get(i)) instanceof Map) {
228+
claimsMap = (Map<String, Object>) claimsMap.get(rolesKey.get(i));
229+
} else {
230+
log.warn(
231+
"Failed to get roles from JWT claims with roles_key '{}'. Check if this key is correct and available in the JWT payload.",
232+
rolesKey
233+
);
234+
return new String[0];
235+
}
236+
}
223237

224238
if (rolesObject == null) {
225239
log.warn(

0 commit comments

Comments
 (0)