Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ abstract class AbstractManagedIdentitySource {

protected final ManagedIdentityRequest managedIdentityRequest;
protected final ServiceBundle serviceBundle;
private ManagedIdentitySourceType managedIdentitySourceType;
ManagedIdentitySourceType managedIdentitySourceType;

@Getter
@Setter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ AuthenticationResult execute() throws Exception {
scopes.add(this.managedIdentityParameters.resource);
SilentParameters parameters = SilentParameters
.builder(scopes)
.tenant(managedIdentityParameters.tenant())
.build();

RequestContext context = new RequestContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ class AppServiceManagedIdentitySource extends AbstractManagedIdentitySource{
private static final String APP_SERVICE_MSI_API_VERSION = "2019-08-01";
private static final String SECRET_HEADER_NAME = "X-IDENTITY-HEADER";

private final URI MSI_ENDPOINT;
private final String SECRET;
private final URI msiEndpoint;
private final String identityHeader;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, SECRET);
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, identityHeader);

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(APP_SERVICE_MSI_API_VERSION));
Expand All @@ -50,8 +50,8 @@ public void createManagedIdentityRequest(String resource) {
private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint, String secret)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.APP_SERVICE);
this.MSI_ENDPOINT = msiEndpoint;
this.SECRET = secret;
this.msiEndpoint = msiEndpoint;
this.identityHeader = secret;
}

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.FileReader;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.HashMap;

class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{

private final static Logger LOG = LoggerFactory.getLogger(AzureArcManagedIdentitySource.class);
private static final String ARC_API_VERSION = "2019-11-01";
private static final String AZURE_ARC = "Azure Arc";

private final URI MSI_ENDPOINT;

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle)
{
IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT);

URI validatedUri = validateAndGetUri(identityEndpoint, imdsEndpoint);
return validatedUri == null ? null : new AzureArcManagedIdentitySource(validatedUri, msalRequest, serviceBundle );
}

private static URI validateAndGetUri(String identityEndpoint, String imdsEndpoint) {

// if BOTH the env vars IDENTITY_ENDPOINT and IMDS_ENDPOINT are set the MsiType is Azure Arc
if (StringHelper.isNullOrBlank(identityEndpoint) || StringHelper.isNullOrBlank(imdsEndpoint))
{
LOG.info("[Managed Identity] Azure Arc managed identity is unavailable.");
return null;
}

URI endpointUri;
try {
endpointUri = new URI(identityEndpoint);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "IDENTITY_ENDPOINT", identityEndpoint, AZURE_ARC),
ManagedIdentitySourceType.AZURE_ARC);
}

LOG.info("[Managed Identity] Creating Azure Arc managed identity. Endpoint URI: " + endpointUri);
return endpointUri;
}

private AzureArcManagedIdentitySource(URI endpoint, MsalRequest msalRequest, ServiceBundle serviceBundle){
super(msalRequest, serviceBundle, ManagedIdentitySourceType.AZURE_ARC);
this.MSI_ENDPOINT = endpoint;

ManagedIdentityIdType idType =
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
if (idType != ManagedIdentityIdType.SYSTEM_ASSIGNED) {
throw new MsalManagedIdentityException(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED,
String.format(MsalErrorMessage.MANAGED_IDENTITY_USER_ASSIGNED_NOT_SUPPORTED, AZURE_ARC),
ManagedIdentitySourceType.AZURE_ARC);
}
}

@Override
public void createManagedIdentityRequest(String resource)
{
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put("Metadata", "true");

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(ARC_API_VERSION));
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
}

@Override
public ManagedIdentityResponse handleResponse(
ManagedIdentityParameters parameters,
IHttpResponse response) {

LOG.info("[Managed Identity] Response received. Status code: {response.StatusCode}");

if (response.statusCode() == HttpURLConnection.HTTP_UNAUTHORIZED) {
if(!response.headers().containsKey("Www-Authenticate")) {
LOG.error("[Managed Identity] WWW-Authenticate header is expected but not found.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_NO_CHALLENGE_ERROR,
ManagedIdentitySourceType.AZURE_ARC);
}

String challenge = response.headers().get("Www-Authenticate").get(0);
String[] splitChallenge = challenge.split("=");

if (splitChallenge.length != 2) {
LOG.error("[Managed Identity] The WWW-Authenticate header for Azure arc managed identity is not an expected format.");
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED,
MsalErrorMessage.MANAGED_IDENTITY_INVALID_CHALLENGE,
ManagedIdentitySourceType.AZURE_ARC);
}

Path path = Paths.get(splitChallenge[1]);

String authHeaderValue = null;
try {
authHeaderValue = "Basic " + new String(Files.readAllBytes(path), StandardCharsets.UTF_8);
} catch (IOException e) {
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_FILE_READ_ERROR, e.getMessage(), ManagedIdentitySourceType.AZURE_ARC);
}

createManagedIdentityRequest(parameters.resource);

LOG.info("[Managed Identity] Adding authorization header to the request.");

