Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ abstract class AbstractManagedIdentitySource {

protected final ManagedIdentityRequest managedIdentityRequest;
protected final ServiceBundle serviceBundle;
private ManagedIdentitySourceType managedIdentitySourceType;
ManagedIdentitySourceType managedIdentitySourceType;

@Getter
@Setter
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.HashMap;

class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{

private final static Logger LOG = LoggerFactory.getLogger(AzureArcManagedIdentitySource.class);
private static final String ARC_API_VERSION = "2019-11-01";
private static final String AZURE_ARC = "Azure Arc";

private final URI MSI_ENDPOINT;

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
{
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);

URI validatedUri = validateAndGetUri(identityEndpoint, imdsEndpoint);
return validatedUri == null ? null : new AzureArcManagedIdentitySource(validatedUri, msalRequest, serviceBundle );
}

private static URI validateAndGetUri(String identityEndpoint, String imdsEndpoint) {

// if BOTH the env vars IDENTITY_ENDPOINT and IMDS_ENDPOINT are set the MsiType is Azure Arc
if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(imdsEndpoint))
{
LOG.info("[Managed Identity] Azure Arc managed identity is unavailable.");
return null;
}

URI endpointUri;
try {
endpointUri = new URI(identityEndpoint);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "IDENTITY_ENDPOINT", identityEndpoint, AZURE_ARC),
ManagedIdentitySourceType.AzureArc);
}

LOG.info("[Managed Identity] Creating Azure Arc managed identity. Endpoint URI: " + endpointUri);
return endpointUri;
}

private AzureArcManagedIdentitySource(URI endpoint, MsalRequest msalRequest, ServiceBundle serviceBundle){
super(msalRequest, serviceBundle, ManagedIdentitySourceType.AzureArc);
this.MSI_ENDPOINT = endpoint;

ManagedIdentityIdType idType =
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
if (idType != ManagedIdentityIdType.SystemAssigned) {
throw new MsalManagedIdentityException(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED,
String.format(MsalErrorMessage.MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED, AZURE_ARC),
ManagedIdentitySourceType.CloudShell);
}
}

@Override
public void createManagedIdentityRequest(String resource)
{
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put("Metadata", "true");

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(ARC_API_VERSION));
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
}

@Override
public ManagedIdentityResponse handleResponse(
ManagedIdentityParameters parameters,
IHttpResponse response)
{
LOG.info("[Managed Identity] Response received. Status code: {response.StatusCode}");

if (response.statusCode() == HttpURLConnection.HTTP_UNAUTHORIZED)
{
if(!response.headers().containsKey("WWW-Authenticate")){
LOG.error("[Managed Identity] WWW-Authenticate header is expected but not found.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR,
ManagedIdentitySourceType.AzureArc);
}

String challenge = response.headers().get("WWW-Authenticate").get(0);
String[] splitChallenge = challenge.split("=");

if (splitChallenge.length != 2)
{
LOG.error("[Managed Identity] The WWW-Authenticate header for Azure arc managed identity is not an expected format.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE,
ManagedIdentitySourceType.AzureArc);
}

String authHeaderValue = "Basic " + splitChallenge[1];

createManagedIdentityRequest(parameters.resource);

LOG.info("[Managed Identity] Adding authorization header to the request.");

managedIdentityRequest.headers.put("Authorization", authHeaderValue);

try {
response = HttpHelper.executeHttpRequest(
new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString(),
managedIdentityRequest.headers),
managedIdentityRequest.requestContext(),
serviceBundle);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT,
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR,
managedIdentitySourceType);
}

return super.handleResponse(parameters, response);
}

