Skip to content

Commit cabae29

Browse files
authored
Merge pull request #988 from AzureAD/avdunn/improve-credentials
Improve behavior related to assertions
2 parents fb0f51e + dcf0369 commit cabae29

File tree

6 files changed

+270
-39
lines changed

6 files changed

+270
-39
lines changed

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/Authority.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,40 @@ static void validateAuthority(URL authorityUrl) {
131131
}
132132
}
133133

134+
/**
135+
* Creates a new Authority instance with a different tenant.
136+
* This is useful when overriding the tenant at request level.
137+
*
138+
* @param originalAuthority The original authority to base the new one on
139+
* @param newTenant The new tenant to use in the authority URL
140+
* @return A new Authority instance with the specified tenant
141+
* @throws MalformedURLException If the new authority URL is invalid
142+
* @throws NullPointerException If originalAuthority or newTenant is null
143+
*/
144+
static Authority replaceTenant(Authority originalAuthority, String newTenant) throws MalformedURLException {
145+
if (originalAuthority == null) {
146+
throw new NullPointerException("originalAuthority");
147+
}
148+
if (StringHelper.isBlank(newTenant)) {
149+
throw new NullPointerException("newTenant");
150+
}
151+
152+
URL originalUrl = originalAuthority.canonicalAuthorityUrl();
153+
String host = originalUrl.getHost();
154+
String protocol = originalUrl.getProtocol();
155+
int port = originalUrl.getPort();
156+
157+
// Build path with new tenant
158+
String newAuthority = String.format("%s://%s%s/%s/",
159+
protocol,
160+
host,
161+
(port == -1 ? "" : ":" + port),
162+
newTenant);
163+
164+
// Create proper authority instance with the tenant-specific URL
165+
return createAuthority(new URL(newAuthority));
166+
}
167+
134168
static String getTenant(URL authorityUrl, AuthorityType authorityType) {
135169
String[] segments = authorityUrl.getPath().substring(1).split("/");
136170
if (authorityType == AuthorityType.B2C) {

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientAssertion.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,66 @@
44
package com.microsoft.aad.msal4j;
55

66
import java.util.Objects;
7+
import java.util.concurrent.Callable;
78

89
final class ClientAssertion implements IClientAssertion {
910

1011
static final String ASSERTION_TYPE_JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
1112
private final String assertion;
13+
private final Callable<String> assertionProvider;
1214

15+
/**
16+
* Constructor that accepts a static assertion string
17+
*
18+
* @param assertion The JWT assertion string to use
19+
* @throws NullPointerException if assertion is null or empty
20+
*/
1321
ClientAssertion(final String assertion) {
1422
if (StringHelper.isBlank(assertion)) {
1523
throw new NullPointerException("assertion");
1624
}
1725

1826
this.assertion = assertion;
27+
this.assertionProvider = null;
1928
}
2029

30+
/**
31+
* Constructor that accepts a callable that provides the assertion string
32+
*
33+
* @param assertionProvider A callable that returns a JWT assertion string
34+
* @throws NullPointerException if assertionProvider is null
35+
*/
36+
ClientAssertion(final Callable<String> assertionProvider) {
37+
if (assertionProvider == null) {
38+
throw new NullPointerException("assertionProvider");
39+
}
40+
41+
this.assertion = null;
42+
this.assertionProvider = assertionProvider;
43+
}
44+
45+
/**
46+
* Gets the JWT assertion for client authentication.
47+
* If this ClientAssertion was created with a Callable, the callable will be
48+
* invoked each time this method is called to generate a fresh assertion.
49+
*
50+
* @return A JWT assertion string
51+
* @throws MsalClientException if the assertion provider returns null/empty or throws an exception
52+
*/
2153
public String assertion() {
54+
if (assertionProvider != null) {
55+
try {
56+
String generatedAssertion = assertionProvider.call();
57+
if (StringHelper.isBlank(generatedAssertion)) {
58+
throw new MsalClientException("Assertion provider returned null or empty assertion",
59+
AuthenticationErrorCode.INVALID_JWT);
60+
}
61+
return generatedAssertion;
62+
} catch (Exception ex) {
63+
throw new MsalClientException(ex);
64+
}
65+
}
66+
2267
return this.assertion;
2368
}
2469

@@ -30,11 +75,24 @@ public boolean equals(Object o) {
3075
if (!(o instanceof ClientAssertion)) return false;
3176

3277
ClientAssertion other = (ClientAssertion) o;
78+
79+
// For assertion providers, we consider them equal if they're the same object
80+
if (this.assertionProvider != null && other.assertionProvider != null) {
81+
return this.assertionProvider == other.assertionProvider;
82+
}
83+
84+
// For static assertions, compare the assertion strings
3385
return Objects.equals(assertion(), other.assertion());
3486
}
3587

3688
@Override
3789
public int hashCode() {
90+
// For assertion providers, use the provider's identity hash code
91+
if (assertionProvider != null) {
92+
return System.identityHashCode(assertionProvider);
93+
}
94+
95+
// For static assertions, hash the assertion string
3896
int result = 1;
3997
result = result * 59 + (this.assertion == null ? 43 : this.assertion.hashCode());
4098
return result;

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ClientCredentialFactory.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,18 @@ public static IClientAssertion createFromClientAssertion(String clientAssertion)
9191

9292
/**
9393
* Static method to create a {@link ClientAssertion} instance from a provided Callable.
94+
* The callable will be invoked each time the assertion is needed, allowing for dynamic
95+
* generation of assertions.
9496
*
9597
* @param callable Callable that produces a JWT token encoded as a base64 URL encoded string
96-
* @return {@link ClientAssertion}
98+
* @return {@link ClientAssertion} that will invoke the callable each time assertion() is called
99+
* @throws NullPointerException if callable is null
97100
*/
98-
public static IClientAssertion createFromCallback(Callable<String> callable) throws ExecutionException, InterruptedException {
99-
ExecutorService executor = Executors.newSingleThreadExecutor();
100-
101-
Future<String> future = executor.submit(callable);
101+
public static IClientAssertion createFromCallback(Callable<String> callable) {
102+
if (callable == null) {
103+
throw new NullPointerException("callable");
104+
}
102105

103-
return new ClientAssertion(future.get());
106+
return new ClientAssertion(callable);
104107
}
105108
}

msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/TokenRequestExecutor.java

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -90,64 +90,63 @@ private void addQueryParameters(OAuthHttpRequest oauthHttpRequest) {
9090
if (msalRequest.application() instanceof ConfidentialClientApplication) {
9191
ConfidentialClientApplication application = (ConfidentialClientApplication) msalRequest.application();
9292

93-
// Determine which credential to use - either from the request or from the application
94-
IClientCredential credential = getCredentialToUse(application);
95-
96-
// Add appropriate authentication parameters based on the credential type
97-
addCredentialToRequest(queryParameters, credential, application);
93+
// Consolidated credential and tenant override handling
94+
addCredentialToRequest(queryParameters, application);
9895
}
9996

10097
oauthHttpRequest.setQuery(StringHelper.serializeQueryParameters(queryParameters));
10198
}
10299

103-
/**
104-
* Determines which credential to use for authentication:
105-
* - If the request is a ClientCredentialRequest with a specified credential, use that
106-
* - Otherwise use the application's credential
107-
*
108-
* @param application The confidential client application
109-
* @return The credential to use, may be null if no credential is available
110-
*/
111-
private IClientCredential getCredentialToUse(ConfidentialClientApplication application) {
112-
if (msalRequest instanceof ClientCredentialRequest &&
113-
((ClientCredentialRequest) msalRequest).parameters.clientCredential() != null) {
114-
return ((ClientCredentialRequest) msalRequest).parameters.clientCredential();
115-
}
116-
return application.clientCredential;
117-
}
118-
119100
/**
120101
* Adds the appropriate authentication parameters to the request based on credential type.
121102
* Handles different credential types (secret, assertion, certificate) by adding the appropriate
122103
* parameters to the request.
123104
*
124105
* @param queryParameters The map of query parameters to add to
125-
* @param credential The credential to use for authentication, may be null
126106
* @param application The confidential client application
127107
*/
128108
private void addCredentialToRequest(Map<String, String> queryParameters,
129-
IClientCredential credential,
130109
ConfidentialClientApplication application) {
131-
if (credential == null) {
110+
IClientCredential credentialToUse = application.clientCredential;
111+
Authority authorityToUse = application.authenticationAuthority;
112+
113+
// A ClientCredentialRequest may have parameters which override the credentials used to build the application.
114+
if (msalRequest instanceof ClientCredentialRequest) {
115+
ClientCredentialParameters parameters = ((ClientCredentialRequest) msalRequest).parameters;
116+
117+
if (parameters.clientCredential() != null) {
118+
credentialToUse = parameters.clientCredential();
119+
}
120+
121+
if (parameters.tenant() != null) {
122+
try {
123+
authorityToUse = Authority.replaceTenant(authorityToUse, parameters.tenant());
124+
} catch (MalformedURLException e) {
125+
log.warn("Could not create authority with tenant override: {}", e.getMessage());
126+
}
127+
}
128+
}
129+
130+
// Quick return if no credential is provided
131+
if (credentialToUse == null) {
132132
return;
133133
}
134134

135-
if (credential instanceof ClientSecret) {
135+
if (credentialToUse instanceof ClientSecret) {
136136
// For client secret, add client_secret parameter
137-
queryParameters.put("client_secret", ((ClientSecret) credential).clientSecret());
138-
} else if (credential instanceof ClientAssertion) {
137+
queryParameters.put("client_secret", ((ClientSecret) credentialToUse).clientSecret());
138+
} else if (credentialToUse instanceof ClientAssertion) {
139139
// For client assertion, add client_assertion and client_assertion_type parameters
140-
addJWTBearerAssertionParams(queryParameters, ((ClientAssertion) credential).assertion());
141-
} else if (credential instanceof ClientCertificate) {
140+
addJWTBearerAssertionParams(queryParameters, ((ClientAssertion) credentialToUse).assertion());
141+
} else if (credentialToUse instanceof ClientCertificate) {
142142
// For client certificate, generate a new assertion and add it to the request
143-
ClientCertificate certificate = (ClientCertificate) credential;
143+
ClientCertificate certificate = (ClientCertificate) credentialToUse;
144144
String assertion = certificate.getAssertion(
145-
application.authenticationAuthority,
145+
authorityToUse,
146146
application.clientId(),
147147
application.sendX5c());
148148
addJWTBearerAssertionParams(queryParameters, assertion);
149149
}
150-
// If credential is of an unknown type, no additional parameters are added
151150
}
152151

153152
/**

msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ClientCertificateTest.java

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static org.junit.jupiter.api.Assertions.assertNotNull;
1313
import static org.junit.jupiter.api.Assertions.assertNull;
1414
import static org.junit.jupiter.api.Assertions.assertThrows;
15+
import static org.junit.jupiter.api.Assertions.assertTrue;
1516
import static org.mockito.ArgumentMatchers.any;
1617
import static org.mockito.Mockito.*;
1718

@@ -156,6 +157,108 @@ void testClientCertificate_GeneratesNewAssertionEachTime() throws Exception {
156157
"The access tokens from each request should be different");
157158
}
158159

160+
@Test
161+
void testClientCertificate_TenantOverride() throws Exception {
162+
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);
163+
Map<String, String> capturedTenants = new HashMap<>();
164+
165+
ConfidentialClientApplication cca =
166+
ConfidentialClientApplication.builder("clientId", ClientCredentialFactory.createFromCertificate(TestHelper.getPrivateKey(), TestHelper.getX509Cert()))
167+
.authority("https://login.microsoftonline.com/default-tenant")
168+
.instanceDiscovery(false)
169+
.validateAuthority(false)
170+
.httpClient(httpClientMock)
171+
.build();
172+
173+
// Mock the HTTP client to capture and analyze assertions from each request
174+
when(httpClientMock.send(any(HttpRequest.class))).thenAnswer(parameters -> {
175+
HttpRequest request = parameters.getArgument(0);
176+
String requestBody = request.body();
177+
String url = request.url().toString();
178+
179+
// Capture which tenant was used in the authority
180+
String tenant = extractTenantFromUrl(url);
181+
182+
// Extract the assertion to verify its audience claim
183+
String clientAssertion = extractClientAssertion(requestBody);
184+
if (clientAssertion != null) {
185+
SignedJWT signedJWT = SignedJWT.parse(clientAssertion);
186+
187+
// Get the audience claim to verify it matches the tenant
188+
String audience = signedJWT.getJWTClaimsSet().getAudience().get(0);
189+
190+
// Store the tenant and audience for verification
191+
capturedTenants.put(tenant, audience);
192+
193+
// Verify it's a valid JWT with proper headers
194+
if (signedJWT.getHeader().toJSONObject().containsKey("x5t#S256")) {
195+
HashMap<String, String> tokenResponseValues = new HashMap<>();
196+
tokenResponseValues.put("access_token", "access_token_for_" + tenant);
197+
return TestHelper.expectedResponse(200, TestHelper.getSuccessfulTokenResponse(tokenResponseValues));
198+
}
199+
}
200+
return null;
201+
});
202+
203+
// First request with default tenant
204+
ClientCredentialParameters defaultParameters = ClientCredentialParameters.builder(Collections.singleton("scopes"))
205+
.skipCache(true)
206+
.build();
207+
IAuthenticationResult resultDefault = cca.acquireToken(defaultParameters).get();
208+
209+
// Second request with override tenant
210+
String overrideTenant = "override-tenant";
211+
ClientCredentialParameters overrideParameters = ClientCredentialParameters.builder(Collections.singleton("scopes"))
212+
.skipCache(true)
213+
.tenant(overrideTenant)
214+
.build();
215+
IAuthenticationResult resultOverride = cca.acquireToken(overrideParameters).get();
216+
217+
// Verify both requests were processed
218+
assertEquals(2, capturedTenants.size(), "Two requests with different tenants should have been processed");
219+
220+
// Verify both tenants were used
221+
assertTrue(capturedTenants.containsKey("default-tenant"), "Default tenant should have been used");
222+
assertTrue(capturedTenants.containsKey(overrideTenant), "Override tenant should have been used");
223+
224+
// Verify the audience in the JWT assertions reflects the different tenants
225+
assertNotEquals(
226+
capturedTenants.get("default-tenant"),
227+
capturedTenants.get(overrideTenant),
228+
"JWT audience should differ between default and override tenant"
229+
);
230+
231+
// Verify the audience claims match the expected format with the correct tenant
232+
assertTrue(
233+
capturedTenants.get("default-tenant").contains("default-tenant"),
234+
"Audience for default tenant should contain the default tenant name"
235+
);
236+
assertTrue(
237+
capturedTenants.get(overrideTenant).contains(overrideTenant),
238+
"Audience for override tenant should contain the override tenant name"
239+
);
240+
241+
// Verify different access tokens were returned
242+
assertNotEquals(resultDefault.accessToken(), resultOverride.accessToken(),
243+
"Access tokens should differ when using different tenants");
244+
}
245+
246+
/**
247+
* Extracts the tenant name from an authority URL
248+
* @param url The full URL containing the tenant
249+
* @return The tenant name
250+
*/
251+
private String extractTenantFromUrl(String url) {
252+
// Authority URL format is typically https://login.microsoftonline.com/tenant/...
253+
String[] parts = url.split("/");
254+
for (int i = 0; i < parts.length; i++) {
255+
if (parts[i].equalsIgnoreCase("login.microsoftonline.com") && i + 1 < parts.length) {
256+
return parts[i + 1];
257+
}
258+
}
259+
return null;
260+
}
261+
159262
/**
160263
* Extracts the client_assertion value from a URL-encoded request body
161264
* @param requestBody The request body string

0 commit comments

Comments
 (0)