managedIdentityRequest.headers.put("Authorization", authHeaderValue);

try {
response = HttpHelper.executeHttpRequest(
new HttpRequest(HttpMethod.GET, managedIdentityRequest.computeURI().toString(),
managedIdentityRequest.headers),
managedIdentityRequest.requestContext(),
serviceBundle);
} catch (URISyntaxException e) {
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT,
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR,
managedIdentitySourceType);
}

return super.handleResponse(parameters, response);
}

return super.handleResponse(parameters, response);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class CloudShellManagedIdentitySource extends AbstractManagedIdentitySource{

private static final Logger LOG = LoggerFactory.getLogger(CloudShellManagedIdentitySource.class);

private final URI MSI_ENDPOINT;
private final URI msiEndpoint;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.POST;

managedIdentityRequest.headers = new HashMap<>();
Expand All @@ -33,7 +33,7 @@ public void createManagedIdentityRequest(String resource) {
private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.CLOUD_SHELL);
this.MSI_ENDPOINT = msiEndpoint;
this.msiEndpoint = msiEndpoint;

ManagedIdentityIdType idType =
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
Expand All @@ -57,28 +57,23 @@ static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBund
return null;
}

URI validatedUri = validateAndGetUri(msiEndpoint);
return validatedUri == null ? null
: new CloudShellManagedIdentitySource(msalRequest, serviceBundle, validatedUri);
return new CloudShellManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(msiEndpoint));
}

private static URI validateAndGetUri(String msiEndpoint)
{
URI endpointUri = null;

try
{
endpointUri = new URI(msiEndpoint);
URI endpointUri = new URI(msiEndpoint);
LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
return endpointUri;
}
catch (URISyntaxException ex)
{
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Cloud Shell"),
ManagedIdentitySourceType.CLOUD_SHELL);
}

LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
return endpointUri;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public ManagedIdentityResponse handleResponse(

message = message + " " + errorContentMessage;

LOG.error(String.format("Error message: %s Http status code: %s"), message, response.statusCode());
LOG.error(String.format("Error message: %s Http status code: %s", message, response.statusCode()));
throw new MsalManagedIdentityException(MsalError.MANAGED_IDENTITY_REQUEST_FAILED, message,
ManagedIdentitySourceType.IMDS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
AbstractManagedIdentitySource managedIdentitySource;
if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
if ((managedIdentitySource = ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AzureArcManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else {
return new IMDSManagedIdentitySource(msalRequest, serviceBundle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class ManagedIdentityParameters implements IAcquireTokenParameters {

@Getter
String resource;

boolean forceRefresh;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ public class MsalError {
* Managed Identity endpoint is not reachable.
*/
public static final String MANAGED_IDENTITY_UNREACHABLE_NETWORK = "managed_identity_unreachable_network";

public static final String MANAGED_IDENTITY_FILE_READ_ERROR = "managed_identity_file_read_error";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.HashMap;

class ServiceFabricManagedIdentitySource extends AbstractManagedIdentitySource {

private static final Logger LOG = LoggerFactory.getLogger(ServiceFabricManagedIdentitySource.class);

private static final String SERVICE_FABRIC_MSI_API_VERSION = "2019-07-01-preview";

private final URI msiEndpoint;
private final String identityHeader;
private final ManagedIdentityIdType idType;
private final String userAssignedId;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put("secret", identityHeader);

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(SERVICE_FABRIC_MSI_API_VERSION));

if (idType == ManagedIdentityIdType.CLIENT_ID) {
LOG.info("[Managed Identity] Adding user assigned client id to the request for Service Fabric Managed Identity.");
managedIdentityRequest.queryParameters.put(Constants.MANAGED_IDENTITY_CLIENT_ID, Collections.singletonList(userAssignedId));
} else if (idType == ManagedIdentityIdType.RESOURCE_ID) {
LOG.info("[Managed Identity] Adding user assigned resource id to the request for Service Fabric Managed Identity.");
managedIdentityRequest.queryParameters.put(Constants.MANAGED_IDENTITY_RESOURCE_ID, Collections.singletonList(userAssignedId));
}
}

private ServiceFabricManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint, String identityHeader)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.SERVICE_FABRIC);
this.msiEndpoint = msiEndpoint;
this.identityHeader = identityHeader;

this.idType = ((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
this.userAssignedId = ((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getUserAssignedId();
}

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT);


if (StringHelper.isNullOrBlank(msiEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
{
LOG.info("[Managed Identity] Service fabric managed identity is unavailable.");
return null;
}

return new ServiceFabricManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(msiEndpoint), identityHeader);
}

private static URI validateAndGetUri(String msiEndpoint)
{
try
{
URI endpointUri = new URI(msiEndpoint);
LOG.info("[Managed Identity] Environment variables validation passed for Service Fabric Managed Identity. Endpoint URI: " + endpointUri);
return endpointUri;
}
catch (URISyntaxException ex)
{
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Service Fabric"),
ManagedIdentitySourceType.SERVICE_FABRIC);
}
}

}
Loading