return super.handleResponse(parameters, response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ private static AbstractManagedIdentitySource createManagedIdentitySource(MsalReq
return managedIdentitySource;
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else {
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class ManagedIdentityParameters implements IAcquireTokenParameters {

@Getter
String resource;

boolean forceRefresh;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ public static Stream<Arguments> createData() {
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
ManagedIdentityTests.resourceDefaultSuffix),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityTests.resourceDefaultSuffix),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
Expand All @@ -45,7 +49,11 @@ public static Stream<Arguments> createDataUserAssignedNotSupported() {
return Stream.of(
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
Arguments.of(ManagedIdentitySourceType.CloudShell, ManagedIdentityTests.cloudShellEndpoint,
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)));
}

Expand All @@ -59,6 +67,10 @@ public static Stream<Arguments> createDataWrongScope() {
"user.read"),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint,
"https://management.core.windows.net//user_impersonation"),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
"user.read"),
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint,
"https://management.core.windows.net//user_impersonation"),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
"user.read"),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
Expand All @@ -67,6 +79,7 @@ public static Stream<Arguments> createDataWrongScope() {

public static Stream<Arguments> createDataError() {
return Stream.of(
Arguments.of(ManagedIdentitySourceType.AzureArc, ManagedIdentityTests.azureArcEndpoint),
Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package com.microsoft.aad.msal4j;

import com.nimbusds.oauth2.sdk.util.URLUtils;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -80,6 +82,15 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
headers.put("Metadata", "true");
break;
}
case AzureArc: {
endpoint = azureArcEndpoint;

queryParameters.put("api-version", Collections.singletonList("2019-11-01"));
queryParameters.put("resource", Collections.singletonList(resource));

headers.put("Metadata", "true");
break;
}
}

switch (id.getIdType()) {
Expand Down Expand Up @@ -182,6 +193,7 @@ void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType sou
.build()).get();
} catch (Exception e) {
assertNotNull(e);
assertNotNull(e.getCause());
assertInstanceOf(MsalManagedIdentityException.class, e.getCause());

MsalManagedIdentityException msalMsiException = (MsalManagedIdentityException) e.getCause();
Expand Down Expand Up @@ -349,6 +361,39 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType
fail("MsalManagedIdentityException is expected but not thrown.");
}

@Test
void azureArcManagedIdentity_MissingAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AzureArc, azureArcEndpoint);
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
response.statusCode(HttpStatus.SC_UNAUTHORIZED);

lenient().when(httpClientMock.send(any())).thenReturn(response);

ManagedIdentityApplication miApp = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

try {
miApp.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.environmentVariables(environmentVariables)
.build()).get();
} catch (Exception exception) {
assert(exception.getCause() instanceof MsalManagedIdentityException);

MsalManagedIdentityException miException = (MsalManagedIdentityException) exception.getCause();
assertEquals(ManagedIdentitySourceType.AzureArc, miException.managedIdentitySourceType);
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR, miException.getMessage());
return;
}

fail("MsalManagedIdentityException is expected but not thrown.");
}

@ParameterizedTest
@MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError")
void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception {
Expand All @@ -361,13 +406,13 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

ManagedIdentityApplication miApp2 = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

IAuthenticationResult resultMiApp1 = miApp1.acquireTokenForManagedIdentity(
IAuthenticationResult resultMiApp1 = miApp1.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.environmentVariables(environmentVariables)
.build()).get();
Expand All @@ -386,4 +431,38 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi
// should return the same token
assertEquals(resultMiApp1.accessToken(), resultMiApp2.accessToken());
}

@Test
void azureArcManagedIdentity_InvalidAuthHeader() throws Exception {
IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.AzureArc, azureArcEndpoint);
DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class);

HttpResponse response = new HttpResponse();
response.statusCode(HttpStatus.SC_UNAUTHORIZED);
response.headers().put("WWW-Authenticate", Collections.singletonList("Basic realm=filepath=somepath"));

lenient().when(httpClientMock.send(any())).thenReturn(response);

ManagedIdentityApplication miApp = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

try {
miApp.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.environmentVariables(environmentVariables)
.build()).get();
} catch (Exception exception) {
assert(exception.getCause() instanceof MsalManagedIdentityException);

MsalManagedIdentityException miException = (MsalManagedIdentityException) exception.getCause();
assertEquals(ManagedIdentitySourceType.AzureArc, miException.managedIdentitySourceType);
assertEquals(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode());
assertEquals(MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE, miException.getMessage());
return;
}

fail("MsalManagedIdentityException is expected but not thrown.");
}
}